from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import gradio as gr from transformers import BlenderbotTokenizer from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration, BlenderbotConfig from transformers import BlenderbotTokenizerFast import contextlib #tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") #model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill") #tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-3B") mname = "facebook/blenderbot-400M-distill" #configuration = BlenderbotConfig.from_pretrained(mname) tokenizer = BlenderbotTokenizerFast.from_pretrained(mname) model = BlenderbotForConditionalGeneration.from_pretrained(mname) #tokenizer = BlenderbotTokenizer.from_pretrained(mname) #-----------new chat----------- print(mname + 'model loaded') def predict(input,history=[]): history.append(input) listToStr= ' '.join([str(elem)for elem in history[len(history)-3:]]) #print('listToStr -->',str(listToStr)) input_ids = tokenizer([(listToStr)], return_tensors="pt",max_length=512,truncation=True) next_reply_ids = model.generate(**input_ids,max_length=512, pad_token_id=tokenizer.eos_token_id) response = tokenizer.batch_decode(next_reply_ids, skip_special_tokens=True)[0] history.append(response) response = [(history[i], history[i+1]) for i in range(0, len(history)-1, 2)] # convert to tuples of list return response, history demo = gr.Interface(fn=predict, inputs=["text",'state'], outputs=["chatbot",'state']) demo.launch()