Files
eason 79db6e5c96 feat: 添加RUL-Mamba模型及相关组件
新增锂电池剩余使用寿命预测模型RUL-Mamba,包含以下主要组件:
1. 添加Mamba模块作为核心时序建模组件
2. 实现特征注意力网络(FAN)和门控残差网络(GRN)
3. 新增数据预处理和归一化层
4. 添加模型训练和评估脚本
5. 补充README文档说明使用方法
6. 添加可视化辅助工具Helper_Plot.py
7. 实现多种时间序列处理层(Embedding、AutoCorrelation等)
8. 添加配置文件requirements.txt
9. 补充测试数据集TJU battery data
2026-01-09 08:53:50 +08:00

159 lines
7.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
from torch import nn
from ModelsModify.layers.Transformer_EncDec import Encoder, EncoderLayer
from ModelsModify.layers.SelfAttention_Family import FullAttention, AttentionLayer
from ModelsModify.layers.Embed import PatchEmbedding
class Transpose(nn.Module):
def __init__(self, *dims, contiguous=False):
super().__init__()
self.dims, self.contiguous = dims, contiguous
def forward(self, x):
if self.contiguous: return x.transpose(*self.dims).contiguous()
else: return x.transpose(*self.dims)
class FlattenHead(nn.Module):
def __init__(self, n_vars, nf, target_window, head_dropout=0):
super().__init__()
self.n_vars = n_vars
self.flatten = nn.Flatten(start_dim=-2)
self.linear = nn.Linear(nf, target_window)
self.dropout = nn.Dropout(head_dropout)
def forward(self, x): # x: [bs x nvars x d_model x patch_num]
x = self.flatten(x)
x = self.linear(x)
x = self.dropout(x)
return x
class PatchTST(nn.Module):
"""
Paper link: https://arxiv.org/pdf/2211.14730.pdf
"""
def __init__(self,patch_len=16, stride=8,task_name='short_term_forecast',seq_len=96, label_len=48, pred_len=96, enc_in=7, dec_in=7, c_out=1, e_layers=2, d_layers=1, n_heads=8,factor=3,
d_model=16, d_ff=32, des='Exp', expand=2, d_conv=4, top_k=5, embed='timeF',freq='h', dropout=0.1,num_kernels=6,
moving_avg=25,channel_independence=1, decomp_method='moving_avg', use_norm=1,
version='fourier', mode_select='random', modes=32, activation='gelu',seasonal_patterns='Monthly',
inverse=False, mask_rate=0.25, anomaly_ratio=0.25,output_attention=False,down_sampling_layers=0, down_sampling_window=1, down_sampling_method=None,
seg_len=48, num_workers=0, itr=1, train_epochs=100, batch_size=32, patience=3, learning_rate=0.0001, loss='MSE',
lradj='type1', use_amp=False, use_gpu=True, gpu=0, use_multi_gpu=False, devices='0,1,2,3', p_hidden_dims=[128, 128],
p_hidden_layers=2, use_dtw=False, augmentation_ratio=0, seed=2, jitter=False, scaling=False, permutation=False,
randompermutation=False, magwarp=False, timewarp=False, windowslice=False, windowwarp=False, rotation=False,
spawner=False, dtwwarp=False, shapedtwwarp=False, wdba=False, discdtw=False, discsdtw=False, extra_tag='',**kwargs):
"""
patch_len: int, patch len for patch_embedding
stride: int, stride for patch_embedding
"""
super().__init__()
self.task_name = task_name
self.seq_len = seq_len
self.pred_len = pred_len
padding = stride
# patching and embedding
self.patch_embedding = PatchEmbedding(
d_model, patch_len, stride, padding, dropout)
# Encoder
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
FullAttention(False, factor, attention_dropout=dropout,
output_attention=output_attention), d_model, n_heads),
d_model,
d_ff,
dropout=dropout,
activation=activation
) for l in range(e_layers)
],
norm_layer=nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
)
# Prediction Head
self.head_nf = d_model * \
int((seq_len - patch_len) / stride + 2)
self.head = FlattenHead(enc_in, self.head_nf, pred_len,
head_dropout=dropout)
self.projection_final = nn.Linear(pred_len*enc_in, pred_len*c_out, bias=True)
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
# do patching and embedding
x_enc = x_enc.permute(0, 2, 1)
# u: [bs * nvars x patch_num x d_model]
enc_out, n_vars = self.patch_embedding(x_enc)
# Encoder
# z: [bs * nvars x patch_num x d_model]
enc_out, attns = self.encoder(enc_out)
# z: [bs x nvars x patch_num x d_model]
enc_out = torch.reshape(
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
# z: [bs x nvars x d_model x patch_num]
enc_out = enc_out.permute(0, 1, 3, 2)
# Decoder
dec_out = self.head(enc_out) # z: [bs x nvars x target_window]
dec_out = dec_out.permute(0, 2, 1)
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * \
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out = dec_out + \
(means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out=dec_out[:, -self.pred_len:, :] # [B, L, D]
dec_out=self.projection_final(dec_out.reshape(dec_out.shape[0], -1))
return dec_out
from pytorch_forecasting.models import BaseModel
from typing import Dict
class PatchTSTNetModel(BaseModel):
def __init__(self,patch_len=6, stride=3,seq_len=24, pred_len=1, enc_in=7, c_out=1, e_layers=2, d_layers=1, factor=3,
d_model=16, d_ff=32, des='Exp', itr=1, top_k=5,embed='timeF',freq='h', dropout=0.1,num_kernels=6, **kwargs):
# saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this
self.save_hyperparameters()
# pass additional arguments to BaseModel.__init__, mandatory call - do not skip this
super().__init__(**kwargs)
self.network = PatchTST(
patch_len=patch_len,stride=stride,seq_len=seq_len, pred_len=pred_len, enc_in=enc_in, c_out=c_out, e_layers=e_layers, d_layers=d_layers, factor=factor,
d_model=d_model, d_ff=d_ff, des=des, itr=itr, top_k=top_k, embed=embed, freq=freq, dropout=dropout, num_kernels=num_kernels
)
# 修改,锂电池预测
def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
x_enc = x["encoder_cont"][:,:,:-1]
# 输出
prediction = self.network(x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None)
# 输出rescale rescale predictions into target space
prediction = self.transform_output(prediction, target_scale=x["target_scale"])
# 返回一个字典,包含输出结果(prediction)
return self.to_network_output(prediction=prediction)
if __name__=='__main__':
N,L,C=100,96,15
label_len = 16
c_out = 1
pred_len=16
x_enc=torch.ones((N,L,C))
# x_mark_enc=torch.ones((N, L, 4))
# x_dec = torch.ones((N, pred_len, C))
# x_mark_dec=torch.ones((N, pred_len, 4))
model=PatchTST(seq_len=L, enc_in=C, dec_in=C, label_len = label_len, pred_len=pred_len, c_out=1) # pred_len 被限制了
out=model(x_enc=x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None)
print(out.shape)