Source code for ltsm.models.ltsm_stat_model

import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from .base_config import LTSMConfig
from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
from transformers import AutoModel, AutoConfig, AutoTokenizer

[docs] class LTSM(PreTrainedModel): config_class = LTSMConfig def __init__(self, configs, *model_args, **model_kwargs): super().__init__(configs) self.patch_size = configs.patch_size self.pretrain = configs.pretrain self.stride = configs.stride self.patch_num = (configs.seq_len + configs.prompt_len - self.patch_size) // self.stride + 1 self.d_type = torch.bfloat16 self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride)) self.patch_num += 1 self.configs = configs if configs.pretrain: print("Loading the pretraining weight.") self.llm_config = AutoConfig.from_pretrained(configs.model_name_or_path) self.llm = AutoModel.from_pretrained(configs.model_name_or_path) # loads a pretrained GPT-2 base model else: raise NotImplementedError("You must load the pretraining weight.") self.model_prune(configs) print("model = {}".format(self.llm)) self.in_layer = nn.Linear(configs.patch_size, self.llm_config.hidden_size) self.out_layer = nn.Linear(self.llm_config.hidden_size * self.patch_num, configs.pred_len) self.cnt = 0
[docs] def model_prune(self, configs): if "gpt2" in configs.model_name_or_path: self.llm.h = self.llm.h[:configs.gpt_layers] elif "phi" in configs.model_name_or_path or "llama" in configs.model_name_or_path or "gemma" in configs.model_name_or_path: self.llm.layers = self.llm.layers[:configs.gpt_layers] else: raise NotImplementedError(f"No implementation in model prune for {self.llm}.")
[docs] def forward(self, x, return_feature=False): B, L, M = x.shape means = x.mean(1, keepdim=True).detach() x = x - means stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False)+ 1e-5).detach() x /= stdev x = rearrange(x, 'b l m -> b m l') x = self.padding_patch_layer(x) x = x.unfold(dimension=-1, size=self.patch_size, step=self.stride) x = rearrange(x, 'b m n p -> (b m) n p') outputs = self.in_layer(x).to(dtype=torch.bfloat16) outputs = self.llm(inputs_embeds=outputs).last_hidden_state if return_feature: return outputs outputs = outputs.to(dtype=x.dtype) outputs = self.out_layer(outputs.reshape(B*M, -1)) outputs = rearrange(outputs, '(b m) l -> b l m', b=B) outputs = outputs * stdev outputs = outputs + means return outputs