# from ltsm.data_provider.data_factory import get_data_loader, get_data_loaders, get_dataset
import argparse
import ipdb
import pandas as pd
import numpy as np
# import tsfel
from pandas import read_csv, read_feather
import matplotlib.pyplot as plt
import sys, os
import torch
from sklearn.preprocessing import StandardScaler
from sklearn import manifold
[docs]
def get_args():
parser = argparse.ArgumentParser(description='LTSM')
parser.add_argument('--data_path', type=str, default='dataset/weather.csv')
parser.add_argument('--data', type=str, default='custom')
parser.add_argument('--freq', type=str, default="h")
parser.add_argument('--target', type=str, default='OT')
parser.add_argument('--embed', type=str, default='timeF')
parser.add_argument('--percent', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--max_len', type=int, default=-1)
parser.add_argument('--seq_len', type=int, default=512)
parser.add_argument('--pred_len', type=int, default=96)
parser.add_argument('--label_len', type=int, default=48)
parser.add_argument('--features', type=str, default='M')
args = parser.parse_args()
return args
[docs]
def prompt_generation(ts):
cfg = tsfel.get_features_by_domain()
prompt = tsfel.time_series_features_extractor(cfg, ts)
return prompt
[docs]
def prompt_prune(pt):
pt_dict = pt.to_dict()
pt_keys = list(pt_dict.keys())
for key in pt_keys:
if key.startswith("0_FFT mean coefficient"):
del pt[key]
return pt
if __name__ == "__main__":
root_path = "./prompt_bank/stat-prompt/prompt_data_split/"
# print(data_path_buf)
dataset_name = [
"electricity",
"ETT-small",
"exchange_rate",
"illness",
"traffic",
"weather",
]
split_buf = ["train", "val", "test"]
dataset_fullname_train = [os.path.join(root_path, "train", name) for name in dataset_name]
dataset_fullname_val = [os.path.join(root_path, "val", name) for name in dataset_name]
dataset_fullname_test = [os.path.join(root_path, "test", name) for name in dataset_name]
dataset_fullname = dataset_fullname_train + dataset_fullname_val + dataset_fullname_test
data_path_buf = []
dataset_dir_buf = []
dataset_split_buf = []
K = 100
for index, dataset_dir in enumerate(dataset_fullname):
paths = os.listdir(dataset_dir)
new_dataset = [os.path.join(dataset_dir, path) for path in paths]
sample_idx = np.random.permutation(len(new_dataset))[:K].astype(np.int64)
# ipdb.set_trace()
new_dataset = np.array(new_dataset)[sample_idx].tolist()
data_path_buf.extend(new_dataset)
for dataset_index, dname in enumerate(dataset_name):
if dname in dataset_dir:
dataset_dir_buf.extend(len(new_dataset) * [dataset_index])
for split_index, split in enumerate(split_buf):
if split in dataset_dir:
dataset_split_buf.extend(len(new_dataset) * [split_index])
break
prompt_data_buf = []
for index, dataset_path in enumerate(data_path_buf):
prompt_data = torch.load(dataset_path)
prompt_data_buf.append(prompt_data)
print("Import from {}".format(dataset_path))
# print(prompt_data)
# if index == 100:
# break
# print(prompt_data_buf)
# print(output_path_buf)
prompt_data_all = pd.concat(prompt_data_buf, axis=0).values
print(prompt_data_all.shape)
# (3166, 133)
# nan_index = np.where(np.isnan(prompt_data_all))[0]
# prompt_data_all[nan_index] = 0
# ipdb.set_trace()
tsne = manifold.TSNE(n_components=2, init='pca', random_state=0)
prompt_data_tsne = tsne.fit_transform(prompt_data_all)
dataset_plot_buf = ["electricity"]
color_buf = ["red", "blue", "black", "green", "pink", "brown"]
marker_buf = [".", "^", "x"]
for index, _ in enumerate(dataset_name):
for sindex, split_fold in enumerate(split_buf):
data_index = (np.array(dataset_dir_buf) == index)
split_index = (np.array(dataset_split_buf) == sindex)
plot_index = data_index & split_index
plt.plot(prompt_data_tsne[plot_index, 0], prompt_data_tsne[plot_index, 1], linewidth=0, marker=marker_buf[sindex], label=str(dataset_name[index][0:8] + "-" + split_fold), color=color_buf[index])
# plt.text(prompt_data_tsne[data_index, 0].mean()-20, prompt_data_tsne[data_index, 1].mean(), str(dataset_name[index][0:8]), fontdict={'weight': 'bold', 'size': 9})
plt.legend(loc="right")
plt.savefig("./figures/stat_prompt_tsne.png")
plt.close()
# ipdb.set_trace()
# plt.xticks([])
# plt.yticks([])
# print(prompt_data_all)
# , color = plt.cm.Set1(dataset_dir_buf[index])
# print(prompt_data_transform)
# print(prompt_data_transform_array.mean(axis=0))
# print(prompt_data_transform_array.std(axis=0))
# print(prompt_data_transform.loc[5])