79db6e5c96
新增锂电池剩余使用寿命预测模型RUL-Mamba,包含以下主要组件: 1. 添加Mamba模块作为核心时序建模组件 2. 实现特征注意力网络(FAN)和门控残差网络(GRN) 3. 新增数据预处理和归一化层 4. 添加模型训练和评估脚本 5. 补充README文档说明使用方法 6. 添加可视化辅助工具Helper_Plot.py 7. 实现多种时间序列处理层(Embedding、AutoCorrelation等) 8. 添加配置文件requirements.txt 9. 补充测试数据集TJU battery data
159 lines
7.1 KiB
Python
159 lines
7.1 KiB
Python
import torch
|
||
from torch import nn
|
||
from ModelsModify.layers.Transformer_EncDec import Encoder, EncoderLayer
|
||
from ModelsModify.layers.SelfAttention_Family import FullAttention, AttentionLayer
|
||
from ModelsModify.layers.Embed import PatchEmbedding
|
||
class Transpose(nn.Module):
|
||
def __init__(self, *dims, contiguous=False):
|
||
super().__init__()
|
||
self.dims, self.contiguous = dims, contiguous
|
||
def forward(self, x):
|
||
if self.contiguous: return x.transpose(*self.dims).contiguous()
|
||
else: return x.transpose(*self.dims)
|
||
|
||
|
||
class FlattenHead(nn.Module):
|
||
def __init__(self, n_vars, nf, target_window, head_dropout=0):
|
||
super().__init__()
|
||
self.n_vars = n_vars
|
||
self.flatten = nn.Flatten(start_dim=-2)
|
||
self.linear = nn.Linear(nf, target_window)
|
||
self.dropout = nn.Dropout(head_dropout)
|
||
|
||
def forward(self, x): # x: [bs x nvars x d_model x patch_num]
|
||
x = self.flatten(x)
|
||
x = self.linear(x)
|
||
x = self.dropout(x)
|
||
return x
|
||
|
||
|
||
class PatchTST(nn.Module):
|
||
"""
|
||
Paper link: https://arxiv.org/pdf/2211.14730.pdf
|
||
"""
|
||
|
||
def __init__(self,patch_len=16, stride=8,task_name='short_term_forecast',seq_len=96, label_len=48, pred_len=96, enc_in=7, dec_in=7, c_out=1, e_layers=2, d_layers=1, n_heads=8,factor=3,
|
||
d_model=16, d_ff=32, des='Exp', expand=2, d_conv=4, top_k=5, embed='timeF',freq='h', dropout=0.1,num_kernels=6,
|
||
moving_avg=25,channel_independence=1, decomp_method='moving_avg', use_norm=1,
|
||
version='fourier', mode_select='random', modes=32, activation='gelu',seasonal_patterns='Monthly',
|
||
inverse=False, mask_rate=0.25, anomaly_ratio=0.25,output_attention=False,down_sampling_layers=0, down_sampling_window=1, down_sampling_method=None,
|
||
seg_len=48, num_workers=0, itr=1, train_epochs=100, batch_size=32, patience=3, learning_rate=0.0001, loss='MSE',
|
||
lradj='type1', use_amp=False, use_gpu=True, gpu=0, use_multi_gpu=False, devices='0,1,2,3', p_hidden_dims=[128, 128],
|
||
p_hidden_layers=2, use_dtw=False, augmentation_ratio=0, seed=2, jitter=False, scaling=False, permutation=False,
|
||
randompermutation=False, magwarp=False, timewarp=False, windowslice=False, windowwarp=False, rotation=False,
|
||
spawner=False, dtwwarp=False, shapedtwwarp=False, wdba=False, discdtw=False, discsdtw=False, extra_tag='',**kwargs):
|
||
"""
|
||
patch_len: int, patch len for patch_embedding
|
||
stride: int, stride for patch_embedding
|
||
"""
|
||
super().__init__()
|
||
self.task_name = task_name
|
||
self.seq_len = seq_len
|
||
self.pred_len = pred_len
|
||
padding = stride
|
||
|
||
# patching and embedding
|
||
self.patch_embedding = PatchEmbedding(
|
||
d_model, patch_len, stride, padding, dropout)
|
||
|
||
# Encoder
|
||
self.encoder = Encoder(
|
||
[
|
||
EncoderLayer(
|
||
AttentionLayer(
|
||
FullAttention(False, factor, attention_dropout=dropout,
|
||
output_attention=output_attention), d_model, n_heads),
|
||
d_model,
|
||
d_ff,
|
||
dropout=dropout,
|
||
activation=activation
|
||
) for l in range(e_layers)
|
||
],
|
||
norm_layer=nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
|
||
)
|
||
|
||
# Prediction Head
|
||
self.head_nf = d_model * \
|
||
int((seq_len - patch_len) / stride + 2)
|
||
|
||
self.head = FlattenHead(enc_in, self.head_nf, pred_len,
|
||
head_dropout=dropout)
|
||
|
||
self.projection_final = nn.Linear(pred_len*enc_in, pred_len*c_out, bias=True)
|
||
|
||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||
# Normalization from Non-stationary Transformer
|
||
means = x_enc.mean(1, keepdim=True).detach()
|
||
x_enc = x_enc - means
|
||
stdev = torch.sqrt(
|
||
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||
x_enc /= stdev
|
||
|
||
# do patching and embedding
|
||
x_enc = x_enc.permute(0, 2, 1)
|
||
# u: [bs * nvars x patch_num x d_model]
|
||
enc_out, n_vars = self.patch_embedding(x_enc)
|
||
|
||
# Encoder
|
||
# z: [bs * nvars x patch_num x d_model]
|
||
enc_out, attns = self.encoder(enc_out)
|
||
# z: [bs x nvars x patch_num x d_model]
|
||
enc_out = torch.reshape(
|
||
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
|
||
# z: [bs x nvars x d_model x patch_num]
|
||
enc_out = enc_out.permute(0, 1, 3, 2)
|
||
|
||
# Decoder
|
||
dec_out = self.head(enc_out) # z: [bs x nvars x target_window]
|
||
dec_out = dec_out.permute(0, 2, 1)
|
||
|
||
# De-Normalization from Non-stationary Transformer
|
||
dec_out = dec_out * \
|
||
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||
dec_out = dec_out + \
|
||
(means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||
dec_out=dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||
dec_out=self.projection_final(dec_out.reshape(dec_out.shape[0], -1))
|
||
return dec_out
|
||
|
||
|
||
from pytorch_forecasting.models import BaseModel
|
||
from typing import Dict
|
||
|
||
class PatchTSTNetModel(BaseModel):
|
||
def __init__(self,patch_len=6, stride=3,seq_len=24, pred_len=1, enc_in=7, c_out=1, e_layers=2, d_layers=1, factor=3,
|
||
d_model=16, d_ff=32, des='Exp', itr=1, top_k=5,embed='timeF',freq='h', dropout=0.1,num_kernels=6, **kwargs):
|
||
# saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this
|
||
self.save_hyperparameters()
|
||
# pass additional arguments to BaseModel.__init__, mandatory call - do not skip this
|
||
super().__init__(**kwargs)
|
||
self.network = PatchTST(
|
||
patch_len=patch_len,stride=stride,seq_len=seq_len, pred_len=pred_len, enc_in=enc_in, c_out=c_out, e_layers=e_layers, d_layers=d_layers, factor=factor,
|
||
d_model=d_model, d_ff=d_ff, des=des, itr=itr, top_k=top_k, embed=embed, freq=freq, dropout=dropout, num_kernels=num_kernels
|
||
)
|
||
|
||
# 修改,锂电池预测
|
||
def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||
|
||
x_enc = x["encoder_cont"][:,:,:-1]
|
||
# 输出
|
||
prediction = self.network(x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None)
|
||
# 输出rescale, rescale predictions into target space
|
||
prediction = self.transform_output(prediction, target_scale=x["target_scale"])
|
||
|
||
# 返回一个字典,包含输出结果(prediction)
|
||
return self.to_network_output(prediction=prediction)
|
||
|
||
|
||
if __name__=='__main__':
|
||
N,L,C=100,96,15
|
||
label_len = 16
|
||
c_out = 1
|
||
pred_len=16
|
||
x_enc=torch.ones((N,L,C))
|
||
# x_mark_enc=torch.ones((N, L, 4))
|
||
# x_dec = torch.ones((N, pred_len, C))
|
||
# x_mark_dec=torch.ones((N, pred_len, 4))
|
||
model=PatchTST(seq_len=L, enc_in=C, dec_in=C, label_len = label_len, pred_len=pred_len, c_out=1) # pred_len 被限制了
|
||
out=model(x_enc=x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None)
|
||
print(out.shape) |