79db6e5c96
新增锂电池剩余使用寿命预测模型RUL-Mamba,包含以下主要组件: 1. 添加Mamba模块作为核心时序建模组件 2. 实现特征注意力网络(FAN)和门控残差网络(GRN) 3. 新增数据预处理和归一化层 4. 添加模型训练和评估脚本 5. 补充README文档说明使用方法 6. 添加可视化辅助工具Helper_Plot.py 7. 实现多种时间序列处理层(Embedding、AutoCorrelation等) 8. 添加配置文件requirements.txt 9. 补充测试数据集TJU battery data
61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
class RevIN(nn.Module):
|
|
def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False):
|
|
"""
|
|
:param num_features: the number of features or channels
|
|
:param eps: a value added for numerical stability
|
|
:param affine: if True, RevIN has learnable affine parameters
|
|
"""
|
|
super(RevIN, self).__init__()
|
|
self.num_features = num_features
|
|
self.eps = eps
|
|
self.affine = affine
|
|
self.subtract_last = subtract_last
|
|
if self.affine:
|
|
self._init_params()
|
|
|
|
def forward(self, x, mode:str):
|
|
if mode == 'norm':
|
|
self._get_statistics(x)
|
|
x = self._normalize(x)
|
|
elif mode == 'denorm':
|
|
x = self._denormalize(x)
|
|
else: raise NotImplementedError
|
|
return x
|
|
|
|
def _init_params(self):
|
|
# initialize RevIN params: (C,)
|
|
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
|
|
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
|
|
|
|
def _get_statistics(self, x):
|
|
dim2reduce = tuple(range(1, x.ndim-1))
|
|
if self.subtract_last:
|
|
self.last = x[:,-1,:].unsqueeze(1)
|
|
else:
|
|
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
|
|
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
|
|
|
|
def _normalize(self, x):
|
|
if self.subtract_last:
|
|
x = x - self.last
|
|
else:
|
|
x = x - self.mean
|
|
x = x / self.stdev
|
|
if self.affine:
|
|
x = x * self.affine_weight
|
|
x = x + self.affine_bias
|
|
return x
|
|
|
|
def _denormalize(self, x):
|
|
if self.affine:
|
|
x = x - self.affine_bias
|
|
x = x / (self.affine_weight + self.eps*self.eps)
|
|
x = x * self.stdev
|
|
if self.subtract_last:
|
|
x = x + self.last
|
|
else:
|
|
x = x + self.mean
|
|
return x |