Spaces:
Runtime error
Runtime error
# -*-coding:utf-8 -*- | |
import re | |
import json | |
import random | |
import pandas as pd | |
class Instance(object): | |
""" | |
By Default use few-shot for generation and evaluation | |
""" | |
def __init__(self, loader=None): | |
self.samples = loader() | |
self.n_few_shot = 0 | |
self.n_train = 0 | |
self.n_eval = 0 | |
self.train_iter = None | |
self.train_samples = [] | |
self.eval_samples = [] | |
def n_sample(self): | |
return len(self.samples) | |
def sample(self, n_train, n_few_shot, n_eval): | |
self.n_train = n_train | |
self.n_few_shot = n_few_shot | |
self.n_eval = n_eval | |
n_train = n_train * n_few_shot | |
if n_train + n_eval > len(self.samples): | |
raise ValueError(f'Train + Eval > total samples {len(self.samples)}, decrease them') | |
index = random.sample(list(range(len(self.samples))), n_train + n_eval) | |
train_index, eval_index = index[:n_train], index[n_train:] | |
self.train_samples = [self.samples[i] for i in train_index] | |
self.eval_samples = [self.samples[i] for i in eval_index] | |
def get_train_iter(self): | |
for i in range(self.n_train): | |
yield self.train_samples[(i * self.n_few_shot):(i + 1) * self.n_few_shot] | |
def display(samples): | |
s = "" | |
for i in samples: | |
s += f'{i[0]} >> {i[1]}\n' | |
return s | |
def from_file(cls, loader): | |
return cls(loader) | |
def from_list(cls, tuple_list): | |
# 直接输入Input,Ouput List 构建Instance | |
def func(): | |
return tuple_list | |
return cls(func) | |
def load_paraphase(file='./ape/data/paraphase_train.csv'): | |
df = pd.read_csv(file, encoding='GBK') | |
tuple_list = [] | |
for i in df.iterrows(): | |
tuple_list.append((i[1][0], i[1][1])) | |
return tuple_list | |
def load_intent(file='./ape/data/intent_train.csv'): | |
df = pd.read_csv(file, encoding='UTF8', sep='\t') | |
tuple_list = [] | |
for i in df.iterrows(): | |
tuple_list.append((i[1][0], i[1][1])) | |
return tuple_list | |
def load_qa(file='./ape/data/qa_train.json'): | |
data = [] | |
raw_data = json.load(open(file, encoding='UTF8')) | |
for i in raw_data: | |
input = i['text'] | |
# 只取一个QA不然容易超出模型输入长度' | |
output = [] | |
for j in i['annotations']: | |
output.append(json.dumps({'问题': j["Q"], '回答': j["A"]}, ensure_ascii=False)) | |
output = sorted(output, key=lambda x: len(x)) | |
output = output[0] | |
data.append((input, output)) | |
return data | |
def upload_file(file): | |
tuple_list = [] | |
with open(file, 'r', encoding='UTF-8') as f: | |
for i in f.readlines(): | |
input, output = i.split(' ') | |
tuple_list.append((input, output)) | |
return tuple_list | |
def load_entity(file='./ape/data/entity_train.json'): | |
data = [] | |
raw_data = json.load(open(file, encoding='UTF8')) | |
for i in raw_data: | |
input = re.sub(r'\s{1,}', '',i['text'][:200]) # 对文本进行截断,不然太长了。。。 | |
output = [] | |
for j in i['labels']: | |
##拆分成单实体任务类型 | |
if j[1] =='DRUG_EFFICACY' and j[-1] in input: | |
output.append(j[-1]) | |
output = json.dumps(output, ensure_ascii=False) | |
data.append((input, output)) | |
return data | |
LoadFactory = { | |
'paraphase': load_paraphase, | |
'search_intent': load_intent, | |
'qa_generation': load_qa, | |
'entity': load_entity | |
} | |
if __name__ == '__main__': | |
n_train = 2 | |
few_shot = 3 | |
n_eval = 2 | |
instance1 = Instance.from_file(load_paraphase) | |
instance1.sample(n_train, few_shot, n_eval) | |
print(instance1.display(instance1.train_samples)) | |
instance2 = Instance.from_list([('sane', 'insane'), ('direct', 'indirect'), ('informally', 'formally'), | |
('unpopular', 'popular'), ('subtractive', 'additive'), | |
('nonresidential', 'residential'), ('inexact', 'exact'), | |
('uptown', 'downtown'), ('incomparable', 'comparable'), | |
('powerful', 'powerless'), ('gaseous', 'solid'), | |
('evenly', 'unevenly'), ('formality', 'informality'), | |
('deliberately', 'accidentally'), ('off', 'on')]) | |
instance2.sample(n_train, few_shot, n_eval) | |
print(instance2.display(instance2.train_samples)) | |
train_iter = instance2.get_train_iter() | |
print(next(train_iter)) | |