dipraj commited on
Commit
f16ba4c
β€’
1 Parent(s): b6104f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -52
app.py CHANGED
@@ -1,62 +1,128 @@
1
-
2
  import streamlit as st
 
 
3
  import os
4
- from langchain_groq import ChatGroq
5
- from langchain_community.document_loaders import WebBaseLoader
6
- from langchain_community.embeddings import OllamaEmbeddings
7
- from langchain.embeddings import OllamaEmbeddings
8
- from langchain.text_splitter import RecursiveCharacterTextSplitter
9
- from langchain.chains.combine_documents import create_stuff_documents_chain
10
- from langchain_core.prompts import ChatPromptTemplate
11
- from langchain.chains import create_retrieval_chain
12
- from langchain_community.vectorstores import FAISS
13
  import time
 
 
 
14
 
15
  from dotenv import load_dotenv
 
16
  load_dotenv()
17
 
18
  ## load the Groq API key
19
  groq_api_key=os.environ['GROQ_API_KEY']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- if "vector" not in st.session_state:
22
- st.session_state.embeddings=OllamaEmbeddings()
23
- st.session_state.loader=WebBaseLoader("https://docs.smith.langchain.com/")
24
- st.session_state.docs=st.session_state.loader.load()
25
-
26
- st.session_state.text_splitter=RecursiveCharacterTextSplitter(chunk_size=1000,chunk_overlap=200)
27
- st.session_state.final_documents=st.session_state.text_splitter.split_documents(st.session_state.docs[:50])
28
- st.session_state.vectors=FAISS.from_documents(st.session_state.final_documents,st.session_state.embeddings)
29
-
30
- st.title("ChatGroq Demo")
31
- llm=ChatGroq(groq_api_key=groq_api_key,
32
- model_name="mixtral-8x7b-32768")
33
-
34
- prompt=ChatPromptTemplate.from_template(
35
- """
36
- Answer the questions based on the provided context only.
37
- Please provide the most accurate response based on the question
38
- <context>
39
- {context}
40
- <context>
41
- Questions:{input}
42
-
43
- """
44
- )
45
- document_chain = create_stuff_documents_chain(llm, prompt)
46
- retriever = st.session_state.vectors.as_retriever()
47
- retrieval_chain = create_retrieval_chain(retriever, document_chain)
48
-
49
- prompt=st.text_input("Input you prompt here")
50
-
51
- if prompt:
52
- start=time.process_time()
53
- response=retrieval_chain.invoke({"input":prompt})
54
- print("Response time :",time.process_time()-start)
55
- st.write(response['answer'])
56
-
57
- # With a streamlit expander
58
- with st.expander("Document Similarity Search"):
59
- # Find the relevant chunks
60
- for i, doc in enumerate(response["context"]):
61
- st.write(doc.page_content)
62
- st.write("--------------------------------")
 
 
1
  import streamlit as st
2
+ from typing import Generator
3
+ from groq import Groq
4
  import os
 
 
 
 
 
 
 
 
 
5
  import time
6
+ from langchain_groq import ChatGroq
7
+ st.set_page_config(page_icon="πŸ’¬", layout="wide",
8
+ page_title="Groq Goes Brrrrrrrr...")
9
 
10
  from dotenv import load_dotenv
11
+
12
  load_dotenv()
13
 
14
  ## load the Groq API key
15
  groq_api_key=os.environ['GROQ_API_KEY']
16
+ def icon(emoji: str):
17
+ """Shows an emoji as a Notion-style page icon."""
18
+ st.write(
19
+ f'<span style="font-size: 78px; line-height: 1">{emoji}</span>',
20
+ unsafe_allow_html=True,
21
+ )
22
+
23
+
24
+ icon("🏎️")
25
+
26
+ st.subheader("Groq Chat Streamlit App", divider="rainbow", anchor=False)
27
+
28
+ client = Groq(api_key=groq_api_key)
29
+
30
+
31
+ # Initialize chat history and selected model
32
+ if "messages" not in st.session_state:
33
+ st.session_state.messages = []
34
+
35
+ if "selected_model" not in st.session_state:
36
+ st.session_state.selected_model = None
37
+
38
+ # Define model details
39
+ models = {
40
+ "gemma-7b-it": {"name": "Gemma-7b-it", "tokens": 8192, "developer": "Google"},
41
+ "llama2-70b-4096": {"name": "LLaMA2-70b-chat", "tokens": 4096, "developer": "Meta"},
42
+ "llama3-70b-8192": {"name": "LLaMA3-70b-8192", "tokens": 8192, "developer": "Meta"},
43
+ "llama3-8b-8192": {"name": "LLaMA3-8b-8192", "tokens": 8192, "developer": "Meta"},
44
+ "mixtral-8x7b-32768": {"name": "Mixtral-8x7b-Instruct-v0.1", "tokens": 32768, "developer": "Mistral"},
45
+ }
46
+
47
+ # Layout for model selection and max_tokens slider
48
+ col1, col2 = st.columns(2)
49
+
50
+ with col1:
51
+ model_option = st.selectbox(
52
+ "Choose a model:",
53
+ options=list(models.keys()),
54
+ format_func=lambda x: models[x]["name"],
55
+ index=4 # Default to mixtral
56
+ )
57
+
58
+ # Detect model change and clear chat history if model has changed
59
+ if st.session_state.selected_model != model_option:
60
+ st.session_state.messages = []
61
+ st.session_state.selected_model = model_option
62
+
63
+ max_tokens_range = models[model_option]["tokens"]
64
+
65
+ with col2:
66
+ # Adjust max_tokens slider dynamically based on the selected model
67
+ max_tokens = st.slider(
68
+ "Max Tokens:",
69
+ min_value=512, # Minimum value to allow some flexibility
70
+ max_value=max_tokens_range,
71
+ # Default value or max allowed if less
72
+ value=min(32768, max_tokens_range),
73
+ step=512,
74
+ help=f"Adjust the maximum number of tokens (words) for the model's response. Max for selected model: {max_tokens_range}"
75
+ )
76
+
77
+ # Display chat messages from history on app rerun
78
+ for message in st.session_state.messages:
79
+ avatar = 'πŸ€–' if message["role"] == "assistant" else 'πŸ‘¨β€πŸ’»'
80
+ with st.chat_message(message["role"], avatar=avatar):
81
+ st.markdown(message["content"])
82
+
83
+
84
+ def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
85
+ """Yield chat response content from the Groq API response."""
86
+ for chunk in chat_completion:
87
+ if chunk.choices[0].delta.content:
88
+ yield chunk.choices[0].delta.content
89
+
90
+
91
+ if prompt := st.chat_input("Enter your prompt here..."):
92
+ st.session_state.messages.append({"role": "user", "content": prompt})
93
+
94
+ with st.chat_message("user", avatar='πŸ‘¨β€πŸ’»'):
95
+ st.markdown(prompt)
96
+ # Fetch response from Groq API
97
+ try:
98
+ chat_completion = client.chat.completions.create(
99
+ model=model_option,
100
+ messages=[
101
+ {
102
+ "role": m["role"],
103
+ "content": m["content"]
104
+ }
105
+ for m in st.session_state.messages
106
+ ],
107
+ max_tokens=max_tokens,
108
+ stream=False
109
+ )
110
+ full_response = chat_completion.choices[0].message.content
111
+
112
+ # Use the generator function with st.write_stream
113
+ with st.chat_message("assistant", avatar="πŸ€–"):
114
+ st.write(full_response)
115
+ for chunk in chat_completion:
116
+ if chunk.choices[0].message.content:
117
+ st.text(chunk.choices[0].message.content)
118
+
119
 
120
+ # Update message content dynamically using a loop with sleep
121
+ for i in range(1, len(full_response) // 100 + 1):
122
+ st.write(full_response[i * 100 : (i + 1) * 100])
123
+ time.sleep(0.1)
124
+
125
+ except Exception as e:
126
+ st.error(e, icon="🚨")
127
+ # Set full_response to a default value or handle the error accordingly
128
+ full_response = None