File size: 3,780 Bytes
ebab1a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import os
os.environ["HF_HOME"] = "/scratch/sroydip1/cache/hf/"
os.environ["HUGGINGFACEHUB_API_TOKEN"] = ""
# import torch
import pickle
import torch
import streamlit as st
from transformers import Conversation, pipeline
from upload import get_file, upload_file
from utils import clear_uploader, undo, restart
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory
from langchain_community.llms import HuggingFaceHub
share_keys = ["messages", "model_name"]
MODELS = [
"mistralai/Mistral-7B-Instruct-v0.2",
"google/flan-t5-small",
"google/flan-t5-base",
"google/flan-t5-large",
"google/flan-t5-xl",
"google/flan-t5-xxl",
]
default_model = "mistralai/Mistral-7B-Instruct-v0.2"
# default_model = "meta-llama/Llama-2-7b-chat-hf"
st.set_page_config(
page_title="LLM",
page_icon="π",
)
if "model_name" not in st.session_state:
st.session_state.model_name = default_model
template = """You are a friendly chatbot engaging in a conversation with a human.
Previous conversation:
{chat_history}
New human question: {question}
Response:"""
def get_pipeline(model_name):
llm = HuggingFaceHub(
repo_id=model_name,
task="text-generation",
model_kwargs={
"max_new_tokens": 512,
"top_k": 30,
"temperature": 0.1,
"repetition_penalty": 1.03,
},
)
return llm
chatbot = get_pipeline(st.session_state.model_name)
memory = ConversationBufferMemory(memory_key="chat_history")
prompt_template = PromptTemplate.from_template(template)
conversation = LLMChain(llm=chatbot, prompt=prompt_template, verbose=True, memory=memory)
if "messages" not in st.session_state:
st.session_state.messages = []
if len(st.session_state.messages) == 0 and "id" in st.query_params:
with st.spinner("Loading chat..."):
id = st.query_params["id"]
data = get_file(id)
obj = pickle.loads(data)
for k, v in obj.items():
st.session_state[k] = v
def share():
obj = {}
for k in share_keys:
if k in st.session_state:
obj[k] = st.session_state[k]
data = pickle.dumps(obj)
id = upload_file(data)
url = f"https://umbc-nlp-chat-llm.hf.space/?id={id}"
st.markdown(f"[share](/?id={id})")
st.success(f"Share URL: {url}")
with st.sidebar:
st.title(":blue[LLM Only]")
st.subheader("Model")
model_name = st.selectbox(
"Model", MODELS, index=MODELS.index(st.session_state.model_name)
)
if st.button("Share", use_container_width=True):
share()
cols = st.columns(2)
with cols[0]:
if st.button("Restart", type="primary", use_container_width=True):
restart()
with cols[1]:
if st.button("Undo", use_container_width=True):
undo()
append = st.checkbox("Append to previous message", value=False)
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
def push_message(role, content):
message = {"role": role, "content": content}
st.session_state.messages.append(message)
return message
if prompt := st.chat_input("Type a message", key="chat_input"):
push_message("user", prompt)
with st.chat_message("user"):
st.markdown(prompt)
if not append:
with st.chat_message("assistant"):
print(conversation)
with st.spinner("Generating response..."):
response = conversation({"question": prompt})
print(response)
response = response["text"]
st.write(response)
push_message("assistant", response)
clear_uploader()
|