Source code for ltsm.models

from .ltsm_stat_model import LTSM
from .ltsm_wordprompt import LTSM_WordPrompt
from .ltsm_ts_tokenizer import LTSM_Tokenizer
from .PatchTST import PatchTST
from .DLinear import DLinear
from .Informer import Informer
from transformers import PretrainedConfig, PreTrainedModel

model_dict = {}

[docs] def register_model(module, module_name: str): """ Registers a PreTrainedModel module into the model dictionary. Args: module: A Python module or class that implements a PreTrainedModel. module_name (str): The key name for the module in the model dictionary. Raises: AssertionError: If a model with the same name is already registered """ assert module_name not in model_dict, f"Reader {module_name} already registered" model_dict[module_name] = module
register_model(LTSM, 'LTSM') register_model(LTSM_WordPrompt, 'LTSM_WordPrompt') register_model(LTSM_Tokenizer, 'LTSM_Tokenizer') register_model(PatchTST, 'PatchTST') register_model(DLinear, 'DLinear') register_model(Informer, 'Informer')
[docs] def get_model(config: PretrainedConfig, model_name: str, local_pretrain: str = None, hf_hub_model: str = None) -> PreTrainedModel: """ Factory method to create a model by name. Args: config (PreTrainedConfig): The configuration for the model. model_name (str): The name of the model to instantiate. local_pretrain (bool): If True, load the model from a local pretraining path. hf_hub_model (str): The Hugging Face Hub model name. Returns: torch.nn.Module: Instantiated model. Raises: ValueError: If the model name is not found in model_dict. """ if model_name not in model_dict: raise ValueError(f"Model {model_name} is not registered. Available models: {list(model_dict.keys())}") # Load pretrained weights if hf_hub_model is provided if hf_hub_model is not None: return model_dict[model_name].from_pretrained(hf_hub_model, config) # Check for local pretraining if local_pretrain is None or local_pretrain == "None": return model_dict[model_name](config) else: model_config = PretrainedConfig.from_pretrained(local_pretrain) return model_dict[model_name].from_pretrained(local_pretrain, model_config)
__all__ = { register_model, get_model, PatchTST, DLinear, Informer, LTSM, LTSM_WordPrompt, LTSM_Tokenizer }