import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import torch MAX_HISTORY = 7 MODEL_PATH = 'llongpre/DialoGPT-small-miles' tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForCausalLM.from_pretrained(MODEL_PATH) # def predict(input, history=[]): # # tokenize the new input sentence # new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt') # # # append the new user input tokens to the chat history # bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) # # # generate a response # history = model.generate( # bot_input_ids, # max_length=1000, # pad_token_id=tokenizer.eos_token_id, # no_repeat_ngram_size=3, # top_p = 0.92, # top_k = 50 # ).tolist() # # # convert the tokens to text, and then split the responses into lines # response = tokenizer.decode(history[0]).split("<|endoftext|>") # response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list # # return response, history # # from transformers.utils import logging logging.set_verbosity_info() logger = logging.get_logger("transformers") logger.info("INFO") def generate_answer(input, history=[]): new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt') history = history.append(input) logger.info(history) if len(history) > MAX_HISTORY: history = history[-MAX_HISTORY:] bot_input_ids = torch.cat(history, dim=-1) chat_history_ids = model.generate( bot_input_ids, pad_token_id=tokenizer.pad_token_id, max_length=1000, do_sample=True, # top_k=150, # sample from the top k words sorted descending by probability top_p=0.7, # choose smallest possible words whose cumulative probability exceeds p temperature = 0.95, # 0 greedy, inf is random no_repeat_ngram_size=3, ) response = chat_history_ids[:, bot_input_ids.shape[-1]:] output = tokenizer.decode(response[0], skip_special_tokens=True) history.append(output) return output, history gr.Interface( fn=generate_answer, title="DialoGPT-large", inputs=["text", "state"], outputs=["chatbot", "state"], ).launch()