Files
eason 79db6e5c96 feat: 添加RUL-Mamba模型及相关组件
新增锂电池剩余使用寿命预测模型RUL-Mamba,包含以下主要组件:
1. 添加Mamba模块作为核心时序建模组件
2. 实现特征注意力网络(FAN)和门控残差网络(GRN)
3. 新增数据预处理和归一化层
4. 添加模型训练和评估脚本
5. 补充README文档说明使用方法
6. 添加可视化辅助工具Helper_Plot.py
7. 实现多种时间序列处理层(Embedding、AutoCorrelation等)
8. 添加配置文件requirements.txt
9. 补充测试数据集TJU battery data
2026-01-09 08:53:50 +08:00

417 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat, einsum
# from .layers.Embed import DataEmbedding
# import torch.nn.functional as F
from torch.nn.utils import weight_norm
import math
class PositionalEmbedding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEmbedding, self).__init__()
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model).float()
pe.require_grad = False
position = torch.arange(0, max_len).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float()
* -(math.log(10000.0) / d_model)).exp()
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return self.pe[:, :x.size(1)]
class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(TokenEmbedding, self).__init__()
padding = 1 if torch.__version__ >= '1.5.0' else 2
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(
m.weight, mode='fan_in', nonlinearity='leaky_relu')
def forward(self, x):
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
return x
class FixedEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(FixedEmbedding, self).__init__()
w = torch.zeros(c_in, d_model).float()
w.require_grad = False
position = torch.arange(0, c_in).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float()
* -(math.log(10000.0) / d_model)).exp()
w[:, 0::2] = torch.sin(position * div_term)
w[:, 1::2] = torch.cos(position * div_term)
self.emb = nn.Embedding(c_in, d_model)
self.emb.weight = nn.Parameter(w, requires_grad=False)
def forward(self, x):
return self.emb(x).detach()
class TemporalEmbedding(nn.Module):
def __init__(self, d_model, embed_type='fixed', freq='h'):
super(TemporalEmbedding, self).__init__()
minute_size = 4
hour_size = 24
weekday_size = 7
day_size = 32
month_size = 13
Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
if freq == 't':
self.minute_embed = Embed(minute_size, d_model)
self.hour_embed = Embed(hour_size, d_model)
self.weekday_embed = Embed(weekday_size, d_model)
self.day_embed = Embed(day_size, d_model)
self.month_embed = Embed(month_size, d_model)
def forward(self, x):
x = x.long()
minute_x = self.minute_embed(x[:, :, 4]) if hasattr(
self, 'minute_embed') else 0.
hour_x = self.hour_embed(x[:, :, 3])
weekday_x = self.weekday_embed(x[:, :, 2])
day_x = self.day_embed(x[:, :, 1])
month_x = self.month_embed(x[:, :, 0])
return hour_x + weekday_x + day_x + month_x + minute_x
class TimeFeatureEmbedding(nn.Module):
def __init__(self, d_model, embed_type='timeF', freq='h'):
super(TimeFeatureEmbedding, self).__init__()
freq_map = {'h': 4, 't': 5, 's': 6,
'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
d_inp = freq_map[freq]
self.embed = nn.Linear(d_inp, d_model, bias=False)
def forward(self, x):
return self.embed(x)
class DataEmbedding(nn.Module):
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
super(DataEmbedding, self).__init__()
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
d_model=d_model, embed_type=embed_type, freq=freq)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
if x_mark is None:
x = self.value_embedding(x) + self.position_embedding(x)
else:
x = self.value_embedding(
x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
return self.dropout(x)
class DataEmbedding_inverted(nn.Module):
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
super(DataEmbedding_inverted, self).__init__()
self.value_embedding = nn.Linear(c_in, d_model)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
x = x.permute(0, 2, 1)
# x: [Batch Variate Time]
if x_mark is None:
x = self.value_embedding(x)
else:
x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1))
# x: [Batch Variate d_model]
return self.dropout(x)
class DataEmbedding_wo_pos(nn.Module):
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
super(DataEmbedding_wo_pos, self).__init__()
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
d_model=d_model, embed_type=embed_type, freq=freq)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
if x_mark is None:
x = self.value_embedding(x)
else:
x = self.value_embedding(x) + self.temporal_embedding(x_mark)
return self.dropout(x)
class PatchEmbedding(nn.Module):
def __init__(self, d_model, patch_len, stride, padding, dropout):
super(PatchEmbedding, self).__init__()
# Patching
self.patch_len = patch_len
self.stride = stride
self.padding_patch_layer = nn.ReplicationPad1d((0, padding))
# Backbone, Input encoding: projection of feature vectors onto a d-dim vector space
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
# Positional embedding
self.position_embedding = PositionalEmbedding(d_model)
# Residual dropout
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# do patching
n_vars = x.shape[1]
x = self.padding_patch_layer(x)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
# Input encoding
x = self.value_embedding(x) + self.position_embedding(x)
return self.dropout(x), n_vars
class ResidualBlock(nn.Module):
def __init__(self, d_inner, dt_rank, d_model=16, d_ff=32,d_conv=4, top_k=5,):
super(ResidualBlock, self).__init__()
self.mixer = MambaBlock(d_inner, dt_rank, d_model=d_model, d_ff=d_ff,d_conv=d_conv, top_k=top_k)
self.norm = RMSNorm(d_model)
def forward(self, x):
output = self.mixer(self.norm(x)) + x
return output
class MambaBlock(nn.Module):
def __init__(self, d_inner, dt_rank, d_model=16, d_ff=32,d_conv=4, top_k=5,):
super(MambaBlock, self).__init__()
self.d_inner = d_inner
self.dt_rank = dt_rank
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(
in_channels = self.d_inner,
out_channels = self.d_inner,
bias = True,
kernel_size = d_conv,
padding = d_conv - 1,
groups = self.d_inner
)
# takes in x and outputs the input-specific delta, B, C
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_ff * 2, bias=False)
# projects delta
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
A = repeat(torch.arange(1, d_ff + 1), "n -> d n", d=self.d_inner)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
def forward(self, x):
"""
Figure 3 in Section 3.4 in the paper
"""
(b, l, d) = x.shape
x_and_res = self.in_proj(x) # [B, L, 2 * d_inner]
(x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1)
x = rearrange(x, "b l d -> b d l")
x = self.conv1d(x)[:, :, :l]
x = rearrange(x, "b d l -> b l d")
x = F.silu(x)
y = self.ssm(x)
y = y * F.silu(res)
output = self.out_proj(y)
return output
def ssm(self, x):
"""
Algorithm 2 in Section 3.2 in the paper
"""
(d_in, n) = self.A_log.shape
A = -torch.exp(self.A_log.float()) # [d_in, n]
D = self.D.float() # [d_in]
x_dbl = self.x_proj(x) # [B, L, d_rank + 2 * d_ff]
(delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: [B, L, d_rank]; B, C: [B, L, n]
delta = F.softplus(self.dt_proj(delta)) # [B, L, d_in]
y = self.selective_scan(x, delta, A, B, C, D)
return y
def selective_scan(self, u, delta, A, B, C, D):
(b, l, d_in) = u.shape
n = A.shape[1]
deltaA = torch.exp(einsum(delta, A, "b l d, d n -> b l d n")) # A is discretized using zero-order hold (ZOH) discretization
deltaB_u = einsum(delta, B, u, "b l d, b l n, b l d -> b l d n") # B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: "A is the more important term and the performance doesn't change much with the simplification on B"
# selective scan, sequential instead of parallel
x = torch.zeros((b, d_in, n), device=deltaA.device)
ys = []
for i in range(l):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = einsum(x, C[:, i, :], "b d n, b n -> b d")
ys.append(y)
y = torch.stack(ys, dim=1) # [B, L, d_in]
y = y + u * D
return y
class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-5):
super(RMSNorm, self).__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
return output
class MambaSimple(nn.Module):
"""
Mamba, linear-time sequence modeling with selective state spaces O(L)
Paper link: https://arxiv.org/abs/2312.00752
Implementation refernce: https://github.com/johnma2006/mamba-minimal/
"""
def __init__(self, task_name='short_term_forecast',seq_len=96, label_len=48, pred_len=96, enc_in=7, dec_in=7, c_out=1, e_layers=2, d_layers=1, n_heads=8,factor=3,
d_model=16, d_ff=32, des='Exp', expand=2, d_conv=4, top_k=5, embed='timeF',freq='h', dropout=0.1,num_kernels=6,
moving_avg=25,channel_independence=1, decomp_method='moving_avg', use_norm=1,
version='fourier', mode_select='random', modes=32, activation='gelu',seasonal_patterns='Monthly',
inverse=False, mask_rate=0.25, anomaly_ratio=0.25,output_attention=False,down_sampling_layers=0, down_sampling_window=1, down_sampling_method=None,
seg_len=48, num_workers=0, itr=1, train_epochs=100, batch_size=32, patience=3, learning_rate=0.0001, loss='MSE',
lradj='type1', use_amp=False, use_gpu=True, gpu=0, use_multi_gpu=False, devices='0,1,2,3', p_hidden_dims=[128, 128],
p_hidden_layers=2, use_dtw=False, augmentation_ratio=0, seed=2, jitter=False, scaling=False, permutation=False,
randompermutation=False, magwarp=False, timewarp=False, windowslice=False, windowwarp=False, rotation=False,
spawner=False, dtwwarp=False, shapedtwwarp=False, wdba=False, discdtw=False, discsdtw=False, extra_tag='',
**kwargs):
super(MambaSimple, self).__init__()
self.task_name = task_name
self.pred_len = pred_len
self.d_inner = d_model * expand
self.dt_rank = math.ceil(d_model / 16)
self.embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout)
self.layers = nn.ModuleList([ResidualBlock(self.d_inner, self.dt_rank,d_model=d_model, d_ff=d_ff,d_conv=d_conv, top_k=top_k) for _ in range(e_layers)])
self.norm = RMSNorm(d_model)
self.out_layer = nn.Linear(d_model, c_out, bias=False)
self.projection_final = nn.Linear(pred_len*enc_in, pred_len*c_out, bias=True)
# def short_term_forecast(self, x_enc, x_mark_enc):
def forecast(self, x_enc, x_mark_enc):
mean_enc = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - mean_enc
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
x_enc = x_enc / std_enc
x = self.embedding(x_enc, x_mark_enc)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
x_out = self.out_layer(x)
x_out = x_out * std_enc + mean_enc
return x_out
# def long_term_forecast(self, x_enc, x_mark_enc):
# x = self.embedding(x_enc, x_mark_enc)
# for layer in self.layers:
# x = layer(x)
# x = self.norm(x)
# x_out = self.out_layer(x)
# return x_out
def forward(self, x_enc, x_mark_enc, x_dec=None, x_mark_dec=None, mask=None):
if self.task_name in ['short_term_forecast', 'long_term_forecast']:
x_out = self.forecast(x_enc, x_mark_enc)
x_out=x_out[:, -self.pred_len:, :]
x_out = self.projection_final(x_out.view(x_out.shape[0], -1))
return x_out
# other tasks not implemented
from pytorch_forecasting.models import BaseModel
from typing import Dict
class MambaSimpleNetModel(BaseModel):
def __init__(self,seq_len=24, label_len=0, pred_len=1, enc_in=7, dec_in=7, c_out=1, e_layers=2, d_layers=1, factor=3,
d_model=16, d_ff=32, des='Exp', itr=1, top_k=5,embed='timeF',freq='h', dropout=0.1,num_kernels=6, **kwargs):
# saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this
self.save_hyperparameters()
# pass additional arguments to BaseModel.__init__, mandatory call - do not skip this
super().__init__(**kwargs)
self.network = MambaSimple(
seq_len=seq_len, label_len=label_len, pred_len=pred_len, enc_in=enc_in, dec_in=dec_in, c_out=c_out, e_layers=e_layers, d_layers=d_layers, factor=factor,
d_model=d_model, d_ff=d_ff, des=des, itr=itr, top_k=top_k, embed=embed, freq=freq, dropout=dropout, num_kernels=num_kernels
)
# 修改,锂电池预测
def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
x_enc = x["encoder_cont"][:,:,:-1]
# 输出
prediction = self.network(x_enc, x_mark_enc=None)
# 输出rescale rescale predictions into target space
prediction = self.transform_output(prediction, target_scale=x["target_scale"])
# 返回一个字典,包含输出结果(prediction)
return self.to_network_output(prediction=prediction)
if __name__=='__main__':
N,L,C=100,24,11
label_len = 0
c_out = 1
pred_len=1
x_enc=torch.ones((N,L,C))
x_mark_enc=torch.ones((N, L, 4))
x_dec = torch.ones((N, pred_len, C))
x_mark_dec=torch.ones((N, pred_len, 4))
model=MambaSimple(seq_len=L, enc_in=C, dec_in=C, label_len = label_len, pred_len=pred_len, c_out=1) # pred_len 被限制了
out=model(x_enc=x_enc, x_mark_enc=x_mark_enc, x_dec=x_dec, x_mark_dec=x_mark_dec)
print(out.shape)