scottbotai / app.py
tomkr000's picture
Update app.py
4fb8c33
raw history blame
No virus
1.09 kB
import gradio as gr
import torch
from transformers import AutoModelWithLMHead, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('microsoft/DialoGPT-small', padding_side='right')
model = AutoModelWithLMHead.from_pretrained('tomkr000/scottbotai')
def chat(message, history):
new_user_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors='pt')
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
chat_history_ids = model.generate(
bot_input_ids, max_length=200,
pad_token_id=tokenizer.eos_token_id,
no_repeat_ngram_size=3,
do_sample=True,
top_k=100,
top_p=0.7,
temperature = 0.8
)
response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
history.append((message, response))
return history, history
chatbot = gr.Chatbot().style(color_map=("green", "pink"))
demo = gr.Interface(
chat,
["text", "state"],
[chatbot, "state"],
allow_flagging="never",
)
demo.launch()