DialoGPT-uk / pipeline.py
theodotus's picture
Added translation wrapper
319ae94
raw
history blame
2.2 kB
import transformers
import ctranslate2
from typing import List, Dict
import os
class PreTrainedPipeline():
def __init__(self, path: str):
# Init DialoGPT
dialogpt_path = os.path.join(path, "dialogpt")
self.generator = ctranslate2.Generator(dialogpt_path, device="cpu", compute_type="int8")
self.tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
# Init M2M100
m2m100_path = os.path.join(path, "m2m100")
self.translator = ctranslate2.Translator(m2m100_path, device="cpu", compute_type="int8")
self.m2m100_tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/m2m100_418M")
def __call__(self, inputs: str) -> List[Dict]:
# to eng
en_text = self.m2m100(inputs, "uk", "en")
# Run dialogpt
generated_text = self.dialogpt(en_text)
# to ukr
uk_text = self.m2m100(generated_text, "en", "uk")
return [{"generated_text": uk_text}]
def dialogpt(self, inputs: str) -> str:
# Get input tokens
text = inputs + self.tokenizer.eos_token
start_tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text))
# generate
results = self.generator.generate_batch([start_tokens])
output = results[0].sequences[0]
# left only answers
tokens = self.tokenizer.convert_tokens_to_ids(output)
eos_index = tokens.index(self.tokenizer.eos_token_id)
answer_tokens = tokens[eos_index+1:]
generated_text = self.tokenizer.decode(answer_tokens)
return generated_text
def m2m100(self, inputs: str, from_lang: str, to_lang: str) -> str:
self.m2m100_tokenizer.src_lang = from_lang
source = self.m2m100_tokenizer.convert_ids_to_tokens(self.m2m100_tokenizer.encode(inputs))
target_prefix = [self.m2m100_tokenizer.lang_code_to_token[to_lang]]
results = self.translator.translate_batch([source], target_prefix=[target_prefix])
target = results[0].hypotheses[0][1:]
translated_text = self.m2m100_tokenizer.decode(self.m2m100_tokenizer.convert_tokens_to_ids(target))
return translated_text