feat: 添加RUL-Mamba模型及相关组件
新增锂电池剩余使用寿命预测模型RUL-Mamba,包含以下主要组件: 1. 添加Mamba模块作为核心时序建模组件 2. 实现特征注意力网络(FAN)和门控残差网络(GRN) 3. 新增数据预处理和归一化层 4. 添加模型训练和评估脚本 5. 补充README文档说明使用方法 6. 添加可视化辅助工具Helper_Plot.py 7. 实现多种时间序列处理层(Embedding、AutoCorrelation等) 8. 添加配置文件requirements.txt 9. 补充测试数据集TJU battery data
This commit is contained in:
@@ -0,0 +1,198 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(self, d_model, n_position=1024):
|
||||
super(PositionalEmbedding, self).__init__()
|
||||
|
||||
# Not a parameter
|
||||
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_model))
|
||||
|
||||
def _get_sinusoid_encoding_table(self, n_position, d_model):
|
||||
''' Sinusoid position encoding table '''
|
||||
def get_position_angle_vec(position):
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / d_model) for hid_j in range(d_model)]
|
||||
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
return self.pos_table[:, :x.size(1)].clone().detach()
|
||||
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model):
|
||||
super(TokenEmbedding, self).__init__()
|
||||
padding = 1 if torch.__version__ >= '1.5.0' else 2
|
||||
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
||||
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class FixedEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model):
|
||||
super(FixedEmbedding, self).__init__()
|
||||
|
||||
w = torch.zeros(c_in, d_model).float()
|
||||
w.require_grad = False
|
||||
|
||||
position = torch.arange(0, c_in).float().unsqueeze(1)
|
||||
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
|
||||
|
||||
w[:, 0::2] = torch.sin(position * div_term)
|
||||
w[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
self.emb = nn.Embedding(c_in, d_model)
|
||||
self.emb.weight = nn.Parameter(w, requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.emb(x).detach()
|
||||
|
||||
|
||||
class TemporalEmbedding(nn.Module):
|
||||
def __init__(self, d_model, embed_type='fixed', freq='h'):
|
||||
super(TemporalEmbedding, self).__init__()
|
||||
|
||||
minute_size = 4
|
||||
hour_size = 24
|
||||
weekday_size = 7
|
||||
day_size = 32
|
||||
month_size = 13
|
||||
|
||||
Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
|
||||
if freq == 't':
|
||||
self.minute_embed = Embed(minute_size, d_model)
|
||||
self.hour_embed = Embed(hour_size, d_model)
|
||||
self.weekday_embed = Embed(weekday_size, d_model)
|
||||
self.day_embed = Embed(day_size, d_model)
|
||||
self.month_embed = Embed(month_size, d_model)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.long()
|
||||
|
||||
minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, 'minute_embed') else 0.
|
||||
hour_x = self.hour_embed(x[:, :, 3])
|
||||
weekday_x = self.weekday_embed(x[:, :, 2])
|
||||
day_x = self.day_embed(x[:, :, 1])
|
||||
month_x = self.month_embed(x[:, :, 0])
|
||||
|
||||
return hour_x + weekday_x + day_x + month_x + minute_x
|
||||
|
||||
|
||||
class TimeFeatureEmbedding(nn.Module):
|
||||
def __init__(self, d_model, embed_type='timeF', freq='h'):
|
||||
super(TimeFeatureEmbedding, self).__init__()
|
||||
|
||||
freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
|
||||
d_inp = freq_map[freq]
|
||||
self.embed = nn.Linear(d_inp, d_model, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embed(x)
|
||||
|
||||
|
||||
class DataEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
|
||||
super(DataEmbedding, self).__init__()
|
||||
|
||||
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
||||
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
||||
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
|
||||
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
|
||||
d_model=d_model, embed_type=embed_type, freq=freq)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, x, x_mark):
|
||||
x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class DataEmbedding_wo_temp(nn.Module):
|
||||
def __init__(self, c_in, d_model, dropout=0.1):
|
||||
super(DataEmbedding_wo_temp, self).__init__()
|
||||
|
||||
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
||||
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, x, x_mark=None):
|
||||
x = self.value_embedding(x) + self.position_embedding(x)
|
||||
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
|
||||
|
||||
def PositionalEncoding(q_len, d_model, normalize=True):
|
||||
pe = torch.zeros(q_len, d_model)
|
||||
position = torch.arange(0, q_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
if normalize:
|
||||
pe = pe - pe.mean()
|
||||
pe = pe / (pe.std() * 10)
|
||||
return pe
|
||||
|
||||
SinCosPosEncoding = PositionalEncoding
|
||||
|
||||
def Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=1e-3, verbose=False):
|
||||
x = .5 if exponential else 1
|
||||
i = 0
|
||||
for i in range(100):
|
||||
cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * (torch.linspace(0, 1, d_model).reshape(1, -1) ** x) - 1
|
||||
pv(f'{i:4.0f} {x:5.3f} {cpe.mean():+6.3f}', verbose)
|
||||
if abs(cpe.mean()) <= eps: break
|
||||
elif cpe.mean() > eps: x += .001
|
||||
else: x -= .001
|
||||
i += 1
|
||||
if normalize:
|
||||
cpe = cpe - cpe.mean()
|
||||
cpe = cpe / (cpe.std() * 10)
|
||||
return cpe
|
||||
|
||||
def Coord1dPosEncoding(q_len, exponential=False, normalize=True):
|
||||
cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1)**(.5 if exponential else 1)) - 1)
|
||||
if normalize:
|
||||
cpe = cpe - cpe.mean()
|
||||
cpe = cpe / (cpe.std() * 10)
|
||||
return cpe
|
||||
|
||||
def positional_encoding(pe, learn_pe, q_len, d_model):
|
||||
# Positional encoding
|
||||
if pe == None:
|
||||
W_pos = torch.empty((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe
|
||||
nn.init.uniform_(W_pos, -0.02, 0.02)
|
||||
learn_pe = False
|
||||
elif pe == 'zero':
|
||||
W_pos = torch.empty((q_len, 1))
|
||||
nn.init.uniform_(W_pos, -0.02, 0.02)
|
||||
elif pe == 'zeros':
|
||||
W_pos = torch.empty((q_len, d_model))
|
||||
nn.init.uniform_(W_pos, -0.02, 0.02)
|
||||
elif pe == 'normal' or pe == 'gauss':
|
||||
W_pos = torch.zeros((q_len, 1))
|
||||
torch.nn.init.normal_(W_pos, mean=0.0, std=0.1)
|
||||
elif pe == 'uniform':
|
||||
W_pos = torch.zeros((q_len, 1))
|
||||
nn.init.uniform_(W_pos, a=0.0, b=0.1)
|
||||
elif pe == 'lin1d': W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True)
|
||||
elif pe == 'exp1d': W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True)
|
||||
elif pe == 'lin2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True)
|
||||
elif pe == 'exp2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=True)
|
||||
elif pe == 'sincos': W_pos = PositionalEncoding(q_len, d_model, normalize=True)
|
||||
else: raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \
|
||||
'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)")
|
||||
return nn.Parameter(W_pos, requires_grad=learn_pe)
|
||||
Reference in New Issue
Block a user