pragneshbarik commited on
Commit
e506679
1 Parent(s): cebfd3c

added support zephyr 7b

Browse files
Files changed (3) hide show
  1. __pycache__/mistral7b.cpython-310.pyc +0 -0
  2. app.py +18 -9
  3. mistral7b.py +6 -6
__pycache__/mistral7b.cpython-310.pyc CHANGED
Binary files a/__pycache__/mistral7b.cpython-310.pyc and b/__pycache__/mistral7b.cpython-310.pyc differ
 
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from mistral7b import mistral
3
  import time
4
  import pandas as pd
5
  import pinecone
@@ -17,6 +17,11 @@ pinecone.init(
17
 
18
  pinecone_index = pinecone.Index('ikigai-chat')
19
  text_vectorizer = SentenceTransformer('all-distilroberta-v1')
 
 
 
 
 
20
 
21
  def gen_augmented_prompt(prompt, top_k) :
22
  query_vector = text_vectorizer.encode(prompt).tolist()
@@ -41,7 +46,7 @@ def gen_augmented_prompt(prompt, top_k) :
41
 
42
  data = {
43
  "Attribute": ["LLM", "Text Vectorizer", "Vector Database","CPU", "System RAM"],
44
- "Information": ["Mistral-7B-Instruct-v0.1 (more models soon)","all-distilroberta-v1", "Hosted Pinecone" ,"2 vCPU", "16 GB"]
45
  }
46
  df = pd.DataFrame(data)
47
 
@@ -82,18 +87,22 @@ if "history" not in st.session_state:
82
  Let me know if you have any specific questions about Ikigai Labs or our products."""]]
83
 
84
  if "top_k" not in st.session_state:
85
- st.session_state.top_k = 3
86
 
87
  if "repetion_penalty" not in st.session_state :
88
  st.session_state.repetion_penalty = 1
89
 
90
  if "rag_enabled" not in st.session_state :
91
  st.session_state.rag_enabled = True
 
 
 
 
92
  with st.sidebar:
93
  st.markdown("# Retrieval Settings")
94
- st.session_state.rag_enabled = st.toggle("Activate RAG")
95
  st.session_state.top_k = st.slider(label="Documents to retrieve",
96
- min_value=1, max_value=10, value=3, disabled=not st.session_state.rag_enabled)
97
  st.markdown("---")
98
  st.markdown("# Model Analytics")
99
 
@@ -107,8 +116,8 @@ with st.sidebar:
107
 
108
  st.markdown("# Model Settings")
109
 
110
- selected_model = st.sidebar.radio(
111
- 'Select one:', ["Mistral 7B","Llama 7B" ,"GPT 3.5 Turbo", "GPT 4" ])
112
  st.session_state.temp = st.slider(
113
  label="Temperature", min_value=0.0, max_value=1.0, step=0.1, value=0.9)
114
 
@@ -126,7 +135,7 @@ with st.sidebar:
126
 
127
  st.image("ikigai.svg")
128
  st.title("Ikigai Chat")
129
- st.caption("Maintained and developed by Pragnesh Barik.")
130
 
131
  with st.expander("What is Ikigai Chat ?"):
132
  st.info("""Ikigai Chat is a vector database powered chat agent, it works on the principle of
@@ -153,7 +162,7 @@ if prompt := st.chat_input("Chat with Ikigai Docs..."):
153
  prompt, links = gen_augmented_prompt(prompt=prompt, top_k=st.session_state.top_k)
154
 
155
  with st.spinner("Generating response...") :
156
- response = mistral(prompt, st.session_state.history,
157
  temperature=st.session_state.temp, max_new_tokens=st.session_state.max_tokens)
158
  tock = time.time()
159
 
 
1
  import streamlit as st
2
+ from mistral7b import chat
3
  import time
4
  import pandas as pd
5
  import pinecone
 
17
 
18
  pinecone_index = pinecone.Index('ikigai-chat')
19
  text_vectorizer = SentenceTransformer('all-distilroberta-v1')
20
+ chat_bots = {
21
+ "Mistral 7B" : "mistralai/Mistral-7B-Instruct-v0.1",
22
+ "Zephyr 7B" : "HuggingFaceH4/zephyr-7b-alpha",
23
+ }
24
+
25
 
26
  def gen_augmented_prompt(prompt, top_k) :
27
  query_vector = text_vectorizer.encode(prompt).tolist()
 
46
 
47
  data = {
48
  "Attribute": ["LLM", "Text Vectorizer", "Vector Database","CPU", "System RAM"],
49
+ "Information": ["Mistral-7B-Instruct-v0.1","all-distilroberta-v1", "Hosted Pinecone" ,"2 vCPU", "16 GB"]
50
  }
51
  df = pd.DataFrame(data)
52
 
 
87
  Let me know if you have any specific questions about Ikigai Labs or our products."""]]
88
 
89
  if "top_k" not in st.session_state:
90
+ st.session_state.top_k = 4
91
 
92
  if "repetion_penalty" not in st.session_state :
93
  st.session_state.repetion_penalty = 1
94
 
95
  if "rag_enabled" not in st.session_state :
96
  st.session_state.rag_enabled = True
97
+
98
+ if "chat_bot" not in st.session_state :
99
+ st.session_state.chat_bot = "Mistral 7B"
100
+
101
  with st.sidebar:
102
  st.markdown("# Retrieval Settings")
103
+ st.session_state.rag_enabled = st.toggle("Activate RAG", value=True)
104
  st.session_state.top_k = st.slider(label="Documents to retrieve",
105
+ min_value=1, max_value=10, value=4, disabled=not st.session_state.rag_enabled)
106
  st.markdown("---")
107
  st.markdown("# Model Analytics")
108
 
 
116
 
117
  st.markdown("# Model Settings")
118
 
119
+ st.session_state.chat_bot = st.sidebar.radio(
120
+ 'Select one:', ["Mistral 7B","Zephyr 7B"])
121
  st.session_state.temp = st.slider(
122
  label="Temperature", min_value=0.0, max_value=1.0, step=0.1, value=0.9)
123
 
 
135
 
136
  st.image("ikigai.svg")
137
  st.title("Ikigai Chat")
138
+ # st.caption("Maintained and developed by Pragnesh Barik.")
139
 
140
  with st.expander("What is Ikigai Chat ?"):
141
  st.info("""Ikigai Chat is a vector database powered chat agent, it works on the principle of
 
162
  prompt, links = gen_augmented_prompt(prompt=prompt, top_k=st.session_state.top_k)
163
 
164
  with st.spinner("Generating response...") :
165
+ response = chat(prompt, st.session_state.history,chat_client=chat_bots[st.session_state.chat_bot] ,
166
  temperature=st.session_state.temp, max_new_tokens=st.session_state.max_tokens)
167
  tock = time.time()
168
 
mistral7b.py CHANGED
@@ -4,10 +4,6 @@ from dotenv import load_dotenv
4
  load_dotenv()
5
 
6
  API_TOKEN = os.getenv('HF_TOKEN')
7
- client = InferenceClient(
8
- "mistralai/Mistral-7B-Instruct-v0.1",
9
- token=API_TOKEN
10
- )
11
 
12
 
13
  def format_prompt(message, history):
@@ -18,9 +14,13 @@ def format_prompt(message, history):
18
  prompt += f"[INST] {message} [/INST]"
19
  return prompt
20
 
21
- def mistral(
22
- prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
23
  ):
 
 
 
 
24
  temperature = float(temperature)
25
  if temperature < 1e-2:
26
  temperature = 1e-2
 
4
  load_dotenv()
5
 
6
  API_TOKEN = os.getenv('HF_TOKEN')
 
 
 
 
7
 
8
 
9
  def format_prompt(message, history):
 
14
  prompt += f"[INST] {message} [/INST]"
15
  return prompt
16
 
17
+ def chat(
18
+ prompt, history, chat_client = "mistralai/Mistral-7B-Instruct-v0.1",temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
19
  ):
20
+ client = InferenceClient(
21
+ chat_client,
22
+ token=API_TOKEN
23
+ )
24
  temperature = float(temperature)
25
  if temperature < 1e-2:
26
  temperature = 1e-2