Spaces:
Runtime error
Runtime error
import re | |
def batch_as_list(a, batch_size = int(100000)): | |
req = [] | |
for ele in a: | |
if not req: | |
req.append([]) | |
if len(req[-1]) < batch_size: | |
req[-1].append(ele) | |
else: | |
req.append([]) | |
req[-1].append(ele) | |
return req | |
class Obj: | |
def __init__(self, model, tokenizer, device = "cpu"): | |
self.model = model | |
self.tokenizer = tokenizer | |
self.device = "cpu" | |
def predict( | |
self, | |
source_text: str, | |
max_length: int = 512, | |
num_return_sequences: int = 1, | |
num_beams: int = 2, | |
top_k: int = 50, | |
top_p: float = 0.95, | |
do_sample: bool = True, | |
repetition_penalty: float = 2.5, | |
length_penalty: float = 1.0, | |
early_stopping: bool = True, | |
skip_special_tokens: bool = True, | |
clean_up_tokenization_spaces: bool = True, | |
): | |
input_ids = self.tokenizer.encode( | |
source_text, return_tensors="pt", add_special_tokens=True | |
) | |
input_ids = input_ids.to(self.device) | |
generated_ids = self.model.generate( | |
input_ids=input_ids, | |
num_beams=num_beams, | |
max_length=max_length, | |
repetition_penalty=repetition_penalty, | |
length_penalty=length_penalty, | |
early_stopping=early_stopping, | |
top_p=top_p, | |
top_k=top_k, | |
num_return_sequences=num_return_sequences, | |
) | |
preds = [ | |
self.tokenizer.decode( | |
g, | |
skip_special_tokens=skip_special_tokens, | |
clean_up_tokenization_spaces=clean_up_tokenization_spaces, | |
) | |
for g in generated_ids | |
] | |
return preds | |