Files
eason 79db6e5c96 feat: 添加RUL-Mamba模型及相关组件
新增锂电池剩余使用寿命预测模型RUL-Mamba,包含以下主要组件:
1. 添加Mamba模块作为核心时序建模组件
2. 实现特征注意力网络(FAN)和门控残差网络(GRN)
3. 新增数据预处理和归一化层
4. 添加模型训练和评估脚本
5. 补充README文档说明使用方法
6. 添加可视化辅助工具Helper_Plot.py
7. 实现多种时间序列处理层(Embedding、AutoCorrelation等)
8. 添加配置文件requirements.txt
9. 补充测试数据集TJU battery data
2026-01-09 08:53:50 +08:00

137 lines
4.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import math
import torch
import torch.nn as nn
from torch.distributions.normal import Normal
import numpy as np
from ModelsModify.layers.AMS import AMS
from ModelsModify.layers.Layer import WeightGenerator, CustomLinear
from ModelsModify.layers.RevIN import RevIN
from functools import reduce
from operator import mul
from typing import Dict, List, Tuple, Union
from matplotlib import pyplot as plt
import numpy as np
import torch
from torch import nn
from torchmetrics import Metric as LightningMetric
from pytorch_forecasting.data import TimeSeriesDataSet
from pytorch_forecasting.data.encoders import NaNLabelEncoder, EncoderNormalizer, MultiNormalizer, TorchNormalizer
from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric, MultiLoss, QuantileLoss
from pytorch_forecasting.models.base_model import BaseModelWithCovariates
from pytorch_forecasting.models.nn import LSTM, MultiEmbedding
from pytorch_forecasting.utils import autocorrelation, create_mask, detach, integer_histogram, padded_stack, to_list
from pytorch_forecasting.models import BaseModel
from typing import Dict
import torch.nn.functional as F
from einops import rearrange
class PathFormer(nn.Module):
def __init__(self,
layer_nums=3,
num_nodes=15,
pred_len=1,
seq_len=24,
k=3,
num_experts_list=[4, 4, 4],
patch_size_list=[8,6,4,2], #能整除seq_len,且长度等于num_experts
d_model=16,
d_ff=64,
residual_connection=True,
revin=1,
gpu=0
):
super(PathFormer, self).__init__()
self.layer_nums = layer_nums # 设置pathway的层数
self.num_nodes = num_nodes
self.pre_len = pred_len
self.seq_len = seq_len
self.k = k
self.num_experts_list = num_experts_list
self.patch_size_list = patch_size_list
self.d_model = d_model
self.d_ff = d_ff
self.residual_connection = residual_connection
self.revin = revin
self.gpu = gpu
if self.revin:
self.revin_layer = RevIN(num_features=self.num_nodes, affine=False, subtract_last=False)
self.start_fc = nn.Linear(in_features=1, out_features=self.d_model)
self.AMS_lists = nn.ModuleList()
self.device = torch.device('cuda:{}'.format(self.gpu))
for num in range(self.layer_nums):
self.AMS_lists.append(
AMS(self.seq_len, self.seq_len, self.num_experts_list[num], self.device, k=self.k,
num_nodes=self.num_nodes, patch_size=self.patch_size_list, noisy_gating=True,
d_model=self.d_model, d_ff=self.d_ff, layer_number=num + 1, residual_connection=self.residual_connection))
self.projections = nn.Sequential(
nn.Linear(self.seq_len * self.d_model, self.pre_len)
)
def forward(self,x: torch.Tensor) -> torch.Tensor:
#balance_loss = 0
# norm
if self.revin:
x = self.revin_layer(x, 'norm')
out = self.start_fc(x.unsqueeze(-1))
batch_size = x.shape[0]
for layer in self.AMS_lists:
out, _ = layer(out)
#balance_loss += aux_loss
out = out.permute(0,2,1,3).reshape(batch_size, self.num_nodes, -1)
out = self.projections(out).transpose(2, 1)
# denorm
if self.revin:
out = self.revin_layer(out, 'denorm')
#print("Out shape:",out.shape)
out = out[:,:,-1].view(out.shape[0],self.pre_len)
return out
class PathFormerModel(BaseModel):
def __init__(self,
enc_in:int,
seq_len:int,
pred_len:int,
k:int,
patch_size_list:list,
**kwargs):
# saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this
self.save_hyperparameters()
super().__init__(**kwargs)
self.network = PathFormer(
num_nodes=self.hparams.enc_in,
seq_len=self.hparams.seq_len,
pred_len=self.hparams.pred_len,
k=self.hparams.k,
patch_size_list=self.hparams.patch_size_list,
)
# 修改,锂电池预测
def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
x_enc = x["encoder_cont"][:,:,:-1]
# 输出
prediction = self.network(x_enc)
# 输出rescale rescale predictions into target space
prediction = self.transform_output(prediction, target_scale=x["target_scale"])
# 返回一个字典,包含输出结果(prediction)
return self.to_network_output(prediction=prediction)