Spaces:
Runtime error
Runtime error
rexthecoder
commited on
Commit
·
23dd0e5
1
Parent(s):
5dfb7cd
chore:
Browse files- src/agent/tools/conversation.py +24 -37
src/agent/tools/conversation.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1 |
import logging
|
2 |
from telegram import Update
|
|
|
3 |
import torch
|
4 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BlenderbotForConditionalGeneration, BlenderbotForCausalLM, BlenderbotTokenizer
|
5 |
-
|
6 |
from telegram.ext import (
|
7 |
CallbackContext,
|
8 |
)
|
@@ -19,8 +18,10 @@ GET_CON = range(1)
|
|
19 |
|
20 |
|
21 |
class Conversation():
|
22 |
-
tokenizer =
|
23 |
-
|
|
|
|
|
24 |
|
25 |
# async def talk(self, message: str):
|
26 |
# logging.info(f"{message}")
|
@@ -29,41 +30,27 @@ class Conversation():
|
|
29 |
# bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
|
30 |
# chat_history_ids =self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id)
|
31 |
# return "{}".format(self.tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True))
|
32 |
-
def predict(self, input, history=[]):
|
33 |
-
# tokenize the new input sentence
|
34 |
-
new_user_input_ids = self.tokenizer.encode(input + self.tokenizer.eos_token, return_tensors='pt')
|
35 |
-
|
36 |
-
# append the new user input tokens to the chat history
|
37 |
-
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
|
38 |
-
|
39 |
-
# generate a response
|
40 |
-
history = self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id).tolist()
|
41 |
-
|
42 |
-
# convert the tokens to text, and then split the responses into the right format
|
43 |
-
response = self.tokenizer.decode(history[0]).replace("<s>","").split("</s>")
|
44 |
-
#response = [(response[i], response[i+1]) for i in range(0, len(response), 2)] # convert to tuples of list
|
45 |
-
return f'{response}'
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
|
62 |
async def process_conversation(self, update: Update, context: CallbackContext) -> int:
|
63 |
message = update.message.text
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
text = self.
|
69 |
await update.message.reply_text(f'{text}')
|
|
|
1 |
import logging
|
2 |
from telegram import Update
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
import torch
|
|
|
|
|
5 |
from telegram.ext import (
|
6 |
CallbackContext,
|
7 |
)
|
|
|
18 |
|
19 |
|
20 |
class Conversation():
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
22 |
+
"microsoft/GODEL-v1_1-large-seq2seq")
|
23 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
24 |
+
"microsoft/GODEL-v1_1-large-seq2seq")
|
25 |
|
26 |
# async def talk(self, message: str):
|
27 |
# logging.info(f"{message}")
|
|
|
30 |
# bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
|
31 |
# chat_history_ids =self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id)
|
32 |
# return "{}".format(self.tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
def generate(self, instruction, knowledge, dialog):
|
35 |
+
if knowledge != '':
|
36 |
+
knowledge = '[KNOWLEDGE] ' + knowledge
|
37 |
+
dialog = ' EOS '.join(dialog)
|
38 |
+
query = f"{instruction} [CONTEXT] {dialog} {knowledge}"
|
39 |
+
input_ids = self.tokenizer(f"{query}", return_tensors="pt").input_ids
|
40 |
+
outputs = self.model.generate(
|
41 |
+
input_ids, max_length=128,
|
42 |
+
min_length=8,
|
43 |
+
top_p=0.9,
|
44 |
+
do_sample=True,
|
45 |
+
)
|
46 |
+
output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
47 |
+
return output
|
48 |
|
49 |
async def process_conversation(self, update: Update, context: CallbackContext) -> int:
|
50 |
message = update.message.text
|
51 |
+
instruction = f'Instruction: given a dialog context, you need to response empathically.'
|
52 |
+
knowledge = ''
|
53 |
+
dialog = []
|
54 |
+
dialog .append(message)
|
55 |
+
text = self.generate(instruction, knowledge, dialog)
|
56 |
await update.message.reply_text(f'{text}')
|