blenderbot_chat / app.py
Mandar Patil
Add req file
042d014
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import gradio as gr
from transformers import BlenderbotTokenizer
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration, BlenderbotConfig
from transformers import BlenderbotTokenizerFast
import contextlib
#tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
#model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill")
#tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-3B")
mname = "facebook/blenderbot-400M-distill"
#configuration = BlenderbotConfig.from_pretrained(mname)
tokenizer = BlenderbotTokenizerFast.from_pretrained(mname)
model = BlenderbotForConditionalGeneration.from_pretrained(mname)
#tokenizer = BlenderbotTokenizer.from_pretrained(mname)
#-----------new chat-----------
print(mname + 'model loaded')
def predict(input,history=[]):
history.append(input)
listToStr= '</s> <s>'.join([str(elem)for elem in history[len(history)-3:]])
#print('listToStr -->',str(listToStr))
input_ids = tokenizer([(listToStr)], return_tensors="pt",max_length=512,truncation=True)
next_reply_ids = model.generate(**input_ids,max_length=512, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.batch_decode(next_reply_ids, skip_special_tokens=True)[0]
history.append(response)
response = [(history[i], history[i+1]) for i in range(0, len(history)-1, 2)] # convert to tuples of list
return response, history
demo = gr.Interface(fn=predict, inputs=["text",'state'], outputs=["chatbot",'state'])
demo.launch()