79db6e5c96
新增锂电池剩余使用寿命预测模型RUL-Mamba,包含以下主要组件: 1. 添加Mamba模块作为核心时序建模组件 2. 实现特征注意力网络(FAN)和门控残差网络(GRN) 3. 新增数据预处理和归一化层 4. 添加模型训练和评估脚本 5. 补充README文档说明使用方法 6. 添加可视化辅助工具Helper_Plot.py 7. 实现多种时间序列处理层(Embedding、AutoCorrelation等) 8. 添加配置文件requirements.txt 9. 补充测试数据集TJU battery data
1454 lines
59 KiB
Python
1454 lines
59 KiB
Python
import torch
|
||
import numpy as np
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from torch import Tensor
|
||
from typing import List, Tuple
|
||
import math
|
||
from functools import partial
|
||
from torch import nn, einsum, diagonal
|
||
from math import log2, ceil
|
||
import pdb
|
||
from sympy import Poly, legendre, Symbol, chebyshevt
|
||
from scipy.special import eval_legendre
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import math
|
||
import numpy as np
|
||
from pytorch_forecasting.models import BaseModel
|
||
from typing import Dict
|
||
# from layers.Embed import DataEmbedding
|
||
# from .layers.AutoCorrelation import AutoCorrelationLayer
|
||
# from .layers.FourierCorrelation import FourierBlock, FourierCrossAttention
|
||
# from .layers.MultiWaveletCorrelation import MultiWaveletCross, MultiWaveletTransform
|
||
# from .layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp
|
||
|
||
class PositionalEmbedding(nn.Module):
|
||
def __init__(self, d_model, max_len=5000):
|
||
super(PositionalEmbedding, self).__init__()
|
||
# Compute the positional encodings once in log space.
|
||
pe = torch.zeros(max_len, d_model).float()
|
||
pe.require_grad = False
|
||
|
||
position = torch.arange(0, max_len).float().unsqueeze(1)
|
||
div_term = (torch.arange(0, d_model, 2).float()
|
||
* -(math.log(10000.0) / d_model)).exp()
|
||
|
||
pe[:, 0::2] = torch.sin(position * div_term)
|
||
pe[:, 1::2] = torch.cos(position * div_term)
|
||
|
||
pe = pe.unsqueeze(0)
|
||
self.register_buffer('pe', pe)
|
||
|
||
def forward(self, x):
|
||
return self.pe[:, :x.size(1)]
|
||
|
||
|
||
class TokenEmbedding(nn.Module):
|
||
def __init__(self, c_in, d_model):
|
||
super(TokenEmbedding, self).__init__()
|
||
padding = 1 if torch.__version__ >= '1.5.0' else 2
|
||
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
||
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
||
for m in self.modules():
|
||
if isinstance(m, nn.Conv1d):
|
||
nn.init.kaiming_normal_(
|
||
m.weight, mode='fan_in', nonlinearity='leaky_relu')
|
||
|
||
def forward(self, x):
|
||
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
|
||
return x
|
||
|
||
|
||
class FixedEmbedding(nn.Module):
|
||
def __init__(self, c_in, d_model):
|
||
super(FixedEmbedding, self).__init__()
|
||
|
||
w = torch.zeros(c_in, d_model).float()
|
||
w.require_grad = False
|
||
|
||
position = torch.arange(0, c_in).float().unsqueeze(1)
|
||
div_term = (torch.arange(0, d_model, 2).float()
|
||
* -(math.log(10000.0) / d_model)).exp()
|
||
|
||
w[:, 0::2] = torch.sin(position * div_term)
|
||
w[:, 1::2] = torch.cos(position * div_term)
|
||
|
||
self.emb = nn.Embedding(c_in, d_model)
|
||
self.emb.weight = nn.Parameter(w, requires_grad=False)
|
||
|
||
def forward(self, x):
|
||
return self.emb(x).detach()
|
||
|
||
|
||
class TemporalEmbedding(nn.Module):
|
||
def __init__(self, d_model, embed_type='fixed', freq='h'):
|
||
super(TemporalEmbedding, self).__init__()
|
||
|
||
minute_size = 4
|
||
hour_size = 24
|
||
weekday_size = 7
|
||
day_size = 32
|
||
month_size = 13
|
||
|
||
Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
|
||
if freq == 't':
|
||
self.minute_embed = Embed(minute_size, d_model)
|
||
self.hour_embed = Embed(hour_size, d_model)
|
||
self.weekday_embed = Embed(weekday_size, d_model)
|
||
self.day_embed = Embed(day_size, d_model)
|
||
self.month_embed = Embed(month_size, d_model)
|
||
|
||
def forward(self, x):
|
||
x = x.long()
|
||
minute_x = self.minute_embed(x[:, :, 4]) if hasattr(
|
||
self, 'minute_embed') else 0.
|
||
hour_x = self.hour_embed(x[:, :, 3])
|
||
weekday_x = self.weekday_embed(x[:, :, 2])
|
||
day_x = self.day_embed(x[:, :, 1])
|
||
month_x = self.month_embed(x[:, :, 0])
|
||
|
||
return hour_x + weekday_x + day_x + month_x + minute_x
|
||
|
||
|
||
class TimeFeatureEmbedding(nn.Module):
|
||
def __init__(self, d_model, embed_type='timeF', freq='h'):
|
||
super(TimeFeatureEmbedding, self).__init__()
|
||
|
||
freq_map = {'h': 4, 't': 5, 's': 6,
|
||
'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
|
||
d_inp = freq_map[freq]
|
||
self.embed = nn.Linear(d_inp, d_model, bias=False)
|
||
|
||
def forward(self, x):
|
||
return self.embed(x)
|
||
|
||
|
||
class DataEmbedding(nn.Module):
|
||
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
|
||
super(DataEmbedding, self).__init__()
|
||
|
||
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
||
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
||
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
|
||
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
|
||
d_model=d_model, embed_type=embed_type, freq=freq)
|
||
self.dropout = nn.Dropout(p=dropout)
|
||
|
||
def forward(self, x, x_mark):
|
||
if x_mark is None:
|
||
x = self.value_embedding(x) + self.position_embedding(x)
|
||
else:
|
||
x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
|
||
return self.dropout(x)
|
||
|
||
|
||
class AutoCorrelation(nn.Module):
|
||
"""
|
||
AutoCorrelation Mechanism with the following two phases:
|
||
(1) period-based dependencies discovery
|
||
(2) time delay aggregation
|
||
This block can replace the self-attention family mechanism seamlessly.
|
||
"""
|
||
|
||
def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):
|
||
super(AutoCorrelation, self).__init__()
|
||
self.factor = factor
|
||
self.scale = scale
|
||
self.mask_flag = mask_flag
|
||
self.output_attention = output_attention
|
||
self.dropout = nn.Dropout(attention_dropout)
|
||
|
||
def time_delay_agg_training(self, values, corr):
|
||
"""
|
||
SpeedUp version of Autocorrelation (a batch-normalization style design)
|
||
This is for the training phase.
|
||
"""
|
||
head = values.shape[1]
|
||
channel = values.shape[2]
|
||
length = values.shape[3]
|
||
# find top k
|
||
top_k = int(self.factor * math.log(length))
|
||
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
|
||
index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
|
||
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
|
||
# update corr
|
||
tmp_corr = torch.softmax(weights, dim=-1)
|
||
# aggregation
|
||
tmp_values = values
|
||
delays_agg = torch.zeros_like(values).float()
|
||
for i in range(top_k):
|
||
pattern = torch.roll(tmp_values, -int(index[i]), -1)
|
||
delays_agg = delays_agg + pattern * \
|
||
(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
|
||
return delays_agg
|
||
|
||
def time_delay_agg_inference(self, values, corr):
|
||
"""
|
||
SpeedUp version of Autocorrelation (a batch-normalization style design)
|
||
This is for the inference phase.
|
||
"""
|
||
batch = values.shape[0]
|
||
head = values.shape[1]
|
||
channel = values.shape[2]
|
||
length = values.shape[3]
|
||
# index init
|
||
init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda()
|
||
# find top k
|
||
top_k = int(self.factor * math.log(length))
|
||
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
|
||
weights, delay = torch.topk(mean_value, top_k, dim=-1)
|
||
# update corr
|
||
tmp_corr = torch.softmax(weights, dim=-1)
|
||
# aggregation
|
||
tmp_values = values.repeat(1, 1, 1, 2)
|
||
delays_agg = torch.zeros_like(values).float()
|
||
for i in range(top_k):
|
||
tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
|
||
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
|
||
delays_agg = delays_agg + pattern * \
|
||
(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
|
||
return delays_agg
|
||
|
||
def time_delay_agg_full(self, values, corr):
|
||
"""
|
||
Standard version of Autocorrelation
|
||
"""
|
||
batch = values.shape[0]
|
||
head = values.shape[1]
|
||
channel = values.shape[2]
|
||
length = values.shape[3]
|
||
# index init
|
||
init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda()
|
||
# find top k
|
||
top_k = int(self.factor * math.log(length))
|
||
weights, delay = torch.topk(corr, top_k, dim=-1)
|
||
# update corr
|
||
tmp_corr = torch.softmax(weights, dim=-1)
|
||
# aggregation
|
||
tmp_values = values.repeat(1, 1, 1, 2)
|
||
delays_agg = torch.zeros_like(values).float()
|
||
for i in range(top_k):
|
||
tmp_delay = init_index + delay[..., i].unsqueeze(-1)
|
||
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
|
||
delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
|
||
return delays_agg
|
||
|
||
def forward(self, queries, keys, values, attn_mask):
|
||
B, L, H, E = queries.shape
|
||
_, S, _, D = values.shape
|
||
if L > S:
|
||
zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
|
||
values = torch.cat([values, zeros], dim=1)
|
||
keys = torch.cat([keys, zeros], dim=1)
|
||
else:
|
||
values = values[:, :L, :, :]
|
||
keys = keys[:, :L, :, :]
|
||
|
||
# period-based dependencies
|
||
q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
|
||
k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
|
||
res = q_fft * torch.conj(k_fft)
|
||
corr = torch.fft.irfft(res, dim=-1)
|
||
|
||
# time delay agg
|
||
if self.training:
|
||
V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
|
||
else:
|
||
V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
|
||
|
||
if self.output_attention:
|
||
return (V.contiguous(), corr.permute(0, 3, 1, 2))
|
||
else:
|
||
return (V.contiguous(), None)
|
||
|
||
|
||
class AutoCorrelationLayer(nn.Module):
|
||
def __init__(self, correlation, d_model, n_heads, d_keys=None,
|
||
d_values=None):
|
||
super(AutoCorrelationLayer, self).__init__()
|
||
|
||
d_keys = d_keys or (d_model // n_heads)
|
||
d_values = d_values or (d_model // n_heads)
|
||
|
||
self.inner_correlation = correlation
|
||
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
|
||
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
|
||
self.value_projection = nn.Linear(d_model, d_values * n_heads)
|
||
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
||
self.n_heads = n_heads
|
||
|
||
def forward(self, queries, keys, values, attn_mask):
|
||
B, L, _ = queries.shape
|
||
_, S, _ = keys.shape
|
||
H = self.n_heads
|
||
|
||
queries = self.query_projection(queries).view(B, L, H, -1)
|
||
keys = self.key_projection(keys).view(B, S, H, -1)
|
||
values = self.value_projection(values).view(B, S, H, -1)
|
||
|
||
out, attn = self.inner_correlation(
|
||
queries,
|
||
keys,
|
||
values,
|
||
attn_mask
|
||
)
|
||
out = out.view(B, L, -1)
|
||
|
||
return self.out_projection(out), attn
|
||
|
||
def get_frequency_modes(seq_len, modes=64, mode_select_method='random'):
|
||
"""
|
||
get modes on frequency domain:
|
||
'random' means sampling randomly;
|
||
'else' means sampling the lowest modes;
|
||
"""
|
||
modes = min(modes, seq_len // 2)
|
||
if mode_select_method == 'random':
|
||
index = list(range(0, seq_len // 2))
|
||
np.random.shuffle(index)
|
||
index = index[:modes]
|
||
else:
|
||
index = list(range(0, modes))
|
||
index.sort()
|
||
return index
|
||
|
||
|
||
# ########## fourier layer #############
|
||
class FourierBlock(nn.Module):
|
||
def __init__(self, in_channels, out_channels, seq_len, modes=0, mode_select_method='random'):
|
||
super(FourierBlock, self).__init__()
|
||
print('fourier enhanced block used!')
|
||
"""
|
||
1D Fourier block. It performs representation learning on frequency domain,
|
||
it does FFT, linear transform, and Inverse FFT.
|
||
"""
|
||
# get modes on frequency domain
|
||
self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method)
|
||
print('modes={}, index={}'.format(modes, self.index))
|
||
|
||
self.scale = (1 / (in_channels * out_channels))
|
||
self.weights1 = nn.Parameter(
|
||
self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.float))
|
||
self.weights2 = nn.Parameter(
|
||
self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.float))
|
||
|
||
# Complex multiplication
|
||
def compl_mul1d(self, order, x, weights):
|
||
x_flag = True
|
||
w_flag = True
|
||
if not torch.is_complex(x):
|
||
x_flag = False
|
||
x = torch.complex(x, torch.zeros_like(x).to(x.device))
|
||
if not torch.is_complex(weights):
|
||
w_flag = False
|
||
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
|
||
if x_flag or w_flag:
|
||
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
|
||
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
|
||
else:
|
||
return torch.einsum(order, x.real, weights.real)
|
||
|
||
def forward(self, q, k, v, mask):
|
||
# size = [B, L, H, E]
|
||
B, L, H, E = q.shape
|
||
x = q.permute(0, 2, 3, 1)
|
||
# Compute Fourier coefficients
|
||
x_ft = torch.fft.rfft(x, dim=-1)
|
||
# Perform Fourier neural operations
|
||
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
|
||
for wi, i in enumerate(self.index):
|
||
if i >= x_ft.shape[3] or wi >= out_ft.shape[3]:
|
||
continue
|
||
out_ft[:, :, :, wi] = self.compl_mul1d("bhi,hio->bho", x_ft[:, :, :, i],
|
||
torch.complex(self.weights1, self.weights2)[:, :, :, wi])
|
||
# Return to time domain
|
||
x = torch.fft.irfft(out_ft, n=x.size(-1))
|
||
return (x, None)
|
||
|
||
|
||
# ########## Fourier Cross Former ####################
|
||
class FourierCrossAttention(nn.Module):
|
||
def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=64, mode_select_method='random',
|
||
activation='tanh', policy=0, num_heads=8):
|
||
super(FourierCrossAttention, self).__init__()
|
||
print(' fourier enhanced cross attention used!')
|
||
"""
|
||
1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT.
|
||
"""
|
||
self.activation = activation
|
||
self.in_channels = in_channels
|
||
self.out_channels = out_channels
|
||
# get modes for queries and keys (& values) on frequency domain
|
||
self.index_q = get_frequency_modes(seq_len_q, modes=modes, mode_select_method=mode_select_method)
|
||
self.index_kv = get_frequency_modes(seq_len_kv, modes=modes, mode_select_method=mode_select_method)
|
||
|
||
print('modes_q={}, index_q={}'.format(len(self.index_q), self.index_q))
|
||
print('modes_kv={}, index_kv={}'.format(len(self.index_kv), self.index_kv))
|
||
|
||
self.scale = (1 / (in_channels * out_channels))
|
||
self.weights1 = nn.Parameter(
|
||
self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float))
|
||
self.weights2 = nn.Parameter(
|
||
self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float))
|
||
|
||
# Complex multiplication
|
||
def compl_mul1d(self, order, x, weights):
|
||
x_flag = True
|
||
w_flag = True
|
||
if not torch.is_complex(x):
|
||
x_flag = False
|
||
x = torch.complex(x, torch.zeros_like(x).to(x.device))
|
||
if not torch.is_complex(weights):
|
||
w_flag = False
|
||
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
|
||
if x_flag or w_flag:
|
||
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
|
||
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
|
||
else:
|
||
return torch.einsum(order, x.real, weights.real)
|
||
|
||
def forward(self, q, k, v, mask):
|
||
# size = [B, L, H, E]
|
||
B, L, H, E = q.shape
|
||
xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L]
|
||
xk = k.permute(0, 2, 3, 1)
|
||
xv = v.permute(0, 2, 3, 1)
|
||
|
||
# Compute Fourier coefficients
|
||
xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat)
|
||
xq_ft = torch.fft.rfft(xq, dim=-1)
|
||
for i, j in enumerate(self.index_q):
|
||
if j >= xq_ft.shape[3]:
|
||
continue
|
||
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
|
||
xk_ft_ = torch.zeros(B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat)
|
||
xk_ft = torch.fft.rfft(xk, dim=-1)
|
||
for i, j in enumerate(self.index_kv):
|
||
if j >= xk_ft.shape[3]:
|
||
continue
|
||
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
|
||
|
||
# perform attention mechanism on frequency domain
|
||
xqk_ft = (self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_))
|
||
if self.activation == 'tanh':
|
||
xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
|
||
elif self.activation == 'softmax':
|
||
xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
|
||
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
|
||
else:
|
||
raise Exception('{} actiation function is not implemented'.format(self.activation))
|
||
xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)
|
||
xqkvw = self.compl_mul1d("bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2))
|
||
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
|
||
for i, j in enumerate(self.index_q):
|
||
if i >= xqkvw.shape[3] or j >= out_ft.shape[3]:
|
||
continue
|
||
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
|
||
# Return to time domain
|
||
out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1))
|
||
return (out, None)
|
||
|
||
|
||
|
||
def legendreDer(k, x):
|
||
def _legendre(k, x):
|
||
return (2 * k + 1) * eval_legendre(k, x)
|
||
|
||
out = 0
|
||
for i in np.arange(k - 1, -1, -2):
|
||
out += _legendre(i, x)
|
||
return out
|
||
|
||
|
||
def phi_(phi_c, x, lb=0, ub=1):
|
||
mask = np.logical_or(x < lb, x > ub) * 1.0
|
||
return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1 - mask)
|
||
|
||
|
||
def get_phi_psi(k, base):
|
||
x = Symbol('x')
|
||
phi_coeff = np.zeros((k, k))
|
||
phi_2x_coeff = np.zeros((k, k))
|
||
if base == 'legendre':
|
||
for ki in range(k):
|
||
coeff_ = Poly(legendre(ki, 2 * x - 1), x).all_coeffs()
|
||
phi_coeff[ki, :ki + 1] = np.flip(np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64))
|
||
coeff_ = Poly(legendre(ki, 4 * x - 1), x).all_coeffs()
|
||
phi_2x_coeff[ki, :ki + 1] = np.flip(np.sqrt(2) * np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64))
|
||
|
||
psi1_coeff = np.zeros((k, k))
|
||
psi2_coeff = np.zeros((k, k))
|
||
for ki in range(k):
|
||
psi1_coeff[ki, :] = phi_2x_coeff[ki, :]
|
||
for i in range(k):
|
||
a = phi_2x_coeff[ki, :ki + 1]
|
||
b = phi_coeff[i, :i + 1]
|
||
prod_ = np.convolve(a, b)
|
||
prod_[np.abs(prod_) < 1e-8] = 0
|
||
proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
|
||
psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :]
|
||
psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :]
|
||
for j in range(ki):
|
||
a = phi_2x_coeff[ki, :ki + 1]
|
||
b = psi1_coeff[j, :]
|
||
prod_ = np.convolve(a, b)
|
||
prod_[np.abs(prod_) < 1e-8] = 0
|
||
proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
|
||
psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :]
|
||
psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :]
|
||
|
||
a = psi1_coeff[ki, :]
|
||
prod_ = np.convolve(a, a)
|
||
prod_[np.abs(prod_) < 1e-8] = 0
|
||
norm1 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
|
||
|
||
a = psi2_coeff[ki, :]
|
||
prod_ = np.convolve(a, a)
|
||
prod_[np.abs(prod_) < 1e-8] = 0
|
||
norm2 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * (1 - np.power(0.5, 1 + np.arange(len(prod_))))).sum()
|
||
norm_ = np.sqrt(norm1 + norm2)
|
||
psi1_coeff[ki, :] /= norm_
|
||
psi2_coeff[ki, :] /= norm_
|
||
psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0
|
||
psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0
|
||
|
||
phi = [np.poly1d(np.flip(phi_coeff[i, :])) for i in range(k)]
|
||
psi1 = [np.poly1d(np.flip(psi1_coeff[i, :])) for i in range(k)]
|
||
psi2 = [np.poly1d(np.flip(psi2_coeff[i, :])) for i in range(k)]
|
||
|
||
elif base == 'chebyshev':
|
||
for ki in range(k):
|
||
if ki == 0:
|
||
phi_coeff[ki, :ki + 1] = np.sqrt(2 / np.pi)
|
||
phi_2x_coeff[ki, :ki + 1] = np.sqrt(2 / np.pi) * np.sqrt(2)
|
||
else:
|
||
coeff_ = Poly(chebyshevt(ki, 2 * x - 1), x).all_coeffs()
|
||
phi_coeff[ki, :ki + 1] = np.flip(2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
|
||
coeff_ = Poly(chebyshevt(ki, 4 * x - 1), x).all_coeffs()
|
||
phi_2x_coeff[ki, :ki + 1] = np.flip(
|
||
np.sqrt(2) * 2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
|
||
|
||
phi = [partial(phi_, phi_coeff[i, :]) for i in range(k)]
|
||
|
||
x = Symbol('x')
|
||
kUse = 2 * k
|
||
roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots()
|
||
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
|
||
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
|
||
# not needed for our purpose here, we use even k always to avoid
|
||
wm = np.pi / kUse / 2
|
||
|
||
psi1_coeff = np.zeros((k, k))
|
||
psi2_coeff = np.zeros((k, k))
|
||
|
||
psi1 = [[] for _ in range(k)]
|
||
psi2 = [[] for _ in range(k)]
|
||
|
||
for ki in range(k):
|
||
psi1_coeff[ki, :] = phi_2x_coeff[ki, :]
|
||
for i in range(k):
|
||
proj_ = (wm * phi[i](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum()
|
||
psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :]
|
||
psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :]
|
||
|
||
for j in range(ki):
|
||
proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum()
|
||
psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :]
|
||
psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :]
|
||
|
||
psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5)
|
||
psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5, ub=1)
|
||
|
||
norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum()
|
||
norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum()
|
||
|
||
norm_ = np.sqrt(norm1 + norm2)
|
||
psi1_coeff[ki, :] /= norm_
|
||
psi2_coeff[ki, :] /= norm_
|
||
psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0
|
||
psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0
|
||
|
||
psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5 + 1e-16)
|
||
psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5 + 1e-16, ub=1)
|
||
|
||
return phi, psi1, psi2
|
||
|
||
|
||
def get_filter(base, k):
|
||
def psi(psi1, psi2, i, inp):
|
||
mask = (inp <= 0.5) * 1.0
|
||
return psi1[i](inp) * mask + psi2[i](inp) * (1 - mask)
|
||
|
||
if base not in ['legendre', 'chebyshev']:
|
||
raise Exception('Base not supported')
|
||
|
||
x = Symbol('x')
|
||
H0 = np.zeros((k, k))
|
||
H1 = np.zeros((k, k))
|
||
G0 = np.zeros((k, k))
|
||
G1 = np.zeros((k, k))
|
||
PHI0 = np.zeros((k, k))
|
||
PHI1 = np.zeros((k, k))
|
||
phi, psi1, psi2 = get_phi_psi(k, base)
|
||
if base == 'legendre':
|
||
roots = Poly(legendre(k, 2 * x - 1)).all_roots()
|
||
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
|
||
wm = 1 / k / legendreDer(k, 2 * x_m - 1) / eval_legendre(k - 1, 2 * x_m - 1)
|
||
|
||
for ki in range(k):
|
||
for kpi in range(k):
|
||
H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum()
|
||
G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum()
|
||
H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum()
|
||
G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum()
|
||
|
||
PHI0 = np.eye(k)
|
||
PHI1 = np.eye(k)
|
||
|
||
elif base == 'chebyshev':
|
||
x = Symbol('x')
|
||
kUse = 2 * k
|
||
roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots()
|
||
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
|
||
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
|
||
# not needed for our purpose here, we use even k always to avoid
|
||
wm = np.pi / kUse / 2
|
||
|
||
for ki in range(k):
|
||
for kpi in range(k):
|
||
H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum()
|
||
G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum()
|
||
H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum()
|
||
G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum()
|
||
|
||
PHI0[ki, kpi] = (wm * phi[ki](2 * x_m) * phi[kpi](2 * x_m)).sum() * 2
|
||
PHI1[ki, kpi] = (wm * phi[ki](2 * x_m - 1) * phi[kpi](2 * x_m - 1)).sum() * 2
|
||
|
||
PHI0[np.abs(PHI0) < 1e-8] = 0
|
||
PHI1[np.abs(PHI1) < 1e-8] = 0
|
||
|
||
H0[np.abs(H0) < 1e-8] = 0
|
||
H1[np.abs(H1) < 1e-8] = 0
|
||
G0[np.abs(G0) < 1e-8] = 0
|
||
G1[np.abs(G1) < 1e-8] = 0
|
||
|
||
return H0, H1, G0, G1, PHI0, PHI1
|
||
|
||
|
||
class MultiWaveletTransform(nn.Module):
|
||
"""
|
||
1D multiwavelet block.
|
||
"""
|
||
|
||
def __init__(self, ich=1, k=8, alpha=16, c=128,
|
||
nCZ=1, L=0, base='legendre', attention_dropout=0.1):
|
||
super(MultiWaveletTransform, self).__init__()
|
||
print('base', base)
|
||
self.k = k
|
||
self.c = c
|
||
self.L = L
|
||
self.nCZ = nCZ
|
||
self.Lk0 = nn.Linear(ich, c * k)
|
||
self.Lk1 = nn.Linear(c * k, ich)
|
||
self.ich = ich
|
||
self.MWT_CZ = nn.ModuleList(MWT_CZ1d(k, alpha, L, c, base) for i in range(nCZ))
|
||
|
||
def forward(self, queries, keys, values, attn_mask):
|
||
B, L, H, E = queries.shape
|
||
_, S, _, D = values.shape
|
||
if L > S:
|
||
zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
|
||
values = torch.cat([values, zeros], dim=1)
|
||
keys = torch.cat([keys, zeros], dim=1)
|
||
else:
|
||
values = values[:, :L, :, :]
|
||
keys = keys[:, :L, :, :]
|
||
values = values.view(B, L, -1)
|
||
|
||
V = self.Lk0(values).view(B, L, self.c, -1)
|
||
for i in range(self.nCZ):
|
||
V = self.MWT_CZ[i](V)
|
||
if i < self.nCZ - 1:
|
||
V = F.relu(V)
|
||
|
||
V = self.Lk1(V.view(B, L, -1))
|
||
V = V.view(B, L, -1, D)
|
||
return (V.contiguous(), None)
|
||
|
||
class MultiWaveletTransform(nn.Module):
|
||
"""
|
||
1D multiwavelet block.
|
||
"""
|
||
|
||
def __init__(self, ich=1, k=8, alpha=16, c=128,
|
||
nCZ=1, L=0, base='legendre', attention_dropout=0.1):
|
||
super(MultiWaveletTransform, self).__init__()
|
||
print('base', base)
|
||
self.k = k
|
||
self.c = c
|
||
self.L = L
|
||
self.nCZ = nCZ
|
||
self.Lk0 = nn.Linear(ich, c * k)
|
||
self.Lk1 = nn.Linear(c * k, ich)
|
||
self.ich = ich
|
||
self.MWT_CZ = nn.ModuleList(MWT_CZ1d(k, alpha, L, c, base) for i in range(nCZ))
|
||
|
||
def forward(self, queries, keys, values, attn_mask):
|
||
B, L, H, E = queries.shape
|
||
_, S, _, D = values.shape
|
||
if L > S:
|
||
zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
|
||
values = torch.cat([values, zeros], dim=1)
|
||
keys = torch.cat([keys, zeros], dim=1)
|
||
else:
|
||
values = values[:, :L, :, :]
|
||
keys = keys[:, :L, :, :]
|
||
values = values.view(B, L, -1)
|
||
|
||
V = self.Lk0(values).view(B, L, self.c, -1)
|
||
for i in range(self.nCZ):
|
||
V = self.MWT_CZ[i](V)
|
||
if i < self.nCZ - 1:
|
||
V = F.relu(V)
|
||
|
||
V = self.Lk1(V.view(B, L, -1))
|
||
V = V.view(B, L, -1, D)
|
||
return (V.contiguous(), None)
|
||
|
||
|
||
class MultiWaveletCross(nn.Module):
|
||
"""
|
||
1D Multiwavelet Cross Attention layer.
|
||
"""
|
||
|
||
def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes, c=64,
|
||
k=8, ich=512,
|
||
L=0,
|
||
base='legendre',
|
||
mode_select_method='random',
|
||
initializer=None, activation='tanh',
|
||
**kwargs):
|
||
super(MultiWaveletCross, self).__init__()
|
||
print('base', base)
|
||
|
||
self.c = c
|
||
self.k = k
|
||
self.L = L
|
||
H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
|
||
H0r = H0 @ PHI0
|
||
G0r = G0 @ PHI0
|
||
H1r = H1 @ PHI1
|
||
G1r = G1 @ PHI1
|
||
|
||
H0r[np.abs(H0r) < 1e-8] = 0
|
||
H1r[np.abs(H1r) < 1e-8] = 0
|
||
G0r[np.abs(G0r) < 1e-8] = 0
|
||
G1r[np.abs(G1r) < 1e-8] = 0
|
||
self.max_item = 3
|
||
|
||
self.attn1 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
|
||
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
|
||
mode_select_method=mode_select_method)
|
||
self.attn2 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
|
||
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
|
||
mode_select_method=mode_select_method)
|
||
self.attn3 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
|
||
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
|
||
mode_select_method=mode_select_method)
|
||
self.attn4 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
|
||
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
|
||
mode_select_method=mode_select_method)
|
||
self.T0 = nn.Linear(k, k)
|
||
self.register_buffer('ec_s', torch.Tensor(
|
||
np.concatenate((H0.T, H1.T), axis=0)))
|
||
self.register_buffer('ec_d', torch.Tensor(
|
||
np.concatenate((G0.T, G1.T), axis=0)))
|
||
|
||
self.register_buffer('rc_e', torch.Tensor(
|
||
np.concatenate((H0r, G0r), axis=0)))
|
||
self.register_buffer('rc_o', torch.Tensor(
|
||
np.concatenate((H1r, G1r), axis=0)))
|
||
|
||
self.Lk = nn.Linear(ich, c * k)
|
||
self.Lq = nn.Linear(ich, c * k)
|
||
self.Lv = nn.Linear(ich, c * k)
|
||
self.out = nn.Linear(c * k, ich)
|
||
self.modes1 = modes
|
||
|
||
def forward(self, q, k, v, mask=None):
|
||
B, N, H, E = q.shape # (B, N, H, E) torch.Size([3, 768, 8, 2])
|
||
_, S, _, _ = k.shape # (B, S, H, E) torch.Size([3, 96, 8, 2])
|
||
|
||
q = q.view(q.shape[0], q.shape[1], -1)
|
||
k = k.view(k.shape[0], k.shape[1], -1)
|
||
v = v.view(v.shape[0], v.shape[1], -1)
|
||
q = self.Lq(q)
|
||
q = q.view(q.shape[0], q.shape[1], self.c, self.k)
|
||
k = self.Lk(k)
|
||
k = k.view(k.shape[0], k.shape[1], self.c, self.k)
|
||
v = self.Lv(v)
|
||
v = v.view(v.shape[0], v.shape[1], self.c, self.k)
|
||
|
||
if N > S:
|
||
zeros = torch.zeros_like(q[:, :(N - S), :]).float()
|
||
v = torch.cat([v, zeros], dim=1)
|
||
k = torch.cat([k, zeros], dim=1)
|
||
else:
|
||
v = v[:, :N, :, :]
|
||
k = k[:, :N, :, :]
|
||
|
||
ns = math.floor(np.log2(N))
|
||
nl = pow(2, math.ceil(np.log2(N)))
|
||
extra_q = q[:, 0:nl - N, :, :]
|
||
extra_k = k[:, 0:nl - N, :, :]
|
||
extra_v = v[:, 0:nl - N, :, :]
|
||
q = torch.cat([q, extra_q], 1)
|
||
k = torch.cat([k, extra_k], 1)
|
||
v = torch.cat([v, extra_v], 1)
|
||
|
||
Ud_q = torch.jit.annotate(List[Tuple[Tensor]], [])
|
||
Ud_k = torch.jit.annotate(List[Tuple[Tensor]], [])
|
||
Ud_v = torch.jit.annotate(List[Tuple[Tensor]], [])
|
||
|
||
Us_q = torch.jit.annotate(List[Tensor], [])
|
||
Us_k = torch.jit.annotate(List[Tensor], [])
|
||
Us_v = torch.jit.annotate(List[Tensor], [])
|
||
|
||
Ud = torch.jit.annotate(List[Tensor], [])
|
||
Us = torch.jit.annotate(List[Tensor], [])
|
||
|
||
# decompose
|
||
for i in range(ns - self.L):
|
||
# print('q shape',q.shape)
|
||
d, q = self.wavelet_transform(q)
|
||
Ud_q += [tuple([d, q])]
|
||
Us_q += [d]
|
||
for i in range(ns - self.L):
|
||
d, k = self.wavelet_transform(k)
|
||
Ud_k += [tuple([d, k])]
|
||
Us_k += [d]
|
||
for i in range(ns - self.L):
|
||
d, v = self.wavelet_transform(v)
|
||
Ud_v += [tuple([d, v])]
|
||
Us_v += [d]
|
||
for i in range(ns - self.L):
|
||
dk, sk = Ud_k[i], Us_k[i]
|
||
dq, sq = Ud_q[i], Us_q[i]
|
||
dv, sv = Ud_v[i], Us_v[i]
|
||
Ud += [self.attn1(dq[0], dk[0], dv[0], mask)[0] + self.attn2(dq[1], dk[1], dv[1], mask)[0]]
|
||
Us += [self.attn3(sq, sk, sv, mask)[0]]
|
||
v = self.attn4(q, k, v, mask)[0]
|
||
|
||
# reconstruct
|
||
for i in range(ns - 1 - self.L, -1, -1):
|
||
v = v + Us[i]
|
||
v = torch.cat((v, Ud[i]), -1)
|
||
v = self.evenOdd(v)
|
||
v = self.out(v[:, :N, :, :].contiguous().view(B, N, -1))
|
||
return (v.contiguous(), None)
|
||
|
||
def wavelet_transform(self, x):
|
||
xa = torch.cat([x[:, ::2, :, :],
|
||
x[:, 1::2, :, :],
|
||
], -1)
|
||
d = torch.matmul(xa, self.ec_d)
|
||
s = torch.matmul(xa, self.ec_s)
|
||
return d, s
|
||
|
||
def evenOdd(self, x):
|
||
B, N, c, ich = x.shape # (B, N, c, k)
|
||
assert ich == 2 * self.k
|
||
x_e = torch.matmul(x, self.rc_e)
|
||
x_o = torch.matmul(x, self.rc_o)
|
||
|
||
x = torch.zeros(B, N * 2, c, self.k,
|
||
device=x.device)
|
||
x[..., ::2, :, :] = x_e
|
||
x[..., 1::2, :, :] = x_o
|
||
return x
|
||
|
||
|
||
class FourierCrossAttentionW(nn.Module):
|
||
def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=16, activation='tanh',
|
||
mode_select_method='random'):
|
||
super(FourierCrossAttentionW, self).__init__()
|
||
print('corss fourier correlation used!')
|
||
self.in_channels = in_channels
|
||
self.out_channels = out_channels
|
||
self.modes1 = modes
|
||
self.activation = activation
|
||
|
||
def compl_mul1d(self, order, x, weights):
|
||
x_flag = True
|
||
w_flag = True
|
||
if not torch.is_complex(x):
|
||
x_flag = False
|
||
x = torch.complex(x, torch.zeros_like(x).to(x.device))
|
||
if not torch.is_complex(weights):
|
||
w_flag = False
|
||
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
|
||
if x_flag or w_flag:
|
||
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
|
||
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
|
||
else:
|
||
return torch.einsum(order, x.real, weights.real)
|
||
|
||
def forward(self, q, k, v, mask):
|
||
B, L, E, H = q.shape
|
||
|
||
xq = q.permute(0, 3, 2, 1) # size = [B, H, E, L] torch.Size([3, 8, 64, 512])
|
||
xk = k.permute(0, 3, 2, 1)
|
||
xv = v.permute(0, 3, 2, 1)
|
||
self.index_q = list(range(0, min(int(L // 2), self.modes1)))
|
||
self.index_k_v = list(range(0, min(int(xv.shape[3] // 2), self.modes1)))
|
||
|
||
# Compute Fourier coefficients
|
||
xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat)
|
||
xq_ft = torch.fft.rfft(xq, dim=-1)
|
||
for i, j in enumerate(self.index_q):
|
||
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
|
||
|
||
xk_ft_ = torch.zeros(B, H, E, len(self.index_k_v), device=xq.device, dtype=torch.cfloat)
|
||
xk_ft = torch.fft.rfft(xk, dim=-1)
|
||
for i, j in enumerate(self.index_k_v):
|
||
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
|
||
xqk_ft = (self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_))
|
||
if self.activation == 'tanh':
|
||
xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
|
||
elif self.activation == 'softmax':
|
||
xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
|
||
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
|
||
else:
|
||
raise Exception('{} actiation function is not implemented'.format(self.activation))
|
||
xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)
|
||
|
||
xqkvw = xqkv_ft
|
||
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
|
||
for i, j in enumerate(self.index_q):
|
||
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
|
||
|
||
out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)).permute(0, 3, 2, 1)
|
||
# size = [B, L, H, E]
|
||
return (out, None)
|
||
|
||
|
||
class sparseKernelFT1d(nn.Module):
|
||
def __init__(self,
|
||
k, alpha, c=1,
|
||
nl=1,
|
||
initializer=None,
|
||
**kwargs):
|
||
super(sparseKernelFT1d, self).__init__()
|
||
|
||
self.modes1 = alpha
|
||
self.scale = (1 / (c * k * c * k))
|
||
self.weights1 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float))
|
||
self.weights2 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float))
|
||
self.weights1.requires_grad = True
|
||
self.weights2.requires_grad = True
|
||
self.k = k
|
||
|
||
def compl_mul1d(self, order, x, weights):
|
||
x_flag = True
|
||
w_flag = True
|
||
if not torch.is_complex(x):
|
||
x_flag = False
|
||
x = torch.complex(x, torch.zeros_like(x).to(x.device))
|
||
if not torch.is_complex(weights):
|
||
w_flag = False
|
||
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
|
||
if x_flag or w_flag:
|
||
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
|
||
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
|
||
else:
|
||
return torch.einsum(order, x.real, weights.real)
|
||
|
||
def forward(self, x):
|
||
B, N, c, k = x.shape # (B, N, c, k)
|
||
|
||
x = x.view(B, N, -1)
|
||
x = x.permute(0, 2, 1)
|
||
x_fft = torch.fft.rfft(x)
|
||
# Multiply relevant Fourier modes
|
||
l = min(self.modes1, N // 2 + 1)
|
||
out_ft = torch.zeros(B, c * k, N // 2 + 1, device=x.device, dtype=torch.cfloat)
|
||
out_ft[:, :, :l] = self.compl_mul1d("bix,iox->box", x_fft[:, :, :l],
|
||
torch.complex(self.weights1, self.weights2)[:, :, :l])
|
||
x = torch.fft.irfft(out_ft, n=N)
|
||
x = x.permute(0, 2, 1).view(B, N, c, k)
|
||
return x
|
||
|
||
|
||
# ##
|
||
class MWT_CZ1d(nn.Module):
|
||
def __init__(self,
|
||
k=3, alpha=64,
|
||
L=0, c=1,
|
||
base='legendre',
|
||
initializer=None,
|
||
**kwargs):
|
||
super(MWT_CZ1d, self).__init__()
|
||
|
||
self.k = k
|
||
self.L = L
|
||
H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
|
||
H0r = H0 @ PHI0
|
||
G0r = G0 @ PHI0
|
||
H1r = H1 @ PHI1
|
||
G1r = G1 @ PHI1
|
||
|
||
H0r[np.abs(H0r) < 1e-8] = 0
|
||
H1r[np.abs(H1r) < 1e-8] = 0
|
||
G0r[np.abs(G0r) < 1e-8] = 0
|
||
G1r[np.abs(G1r) < 1e-8] = 0
|
||
self.max_item = 3
|
||
|
||
self.A = sparseKernelFT1d(k, alpha, c)
|
||
self.B = sparseKernelFT1d(k, alpha, c)
|
||
self.C = sparseKernelFT1d(k, alpha, c)
|
||
|
||
self.T0 = nn.Linear(k, k)
|
||
|
||
self.register_buffer('ec_s', torch.Tensor(
|
||
np.concatenate((H0.T, H1.T), axis=0)))
|
||
self.register_buffer('ec_d', torch.Tensor(
|
||
np.concatenate((G0.T, G1.T), axis=0)))
|
||
|
||
self.register_buffer('rc_e', torch.Tensor(
|
||
np.concatenate((H0r, G0r), axis=0)))
|
||
self.register_buffer('rc_o', torch.Tensor(
|
||
np.concatenate((H1r, G1r), axis=0)))
|
||
|
||
def forward(self, x):
|
||
B, N, c, k = x.shape # (B, N, k)
|
||
ns = math.floor(np.log2(N))
|
||
nl = pow(2, math.ceil(np.log2(N)))
|
||
extra_x = x[:, 0:nl - N, :, :]
|
||
x = torch.cat([x, extra_x], 1)
|
||
Ud = torch.jit.annotate(List[Tensor], [])
|
||
Us = torch.jit.annotate(List[Tensor], [])
|
||
for i in range(ns - self.L):
|
||
d, x = self.wavelet_transform(x)
|
||
Ud += [self.A(d) + self.B(x)]
|
||
Us += [self.C(d)]
|
||
x = self.T0(x) # coarsest scale transform
|
||
|
||
# reconstruct
|
||
for i in range(ns - 1 - self.L, -1, -1):
|
||
x = x + Us[i]
|
||
x = torch.cat((x, Ud[i]), -1)
|
||
x = self.evenOdd(x)
|
||
x = x[:, :N, :, :]
|
||
|
||
return x
|
||
|
||
def wavelet_transform(self, x):
|
||
xa = torch.cat([x[:, ::2, :, :],
|
||
x[:, 1::2, :, :],
|
||
], -1)
|
||
d = torch.matmul(xa, self.ec_d)
|
||
s = torch.matmul(xa, self.ec_s)
|
||
return d, s
|
||
|
||
def evenOdd(self, x):
|
||
|
||
B, N, c, ich = x.shape # (B, N, c, k)
|
||
assert ich == 2 * self.k
|
||
x_e = torch.matmul(x, self.rc_e)
|
||
x_o = torch.matmul(x, self.rc_o)
|
||
|
||
x = torch.zeros(B, N * 2, c, self.k,
|
||
device=x.device)
|
||
x[..., ::2, :, :] = x_e
|
||
x[..., 1::2, :, :] = x_o
|
||
return x
|
||
|
||
class my_Layernorm(nn.Module):
|
||
"""
|
||
Special designed layernorm for the seasonal part
|
||
"""
|
||
|
||
def __init__(self, channels):
|
||
super(my_Layernorm, self).__init__()
|
||
self.layernorm = nn.LayerNorm(channels)
|
||
|
||
def forward(self, x):
|
||
x_hat = self.layernorm(x)
|
||
bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
|
||
return x_hat - bias
|
||
|
||
|
||
class moving_avg(nn.Module):
|
||
"""
|
||
Moving average block to highlight the trend of time series
|
||
"""
|
||
|
||
def __init__(self, kernel_size, stride):
|
||
super(moving_avg, self).__init__()
|
||
self.kernel_size = kernel_size
|
||
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
|
||
|
||
def forward(self, x):
|
||
# padding on the both ends of time series
|
||
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
||
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
||
x = torch.cat([front, x, end], dim=1)
|
||
x = self.avg(x.permute(0, 2, 1))
|
||
x = x.permute(0, 2, 1)
|
||
return x
|
||
|
||
|
||
class series_decomp(nn.Module):
|
||
"""
|
||
Series decomposition block
|
||
"""
|
||
|
||
def __init__(self, kernel_size):
|
||
super(series_decomp, self).__init__()
|
||
self.moving_avg = moving_avg(kernel_size, stride=1)
|
||
|
||
def forward(self, x):
|
||
moving_mean = self.moving_avg(x)
|
||
res = x - moving_mean
|
||
return res, moving_mean
|
||
|
||
|
||
class series_decomp_multi(nn.Module):
|
||
"""
|
||
Multiple Series decomposition block from FEDformer
|
||
"""
|
||
|
||
def __init__(self, kernel_size):
|
||
super(series_decomp_multi, self).__init__()
|
||
self.kernel_size = kernel_size
|
||
self.series_decomp = [series_decomp(kernel) for kernel in kernel_size]
|
||
|
||
def forward(self, x):
|
||
moving_mean = []
|
||
res = []
|
||
for func in self.series_decomp:
|
||
sea, moving_avg = func(x)
|
||
moving_mean.append(moving_avg)
|
||
res.append(sea)
|
||
|
||
sea = sum(res) / len(res)
|
||
moving_mean = sum(moving_mean) / len(moving_mean)
|
||
return sea, moving_mean
|
||
|
||
|
||
class EncoderLayer(nn.Module):
|
||
"""
|
||
Autoformer encoder layer with the progressive decomposition architecture
|
||
"""
|
||
|
||
def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"):
|
||
super(EncoderLayer, self).__init__()
|
||
d_ff = d_ff or 4 * d_model
|
||
self.attention = attention
|
||
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
|
||
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
|
||
self.decomp1 = series_decomp(moving_avg)
|
||
self.decomp2 = series_decomp(moving_avg)
|
||
self.dropout = nn.Dropout(dropout)
|
||
self.activation = F.relu if activation == "relu" else F.gelu
|
||
|
||
def forward(self, x, attn_mask=None):
|
||
new_x, attn = self.attention(
|
||
x, x, x,
|
||
attn_mask=attn_mask
|
||
)
|
||
x = x + self.dropout(new_x)
|
||
x, _ = self.decomp1(x)
|
||
y = x
|
||
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
||
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
||
res, _ = self.decomp2(x + y)
|
||
return res, attn
|
||
|
||
|
||
class Encoder(nn.Module):
|
||
"""
|
||
Autoformer encoder
|
||
"""
|
||
|
||
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
|
||
super(Encoder, self).__init__()
|
||
self.attn_layers = nn.ModuleList(attn_layers)
|
||
self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
|
||
self.norm = norm_layer
|
||
|
||
def forward(self, x, attn_mask=None):
|
||
attns = []
|
||
if self.conv_layers is not None:
|
||
for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
|
||
x, attn = attn_layer(x, attn_mask=attn_mask)
|
||
x = conv_layer(x)
|
||
attns.append(attn)
|
||
x, attn = self.attn_layers[-1](x)
|
||
attns.append(attn)
|
||
else:
|
||
for attn_layer in self.attn_layers:
|
||
x, attn = attn_layer(x, attn_mask=attn_mask)
|
||
attns.append(attn)
|
||
|
||
if self.norm is not None:
|
||
x = self.norm(x)
|
||
|
||
return x, attns
|
||
|
||
|
||
class DecoderLayer(nn.Module):
|
||
"""
|
||
Autoformer decoder layer with the progressive decomposition architecture
|
||
"""
|
||
|
||
def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None,
|
||
moving_avg=25, dropout=0.1, activation="relu"):
|
||
super(DecoderLayer, self).__init__()
|
||
d_ff = d_ff or 4 * d_model
|
||
self.self_attention = self_attention
|
||
self.cross_attention = cross_attention
|
||
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
|
||
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
|
||
self.decomp1 = series_decomp(moving_avg)
|
||
self.decomp2 = series_decomp(moving_avg)
|
||
self.decomp3 = series_decomp(moving_avg)
|
||
self.dropout = nn.Dropout(dropout)
|
||
self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1,
|
||
padding_mode='circular', bias=False)
|
||
self.activation = F.relu if activation == "relu" else F.gelu
|
||
|
||
def forward(self, x, cross, x_mask=None, cross_mask=None):
|
||
x = x + self.dropout(self.self_attention(
|
||
x, x, x,
|
||
attn_mask=x_mask
|
||
)[0])
|
||
x, trend1 = self.decomp1(x)
|
||
x = x + self.dropout(self.cross_attention(
|
||
x, cross, cross,
|
||
attn_mask=cross_mask
|
||
)[0])
|
||
x, trend2 = self.decomp2(x)
|
||
y = x
|
||
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
||
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
||
x, trend3 = self.decomp3(x + y)
|
||
|
||
residual_trend = trend1 + trend2 + trend3
|
||
residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2)
|
||
return x, residual_trend
|
||
|
||
|
||
class Decoder(nn.Module):
|
||
"""
|
||
Autoformer encoder
|
||
"""
|
||
|
||
def __init__(self, layers, norm_layer=None, projection=None):
|
||
super(Decoder, self).__init__()
|
||
self.layers = nn.ModuleList(layers)
|
||
self.norm = norm_layer
|
||
self.projection = projection
|
||
|
||
def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
|
||
for layer in self.layers:
|
||
x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
|
||
trend = trend + residual_trend
|
||
|
||
if self.norm is not None:
|
||
x = self.norm(x)
|
||
|
||
if self.projection is not None:
|
||
x = self.projection(x)
|
||
return x, trend
|
||
|
||
|
||
class FEDformer(nn.Module):
|
||
"""
|
||
FEDformer performs the attention mechanism on frequency domain and achieved O(N) complexity
|
||
Paper link: https://proceedings.mlr.press/v162/zhou22g.html
|
||
Namespace(task_name='long_term_forecast', is_training=1, model_id='ETTh1_96_96', model='FEDformer', data='ETTh1',
|
||
root_path='./dataset/ETT-small/', data_path='ETTh1.csv', features='M', target='OT', freq='h', checkpoints='./checkpoints/',
|
||
seq_len=96, label_len=48, pred_len=96, seasonal_patterns='Monthly', inverse=False, mask_rate=0.25, anomaly_ratio=0.25,
|
||
expand=2, d_conv=4, top_k=5, num_kernels=6, enc_in=7, dec_in=7, c_out=7, d_model=16, n_heads=8, e_layers=2, d_layers=1,
|
||
d_ff=32, moving_avg=25, factor=3, distil=True, dropout=0.1, embed='timeF', activation='gelu', output_attention=False,
|
||
channel_independence=1, decomp_method='moving_avg', use_norm=1, down_sampling_layers=0, down_sampling_window=1, down_sampling_method=None,
|
||
seg_len=48, num_workers=0, itr=1, train_epochs=100, batch_size=32, patience=3, learning_rate=0.0001, des="'Exp'", loss='MSE',
|
||
lradj='type1', use_amp=False, use_gpu=True, gpu=0, use_multi_gpu=False, devices='0,1,2,3', p_hidden_dims=[128, 128],
|
||
p_hidden_layers=2, use_dtw=False, augmentation_ratio=0, seed=2, jitter=False, scaling=False, permutation=False,
|
||
randompermutation=False, magwarp=False, timewarp=False, windowslice=False, windowwarp=False, rotation=False,
|
||
spawner=False, dtwwarp=False, shapedtwwarp=False, wdba=False, discdtw=False, discsdtw=False, extra_tag='')
|
||
"""
|
||
|
||
def __init__(self,task_name='short_term_forecast',seq_len=96, label_len=48, pred_len=96, enc_in=7, dec_in=7, c_out=1, e_layers=2, d_layers=1, n_heads=8,factor=3,
|
||
d_model=16, d_ff=32, des='Exp', expand=2, d_conv=4, top_k=5, embed='timeF',freq='h', dropout=0.1,num_kernels=6,
|
||
moving_avg=25,channel_independence=1, decomp_method='moving_avg', use_norm=1,
|
||
version='fourier', mode_select='random', modes=32, activation='gelu',seasonal_patterns='Monthly',
|
||
inverse=False, mask_rate=0.25, anomaly_ratio=0.25,output_attention=False,down_sampling_layers=0, down_sampling_window=1, down_sampling_method=None,
|
||
seg_len=48, num_workers=0, itr=1, train_epochs=100, batch_size=32, patience=3, learning_rate=0.0001, loss='MSE',
|
||
lradj='type1', use_amp=False, use_gpu=True, gpu=0, use_multi_gpu=False, devices='0,1,2,3', p_hidden_dims=[128, 128],
|
||
p_hidden_layers=2, use_dtw=False, augmentation_ratio=0, seed=2, jitter=False, scaling=False, permutation=False,
|
||
randompermutation=False, magwarp=False, timewarp=False, windowslice=False, windowwarp=False, rotation=False,
|
||
spawner=False, dtwwarp=False, shapedtwwarp=False, wdba=False, discdtw=False, discsdtw=False, extra_tag='',
|
||
**kwargs):
|
||
"""
|
||
version: str, for FEDformer, there are two versions to choose, options: [Fourier, Wavelets].
|
||
mode_select: str, for FEDformer, there are two mode selection method, options: [random, low].
|
||
modes: int, modes to be selected.
|
||
"""
|
||
super(FEDformer, self).__init__()
|
||
self.task_name = task_name
|
||
self.seq_len = seq_len
|
||
self.label_len = label_len
|
||
self.pred_len = pred_len
|
||
|
||
self.version = version
|
||
self.mode_select = mode_select
|
||
self.modes = modes
|
||
|
||
# Decomp
|
||
self.decomp = series_decomp(moving_avg)
|
||
self.enc_embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout)
|
||
self.dec_embedding = DataEmbedding(dec_in, d_model, embed, freq, dropout)
|
||
|
||
if self.version == 'Wavelets':
|
||
encoder_self_att = MultiWaveletTransform(ich=d_model, L=1, base='legendre')
|
||
decoder_self_att = MultiWaveletTransform(ich=d_model, L=1, base='legendre')
|
||
decoder_cross_att = MultiWaveletCross(in_channels=d_model,
|
||
out_channels=d_model,
|
||
seq_len_q=self.seq_len // 2 + self.pred_len,
|
||
seq_len_kv=self.seq_len,
|
||
modes=self.modes,
|
||
ich=d_model,
|
||
base='legendre',
|
||
activation='tanh')
|
||
else:
|
||
encoder_self_att = FourierBlock(in_channels=d_model,
|
||
out_channels=d_model,
|
||
seq_len=self.seq_len,
|
||
modes=self.modes,
|
||
mode_select_method=self.mode_select)
|
||
decoder_self_att = FourierBlock(in_channels=d_model,
|
||
out_channels=d_model,
|
||
seq_len=self.seq_len // 2 + self.pred_len,
|
||
modes=self.modes,
|
||
mode_select_method=self.mode_select)
|
||
decoder_cross_att = FourierCrossAttention(in_channels=d_model,
|
||
out_channels=d_model,
|
||
seq_len_q=self.seq_len // 2 + self.pred_len,
|
||
seq_len_kv=self.seq_len,
|
||
modes=self.modes,
|
||
mode_select_method=self.mode_select,
|
||
num_heads=n_heads)
|
||
# Encoder
|
||
self.encoder = Encoder(
|
||
[
|
||
EncoderLayer(
|
||
AutoCorrelationLayer(
|
||
encoder_self_att, # instead of multi-head attention in transformer
|
||
d_model, n_heads),
|
||
d_model,
|
||
d_ff,
|
||
moving_avg=moving_avg,
|
||
dropout=dropout,
|
||
activation=activation
|
||
) for l in range(e_layers)
|
||
],
|
||
norm_layer=my_Layernorm(d_model)
|
||
)
|
||
# Decoder
|
||
self.decoder = Decoder(
|
||
[
|
||
DecoderLayer(
|
||
AutoCorrelationLayer(
|
||
decoder_self_att,
|
||
d_model, n_heads),
|
||
AutoCorrelationLayer(
|
||
decoder_cross_att,
|
||
d_model, n_heads),
|
||
d_model,
|
||
c_out,
|
||
d_ff,
|
||
moving_avg=moving_avg,
|
||
dropout=dropout,
|
||
activation=activation,
|
||
)
|
||
for l in range(d_layers)
|
||
],
|
||
norm_layer=my_Layernorm(d_model),
|
||
projection=nn.Linear(d_model, c_out, bias=True)
|
||
)
|
||
self.projection_final=nn.Linear(pred_len*enc_in, pred_len*c_out, bias=True)
|
||
|
||
|
||
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
|
||
# ----------------------------- Step 1: 分解 -------------------------------
|
||
# decomp init
|
||
mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1)
|
||
zeros = torch.zeros([x_dec.shape[0], self.pred_len, x_dec.shape[2]], device=x_enc.device)
|
||
seasonal_init, trend_init = self.decomp(x_enc) # x - moving_avg, moving_avg
|
||
# decoder input
|
||
if self.label_len==0:
|
||
trend_init = trend_init #mean
|
||
seasonal_init = seasonal_init
|
||
else:
|
||
trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)
|
||
seasonal_init = F.pad(seasonal_init[:, -self.label_len:, :], (0, 0, 0, self.pred_len))
|
||
# seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1)
|
||
# enc
|
||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||
# dec
|
||
dec_out = self.dec_embedding(seasonal_init, x_mark_dec)
|
||
seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None, trend=trend_init)
|
||
# final
|
||
dec_out = trend_part + seasonal_part
|
||
dec_out = self.projection_final(dec_out[:, -self.pred_len:, :].view(dec_out.shape[0],-1))
|
||
return dec_out
|
||
|
||
class FEDFormerNetModel(BaseModel):
|
||
def __init__(self,seq_len=24, label_len=0, pred_len=1, enc_in=7, dec_in=7, c_out=1, e_layers=2,
|
||
d_layers=1, factor=3, d_model=16, d_ff=32, des='Exp', itr=1, top_k=5,embed='timeF',freq='h', dropout=0.1,num_kernels=6, **kwargs):
|
||
|
||
# saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this
|
||
self.save_hyperparameters()
|
||
# pass additional arguments to BaseModel.__init__, mandatory call - do not skip this
|
||
super().__init__(**kwargs)
|
||
self.network = FEDformer(
|
||
seq_len=seq_len, label_len=label_len, pred_len=pred_len, enc_in=enc_in, dec_in=dec_in, c_out=c_out, e_layers=e_layers, d_layers=d_layers, factor=factor,
|
||
d_model=d_model, d_ff=d_ff, des=des, itr=itr, top_k=top_k, embed=embed, freq=freq, dropout=dropout, num_kernels=num_kernels
|
||
)
|
||
self.label_len=label_len
|
||
|
||
# 修改,锂电池预测
|
||
def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||
|
||
x_enc = x["encoder_cont"][:,:,:-1] # torch.Size([100, 10, 9])
|
||
x_dec = torch.cat([x["encoder_cont"][:, -self.label_len:, :-1], x["decoder_cont"][:,:,:-1]],
|
||
dim=1) # torch.Size([100, 11, 9])
|
||
# 输出
|
||
prediction = self.network(x_enc=x_enc,x_mark_enc=None,x_dec=x_dec,x_mark_dec=None)
|
||
# 输出rescale, rescale predictions into target space
|
||
prediction = self.transform_output(prediction, target_scale=x["target_scale"])
|
||
|
||
# 返回一个字典,包含输出结果(prediction)
|
||
return self.to_network_output(prediction=prediction)
|
||
|
||
|
||
|
||
if __name__=='__main__':
|
||
N,L,C=100,8,11
|
||
label_len = 0
|
||
x_enc=torch.ones((N,L,C))
|
||
x_mark_enc=torch.ones((N, L, 4))
|
||
x_mark_dec=torch.ones((N, L+label_len, 4))
|
||
model=FEDformer(seq_len=L, enc_in=C, dec_in=C, label_len = label_len, pred_len=1, c_out=1) # pred_len 被限制了
|
||
out=model(x_enc=x_enc, x_mark_enc=x_mark_enc, x_dec=None, x_mark_dec=x_mark_dec)
|
||
print(out.shape)
|
||
|
||
|
||
|
||
|
||
|