Spaces:
Running
Running
| import os | |
| import json | |
| from keras.models import load_model | |
| from keras_preprocessing.sequence import pad_sequences | |
| from keras.preprocessing.text import tokenizer_from_json | |
| BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # Risale alla root | |
| MODEL_PATH = os.path.join(BASE_DIR, 'data', 'model', 'multi-classification.h5') | |
| TOKENIZER_PATH = os.path.join(BASE_DIR, 'data', 'tokenizer', 'multi-classification-tokenizer.json') | |
| CLASS_NAMES = ['Economia', 'Politica', 'Scienza_e_tecnica', 'Sport', 'Storia'] | |
| # Caricamento Singleton (lo carichiamo una volta sola) | |
| model = None | |
| def load_resources(): | |
| global model | |
| if model is None and os.path.exists(MODEL_PATH): | |
| try: | |
| # Carica Tokenizer | |
| with open(TOKENIZER_PATH, 'r') as f: | |
| tokenizer_data = json.load(f) | |
| tokenizer = tokenizer_from_json(tokenizer_data) | |
| # Carica Modello | |
| model = load_model(MODEL_PATH, compile=False) | |
| return model, tokenizer | |
| except Exception as e: | |
| print(f"Errore caricamento risorse MultiLabel: {e}") | |
| return None, None | |
| # Carichiamo una volta sola all'avvio (Singleton) per velocità | |
| model, tokenizer = load_resources() | |
| def multi_classification(text): | |
| if model is None or tokenizer is None: | |
| return {"Errore": "Modello non caricato"} | |
| try: | |
| # Preprocessing identico al training | |
| # 1. Tokenization | |
| sequences = tokenizer.texts_to_sequences([text]) | |
| new_sequences = [] | |
| for sequence in sequences[0]: | |
| if(sequence is None): | |
| sequence = 1 | |
| new_sequences.append(sequence) | |
| x_new_sequences = [new_sequences] | |
| # 2. Padding (maxlen=200 come da tuo codice originale) | |
| data_padded = pad_sequences(x_new_sequences, maxlen=200) | |
| # 3. Predizione | |
| prediction = model.predict(data_padded, verbose=0)[0] | |
| # 4. Formattazione Output per Gradio {Label: Score} | |
| results = {} | |
| for i, score in enumerate(prediction): | |
| label = CLASS_NAMES[i].replace('_', ' ') | |
| results[label] = float(score) | |
| return results | |
| except Exception as e: | |
| return {f"Errore durante l'analisi: {str(e)}": 0.0} |