# code from https://github.com/yuqinie98/PatchTST, with minor modifications
import torch
from torch import Tensor
from .base_config import PatchTSTConfig
from ltsm.layers.PatchTST_backbone import PatchTST_backbone
from ltsm.layers.PatchTST_layers import series_decomp
from transformers import PreTrainedModel
[docs]
class PatchTST(PreTrainedModel):
    config_class = PatchTSTConfig
    
    def __init__(self, config: PatchTSTConfig, **kwargs):
        super().__init__(config)
        self.decomposition = config.decomposition
        if self.decomposition:
            self.decomp_module = series_decomp(config.kernel_size)
            self.model_trend = PatchTST_backbone(
                config.enc_in,
                config.seq_len,
                config.pred_len,
                config.patch_len, 
                config.stride,
                config.max_seq_len,
                config.n_layers,
                config.d_model,
                config.n_heads,
                config.d_k,
                config.d_v,
                config.d_ff,
                config.norm,
                config.attn_dropout,
                config.dropout,
                config.activation,
                config.key_padding_mask,
                config.padding_var,
                config.attn_mask,
                config.res_attention,
                config.pre_norm,
                config.store_attn,
                config.pe,
                config.learn_pe,
                config.fc_dropout,
                config.head_dropout,
                config.padding_patch,
                config.pretrain_head,
                config.head_type,
                config.individual,
                config.revin,
                config.affine,
                config.subtract_last,
                config.verbose
            )
            self.model_res = PatchTST_backbone(
                config.enc_in, 
                config.seq_len, 
                config.pred_len, 
                config.patch_len, 
                config.stride,
                config.max_seq_len,
                config.n_layers,
                config.d_model,
                config.n_heads,
                config.d_k,
                config.d_v,
                config.d_ff,
                config.norm,
                config.attn_dropout,
                config.dropout,
                config.activation,
                config.key_padding_mask,
                config.padding_var,
                config.attn_mask,
                config.res_attention,
                config.pre_norm,
                config.store_attn,
                config.pe,
                config.learn_pe,
                config.fc_dropout,
                config.head_dropout,
                config.padding_patch,
                config.pretrain_head,
                config.head_type,
                config.individual,
                config.revin,
                config.affine,
                config.subtract_last,
                config.verbose
            )
        else:
            self.model = PatchTST_backbone(
                config.enc_in, 
                config.seq_len, 
                config.pred_len, 
                config.patch_len, 
                config.stride,
                config.max_seq_len,
                config.n_layers,
                config.d_model,
                config.n_heads,
                config.d_k,
                config.d_v,
                config.d_ff,
                config.norm,
                config.attn_dropout,
                config.dropout,
                config.activation,
                config.key_padding_mask,
                config.padding_var,
                config.attn_mask,
                config.res_attention,
                config.pre_norm,
                config.store_attn,
                config.pe,
                config.learn_pe,
                config.fc_dropout,
                config.head_dropout,
                config.padding_patch,
                config.pretrain_head,
                config.head_type,
                config.individual,
                config.revin,
                config.affine,
                config.subtract_last,
                config.verbose
            )
[docs]
    def forward(self, x: Tensor):
        if self.decomposition:
            res_init, trend_init = self.decomp_module(x)
            res_init, trend_init = res_init.permute(0, 2, 1), trend_init.permute(0, 2, 1)  # [Batch, Channel, Input length]
            res = self.model_res(res_init)
            trend = self.model_trend(trend_init)
            x = res + trend
            x = x.permute(0, 2, 1)  # [Batch, Input length, Channel]
        else:
            x = x.permute(0, 2, 1)  # [Batch, Channel, Input length]
            x = self.model(x)
            x = x.permute(0, 2, 1)  # [Batch, Input length, Channel]
        return x