|
import transformers |
|
import ctranslate2 |
|
|
|
from typing import List, Dict |
|
import os |
|
|
|
|
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path: str): |
|
|
|
dialogpt_path = os.path.join(path, "dialogpt") |
|
self.generator = ctranslate2.Generator(dialogpt_path, device="cpu", compute_type="int8") |
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") |
|
|
|
def __call__(self, inputs: str) -> List[Dict]: |
|
|
|
text = inputs + self.tokenizer.eos_token |
|
start_tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text)) |
|
|
|
results = self.generator.generate_batch([start_tokens]) |
|
output = results[0].sequences[0] |
|
|
|
generated_text = self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(output)) |
|
|
|
return [{"generated_text": generated_text}] |