|
|
|
import streamlit as st |
|
import os |
|
|
|
os.environ["HUGGINGFACEHUB_API_TOKEN"] = st.secrets["HF_TOKEN"] |
|
|
|
import pickle |
|
import torch |
|
from transformers import Conversation, pipeline, AutoTokenizer, AutoModelForCausalLM |
|
from upload import get_file, upload_file |
|
from utils import clear_uploader, undo, restart |
|
|
|
|
|
TOKEN = st.secrets["HF_TOKEN"] |
|
|
|
share_keys = ["messages", "model_name"] |
|
MODELS = [ |
|
"meta-llama/Llama-2-7b-chat-hf", |
|
"mistralai/Mistral-7B-Instruct-v0.2", |
|
|
|
|
|
|
|
|
|
|
|
] |
|
default_model = MODELS[0] |
|
|
|
|
|
st.set_page_config( |
|
page_title="LLM", |
|
page_icon="π", |
|
) |
|
|
|
if "model_name" not in st.session_state: |
|
st.session_state.model_name = default_model |
|
|
|
|
|
@st.cache_resource |
|
def get_pipeline(model_name): |
|
model_name = "gpt2-medium" |
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, token=TOKEN) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, token=TOKEN) |
|
|
|
chatbot = pipeline("conversational", model=model, tokenizer=tokenizer) |
|
return chatbot |
|
|
|
chatbot = get_pipeline(st.session_state.model_name) |
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
if len(st.session_state.messages) == 0 and "id" in st.query_params: |
|
with st.spinner("Loading chat..."): |
|
id = st.query_params["id"] |
|
data = get_file(id) |
|
obj = pickle.loads(data) |
|
for k, v in obj.items(): |
|
st.session_state[k] = v |
|
|
|
|
|
def share(): |
|
obj = {} |
|
for k in share_keys: |
|
if k in st.session_state: |
|
obj[k] = st.session_state[k] |
|
data = pickle.dumps(obj) |
|
id = upload_file(data) |
|
url = f"https://umbc-nlp-chat-llm.hf.space/?id={id}" |
|
st.markdown(f"[share](/?id={id})") |
|
st.success(f"Share URL: {url}") |
|
|
|
with st.sidebar: |
|
st.title(":blue[LLM Only]") |
|
|
|
st.subheader("Model") |
|
model_name = st.selectbox("Model", MODELS, key="model_name") |
|
|
|
if st.button("Share", use_container_width=True): |
|
share() |
|
|
|
cols = st.columns(2) |
|
with cols[0]: |
|
if st.button("Restart", type="primary", use_container_width=True): |
|
restart() |
|
|
|
with cols[1]: |
|
if st.button("Undo", use_container_width=True): |
|
undo() |
|
|
|
append = st.checkbox("Append to previous message", value=False) |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
def push_message(role, content): |
|
message = {"role": role, "content": content} |
|
st.session_state.messages.append(message) |
|
return message |
|
|
|
if prompt := st.chat_input("Type a message", key="chat_input"): |
|
push_message("user", prompt) |
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
if not append: |
|
with st.chat_message("assistant"): |
|
chat = Conversation() |
|
for m in st.session_state.messages: |
|
chat.add_message(m) |
|
print(chat) |
|
with st.spinner("Generating response..."): |
|
response = chatbot(chat) |
|
response = response[-1]["content"] |
|
st.write(response) |
|
|
|
push_message("assistant", response) |
|
clear_uploader() |