Spaces:
Running
Running
import requests | |
import json | |
import streamlit as st | |
import sys | |
import time | |
st.set_page_config(page_title="Samba-CoE_v0.1 Chatbot") | |
if 'expert_sampling_details' not in st.session_state.keys(): | |
st.session_state['temperature'] = 0.7 | |
st.session_state['top_p'] = 0.1 | |
st.session_state['max_length'] = 512 | |
st.session_state['history_buffer_length'] = 3 | |
st.session_state['top_k'] = 40 | |
with st.sidebar: | |
st.sidebar.info("Samba-CoE_v0.1 Chatbot") | |
with st.sidebar.form("Chatbot Settings"): | |
temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=1.0, value=0.7, step=0.01) | |
top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.1, step=0.01) | |
max_length = st.sidebar.slider('max_length', min_value=32, max_value=1024, value=512, step=8) | |
submitted = st.form_submit_button("Submit") | |
if submitted: | |
st.session_state['max_length'] = max_length | |
st.session_state['temperature'] = temperature | |
st.session_state['top_p'] = top_p | |
st.sidebar.success("Generation parameters:\n" + f'\n * temperature: {st.session_state.temperature}\n' + f'\n * top_p: {st.session_state.top_p}\n' + f'\n * max_tokens: {st.session_state.max_length}\n' + f'\n * skip_special_tokens: True\n'+ f'\n * do_sample: True\n' ) | |
def clear_chat_history(): | |
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}] | |
st.sidebar.button('Clear Chat History', on_click=clear_chat_history) | |
# Store LLM generated responses | |
if "messages" not in st.session_state.keys(): | |
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}] | |
# Display or clear chat message | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
def generate_response(prompt_input): | |
history_buffer_length = st.session_state.history_buffer_length | |
message_history = st.session_state.messages[-history_buffer_length:] | |
string_dialogue = "<s> [INST] You are a helpful assistant developed by SambaNova Systems as part of its Composition of Expert (CoE) effort. Always assist with care, respect, and truth. Respond with utmost utility yet securely and professionally. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness, positivity and an engaging conversation. [/INST] \n" | |
for dict_message in message_history: | |
if dict_message["role"] == "user": | |
string_dialogue += '[INST]' + dict_message["content"] + '[/INST]' + '\n' | |
else: | |
string_dialogue += dict_message["content"] + "</s>" + '\n' | |
payload = {'prompt': string_dialogue, | |
'max_tokens': st.session_state['max_length'], | |
'n': 1, | |
'do_sample': True, | |
'temperature': st.session_state['temperature'], | |
'top_p': st.session_state['top_p'], | |
'top_k': st.session_state['top_k'], | |
'skip_special_token': True, | |
'repition_penalty': 1.15, | |
'stop_sequences': ['INST', '[INST]', '[/INST]'] | |
} | |
r = requests.post(st.secrets["backend_url"], json=payload) | |
response = r.json()['choices'][0]['text'] | |
return response | |
# User-provided prompt | |
if prompt := st.chat_input(): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.write(prompt) | |
def clean_response(response): | |
if '[INST' in response: | |
return response.split('[INST')[0] | |
elif '[/INST' in response: | |
return response.split('[/INST')[0] | |
return response | |
def stream_response(response): | |
response = "We are undergoing speed upgrade, will be back in a moment! Stay tuned!" | |
for word in response.split(" "): | |
yield word + " " | |
time.sleep(0.02) | |
# Generate a new response if last message is not from assistant | |
if st.session_state.messages[-1]["role"] != "assistant": | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
# response = generate_response(prompt) | |
# response = clean_response(response) | |
response = None | |
st.write_stream(stream_response(response)) | |
message = {"role": "assistant", "content": response} | |
st.session_state.messages.append(message) | |