# from ltsm.data_provider.data_factory import get_data_loader, get_data_loaders, get_dataset
import argparse
import ipdb
import pandas as pd
import numpy as np
import tsfel
from pandas import read_csv, read_feather
import matplotlib.pyplot as plt
import sys, os
import torch
# def parse_list(arg):
#     """parse a string of comma-separated values into a list
#     e.g. python ./ltsm/prompt_reader/stat_prompt/prompt_generate_split.py --dataset_name ETT-small, illness
#     """
#     return arg.split(',')
[docs]
def get_args():
    parser = argparse.ArgumentParser(description='LTSM')
    parser.add_argument('--root_path', type=str, default='./datasets/', help='Root path for datasets')
    parser.add_argument('--output_path', type=str, default='./prompt_bank/stat-prompt/prompt_data_split/', help='Output path for prompt data')
    parser.add_argument('--dataset_name', nargs='+', default=[])
    parser.add_argument('--save_format', type=str, default='pth.tar',choices=["pth.tar", "csv", "npz"], help='The format to save the data')
    parser.add_argument('--test', type=bool, default=False)
    parser.add_argument('--data_path', type=str, default='dataset/weather.csv')
    parser.add_argument('--data', type=str, default='custom')
    parser.add_argument('--freq', type=str, default="h")
    parser.add_argument('--target', type=str, default='OT')
    parser.add_argument('--embed', type=str, default='timeF')
    parser.add_argument('--percent', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=512)
    parser.add_argument('--max_len', type=int, default=-1)
    parser.add_argument('--seq_len', type=int, default=512)
    parser.add_argument('--pred_len', type=int, default=96)
    parser.add_argument('--label_len', type=int, default=48)
    parser.add_argument('--features', type=str, default='M')
    args = parser.parse_args()
    return args 
[docs]
def prompt_prune(pt):
    pt_dict = pt.to_dict()
    pt_keys = list(pt_dict.keys())
    for key in pt_keys:
        if key.startswith("0_FFT mean coefficient"):
            del pt[key]
    return pt 
[docs]
def prompt_generation_single(ts):
    """Generate prompt data for the input time-series data
    Args:
        ts (pd.Series): input time-series data
    """
    cfg = tsfel.get_features_by_domain()
    prompt = tsfel.time_series_features_extractor(cfg, ts)
    prompt = prompt_prune(prompt)
    return prompt 
[docs]
def prompt_generation(ts, ts_name):
    """Generate prompt data for the input time-series data
    Args:
        ts (pd.DataFrame): input time-series data
        ts_name (str): name of the time-series data
    """
    if ts.shape[1] == 1:
        return None
    else:
        column_name = [name.replace("/", "-") for name in list(ts.columns)]
        # column_name_map = {}
        # column_name = []
        # for i, name in enumerate(ts.columns):
        #     if not name.isnumeric():
        #         new_name = str(i)
        #     else:
        #         new_name = name
        #     column_name.append(new_name)
        #     column_name_map[name] = new_name
        prompt_buf_train = pd.DataFrame(np.zeros((133, ts.shape[1])), columns=column_name)
        prompt_buf_val = pd.DataFrame(np.zeros((133, ts.shape[1])), columns=column_name)
        prompt_buf_test = pd.DataFrame(np.zeros((133, ts.shape[1])), columns=column_name)
        for index, col in ts.T.iterrows():
            if "ETT" in ts_name:
                ts_len = len(ts)
                t1, t2 = int(0.6*ts_len), int(0.6*ts_len) + int(0.2*ts_len)
                ts_train, ts_val, ts_test = col[:t1], col[t1:t2].reset_index(drop=True), col[t2:].reset_index(drop=True)
            else:
                ts_len = len(ts)
                t1, t2 = int(0.7 * ts_len), int(0.7 * ts_len) + int(0.1 * ts_len)
                ts_train, ts_val, ts_test = col[:t1], col[t1:t2].reset_index(drop=True), col[t2:].reset_index(drop=True)
            prompt_train = prompt_generation_single(ts_train)
            prompt_val = prompt_generation_single(ts_val)
            prompt_test = prompt_generation_single(ts_test)
            prompt_buf_train[index.replace("/", "-")] = prompt_train.T.values
            prompt_buf_val[index.replace("/", "-")] = prompt_val.T.values
            prompt_buf_test[index.replace("/", "-")] = prompt_test.T.values
            # new_index = column_name_map[index]
            # prompt_buf_train[new_index] = prompt_train.T.values
            # prompt_buf_val[new_index] = prompt_val.T.values
            # prompt_buf_test[new_index] = prompt_test.T.values
    prompt_buf_total = {"train": prompt_buf_train, "val": prompt_buf_val, "test": prompt_buf_test}
    print(prompt_buf_total)
    return prompt_buf_total 
[docs]
def prompt_save(prompt_buf, output_path, data_name, save_format="pth.tar", ifTest=False):
    """save prompts to three different files in the output path
    Args:
        prompt_buf (dict): dictionary containing prompts for train, val, and test splits
        output_path (str): path to save the prompt data
        data_name (str): name of the dataset
        save_format (str): format to save the prompt data
        ifTest (bool): if True, test if the saved prompt data is loaded back. Can be used during generating data.
    """
    if prompt_buf["train"].shape[1] == 1:
        # ipdb.set_trace()
        return None
        # prompt_train_fname = os.path.join(prompt_train_data_dir, data_name + "_prompt.pth.tar")
        # prompt_train = prompt_buf["train"]
        # print("Export", prompt_train_fname, prompt_train.shape)
        #
        # prompt_val_fname = os.path.join(prompt_val_data_dir, data_name + "_prompt.pth.tar")
        # prompt_val = prompt_buf["val"]
        # torch.save(prompt_val, prompt_val_fname)
        # print("Export", prompt_val_fname, prompt_val.shape)
        #
        # prompt_test_fname = os.path.join(prompt_test_data_dir, data_name + "_prompt.pth.tar")
        # prompt_test = prompt_buf["test"]
        # torch.save(prompt_test, prompt_test_fname)
        # print("Export", prompt_test_fname, prompt_test.shape)
    else:
        for split in ["train", "val", "test"]:
            split_dir = os.path.join(output_path, split)
            for index, col in prompt_buf[split].T.iterrows():
                file_name = f"{data_name}_{index}_prompt.{save_format}"
                file_path = os.path.join(split_dir, file_name)
                # print("split_dir", split_dir)
                # print("file_name", file_name)
                # print("file_path", file_path)
                prompt_data = col
                prompt_data.columns = [index]
                prompt_data = prompt_data.T
                print("Type of prompt data", type(prompt_data), "Shape of prompt data", prompt_data.shape)
                if save_format == "pth.tar":
                    torch.save(prompt_data, file_path)
                elif save_format == "csv":
                    prompt_data.to_csv(file_path, index=False)  # use csv may result in some loss of precision
                elif save_format == "npz":
                    np.savez(file_path, data=prompt_data.values, index=prompt_data.index, name=prompt_data.name)
                else:
                    raise ValueError(f"Unsupported save format: {save_format}")
                if ifTest:
                    if save_format == "pth.tar":
                        load_data = torch.load(file_path)
                    elif save_format == "csv":
                        load_data = pd.read_csv(file_path)
                        if isinstance(load_data, pd.DataFrame):
                            load_data = load_data.squeeze()
                    elif save_format == "npz":
                        loaded = np.load(file_path)
                        load_data = pd.Series(data=loaded["data"], index=loaded["index"], name=loaded["name"].item())
                        if isinstance(load_data, pd.DataFrame):
                            load_data = load_data.squeeze()
                    assert type(load_data) == type(prompt_data), f"Type mismatch: {type(load_data)} vs {type(prompt_data)}"  # type should be pd.Series
                    assert load_data.shape == prompt_data.shape, f"Shape mismatch: {load_data.shape} vs {prompt_data.shape}"
                    assert load_data.index.equals(prompt_data.index), "Index mismatch"
                    assert load_data.name == prompt_data.name, f"Series names mismatch: {load_data.name} vs {prompt_data.name}"
                    assert np.allclose(load_data.values, prompt_data.values, rtol=1e-8, atol=1e-8), "Data values mismatch"
                    if save_format != "csv":
                        assert load_data.equals(prompt_data), f"Data mismatch: {load_data} vs {prompt_data}"
                    print("All tests passed for", file_path)
                print("Export", file_path, prompt_data.shape) 
[docs]
def data_import(path, format="feather", anomaly=False):
    if format == "feather":
        data = read_feather(path)
        data_name = path.replace(root_path, "").replace(".feather", "")
        data_dir = data_name[0:data_name.rfind("/")]
        # ipdb.set_trace()
        data = data.value
    else:
        data = read_csv(path)
        data_name = path.replace(root_path, "").replace(".csv", "")
        data_dir = data_name[0:data_name.rfind("/")]
        if "date" in data.columns:
            data = data.drop("date", axis=1)
        if "anomaly" in data.columns:
            data = data.drop("anomaly", axis=1)
            print("Drop anomaly column")
    return data, data_name, data_dir 
[docs]
def create_data_dir(dir_name):
    # prompt_dir =
    if not os.path.exists(dir_name):
        os.mkdir(dir_name) 
if __name__ == "__main__":
    args = get_args()
    root_path = args.root_path
    output_path = args.output_path
    dataset_name = args.dataset_name
    save_format = args.save_format
    ifTest = args.test    
    # if the dataset_name is not provided, use all the datasets in the dataset root path
    if not dataset_name:
        dataset_name = [name for name in os.listdir(root_path) if os.path.isdir(os.path.join(root_path, name))]
    if len(dataset_name) == 0:
        print("No dataset found in the root path.")
        sys.exit(0)
    dataset_fullname = [os.path.join(root_path, name) for name in dataset_name]
    data_path_buf = []
    for dataset_dir in dataset_fullname:
        for root, dirs, files in os.walk(dataset_dir):
            for file_name in files:
                if file_name.endswith(".csv"):
                    file_path = os.path.join(root, file_name)
                    data_path_buf.append(file_path)
    print(data_path_buf)
    create_data_dir(output_path)
    # ipdb.set_trace()
    for path_idx, path in enumerate(data_path_buf):
        # print(path)
        data, data_name, data_dir = data_import(path, "csv")
        print("*****************Data Name: ", data_name)
        # print("Data Shape:", data.shape)
        if data.shape[0] < 20:
            print(path, "Skip too short time-series data.", data.shape)
            continue
        else:
            print("Import", path, "data shape", data.shape)
        create_data_dir(os.path.join(output_path, "train"))
        create_data_dir(os.path.join(output_path, "val"))
        create_data_dir(os.path.join(output_path, "test"))
        create_data_dir(os.path.join(output_path, "train", data_dir))
        create_data_dir(os.path.join(output_path, "val", data_dir))
        create_data_dir(os.path.join(output_path, "test", data_dir))
        prompt_data_buf = prompt_generation(data, data_name)
        if prompt_data_buf is not None:
            prompt_save(prompt_data_buf, output_path, data_name, save_format,ifTest)