DialoGPT-uk / pipeline.py
theodotus's picture
Added dialogpt pipeline
651c45f
raw
history blame
862 Bytes
import transformers
import ctranslate2
from typing import List, Dict
import os
class PreTrainedPipeline():
def __init__(self, path: str):
# Init DialoGPT
dialogpt_path = os.path.join(path, "dialogpt")
self.generator = ctranslate2.Generator(dialogpt_path, device="cpu", compute_type="int8")
self.tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
def __call__(self, inputs: str) -> List[Dict]:
text = inputs + self.tokenizer.eos_token
start_tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text))
results = self.generator.generate_batch([start_tokens])
output = results[0].sequences[0]
generated_text = self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(output))
return [{"generated_text": generated_text}]