Source code for ltsm.data_pipeline.model_manager

from ltsm.models import get_model
from ltsm.models.base_config import LTSMConfig
import torch
from torch import nn
import numpy as np
from peft import get_peft_model, LoraConfig
from transformers import (
    EvalPrediction,
)

# Deprecated: This class is not used in the current implementation.
[docs] class ModelManager: """ Manages model creation, parameter settings, optimizer, and evaluation metrics for training. Attributes: args (argparse.Namespace): Configuration and hyperparameters for model training. model (torch.nn.Module): The model to be trained, created based on configuration. optimizer (torch.optim.Optimizer): Optimizer for model parameter updates. scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler. """ def __init__(self, config: LTSMConfig): """ Initializes the ModelManager with provided arguments and default values for model, optimizer, and scheduler. Args: args (argparse.Namespace): Training configurations and hyperparameters. """ self.args = args self.model = None self.optimizer = None self.scheduler = None
[docs] def print_trainable_parameters(self): """ Prints the names of parameters in the model that are trainable. """ for n, p in self.model.named_parameters(): if p.requires_grad: print(f"{n} is trainable...")
[docs] def create_model(self): """ Initializes and configures the model based on specified arguments, including options for freezing parameters or applying LoRA (Low-Rank Adaptation). Returns: torch.nn.Module: The configured model ready for training. """ model_config = LTSMConfig(**vars(self.args)) self.model = get_model(model_config) if self.args.lora: peft_config = LoraConfig( target_modules=["c_attn"], inference_mode=False, r=self.args.lora_dim, lora_alpha=32, lora_dropout=0.1 ) self.model = get_peft_model(self.model, peft_config) self.model.print_trainable_parameters() elif self.args.freeze: self.freeze_parameters() self.print_trainable_parameters() # Optimizer settings return self.model
[docs] def set_optimizer(self): """ Configures the optimizer and learning rate scheduler for the model training. Uses Adam optimizer and cosine annealing learning rate scheduler. """ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.learning_rate) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.args.tmax, eta_min=1e-8)
[docs] def compute_metrics(self, p: EvalPrediction): """ Computes evaluation metrics for model predictions. Args: p (EvalPrediction): Contains predictions and label IDs. Returns: dict: Dictionary containing Mean Squared Error (MSE) and Mean Absolute Error (MAE). """ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions preds = np.squeeze(preds) if preds.shape != p.label_ids.shape: label_ids = np.squeeze(p.label_ids) else: label_ids = p.label_ids return { "mse": ((preds - label_ids) ** 2).mean().item(), "mae": (np.abs(preds - label_ids)).mean().item() }
[docs] def compute_loss(self, model, inputs, return_outputs=False): """ Computes the loss for model training. Args: model (torch.nn.Module): The model used for predictions. inputs (dict): Input data and labels. return_outputs (bool): If True, returns both loss and model outputs. Returns: torch.Tensor or tuple: The computed loss, and optionally the outputs. """ if self.args.model == 'Informer': input_data_mark = inputs["timestamp_input"].to(model.module.device) label_mark = inputs["timestamp_labels"].to(model.module.device) outputs = model(inputs["input_data"], input_data_mark, inputs["labels"], label_mark) else: outputs = model(inputs["input_data"]) loss = nn.functional.mse_loss(outputs, inputs["labels"]) return (loss, outputs) if return_outputs else loss
[docs] @torch.no_grad() def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None): """ Makes a prediction step, computing loss and returning model outputs without gradients. Args: model (torch.nn.Module): The model used for predictions. inputs (dict): Input data and labels. prediction_loss_only (bool): If True, returns only the loss. ignore_keys (list): Keys to ignore in inputs. Returns: tuple: The loss, outputs, and labels. """ input_data = inputs["input_data"].to(model.module.device) labels = inputs["labels"].to(model.module.device) if self.args.model == 'Informer': input_data_mark = inputs["timestamp_input"].to(model.module.device) label_mark = inputs["timestamp_labels"].to(model.module.device) outputs = model(input_data, input_data_mark, labels, label_mark) else: outputs = model(input_data) loss = nn.functional.mse_loss(outputs, labels) return (loss, outputs, labels)
[docs] def collate_fn(self, batch): """ Collates a batch of data into tensors for model training. Args: batch (list): List of data samples with 'input_data' and 'labels' keys. Returns: dict: Collated batch with 'input_data' and 'labels' tensors. """ if self.args.model == 'Informer': return { 'input_data': torch.from_numpy(np.stack([x['input_data'] for x in batch])).type(torch.float32), 'labels': torch.from_numpy(np.stack([x['labels'] for x in batch])).type(torch.float32), 'timestamp_input': torch.from_numpy(np.stack([x['timestamp_input'] for x in batch])).type(torch.float32), 'timestamp_labels': torch.from_numpy(np.stack([x['timestamp_labels'] for x in batch])).type(torch.float32) } return { 'input_data': torch.from_numpy(np.stack([x['input_data'] for x in batch])).type(torch.float32), 'labels': torch.from_numpy(np.stack([x['labels'] for x in batch])).type(torch.float32), }