monsimas's picture
Update app.py
994abde verified
raw
history blame contribute delete
22.1 kB
import gradio as gr
import pandas as pd
import numpy as np
from datetime import datetime
import os
from mistralai import Mistral, UserMessage
from sklearn.cluster import KMeans
import plotly.express as px
import umap
from collections import defaultdict
import umap.umap_ as umap
import csv
from tempfile import NamedTemporaryFile
import io
# Récupérer la clé API depuis les variables d'environnement
MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY')
if not MISTRAL_API_KEY:
raise EnvironmentError("La variable d'environnement MISTRAL_API_KEY n'est pas définie")
mistral_client = Mistral(api_key=MISTRAL_API_KEY)
class MistralClient:
"""Singleton class for Mistral client to avoid multiple instantiations"""
_instance = None
_client = None
@classmethod
def initialize(cls):
"""Initialize the Mistral client with environment variable"""
if cls._instance is None or cls._client is None:
api_key = os.getenv('MISTRAL_API_KEY')
if not api_key:
raise EnvironmentError("La variable d'environnement MISTRAL_API_KEY n'est pas définie")
cls._instance = cls()
cls._client = Mistral(api_key=api_key)
return cls._client
@classmethod
def get_instance(cls):
"""Get the Mistral client instance"""
if cls._client is None:
return cls.initialize()
return cls._client
def process_uploaded_file(file_path):
"""
Process an uploaded CSV or XLSX file
"""
try:
# Determine file extension
file_ext = os.path.splitext(file_path.name)[1].lower()
# Read file based on extension
if file_ext == '.csv':
df = pd.read_csv(file_path)
elif file_ext in ['.xlsx', '.xls']:
df = pd.read_excel(file_path)
else:
raise ValueError("Unsupported file format. Please upload a CSV or Excel file.")
# Check if the file has at least one column
if df.shape[1] == 0:
raise ValueError("The uploaded file has no columns.")
# Create a combined text field from the first column
first_col = df.columns[0]
df['combined_text'] = df[first_col].fillna("").astype(str)
# Add any missing columns needed for analysis
if 'title' not in df.columns:
df['title'] = df['combined_text'].str.slice(0, 100) # Use first 100 chars as title
if 'body' not in df.columns:
df['body'] = df['combined_text']
if 'reference' not in df.columns:
df['reference'] = df.index.astype(str)
if 'endorsements' not in df.columns:
df['endorsements'] = 0
if 'category' not in df.columns:
df['category'] = ""
# Return processed DataFrame
return df
except Exception as e:
raise ValueError(f"Error processing file: {str(e)}")
def get_embeddings_batch(texts: list[str], client: Mistral, batch_size: int = 32) -> list[list[float]]:
"""
Get embeddings for a batch of texts with proper error handling and batching
Args:
texts: List of texts to embed
client: Mistral client instance
batch_size: Size of batches to process at once
Returns:
List of embeddings (each embedding is a list of floats)
"""
all_embeddings = []
# Process texts in batches
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
# Clean texts in batch
batch = [text.replace("\n", " ").strip() for text in batch]
max_retries = 3
for attempt in range(max_retries):
try:
response = client.embeddings.create(
model="mistral-embed",
inputs=batch
)
batch_embeddings = [data.embedding for data in response.data]
all_embeddings.extend(batch_embeddings)
break
except Exception as e:
if attempt == max_retries - 1:
print(f"Failed to get embeddings after {max_retries} attempts: {str(e)}")
# Add zero embeddings for failed batch
all_embeddings.extend([[0] * 1024] * len(batch))
else:
import time
time.sleep(1) # Wait before retrying
return all_embeddings
def generate_embeddings(df: pd.DataFrame, client: Mistral) -> pd.DataFrame:
"""
Generate embeddings for all proposals in the dataframe
Args:
df: DataFrame containing proposals
client: Mistral client instance
Returns:
DataFrame with added embedding column
"""
# Get all texts that need to be embedded
texts = df['combined_text'].tolist()
# Get embeddings in batches
embeddings = get_embeddings_batch(texts, client)
# Add embeddings to dataframe
df['embedding'] = embeddings
return df
def perform_clustering(df, n_clusters):
matrix = np.vstack(df.embedding.values)
kmeans = KMeans(n_clusters=n_clusters, init="k-means++", random_state=42)
kmeans.fit(matrix)
df["Cluster"] = kmeans.labels_
return df, matrix
def analyze_clusters(df, client):
cluster_names = {}
cluster_analyses = {}
for cluster_id in sorted(df.Cluster.unique()):
cluster_df = df[df.Cluster == cluster_id]
sample_proposals = cluster_df.sample(min(5, len(cluster_df)))
proposal_texts = "\n".join([
f"Proposal {i+1}: {row['title']} - {row['body'][:200]}..."
for i, row in enumerate(sample_proposals.to_dict('records'))
])
messages = [
UserMessage(content=f"Quel est le thème principal de ces propositions ? Donne un titre court (3-4 mots max) pour cet ensemble de propositions. Ne retourne rien d'autre - juste le titre.\n\n{proposal_texts}\n\nTitre:")
]
try:
response = client.chat.complete(
model="mistral-large-latest",
messages=messages,
temperature=0,
)
theme = response.choices[0].message.content.strip()
cluster_names[cluster_id] = theme
cluster_analyses[cluster_id] = {
'theme': theme,
'size': len(cluster_df),
'avg_endorsements': cluster_df.endorsements.mean(),
'categories': cluster_df.category.value_counts().to_dict()
}
except Exception as e:
cluster_names[cluster_id] = f"Cluster {cluster_id}"
return cluster_names, cluster_analyses
def create_interactive_cluster_viz(df, matrix, cluster_names):
"""Create an interactive cluster visualization using UMAP and Plotly"""
# Use UMAP for better cluster separation
reducer = umap.UMAP(
random_state=42,
n_neighbors=15,
min_dist=0.1,
metric='cosine' # Better for high-dimensional embeddings
)
embedding = reducer.fit_transform(matrix)
# Create DataFrame for plotting
plot_df = pd.DataFrame({
'UMAP1': embedding[:, 0],
'UMAP2': embedding[:, 1],
'Cluster': [cluster_names[c] for c in df['Cluster']],
'Title': df['title'],
'Reference': df['reference']
})
# Create interactive scatter plot
fig = px.scatter(
plot_df,
x='UMAP1',
y='UMAP2',
color='Cluster',
hover_data=['Title', 'Reference'],
title='Visualisation des clusters',
template='plotly_white'
)
# Update layout for better visualization
fig.update_traces(
marker=dict(size=8),
hovertemplate="<br>".join([
"Cluster: %{customdata[0]}",
"Title: %{customdata[1]}",
"Reference: %{customdata[2]}"
])
)
fig.update_layout(
width=1200,
height=800,
showlegend=True,
legend=dict(
yanchor="top",
y=0.99,
xanchor="right",
x=0.99
)
)
return fig
def ensure_unique_cluster_names(cluster_names, client):
"""Ensure cluster names are unique by regenerating duplicates"""
seen_names = set()
new_names = {}
for cluster_id, name in cluster_names.items():
if name in seen_names:
# Generate new name with context
prompt = f"""Le nom '{name}' est déjà utilisé pour un autre cluster.
Donne un nom plus spécifique à ce cluster. Retourne seulement le nouveau nom. Quelques mots maximum."""
messages = [UserMessage(content=prompt)]
try:
response = client.chat.complete(
model="mistral-large-latest",
messages=messages,
temperature=0.7
)
new_name = response.choices[0].message.content.strip()
# If still duplicate, add numerical suffix
if new_name in seen_names:
new_name = f"{name} (Group {cluster_id + 1})"
new_names[cluster_id] = new_name
seen_names.add(new_name)
except Exception:
new_names[cluster_id] = f"{name} (Group {cluster_id + 1})"
seen_names.add(new_names[cluster_id])
else:
new_names[cluster_id] = name
seen_names.add(name)
return new_names
def generate_trend_analysis(df, cluster_names, client):
"""Generate a human-readable analysis of the clusters using LLM"""
# Prepare cluster information
cluster_info = {}
for cluster_id in df['Cluster'].unique():
cluster_df = df[df['Cluster'] == cluster_id]
# Get representative samples
samples = cluster_df.sample(min(5, len(cluster_df)))
cluster_info[cluster_id] = {
'name': cluster_names[cluster_id],
'size': len(cluster_df),
'samples': [
{'title': row['title'], 'body': row['body'][:200]}
for _, row in samples.iterrows()
]
}
# Create analysis prompt
prompt = """Analysez ces groupes de propositions et rédigez un résumé éclairant pour les citoyens et les décideurs.
À partir de ces données, concentrez-vous sur les thèmes principaux, les tendances et les implications potentielles pour les politiques publiques.
Utilisez un langage clair et accessible en évitant les termes techniques.
Structurez l'analyse comme suit :
1. Tendances générales et thèmes principaux
2. Observations notables pour chaque groupe
3. Points clés à retenir pour les décideurs
Groupes :
"""
for cluster_id, info in cluster_info.items():
prompt += f"\n{info['name']} ({info['size']} proposals)\n"
prompt += "Sample proposals:\n"
for sample in info['samples']:
prompt += f"- {sample['title']}\n"
messages = [UserMessage(content=prompt)]
try:
response = client.chat.complete(
model="mistral-large-latest",
messages=messages,
temperature=0.7,
max_tokens=10000
)
return response.choices[0].message.content
except Exception as e:
return f"Error generating analysis: {str(e)}"
def generate_categorization_analysis(results_df, client):
"""
Generate analysis of categorization results
"""
category_counts = results_df['categories'].value_counts()
analysis_prompt = f"""Analyze these categorization results and provide insights:
Category Distribution:
{category_counts.to_string()}
Provide a brief analysis of the distribution and any notable patterns."""
messages = [UserMessage(content=analysis_prompt)]
try:
response = client.chat.complete(
model="mistral-large-latest",
messages=messages,
temperature=0.7,
max_tokens=10000
)
return response.choices[0].message.content
except Exception as e:
return f"Error generating analysis: {str(e)}"
def suggest_new_categories(uncategorized_proposals, client):
"""Suggest new categories based on uncategorized proposals"""
if len(uncategorized_proposals) == 0:
return []
# Sample proposals for analysis
sample_size = min(10, len(uncategorized_proposals))
samples = uncategorized_proposals.sample(n=sample_size)
# Create analysis prompt
prompt = """Analysez ces propositions non catégorisées et suggérez 3 à 5 nouvelles catégories distinctes.
Les catégories doivent être :
- Suffisamment spécifiques pour être pertinentes
- Assez larges pour s'appliquer à plusieurs propositions
- Clairement distinctes les unes des autres
- Décrites en 2 à 4 mots chacune
Ne retournez que les noms des catégories, séparés par des virgules.
Propositions :
"""
for _, row in samples.iterrows():
prompt += f"Title: {row['title']}\n"
prompt += f"Description: {row['body'][:200]}...\n\n"
messages = [UserMessage(content=prompt)]
try:
response = client.chat.complete(
model="mistral-large-latest",
messages=messages,
temperature=0.3
)
# Clean and validate suggested categories
suggested_categories = [
cat.strip()
for cat in response.choices[0].message.content.split(',')
if 2 <= len(cat.strip().split()) <= 4 # Ensure proper length
]
return suggested_categories[:5] # Limit to 5 categories maximum
except Exception as e:
print(f"Error suggesting categories: {str(e)}")
return []
def categorize_proposal_multi(proposal, categories, client):
"""Catégoriser une proposition dans plusieurs catégories"""
prompt = f"""Voici les catégories disponibles :
{', '.join(categories)}
Pour cette proposition :
Titre : {proposal['title']}
Description : {proposal['body'][:500]}
Attribuez TOUTES les catégories pertinentes. Gardez à l'esprit que :
- Une proposition peut appartenir à plusieurs catégories si elles sont vraiment pertinentes
- N'attribuez que les catégories qui correspondent fortement au contenu de la proposition
- Si aucune catégorie ne convient bien, retournez "Non catégorisé"
Ne retournez que les noms des catégories séparés par des virgules, rien d'autre."""
messages = [UserMessage(content=prompt)]
try:
response = client.chat.complete(
model="mistral-large-latest",
messages=messages,
temperature=0.3
)
assigned_categories = [
cat.strip()
for cat in response.choices[0].message.content.split(',')
if cat.strip() in categories
]
return assigned_categories if assigned_categories else ["Uncategorized"]
except Exception as e:
print(f"Error categorizing proposal: {str(e)}")
return ["Uncategorized"]
def create_interface():
with gr.Blocks() as app:
gr.Markdown("# IA data scientist pour analyse de contenu")
with gr.Tabs():
# Exploration Tab
with gr.Tab("Exploration"):
with gr.Row():
with gr.Column():
file_upload_explore = gr.File(
label="Téléverser un fichier CSV ou Excel",
file_types=[".csv", ".xlsx", ".xls"],
type="filepath"
)
n_clusters = gr.Slider(
minimum=2,
maximum=20,
value=4,
step=1,
label="Nombre de clusters"
)
analyze_btn = gr.Button("Analyser les données")
with gr.Row():
cluster_plot = gr.Plot(label="Visualisation interactive")
with gr.Row():
analysis_text = gr.Markdown(label="Analyse des tendances")
# Categorization Tab
with gr.Tab("Categorization"):
with gr.Row():
with gr.Column():
file_upload_cat = gr.File(
label="Téléverser un fichier CSV ou Excel",
file_types=[".csv", ".xlsx", ".xls"],
type="filepath"
)
allow_multiple = gr.Checkbox(
label="Permettre plusieurs catégories par entrée",
value=False
)
suggest_categories = gr.Checkbox(
label="Proposer de nouvelles catégories pour les entrées non catégorisées",
value=False
)
categories_input = gr.Textbox(
label="Les catégories (séparées par des virgules)",
placeholder="Par exemple: Transports, Parcs, Education, Logement",
lines=3
)
categorize_btn = gr.Button("Catégoriser")
with gr.Row():
suggested_categories = gr.CheckboxGroup(
label="Catégories suggérées",
choices=[],
visible=False
)
with gr.Row():
results_table = gr.DataFrame(label="Résultats de la catégorisation")
with gr.Row():
category_analysis = gr.Markdown(label="Analyse de la catégorisation")
def exploration_handler(file_path, n_clusters):
try:
client = MistralClient.get_instance()
# Process uploaded file
df = process_uploaded_file(file_path)
# Generate embeddings and clustering
df = generate_embeddings(df, client)
df, matrix = perform_clustering(df, n_clusters)
# Get cluster names and ensure uniqueness
cluster_names, _ = analyze_clusters(df, client)
cluster_names = ensure_unique_cluster_names(cluster_names, client)
# Create visualization
plot = create_interactive_cluster_viz(df, matrix, cluster_names)
# Generate analysis
analysis = generate_trend_analysis(df, cluster_names, client)
return plot, analysis
except EnvironmentError as e:
return None, f"Erreur : {str(e)}"
except Exception as e:
return None, f"Une erreur est survenue : {str(e)}"
def categorization_handler(file_path, categories, allow_multiple, suggest_categories):
try:
client = MistralClient.get_instance()
# Process uploaded file
df = process_uploaded_file(file_path)
# Convert categories string to list
category_list = [cat.strip() for cat in categories.split(',')]
# Initialize results
results = []
uncategorized = []
# Process each entry
for _, row in df.iterrows():
if allow_multiple:
assigned_cats = categorize_proposal_multi(row, category_list, client)
else:
assigned_cats = categorize_proposal_multi(row, category_list, client)[:1]
if (not assigned_cats or assigned_cats == ["Uncategorized"]) and suggest_categories:
uncategorized.append(row)
results.append({
'id': row.get('id', str(row.name)), # Use index if no id
'title': row['title'],
'categories': ', '.join(assigned_cats) if assigned_cats else 'Non catégorisé'
})
results_df = pd.DataFrame(results)
# Generate suggested categories if needed
new_categories = []
if suggest_categories and uncategorized:
uncategorized_df = pd.DataFrame(uncategorized)
if not uncategorized_df.empty:
new_categories = suggest_new_categories(uncategorized_df, client)
# Generate analysis
analysis = generate_categorization_analysis(results_df, client)
return (
results_df,
gr.CheckboxGroup(choices=new_categories, visible=bool(new_categories)),
analysis
)
except EnvironmentError as e:
return pd.DataFrame(), gr.CheckboxGroup(visible=False), f"Erreur : {str(e)}"
except Exception as e:
return pd.DataFrame(), gr.CheckboxGroup(visible=False), f"Une erreur est survenue : {str(e)}"
analyze_btn.click(
exploration_handler,
inputs=[file_upload_explore, n_clusters],
outputs=[cluster_plot, analysis_text]
)
categorize_btn.click(
categorization_handler,
inputs=[
file_upload_cat,
categories_input,
allow_multiple,
suggest_categories
],
outputs=[results_table, suggested_categories, category_analysis]
)
return app
if __name__ == "__main__":
app = create_interface()
app.launch()