Files
eason 79db6e5c96 feat: 添加RUL-Mamba模型及相关组件
新增锂电池剩余使用寿命预测模型RUL-Mamba,包含以下主要组件:
1. 添加Mamba模块作为核心时序建模组件
2. 实现特征注意力网络(FAN)和门控残差网络(GRN)
3. 新增数据预处理和归一化层
4. 添加模型训练和评估脚本
5. 补充README文档说明使用方法
6. 添加可视化辅助工具Helper_Plot.py
7. 实现多种时间序列处理层(Embedding、AutoCorrelation等)
8. 添加配置文件requirements.txt
9. 补充测试数据集TJU battery data
2026-01-09 08:53:50 +08:00

280 lines
9.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict
# from mamba_ssm import Mamba2
# from MambaSimple import MambaBlock as Mamba
from ModelsModify.MambaSimple import MambaBlock as Mamba
from pytorch_forecasting.models import BaseModel
from einops import rearrange
import math
class MyGLU(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.linear_1 = nn.Linear(input_size, output_size)
self.linear_2 = nn.Linear(input_size, output_size)
self.glu = nn.GLU()
def forward(self, x):
a = self.linear_1(x)
b = self.linear_2(x)
return self.glu(torch.cat([a, b], dim=-1))
class GRN(nn.Module):
def __init__(self, input_size, hidden_size, dropout, external=False):
super().__init__()
self.shortcut = nn.Identity()
if external:
body_input_size = 2 * input_size
else:
body_input_size = input_size
self.body = nn.Sequential(
nn.Linear(body_input_size, hidden_size),
nn.ELU(),
nn.Linear(hidden_size, input_size),
nn.Dropout(dropout),
MyGLU(input_size, input_size))
self.norm = nn.LayerNorm(input_size)
def forward(self, x, e=None):
s = self.shortcut(x)
if e is not None:
x = torch.cat([x, e], dim=-1)
x = self.body(x)
y = self.norm(s + x)
return y
class VarEncoder(nn.Module):
def __init__(self, n_var, dim):
super().__init__()
self.n_var = n_var
self.dim = dim
self.layers = nn.ModuleList([nn.Sequential(nn.Linear(1, dim))
for _ in range(n_var)])
def forward(self, x):
# [B, L, C]
y = torch.zeros([x.shape[0], x.shape[1], self.n_var, self.dim], device=x.device)
for i in range(self.n_var):
y[:, :, i, :] = self.layers[i](x[:, :, i].unsqueeze(-1))
return y
class DecoderBlock(nn.Module):
def __init__(self, d_inner,dt_rank,d_model,d_ff, dropout):
super().__init__()
self.shortcut = nn.Identity()
self.mamba_dec = Mamba(d_inner= d_inner, dt_rank=dt_rank, d_model=d_model, d_ff=d_ff,d_conv=4, top_k=5)
self.grn = GRN(d_model, 2 * d_model, dropout, external=True)
def forward(self, x, c):
if x is not None:#天气预测
x = torch.cat([x, c], dim=-2)
s = self.shortcut(x)
y = self.mamba_dec(x)
y = self.grn(y, s)[:, 1:, :]
else:#锂电池容量预测
s = self.shortcut(c)
y = self.mamba_dec(c)
y = self.grn(y, s)
return y
#全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
class AttentionSE2D(nn.Module):
def __init__(self, inchannel, ratio=2):
super(AttentionSE2D, self).__init__()
# 全局平均池化(Fsq操作)
self.gap = nn.AdaptiveAvgPool2d((1, 1))
# 两个全连接层(Fex操作)
self.fc = nn.Sequential(
nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r
nn.ReLU(),
nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c
nn.Sigmoid()
)
def forward(self, x):
# 读取批数据图片数量及通道数
b, c, h, w = x.size()
# Fsq操作:经池化后输出b*c的矩阵
y = self.gap(x).view(b, c)
# Fex操作:经全连接层输出(b,c,1,1)矩阵
y = self.fc(y).view(b, c, 1, 1)
# Fscale操作:将得到的权重乘以原来的特征图x
return x * y.expand_as(x)
class GatedResidualNetwork(nn.Module):
def __init__(self, input_size, hidden_size, output_size, dropout):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.elu = nn.ELU(alpha=1.0)
self.fc2 = nn.Linear(hidden_size, output_size)
self.gate_norm = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(output_size, 2 * output_size),
nn.GLU(dim=-1),
nn.LayerNorm(output_size)
)
def forward(self, x):
x = self.elu(self.fc1(x))
x = self.fc2(x)
x = self.gate_norm(x)
return x
class VariableAttentionNetwork(nn.Module):
def __init__(
self,
input_sizes: list,
hidden_size: int,
dropout: float = 0.1,
calc_sum=True
):
super().__init__()
self.calc_sum = calc_sum
self.hidden_size = hidden_size
self.input_sizes = input_sizes
self.dropout = dropout
self.single_variable_grns = nn.ModuleList()
self.prescalers = nn.ModuleList()
for input_size in input_sizes:
self.single_variable_grns.append(
GatedResidualNetwork(
input_size,
min(input_size, hidden_size),
output_size=hidden_size,
dropout=dropout
)
)
self.prescalers.append(nn.Linear(1, input_size))
self.flattened_grn = GatedResidualNetwork(
sum(input_sizes),
min(hidden_size, len(input_sizes)),
len(input_sizes),
dropout
)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x: torch.Tensor):
# 假设 x 的形状为 [B, L, M, D]D 为变量数量
var_outputs = []
weight_inputs = []
for i in range(len(self.input_sizes)):
variable_embedding = x[..., i:i + 1]
variable_embedding = self.prescalers[i](variable_embedding)
weight_inputs.append(variable_embedding)
var_outputs.append(self.single_variable_grns[i](variable_embedding))
var_outputs = torch.stack(var_outputs, dim=-1)
flat_embedding = torch.cat(weight_inputs, dim=-1)
sparse_weights = self.flattened_grn(flat_embedding)
sparse_weights = self.softmax(sparse_weights).unsqueeze(-2)
outputs = var_outputs * sparse_weights
if self.calc_sum:
outputs = outputs.sum(dim=-1)
return outputs, sparse_weights
class RULMambaVAN(nn.Module):
def __init__(self, lookback, predict, enc_in, dec_in, d_model, d_ff,n_enc_layer, n_dec_layer, dropout,expand=2):
super().__init__()
self.emb_enc = VarEncoder(enc_in, d_model)
self.encoder_VSN = VariableAttentionNetwork(input_sizes=[d_model]*enc_in,hidden_size=d_model,dropout=0.1,calc_sum=False)
self.se_attention = AttentionSE2D(d_model)
self.lookback = lookback
self.d_inner = d_model * expand
self.dt_rank = math.ceil(d_model / 16)
self.d_ff = 2*d_model
self.mamba_enc = Mamba(d_inner=self.d_inner, dt_rank=self.dt_rank, d_model=d_model, d_ff=self.d_ff)
self.dec = nn.ModuleList([DecoderBlock(d_inner=self.d_inner, dt_rank=self.dt_rank, d_model=d_model, d_ff=self.d_ff,dropout=dropout) for _ in range(n_dec_layer)])
self.proj = nn.Linear(d_model, 1)
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# x_enc = self.emb_enc(x_enc)#[B,L,M]->[B,L,M,D]
x_enc = x_enc.unsqueeze(-2)
# ------------------------------编码器的VSN,特征选择-----------------------------
x_enc, _ = self.encoder_VSN(x_enc)
x_enc = x_enc.squeeze(2)
x_enc = rearrange(x_enc,'b l d m -> b d l m')
x_enc = self.se_attention(x_enc)#[B,D,L,M]->[B,D,L,M]
x_enc = rearrange(x_enc,'b d l m -> b l m d')
x_enc = torch.sum(x_enc, dim=-2)#[B,L,M,D]->[B,L,D]
enc_out = self.mamba_enc(x_enc)#[B,L,D]->[B,L,D]
context = enc_out[:, -1:, :]#[B,L,D]->[B,1,D]
for i in range(len(self.dec)):
dec_out = self.dec[i](x_dec, context)#[B,1,D]->[B,1,D]
out = self.proj(dec_out)#[B,1,D]->[B,1,1]
return out
'''
--lookback 24 --predict 24 --advance_features False --future_info True --n_trials 30
'''
class RULMambaVANNetModel(BaseModel):
def __init__(self, lookback=24, predict=24, enc_in=11, dec_in=4, d_model=16, d_ff=32,n_enc_layer=1, n_dec_layer=2, dropout=0.01, **kwargs):
self.save_hyperparameters()
super().__init__(**kwargs)
# self.future_info = future_info
self.network = RULMambaVAN(lookback=lookback, predict=predict, enc_in=enc_in,
dec_in=dec_in, d_model=d_model, d_ff=d_ff,
n_enc_layer=n_enc_layer, n_dec_layer=n_dec_layer, dropout=dropout)
# 修改,锂电池预测
def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
x_enc = x["encoder_cont"][:,:,:-1] # x_enc:torch.Size([batch_size, seq_len, enc_in])
# 输出
prediction = self.network(x_enc=x_enc,x_mark_enc=None,x_dec=None,x_mark_dec=None)
# 输出rescale rescale predictions into target space
prediction = self.transform_output(prediction, target_scale=x["target_scale"])
# 返回一个字典,包含输出结果(prediction)
return self.to_network_output(prediction=prediction)
if __name__=='__main__':
N,L,C=128,10,1
x_enc=torch.ones((N,L,C))
x_mark_enc=None
# x_dec =torch.ones((N,L,C))
x_dec=None
x_mark_dec=None
model=RULMambaVAN(lookback=10, predict=1, enc_in=1, dec_in=1, d_model=16 ,d_ff=32,n_enc_layer=1, n_dec_layer=2, dropout=0.01) # pred_len 被限制了
out=model(x_enc=x_enc, x_mark_enc=x_mark_enc, x_dec=x_dec, x_mark_dec=x_mark_dec)
print(out.shape)