AIdeaText commited on
Commit
c0c171a
1 Parent(s): 7054950

Update modules/chatbot.py

Browse files
Files changed (1) hide show
  1. modules/chatbot.py +41 -6
modules/chatbot.py CHANGED
@@ -3,14 +3,49 @@ import torch
3
 
4
  class MultilingualChatbot:
5
  def __init__(self):
6
- self.model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
7
- self.tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
 
 
 
 
 
 
 
 
 
 
8
 
9
  def generate_response(self, prompt, src_lang):
10
- self.tokenizer.src_lang = src_lang
11
- encoded_input = self.tokenizer(prompt, return_tensors="pt")
12
- generated_tokens = self.model.generate(**encoded_input, max_length=100)
13
- return self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def initialize_chatbot():
16
  return MultilingualChatbot()
 
3
 
4
  class MultilingualChatbot:
5
  def __init__(self):
6
+ self.models = {
7
+ 'en': GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-medium"),
8
+ 'es': GPT2LMHeadModel.from_pretrained("DeepESP/gpt2-spanish"),
9
+ 'fr': GPT2LMHeadModel.from_pretrained("asi/gpt-fr-cased-small")
10
+ }
11
+ self.tokenizers = {
12
+ 'en': GPT2Tokenizer.from_pretrained("microsoft/DialoGPT-medium"),
13
+ 'es': GPT2Tokenizer.from_pretrained("DeepESP/gpt2-spanish"),
14
+ 'fr': GPT2Tokenizer.from_pretrained("asi/gpt-fr-cased-small")
15
+ }
16
+ for tokenizer in self.tokenizers.values():
17
+ tokenizer.pad_token = tokenizer.eos_token
18
 
19
  def generate_response(self, prompt, src_lang):
20
+ # Default to English if the language is not supported
21
+ model = self.models.get(src_lang, self.models['en'])
22
+ tokenizer = self.tokenizers.get(src_lang, self.tokenizers['en'])
23
+
24
+ input_ids = tokenizer.encode(prompt + tokenizer.eos_token, return_tensors='pt')
25
+
26
+ # Move input to the same device as the model
27
+ input_ids = input_ids.to(model.device)
28
+
29
+ chat_history_ids = model.generate(
30
+ input_ids,
31
+ max_length=1000,
32
+ pad_token_id=tokenizer.eos_token_id,
33
+ no_repeat_ngram_size=3,
34
+ do_sample=True,
35
+ top_k=50,
36
+ top_p=0.95,
37
+ temperature=0.7,
38
+ num_return_sequences=1,
39
+ length_penalty=1.0,
40
+ repetition_penalty=1.2
41
+ )
42
+ return tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
43
+
44
+ def initialize_chatbot():
45
+ return MultilingualChatbot()
46
+
47
+ def get_chatbot_response(chatbot, prompt, src_lang):
48
+ return chatbot.generate_response(prompt, src_lang)
49
 
50
  def initialize_chatbot():
51
  return MultilingualChatbot()