Source code for ltsm.models.utils

import numpy as np
import torch
import torch.nn as nn
from math import sqrt
from transformers.modeling_utils import PreTrainedModel, PretrainedConfig

[docs] class Normalize(nn.Module): def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False): """ :param num_features: the number of features or channels :param eps: a value added for numerical stability :param affine: if True, RevIN has learnable affine parameters """ super(Normalize, self).__init__() self.num_features = num_features self.eps = eps self.affine = affine self.subtract_last = subtract_last self.non_norm = non_norm if self.affine: self._init_params()
[docs] def forward(self, x, mode: str): if mode == 'norm': self._get_statistics(x) x = self._normalize(x) elif mode == 'denorm': x = self._denormalize(x) else: raise NotImplementedError return x
def _init_params(self): # initialize RevIN params: (C,) self.affine_weight = nn.Parameter(torch.ones(self.num_features)) self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) def _get_statistics(self, x): dim2reduce = tuple(range(1, x.ndim - 1)) if self.subtract_last: self.last = x[:, -1, :].unsqueeze(1) else: self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() def _normalize(self, x): if self.non_norm: return x if self.subtract_last: x = x - self.last else: x = x - self.mean x = x / self.stdev if self.affine: x = x * self.affine_weight x = x + self.affine_bias return x def _denormalize(self, x): if self.non_norm: return x if self.affine: x = x - self.affine_bias x = x / (self.affine_weight + self.eps * self.eps) x = x * self.stdev if self.subtract_last: x = x + self.last else: x = x + self.mean return x
[docs] class FlattenHead(nn.Module): def __init__(self, n_vars, nf, target_window, head_dropout=0): super().__init__() self.n_vars = n_vars self.flatten = nn.Flatten(start_dim=-2) self.linear = nn.Linear(nf, target_window) self.dropout = nn.Dropout(head_dropout)
[docs] def forward(self, x): x = self.flatten(x) x = self.linear(x) x = self.dropout(x) return x
[docs] class ReprogrammingLayer(nn.Module): def __init__(self, d_model, n_heads, d_keys=None, d_llm=None, attention_dropout=0.1): super(ReprogrammingLayer, self).__init__() d_keys = d_keys or (d_model // n_heads) self.query_projection = nn.Linear(d_model, d_keys * n_heads) self.key_projection = nn.Linear(d_llm, d_keys * n_heads) self.value_projection = nn.Linear(d_llm, d_keys * n_heads) self.out_projection = nn.Linear(d_keys * n_heads, d_llm) self.n_heads = n_heads self.dropout = nn.Dropout(attention_dropout)
[docs] def forward(self, target_embedding, source_embedding, value_embedding): B, L, _ = target_embedding.shape S, _ = source_embedding.shape H = self.n_heads target_embedding = self.query_projection(target_embedding).view(B, L, H, -1) source_embedding = self.key_projection(source_embedding).view(S, H, -1) value_embedding = self.value_projection(value_embedding).view(S, H, -1) out = self.reprogramming(target_embedding, source_embedding, value_embedding) out = out.reshape(B, L, -1) return self.out_projection(out)
[docs] def reprogramming(self, target_embedding, source_embedding, value_embedding): B, L, H, E = target_embedding.shape scale = 1. / sqrt(E) scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding) A = self.dropout(torch.softmax(scale * scores, dim=-1)) reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding) return reprogramming_embedding
[docs] def freeze_parameters(model: PreTrainedModel): """ Sets certain model parameters to non-trainable, and specific parameters to trainable, based on predefined lists of layer names to freeze or keep trainable. """ freeze_param_buf = ["gpt2"] for n, p in model.named_parameters(): if any(fp in n for fp in freeze_param_buf): p.requires_grad = False print(f"{n} has been freeezed") trainable_param_buf = ["ln", "wpe", "in_layer", "out_layer", "lora"] for n, p in model.named_parameters(): if any(fp in n for fp in trainable_param_buf): p.requires_grad = True