Source code for ltsm.prompt_reader.stat_prompt.prompt_generate_split

# from ltsm.data_provider.data_factory import get_data_loader, get_data_loaders, get_dataset
import argparse
import ipdb
import pandas as pd
import numpy as np
import tsfel
from pandas import read_csv, read_feather
import matplotlib.pyplot as plt
import sys, os
import torch

# def parse_list(arg):
#     """parse a string of comma-separated values into a list
#     e.g. python ./ltsm/prompt_reader/stat_prompt/prompt_generate_split.py --dataset_name ETT-small, illness
#     """
#     return arg.split(',')

[docs] def get_args(): parser = argparse.ArgumentParser(description='LTSM') parser.add_argument('--root_path', type=str, default='./datasets/', help='Root path for datasets') parser.add_argument('--output_path', type=str, default='./prompt_bank/stat-prompt/prompt_data_split/', help='Output path for prompt data') parser.add_argument('--dataset_name', nargs='+', default=[]) parser.add_argument('--save_format', type=str, default='pth.tar',choices=["pth.tar", "csv", "npz"], help='The format to save the data') parser.add_argument('--test', type=bool, default=False) parser.add_argument('--data_path', type=str, default='dataset/weather.csv') parser.add_argument('--data', type=str, default='custom') parser.add_argument('--freq', type=str, default="h") parser.add_argument('--target', type=str, default='OT') parser.add_argument('--embed', type=str, default='timeF') parser.add_argument('--percent', type=int, default=10) parser.add_argument('--batch_size', type=int, default=512) parser.add_argument('--max_len', type=int, default=-1) parser.add_argument('--seq_len', type=int, default=512) parser.add_argument('--pred_len', type=int, default=96) parser.add_argument('--label_len', type=int, default=48) parser.add_argument('--features', type=str, default='M') args = parser.parse_args() return args
[docs] def prompt_prune(pt): pt_dict = pt.to_dict() pt_keys = list(pt_dict.keys()) for key in pt_keys: if key.startswith("0_FFT mean coefficient"): del pt[key] return pt
[docs] def prompt_generation_single(ts): """Generate prompt data for the input time-series data Args: ts (pd.Series): input time-series data """ cfg = tsfel.get_features_by_domain() prompt = tsfel.time_series_features_extractor(cfg, ts) prompt = prompt_prune(prompt) return prompt
[docs] def prompt_generation(ts, ts_name): """Generate prompt data for the input time-series data Args: ts (pd.DataFrame): input time-series data ts_name (str): name of the time-series data """ if ts.shape[1] == 1: return None else: column_name = [name.replace("/", "-") for name in list(ts.columns)] # column_name_map = {} # column_name = [] # for i, name in enumerate(ts.columns): # if not name.isnumeric(): # new_name = str(i) # else: # new_name = name # column_name.append(new_name) # column_name_map[name] = new_name prompt_buf_train = pd.DataFrame(np.zeros((133, ts.shape[1])), columns=column_name) prompt_buf_val = pd.DataFrame(np.zeros((133, ts.shape[1])), columns=column_name) prompt_buf_test = pd.DataFrame(np.zeros((133, ts.shape[1])), columns=column_name) for index, col in ts.T.iterrows(): if "ETT" in ts_name: ts_len = len(ts) t1, t2 = int(0.6*ts_len), int(0.6*ts_len) + int(0.2*ts_len) ts_train, ts_val, ts_test = col[:t1], col[t1:t2].reset_index(drop=True), col[t2:].reset_index(drop=True) else: ts_len = len(ts) t1, t2 = int(0.7 * ts_len), int(0.7 * ts_len) + int(0.1 * ts_len) ts_train, ts_val, ts_test = col[:t1], col[t1:t2].reset_index(drop=True), col[t2:].reset_index(drop=True) prompt_train = prompt_generation_single(ts_train) prompt_val = prompt_generation_single(ts_val) prompt_test = prompt_generation_single(ts_test) prompt_buf_train[index.replace("/", "-")] = prompt_train.T.values prompt_buf_val[index.replace("/", "-")] = prompt_val.T.values prompt_buf_test[index.replace("/", "-")] = prompt_test.T.values # new_index = column_name_map[index] # prompt_buf_train[new_index] = prompt_train.T.values # prompt_buf_val[new_index] = prompt_val.T.values # prompt_buf_test[new_index] = prompt_test.T.values prompt_buf_total = {"train": prompt_buf_train, "val": prompt_buf_val, "test": prompt_buf_test} print(prompt_buf_total) return prompt_buf_total
[docs] def prompt_save(prompt_buf, output_path, data_name, save_format="pth.tar", ifTest=False): """save prompts to three different files in the output path Args: prompt_buf (dict): dictionary containing prompts for train, val, and test splits output_path (str): path to save the prompt data data_name (str): name of the dataset save_format (str): format to save the prompt data ifTest (bool): if True, test if the saved prompt data is loaded back. Can be used during generating data. """ if prompt_buf["train"].shape[1] == 1: # ipdb.set_trace() return None # prompt_train_fname = os.path.join(prompt_train_data_dir, data_name + "_prompt.pth.tar") # prompt_train = prompt_buf["train"] # print("Export", prompt_train_fname, prompt_train.shape) # # prompt_val_fname = os.path.join(prompt_val_data_dir, data_name + "_prompt.pth.tar") # prompt_val = prompt_buf["val"] # torch.save(prompt_val, prompt_val_fname) # print("Export", prompt_val_fname, prompt_val.shape) # # prompt_test_fname = os.path.join(prompt_test_data_dir, data_name + "_prompt.pth.tar") # prompt_test = prompt_buf["test"] # torch.save(prompt_test, prompt_test_fname) # print("Export", prompt_test_fname, prompt_test.shape) else: for split in ["train", "val", "test"]: split_dir = os.path.join(output_path, split) for index, col in prompt_buf[split].T.iterrows(): file_name = f"{data_name}_{index}_prompt.{save_format}" file_path = os.path.join(split_dir, file_name) # print("split_dir", split_dir) # print("file_name", file_name) # print("file_path", file_path) prompt_data = col prompt_data.columns = [index] prompt_data = prompt_data.T print("Type of prompt data", type(prompt_data), "Shape of prompt data", prompt_data.shape) if save_format == "pth.tar": torch.save(prompt_data, file_path) elif save_format == "csv": prompt_data.to_csv(file_path, index=False) # use csv may result in some loss of precision elif save_format == "npz": np.savez(file_path, data=prompt_data.values, index=prompt_data.index, name=prompt_data.name) else: raise ValueError(f"Unsupported save format: {save_format}") if ifTest: if save_format == "pth.tar": load_data = torch.load(file_path) elif save_format == "csv": load_data = pd.read_csv(file_path) if isinstance(load_data, pd.DataFrame): load_data = load_data.squeeze() elif save_format == "npz": loaded = np.load(file_path) load_data = pd.Series(data=loaded["data"], index=loaded["index"], name=loaded["name"].item()) if isinstance(load_data, pd.DataFrame): load_data = load_data.squeeze() assert type(load_data) == type(prompt_data), f"Type mismatch: {type(load_data)} vs {type(prompt_data)}" # type should be pd.Series assert load_data.shape == prompt_data.shape, f"Shape mismatch: {load_data.shape} vs {prompt_data.shape}" assert load_data.index.equals(prompt_data.index), "Index mismatch" assert load_data.name == prompt_data.name, f"Series names mismatch: {load_data.name} vs {prompt_data.name}" assert np.allclose(load_data.values, prompt_data.values, rtol=1e-8, atol=1e-8), "Data values mismatch" if save_format != "csv": assert load_data.equals(prompt_data), f"Data mismatch: {load_data} vs {prompt_data}" print("All tests passed for", file_path) print("Export", file_path, prompt_data.shape)
[docs] def data_import(path, format="feather", anomaly=False): if format == "feather": data = read_feather(path) data_name = path.replace(root_path, "").replace(".feather", "") data_dir = data_name[0:data_name.rfind("/")] # ipdb.set_trace() data = data.value else: data = read_csv(path) data_name = path.replace(root_path, "").replace(".csv", "") data_dir = data_name[0:data_name.rfind("/")] if "date" in data.columns: data = data.drop("date", axis=1) if "anomaly" in data.columns: data = data.drop("anomaly", axis=1) print("Drop anomaly column") return data, data_name, data_dir
[docs] def create_data_dir(dir_name): # prompt_dir = if not os.path.exists(dir_name): os.mkdir(dir_name)
if __name__ == "__main__": args = get_args() root_path = args.root_path output_path = args.output_path dataset_name = args.dataset_name save_format = args.save_format ifTest = args.test # if the dataset_name is not provided, use all the datasets in the dataset root path if not dataset_name: dataset_name = [name for name in os.listdir(root_path) if os.path.isdir(os.path.join(root_path, name))] if len(dataset_name) == 0: print("No dataset found in the root path.") sys.exit(0) dataset_fullname = [os.path.join(root_path, name) for name in dataset_name] data_path_buf = [] for dataset_dir in dataset_fullname: for root, dirs, files in os.walk(dataset_dir): for file_name in files: if file_name.endswith(".csv"): file_path = os.path.join(root, file_name) data_path_buf.append(file_path) print(data_path_buf) create_data_dir(output_path) # ipdb.set_trace() for path_idx, path in enumerate(data_path_buf): # print(path) data, data_name, data_dir = data_import(path, "csv") print("*****************Data Name: ", data_name) # print("Data Shape:", data.shape) if data.shape[0] < 20: print(path, "Skip too short time-series data.", data.shape) continue else: print("Import", path, "data shape", data.shape) create_data_dir(os.path.join(output_path, "train")) create_data_dir(os.path.join(output_path, "val")) create_data_dir(os.path.join(output_path, "test")) create_data_dir(os.path.join(output_path, "train", data_dir)) create_data_dir(os.path.join(output_path, "val", data_dir)) create_data_dir(os.path.join(output_path, "test", data_dir)) prompt_data_buf = prompt_generation(data, data_name) if prompt_data_buf is not None: prompt_save(prompt_data_buf, output_path, data_name, save_format,ifTest)