blenderbot_chat / app.py
Mandar Patil
Add application file
2c050f2
raw history blame
No virus
1.63 kB
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import gradio as gr
from transformers import BlenderbotTokenizer
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration, BlenderbotConfig
from transformers import BlenderbotTokenizerFast
import contextlib
#tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
#model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill")
#tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-3B")
mname = "facebook/blenderbot-400M-distill"
#configuration = BlenderbotConfig.from_pretrained(mname)
tokenizer = BlenderbotTokenizerFast.from_pretrained(mname)
model = BlenderbotForConditionalGeneration.from_pretrained(mname)
#tokenizer = BlenderbotTokenizer.from_pretrained(mname)
#-----------new chat-----------
print(mname + 'model loaded')
def predict(input,history=[]):
history.append(input)
listToStr= '</s> <s>'.join([str(elem)for elem in history[len(history)-3:]])
#print('listToStr -->',str(listToStr))
input_ids = tokenizer([(listToStr)], return_tensors="pt",max_length=512,truncation=True)
next_reply_ids = model.generate(**input_ids,max_length=512, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.batch_decode(next_reply_ids, skip_special_tokens=True)[0]
history.append(response)
response = [(history[i], history[i+1]) for i in range(0, len(history)-1, 2)] # convert to tuples of list
return response, history
demo = gr.Interface(fn=predict, inputs=["text",'state'], outputs=["chatbot",'state'])
demo.launch(share=True)