Spaces:
Runtime error
Runtime error
import logging | |
from telegram import Update | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from telegram.ext import ( | |
CallbackContext, | |
) | |
NAME = "Conversation" | |
DESCRIPTION = """ | |
Useful for building up conversation. | |
Input: A normal chat text | |
Output: A text | |
""" | |
GET_CON = range(1) | |
class Conversation(): | |
tokenizer = AutoTokenizer.from_pretrained( | |
"microsoft/GODEL-v1_1-large-seq2seq") | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
"microsoft/GODEL-v1_1-large-seq2seq") | |
# async def talk(self, message: str): | |
# logging.info(f"{message}") | |
# chat_history_ids = torch.tensor([], dtype=torch.long) | |
# new_user_input_ids = self.tokenizer.encode(message + self.tokenizer.eos_token, return_tensors='pt') | |
# bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) | |
# chat_history_ids =self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id) | |
# return "{}".format(self.tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)) | |
def generate(self, instruction, knowledge, dialog): | |
if knowledge != '': | |
knowledge = '[KNOWLEDGE] ' + knowledge | |
dialog = ' EOS '.join(dialog) | |
query = f"{instruction} [CONTEXT] {dialog} {knowledge}" | |
input_ids = self.tokenizer(f"{query}", return_tensors="pt").input_ids | |
outputs = self.model.generate( | |
input_ids, max_length=128, | |
min_length=8, | |
top_p=0.9, | |
do_sample=True, | |
) | |
output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return output | |
async def process_conversation(self, update: Update, context: CallbackContext) -> int: | |
message = update.message.text | |
instruction = f'Instruction: given a dialog context, you need to response empathically.' | |
knowledge = '' | |
dialog = [] | |
dialog .append(message) | |
text = self.generate(instruction, knowledge, dialog) | |
await update.message.reply_text(f'{text}') | |