Spaces:
Runtime error
Runtime error
import os | |
import json | |
from typing import Dict | |
from torch.utils.data import Dataset | |
from datasets import Dataset as AdvancedDataset | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
DEFAULT_TRAIN_DATA_NAME = "test_openprompt.json" | |
DEFAULT_TEST_DATA_NAME = "train_openprompt.json" | |
DEFAULT_DICT_DATA_NAME = "dataset_openprompt.json" | |
def get_open_prompt_data(path_for_data): | |
with open(os.path.join(path_for_data, DEFAULT_TRAIN_DATA_NAME)) as f: | |
train_data = json.load(f) | |
with open(os.path.join(path_for_data, DEFAULT_TEST_DATA_NAME)) as f: | |
test_data = json.load(f) | |
return train_data, test_data | |
def get_tok_and_model(path_for_model): | |
if not os.path.exists(path_for_model): | |
raise RuntimeError("no cached model.") | |
tok = AutoTokenizer.from_pretrained(path_for_model, padding_side='left') | |
tok.pad_token_id = 50256 | |
# default for open-ended generation | |
model = AutoModelForCausalLM.from_pretrained(path_for_model) | |
return tok, model | |
class OpenPromptDataset(Dataset): | |
def __init__(self, data) -> None: | |
super().__init__() | |
self.data = data | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index): | |
return self.data[index] | |
def get_dataset(train_data, test_data): | |
train_dataset = OpenPromptDataset(train_data) | |
test_dataset = OpenPromptDataset(test_data) | |
return train_dataset, test_dataset | |
def get_dict_dataset(path_for_data): | |
with open(os.path.join(path_for_data, DEFAULT_DICT_DATA_NAME)) as f: | |
dict_data = json.load(f) | |
return dict_data | |
def get_advance_dataset(dict_data): | |
if not isinstance(dict_data, Dict): | |
raise RuntimeError("dict_data is not a dict.") | |
dataset = AdvancedDataset.from_dict(dict_data) | |
return dataset |