Source code for ltsm.prompt_reader.stat_prompt.prompt_tsne

# from ltsm.data_provider.data_factory import get_data_loader, get_data_loaders, get_dataset
import argparse
import ipdb
import pandas as pd
import numpy as np
# import tsfel
from pandas import read_csv, read_feather
import matplotlib.pyplot as plt
import sys, os
import torch
from sklearn.preprocessing import StandardScaler
from sklearn import manifold


[docs] def get_args(): parser = argparse.ArgumentParser(description='LTSM') parser.add_argument('--data_path', type=str, default='dataset/weather.csv') parser.add_argument('--data', type=str, default='custom') parser.add_argument('--freq', type=str, default="h") parser.add_argument('--target', type=str, default='OT') parser.add_argument('--embed', type=str, default='timeF') parser.add_argument('--percent', type=int, default=10) parser.add_argument('--batch_size', type=int, default=512) parser.add_argument('--max_len', type=int, default=-1) parser.add_argument('--seq_len', type=int, default=512) parser.add_argument('--pred_len', type=int, default=96) parser.add_argument('--label_len', type=int, default=48) parser.add_argument('--features', type=str, default='M') args = parser.parse_args() return args
[docs] def prompt_generation(ts): cfg = tsfel.get_features_by_domain() prompt = tsfel.time_series_features_extractor(cfg, ts) return prompt
[docs] def prompt_prune(pt): pt_dict = pt.to_dict() pt_keys = list(pt_dict.keys()) for key in pt_keys: if key.startswith("0_FFT mean coefficient"): del pt[key] return pt
if __name__ == "__main__": root_path = "./prompt_bank/stat-prompt/prompt_data_split/" # print(data_path_buf) dataset_name = [ "electricity", "ETT-small", "exchange_rate", "illness", "traffic", "weather", ] split_buf = ["train", "val", "test"] dataset_fullname_train = [os.path.join(root_path, "train", name) for name in dataset_name] dataset_fullname_val = [os.path.join(root_path, "val", name) for name in dataset_name] dataset_fullname_test = [os.path.join(root_path, "test", name) for name in dataset_name] dataset_fullname = dataset_fullname_train + dataset_fullname_val + dataset_fullname_test data_path_buf = [] dataset_dir_buf = [] dataset_split_buf = [] K = 100 for index, dataset_dir in enumerate(dataset_fullname): paths = os.listdir(dataset_dir) new_dataset = [os.path.join(dataset_dir, path) for path in paths] sample_idx = np.random.permutation(len(new_dataset))[:K].astype(np.int64) # ipdb.set_trace() new_dataset = np.array(new_dataset)[sample_idx].tolist() data_path_buf.extend(new_dataset) for dataset_index, dname in enumerate(dataset_name): if dname in dataset_dir: dataset_dir_buf.extend(len(new_dataset) * [dataset_index]) for split_index, split in enumerate(split_buf): if split in dataset_dir: dataset_split_buf.extend(len(new_dataset) * [split_index]) break prompt_data_buf = [] for index, dataset_path in enumerate(data_path_buf): prompt_data = torch.load(dataset_path) prompt_data_buf.append(prompt_data) print("Import from {}".format(dataset_path)) # print(prompt_data) # if index == 100: # break # print(prompt_data_buf) # print(output_path_buf) prompt_data_all = pd.concat(prompt_data_buf, axis=0).values print(prompt_data_all.shape) # (3166, 133) # nan_index = np.where(np.isnan(prompt_data_all))[0] # prompt_data_all[nan_index] = 0 # ipdb.set_trace() tsne = manifold.TSNE(n_components=2, init='pca', random_state=0) prompt_data_tsne = tsne.fit_transform(prompt_data_all) dataset_plot_buf = ["electricity"] color_buf = ["red", "blue", "black", "green", "pink", "brown"] marker_buf = [".", "^", "x"] for index, _ in enumerate(dataset_name): for sindex, split_fold in enumerate(split_buf): data_index = (np.array(dataset_dir_buf) == index) split_index = (np.array(dataset_split_buf) == sindex) plot_index = data_index & split_index plt.plot(prompt_data_tsne[plot_index, 0], prompt_data_tsne[plot_index, 1], linewidth=0, marker=marker_buf[sindex], label=str(dataset_name[index][0:8] + "-" + split_fold), color=color_buf[index]) # plt.text(prompt_data_tsne[data_index, 0].mean()-20, prompt_data_tsne[data_index, 1].mean(), str(dataset_name[index][0:8]), fontdict={'weight': 'bold', 'size': 9}) plt.legend(loc="right") plt.savefig("./figures/stat_prompt_tsne.png") plt.close() # ipdb.set_trace() # plt.xticks([]) # plt.yticks([]) # print(prompt_data_all) # , color = plt.cm.Set1(dataset_dir_buf[index]) # print(prompt_data_transform) # print(prompt_data_transform_array.mean(axis=0)) # print(prompt_data_transform_array.std(axis=0)) # print(prompt_data_transform.loc[5])