Files
RUL-Mamba/Helper_Plot.py
T
eason 79db6e5c96 feat: 添加RUL-Mamba模型及相关组件
新增锂电池剩余使用寿命预测模型RUL-Mamba,包含以下主要组件:
1. 添加Mamba模块作为核心时序建模组件
2. 实现特征注意力网络(FAN)和门控残差网络(GRN)
3. 新增数据预处理和归一化层
4. 添加模型训练和评估脚本
5. 补充README文档说明使用方法
6. 添加可视化辅助工具Helper_Plot.py
7. 实现多种时间序列处理层(Embedding、AutoCorrelation等)
8. 添加配置文件requirements.txt
9. 补充测试数据集TJU battery data
2026-01-09 08:53:50 +08:00

96 lines
6.3 KiB
Python

import os
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('agg')
import matplotlib.colors as mcolors
from matplotlib.font_manager import FontProperties
# import seaborn as sns
import pandas as pd
plt.rcParams['font.family'] = 'Times New Roman'
def single_model_draw_test_CY25_1_plt(real_data,pred_data_1,pred_data_2,pred_data_3,save_filename,save_figure_dir,Rated_Capacity,model):
otf_font_path_1 = 'fonts/otf/Times-Newer-Roman-Regular.otf' # 替换为您的OTF字体文件的实际路径
font_prop_1a = FontProperties(fname=otf_font_path_1,size=12)
font_prop_1b = FontProperties(fname=otf_font_path_1,size=15)
otf_font_path_2 = 'fonts/otf/Times-Newer-Roman-Bold.otf'
font_prop_2 = FontProperties(fname=otf_font_path_2,size=16)
fig, ax = plt.subplots()
x = [t+1 for t in range(len(real_data))]
threshold = [Rated_Capacity*0.7] * len(real_data)
# print(plt.rcParams['font.size'])
ax.set_xlim([0,1000])
ax.set_ylim([0.8,2.6])
plt.xticks(fontproperties=font_prop_1a)
plt.yticks(fontproperties=font_prop_1a)
ax.set_xticks([0,200,400,600,800,1000])
ax.set_yticks([0.8,1.1,1.4,1.7,2.0,2.3,2.6])
ax.plot(x,real_data,color='b',label='Real data',linewidth=1,zorder=2)#marker='>',markersize=1,markevery=(1, 5),
ax.plot(x[200-1:],pred_data_1,color=mcolors.CSS4_COLORS['red'],label="SP = 200",linewidth=1,zorder=2.5)#marker='s',markersize=1,markevery=(1, 5),
ax.plot(x[300-1:],pred_data_2,color=mcolors.CSS4_COLORS['yellow'],label="SP = 300",linewidth=1,zorder=3)#marker='.',markersize=1,markevery=(1, 5),
ax.plot(x[400-1:],pred_data_3,color=mcolors.CSS4_COLORS['lime'],label="SP = 400",linewidth=1,zorder=3.5)#marker='v',markersize=1,markevery=(1, 5),
ax.plot(x, threshold, mcolors.CSS4_COLORS['black'], linestyle='--',linewidth=1,zorder=4)
axins = ax.inset_axes([0.1, 0.1, 0.47, 0.47])
# 在子图中绘制相同的曲线
axins.plot(x,real_data,color='b',marker='>',label='Real data',linewidth=1,markersize=2,zorder=2)
axins.plot(x[200-1:],pred_data_1,color=mcolors.CSS4_COLORS['red'],marker='s',label="SP = 200",linewidth=1,markersize=2,zorder=2.5)
axins.plot(x[300-1:],pred_data_2,color=mcolors.CSS4_COLORS['yellow'],marker='.',label="SP = 300",linewidth=1,markersize=2,zorder=3)
axins.plot(x[400-1:],pred_data_3,color=mcolors.CSS4_COLORS['lime'],marker='v',label="SP = 400",linewidth=1,markersize=2,zorder=3.5)
axins.plot(x, threshold, mcolors.CSS4_COLORS['black'], linestyle='--',linewidth=1,zorder=4)
# 调整子图的坐标轴以聚焦特定区域
axins.set_xlim(765, 795)
axins.set_ylim(1.73, 1.77)
plt.xticks(fontproperties=font_prop_1a)
plt.yticks(fontproperties=font_prop_1a)
axins.set_xticks([765,775,785,795])
axins.set_yticks([1.73,1.75,1.77])
# 添加连接线,指示主图和子图之间的关系
ax.indicate_inset_zoom(axins, edgecolor="black")
ax.text(500,2.5,model,horizontalalignment='center',verticalalignment='center', color='k',fontsize=16,fontproperties=font_prop_2)
# plt.title('real v.s. prediction of battery ' + args.test_name, fontproperties=font_prop_1b)
ax.set_xlabel('Cycle',fontsize=15, fontproperties=font_prop_1b)
ax.set_ylabel('Capacity (Ah)',fontsize=15, fontproperties=font_prop_1b)
lines = [ax.get_lines()[0], ax.get_lines()[1], ax.get_lines()[2], ax.get_lines()[3]]
ax.legend(lines,['Real data', "SP = 200", "SP = 300", "SP = 400"],loc='upper right', frameon=True, prop=font_prop_1a) # 添加图例,并移除边框
# fig.savefig(os.path.join(save_figure_dir,save_filename+'.jpg'),dpi=1000)
fig.savefig(os.path.join(save_figure_dir,save_filename+'.png'),dpi=300, bbox_inches='tight')
fig.savefig(os.path.join(save_figure_dir,save_filename+'.svg'), bbox_inches='tight')#SVG文件的内容本身并不包含分辨率信息
if __name__ == '__main__':
import argparse
import torch
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='iTransformer',help='Model name.')
parser.add_argument('--seq_len', type=int, default=64, help='input sequence length')
parser.add_argument('--label_len', type=int, default=0, help='start token length')
parser.add_argument('--pred_len', type=int, default=1, help='prediction sequence length')
parser.add_argument('--Battery_list', type=list, default=['CY25_1', 'CY25_2', 'CY25_3'], help='Battery data')
parser.add_argument('--data_dir', type=str, default='data/TJU data/Dataset_3_NCM_NCA_battery/', help='path of the data file')
parser.add_argument('--Rated_Capacity', type=float, default=2.5, help='Rate Capacity')
parser.add_argument('--test_name', type=str, default='CY25_3', help='Battery data used for test')
parser.add_argument('--start_point_list', type=int, default=[200,400,600], help='The cycle when prediction gets started.')
parser.add_argument('--seed', type=int, default=1, help='Random seed.')
parser.add_argument('--root_dir', type=str, default='TJU_RUL_prediction_sl_64', help='root path of the store file')
parser.add_argument('--count', type=int, default=1, help='The number of independent experiment.')
parser.add_argument('--batch_size', type=int, default=128, help='The batch size.')
parser.add_argument('--max_epochs', type=int, default=200, help='max train epochs')
parser.add_argument('--patch_len', type=int, default=2, help='patch length')
parser.add_argument('--d_model', type=int, default=16, help='hidden dimensions of model')
args = parser.parse_args()
real_data = torch.load('results/Capacity_CY25_3_real_data.pth')
predict_results = torch.load('results/RUL_CY25_3_iTransformer.pth')
root_dir = 'results_{}/{}/{}/'.format(args.root_dir,args.test_name,args.model)
if not os.path.exists(root_dir):
os.makedirs(root_dir)
save_figure_dir = os.path.join(root_dir,'figures')
if not os.path.exists(save_figure_dir):
os.makedirs(save_figure_dir)
if args.test_name == 'CY25_3':
for i in range(args.count):
single_model_draw_test_CY25_1_plt(real_data,predict_results['SP200'][i],predict_results['SP400'][i],predict_results['SP600'][i],
save_filename='best_model_{}'.format(i)+'_RUL_Prediction',save_figure_dir=save_figure_dir,Rated_Capacity=args.Rated_Capacity,model=args.model)