import numpy as np
import torch
import torch.nn as nn
from math import sqrt
from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
[docs]
class Normalize(nn.Module):
def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False):
"""
:param num_features: the number of features or channels
:param eps: a value added for numerical stability
:param affine: if True, RevIN has learnable affine parameters
"""
super(Normalize, self).__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
self.subtract_last = subtract_last
self.non_norm = non_norm
if self.affine:
self._init_params()
[docs]
def forward(self, x, mode: str):
if mode == 'norm':
self._get_statistics(x)
x = self._normalize(x)
elif mode == 'denorm':
x = self._denormalize(x)
else:
raise NotImplementedError
return x
def _init_params(self):
# initialize RevIN params: (C,)
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
def _get_statistics(self, x):
dim2reduce = tuple(range(1, x.ndim - 1))
if self.subtract_last:
self.last = x[:, -1, :].unsqueeze(1)
else:
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
def _normalize(self, x):
if self.non_norm:
return x
if self.subtract_last:
x = x - self.last
else:
x = x - self.mean
x = x / self.stdev
if self.affine:
x = x * self.affine_weight
x = x + self.affine_bias
return x
def _denormalize(self, x):
if self.non_norm:
return x
if self.affine:
x = x - self.affine_bias
x = x / (self.affine_weight + self.eps * self.eps)
x = x * self.stdev
if self.subtract_last:
x = x + self.last
else:
x = x + self.mean
return x
[docs]
class FlattenHead(nn.Module):
def __init__(self, n_vars, nf, target_window, head_dropout=0):
super().__init__()
self.n_vars = n_vars
self.flatten = nn.Flatten(start_dim=-2)
self.linear = nn.Linear(nf, target_window)
self.dropout = nn.Dropout(head_dropout)
[docs]
def forward(self, x):
x = self.flatten(x)
x = self.linear(x)
x = self.dropout(x)
return x
[docs]
class ReprogrammingLayer(nn.Module):
def __init__(self, d_model, n_heads, d_keys=None, d_llm=None, attention_dropout=0.1):
super(ReprogrammingLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads)
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_llm, d_keys * n_heads)
self.value_projection = nn.Linear(d_llm, d_keys * n_heads)
self.out_projection = nn.Linear(d_keys * n_heads, d_llm)
self.n_heads = n_heads
self.dropout = nn.Dropout(attention_dropout)
[docs]
def forward(self, target_embedding, source_embedding, value_embedding):
B, L, _ = target_embedding.shape
S, _ = source_embedding.shape
H = self.n_heads
target_embedding = self.query_projection(target_embedding).view(B, L, H, -1)
source_embedding = self.key_projection(source_embedding).view(S, H, -1)
value_embedding = self.value_projection(value_embedding).view(S, H, -1)
out = self.reprogramming(target_embedding, source_embedding, value_embedding)
out = out.reshape(B, L, -1)
return self.out_projection(out)
[docs]
def reprogramming(self, target_embedding, source_embedding, value_embedding):
B, L, H, E = target_embedding.shape
scale = 1. / sqrt(E)
scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)
return reprogramming_embedding
[docs]
def freeze_parameters(model: PreTrainedModel):
"""
Sets certain model parameters to non-trainable, and specific parameters to trainable, based on predefined
lists of layer names to freeze or keep trainable.
"""
freeze_param_buf = ["gpt2"]
for n, p in model.named_parameters():
if any(fp in n for fp in freeze_param_buf):
p.requires_grad = False
print(f"{n} has been freeezed")
trainable_param_buf = ["ln", "wpe", "in_layer", "out_layer", "lora"]
for n, p in model.named_parameters():
if any(fp in n for fp in trainable_param_buf):
p.requires_grad = True
[docs]
def print_trainable_parameters(model):
"""
Prints the names of parameters in the model that are trainable.
"""
for n, p in model.named_parameters():
if p.requires_grad:
print(f"{n} is trainable...")