79db6e5c96
新增锂电池剩余使用寿命预测模型RUL-Mamba,包含以下主要组件: 1. 添加Mamba模块作为核心时序建模组件 2. 实现特征注意力网络(FAN)和门控残差网络(GRN) 3. 新增数据预处理和归一化层 4. 添加模型训练和评估脚本 5. 补充README文档说明使用方法 6. 添加可视化辅助工具Helper_Plot.py 7. 实现多种时间序列处理层(Embedding、AutoCorrelation等) 8. 添加配置文件requirements.txt 9. 补充测试数据集TJU battery data
137 lines
4.8 KiB
Python
137 lines
4.8 KiB
Python
import math
|
||
import torch
|
||
import torch.nn as nn
|
||
from torch.distributions.normal import Normal
|
||
import numpy as np
|
||
from ModelsModify.layers.AMS import AMS
|
||
from ModelsModify.layers.Layer import WeightGenerator, CustomLinear
|
||
from ModelsModify.layers.RevIN import RevIN
|
||
from functools import reduce
|
||
from operator import mul
|
||
|
||
from typing import Dict, List, Tuple, Union
|
||
|
||
from matplotlib import pyplot as plt
|
||
import numpy as np
|
||
import torch
|
||
from torch import nn
|
||
from torchmetrics import Metric as LightningMetric
|
||
|
||
from pytorch_forecasting.data import TimeSeriesDataSet
|
||
from pytorch_forecasting.data.encoders import NaNLabelEncoder, EncoderNormalizer, MultiNormalizer, TorchNormalizer
|
||
from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric, MultiLoss, QuantileLoss
|
||
from pytorch_forecasting.models.base_model import BaseModelWithCovariates
|
||
from pytorch_forecasting.models.nn import LSTM, MultiEmbedding
|
||
|
||
|
||
from pytorch_forecasting.utils import autocorrelation, create_mask, detach, integer_histogram, padded_stack, to_list
|
||
|
||
from pytorch_forecasting.models import BaseModel
|
||
from typing import Dict
|
||
|
||
|
||
import torch.nn.functional as F
|
||
from einops import rearrange
|
||
|
||
class PathFormer(nn.Module):
|
||
def __init__(self,
|
||
layer_nums=3,
|
||
num_nodes=15,
|
||
pred_len=1,
|
||
seq_len=24,
|
||
k=3,
|
||
num_experts_list=[4, 4, 4],
|
||
patch_size_list=[8,6,4,2], #能整除seq_len,且长度等于num_experts
|
||
d_model=16,
|
||
d_ff=64,
|
||
residual_connection=True,
|
||
revin=1,
|
||
gpu=0
|
||
):
|
||
super(PathFormer, self).__init__()
|
||
self.layer_nums = layer_nums # 设置pathway的层数
|
||
self.num_nodes = num_nodes
|
||
self.pre_len = pred_len
|
||
self.seq_len = seq_len
|
||
self.k = k
|
||
self.num_experts_list = num_experts_list
|
||
self.patch_size_list = patch_size_list
|
||
self.d_model = d_model
|
||
self.d_ff = d_ff
|
||
self.residual_connection = residual_connection
|
||
self.revin = revin
|
||
self.gpu = gpu
|
||
|
||
if self.revin:
|
||
self.revin_layer = RevIN(num_features=self.num_nodes, affine=False, subtract_last=False)
|
||
|
||
self.start_fc = nn.Linear(in_features=1, out_features=self.d_model)
|
||
self.AMS_lists = nn.ModuleList()
|
||
self.device = torch.device('cuda:{}'.format(self.gpu))
|
||
|
||
for num in range(self.layer_nums):
|
||
self.AMS_lists.append(
|
||
AMS(self.seq_len, self.seq_len, self.num_experts_list[num], self.device, k=self.k,
|
||
num_nodes=self.num_nodes, patch_size=self.patch_size_list, noisy_gating=True,
|
||
d_model=self.d_model, d_ff=self.d_ff, layer_number=num + 1, residual_connection=self.residual_connection))
|
||
self.projections = nn.Sequential(
|
||
nn.Linear(self.seq_len * self.d_model, self.pre_len)
|
||
)
|
||
|
||
def forward(self,x: torch.Tensor) -> torch.Tensor:
|
||
|
||
#balance_loss = 0
|
||
# norm
|
||
if self.revin:
|
||
x = self.revin_layer(x, 'norm')
|
||
out = self.start_fc(x.unsqueeze(-1))
|
||
|
||
|
||
batch_size = x.shape[0]
|
||
|
||
for layer in self.AMS_lists:
|
||
out, _ = layer(out)
|
||
#balance_loss += aux_loss
|
||
|
||
out = out.permute(0,2,1,3).reshape(batch_size, self.num_nodes, -1)
|
||
out = self.projections(out).transpose(2, 1)
|
||
|
||
# denorm
|
||
if self.revin:
|
||
out = self.revin_layer(out, 'denorm')
|
||
#print("Out shape:",out.shape)
|
||
out = out[:,:,-1].view(out.shape[0],self.pre_len)
|
||
|
||
return out
|
||
|
||
class PathFormerModel(BaseModel):
|
||
def __init__(self,
|
||
enc_in:int,
|
||
seq_len:int,
|
||
pred_len:int,
|
||
k:int,
|
||
patch_size_list:list,
|
||
**kwargs):
|
||
# saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this
|
||
self.save_hyperparameters()
|
||
|
||
super().__init__(**kwargs)
|
||
self.network = PathFormer(
|
||
num_nodes=self.hparams.enc_in,
|
||
seq_len=self.hparams.seq_len,
|
||
pred_len=self.hparams.pred_len,
|
||
k=self.hparams.k,
|
||
patch_size_list=self.hparams.patch_size_list,
|
||
)
|
||
|
||
# 修改,锂电池预测
|
||
def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||
|
||
x_enc = x["encoder_cont"][:,:,:-1]
|
||
# 输出
|
||
prediction = self.network(x_enc)
|
||
# 输出rescale, rescale predictions into target space
|
||
prediction = self.transform_output(prediction, target_scale=x["target_scale"])
|
||
|
||
# 返回一个字典,包含输出结果(prediction)
|
||
return self.to_network_output(prediction=prediction) |