import torch import pickle import random import numpy as np from tqdm import tqdm import pandas as pd import os import time import argparse from model import * import pyarrow as pa import pyarrow.parquet as pq parser = argparse.ArgumentParser(description='add parameters') parser.add_argument('--seed', type=int, default=2024) parser.add_argument('--samples', type=int, default=20) # Parse the arguments args = parser.parse_args() SEED = args.seed random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") index_to_code = pickle.load(open("indexToCode.pkl", "rb")) class Config(): total_vocab_size = 9100 code_vocab_size = 9072 label_vocab_size = 25 n_ctx=54 config= Config() model = torch.load('./generator.pth',map_location=device) model = model.to(device) model.eval() def sample_sequence(model, length, context, batch_size, device='cuda', sample=True): empty = torch.zeros((1,1,config.total_vocab_size), device=device, dtype=torch.float32).repeat(batch_size, 1, 1) context = torch.tensor(context, device=device, dtype=torch.float32).unsqueeze(0).repeat(batch_size, 1) prev = context.unsqueeze(1) context = None with torch.no_grad(): for _ in range(length-1): prev = model.sample(torch.cat((prev,empty), dim=1), sample) if torch.sum(torch.sum(prev[:,:,config.code_vocab_size+config.label_vocab_size+1], dim=1).bool().int(), dim=0).item() == batch_size: break ehr = prev.cpu().detach().numpy() prev = None empty = None return ehr def convert_ehr(ehrs, index_to_code=None): ehr_outputs = [] for i in range(len(ehrs)): ehr = ehrs[i] ehr_output = [] labels_output = ehr[1][config.code_vocab_size:config.code_vocab_size+config.label_vocab_size] if index_to_code is not None: labels_output = [index_to_code[idx + config.code_vocab_size] for idx in np.nonzero(labels_output)[0]] for j in range(2, len(ehr)): visit = ehr[j] visit_output = [] indices = np.nonzero(visit)[0] end = False for idx in indices: if idx < config.code_vocab_size: visit_output.append(index_to_code[idx] if index_to_code is not None else idx) elif idx == config.code_vocab_size+config.label_vocab_size+1: end = True if visit_output != []: ehr_output.append(visit_output) if end: break ehr_outputs.append({'visits': ehr_output, 'labels': labels_output}) ehr = None ehr_output = None labels_output = None visit = None visit_output = None indices = None return ehr_outputs generate_batch_size = 10 # Generate Synthetic EHR dataset generate_num_samples = args.samples synthetic_ehr_dataset = [] stoken = np.zeros(config.total_vocab_size) stoken[config.code_vocab_size+config.label_vocab_size] = 1 for i in tqdm(range(0, generate_num_samples, generate_batch_size)): start_time = time.time() bs = min([generate_num_samples-i, generate_batch_size]) batch_synthetic_ehrs = sample_sequence(model, config.n_ctx, stoken, batch_size=bs, device=device, sample=True) batch_synthetic_ehrs = convert_ehr(batch_synthetic_ehrs) end_time = time.time() execution_time_seconds = end_time - start_time execution_time_minutes = execution_time_seconds / 60 print(f"Execution time of every batch of {generate_batch_size} samples: {execution_time_minutes:.6f} minutes") filepath = f'./data/raw_data.csv' if os.path.isfile(filepath) is False: tdf = pd.DataFrame(batch_synthetic_ehrs)[['visits']] tdf.to_csv(filepath, mode='a+', index=False) else: tdf = pd.DataFrame(batch_synthetic_ehrs)[['visits']] tdf.to_csv(filepath, mode='a+', index=False, header=False) import pandas as pd import numpy as np def random_dates(n=10): start_u = pd.Timestamp(1800, 1, 1).timestamp() end_u = pd.Timestamp(2000, 1, 1).timestamp() return pd.to_datetime(np.random.randint(start_u, end_u, n), unit='s') def sample_time_interval(num_visit=2): from datetime import timedelta generate_itvdays = [0] for current_visit in range(num_visit-1): generate_itvdays.append( np.random.randint(20, 60) + generate_itvdays[-1]) sampled_visit_time = random_dates(1)[0] trans2days = [sampled_visit_time+timedelta(days=item) for item in generate_itvdays] return trans2days # synthetic timestamp SynDF = pd.read_csv('./data/raw_data.csv') print("The Length of Whole Generated Samples are:{}".format(len(SynDF))) SynDF['visits'] = SynDF['visits'].apply(lambda x:eval(x)) SynDF['NumVisit'] = SynDF['visits'].apply(lambda x: len(x)) SynDF['ADMITTIME'] = SynDF['NumVisit'].apply(lambda x:sample_time_interval(x)) SynDF['ICD'] = SynDF['visits'].apply(lambda x: [[index_to_code[v] for v in vs] if isinstance(vs, list) else [] for vs in x ]) saveDF = SynDF[['ICD','ADMITTIME']] table = pa.Table.from_pandas(saveDF) pq.write_table(table, f'./data/middle_state.parquet')