Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
import sys | |
sys.path.append("../") | |
from collections import defaultdict | |
from .utils import is_float, load_txt | |
import random | |
random.seed(1234) | |
class CreateDataset: | |
def __init__(self, max_input_len=1500): | |
self.prompt = load_txt("../prompt/dataset_character.txt") | |
self.max_input_len = max_input_len # 小于(seq-length)-(max-gen-length) | |
self.example_split_flag = f"\n{'-' * 20}\n" | |
self.dataset = defaultdict(list) | |
self.manual_dataset = [] | |
def choose_examples(similar_examples, | |
max_length, | |
train_flag=False, | |
dialog=None, | |
example_split_flag=f"\n{'-' * 20}\n"): | |
if isinstance(similar_examples, str): | |
new_similar_examples = [x.strip() for x in similar_examples.split(example_split_flag)] | |
else: | |
# 去重 | |
new_similar_examples = [] | |
for example in similar_examples: | |
if (isinstance(example, list) or isinstance(example, tuple)) and len(example) == 2 and is_float( | |
example[0]): | |
# 包含score | |
example = example[1] | |
try: | |
example = "\n".join(example).strip() | |
except TypeError: | |
raise TypeError(f"example: {example}") | |
if train_flag and dialog and (example in dialog or dialog in example): | |
continue | |
# example去重 | |
if train_flag: | |
# 部分相似也去掉 | |
flag = False | |
for n_example in new_similar_examples: | |
if example in n_example or n_example in example: | |
flag = True | |
break | |
if not flag: | |
new_similar_examples.append(example) | |
else: | |
if example not in new_similar_examples: | |
new_similar_examples.append(example) | |
results = [] | |
total_length = 0 | |
for example in new_similar_examples: | |
total_length += len(example) if not total_length else len(example_split_flag) + len(example) | |
if total_length > max_length: | |
break | |
results.append(example) | |
results = example_split_flag.join(results).strip() | |
return results | |