amalaugustinem's picture
Update app.py
f433de4
import streamlit as st
from streamlit_chat import message as st_message
import numpy as np
from transformers import AutoModelForCausalLM , AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
def run(user_text , chat_history_ids):
input_ids = tokenizer.encode(user_text + tokenizer.eos_token , return_tensors = "pt")
if chat_history_ids is None:
bot_input_ids = input_ids
else:
bot_input_ids = torch.cat([chat_history_ids , input_ids] , dim = -1)
chat_history_ids = model.generate(bot_input_ids , max_length = 1000 ,
pad_token_id = tokenizer.eos_token_id)
resp = tokenizer.decode(chat_history_ids[: , bot_input_ids.shape[-1]:][0] ,
skip_special_tokens = True)
return resp , chat_history_ids
if "chat_history_ids" not in st.session_state:
st.session_state["chat_history_ids"] = None
if "book" not in st.session_state:
st.session_state["book"] = []
txt = st.text_input("Type Here")
if txt:
resp , hist = run(txt , st.session_state["chat_history_ids"])
st.session_state["chat_history_ids"] = hist
st.session_state["book"].append({"message" : txt ,
"is_user" : True})
st.session_state["book"].append({"message" : resp ,
"is_user" : False})
for i , chat in enumerate(st.session_state["book"]):
st_message(**chat , key = str(i))