|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
MAX_HISTORY = 7 |
|
MODEL_PATH = 'llongpre/DialoGPT-small-miles' |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
|
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.set_verbosity_info() |
|
logger = logging.get_logger("transformers") |
|
logger.info("INFO") |
|
|
|
def generate_answer(input, history=[]): |
|
new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt') |
|
history = history.append(input) |
|
logger.info(history) |
|
if len(history) > MAX_HISTORY: |
|
history = history[-MAX_HISTORY:] |
|
bot_input_ids = torch.cat(history, dim=-1) |
|
chat_history_ids = model.generate( |
|
bot_input_ids, |
|
pad_token_id=tokenizer.pad_token_id, |
|
max_length=1000, |
|
do_sample=True, |
|
|
|
top_p=0.7, |
|
temperature = 0.95, |
|
no_repeat_ngram_size=3, |
|
) |
|
response = chat_history_ids[:, bot_input_ids.shape[-1]:] |
|
output = tokenizer.decode(response[0], skip_special_tokens=True) |
|
history.append(output) |
|
|
|
return output, history |
|
|
|
|
|
gr.Interface( |
|
fn=generate_answer, |
|
title="DialoGPT-large", |
|
inputs=["text", "state"], |
|
outputs=["chatbot", "state"], |
|
).launch() |
|
|