Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import re | |
import random | |
import streamlit as st | |
from dotenv import load_dotenv | |
from PyPDF2 import PdfReader | |
from langchain.text_splitter import CharacterTextSplitter,RecursiveCharacterTextSplitter | |
from langchain_experimental.text_splitter import SemanticChunker | |
from langchain_community.embeddings import OpenAIEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.chat_models import ChatOpenAI | |
from langchain.llms import HuggingFaceHub | |
from langchain import hub | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_community.document_loaders import WebBaseLoader,FireCrawlLoader | |
from langchain_core.prompts.prompt import PromptTemplate | |
from session import set_partie_prenante | |
import os | |
from streamlit_vertical_slider import vertical_slider | |
from high_chart import test_chart | |
from chat_with_pps import get_response | |
load_dotenv() | |
def get_docs_from_website(urls): | |
loader = WebBaseLoader(urls, header_template={ | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36', | |
}) | |
try: | |
docs = loader.load() | |
return docs | |
except Exception as e: | |
return None | |
def get_docs_from_website_fc(urls,firecrawl_api_key): | |
docs = [] | |
try: | |
for url in urls: | |
loader = FireCrawlLoader(api_key=firecrawl_api_key, url = url,mode="scrape") | |
docs+=loader.load() | |
return docs | |
except Exception as e: | |
return None | |
def get_doc_chunks(docs): | |
# Split the loaded data | |
# text_splitter = RecursiveCharacterTextSplitter( | |
# chunk_size=500, | |
# chunk_overlap=100) | |
text_splitter = SemanticChunker(OpenAIEmbeddings(model="text-embedding-3-small")) | |
docs = text_splitter.split_documents(docs) | |
return docs | |
def get_doc_chunks_fc(docs): | |
# Split the loaded data | |
# text_splitter = RecursiveCharacterTextSplitter( | |
# chunk_size=500, | |
# chunk_overlap=100) | |
text_splitter = SemanticChunker(OpenAIEmbeddings(model="text-embedding-3-small")) | |
docs_splitted = [] | |
for text in docs: | |
text_splitted = text_splitter.split_text(text) | |
docs_splitted+=text_splitted | |
return docs_splitted | |
def get_vectorstore_from_docs(doc_chunks): | |
embedding = OpenAIEmbeddings(model="text-embedding-3-small") | |
vectorstore = FAISS.from_documents(documents=doc_chunks, embedding=embedding) | |
return vectorstore | |
def get_vectorstore_from_text(texts): | |
embedding = OpenAIEmbeddings(model="text-embedding-3-small") | |
vectorstore = FAISS.from_texts(texts=texts, embedding=embedding) | |
return vectorstore | |
def get_conversation_chain(vectorstore): | |
llm = ChatOpenAI(model="gpt-4o",temperature=0.5, max_tokens=2048) | |
retriever=vectorstore.as_retriever() | |
prompt = hub.pull("rlm/rag-prompt") | |
# Chain | |
rag_chain = ( | |
{"context": retriever , "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
) | |
return rag_chain | |
# FILL THE PROMPT FOR THE QUESTION VARIABLE THAT WILL BE USED IN THE RAG PROMPT, ATTENTION NOT CONFUSE WITH THE RAG PROMPT | |
def fill_promptQ_template(input_variables, template): | |
prompt = PromptTemplate(input_variables=["BRAND_NAME","BRAND_DESCRIPTION"], template=template) | |
return prompt.format(BRAND_NAME=input_variables["BRAND_NAME"], BRAND_DESCRIPTION=input_variables["BRAND_DESCRIPTION"]) | |
def text_to_list(text): | |
lines = text.replace("- ","").split('\n') | |
lines = [line.split() for line in lines] | |
items = [[' '.join(line[:-1]),line[-1]] for line in lines] | |
# Assuming `items` is the list of items | |
for item in items: | |
item[1] = re.sub(r'\D', '', item[1]) | |
return items | |
def delete_pp(pps): | |
for pp in pps: | |
for i in range(len(st.session_state['pp_grouped'])): | |
if st.session_state['pp_grouped'][i]['name'] == pp: | |
del st.session_state['pp_grouped'][i] | |
break | |
def display_list_urls(): | |
for index, item in enumerate(st.session_state["urls"]): | |
emp = st.empty() # Create an empty placeholder | |
col1, col2 = emp.columns([7, 3]) # Divide the space into two columns | |
# Button to delete the entry, placed in the second column | |
if col2.button("❌", key=f"but{index}"): | |
temp = st.session_state['parties_prenantes'][index] | |
delete_pp(temp) | |
del st.session_state.urls[index] | |
del st.session_state["parties_prenantes"][index] | |
st.experimental_rerun() # Rerun the app to update the display | |
if len(st.session_state.urls) > index: | |
# Instead of using markdown, use an expander in the first column | |
with col1.expander(f"Source {index+1}: {item}"): | |
pp = st.session_state["parties_prenantes"][index] | |
st.write(pd.DataFrame(pp, columns=["Partie prenante"])) | |
else: | |
emp.empty() # Clear the placeholder if the index exceeds the list | |
def colored_circle(color): | |
return f'<span style="display: inline-block; width: 15px; height: 15px; border-radius: 50%; background-color: {color};"></span>' | |
def display_list_pps(): | |
for index, item in enumerate(st.session_state["pp_grouped"]): | |
emp = st.empty() | |
col1, col2 = emp.columns([7, 3]) | |
if col2.button("❌", key=f"butp{index}"): | |
del st.session_state["pp_grouped"][index] | |
st.experimental_rerun() | |
if len(st.session_state["pp_grouped"]) > index: | |
name = st.session_state["pp_grouped"][index]["name"] | |
col1.markdown(f'<p>{colored_circle(st.session_state["pp_grouped"][index]["color"])} {st.session_state["pp_grouped"][index]["name"]}</p>', | |
unsafe_allow_html=True | |
) | |
else: | |
emp.empty() | |
def extract_pp(docs,input_variables): | |
template_extraction_PP = """ | |
Objectif : Identifiez toutes les parties prenantes de la marque suivante : | |
Le nom de la marque de référence est le suivant : {BRAND_NAME} | |
TA RÉPONSE DOIT ÊTRE SOUS FORME DE LISTE DE NOMS DE MARQUES, CHAQUE NOM SUR UNE LIGNE SÉPARÉE. | |
""" | |
#don't forget to add the input variables from the maim function | |
if docs == None: | |
return "445" | |
#get text chunks | |
text_chunks = get_doc_chunks(docs) | |
#create vectorstore | |
vectorstore = get_vectorstore_from_docs(text_chunks) | |
chain = get_conversation_chain(vectorstore) | |
question = fill_promptQ_template(input_variables, template_extraction_PP) | |
response = chain.invoke(question) | |
# version plus poussée a considérer | |
# each item in the list is a list with the name of the brand and the similarity percentage | |
# partie_prenante = text_to_list(response.content) | |
if "ne sais pas" in response.content: | |
return "444" | |
#version simple | |
partie_prenante = response.content.replace("- ","").split('\n') | |
partie_prenante = [item.strip() for item in partie_prenante] | |
return partie_prenante | |
def generate_random_color(): | |
# Generate random RGB values | |
r = random.randint(0, 255) | |
g = random.randint(0, 255) | |
b = random.randint(0, 255) | |
# Convert RGB to hexadecimal | |
color_hex = '#{:02x}{:02x}{:02x}'.format(r, g, b) | |
return color_hex | |
def format_pp_add_viz(pp): | |
y = 50 | |
x = 50 | |
for i in range(len(st.session_state['pp_grouped'])): | |
if st.session_state['pp_grouped'][i]['y'] == y and st.session_state['pp_grouped'][i]['x'] == x: | |
y += 5 | |
if y > 95: | |
y = 50 | |
x += 5 | |
if st.session_state['pp_grouped'][i]['name'] == pp: | |
return None | |
else: | |
st.session_state['pp_grouped'].append({'name':pp, 'x':x,'y':y, 'color':generate_random_color()}) | |
def add_pp(new_pp, default_value=50): | |
new_pp = sorted(new_pp) | |
new_pp = [item.lower().capitalize().strip() for item in new_pp] | |
st.session_state['parties_prenantes'].append(new_pp) | |
for pp in new_pp: | |
format_pp_add_viz(pp) | |
def add_pp_input_text(): | |
new_pp = st.text_input("Ajouter une partie prenante") | |
if st.button("Ajouter",key="add_single_pp"): | |
format_pp_add_viz(new_pp) | |
def complete_and_verify_url(partial_url): | |
# Regex pattern for validating a URL | |
regex = re.compile( | |
r'^(?:http|ftp)s?://' # http:// or https:// | |
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,8}\.?|' # domain | |
r'localhost|' # localhost... | |
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip | |
r'(?::\d+)?' # optional port | |
r'(?:/?|[/?]\S+)$', re.IGNORECASE) | |
regex = re.compile( | |
r'^(?:http|ftp)s?://' # http:// or https:// | |
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,8}\.?|' # domain name | |
r'localhost|' # or localhost | |
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # or IPv4 address | |
r'(?::\d+)?' # optional port | |
r'(?:[/?#][^\s]*)?$', # optional path, query, or fragment | |
re.IGNORECASE) | |
# Complete the URL if it doesn't have http:// or https:// | |
if not partial_url.startswith(('http://', 'https://', 'www.')): | |
if not partial_url.startswith('www.'): | |
complete_url = 'https://www.' + partial_url | |
else: | |
complete_url = 'https://' + partial_url | |
elif partial_url.startswith('www.'): | |
complete_url = 'https://' + partial_url | |
else: | |
complete_url = partial_url | |
# Check if the URL is valid | |
if re.match(regex, complete_url): | |
return (True, complete_url) | |
else: | |
return (False, complete_url) | |
def show_conseil_ia(): | |
prompt = "Prenant compte les données de l'entreprise (activité, produits, services ...), quelles sont les principales parties prenantes à animer pour une démarche RSE réussie ?" | |
st.markdown(f"**{prompt}**") | |
response = st.write_stream(get_response(prompt, "",st.session_state["latest_doc"][0].page_content)) | |
st.warning("Quittez et saisissez une autre URL") | |
def display_pp(): | |
load_dotenv() | |
fire_crawl_api_key = os.getenv("FIRECRAWL_API_KEY") | |
#check if brand name and description are already set | |
if "Nom de la marque" not in st.session_state: | |
st.session_state["Nom de la marque"] = "" | |
#check if urls and partie prenante are already set | |
if "urls" not in st.session_state: | |
st.session_state["urls"] = [] | |
if "parties_prenantes" not in st.session_state: | |
st.session_state['parties_prenantes'] = [] | |
if "pp_grouped" not in st.session_state: #servira pour le plot et la cartographie des parties prenantes, regroupe sans doublons | |
st.session_state['pp_grouped'] = [] | |
if "latest_doc" not in st.session_state: | |
st.session_state['latest_doc'] = "" | |
if "not_pp" not in st.session_state: | |
st.session_state["not_pp"] = "" | |
st.header("Parties prenantes de la marque") | |
#set brand name and description | |
brand_name = st.text_input("Nom de la marque", st.session_state["Nom de la marque"]) | |
st.session_state["Nom de la marque"] = brand_name | |
option = st.radio("Source", ("A partir de votre site web", "A partir de vos documents entreprise")) | |
#if the user chooses to extract from website | |
if option == "A partir de votre site web": | |
url = st.text_input("Ajouter une URL") | |
captions = ["L’IA prend en compte uniquement les textes contenus dans les pages web analysées","L’IA prend en compte les textes, les images et les liens URL contenus dans les pages web analysées"] | |
scraping_option = st.radio("Mode", ("Analyse rapide", "Analyse profonde"),horizontal=True,captions = captions) | |
#if the user clicks on the button | |
if st.button("ajouter",key="add_pp"): | |
st.session_state["not_pp"] = "" | |
#complete and verify the url | |
is_valid,url = complete_and_verify_url(url) | |
if not is_valid: | |
st.error("URL invalide") | |
elif url in st.session_state["urls"] : | |
st.error("URL déjà ajoutée") | |
else: | |
if scraping_option == "Analyse profonde": | |
with st.spinner("Collecte des données..."): | |
docs = get_docs_from_website_fc([url],fire_crawl_api_key) | |
if docs is None: | |
st.warning("Erreur lors de la collecte des données, 2eme essai avec collecte rapide...") | |
with st.spinner("2eme essai, collecte rapide..."): | |
docs = get_docs_from_website([url]) | |
if scraping_option == "Analyse rapide": | |
with st.spinner("Collecte des données..."): | |
docs = get_docs_from_website([url]) | |
if docs is None: | |
st.error("Erreur lors de la collecte des données, URL unvalide") | |
st.session_state["latest_doc"] = "" | |
else: | |
# Création de l'expander | |
st.session_state["latest_doc"] = docs | |
with st.spinner("Processing..."): | |
#handle the extraction | |
input_variables = {"BRAND_NAME": brand_name, "BRAND_DESCRIPTION": ""} | |
partie_prenante = extract_pp(docs, input_variables) | |
if "444" in partie_prenante: #444 is the code for no brand found , chosen | |
st.session_state["not_pp"] = "444" | |
elif "445" in partie_prenante: #445 is the code for no website found with the given url | |
st.error("Aucun site web trouvé avec l'url donnée") | |
st.session_state["not_pp"] = "" | |
else: | |
st.session_state["not_pp"] = "" | |
partie_prenante = sorted(partie_prenante) | |
st.session_state["urls"].append(url) | |
add_pp(partie_prenante) | |
# alphabet = [ pp[0] for pp in partie_prenante] | |
# pouvoir = [ 50 for _ in range(len(partie_prenante))] | |
# df = pd.DataFrame({'partie_prenante': partie_prenante, 'pouvoir': pouvoir, 'code couleur': partie_prenante}) | |
# st.write(df) | |
# c = ( | |
# alt.Chart(df) | |
# .mark_circle(size=300) | |
# .encode(x="partie_prenante", y=alt.Y("pouvoir",scale=alt.Scale(domain=[0,100])), color="code couleur") | |
# ) | |
# st.subheader("Vertical Slider") | |
# age = st.slider("How old are you?", 0, 130, 25) | |
# st.write("I'm ", age, "years old") | |
# disp_vertical_slider(partie_prenante) | |
# st.altair_chart(c, use_container_width=True) | |
if st.session_state["not_pp"] == "444": | |
st.warning("Aucune parties prenantes n'est identifiable sur l'URL fournie. Fournissez une autre URL ou bien cliquez sur le boutton ci-dessous pour un Conseils IA") | |
if st.button("Conseil IA"): | |
show_conseil_ia() | |
#display docs | |
if st.session_state["latest_doc"] != "": | |
with st.expander("Cliquez ici pour éditer et voir le document"): | |
docs = st.session_state["latest_doc"] | |
cleaned_text = re.sub(r'\n\n+', '\n\n', docs[0].page_content.strip()) | |
text_value = st.text_area("Modifier le texte ci-dessous:", value=cleaned_text, height=300) | |
if st.button('Sauvegarder',key="save_doc_fake"): | |
st.success("Texte sauvegardé avec succès!") | |
display_list_urls() | |
with st.expander("Liste des parties prenantes"): | |
add_pp_input_text() | |
display_list_pps() | |
test_chart() | |