Files
eason 79db6e5c96 feat: 添加RUL-Mamba模型及相关组件
新增锂电池剩余使用寿命预测模型RUL-Mamba,包含以下主要组件:
1. 添加Mamba模块作为核心时序建模组件
2. 实现特征注意力网络(FAN)和门控残差网络(GRN)
3. 新增数据预处理和归一化层
4. 添加模型训练和评估脚本
5. 补充README文档说明使用方法
6. 添加可视化辅助工具Helper_Plot.py
7. 实现多种时间序列处理层(Embedding、AutoCorrelation等)
8. 添加配置文件requirements.txt
9. 补充测试数据集TJU battery data
2026-01-09 08:53:50 +08:00

348 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
import math
from pytorch_forecasting.models import BaseModel
from typing import Dict
class PositionalEmbedding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEmbedding, self).__init__()
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model).float()
pe.require_grad = False
position = torch.arange(0, max_len).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float()
* -(math.log(10000.0) / d_model)).exp()
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return self.pe[:, :x.size(1)]
class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(TokenEmbedding, self).__init__()
padding = 1 if torch.__version__ >= '1.5.0' else 2
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(
m.weight, mode='fan_in', nonlinearity='leaky_relu')
def forward(self, x):
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
return x
class FixedEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(FixedEmbedding, self).__init__()
w = torch.zeros(c_in, d_model).float()
w.require_grad = False
position = torch.arange(0, c_in).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float()
* -(math.log(10000.0) / d_model)).exp()
w[:, 0::2] = torch.sin(position * div_term)
w[:, 1::2] = torch.cos(position * div_term)
self.emb = nn.Embedding(c_in, d_model)
self.emb.weight = nn.Parameter(w, requires_grad=False)
def forward(self, x):
return self.emb(x).detach()
class TemporalEmbedding(nn.Module):
def __init__(self, d_model, embed_type='fixed', freq='h'):
super(TemporalEmbedding, self).__init__()
minute_size = 4
hour_size = 24
weekday_size = 7
day_size = 32
month_size = 13
Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
if freq == 't':
self.minute_embed = Embed(minute_size, d_model)
self.hour_embed = Embed(hour_size, d_model)
self.weekday_embed = Embed(weekday_size, d_model)
self.day_embed = Embed(day_size, d_model)
self.month_embed = Embed(month_size, d_model)
def forward(self, x):
x = x.long()
minute_x = self.minute_embed(x[:, :, 4]) if hasattr(
self, 'minute_embed') else 0.
hour_x = self.hour_embed(x[:, :, 3])
weekday_x = self.weekday_embed(x[:, :, 2])
day_x = self.day_embed(x[:, :, 1])
month_x = self.month_embed(x[:, :, 0])
return hour_x + weekday_x + day_x + month_x + minute_x
class TimeFeatureEmbedding(nn.Module):
def __init__(self, d_model, embed_type='timeF', freq='h'):
super(TimeFeatureEmbedding, self).__init__()
freq_map = {'h': 4, 't': 5, 's': 6,
'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
d_inp = freq_map[freq]
self.embed = nn.Linear(d_inp, d_model, bias=False)
def forward(self, x):
return self.embed(x)
class DataEmbedding(nn.Module):
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
super(DataEmbedding, self).__init__()
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
d_model=d_model, embed_type=embed_type, freq=freq)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
if x_mark is None:
x = self.value_embedding(x) + self.position_embedding(x)
else:
x = self.value_embedding(
x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
return self.dropout(x)
class Inception_Block_V1(nn.Module):
def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
super(Inception_Block_V1, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.num_kernels = num_kernels
kernels = []
for i in range(self.num_kernels):
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i))
self.kernels = nn.ModuleList(kernels)
if init_weight:
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
res_list = []
for i in range(self.num_kernels):
res_list.append(self.kernels[i](x))
res = torch.stack(res_list, dim=-1).mean(-1)
return res
def FFT_for_Period(x, k=2):
# [B, T, C]
xf = torch.fft.rfft(x, dim=1)
# find period by amplitudes
frequency_list = abs(xf).mean(0).mean(-1)
frequency_list[0] = 0
_, top_list = torch.topk(frequency_list, k)
top_list = top_list.detach().cpu().numpy()
period = x.shape[1] // top_list
return period, abs(xf).mean(-1)[:, top_list]
class TimesBlock(nn.Module):
def __init__(self, seq_len,pred_len,top_k,d_model,d_ff,num_kernels):
super(TimesBlock, self).__init__()
self.seq_len = seq_len
self.pred_len = pred_len
self.k = top_k
# parameter-efficient design
self.conv = nn.Sequential(
Inception_Block_V1(d_model, d_ff, num_kernels=num_kernels),
nn.GELU(),
Inception_Block_V1(d_ff, d_model, num_kernels=num_kernels)
)
def forward(self, x):
B, T, N = x.size()
period_list, period_weight = FFT_for_Period(x, self.k)
res = []
for i in range(self.k):
period = period_list[i]
# padding
if (self.seq_len + self.pred_len) % period != 0:
length = (((self.seq_len + self.pred_len) // period) + 1) * period
padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
out = torch.cat([x, padding], dim=1)
else:
length = (self.seq_len + self.pred_len)
out = x
# reshape
out = out.reshape(B, length // period, period,
N).permute(0, 3, 1, 2).contiguous()
# 2D conv: from 1d Variation to 2d Variation
out = self.conv(out)
# reshape back
out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
res.append(out[:, :(self.seq_len + self.pred_len), :])
res = torch.stack(res, dim=-1)
# adaptive aggregation
period_weight = F.softmax(period_weight, dim=1)
period_weight = period_weight.unsqueeze(
1).unsqueeze(1).repeat(1, T, N, 1)
res = torch.sum(res * period_weight, -1)
# residual connection
res = res + x
return res
class TimesNet(nn.Module):
"""
Paper link: https://openreview.net/pdf?id=ju_Uqw384Oq
"""
def __init__(self, seq_len=24, label_len=0, pred_len=1, enc_in=7, dec_in=7, c_out=1, e_layers=2, d_layers=1, factor=3, d_model=16, d_ff=32, des='Exp', itr=1, top_k=5,embed='timeF',freq='h', dropout=0.1,num_kernels=6 ):
super(TimesNet, self).__init__()
self.task_name = 'TimesNet'
self.seq_len = seq_len
self.label_len = label_len
self.pred_len = pred_len
self.model = nn.ModuleList([TimesBlock(seq_len,pred_len,top_k,d_model,d_ff,num_kernels)
for _ in range(e_layers)])
self.enc_embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout)
self.layer = e_layers
self.layer_norm = nn.LayerNorm(d_model)
self.predict_linear = nn.Linear(self.seq_len, self.pred_len+ self.seq_len)
self.projection = nn.Linear(d_model, c_out, bias=True)
def forward(self, x_enc, x_mark_enc=None):
'''
:param x_enc: torch.Size([32, 96, 7])
:param x_mark_enc: torch.Size([32, 96, 4])
:param x_dec: torch.Size([32, 144, 7])
:param x_mark_dec: torch.Size([32, 144, 4])
:return:
'''
# Normalization from Non-stationary Transformer 减均值除方差
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
target_mean = means[:, :, -1]
target_stdev = stdev[:, :, -1]
# embedding
'''
self.enc_embedding:
DataEmbedding(
(value_embedding): TokenEmbedding(
(tokenConv): Conv1d(7, 16, kernel_size=(3,), stride=(1,), padding=(1,), bias=False, padding_mode=circular)
)
(position_embedding): PositionalEmbedding()
(temporal_embedding): TimeFeatureEmbedding(
(embed): Linear(in_features=4, out_features=16, bias=False)
)
(dropout): Dropout(p=0.1, inplace=False)
)
'''
enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C] x_mark_enc=None
'''
self.predict_linear:
Linear(in_features=96, out_features=192, bias=True)
enc_out :[100,96,16] -> [100,16,96]-> [100,16,192]->[100,192,16]
'''
enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(0, 2, 1) # align temporal dimension
# TimesNet
'''
self.model[i]:
TimesBlock(
(conv): Sequential(
(0): Inception_Block_V1(
(kernels): ModuleList(
(0): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(3): Conv2d(16, 32, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
(4): Conv2d(16, 32, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
(5): Conv2d(16, 32, kernel_size=(11, 11), stride=(1, 1), padding=(5, 5))
)
)
(1): GELU(approximate=none)
(2): Inception_Block_V1(
(kernels): ModuleList(
(0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(32, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(3): Conv2d(32, 16, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
(4): Conv2d(32, 16, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
(5): Conv2d(32, 16, kernel_size=(11, 11), stride=(1, 1), padding=(5, 5))
)
)
self.layer_norm:
LayerNorm((16,), eps=1e-05, elementwise_affine=True)
'''
for i in range(self.layer):
enc_out = self.layer_norm(self.model[i](enc_out))
# porject back
'''
self.projection:
Linear(in_features=16, out_features=7, bias=True)
enc_out: [100, 97, 16]-> [100, 97, 1]
'''
dec_out = self.projection(enc_out) # torch.Size([100, 34, 16])-> torch.Size([100, 34, 1])
dec_out=dec_out*target_stdev.unsqueeze(1).repeat(1, self.pred_len + self.seq_len, 1)+target_mean.unsqueeze(1).repeat(1, self.pred_len + self.seq_len, 1)
return dec_out[:, -self.pred_len:, :]
class TimesNetModel(BaseModel):
def __init__(self,seq_len=24, label_len=0, pred_len=1, enc_in=7, dec_in=7, c_out=1, e_layers=2, d_layers=1, factor=3, d_model=16, d_ff=32, des='Exp', itr=1, top_k=5,embed='timeF',freq='h', dropout=0.1,num_kernels=6, **kwargs):
# saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this
self.save_hyperparameters()
# pass additional arguments to BaseModel.__init__, mandatory call - do not skip this
super().__init__(**kwargs)
self.network = TimesNet(
seq_len=seq_len, label_len=label_len, pred_len=pred_len, enc_in=enc_in, dec_in=dec_in, c_out=c_out, e_layers=e_layers, d_layers=d_layers, factor=factor,
d_model=d_model, d_ff=d_ff, des=des, itr=itr, top_k=top_k, embed=embed, freq=freq, dropout=dropout, num_kernels=num_kernels
)
# 修改,锂电池预测
def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
x_enc = x["encoder_cont"][:,:,:-1] # torch.Size([100, 10, 9])
# 输出
prediction = self.network(x_enc)
# 输出rescale rescale predictions into target space
prediction = self.transform_output(prediction, target_scale=x["target_scale"])
# 返回一个字典,包含输出结果(prediction)
return self.to_network_output(prediction=prediction)
if __name__=='__main__':
N,L,C=100,96,5
input=torch.ones((N,L,C))
model=TimesNet(seq_len=L, enc_in=C, pred_len=1)
out=model(input)
print(out.shape)