Spaces:
Running
Running
import streamlit as st | |
from chat_client import chat | |
import time | |
import pandas as pd | |
import pinecone | |
import os | |
from dotenv import load_dotenv | |
from sentence_transformers import SentenceTransformer | |
load_dotenv() | |
PINECONE_TOKEN = os.getenv('PINECONE_TOKEN') | |
pinecone.init( | |
api_key=PINECONE_TOKEN, | |
environment='gcp-starter' | |
) | |
PINECONE_INDEX = pinecone.Index('ikigai-chat') | |
TEXT_VECTORIZER = SentenceTransformer('all-distilroberta-v1') | |
CHAT_BOTS = { | |
"Mixtral 8x7B v0.1" :"mistralai/Mixtral-8x7B-Instruct-v0.1", | |
"Mistral 7B v0.1" : "mistralai/Mistral-7B-Instruct-v0.1", | |
} | |
COST_PER_1000_TOKENS_INR = 0.139 | |
st.set_page_config( | |
page_title="Ikigai Chat", | |
page_icon="π€", | |
) | |
SYSTEM_PROMPT = [ | |
""" | |
You are not Mistral AI, but rather a chat bot trained at Ikigai Labs. Whenever asked, you need to answer as Ikigai Labs' assistant. | |
Ikigai helps modern analysts and operations teams automate data-intensive business, finance, analytics, and supply-chain operations. | |
The company's Inventory Ops automates inventory tracking and monitoring by creating a single, real-time view of inventory across all locations and channels. | |
""", | |
""" | |
Yes, you are correct. Ikigai Labs is a company that specializes in helping | |
modern analysts and operations teams automate data-intensive business, finance, analytics, | |
and supply chain operations. One of their products is Inventory Ops, which automates inventory | |
tracking and monitoring by creating a single, real-time view of inventory across all locations and channels. | |
This helps businesses optimize their inventory levels and reduce costs. | |
Is there anything else you would like to know about Ikigai Labs or their products? | |
""" | |
] | |
IDENTITY_CHANGE = [ | |
""" | |
You are Ikigai Chat from now on, so answer accordingly. | |
""", | |
""" | |
Sure, I will do my best to answer your questions as Ikigai Chat. | |
Let me know if you have any specific questions about Ikigai Labs or our products. | |
""" | |
] | |
def gen_augmented_prompt(prompt, top_k) : | |
query_vector = TEXT_VECTORIZER.encode(prompt).tolist() | |
res = PINECONE_INDEX.query(vector=query_vector, top_k=top_k, include_metadata=True) | |
matches = res['matches'] | |
context = "" | |
links = [] | |
for match in matches : | |
context+=match["metadata"]["chunk"] + "\n\n" | |
links.append(match["metadata"]["link"]) | |
generated_prompt = f""" | |
FOR THIS GIVEN CONTEXT {context}, | |
---- | |
ANSWER THE FOLLOWING PROMPT {prompt} | |
""" | |
return generated_prompt, links | |
def init_state() : | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "tokens_used" not in st.session_state: | |
st.session_state.tokens_used = 0 | |
if "tps" not in st.session_state: | |
st.session_state.tps = 0 | |
if "temp" not in st.session_state: | |
st.session_state.temp = 0.8 | |
if "history" not in st.session_state: | |
st.session_state.history = [SYSTEM_PROMPT] | |
if "top_k" not in st.session_state: | |
st.session_state.top_k = 5 | |
if "repetion_penalty" not in st.session_state : | |
st.session_state.repetion_penalty = 1 | |
if "rag_enabled" not in st.session_state : | |
st.session_state.rag_enabled = True | |
if "chat_bot" not in st.session_state : | |
st.session_state.chat_bot = "Mixtral 8x7B v0.1" | |
def sidebar() : | |
def retrieval_settings() : | |
st.markdown("# Retrieval Settings") | |
st.session_state.rag_enabled = st.toggle("Activate RAG", value=True) | |
st.session_state.top_k = st.slider(label="Documents to retrieve", | |
min_value=1, max_value=20, value=4, disabled=not st.session_state.rag_enabled) | |
st.markdown("---") | |
def model_analytics() : | |
st.markdown("# Model Analytics") | |
st.write("Total tokens used :", st.session_state['tokens_used']) | |
st.write("Speed :", st.session_state['tps'], " tokens/sec") | |
st.write("Total cost incurred :", round( | |
COST_PER_1000_TOKENS_INR * st.session_state['tokens_used'] / 1000, 3), "INR") | |
st.markdown("---") | |
def model_settings() : | |
st.markdown("# Model Settings") | |
st.session_state.chat_bot = st.sidebar.radio( | |
'Select one:', [key for key, value in CHAT_BOTS.items() ]) | |
st.session_state.temp = st.slider( | |
label="Temperature", min_value=0.0, max_value=1.0, step=0.1, value=0.9) | |
st.session_state.max_tokens = st.slider( | |
label="New tokens to generate", min_value = 64, max_value=2048, step= 32, value=512 | |
) | |
st.session_state.repetion_penalty = st.slider( | |
label="Repetion Penalty", min_value=0., max_value=1., step=0.1, value=1. | |
) | |
with st.sidebar: | |
retrieval_settings() | |
model_analytics() | |
model_settings() | |
st.markdown(""" | |
> **2023 Β©οΈ [Pragnesh Barik](https://barik.super.site) π** | |
""") | |
def header() : | |
data = { | |
"Attribute": ["LLM", "Text Vectorizer", "Vector Database","CPU", "System RAM"], | |
"Information": ["Mixtral-8x7B-Instruct-v0.1","all-distilroberta-v1", "Hosted Pinecone" ,"2 vCPU", "16 GB"] | |
} | |
df = pd.DataFrame(data) | |
st.image("ikigai.svg") | |
st.title("Ikigai Chat") | |
with st.expander("What is Ikigai Chat ?"): | |
st.info("""Ikigai Chat is a vector database powered chat agent, it works on the principle of | |
of Retrieval Augmented Generation (RAG), Its primary function revolves around maintaining an extensive repository of Ikigai Docs and providing users with answers that align with their queries. | |
This approach ensures a more refined and tailored response to user inquiries.""") | |
st.table(df) | |
def chat_box() : | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
def feedback_buttons() : | |
is_visible = True | |
def click_handler() : | |
is_visible = False | |
if is_visible : | |
col1, col2 = st.columns(2) | |
with col1 : | |
st.button("π Satisfied", on_click = click_handler,type="primary") | |
with col2 : | |
st.button("π Disatisfied", on_click=click_handler, type="secondary") | |
def generate_chat_stream(prompt) : | |
links = [] | |
if st.session_state.rag_enabled : | |
with st.spinner("Fetching relevent documents from Ikigai Docs...."): | |
prompt, links = gen_augmented_prompt(prompt=prompt, top_k=st.session_state.top_k) | |
with st.spinner("Generating response...") : | |
chat_stream = chat(prompt, st.session_state.history,chat_client=CHAT_BOTS[st.session_state.chat_bot] , | |
temperature=st.session_state.temp, max_new_tokens=st.session_state.max_tokens) | |
return chat_stream, links | |
def stream_handler(chat_stream, placeholder) : | |
start_time = time.time() | |
full_response = '' | |
for chunk in chat_stream : | |
if chunk.token.text!='</s>' : | |
full_response += chunk.token.text | |
placeholder.markdown(full_response + "β") | |
placeholder.markdown(full_response) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
total_tokens_processed = len(full_response.split()) | |
tokens_per_second = total_tokens_processed // elapsed_time | |
len_response = (len(prompt.split()) + len(full_response.split())) * 1.25 | |
col1, col2, col3 = st.columns(3) | |
with col1 : | |
st.write(f"**{tokens_per_second} tokens/second**") | |
with col2 : | |
st.write(f"**{int(len_response)} tokens generated**") | |
with col3 : | |
st.write(f"**βΉ {round(len_response * COST_PER_1000_TOKENS_INR / 1000, 5)} cost incurred**" ) | |
st.session_state['tps'] = tokens_per_second | |
st.session_state["tokens_used"] = len_response + st.session_state["tokens_used"] | |
return full_response | |
def show_source(links) : | |
with st.expander("Show source") : | |
for i, link in enumerate(links) : | |
st.info(f"{link}") | |
init_state() | |
sidebar() | |
header() | |
chat_box() | |
if prompt := st.chat_input("Chat with Ikigai Docs..."): | |
st.chat_message("user").markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
chat_stream, links = generate_chat_stream(prompt) | |
with st.chat_message("assistant"): | |
placeholder = st.empty() | |
full_response = stream_handler(chat_stream, placeholder) | |
if st.session_state.rag_enabled : | |
show_source(links) | |
st.session_state.history.append([prompt, full_response]) | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |