import os import json from typing import Dict from torch.utils.data import Dataset from datasets import Dataset as AdvancedDataset from transformers import AutoTokenizer, AutoModelForCausalLM DEFAULT_TRAIN_DATA_NAME = "test_openprompt.json" DEFAULT_TEST_DATA_NAME = "train_openprompt.json" DEFAULT_DICT_DATA_NAME = "dataset_openprompt.json" def get_open_prompt_data(path_for_data): with open(os.path.join(path_for_data, DEFAULT_TRAIN_DATA_NAME)) as f: train_data = json.load(f) with open(os.path.join(path_for_data, DEFAULT_TEST_DATA_NAME)) as f: test_data = json.load(f) return train_data, test_data def get_tok_and_model(path_for_model): if not os.path.exists(path_for_model): raise RuntimeError("no cached model.") tok = AutoTokenizer.from_pretrained(path_for_model, padding_side='left') tok.pad_token_id = 50256 # default for open-ended generation model = AutoModelForCausalLM.from_pretrained(path_for_model) return tok, model class OpenPromptDataset(Dataset): def __init__(self, data) -> None: super().__init__() self.data = data def __len__(self): return len(self.data) def __getitem__(self, index): return self.data[index] def get_dataset(train_data, test_data): train_dataset = OpenPromptDataset(train_data) test_dataset = OpenPromptDataset(test_data) return train_dataset, test_dataset def get_dict_dataset(path_for_data): with open(os.path.join(path_for_data, DEFAULT_DICT_DATA_NAME)) as f: dict_data = json.load(f) return dict_data def get_advance_dataset(dict_data): if not isinstance(dict_data, Dict): raise RuntimeError("dict_data is not a dict.") dataset = AdvancedDataset.from_dict(dict_data) return dataset