test2 / modules /chatbot.py
AIdeaText's picture
Update modules/chatbot.py
904f07a verified
raw
history blame
No virus
860 Bytes
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch
class MultilingualChatbot:
def __init__(self):
self.model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
self.tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
def generate_response(self, prompt, src_lang):
self.tokenizer.src_lang = src_lang
encoded_input = self.tokenizer(prompt, return_tensors="pt")
generated_tokens = self.model.generate(**encoded_input, max_length=100)
return self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
def initialize_chatbot():
return MultilingualChatbot()
def get_chatbot_response(chatbot, prompt, src_lang):
return chatbot.generate_response(prompt, src_lang)