| | import requests |
| | import os |
| | import ast |
| | import streamlit as st |
| | import pandas as pd |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import plotly.express as px |
| | from wordcloud import WordCloud |
| | from collections import Counter |
| | import torch |
| | from transformers import AutoTokenizer |
| | import joblib |
| | from model import MultiLabelDeberta |
| | from huggingface_hub import hf_hub_download |
| | from datasets import load_dataset |
| |
|
| |
|
| | st.set_page_config(page_title="Tag Predictor", layout="wide") |
| |
|
| | |
| | st.markdown(""" |
| | <style> |
| | textarea { |
| | font-size: 18px !important; |
| | } |
| | |
| | .markdown-text-container h1 { |
| | font-size: 34px !important; |
| | } |
| | |
| | .markdown-text-container h2 { |
| | font-size: 28px !important; |
| | } |
| | |
| | .markdown-text-container h3 { |
| | font-size: 24px !important; |
| | } |
| | |
| | .stSlider .css-1y4p8pa, .stSlider .css-1cpxqw2 { |
| | font-size: 18px !important; |
| | } |
| | |
| | .stButton > button { |
| | font-size: 18px !important; |
| | } |
| | |
| | .stAlert { |
| | font-size: 18px !important; |
| | } |
| | |
| | .stMarkdown p { |
| | font-size: 18px !important; |
| | } |
| | </style> |
| | """, unsafe_allow_html=True) |
| |
|
| | |
| |
|
| | REPO_ID = "Framby/deberta_multilabel" |
| | mlb_path = hf_hub_download(repo_id=REPO_ID, filename="mlb.joblib") |
| | mlb = joblib.load(mlb_path) |
| | deberta_path = hf_hub_download( |
| | repo_id=REPO_ID, filename="deberta_multilabel.pt") |
| |
|
| |
|
| | @st.cache_resource |
| | def load_model_and_tokenizer(): |
| | mlb = joblib.load(mlb_path) |
| | model = MultiLabelDeberta(num_labels=len(mlb.classes_)) |
| | model.load_state_dict(torch.load( |
| | deberta_path, map_location="cpu", weights_only=False)) |
| | model.eval() |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | "microsoft/deberta-v3-base", use_fast=False) |
| | return model, tokenizer, mlb |
| |
|
| |
|
| | model, tokenizer, mlb = load_model_and_tokenizer() |
| |
|
| | |
| |
|
| | @st.cache_data |
| | def load_data(): |
| | ds = load_dataset("Framby/SOF_full")['train'] |
| | X = pd.Series(ds['text_clean']) |
| | Y = pd.Series(ds['Tags']) |
| | return X, Y |
| |
|
| |
|
| | X, Y = load_data() |
| |
|
| | |
| |
|
| | def predict_tags(text, threshold=0.5): |
| | inputs = tokenizer( |
| | text, |
| | return_tensors='pt', |
| | truncation=True, |
| | max_length=512, |
| | padding='max_length' |
| | ) |
| | inputs.pop('token_type_ids', None) |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | probs = torch.sigmoid(outputs).squeeze().cpu().numpy() |
| | binary_preds = (probs >= threshold).astype(int) |
| | predicted_tags = mlb.inverse_transform( |
| | np.expand_dims(binary_preds, axis=0)) |
| | return predicted_tags[0] |
| |
|
| |
|
| | |
| |
|
| | st.title("Prédicteur de Tags StackOverflow") |
| |
|
| | st.markdown("## 1. Analyse des données textuelles") |
| |
|
| | col1, col2 = st.columns(2) |
| |
|
| | with col1: |
| | st.markdown("### Questions") |
| | text_lengths = X.apply(lambda x: len(x.split())) |
| | df_lengths = pd.DataFrame({'length': text_lengths}) |
| | fig = px.histogram(df_lengths, x='length', nbins=30, title="Distribution de la longueur des questions") |
| | st.plotly_chart(fig, use_container_width=True) |
| |
|
| | with col2: |
| | st.markdown("### Tags") |
| | parsed_tags = Y.apply(ast.literal_eval) |
| | all_tags = [tag for sublist in parsed_tags for tag in sublist] |
| | tag_freq = Counter(all_tags) |
| | most_common_tags = pd.DataFrame(tag_freq.most_common(20), columns=['Tag', 'Nombre']) |
| | fig2 = px.bar(most_common_tags, x='Tag', y='Nombre', title="20 tags les plus fréquents") |
| | st.plotly_chart(fig2, use_container_width=True) |
| |
|
| | st.markdown("### Nuage de mots") |
| | wc = WordCloud(width=800, height=300, background_color='white').generate(" ".join(X)) |
| | fig_wc, ax = plt.subplots(figsize=(10, 4)) |
| | ax.imshow(wc, interpolation='bilinear') |
| | ax.axis("off") |
| | st.pyplot(fig_wc) |
| |
|
| | st.markdown("---") |
| | st.markdown("## 2. Prédiction des tags") |
| |
|
| | input_text = st.text_area("Entrez une question StackOverflow", height=150) |
| | threshold = st.slider("Seuil de probabilité", 0.1, 0.9, 0.5, 0.05) |
| |
|
| | if st.button("Prédire les tags"): |
| | if input_text.strip(): |
| | tags = predict_tags(input_text, threshold) |
| | if tags: |
| | st.success("Tags prédits :") |
| | st.write(", ".join(tags)) |
| | else: |
| | st.warning("Aucun tag trouvé pour le seuil sélectionné.") |
| | else: |
| | st.warning("Veuillez entrer une question.") |
| |
|