CM2D_poc / app.py
Antoine Lelong
fix: revert back float code
42dee0e unverified
import gradio as gr
import fasttext
import re
from huggingface_hub import hf_hub_download
import os
REPO_ID = "numericite/fasttext_cm2d_classificator"
FILENAME = "fasttext_model_v2.bin"
case_examples = [
"avc",
"suicide",
"suicide par arme à feu",
"suicide par médicaments",
"défenestration",
"arrêt cardio respiratoire",
"hta",
"démence avancée type alzheimer",
"oedèmes majeurs",
"plaie faciale",
"traumatisme crânien frontal droit",
"sarm"
]
def preprocess_text(text):
text = re.sub(r'[^\w\s\']', ' ', text)
text = re.sub(r' +', ' ', text)
text = text.strip().lower()
return text
class FastTextModelTester:
def __init__(self):
self.model = fasttext.load_model(hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=os.getenv("token")))
def predict(self, text, k=1):
try:
# For classification models
if hasattr(self.model, 'predict'):
text = preprocess_text(text)
labels, probabilities = self.model.predict(text, k=k)
# Format the results
results = []
for i in range(len(labels)):
label = labels[i].replace('__label__', '')
prob = probabilities[i]
results.append(f"{label}: {prob:.4f}")
return "\n".join(results)
# For word embedding models
else:
vector = self.model.get_word_vector(text)
return f"Word vector (first 5 dimensions): {vector[:5]}"
except Exception as e:
return f"Error during prediction: {str(e)}"
# Create the Gradio interface
def create_interface():
model_tester = FastTextModelTester()
with gr.Blocks(title="CM2D POC Classificateur") as app:
with gr.Row():
text_input = gr.Textbox(label="Input Text", placeholder="Enter text to classify")
k_input = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Number of Labels (k)")
with gr.Row():
gr.Examples(
examples=case_examples,
inputs=[text_input],
)
predict_button = gr.Button("Predict")
prediction_output = gr.Textbox(label="Prediction Result", interactive=False)
predict_button.click(
fn=model_tester.predict,
inputs=[text_input, k_input],
outputs=prediction_output
)
return app
if __name__ == "__main__":
app = create_interface()
app.launch()