feat: 添加RUL-Mamba模型及相关组件
新增锂电池剩余使用寿命预测模型RUL-Mamba,包含以下主要组件: 1. 添加Mamba模块作为核心时序建模组件 2. 实现特征注意力网络(FAN)和门控残差网络(GRN) 3. 新增数据预处理和归一化层 4. 添加模型训练和评估脚本 5. 补充README文档说明使用方法 6. 添加可视化辅助工具Helper_Plot.py 7. 实现多种时间序列处理层(Embedding、AutoCorrelation等) 8. 添加配置文件requirements.txt 9. 补充测试数据集TJU battery data
This commit is contained in:
@@ -0,0 +1,417 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat, einsum
|
||||
|
||||
# from .layers.Embed import DataEmbedding
|
||||
# import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
import math
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(self, d_model, max_len=5000):
|
||||
super(PositionalEmbedding, self).__init__()
|
||||
# Compute the positional encodings once in log space.
|
||||
pe = torch.zeros(max_len, d_model).float()
|
||||
pe.require_grad = False
|
||||
|
||||
position = torch.arange(0, max_len).float().unsqueeze(1)
|
||||
div_term = (torch.arange(0, d_model, 2).float()
|
||||
* -(math.log(10000.0) / d_model)).exp()
|
||||
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
return self.pe[:, :x.size(1)]
|
||||
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model):
|
||||
super(TokenEmbedding, self).__init__()
|
||||
padding = 1 if torch.__version__ >= '1.5.0' else 2
|
||||
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
||||
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class FixedEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model):
|
||||
super(FixedEmbedding, self).__init__()
|
||||
|
||||
w = torch.zeros(c_in, d_model).float()
|
||||
w.require_grad = False
|
||||
|
||||
position = torch.arange(0, c_in).float().unsqueeze(1)
|
||||
div_term = (torch.arange(0, d_model, 2).float()
|
||||
* -(math.log(10000.0) / d_model)).exp()
|
||||
|
||||
w[:, 0::2] = torch.sin(position * div_term)
|
||||
w[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
self.emb = nn.Embedding(c_in, d_model)
|
||||
self.emb.weight = nn.Parameter(w, requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.emb(x).detach()
|
||||
|
||||
|
||||
class TemporalEmbedding(nn.Module):
|
||||
def __init__(self, d_model, embed_type='fixed', freq='h'):
|
||||
super(TemporalEmbedding, self).__init__()
|
||||
|
||||
minute_size = 4
|
||||
hour_size = 24
|
||||
weekday_size = 7
|
||||
day_size = 32
|
||||
month_size = 13
|
||||
|
||||
Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
|
||||
if freq == 't':
|
||||
self.minute_embed = Embed(minute_size, d_model)
|
||||
self.hour_embed = Embed(hour_size, d_model)
|
||||
self.weekday_embed = Embed(weekday_size, d_model)
|
||||
self.day_embed = Embed(day_size, d_model)
|
||||
self.month_embed = Embed(month_size, d_model)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.long()
|
||||
minute_x = self.minute_embed(x[:, :, 4]) if hasattr(
|
||||
self, 'minute_embed') else 0.
|
||||
hour_x = self.hour_embed(x[:, :, 3])
|
||||
weekday_x = self.weekday_embed(x[:, :, 2])
|
||||
day_x = self.day_embed(x[:, :, 1])
|
||||
month_x = self.month_embed(x[:, :, 0])
|
||||
|
||||
return hour_x + weekday_x + day_x + month_x + minute_x
|
||||
|
||||
|
||||
class TimeFeatureEmbedding(nn.Module):
|
||||
def __init__(self, d_model, embed_type='timeF', freq='h'):
|
||||
super(TimeFeatureEmbedding, self).__init__()
|
||||
|
||||
freq_map = {'h': 4, 't': 5, 's': 6,
|
||||
'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
|
||||
d_inp = freq_map[freq]
|
||||
self.embed = nn.Linear(d_inp, d_model, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embed(x)
|
||||
|
||||
|
||||
class DataEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
|
||||
super(DataEmbedding, self).__init__()
|
||||
|
||||
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
||||
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
||||
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
|
||||
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
|
||||
d_model=d_model, embed_type=embed_type, freq=freq)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, x, x_mark):
|
||||
if x_mark is None:
|
||||
x = self.value_embedding(x) + self.position_embedding(x)
|
||||
else:
|
||||
x = self.value_embedding(
|
||||
x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class DataEmbedding_inverted(nn.Module):
|
||||
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
|
||||
super(DataEmbedding_inverted, self).__init__()
|
||||
self.value_embedding = nn.Linear(c_in, d_model)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, x, x_mark):
|
||||
x = x.permute(0, 2, 1)
|
||||
# x: [Batch Variate Time]
|
||||
if x_mark is None:
|
||||
x = self.value_embedding(x)
|
||||
else:
|
||||
x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1))
|
||||
# x: [Batch Variate d_model]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class DataEmbedding_wo_pos(nn.Module):
|
||||
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
|
||||
super(DataEmbedding_wo_pos, self).__init__()
|
||||
|
||||
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
||||
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
||||
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
|
||||
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
|
||||
d_model=d_model, embed_type=embed_type, freq=freq)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, x, x_mark):
|
||||
if x_mark is None:
|
||||
x = self.value_embedding(x)
|
||||
else:
|
||||
x = self.value_embedding(x) + self.temporal_embedding(x_mark)
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
def __init__(self, d_model, patch_len, stride, padding, dropout):
|
||||
super(PatchEmbedding, self).__init__()
|
||||
# Patching
|
||||
self.patch_len = patch_len
|
||||
self.stride = stride
|
||||
self.padding_patch_layer = nn.ReplicationPad1d((0, padding))
|
||||
|
||||
# Backbone, Input encoding: projection of feature vectors onto a d-dim vector space
|
||||
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
|
||||
|
||||
# Positional embedding
|
||||
self.position_embedding = PositionalEmbedding(d_model)
|
||||
|
||||
# Residual dropout
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
# do patching
|
||||
n_vars = x.shape[1]
|
||||
x = self.padding_patch_layer(x)
|
||||
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
||||
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
|
||||
# Input encoding
|
||||
x = self.value_embedding(x) + self.position_embedding(x)
|
||||
return self.dropout(x), n_vars
|
||||
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, d_inner, dt_rank, d_model=16, d_ff=32,d_conv=4, top_k=5,):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.mixer = MambaBlock(d_inner, dt_rank, d_model=d_model, d_ff=d_ff,d_conv=d_conv, top_k=top_k)
|
||||
self.norm = RMSNorm(d_model)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.mixer(self.norm(x)) + x
|
||||
return output
|
||||
|
||||
class MambaBlock(nn.Module):
|
||||
def __init__(self, d_inner, dt_rank, d_model=16, d_ff=32,d_conv=4, top_k=5,):
|
||||
super(MambaBlock, self).__init__()
|
||||
self.d_inner = d_inner
|
||||
self.dt_rank = dt_rank
|
||||
|
||||
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
|
||||
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels = self.d_inner,
|
||||
out_channels = self.d_inner,
|
||||
bias = True,
|
||||
kernel_size = d_conv,
|
||||
padding = d_conv - 1,
|
||||
groups = self.d_inner
|
||||
)
|
||||
|
||||
# takes in x and outputs the input-specific delta, B, C
|
||||
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_ff * 2, bias=False)
|
||||
|
||||
# projects delta
|
||||
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
|
||||
|
||||
A = repeat(torch.arange(1, d_ff + 1), "n -> d n", d=self.d_inner)
|
||||
self.A_log = nn.Parameter(torch.log(A))
|
||||
self.D = nn.Parameter(torch.ones(self.d_inner))
|
||||
|
||||
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Figure 3 in Section 3.4 in the paper
|
||||
"""
|
||||
(b, l, d) = x.shape
|
||||
|
||||
x_and_res = self.in_proj(x) # [B, L, 2 * d_inner]
|
||||
(x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1)
|
||||
|
||||
x = rearrange(x, "b l d -> b d l")
|
||||
x = self.conv1d(x)[:, :, :l]
|
||||
x = rearrange(x, "b d l -> b l d")
|
||||
|
||||
x = F.silu(x)
|
||||
|
||||
y = self.ssm(x)
|
||||
y = y * F.silu(res)
|
||||
|
||||
output = self.out_proj(y)
|
||||
return output
|
||||
|
||||
|
||||
def ssm(self, x):
|
||||
"""
|
||||
Algorithm 2 in Section 3.2 in the paper
|
||||
"""
|
||||
|
||||
(d_in, n) = self.A_log.shape
|
||||
|
||||
A = -torch.exp(self.A_log.float()) # [d_in, n]
|
||||
D = self.D.float() # [d_in]
|
||||
|
||||
x_dbl = self.x_proj(x) # [B, L, d_rank + 2 * d_ff]
|
||||
(delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: [B, L, d_rank]; B, C: [B, L, n]
|
||||
delta = F.softplus(self.dt_proj(delta)) # [B, L, d_in]
|
||||
y = self.selective_scan(x, delta, A, B, C, D)
|
||||
|
||||
return y
|
||||
|
||||
def selective_scan(self, u, delta, A, B, C, D):
|
||||
(b, l, d_in) = u.shape
|
||||
n = A.shape[1]
|
||||
|
||||
deltaA = torch.exp(einsum(delta, A, "b l d, d n -> b l d n")) # A is discretized using zero-order hold (ZOH) discretization
|
||||
deltaB_u = einsum(delta, B, u, "b l d, b l n, b l d -> b l d n") # B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: "A is the more important term and the performance doesn't change much with the simplification on B"
|
||||
|
||||
# selective scan, sequential instead of parallel
|
||||
x = torch.zeros((b, d_in, n), device=deltaA.device)
|
||||
ys = []
|
||||
for i in range(l):
|
||||
x = deltaA[:, i] * x + deltaB_u[:, i]
|
||||
y = einsum(x, C[:, i, :], "b d n, b n -> b d")
|
||||
ys.append(y)
|
||||
|
||||
y = torch.stack(ys, dim=1) # [B, L, d_in]
|
||||
y = y + u * D
|
||||
|
||||
return y
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, d_model, eps=1e-5):
|
||||
super(RMSNorm, self).__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(d_model))
|
||||
|
||||
def forward(self, x):
|
||||
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
||||
return output
|
||||
|
||||
|
||||
class MambaSimple(nn.Module):
|
||||
"""
|
||||
Mamba, linear-time sequence modeling with selective state spaces O(L)
|
||||
Paper link: https://arxiv.org/abs/2312.00752
|
||||
Implementation refernce: https://github.com/johnma2006/mamba-minimal/
|
||||
"""
|
||||
|
||||
def __init__(self, task_name='short_term_forecast',seq_len=96, label_len=48, pred_len=96, enc_in=7, dec_in=7, c_out=1, e_layers=2, d_layers=1, n_heads=8,factor=3,
|
||||
d_model=16, d_ff=32, des='Exp', expand=2, d_conv=4, top_k=5, embed='timeF',freq='h', dropout=0.1,num_kernels=6,
|
||||
moving_avg=25,channel_independence=1, decomp_method='moving_avg', use_norm=1,
|
||||
version='fourier', mode_select='random', modes=32, activation='gelu',seasonal_patterns='Monthly',
|
||||
inverse=False, mask_rate=0.25, anomaly_ratio=0.25,output_attention=False,down_sampling_layers=0, down_sampling_window=1, down_sampling_method=None,
|
||||
seg_len=48, num_workers=0, itr=1, train_epochs=100, batch_size=32, patience=3, learning_rate=0.0001, loss='MSE',
|
||||
lradj='type1', use_amp=False, use_gpu=True, gpu=0, use_multi_gpu=False, devices='0,1,2,3', p_hidden_dims=[128, 128],
|
||||
p_hidden_layers=2, use_dtw=False, augmentation_ratio=0, seed=2, jitter=False, scaling=False, permutation=False,
|
||||
randompermutation=False, magwarp=False, timewarp=False, windowslice=False, windowwarp=False, rotation=False,
|
||||
spawner=False, dtwwarp=False, shapedtwwarp=False, wdba=False, discdtw=False, discsdtw=False, extra_tag='',
|
||||
**kwargs):
|
||||
super(MambaSimple, self).__init__()
|
||||
self.task_name = task_name
|
||||
self.pred_len = pred_len
|
||||
|
||||
self.d_inner = d_model * expand
|
||||
self.dt_rank = math.ceil(d_model / 16)
|
||||
|
||||
self.embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout)
|
||||
|
||||
self.layers = nn.ModuleList([ResidualBlock(self.d_inner, self.dt_rank,d_model=d_model, d_ff=d_ff,d_conv=d_conv, top_k=top_k) for _ in range(e_layers)])
|
||||
self.norm = RMSNorm(d_model)
|
||||
|
||||
self.out_layer = nn.Linear(d_model, c_out, bias=False)
|
||||
self.projection_final = nn.Linear(pred_len*enc_in, pred_len*c_out, bias=True)
|
||||
# def short_term_forecast(self, x_enc, x_mark_enc):
|
||||
def forecast(self, x_enc, x_mark_enc):
|
||||
mean_enc = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - mean_enc
|
||||
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
||||
x_enc = x_enc / std_enc
|
||||
|
||||
x = self.embedding(x_enc, x_mark_enc)
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
|
||||
x = self.norm(x)
|
||||
x_out = self.out_layer(x)
|
||||
|
||||
x_out = x_out * std_enc + mean_enc
|
||||
return x_out
|
||||
|
||||
# def long_term_forecast(self, x_enc, x_mark_enc):
|
||||
# x = self.embedding(x_enc, x_mark_enc)
|
||||
# for layer in self.layers:
|
||||
# x = layer(x)
|
||||
|
||||
# x = self.norm(x)
|
||||
# x_out = self.out_layer(x)
|
||||
# return x_out
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec=None, x_mark_dec=None, mask=None):
|
||||
if self.task_name in ['short_term_forecast', 'long_term_forecast']:
|
||||
x_out = self.forecast(x_enc, x_mark_enc)
|
||||
x_out=x_out[:, -self.pred_len:, :]
|
||||
x_out = self.projection_final(x_out.view(x_out.shape[0], -1))
|
||||
return x_out
|
||||
|
||||
# other tasks not implemented
|
||||
|
||||
from pytorch_forecasting.models import BaseModel
|
||||
from typing import Dict
|
||||
|
||||
class MambaSimpleNetModel(BaseModel):
|
||||
def __init__(self,seq_len=24, label_len=0, pred_len=1, enc_in=7, dec_in=7, c_out=1, e_layers=2, d_layers=1, factor=3,
|
||||
d_model=16, d_ff=32, des='Exp', itr=1, top_k=5,embed='timeF',freq='h', dropout=0.1,num_kernels=6, **kwargs):
|
||||
# saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this
|
||||
self.save_hyperparameters()
|
||||
# pass additional arguments to BaseModel.__init__, mandatory call - do not skip this
|
||||
super().__init__(**kwargs)
|
||||
self.network = MambaSimple(
|
||||
seq_len=seq_len, label_len=label_len, pred_len=pred_len, enc_in=enc_in, dec_in=dec_in, c_out=c_out, e_layers=e_layers, d_layers=d_layers, factor=factor,
|
||||
d_model=d_model, d_ff=d_ff, des=des, itr=itr, top_k=top_k, embed=embed, freq=freq, dropout=dropout, num_kernels=num_kernels
|
||||
)
|
||||
|
||||
|
||||
# 修改,锂电池预测
|
||||
def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
|
||||
x_enc = x["encoder_cont"][:,:,:-1]
|
||||
# 输出
|
||||
prediction = self.network(x_enc, x_mark_enc=None)
|
||||
# 输出rescale, rescale predictions into target space
|
||||
prediction = self.transform_output(prediction, target_scale=x["target_scale"])
|
||||
|
||||
# 返回一个字典,包含输出结果(prediction)
|
||||
return self.to_network_output(prediction=prediction)
|
||||
|
||||
|
||||
if __name__=='__main__':
|
||||
N,L,C=100,24,11
|
||||
label_len = 0
|
||||
c_out = 1
|
||||
pred_len=1
|
||||
x_enc=torch.ones((N,L,C))
|
||||
x_mark_enc=torch.ones((N, L, 4))
|
||||
x_dec = torch.ones((N, pred_len, C))
|
||||
x_mark_dec=torch.ones((N, pred_len, 4))
|
||||
model=MambaSimple(seq_len=L, enc_in=C, dec_in=C, label_len = label_len, pred_len=pred_len, c_out=1) # pred_len 被限制了
|
||||
out=model(x_enc=x_enc, x_mark_enc=x_mark_enc, x_dec=x_dec, x_mark_dec=x_mark_dec)
|
||||
print(out.shape)
|
||||
Reference in New Issue
Block a user