feat: 添加RUL-Mamba模型及相关组件

新增锂电池剩余使用寿命预测模型RUL-Mamba,包含以下主要组件:
1. 添加Mamba模块作为核心时序建模组件
2. 实现特征注意力网络(FAN)和门控残差网络(GRN)
3. 新增数据预处理和归一化层
4. 添加模型训练和评估脚本
5. 补充README文档说明使用方法
6. 添加可视化辅助工具Helper_Plot.py
7. 实现多种时间序列处理层(Embedding、AutoCorrelation等)
8. 添加配置文件requirements.txt
9. 补充测试数据集TJU battery data
This commit is contained in:
2026-01-09 08:53:50 +08:00
parent c56bd0cc42
commit 79db6e5c96
53 changed files with 21021 additions and 1 deletions
+159
View File
@@ -0,0 +1,159 @@
import torch
from torch import nn
from ModelsModify.layers.Transformer_EncDec import Encoder, EncoderLayer
from ModelsModify.layers.SelfAttention_Family import FullAttention, AttentionLayer
from ModelsModify.layers.Embed import PatchEmbedding
class Transpose(nn.Module):
def __init__(self, *dims, contiguous=False):
super().__init__()
self.dims, self.contiguous = dims, contiguous
def forward(self, x):
if self.contiguous: return x.transpose(*self.dims).contiguous()
else: return x.transpose(*self.dims)
class FlattenHead(nn.Module):
def __init__(self, n_vars, nf, target_window, head_dropout=0):
super().__init__()
self.n_vars = n_vars
self.flatten = nn.Flatten(start_dim=-2)
self.linear = nn.Linear(nf, target_window)
self.dropout = nn.Dropout(head_dropout)
def forward(self, x): # x: [bs x nvars x d_model x patch_num]
x = self.flatten(x)
x = self.linear(x)
x = self.dropout(x)
return x
class PatchTST(nn.Module):
"""
Paper link: https://arxiv.org/pdf/2211.14730.pdf
"""
def __init__(self,patch_len=16, stride=8,task_name='short_term_forecast',seq_len=96, label_len=48, pred_len=96, enc_in=7, dec_in=7, c_out=1, e_layers=2, d_layers=1, n_heads=8,factor=3,
d_model=16, d_ff=32, des='Exp', expand=2, d_conv=4, top_k=5, embed='timeF',freq='h', dropout=0.1,num_kernels=6,
moving_avg=25,channel_independence=1, decomp_method='moving_avg', use_norm=1,
version='fourier', mode_select='random', modes=32, activation='gelu',seasonal_patterns='Monthly',
inverse=False, mask_rate=0.25, anomaly_ratio=0.25,output_attention=False,down_sampling_layers=0, down_sampling_window=1, down_sampling_method=None,
seg_len=48, num_workers=0, itr=1, train_epochs=100, batch_size=32, patience=3, learning_rate=0.0001, loss='MSE',
lradj='type1', use_amp=False, use_gpu=True, gpu=0, use_multi_gpu=False, devices='0,1,2,3', p_hidden_dims=[128, 128],
p_hidden_layers=2, use_dtw=False, augmentation_ratio=0, seed=2, jitter=False, scaling=False, permutation=False,
randompermutation=False, magwarp=False, timewarp=False, windowslice=False, windowwarp=False, rotation=False,
spawner=False, dtwwarp=False, shapedtwwarp=False, wdba=False, discdtw=False, discsdtw=False, extra_tag='',**kwargs):
"""
patch_len: int, patch len for patch_embedding
stride: int, stride for patch_embedding
"""
super().__init__()
self.task_name = task_name
self.seq_len = seq_len
self.pred_len = pred_len
padding = stride
# patching and embedding
self.patch_embedding = PatchEmbedding(
d_model, patch_len, stride, padding, dropout)
# Encoder
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
FullAttention(False, factor, attention_dropout=dropout,
output_attention=output_attention), d_model, n_heads),
d_model,
d_ff,
dropout=dropout,
activation=activation
) for l in range(e_layers)
],
norm_layer=nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
)
# Prediction Head
self.head_nf = d_model * \
int((seq_len - patch_len) / stride + 2)
self.head = FlattenHead(enc_in, self.head_nf, pred_len,
head_dropout=dropout)
self.projection_final = nn.Linear(pred_len*enc_in, pred_len*c_out, bias=True)
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
# do patching and embedding
x_enc = x_enc.permute(0, 2, 1)
# u: [bs * nvars x patch_num x d_model]
enc_out, n_vars = self.patch_embedding(x_enc)
# Encoder
# z: [bs * nvars x patch_num x d_model]
enc_out, attns = self.encoder(enc_out)
# z: [bs x nvars x patch_num x d_model]
enc_out = torch.reshape(
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
# z: [bs x nvars x d_model x patch_num]
enc_out = enc_out.permute(0, 1, 3, 2)
# Decoder
dec_out = self.head(enc_out) # z: [bs x nvars x target_window]
dec_out = dec_out.permute(0, 2, 1)
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * \
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out = dec_out + \
(means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out=dec_out[:, -self.pred_len:, :] # [B, L, D]
dec_out=self.projection_final(dec_out.reshape(dec_out.shape[0], -1))
return dec_out
from pytorch_forecasting.models import BaseModel
from typing import Dict
class PatchTSTNetModel(BaseModel):
def __init__(self,patch_len=6, stride=3,seq_len=24, pred_len=1, enc_in=7, c_out=1, e_layers=2, d_layers=1, factor=3,
d_model=16, d_ff=32, des='Exp', itr=1, top_k=5,embed='timeF',freq='h', dropout=0.1,num_kernels=6, **kwargs):
# saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this
self.save_hyperparameters()
# pass additional arguments to BaseModel.__init__, mandatory call - do not skip this
super().__init__(**kwargs)
self.network = PatchTST(
patch_len=patch_len,stride=stride,seq_len=seq_len, pred_len=pred_len, enc_in=enc_in, c_out=c_out, e_layers=e_layers, d_layers=d_layers, factor=factor,
d_model=d_model, d_ff=d_ff, des=des, itr=itr, top_k=top_k, embed=embed, freq=freq, dropout=dropout, num_kernels=num_kernels
)
# 修改,锂电池预测
def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
x_enc = x["encoder_cont"][:,:,:-1]
# 输出
prediction = self.network(x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None)
# 输出rescale rescale predictions into target space
prediction = self.transform_output(prediction, target_scale=x["target_scale"])
# 返回一个字典,包含输出结果(prediction)
return self.to_network_output(prediction=prediction)
if __name__=='__main__':
N,L,C=100,96,15
label_len = 16
c_out = 1
pred_len=16
x_enc=torch.ones((N,L,C))
# x_mark_enc=torch.ones((N, L, 4))
# x_dec = torch.ones((N, pred_len, C))
# x_mark_dec=torch.ones((N, pred_len, 4))
model=PatchTST(seq_len=L, enc_in=C, dec_in=C, label_len = label_len, pred_len=pred_len, c_out=1) # pred_len 被限制了
out=model(x_enc=x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None)
print(out.shape)