import torch import torch.nn as nn # from .layers.Autoformer_EncDec import series_decomp # from .layers.Embed import DataEmbedding_wo_pos # from .layers.StandardNorm import Normalize import math import torch.nn.functional as F from torch.nn.utils import weight_norm class Normalize(nn.Module): def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False): """ :param num_features: the number of features or channels :param eps: a value added for numerical stability :param affine: if True, RevIN has learnable affine parameters """ super(Normalize, self).__init__() self.num_features = num_features self.eps = eps self.affine = affine self.subtract_last = subtract_last self.non_norm = non_norm if self.affine: self._init_params() def forward(self, x, mode: str): if mode == 'norm': self._get_statistics(x) x = self._normalize(x) elif mode == 'denorm': x = self._denormalize(x) else: raise NotImplementedError return x def _init_params(self): # initialize RevIN params: (C,) self.affine_weight = nn.Parameter(torch.ones(self.num_features)) self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) def _get_statistics(self, x): dim2reduce = tuple(range(1, x.ndim - 1)) if self.subtract_last: self.last = x[:, -1, :].unsqueeze(1) else: self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() def _normalize(self, x): if self.non_norm: return x if self.subtract_last: x = x - self.last else: x = x - self.mean x = x / self.stdev if self.affine: x = x * self.affine_weight x = x + self.affine_bias return x def _denormalize(self, x): if self.non_norm: return x if self.affine: x = x - self.affine_bias x = x / (self.affine_weight + self.eps * self.eps) x = x * self.stdev if self.subtract_last: x = x + self.last else: x = x + self.mean return x class my_Layernorm(nn.Module): """ Special designed layernorm for the seasonal part """ def __init__(self, channels): super(my_Layernorm, self).__init__() self.layernorm = nn.LayerNorm(channels) def forward(self, x): x_hat = self.layernorm(x) bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1) return x_hat - bias class moving_avg(nn.Module): """ Moving average block to highlight the trend of time series """ def __init__(self, kernel_size, stride): super(moving_avg, self).__init__() self.kernel_size = kernel_size self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) def forward(self, x): # padding on the both ends of time series front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) x = torch.cat([front, x, end], dim=1) x = self.avg(x.permute(0, 2, 1)) x = x.permute(0, 2, 1) return x class series_decomp(nn.Module): """ Series decomposition block """ def __init__(self, kernel_size): super(series_decomp, self).__init__() self.moving_avg = moving_avg(kernel_size, stride=1) def forward(self, x): moving_mean = self.moving_avg(x) res = x - moving_mean return res, moving_mean class series_decomp_multi(nn.Module): """ Multiple Series decomposition block from FEDformer """ def __init__(self, kernel_size): super(series_decomp_multi, self).__init__() self.kernel_size = kernel_size self.series_decomp = [series_decomp(kernel) for kernel in kernel_size] def forward(self, x): moving_mean = [] res = [] for func in self.series_decomp: sea, moving_avg = func(x) moving_mean.append(moving_avg) res.append(sea) sea = sum(res) / len(res) moving_mean = sum(moving_mean) / len(moving_mean) return sea, moving_mean class EncoderLayer(nn.Module): """ Autoformer encoder layer with the progressive decomposition architecture """ def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"): super(EncoderLayer, self).__init__() d_ff = d_ff or 4 * d_model self.attention = attention self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) self.decomp1 = series_decomp(moving_avg) self.decomp2 = series_decomp(moving_avg) self.dropout = nn.Dropout(dropout) self.activation = F.relu if activation == "relu" else F.gelu def forward(self, x, attn_mask=None): new_x, attn = self.attention( x, x, x, attn_mask=attn_mask ) x = x + self.dropout(new_x) x, _ = self.decomp1(x) y = x y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) y = self.dropout(self.conv2(y).transpose(-1, 1)) res, _ = self.decomp2(x + y) return res, attn class Encoder(nn.Module): """ Autoformer encoder """ def __init__(self, attn_layers, conv_layers=None, norm_layer=None): super(Encoder, self).__init__() self.attn_layers = nn.ModuleList(attn_layers) self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None self.norm = norm_layer def forward(self, x, attn_mask=None): attns = [] if self.conv_layers is not None: for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): x, attn = attn_layer(x, attn_mask=attn_mask) x = conv_layer(x) attns.append(attn) x, attn = self.attn_layers[-1](x) attns.append(attn) else: for attn_layer in self.attn_layers: x, attn = attn_layer(x, attn_mask=attn_mask) attns.append(attn) if self.norm is not None: x = self.norm(x) return x, attns class DecoderLayer(nn.Module): """ Autoformer decoder layer with the progressive decomposition architecture """ def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"): super(DecoderLayer, self).__init__() d_ff = d_ff or 4 * d_model self.self_attention = self_attention self.cross_attention = cross_attention self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) self.decomp1 = series_decomp(moving_avg) self.decomp2 = series_decomp(moving_avg) self.decomp3 = series_decomp(moving_avg) self.dropout = nn.Dropout(dropout) self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1, padding_mode='circular', bias=False) self.activation = F.relu if activation == "relu" else F.gelu def forward(self, x, cross, x_mask=None, cross_mask=None): x = x + self.dropout(self.self_attention( x, x, x, attn_mask=x_mask )[0]) x, trend1 = self.decomp1(x) x = x + self.dropout(self.cross_attention( x, cross, cross, attn_mask=cross_mask )[0]) x, trend2 = self.decomp2(x) y = x y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) y = self.dropout(self.conv2(y).transpose(-1, 1)) x, trend3 = self.decomp3(x + y) residual_trend = trend1 + trend2 + trend3 residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2) return x, residual_trend class Decoder(nn.Module): """ Autoformer encoder """ def __init__(self, layers, norm_layer=None, projection=None): super(Decoder, self).__init__() self.layers = nn.ModuleList(layers) self.norm = norm_layer self.projection = projection def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None): for layer in self.layers: x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) trend = trend + residual_trend if self.norm is not None: x = self.norm(x) if self.projection is not None: x = self.projection(x) return x, trend class PositionalEmbedding(nn.Module): def __init__(self, d_model, max_len=5000): super(PositionalEmbedding, self).__init__() # Compute the positional encodings once in log space. pe = torch.zeros(max_len, d_model).float() pe.require_grad = False position = torch.arange(0, max_len).float().unsqueeze(1) div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer('pe', pe) def forward(self, x): return self.pe[:, :x.size(1)] class TokenEmbedding(nn.Module): def __init__(self, c_in, d_model): super(TokenEmbedding, self).__init__() padding = 1 if torch.__version__ >= '1.5.0' else 2 self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, kernel_size=3, padding=padding, padding_mode='circular', bias=False) for m in self.modules(): if isinstance(m, nn.Conv1d): nn.init.kaiming_normal_( m.weight, mode='fan_in', nonlinearity='leaky_relu') def forward(self, x): x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) return x class FixedEmbedding(nn.Module): def __init__(self, c_in, d_model): super(FixedEmbedding, self).__init__() w = torch.zeros(c_in, d_model).float() w.require_grad = False position = torch.arange(0, c_in).float().unsqueeze(1) div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() w[:, 0::2] = torch.sin(position * div_term) w[:, 1::2] = torch.cos(position * div_term) self.emb = nn.Embedding(c_in, d_model) self.emb.weight = nn.Parameter(w, requires_grad=False) def forward(self, x): return self.emb(x).detach() class TemporalEmbedding(nn.Module): def __init__(self, d_model, embed_type='fixed', freq='h'): super(TemporalEmbedding, self).__init__() minute_size = 4 hour_size = 24 weekday_size = 7 day_size = 32 month_size = 13 Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding if freq == 't': self.minute_embed = Embed(minute_size, d_model) self.hour_embed = Embed(hour_size, d_model) self.weekday_embed = Embed(weekday_size, d_model) self.day_embed = Embed(day_size, d_model) self.month_embed = Embed(month_size, d_model) def forward(self, x): x = x.long() minute_x = self.minute_embed(x[:, :, 4]) if hasattr( self, 'minute_embed') else 0. hour_x = self.hour_embed(x[:, :, 3]) weekday_x = self.weekday_embed(x[:, :, 2]) day_x = self.day_embed(x[:, :, 1]) month_x = self.month_embed(x[:, :, 0]) return hour_x + weekday_x + day_x + month_x + minute_x class TimeFeatureEmbedding(nn.Module): def __init__(self, d_model, embed_type='timeF', freq='h'): super(TimeFeatureEmbedding, self).__init__() freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} d_inp = freq_map[freq] self.embed = nn.Linear(d_inp, d_model, bias=False) def forward(self, x): return self.embed(x) class DataEmbedding(nn.Module): def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): super(DataEmbedding, self).__init__() self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) self.position_embedding = PositionalEmbedding(d_model=d_model) self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( d_model=d_model, embed_type=embed_type, freq=freq) self.dropout = nn.Dropout(p=dropout) def forward(self, x, x_mark): if x_mark is None: x = self.value_embedding(x) + self.position_embedding(x) else: x = self.value_embedding( x) + self.temporal_embedding(x_mark) + self.position_embedding(x) return self.dropout(x) class DataEmbedding_inverted(nn.Module): def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): super(DataEmbedding_inverted, self).__init__() self.value_embedding = nn.Linear(c_in, d_model) self.dropout = nn.Dropout(p=dropout) def forward(self, x, x_mark): x = x.permute(0, 2, 1) # x: [Batch Variate Time] if x_mark is None: x = self.value_embedding(x) else: x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) # x: [Batch Variate d_model] return self.dropout(x) class DataEmbedding_wo_pos(nn.Module): def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): super(DataEmbedding_wo_pos, self).__init__() self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) self.position_embedding = PositionalEmbedding(d_model=d_model) self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( d_model=d_model, embed_type=embed_type, freq=freq) self.dropout = nn.Dropout(p=dropout) def forward(self, x, x_mark): if x_mark is None: x = self.value_embedding(x) else: x = self.value_embedding(x) + self.temporal_embedding(x_mark) return self.dropout(x) # class PatchEmbedding(nn.Module): def __init__(self, d_model, patch_len, stride, padding, dropout): super(PatchEmbedding, self).__init__() # Patching self.patch_len = patch_len self.stride = stride self.padding_patch_layer = nn.ReplicationPad1d((0, padding)) # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space self.value_embedding = nn.Linear(patch_len, d_model, bias=False) # Positional embedding self.position_embedding = PositionalEmbedding(d_model) # Residual dropout self.dropout = nn.Dropout(dropout) def forward(self, x): # do patching n_vars = x.shape[1] x = self.padding_patch_layer(x) x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # Input encoding x = self.value_embedding(x) + self.position_embedding(x) return self.dropout(x), n_vars # 周期和趋势分解,基于傅里叶变化,前面k个为周期,其他为趋势。 class DFT_series_decomp(nn.Module): """ Series decomposition block """ def __init__(self, top_k=5): super(DFT_series_decomp, self).__init__() self.top_k = top_k def forward(self, x): xf = torch.fft.rfft(x) freq = abs(xf) freq[0] = 0 top_k_freq, top_list = torch.topk(freq, 5) xf[freq <= top_k_freq.min()] = 0 x_season = torch.fft.irfft(xf) x_trend = x - x_season return x_season, x_trend class MultiScaleSeasonMixing(nn.Module): """ Bottom-up mixing season pattern """ def __init__(self, seq_len,down_sampling_window = 2, down_sampling_layers = 3): super(MultiScaleSeasonMixing, self).__init__() self.down_sampling_layers = torch.nn.ModuleList( [ nn.Sequential( torch.nn.Linear( seq_len // (down_sampling_window ** i), seq_len // (down_sampling_window ** (i + 1)), ), nn.GELU(), torch.nn.Linear( seq_len // (down_sampling_window ** (i + 1)), seq_len // (down_sampling_window ** (i + 1)), ), ) for i in range(down_sampling_layers) ] ) def forward(self, season_list): # mixing high->low out_high = season_list[0] out_low = season_list[1] out_season_list = [out_high.permute(0, 2, 1)] for i in range(len(season_list) - 1): out_low_res = self.down_sampling_layers[i](out_high) out_low = out_low + out_low_res out_high = out_low if i + 2 <= len(season_list) - 1: out_low = season_list[i + 2] out_season_list.append(out_high.permute(0, 2, 1)) return out_season_list class MultiScaleTrendMixing(nn.Module): """ Top-down mixing trend pattern """ def __init__(self, seq_len,down_sampling_window = 2, down_sampling_layers = 3): super(MultiScaleTrendMixing, self).__init__() self.up_sampling_layers = torch.nn.ModuleList( [ nn.Sequential( torch.nn.Linear( seq_len // (down_sampling_window ** (i + 1)), seq_len // (down_sampling_window ** i), ), nn.GELU(), torch.nn.Linear( seq_len // (down_sampling_window ** i), seq_len // (down_sampling_window ** i), ), ) for i in reversed(range(down_sampling_layers)) ]) def forward(self, trend_list): # mixing low->high trend_list_reverse = trend_list.copy() trend_list_reverse.reverse() out_low = trend_list_reverse[0] out_high = trend_list_reverse[1] out_trend_list = [out_low.permute(0, 2, 1)] for i in range(len(trend_list_reverse) - 1): out_high_res = self.up_sampling_layers[i](out_low) out_high = out_high + out_high_res out_low = out_high if i + 2 <= len(trend_list_reverse) - 1: out_high = trend_list_reverse[i + 2] out_trend_list.append(out_low.permute(0, 2, 1)) out_trend_list.reverse() return out_trend_list class PastDecomposableMixing(nn.Module): def __init__(self, seq_len=96, pred_len=1, d_model=16, d_ff=32, top_k=5, dropout=0.1,moving_avg=25, channel_independence=1, decomp_method='moving_avg', down_sampling_layers=3,down_sampling_window = 2): super(PastDecomposableMixing, self).__init__() self.seq_len = seq_len self.pred_len = pred_len self.down_sampling_window = down_sampling_window self.layer_norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.channel_independence = channel_independence if decomp_method == 'moving_avg': self.decompsition = series_decomp(moving_avg) elif decomp_method == "dft_decomp": self.decompsition = DFT_series_decomp(top_k) else: raise ValueError('decompsition is error') if channel_independence == 0: self.cross_layer = nn.Sequential( nn.Linear(in_features=d_model, out_features=d_ff), nn.GELU(), nn.Linear(in_features=d_ff, out_features=d_model), ) # Mixing season self.mixing_multi_scale_season = MultiScaleSeasonMixing(seq_len=seq_len, down_sampling_window = down_sampling_window, down_sampling_layers = down_sampling_layers) # Mxing trend self.mixing_multi_scale_trend = MultiScaleTrendMixing(seq_len=seq_len, down_sampling_window = down_sampling_window, down_sampling_layers = down_sampling_layers) self.out_cross_layer = nn.Sequential( nn.Linear(in_features=d_model, out_features=d_ff), nn.GELU(), nn.Linear(in_features=d_ff, out_features=d_model), ) def forward(self, x_list): length_list = [] for x in x_list: _, T, _ = x.size() length_list.append(T) # Decompose to obtain the season and trend season_list = [] trend_list = [] for x in x_list: season, trend = self.decompsition(x) if self.channel_independence == 0: season = self.cross_layer(season) trend = self.cross_layer(trend) season_list.append(season.permute(0, 2, 1)) trend_list.append(trend.permute(0, 2, 1)) # bottom-up season mixing out_season_list = self.mixing_multi_scale_season(season_list) # top-down trend mixing out_trend_list = self.mixing_multi_scale_trend(trend_list) out_list = [] for ori, out_season, out_trend, length in zip(x_list, out_season_list, out_trend_list, length_list): out = out_season + out_trend if self.channel_independence: out = ori + self.out_cross_layer(out) out_list.append(out[:, :length, :]) return out_list class TimeMixer(nn.Module): def __init__(self, task_name='short_term_forecast',seq_len=96, label_len=0, pred_len=96, enc_in=7, dec_in=7, c_out=1, e_layers=2, d_layers=1, n_heads=8,factor=3, d_model=16, d_ff=32, des='Exp', expand=2, d_conv=4, top_k=5, embed='timeF',freq='h', dropout=0.1,num_kernels=6, moving_avg=25,channel_independence=1, decomp_method='moving_avg', use_norm=1, down_sampling_layers=3,down_sampling_window = 2, down_sampling_method = 'avg', version='fourier', mode_select='random', modes=32, activation='gelu',seasonal_patterns='Monthly', inverse=False, mask_rate=0.25, anomaly_ratio=0.25,output_attention=False, seg_len=48, num_workers=0, itr=1, train_epochs=100, batch_size=32, patience=10, learning_rate=0.0001, loss='MSE', lradj='type1', use_amp=False, use_gpu=True, gpu=0, use_multi_gpu=False, devices='0,1,2,3', p_hidden_dims=[128, 128], p_hidden_layers=2, use_dtw=False, augmentation_ratio=0, seed=2, jitter=False, scaling=False, permutation=False, randompermutation=False, magwarp=False, timewarp=False, windowslice=False, windowwarp=False, rotation=False, spawner=False, dtwwarp=False, shapedtwwarp=False, wdba=False, discdtw=False, discsdtw=False, extra_tag='', **kwargs): super(TimeMixer, self).__init__() self.task_name = task_name self.seq_len = seq_len self.label_len = label_len self.pred_len = pred_len self.down_sampling_window = down_sampling_window self.down_sampling_layers = down_sampling_layers self.channel_independence = channel_independence self.c_out=c_out self.pdm_blocks = nn.ModuleList([PastDecomposableMixing(seq_len=seq_len, pred_len=pred_len, d_model=d_model, d_ff=d_ff, top_k=top_k, dropout=dropout,moving_avg=moving_avg, channel_independence=channel_independence, decomp_method=decomp_method, down_sampling_layers=down_sampling_layers,down_sampling_window = down_sampling_window) for _ in range(e_layers)]) self.down_sampling_method=down_sampling_method self.preprocess = series_decomp(moving_avg) self.enc_in = enc_in if self.channel_independence == 1: self.enc_embedding = DataEmbedding_wo_pos(1, d_model, embed, freq, dropout) else: self.enc_embedding = DataEmbedding_wo_pos(enc_in, d_model, embed, freq, dropout) self.layer = e_layers if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': self.predict_layers = torch.nn.ModuleList( [ torch.nn.Linear( seq_len // (down_sampling_window ** i), pred_len, ) for i in range(down_sampling_layers + 1) ] ) if self.channel_independence == 1: self.projection_layer = nn.Linear( d_model, 1, bias=True) else: self.projection_layer = nn.Linear( d_model, c_out, bias=True) self.out_res_layers = torch.nn.ModuleList([ torch.nn.Linear( seq_len // (down_sampling_window ** i), seq_len // (down_sampling_window ** i), ) for i in range(down_sampling_layers + 1) ]) self.regression_layers = torch.nn.ModuleList( [ torch.nn.Linear( seq_len // (down_sampling_window ** i), pred_len, ) for i in range(down_sampling_layers + 1) ] ) self.normalize_layers = torch.nn.ModuleList( [ Normalize(self.enc_in, affine=True, non_norm=True if use_norm == 0 else False) for i in range(down_sampling_layers + 1) ] ) self.projection_final = nn.Linear(pred_len*enc_in, pred_len*c_out, bias=True) def out_projection(self, dec_out, i, out_res): dec_out = self.projection_layer(dec_out) out_res = out_res.permute(0, 2, 1) out_res = self.out_res_layers[i](out_res) out_res = self.regression_layers[i](out_res).permute(0, 2, 1) dec_out = dec_out + out_res return dec_out def pre_enc(self, x_list): if self.channel_independence == 1: return (x_list, None) else: out1_list = [] out2_list = [] for x in x_list: x_1, x_2 = self.preprocess(x) out1_list.append(x_1) out2_list.append(x_2) return (out1_list, out2_list) def __multi_scale_process_inputs(self, x_enc, x_mark_enc): if self.down_sampling_method == 'max': down_pool = torch.nn.MaxPool1d(self.down_sampling_window, return_indices=False) elif self.down_sampling_method == 'avg': down_pool = torch.nn.AvgPool1d(self.down_sampling_window) elif self.down_sampling_method == 'conv': padding = 1 if torch.__version__ >= '1.5.0' else 2 down_pool = nn.Conv1d(in_channels=self.enc_in, out_channels=self.enc_in, kernel_size=3, padding=padding, stride=self.down_sampling_window, padding_mode='circular', bias=False) else: return x_enc, x_mark_enc # B,T,C -> B,C,T x_enc = x_enc.permute(0, 2, 1) x_enc_ori = x_enc x_mark_enc_mark_ori = x_mark_enc x_enc_sampling_list = [] x_mark_sampling_list = [] x_enc_sampling_list.append(x_enc.permute(0, 2, 1)) x_mark_sampling_list.append(x_mark_enc) for i in range(self.down_sampling_layers): x_enc_sampling = down_pool(x_enc_ori) x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1)) x_enc_ori = x_enc_sampling if x_mark_enc is not None: x_mark_sampling_list.append(x_mark_enc_mark_ori[:, ::self.down_sampling_window, :]) x_mark_enc_mark_ori = x_mark_enc_mark_ori[:, ::self.down_sampling_window, :] x_enc = x_enc_sampling_list x_mark_enc = x_mark_sampling_list if x_mark_enc is not None else None return x_enc, x_mark_enc def forward(self, x_enc, x_mark_enc, x_dec=None, x_mark_dec=None): # 生成多尺度的数据 [100, 96, 10] [100, 96, 4] -> x_enc is dict x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc) x_list = [] # 存储归一化后的数据 x_mark_list = [] if x_mark_enc is not None: for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc): B, T, N = x.size() # 归一化 x = self.normalize_layers[i](x, 'norm') if self.channel_independence == 1: x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) # (B*N, 96, 1),(B*N, 48, 1),(B*N, 24, 1),(B*N, 12, 1) x_list.append(x) x_mark = x_mark.repeat(N, 1, 1) x_mark_list.append(x_mark) else: for i, x in zip(range(len(x_enc)), x_enc, ): B, T, N = x.size() x = self.normalize_layers[i](x, 'norm') if self.channel_independence == 1: x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) x_list.append(x) # embedding enc_out_list = [] x_list = self.pre_enc(x_list) # 分解 if x_mark_enc is not None: for i, x, x_mark in zip(range(len(x_list[0])), x_list[0], x_mark_list): enc_out = self.enc_embedding(x, x_mark) # [B,T,C] [1000,96,1] -》[1000,96,16], ... enc_out_list.append(enc_out) else: for i, x in zip(range(len(x_list[0])), x_list[0]): enc_out = self.enc_embedding(x, None) # [B,T,C] enc_out_list.append(enc_out) # Past Decomposable Mixing as encoder for past for i in range(self.layer): enc_out_list = self.pdm_blocks[i](enc_out_list) # Future Multipredictor Mixing as decoder for future dec_out_list = self.future_multi_mixing(B, enc_out_list, x_list) # [1000,96/48/24/12,16] -》dict 4个[100, 1, 10] dec_out = torch.stack(dec_out_list, dim=-1).sum(-1) # 求和 [100, 1, 10] dec_out = self.normalize_layers[0](dec_out, 'denorm') dec_out = self.projection_final(dec_out.view(dec_out.shape[0], -1)) # 10->1 return dec_out def future_multi_mixing(self, B, enc_out_list, x_list): dec_out_list = [] if self.channel_independence == 1: x_list = x_list[0] for i, enc_out in zip(range(len(x_list)), enc_out_list): dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute( 0, 2, 1) # align temporal dimension dec_out = self.projection_layer(dec_out) dec_out = dec_out.reshape(B, self.enc_in, self.pred_len).permute(0, 2, 1).contiguous() dec_out_list.append(dec_out) else: for i, enc_out, out_res in zip(range(len(x_list[0])), enc_out_list, x_list[1]): dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute( 0, 2, 1) # align temporal dimension dec_out = self.out_projection(dec_out, i, out_res) dec_out_list.append(dec_out) return dec_out_list from pytorch_forecasting.models import BaseModel from typing import Dict class TimeMixerNetModel(BaseModel): def __init__(self,seq_len=24, label_len=0, pred_len=1, enc_in=7, c_out=1, e_layers=2, d_layers=1, factor=3, d_model=16, d_ff=32, des='Exp', itr=1, top_k=5,embed='timeF',freq='h', dropout=0.1,num_kernels=6, **kwargs): # saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this self.save_hyperparameters() # pass additional arguments to BaseModel.__init__, mandatory call - do not skip this super().__init__(**kwargs) self.network = TimeMixer( seq_len=seq_len, label_len=label_len, pred_len=pred_len, enc_in=enc_in, c_out=c_out, e_layers=e_layers, d_layers=d_layers, factor=factor, d_model=d_model, d_ff=d_ff, des=des, itr=itr, top_k=top_k, embed=embed, freq=freq, dropout=dropout, num_kernels=num_kernels ) # 修改,锂电池预测 def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: x_enc = x["encoder_cont"][:,:,:-1] # 输出 prediction = self.network(x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None) # 输出rescale, rescale predictions into target space prediction = self.transform_output(prediction, target_scale=x["target_scale"]) # 返回一个字典,包含输出结果(prediction) return self.to_network_output(prediction=prediction) if __name__=='__main__': N,L,C=100,96,10 label_len = 0 c_out = 1 pred_len=1 x_enc=torch.ones((N,L,C)) x_mark_enc=torch.ones((N, L, 4)) x_dec = torch.ones((N, pred_len, C)) x_mark_dec=torch.ones((N, pred_len, 4)) model=TimeMixer(seq_len=L, enc_in=C, dec_in=C, label_len = label_len, pred_len=pred_len, c_out=1) # pred_len 被限制了 out=model(x_enc=x_enc, x_mark_enc=x_mark_enc, x_dec=x_dec, x_mark_dec=x_mark_dec) print(out.shape)