import transformers import ctranslate2 from typing import List, Dict import os class PreTrainedPipeline(): def __init__(self, path: str): # Init DialoGPT dialogpt_path = os.path.join(path, "dialogpt") self.generator = ctranslate2.Generator(dialogpt_path, device="cpu", compute_type="int8") self.tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") # Init M2M100 m2m100_path = os.path.join(path, "m2m100") self.translator = ctranslate2.Translator(m2m100_path, device="cpu", compute_type="int8") self.m2m100_tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/m2m100_418M") def __call__(self, inputs: str) -> List[Dict]: # to eng en_text = self.m2m100(inputs, "uk", "en") # Run dialogpt generated_text = self.dialogpt(en_text) # to ukr uk_text = self.m2m100(generated_text, "en", "uk") return [{"generated_text": uk_text}] def dialogpt(self, inputs: str) -> str: # Get input tokens text = inputs + self.tokenizer.eos_token start_tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text)) # generate results = self.generator.generate_batch([start_tokens]) output = results[0].sequences[0] # left only answers tokens = self.tokenizer.convert_tokens_to_ids(output) eos_index = tokens.index(self.tokenizer.eos_token_id) answer_tokens = tokens[eos_index+1:] generated_text = self.tokenizer.decode(answer_tokens) return generated_text def m2m100(self, inputs: str, from_lang: str, to_lang: str) -> str: self.m2m100_tokenizer.src_lang = from_lang source = self.m2m100_tokenizer.convert_ids_to_tokens(self.m2m100_tokenizer.encode(inputs)) target_prefix = [self.m2m100_tokenizer.lang_code_to_token[to_lang]] results = self.translator.translate_batch([source], target_prefix=[target_prefix]) target = results[0].hypotheses[0][1:] translated_text = self.m2m100_tokenizer.decode(self.m2m100_tokenizer.convert_tokens_to_ids(target)) return translated_text