EasyPrompt / utils.py
Trace2333's picture
initial commit
c700ce7
raw
history blame
1.81 kB
import os
import json
from typing import Dict
from torch.utils.data import Dataset
from datasets import Dataset as AdvancedDataset
from transformers import AutoTokenizer, AutoModelForCausalLM
DEFAULT_TRAIN_DATA_NAME = "test_openprompt.json"
DEFAULT_TEST_DATA_NAME = "train_openprompt.json"
DEFAULT_DICT_DATA_NAME = "dataset_openprompt.json"
def get_open_prompt_data(path_for_data):
with open(os.path.join(path_for_data, DEFAULT_TRAIN_DATA_NAME)) as f:
train_data = json.load(f)
with open(os.path.join(path_for_data, DEFAULT_TEST_DATA_NAME)) as f:
test_data = json.load(f)
return train_data, test_data
def get_tok_and_model(path_for_model):
if not os.path.exists(path_for_model):
raise RuntimeError("no cached model.")
tok = AutoTokenizer.from_pretrained(path_for_model, padding_side='left')
tok.pad_token_id = 50256
# default for open-ended generation
model = AutoModelForCausalLM.from_pretrained(path_for_model)
return tok, model
class OpenPromptDataset(Dataset):
def __init__(self, data) -> None:
super().__init__()
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
def get_dataset(train_data, test_data):
train_dataset = OpenPromptDataset(train_data)
test_dataset = OpenPromptDataset(test_data)
return train_dataset, test_dataset
def get_dict_dataset(path_for_data):
with open(os.path.join(path_for_data, DEFAULT_DICT_DATA_NAME)) as f:
dict_data = json.load(f)
return dict_data
def get_advance_dataset(dict_data):
if not isinstance(dict_data, Dict):
raise RuntimeError("dict_data is not a dict.")
dataset = AdvancedDataset.from_dict(dict_data)
return dataset