Source code for ltsm.models.ltsm_ts_tokenizer
import torch
from .base_config import LTSMConfig
from transformers.modeling_utils import PreTrainedModel
from transformers import AutoModel, AutoConfig
[docs]
class LTSM_Tokenizer(PreTrainedModel):
config_class = LTSMConfig
def __init__(self, configs):
super().__init__(configs)
self.patch_size = configs.patch_size
self.pretrain = configs.pretrain
self.d_type = torch.bfloat16
self.pred_len = configs.pred_len
if configs.pretrain:
print("Loading the pretraining weight.")
self.llm_config = AutoConfig.from_pretrained(configs.model_name_or_path)
self.llm = AutoModel.from_pretrained(configs.model_name_or_path) # loads a pretrained GPT-2 base model
else:
raise NotImplementedError("You must load the pretraining weight.")
self.model_prune(configs)
print("gpt2 = {}".format(self.llm))
[docs]
def model_prune(self, configs):
if "gpt2" in configs.model_name_or_path:
self.llm.h = self.llm.h[:configs.gpt_layers]
elif "phi" in configs.model_name_or_path or "llama" in configs.model_name_or_path or "gemma" in configs.model_name_or_path:
self.llm.layers = self.llm.layers[:configs.gpt_layers]
else:
raise NotImplementedError(f"No implementation in model prune for {self.llm}.")
[docs]
def forward(self, x):
x = x.int().unsqueeze(-1)
# x = x.int().to(self.llm.device)
# import ipdb; ipdb.set_trace()
outputs = self.llm(input_ids = x).last_hidden_state
outputs = outputs[:, -self.pred_len:, :]
return outputs