Ilyas KHIAT commited on
Commit
8df1e9f
1 Parent(s): 25eeaae

ajout et big update

Browse files
.streamlit/.env CHANGED
@@ -1,2 +1,3 @@
1
  API_TOKEN_PERPLEXITYAI = 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
2
  OPENAI_API_KEY = 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
 
 
1
  API_TOKEN_PERPLEXITYAI = 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
2
  OPENAI_API_KEY = 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
3
+ FIRECRAWL_API_KEY = fc-381ecdb1175147aab5d2b48023961491
chat_te.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain_core.messages import AIMessage, HumanMessage
3
+ from langchain_community.chat_models import ChatOpenAI
4
+ from dotenv import load_dotenv
5
+ from langchain_core.output_parsers import StrOutputParser
6
+ from langchain_core.prompts import ChatPromptTemplate
7
+ from download_chart import construct_plot
8
+ from langchain_core.runnables import RunnablePassthrough
9
+ from langchain import hub
10
+ from langchain_core.prompts.prompt import PromptTemplate
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain_community.embeddings import OpenAIEmbeddings
13
+ load_dotenv()
14
+
15
+ def get_conversation_chain(vectorstore):
16
+ llm = ChatOpenAI(model="gpt-4o",temperature=0.5, max_tokens=2048)
17
+ retriever=vectorstore.as_retriever()
18
+
19
+ prompt = hub.pull("rlm/rag-prompt")
20
+ # Chain
21
+ rag_chain = (
22
+ {"context": retriever , "question": RunnablePassthrough()}
23
+ | prompt
24
+ | llm
25
+ | StrOutputParser()
26
+ )
27
+ return rag_chain
28
+
29
+ def get_response(user_query, chat_history):
30
+
31
+ template = """
32
+ Chat history: {chat_history}
33
+ User question: {user_question}
34
+ """
35
+
36
+ embeddings = OpenAIEmbeddings()
37
+ db = FAISS.load_local("vectorstore_op", embeddings)
38
+
39
+ question = ChatPromptTemplate.from_template(template)
40
+ question = question.format(chat_history=chat_history, user_question=user_query)
41
+
42
+ chain = get_conversation_chain(db)
43
+
44
+ return chain.stream(question)
45
+
46
+ def display_chart():
47
+ if "pp_grouped" not in st.session_state or st.session_state['pp_grouped'] is None or len(st.session_state['pp_grouped']) == 0:
48
+ st.warning("Aucune partie prenante n'a été définie")
49
+ return None
50
+ plot = construct_plot()
51
+ st.plotly_chart(plot)
52
+
53
+
54
+ def display_chat():
55
+ # app config
56
+ st.title("Chatbot")
57
+
58
+ # session state
59
+ if "chat_history" not in st.session_state:
60
+ st.session_state.chat_history = [
61
+ AIMessage(content="Salut, voici votre cartographie des parties prenantes. Que puis-je faire pour vous?"),
62
+ ]
63
+
64
+
65
+ # conversation
66
+ for message in st.session_state.chat_history:
67
+ if isinstance(message, AIMessage):
68
+ with st.chat_message("AI"):
69
+ st.write(message.content)
70
+ if "cartographie des parties prenantes" in message.content:
71
+ display_chart()
72
+ elif isinstance(message, HumanMessage):
73
+ with st.chat_message("Moi"):
74
+ st.write(message.content)
75
+
76
+ # user input
77
+ user_query = st.chat_input("Par ici...")
78
+ if user_query is not None and user_query != "":
79
+ st.session_state.chat_history.append(HumanMessage(content=user_query))
80
+
81
+ with st.chat_message("Moi"):
82
+ st.markdown(user_query)
83
+
84
+ with st.chat_message("AI"):
85
+
86
+ response = st.write_stream(get_response(user_query, st.session_state.chat_history,format_context(st.session_state['pp_grouped'],st.session_state['Nom de la marque'])))
87
+ if "cartographie des parties prenantes" in message.content:
88
+ display_chart()
89
+
90
+ st.session_state.chat_history.append(AIMessage(content=response))
chat_with_pps.py CHANGED
@@ -25,8 +25,6 @@ def format_context(partie_prenante_grouped,marque):
25
  '''
26
  context += segmentation
27
  return context
28
-
29
-
30
 
31
 
32
  def get_response(user_query, chat_history, context):
 
25
  '''
26
  context += segmentation
27
  return context
 
 
28
 
29
 
30
  def get_response(user_query, chat_history, context):
high_chart.py CHANGED
@@ -151,7 +151,8 @@ cd2 = {
151
  "dragSensitivity":0
152
  },
153
  "data":[],
154
- "colorByPoint":True
 
155
  }
156
  ],
157
  "exporting": {
@@ -191,7 +192,7 @@ def test_chart():
191
  # st.session_state['pp_grouped'] = chart
192
 
193
 
194
-
195
  if st.session_state['save']:
196
  st.session_state['save'] = False
197
  st.session_state['pp_grouped'] = chart.copy()
 
151
  "dragSensitivity":0
152
  },
153
  "data":[],
154
+ "colorByPoint":True,
155
+
156
  }
157
  ],
158
  "exporting": {
 
192
  # st.session_state['pp_grouped'] = chart
193
 
194
 
195
+ st.write(chart)
196
  if st.session_state['save']:
197
  st.session_state['save'] = False
198
  st.session_state['pp_grouped'] = chart.copy()
partie_prenante_carte.py CHANGED
@@ -15,13 +15,11 @@ from langchain.llms import HuggingFaceHub
15
  from langchain import hub
16
  from langchain_core.output_parsers import StrOutputParser
17
  from langchain_core.runnables import RunnablePassthrough
18
- from langchain_community.document_loaders import WebBaseLoader
19
  from langchain_core.prompts.prompt import PromptTemplate
20
- import altair as alt
21
  from session import set_partie_prenante
22
  import os
23
  from streamlit_vertical_slider import vertical_slider
24
- from pp_viz import display_viz
25
  from high_chart import test_chart
26
 
27
  load_dotenv()
@@ -35,7 +33,18 @@ def get_docs_from_website(urls):
35
  return docs
36
  except Exception as e:
37
  return None
38
-
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def get_doc_chunks(docs):
41
  # Split the loaded data
@@ -43,17 +52,35 @@ def get_doc_chunks(docs):
43
  # chunk_size=500,
44
  # chunk_overlap=100)
45
 
46
- text_splitter = SemanticChunker(OpenAIEmbeddings())
47
 
48
  docs = text_splitter.split_documents(docs)
49
  return docs
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  def get_vectorstore_from_docs(doc_chunks):
53
- embedding = OpenAIEmbeddings(model="text-embedding-3-large")
54
  vectorstore = FAISS.from_documents(documents=doc_chunks, embedding=embedding)
55
  return vectorstore
56
 
 
 
 
 
 
57
  def get_conversation_chain(vectorstore):
58
  llm = ChatOpenAI(model="gpt-4o",temperature=0.5, max_tokens=2048)
59
  retriever=vectorstore.as_retriever()
@@ -107,12 +134,15 @@ def display_list_urls():
107
 
108
  if len(st.session_state.urls) > index:
109
  # Instead of using markdown, use an expander in the first column
110
- with col1.expander(f"URL {index}: {item}"):
111
  pp = st.session_state["parties_prenantes"][index]
112
  st.write(pd.DataFrame(pp, columns=["Partie prenante"]))
113
  else:
114
  emp.empty() # Clear the placeholder if the index exceeds the list
115
 
 
 
 
116
  def display_list_pps():
117
  for index, item in enumerate(st.session_state["pp_grouped"]):
118
  emp = st.empty()
@@ -125,27 +155,24 @@ def display_list_pps():
125
 
126
  if len(st.session_state["pp_grouped"]) > index:
127
  name = st.session_state["pp_grouped"][index]["name"]
128
- col1.markdown(f"{name}")
 
 
129
  else:
130
  emp.empty()
131
 
132
 
133
 
134
- def extract_pp(urls,input_variables):
135
  template_extraction_PP = '''
136
- Objectif : identifiez tout les noms de marques qui sont des parties prenantes de la marque suivante pour développer un marketing de coopération (co-op marketing)
137
-
138
- Le nom de la marque de référence est le suivant : {BRAND_NAME}
139
- Son activité est la suivante : {BRAND_DESCRIPTION}
140
 
141
- TA REPONSE DOIT ETRE SOUS FORME DE LISTE DE NOMS DE MARQUES SANS NUMEROTATION ET SEPARES PAR DES SAUTS DE LIGNE
142
 
143
- SI TU NE TROUVES PAS DE NOM DE MARQUE, REPONDS "444"
144
- '''
145
  #don't forget to add the input variables from the maim function
146
 
147
- docs = get_docs_from_website(urls)
148
-
149
  if docs == None:
150
  return "445"
151
 
@@ -167,9 +194,22 @@ def extract_pp(urls,input_variables):
167
 
168
  #version simple
169
  partie_prenante = response.content.replace("- ","").split('\n')
 
170
 
171
  return partie_prenante
172
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  def format_pp_add_viz(pp):
174
  y = 50
175
  x = 50
@@ -182,11 +222,11 @@ def format_pp_add_viz(pp):
182
  if st.session_state['pp_grouped'][i]['name'] == pp:
183
  return None
184
  else:
185
- st.session_state['pp_grouped'].append({'name':pp, 'x':x,'y':y})
186
 
187
  def add_pp(new_pp, default_value=50):
188
  new_pp = sorted(new_pp)
189
- new_pp = [item.lower().capitalize() for item in new_pp]
190
  st.session_state['parties_prenantes'].append(new_pp)
191
  for pp in new_pp:
192
  format_pp_add_viz(pp)
@@ -198,6 +238,7 @@ def add_pp_input_text():
198
  format_pp_add_viz(new_pp)
199
 
200
  import re
 
201
 
202
  def complete_and_verify_url(partial_url):
203
  # Regex pattern for validating a URL
@@ -232,7 +273,7 @@ def complete_and_verify_url(partial_url):
232
  def display_pp():
233
 
234
  load_dotenv()
235
-
236
  #check if brand name and description are already set
237
  if "Nom de la marque" not in st.session_state:
238
  st.session_state["Nom de la marque"] = ""
@@ -260,6 +301,7 @@ def display_pp():
260
 
261
  url = st.text_input("Ajouter une URL")
262
 
 
263
  #if the user clicks on the button
264
  if st.button("ajouter"):
265
  st.session_state["save"] = True
@@ -271,9 +313,20 @@ def display_pp():
271
  st.error("URL déjà ajoutée")
272
 
273
  else:
274
- docs = get_docs_from_website(url)
 
 
 
 
 
 
 
 
 
 
 
275
  if docs is None:
276
- st.error("Aucune url trouvée ou erreur lors de la récupération du contenu")
277
  else:
278
  # Création de l'expander
279
  with st.expander("Cliquez ici pour éditer et voir le document"):
@@ -286,7 +339,7 @@ def display_pp():
286
 
287
  #handle the extraction
288
  input_variables = {"BRAND_NAME": brand_name, "BRAND_DESCRIPTION": ""}
289
- partie_prenante = extract_pp([url], input_variables)
290
 
291
  if "444" in partie_prenante: #444 is the code for no brand found , chosen
292
  st.error("Aucune partie prenante trouvée")
 
15
  from langchain import hub
16
  from langchain_core.output_parsers import StrOutputParser
17
  from langchain_core.runnables import RunnablePassthrough
18
+ from langchain_community.document_loaders import WebBaseLoader,FireCrawlLoader,PDFLoader
19
  from langchain_core.prompts.prompt import PromptTemplate
 
20
  from session import set_partie_prenante
21
  import os
22
  from streamlit_vertical_slider import vertical_slider
 
23
  from high_chart import test_chart
24
 
25
  load_dotenv()
 
33
  return docs
34
  except Exception as e:
35
  return None
36
+
37
+
38
+ def get_docs_from_website_fc(urls,firecrawl_api_key):
39
+ docs = []
40
+ try:
41
+ for url in urls:
42
+ loader = FireCrawlLoader(api_key=firecrawl_api_key, url = url,mode="scrape")
43
+ docs+=loader.load()
44
+ return docs
45
+ except Exception as e:
46
+ return None
47
+
48
 
49
  def get_doc_chunks(docs):
50
  # Split the loaded data
 
52
  # chunk_size=500,
53
  # chunk_overlap=100)
54
 
55
+ text_splitter = SemanticChunker(OpenAIEmbeddings(model="text-embedding-3-small"))
56
 
57
  docs = text_splitter.split_documents(docs)
58
  return docs
59
+
60
+ def get_doc_chunks_fc(docs):
61
+ # Split the loaded data
62
+ # text_splitter = RecursiveCharacterTextSplitter(
63
+ # chunk_size=500,
64
+ # chunk_overlap=100)
65
+
66
+ text_splitter = SemanticChunker(OpenAIEmbeddings(model="text-embedding-3-small"))
67
+ docs_splitted = []
68
+ for text in docs:
69
+ text_splitted = text_splitter.split_text(text)
70
+ docs_splitted+=text_splitted
71
+ return docs_splitted
72
 
73
 
74
  def get_vectorstore_from_docs(doc_chunks):
75
+ embedding = OpenAIEmbeddings(model="text-embedding-3-small")
76
  vectorstore = FAISS.from_documents(documents=doc_chunks, embedding=embedding)
77
  return vectorstore
78
 
79
+ def get_vectorstore_from_text(texts):
80
+ embedding = OpenAIEmbeddings(model="text-embedding-3-small")
81
+ vectorstore = FAISS.from_texts(texts=texts, embedding=embedding)
82
+ return vectorstore
83
+
84
  def get_conversation_chain(vectorstore):
85
  llm = ChatOpenAI(model="gpt-4o",temperature=0.5, max_tokens=2048)
86
  retriever=vectorstore.as_retriever()
 
134
 
135
  if len(st.session_state.urls) > index:
136
  # Instead of using markdown, use an expander in the first column
137
+ with col1.expander(f"Source {index+1}: {item}"):
138
  pp = st.session_state["parties_prenantes"][index]
139
  st.write(pd.DataFrame(pp, columns=["Partie prenante"]))
140
  else:
141
  emp.empty() # Clear the placeholder if the index exceeds the list
142
 
143
+ def colored_circle(color):
144
+ return f'<span style="display: inline-block; width: 15px; height: 15px; border-radius: 50%; background-color: {color};"></span>'
145
+
146
  def display_list_pps():
147
  for index, item in enumerate(st.session_state["pp_grouped"]):
148
  emp = st.empty()
 
155
 
156
  if len(st.session_state["pp_grouped"]) > index:
157
  name = st.session_state["pp_grouped"][index]["name"]
158
+ col1.markdown(f'<p>{colored_circle(st.session_state["pp_grouped"][index]["color"])} {st.session_state["pp_grouped"][index]["name"]}</p>',
159
+ unsafe_allow_html=True
160
+ )
161
  else:
162
  emp.empty()
163
 
164
 
165
 
166
+ def extract_pp(docs,input_variables):
167
  template_extraction_PP = '''
168
+ Objectif : identifiez tout les parties prenantes de la marque suivante:
 
 
 
169
 
170
+ Le nom de la marque de référence est le suivant : {BRAND_NAME}
171
 
172
+ TA REPONSE DOIT ETRE SOUS FORME DE LISTE DE NOMS DE MARQUES SANS INCLURE LE NOM DE LA MARQUE DE REFERENCE SANS NUMEROTATION ET SEPARES PAR DES RETOURS A LA LIGNE
173
+ '''
174
  #don't forget to add the input variables from the maim function
175
 
 
 
176
  if docs == None:
177
  return "445"
178
 
 
194
 
195
  #version simple
196
  partie_prenante = response.content.replace("- ","").split('\n')
197
+ partie_prenante = [item.strip() for item in partie_prenante]
198
 
199
  return partie_prenante
200
 
201
+ def generate_random_color():
202
+ # Generate random RGB values
203
+ r = random.randint(0, 255)
204
+ g = random.randint(0, 255)
205
+ b = random.randint(0, 255)
206
+
207
+ # Convert RGB to hexadecimal
208
+ color_hex = '#{:02x}{:02x}{:02x}'.format(r, g, b)
209
+
210
+ return color_hex
211
+
212
+
213
  def format_pp_add_viz(pp):
214
  y = 50
215
  x = 50
 
222
  if st.session_state['pp_grouped'][i]['name'] == pp:
223
  return None
224
  else:
225
+ st.session_state['pp_grouped'].append({'name':pp, 'x':x,'y':y, 'color':generate_random_color()})
226
 
227
  def add_pp(new_pp, default_value=50):
228
  new_pp = sorted(new_pp)
229
+ new_pp = [item.lower().capitalize().strip() for item in new_pp]
230
  st.session_state['parties_prenantes'].append(new_pp)
231
  for pp in new_pp:
232
  format_pp_add_viz(pp)
 
238
  format_pp_add_viz(new_pp)
239
 
240
  import re
241
+ import random
242
 
243
  def complete_and_verify_url(partial_url):
244
  # Regex pattern for validating a URL
 
273
  def display_pp():
274
 
275
  load_dotenv()
276
+ fire_crawl_api_key = os.getenv("FIRECRAWL_API_KEY")
277
  #check if brand name and description are already set
278
  if "Nom de la marque" not in st.session_state:
279
  st.session_state["Nom de la marque"] = ""
 
301
 
302
  url = st.text_input("Ajouter une URL")
303
 
304
+ scraping_option = st.radio("Mode", ("Analyse rapide", "Analyse profonde"),horizontal=True)
305
  #if the user clicks on the button
306
  if st.button("ajouter"):
307
  st.session_state["save"] = True
 
313
  st.error("URL déjà ajoutée")
314
 
315
  else:
316
+ if scraping_option == "Analyse profonde":
317
+ with st.spinner("Collecte des données..."):
318
+ docs = get_docs_from_website_fc([url],fire_crawl_api_key)
319
+ if docs is None:
320
+ st.warning("Erreur lors de la collecte des données, 2eme essai avec collecte rapide...")
321
+ with st.spinner("2eme essai, collecte rapide..."):
322
+ docs = get_docs_from_website([url])
323
+
324
+ if scraping_option == "Analyse rapide":
325
+ with st.spinner("Collecte des données..."):
326
+ docs = get_docs_from_website([url])
327
+
328
  if docs is None:
329
+ st.error("Erreur lors de la collecte des données")
330
  else:
331
  # Création de l'expander
332
  with st.expander("Cliquez ici pour éditer et voir le document"):
 
339
 
340
  #handle the extraction
341
  input_variables = {"BRAND_NAME": brand_name, "BRAND_DESCRIPTION": ""}
342
+ partie_prenante = extract_pp(docs, input_variables)
343
 
344
  if "444" in partie_prenante: #444 is the code for no brand found , chosen
345
  st.error("Aucune partie prenante trouvée")
pp_viz.py DELETED
@@ -1,51 +0,0 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import numpy as np
4
- import re
5
-
6
- import altair as alt
7
- from session import get_parties_prenantes
8
- import os
9
- from streamlit_vertical_slider import vertical_slider
10
- from st_draggable_list import DraggableList
11
-
12
- def display_viz():
13
-
14
-
15
- parties_prenantes = get_parties_prenantes()
16
-
17
- if parties_prenantes is None or len(parties_prenantes) == 0:
18
- st.write("aucune partie prenante n'a été définie")
19
- else:
20
- partie_prenante_non_filtre = [item.lower().capitalize() for sublist in parties_prenantes for item in sublist]
21
- partie_prenante = sorted(list(set(partie_prenante_non_filtre)))
22
- pouvoir = [ 50 for _ in range(len(partie_prenante))]
23
-
24
- c = (
25
- alt.Chart(st.session_state['partie_prenante_grouped'])
26
- .mark_circle(size=800)
27
- .encode(x="partie_prenante", y=alt.Y("pouvoir",scale=alt.Scale(domain=[0,100])), color="Code couleur",tooltip=["partie_prenante","pouvoir"])
28
- ).configure_legend(orient='bottom',direction="vertical").properties(height=600)
29
-
30
- number_of_sliders = len(partie_prenante)
31
- st.write("Modifiez le pouvoir des parties prenantes en utilisant les sliders ci-dessous")
32
-
33
-
34
- bar = st.columns(number_of_sliders)
35
- for i in range(number_of_sliders):
36
- with bar[i]:
37
- st.session_state['partie_prenante_grouped']['pouvoir'][i] = vertical_slider(
38
- label=partie_prenante[i],
39
- height=100,
40
- key=partie_prenante[i],
41
- default_value=int(st.session_state['partie_prenante_grouped']['pouvoir'][i]),
42
- thumb_color= "orange", #Optional - Defaults to Streamlit Red
43
- step=1,
44
- min_value=0,
45
- max_value=100,
46
- value_always_visible=False,
47
- )
48
- st.altair_chart(c, use_container_width=True)
49
- # data = [{'id':partie_prenante[i], 'name':partie_prenante[i],'pouvoir':int(df["pouvoir"][i])} for i in range(len(partie_prenante))]
50
- # slist = DraggableList(data)
51
- # st.write(slist)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_funcs.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from firecrawl import FireCrawl
2
+
3
+
4
+ def get_docs_from_website_fc(urls):
5
+ app = FireCrawl()
6
+ docs = []
7
+ try:
8
+ for url in urls:
9
+ content = app.scrape_url(url)
10
+ docs.append(content["markdown"])
11
+ return docs
12
+ except Exception as e:
13
+ return None
14
+
requirements.txt CHANGED
@@ -32,4 +32,4 @@ langchain_experimental
32
  streamlit_draggable_list
33
  streamlit-highcharts
34
  pdfkit
35
- kaleido
 
32
  streamlit_draggable_list
33
  streamlit-highcharts
34
  pdfkit
35
+ kaleido
st_hc/frontend/main.js CHANGED
@@ -13,6 +13,7 @@ function onRender(event) {
13
  let points = c.series[0].data.map((p) =>
14
  ({ x: Math.round(p.x),
15
  y: Math.round(p.y),
 
16
  name:p.name} ));
17
  sendValue(points);
18
 
@@ -25,6 +26,7 @@ function onRender(event) {
25
  let points = c.series[0].data.map((p) =>
26
  ({ x: Math.round(p.x),
27
  y: Math.round(p.y),
 
28
  name:p.name} ));
29
 
30
  console.log(points);
 
13
  let points = c.series[0].data.map((p) =>
14
  ({ x: Math.round(p.x),
15
  y: Math.round(p.y),
16
+ color:p.color,
17
  name:p.name} ));
18
  sendValue(points);
19
 
 
26
  let points = c.series[0].data.map((p) =>
27
  ({ x: Math.round(p.x),
28
  y: Math.round(p.y),
29
+ color:p.color,
30
  name:p.name} ));
31
 
32
  console.log(points);
vectorstore_op/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b492225278bd4ba23d11fe72fa16f8abd9a023babcc6734901740ba34fd0ba7
3
+ size 106874