import streamlit as st import os # os.environ['HF_HOME'] = '/scratch/sroydip1/cache/hf/' os.environ["HUGGINGFACEHUB_API_TOKEN"] = st.secrets["HF_TOKEN"] # import torch import pickle import torch from transformers import Conversation, pipeline, AutoTokenizer, AutoModelForCausalLM from upload import get_file, upload_file from utils import clear_uploader, undo, restart TOKEN = st.secrets["HF_TOKEN"] share_keys = ["messages", "model_name"] MODELS = [ "meta-llama/Llama-2-7b-chat-hf", "mistralai/Mistral-7B-Instruct-v0.2", # "google/flan-t5-small", # "google/flan-t5-base", # "google/flan-t5-large", # "google/flan-t5-xl", # "google/flan-t5-xxl", ] default_model = MODELS[0] # default_model = "meta-llama/Llama-2-7b-chat-hf" st.set_page_config( page_title="LLM", page_icon="📚", ) if "model_name" not in st.session_state: st.session_state.model_name = default_model @st.cache_resource def get_pipeline(model_name): model_name = "gpt2-medium" device = 0 if torch.cuda.is_available() else -1 # if True or model_name == "meta-llama/Llama-2-7b-chat-hf" or model_name == "mistralai/Mistral-7B-Instruct-v0.2": # chatbot = pipeline(model=model_name, task="conversational", device=device)#, model_kwargs=model_kwargs) # else: # chatbot = pipeline(model=model_name, task="text-generation", device=device) tokenizer = AutoTokenizer.from_pretrained(model_name, token=TOKEN) model = AutoModelForCausalLM.from_pretrained(model_name, token=TOKEN) # chatbot = pipeline("conversational", model=model, tokenizer=tokenizer, device=device) chatbot = pipeline("conversational", model=model, tokenizer=tokenizer) return chatbot chatbot = get_pipeline(st.session_state.model_name) if "messages" not in st.session_state: st.session_state.messages = [] if len(st.session_state.messages) == 0 and "id" in st.query_params: with st.spinner("Loading chat..."): id = st.query_params["id"] data = get_file(id) obj = pickle.loads(data) for k, v in obj.items(): st.session_state[k] = v def share(): obj = {} for k in share_keys: if k in st.session_state: obj[k] = st.session_state[k] data = pickle.dumps(obj) id = upload_file(data) url = f"https://umbc-nlp-chat-llm.hf.space/?id={id}" st.markdown(f"[share](/?id={id})") st.success(f"Share URL: {url}") with st.sidebar: st.title(":blue[LLM Only]") st.subheader("Model") model_name = st.selectbox("Model", MODELS, key="model_name") if st.button("Share", use_container_width=True): share() cols = st.columns(2) with cols[0]: if st.button("Restart", type="primary", use_container_width=True): restart() with cols[1]: if st.button("Undo", use_container_width=True): undo() append = st.checkbox("Append to previous message", value=False) for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) def push_message(role, content): message = {"role": role, "content": content} st.session_state.messages.append(message) return message if prompt := st.chat_input("Type a message", key="chat_input"): push_message("user", prompt) with st.chat_message("user"): st.markdown(prompt) if not append: with st.chat_message("assistant"): chat = Conversation() for m in st.session_state.messages: chat.add_message(m) print(chat) with st.spinner("Generating response..."): response = chatbot(chat) response = response[-1]["content"] st.write(response) push_message("assistant", response) clear_uploader()