chat-llm / app.py
dipta007's picture
updated model name
41cbe66
import streamlit as st
import os
# os.environ['HF_HOME'] = '/scratch/sroydip1/cache/hf/'
os.environ["HUGGINGFACEHUB_API_TOKEN"] = st.secrets["HF_TOKEN"]
# import torch
import pickle
import torch
from transformers import Conversation, pipeline, AutoTokenizer, AutoModelForCausalLM
from upload import get_file, upload_file
from utils import clear_uploader, undo, restart
TOKEN = st.secrets["HF_TOKEN"]
share_keys = ["messages", "model_name"]
MODELS = [
"meta-llama/Llama-2-7b-chat-hf",
"mistralai/Mistral-7B-Instruct-v0.2",
# "google/flan-t5-small",
# "google/flan-t5-base",
# "google/flan-t5-large",
# "google/flan-t5-xl",
# "google/flan-t5-xxl",
]
default_model = MODELS[0]
# default_model = "meta-llama/Llama-2-7b-chat-hf"
st.set_page_config(
page_title="LLM",
page_icon="πŸ“š",
)
if "model_name" not in st.session_state:
st.session_state.model_name = default_model
@st.cache_resource
def get_pipeline(model_name):
model_name = "gpt2-medium"
device = 0 if torch.cuda.is_available() else -1
# if True or model_name == "meta-llama/Llama-2-7b-chat-hf" or model_name == "mistralai/Mistral-7B-Instruct-v0.2":
# chatbot = pipeline(model=model_name, task="conversational", device=device)#, model_kwargs=model_kwargs)
# else:
# chatbot = pipeline(model=model_name, task="text-generation", device=device)
tokenizer = AutoTokenizer.from_pretrained(model_name, token=TOKEN)
model = AutoModelForCausalLM.from_pretrained(model_name, token=TOKEN)
# chatbot = pipeline("conversational", model=model, tokenizer=tokenizer, device=device)
chatbot = pipeline("conversational", model=model, tokenizer=tokenizer)
return chatbot
chatbot = get_pipeline(st.session_state.model_name)
if "messages" not in st.session_state:
st.session_state.messages = []
if len(st.session_state.messages) == 0 and "id" in st.query_params:
with st.spinner("Loading chat..."):
id = st.query_params["id"]
data = get_file(id)
obj = pickle.loads(data)
for k, v in obj.items():
st.session_state[k] = v
def share():
obj = {}
for k in share_keys:
if k in st.session_state:
obj[k] = st.session_state[k]
data = pickle.dumps(obj)
id = upload_file(data)
url = f"https://umbc-nlp-chat-llm.hf.space/?id={id}"
st.markdown(f"[share](/?id={id})")
st.success(f"Share URL: {url}")
with st.sidebar:
st.title(":blue[LLM Only]")
st.subheader("Model")
model_name = st.selectbox("Model", MODELS, key="model_name")
if st.button("Share", use_container_width=True):
share()
cols = st.columns(2)
with cols[0]:
if st.button("Restart", type="primary", use_container_width=True):
restart()
with cols[1]:
if st.button("Undo", use_container_width=True):
undo()
append = st.checkbox("Append to previous message", value=False)
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
def push_message(role, content):
message = {"role": role, "content": content}
st.session_state.messages.append(message)
return message
if prompt := st.chat_input("Type a message", key="chat_input"):
push_message("user", prompt)
with st.chat_message("user"):
st.markdown(prompt)
if not append:
with st.chat_message("assistant"):
chat = Conversation()
for m in st.session_state.messages:
chat.add_message(m)
print(chat)
with st.spinner("Generating response..."):
response = chatbot(chat)
response = response[-1]["content"]
st.write(response)
push_message("assistant", response)
clear_uploader()