Spaces:
Runtime error
Runtime error
import streamlit as st | |
from streamlit_chat import message as st_message | |
import numpy as np | |
from transformers import AutoModelForCausalLM , AutoTokenizer | |
import torch | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") | |
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") | |
def run(user_text , chat_history_ids): | |
input_ids = tokenizer.encode(user_text + tokenizer.eos_token , return_tensors = "pt") | |
if chat_history_ids is None: | |
bot_input_ids = input_ids | |
else: | |
bot_input_ids = torch.cat([chat_history_ids , input_ids] , dim = -1) | |
chat_history_ids = model.generate(bot_input_ids , max_length = 1000 , | |
pad_token_id = tokenizer.eos_token_id) | |
resp = tokenizer.decode(chat_history_ids[: , bot_input_ids.shape[-1]:][0] , | |
skip_special_tokens = True) | |
return resp , chat_history_ids | |
if "chat_history_ids" not in st.session_state: | |
st.session_state["chat_history_ids"] = None | |
if "book" not in st.session_state: | |
st.session_state["book"] = [] | |
txt = st.text_input("Type Here") | |
if txt: | |
resp , hist = run(txt , st.session_state["chat_history_ids"]) | |
st.session_state["chat_history_ids"] = hist | |
st.session_state["book"].append({"message" : txt , | |
"is_user" : True}) | |
st.session_state["book"].append({"message" : resp , | |
"is_user" : False}) | |
for i , chat in enumerate(st.session_state["book"]): | |
st_message(**chat , key = str(i)) | |