import gradio as gr from datasets import load_dataset from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelEncoder from sklearn.linear_model import LogisticRegression from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.pipeline import Pipeline import warnings # Filtert die UserWarning von scikit-learn wegen der fehlenden Stichhaltigkeit von Klassen im Klassifikationsbericht heraus warnings.filterwarnings("ignore", category=UserWarning) # 1. Laden und Vorbereiten des Datensatzes (einmalig beim Start) try: dataset = load_dataset("banking77") texts = dataset['train']['text'] + dataset['test']['text'] labels = dataset['train']['label'] + dataset['test']['label'] label_encoder = LabelEncoder() numerical_labels = label_encoder.fit_transform(labels) label_names = label_encoder.classes_ train_texts, test_texts, train_labels, test_labels = train_test_split( texts, numerical_labels, test_size=0.2, random_state=42, stratify=numerical_labels ) print("Datensatz 'banking77' erfolgreich geladen.") except Exception as e: print(f"Fehler beim Laden des Datensatzes: {e}") label_names = ["Fehler beim Laden"] pipeline = None print("Modell wird nicht trainiert, da der Datensatz nicht geladen werden konnte.") # 2. Trainieren des Modells (einmalig beim Start) if 'pipeline' not in locals() or pipeline is None: try: pipeline = Pipeline([ ('tfidf', TfidfVectorizer()), ('classifier', LogisticRegression(solver='liblinear', multi_class='ovr', random_state=42)) ]) pipeline.fit(train_texts, train_labels) print("Modell erfolgreich trainiert.") except Exception as e: print(f"Fehler beim Trainieren des Modells: {e}") pipeline = None print("Modell konnte nicht trainiert werden.") # 3. Funktion für die Vorhersage def predict_intent(text): if pipeline is not None and len(label_names) > 0: try: prediction = pipeline.predict([text])[0] predicted_label = label_names[prediction] # Korrektur: String-Label abrufen probabilities = pipeline.predict_proba([text])[0] confidences = {label_names[i]: f"{probabilities[i]:.2f}" for i in range(len(label_names))} return predicted_label, confidences except Exception as e: return "Fehler bei der Vorhersage", {"Fehler": f"Ein Fehler ist bei der Vorhersage aufgetreten: {e}"} else: return "Fehler", {"Fehler": "Modell nicht geladen oder trainiert."} # 4. Erstellen der Gradio Interface iface = gr.Interface( fn=predict_intent, inputs=gr.Textbox(label="Gib deine Kundenanfrage ein:", placeholder="z.B. Ich habe mein Passwort vergessen."), outputs=[ gr.Label(label="Vorhergesagte Kundenintention:"), gr.JSON(label="Konfidenzwerte:") ], # title und description auf Deutsch title="KI-gestützte Vorhersage von Kundenanfragen", description="Diese Anwendung sagt die Absicht einer Kundenanfrage voraus. Gib eine Anfrage ein, um die vorhergesagte Kategorie und die Konfidenzwerte zu sehen. Das Modell wurde auf dem Datensatz Banking77 trainiert.", examples=[ ["Ich habe mein Passwort vergessen."], ["Wie kann ich Geld überweisen?"], ["Meine Karte ist verloren gegangen."], ["Was ist der aktuelle Zinssatz für ein Sparkonto?"] ], css=""" .container { margin: 0 auto; max-width: 700px; padding: 20px; text-align: center; } .input_output_section { display: flex; flex-direction: column; align-items: center; margin-bottom: 20px; } .label { font-weight: bold; margin-bottom: 5px; color: #4a5568; /* Dunkleres Grau für bessere Lesbarkeit */ } .textbox { border: 1px solid #cbd5e0; /* Etwas hellerer Rahmen */ border-radius: 0.375rem; /* Abgerundete Ecken gemäß Tailwind */ padding: 0.75rem; width: 100%; max-width: 400px; /* Begrenze die Breite des Textfelds */ margin-bottom: 1rem; font-size: 1rem; box-shadow: inset 0 2px 4px rgba(0,0,0,0.06); /* Subtiler Schatten */ transition: border-color 0.2s ease-in-out, box-shadow 0.2s ease-in-out; /* Sanfte Übergänge */ } .textbox:focus { outline: none; border-color: #3182ce; /* Blauer Fokus-Rand */ box-shadow: 0 0 0 3px rgba(66, 153, 225, 0.16); /* Heller Fokus-Schatten */ } .label_output { font-size: 1.25rem; font-weight: 600; color: #2d3748; /* Noch dunkler für die Ausgabe */ margin-bottom: 1.5rem; padding: 0.5rem; border-radius: 0.375rem; background-color: #edf2f7; /* Sehr helles Grau für Hintergrund der Ausgabe */ box-shadow: 0 1px 3px rgba(0,0,0,0.08); /* Sehr schwacher Schatten */ min-width: 200px; /* Mindestbreite für die Ausgabe */ text-align: center; } .json_output { background-color: #f7fafc; /* Noch helleres Grau für JSON */ border: 1px solid #e2e8f0; border-radius: 0.375rem; padding: 1rem; font-family: 'Menlo', monospace; /* Monospace-Schriftart für JSON */ font-size: 0.875rem; line-height: 1.5rem; overflow-x: auto; /* Horizontal scrollbar bei Überlauf */ max-width: 400px; /* Maximale Breite */ margin: 0 auto; /* Zentrieren */ } .examples { margin-top: 2rem; text-align: center; } .example_item { cursor: pointer; padding: 0.5rem 1rem; margin: 0.5rem; background-color: #e2e8f0; /* Hellgrauer Hintergrund für Beispiele */ color: #2d3748; border-radius: 0.375rem; border: 1px solid #f0f4f8; transition: background-color 0.2s ease-in-out, transform 0.1s ease; display: inline-block; /* Damit die Breite automatisch angepasst wird */ font-size: 0.9rem; box-shadow: 0 1px 2px rgba(0,0,0,0.05); } .example_item:hover { background-color: #cbd5e0; /* Dunkleres Grau bei Hover */ transform: translateY(-2px); /* Leichter Hover-Effekt */ border-color: #a0aec0; } """, ) # 5. Starten der Gradio App (wird beim Ausführen des Skripts aktiv) iface.launch(share=True)