rexthecoder commited on
Commit
23dd0e5
·
1 Parent(s): 5dfb7cd
Files changed (1) hide show
  1. src/agent/tools/conversation.py +24 -37
src/agent/tools/conversation.py CHANGED
@@ -1,8 +1,7 @@
1
  import logging
2
  from telegram import Update
 
3
  import torch
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BlenderbotForConditionalGeneration, BlenderbotForCausalLM, BlenderbotTokenizer
5
-
6
  from telegram.ext import (
7
  CallbackContext,
8
  )
@@ -19,8 +18,10 @@ GET_CON = range(1)
19
 
20
 
21
  class Conversation():
22
- tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
23
- model = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill",add_cross_attention=False)
 
 
24
 
25
  # async def talk(self, message: str):
26
  # logging.info(f"{message}")
@@ -29,41 +30,27 @@ class Conversation():
29
  # bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
30
  # chat_history_ids =self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id)
31
  # return "{}".format(self.tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True))
32
- def predict(self, input, history=[]):
33
- # tokenize the new input sentence
34
- new_user_input_ids = self.tokenizer.encode(input + self.tokenizer.eos_token, return_tensors='pt')
35
-
36
- # append the new user input tokens to the chat history
37
- bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
38
-
39
- # generate a response
40
- history = self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id).tolist()
41
-
42
- # convert the tokens to text, and then split the responses into the right format
43
- response = self.tokenizer.decode(history[0]).replace("<s>","").split("</s>")
44
- #response = [(response[i], response[i+1]) for i in range(0, len(response), 2)] # convert to tuples of list
45
- return f'{response}'
46
 
47
- # def generate(self, instruction, knowledge, dialog):
48
- # if knowledge != '':
49
- # knowledge = '[KNOWLEDGE] ' + knowledge
50
- # dialog = ' EOS '.join(dialog)
51
- # query = f"{instruction} [CONTEXT] {dialog} {knowledge}"
52
- # input_ids = self.tokenizer(f"{query}", return_tensors="pt").input_ids
53
- # outputs = self.model.generate(
54
- # input_ids, max_length=128,
55
- # min_length=8,
56
- # top_p=0.9,
57
- # do_sample=True,
58
- # )
59
- # output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
60
- # return output
61
 
62
  async def process_conversation(self, update: Update, context: CallbackContext) -> int:
63
  message = update.message.text
64
- # instruction = f'Instruction: given a dialog context, you need to response empathically.'
65
- # knowledge = ''
66
- # dialog = []
67
- # dialog .append(message)
68
- text = self.predict(message)
69
  await update.message.reply_text(f'{text}')
 
1
  import logging
2
  from telegram import Update
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import torch
 
 
5
  from telegram.ext import (
6
  CallbackContext,
7
  )
 
18
 
19
 
20
  class Conversation():
21
+ tokenizer = AutoTokenizer.from_pretrained(
22
+ "microsoft/GODEL-v1_1-large-seq2seq")
23
+ model = AutoModelForSeq2SeqLM.from_pretrained(
24
+ "microsoft/GODEL-v1_1-large-seq2seq")
25
 
26
  # async def talk(self, message: str):
27
  # logging.info(f"{message}")
 
30
  # bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
31
  # chat_history_ids =self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id)
32
  # return "{}".format(self.tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ def generate(self, instruction, knowledge, dialog):
35
+ if knowledge != '':
36
+ knowledge = '[KNOWLEDGE] ' + knowledge
37
+ dialog = ' EOS '.join(dialog)
38
+ query = f"{instruction} [CONTEXT] {dialog} {knowledge}"
39
+ input_ids = self.tokenizer(f"{query}", return_tensors="pt").input_ids
40
+ outputs = self.model.generate(
41
+ input_ids, max_length=128,
42
+ min_length=8,
43
+ top_p=0.9,
44
+ do_sample=True,
45
+ )
46
+ output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
47
+ return output
48
 
49
  async def process_conversation(self, update: Update, context: CallbackContext) -> int:
50
  message = update.message.text
51
+ instruction = f'Instruction: given a dialog context, you need to response empathically.'
52
+ knowledge = ''
53
+ dialog = []
54
+ dialog .append(message)
55
+ text = self.generate(instruction, knowledge, dialog)
56
  await update.message.reply_text(f'{text}')