File size: 3,780 Bytes
ebab1a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os

os.environ["HF_HOME"] = "/scratch/sroydip1/cache/hf/"
os.environ["HUGGINGFACEHUB_API_TOKEN"] = ""
# import torch
import pickle
import torch
import streamlit as st
from transformers import Conversation, pipeline
from upload import get_file, upload_file
from utils import clear_uploader, undo, restart

from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory
from langchain_community.llms import HuggingFaceHub


share_keys = ["messages", "model_name"]
MODELS = [
    "mistralai/Mistral-7B-Instruct-v0.2",
    "google/flan-t5-small",
    "google/flan-t5-base",
    "google/flan-t5-large",
    "google/flan-t5-xl",
    "google/flan-t5-xxl",
]
default_model = "mistralai/Mistral-7B-Instruct-v0.2"
# default_model = "meta-llama/Llama-2-7b-chat-hf"

st.set_page_config(
    page_title="LLM",
    page_icon="πŸ“š",
)

if "model_name" not in st.session_state:
    st.session_state.model_name = default_model

template = """You are a friendly chatbot engaging in a conversation with a human.

Previous conversation:
{chat_history}

New human question: {question}
Response:"""


def get_pipeline(model_name):
    llm = HuggingFaceHub(
        repo_id=model_name,
        task="text-generation",
        model_kwargs={
            "max_new_tokens": 512,
            "top_k": 30,
            "temperature": 0.1,
            "repetition_penalty": 1.03,
        },
    )
    return llm


chatbot = get_pipeline(st.session_state.model_name)
memory = ConversationBufferMemory(memory_key="chat_history")
prompt_template = PromptTemplate.from_template(template)
conversation = LLMChain(llm=chatbot, prompt=prompt_template, verbose=True, memory=memory)


if "messages" not in st.session_state:
    st.session_state.messages = []

if len(st.session_state.messages) == 0 and "id" in st.query_params:
    with st.spinner("Loading chat..."):
        id = st.query_params["id"]
        data = get_file(id)
        obj = pickle.loads(data)
        for k, v in obj.items():
            st.session_state[k] = v


def share():
    obj = {}
    for k in share_keys:
        if k in st.session_state:
            obj[k] = st.session_state[k]
    data = pickle.dumps(obj)
    id = upload_file(data)
    url = f"https://umbc-nlp-chat-llm.hf.space/?id={id}"
    st.markdown(f"[share](/?id={id})")
    st.success(f"Share URL: {url}")


with st.sidebar:
    st.title(":blue[LLM Only]")

    st.subheader("Model")
    model_name = st.selectbox(
        "Model", MODELS, index=MODELS.index(st.session_state.model_name)
    )

    if st.button("Share", use_container_width=True):
        share()

    cols = st.columns(2)
    with cols[0]:
        if st.button("Restart", type="primary", use_container_width=True):
            restart()

    with cols[1]:
        if st.button("Undo", use_container_width=True):
            undo()

    append = st.checkbox("Append to previous message", value=False)


for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])


def push_message(role, content):
    message = {"role": role, "content": content}
    st.session_state.messages.append(message)
    return message


if prompt := st.chat_input("Type a message", key="chat_input"):
    push_message("user", prompt)
    with st.chat_message("user"):
        st.markdown(prompt)

    if not append:
        with st.chat_message("assistant"):
            print(conversation)
            with st.spinner("Generating response..."):
                response = conversation({"question": prompt})
                print(response)
                response = response["text"]
                st.write(response)

        push_message("assistant", response)
    clear_uploader()