callmeclover commited on
Commit
2faeeee
1 Parent(s): 0495618

Create models/blenderbot.py

Browse files
Files changed (1) hide show
  1. models/blenderbot.py +33 -0
models/blenderbot.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from:
3
+ https://mandgie.medium.com/how-to-build-your-own-chatbot-f5848ebcba8d
4
+ """
5
+
6
+
7
+ from transformers import BlenderbotSmallTokenizer, BlenderbotSmallForConditionalGeneration
8
+ import os
9
+
10
+
11
+ class BlenderBot:
12
+ def __init__(
13
+ self,
14
+ model_name: str ='facebook/blenderbot_small-90M',
15
+ ):
16
+ if not os.path.exists('./models/blenderbot'):
17
+ BlenderbotSmallForConditionalGeneration.from_pretrained(model_name).save_pretrained('./models/blenderbot')
18
+ BlenderbotSmallTokenizer.from_pretrained(model_name).save_pretrained('./models/blenderbot')
19
+
20
+ self.model = BlenderbotSmallForConditionalGeneration.from_pretrained('./models/blenderbot')
21
+ self.tokenizer = BlenderbotSmallTokenizer.from_pretrained('./models/blenderbot')
22
+
23
+ def __call__(self, inputs: str) -> str:
24
+ inputs_tokenized = self.tokenizer(inputs, return_tensors='pt')
25
+ reply_ids = self.model.generate(**inputs_tokenized)
26
+ reply = self.tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0]
27
+
28
+ return reply
29
+
30
+ def run(self):
31
+ while True:
32
+ user_input = input("User: ")
33
+ print("Bot:", self(user_input))