|
import os |
|
|
|
os.environ["HF_HOME"] = "/scratch/sroydip1/cache/hf/" |
|
os.environ["HUGGINGFACEHUB_API_TOKEN"] = "" |
|
|
|
import pickle |
|
import torch |
|
import streamlit as st |
|
from transformers import Conversation, pipeline |
|
from upload import get_file, upload_file |
|
from utils import clear_uploader, undo, restart |
|
|
|
from langchain.prompts import PromptTemplate |
|
from langchain.chains import LLMChain |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain_community.llms import HuggingFaceHub |
|
|
|
|
|
share_keys = ["messages", "model_name"] |
|
MODELS = [ |
|
"mistralai/Mistral-7B-Instruct-v0.2", |
|
"google/flan-t5-small", |
|
"google/flan-t5-base", |
|
"google/flan-t5-large", |
|
"google/flan-t5-xl", |
|
"google/flan-t5-xxl", |
|
] |
|
default_model = "mistralai/Mistral-7B-Instruct-v0.2" |
|
|
|
|
|
st.set_page_config( |
|
page_title="LLM", |
|
page_icon="π", |
|
) |
|
|
|
if "model_name" not in st.session_state: |
|
st.session_state.model_name = default_model |
|
|
|
template = """You are a friendly chatbot engaging in a conversation with a human. |
|
|
|
Previous conversation: |
|
{chat_history} |
|
|
|
New human question: {question} |
|
Response:""" |
|
|
|
|
|
def get_pipeline(model_name): |
|
llm = HuggingFaceHub( |
|
repo_id=model_name, |
|
task="text-generation", |
|
model_kwargs={ |
|
"max_new_tokens": 512, |
|
"top_k": 30, |
|
"temperature": 0.1, |
|
"repetition_penalty": 1.03, |
|
}, |
|
) |
|
return llm |
|
|
|
|
|
chatbot = get_pipeline(st.session_state.model_name) |
|
memory = ConversationBufferMemory(memory_key="chat_history") |
|
prompt_template = PromptTemplate.from_template(template) |
|
conversation = LLMChain(llm=chatbot, prompt=prompt_template, verbose=True, memory=memory) |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
if len(st.session_state.messages) == 0 and "id" in st.query_params: |
|
with st.spinner("Loading chat..."): |
|
id = st.query_params["id"] |
|
data = get_file(id) |
|
obj = pickle.loads(data) |
|
for k, v in obj.items(): |
|
st.session_state[k] = v |
|
|
|
|
|
def share(): |
|
obj = {} |
|
for k in share_keys: |
|
if k in st.session_state: |
|
obj[k] = st.session_state[k] |
|
data = pickle.dumps(obj) |
|
id = upload_file(data) |
|
url = f"https://umbc-nlp-chat-llm.hf.space/?id={id}" |
|
st.markdown(f"[share](/?id={id})") |
|
st.success(f"Share URL: {url}") |
|
|
|
|
|
with st.sidebar: |
|
st.title(":blue[LLM Only]") |
|
|
|
st.subheader("Model") |
|
model_name = st.selectbox( |
|
"Model", MODELS, index=MODELS.index(st.session_state.model_name) |
|
) |
|
|
|
if st.button("Share", use_container_width=True): |
|
share() |
|
|
|
cols = st.columns(2) |
|
with cols[0]: |
|
if st.button("Restart", type="primary", use_container_width=True): |
|
restart() |
|
|
|
with cols[1]: |
|
if st.button("Undo", use_container_width=True): |
|
undo() |
|
|
|
append = st.checkbox("Append to previous message", value=False) |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
def push_message(role, content): |
|
message = {"role": role, "content": content} |
|
st.session_state.messages.append(message) |
|
return message |
|
|
|
|
|
if prompt := st.chat_input("Type a message", key="chat_input"): |
|
push_message("user", prompt) |
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
if not append: |
|
with st.chat_message("assistant"): |
|
print(conversation) |
|
with st.spinner("Generating response..."): |
|
response = conversation({"question": prompt}) |
|
print(response) |
|
response = response["text"] |
|
st.write(response) |
|
|
|
push_message("assistant", response) |
|
clear_uploader() |
|
|