|
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast |
|
import torch |
|
|
|
class MultilingualChatbot: |
|
def __init__(self): |
|
self.models = { |
|
'en': GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-medium"), |
|
'es': GPT2LMHeadModel.from_pretrained("DeepESP/gpt2-spanish"), |
|
'fr': GPT2LMHeadModel.from_pretrained("asi/gpt-fr-cased-small") |
|
} |
|
self.tokenizers = { |
|
'en': GPT2Tokenizer.from_pretrained("microsoft/DialoGPT-medium"), |
|
'es': GPT2Tokenizer.from_pretrained("DeepESP/gpt2-spanish"), |
|
'fr': GPT2Tokenizer.from_pretrained("asi/gpt-fr-cased-small") |
|
} |
|
for tokenizer in self.tokenizers.values(): |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
def generate_response(self, prompt, src_lang): |
|
|
|
model = self.models.get(src_lang, self.models['en']) |
|
tokenizer = self.tokenizers.get(src_lang, self.tokenizers['en']) |
|
|
|
input_ids = tokenizer.encode(prompt + tokenizer.eos_token, return_tensors='pt') |
|
|
|
|
|
input_ids = input_ids.to(model.device) |
|
|
|
chat_history_ids = model.generate( |
|
input_ids, |
|
max_length=1000, |
|
pad_token_id=tokenizer.eos_token_id, |
|
no_repeat_ngram_size=3, |
|
do_sample=True, |
|
top_k=50, |
|
top_p=0.95, |
|
temperature=0.7, |
|
num_return_sequences=1, |
|
length_penalty=1.0, |
|
repetition_penalty=1.2 |
|
) |
|
return tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True) |
|
|
|
def initialize_chatbot(): |
|
return MultilingualChatbot() |
|
|
|
def get_chatbot_response(chatbot, prompt, src_lang): |
|
return chatbot.generate_response(prompt, src_lang) |
|
|
|
def initialize_chatbot(): |
|
return MultilingualChatbot() |
|
|
|
def get_chatbot_response(chatbot, prompt, src_lang): |
|
return chatbot.generate_response(prompt, src_lang) |