Spaces:
Build error
Build error
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| import json | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from typing import List, Dict, Any, Union | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
| import shap | |
| st.set_page_config( | |
| page_title="Text Classifiers", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| from text_preprocessing import ( | |
| preprocess_text, get_contextual_embeddings, TextVectorizer | |
| ) | |
| from classical_classifiers import ( | |
| get_logistic_regression, get_svm_linear, get_random_forest, | |
| get_gradient_boosting, get_voting_classifier | |
| ) | |
| from neural_classifiers import get_transformer_classifier | |
| from model_evaluation import evaluate_model | |
| from model_interpretation import ( | |
| get_linear_feature_importance, | |
| analyze_errors, | |
| get_transformer_attention, | |
| visualize_attention_weights, | |
| get_token_importance_captum, | |
| plot_token_importance | |
| ) | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| if 'models' not in st.session_state: | |
| st.session_state.models = {} | |
| if 'results' not in st.session_state: | |
| st.session_state.results = {} | |
| if 'dataset' not in st.session_state: | |
| st.session_state.dataset = None | |
| if 'task_type' not in st.session_state: | |
| st.session_state.task_type = None | |
| if 'preprocessed' not in st.session_state: | |
| st.session_state.preprocessed = None | |
| if 'X' not in st.session_state: | |
| st.session_state.X = None | |
| if 'y' not in st.session_state: | |
| st.session_state.y = None | |
| if 'feature_names' not in st.session_state: | |
| st.session_state.feature_names = None | |
| if 'vectorizer' not in st.session_state: | |
| st.session_state.vectorizer = None | |
| if 'vectorizer_type' not in st.session_state: | |
| st.session_state.vectorizer_type = None | |
| if 'X_test' not in st.session_state: | |
| st.session_state.X_test = None | |
| if 'y_test' not in st.session_state: | |
| st.session_state.y_test = None | |
| if 'test_texts' not in st.session_state: | |
| st.session_state.test_texts = None | |
| if 'label_encoder' not in st.session_state: | |
| st.session_state.label_encoder = None | |
| if 'rubert_model' not in st.session_state: | |
| st.session_state.rubert_model = None | |
| if 'rubert_tokenizer' not in st.session_state: | |
| st.session_state.rubert_tokenizer = None | |
| if 'rubert_trained' not in st.session_state: | |
| st.session_state.rubert_trained = False | |
| st.sidebar.title("Setup") | |
| st.sidebar.subheader("1. Upload Dataset (JSONL)") | |
| uploaded_file = st.sidebar.file_uploader("Upload .jsonl file", type=["jsonl"]) | |
| if uploaded_file: | |
| try: | |
| raw_data = [] | |
| lines = uploaded_file.getvalue().decode("utf-8").splitlines() | |
| for line in lines: | |
| if line.strip(): | |
| raw_data.append(json.loads(line)) | |
| st.session_state.dataset = raw_data | |
| first = raw_data[0] | |
| if 'sentiment' in first: | |
| st.session_state.task_type = "binary" | |
| labels = [item['sentiment'] for item in raw_data] | |
| elif 'category' in first: | |
| st.session_state.task_type = "multiclass" | |
| labels = [item['category'] for item in raw_data] | |
| elif 'tags' in first: | |
| st.session_state.task_type = "multilabel" | |
| labels = [item['tags'] for item in raw_data] | |
| else: | |
| st.sidebar.error("No label field found") | |
| st.session_state.task_type = None | |
| st.session_state.dataset = None | |
| if st.session_state.task_type: | |
| st.sidebar.success(f"Loaded {len(raw_data)} samples. Task: {st.session_state.task_type}") | |
| if st.session_state.task_type == "binary": | |
| id2label = {0: "Negative", 1: "Positive"} | |
| label2id = {"Negative": 0, "Positive": 1} | |
| elif st.session_state.task_type == "multiclass": | |
| id2label = {0: "Политика", 1: "Экономика", 2: "Спорт", 3: "Культура"} | |
| label2id = {"Политика": 0, "Экономика": 1, "Спорт": 2, "Культура": 3} | |
| else: | |
| id2label = None | |
| label2id = None | |
| st.session_state.id2label = id2label | |
| st.session_state.label2id = label2id | |
| except Exception as e: | |
| st.sidebar.error(f"Failed to parse JSONL: {e}") | |
| st.session_state.dataset = None | |
| if st.session_state.dataset is not None: | |
| st.sidebar.subheader("2. Preprocess Text") | |
| lang = st.sidebar.selectbox("Language", ["ru", "en"], index=0) | |
| st.session_state.preprocess_lang = 'ru' | |
| if st.sidebar.button("Run Preprocessing"): | |
| with st.spinner("Preprocessing..."): | |
| texts = [item['text'] for item in st.session_state.dataset] | |
| preprocessed = [preprocess_text(text, lang='ru', remove_stopwords=False) for text in texts] | |
| st.session_state.preprocessed = preprocessed | |
| st.sidebar.success("Preprocessing done!") | |
| if st.session_state.preprocessed is not None: | |
| st.sidebar.subheader("3. Vectorization (Classical)") | |
| vectorizer_type = st.sidebar.selectbox("Method", ["TF-IDF", "RuBERT Embeddings"]) | |
| if st.sidebar.button("Vectorize"): | |
| with st.spinner("Vectorizing..."): | |
| if vectorizer_type == "TF-IDF": | |
| vectorizer = TextVectorizer() | |
| if not isinstance(st.session_state.preprocessed[0], str): | |
| st.session_state.preprocessed = [ | |
| ' '.join(text) for text in st.session_state.preprocessed | |
| ] | |
| st.sidebar.write("Using max_features=5000") | |
| X = vectorizer.tfidf(st.session_state.preprocessed, max_features=5000) | |
| st.sidebar.write(f"X shape: {X.shape}") | |
| st.session_state.vectorizer = vectorizer | |
| st.session_state.feature_names = vectorizer.tfidf_vectorizer.get_feature_names_out() | |
| else: | |
| X = [] | |
| for text in st.session_state.preprocessed: | |
| emb = get_contextual_embeddings([text], model_name="DeepPavlov/rubert-base-cased") | |
| X.append(emb[0]) | |
| X = np.array(X) | |
| st.session_state.vectorizer = None | |
| st.session_state.feature_names = None | |
| st.session_state.X = X | |
| st.session_state.vectorizer_type = vectorizer_type | |
| if st.session_state.task_type == "binary": | |
| y = np.array([item['sentiment'] for item in st.session_state.dataset]) | |
| elif st.session_state.task_type == "multiclass": | |
| y = np.array([item['category'] for item in st.session_state.dataset]) | |
| else: | |
| y = [item['tags'] for item in st.session_state.dataset] | |
| st.session_state.y = y | |
| st.sidebar.success("Vectorization complete!") | |
| if st.session_state.X is not None: | |
| st.sidebar.subheader("4. Train Classical Models") | |
| model_options = ["Logistic Regression", "SVM", "Random Forest", "XGBoost", "Voting"] | |
| selected_models = st.sidebar.multiselect("Models", model_options) | |
| if st.sidebar.button("Train Classical Models"): | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.preprocessing import LabelEncoder | |
| X = st.session_state.X | |
| y = st.session_state.y | |
| if st.session_state.task_type == "multiclass": | |
| le = LabelEncoder() | |
| y_encoded = le.fit_transform(y) | |
| st.session_state.label_encoder = le | |
| y_for_split = y_encoded | |
| else: | |
| y_for_split = y if st.session_state.task_type == "binary" else np.array([len(tags) for tags in y]) | |
| if st.session_state.task_type == "multilabel": | |
| split_idx = int(0.8 * len(X)) | |
| X_train, X_test = X[:split_idx], X[split_idx:] | |
| y_train, y_test = y[:split_idx], y[split_idx:] | |
| test_texts = [item['text'] for item in st.session_state.dataset[split_idx:]] | |
| else: | |
| indices = np.arange(len(X)) | |
| X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split( | |
| X, y_for_split, indices, test_size=0.2, | |
| stratify=y_for_split if st.session_state.task_type != "multilabel" else None, | |
| random_state=42 | |
| ) | |
| test_texts = [st.session_state.dataset[i]['text'] for i in idx_test] | |
| if st.session_state.task_type == "multiclass": | |
| y_train = le.inverse_transform(y_train) | |
| y_test = le.inverse_transform(y_test) | |
| st.session_state.X_test = X_test | |
| st.session_state.y_test = y_test | |
| st.session_state.test_texts = test_texts | |
| for name in selected_models: | |
| try: | |
| with st.spinner(f"Training {name}..."): | |
| if name == "Logistic Regression": | |
| model = get_logistic_regression() | |
| model.fit(X_train, y_train) | |
| st.session_state.models[name] = model | |
| elif name == "SVM": | |
| model = get_svm_linear() | |
| model.fit(X_train, y_train) | |
| st.session_state.models[name] = model | |
| elif name == "Random Forest": | |
| model = get_random_forest() | |
| model.fit(X_train, y_train) | |
| st.session_state.models[name] = model | |
| elif name == "XGBoost": | |
| model = get_gradient_boosting("xgb", n_estimators=100) | |
| model.fit(X_train, y_train) | |
| st.session_state.models[name] = model | |
| elif name == "Voting": | |
| model = get_voting_classifier() | |
| model.fit(X_train, y_train) | |
| st.session_state.models[name] = model | |
| if st.session_state.task_type != "multilabel": | |
| metrics = evaluate_model(model, X_test, y_test) | |
| st.session_state.results[name] = metrics | |
| except Exception as e: | |
| st.sidebar.error(f"Failed to train {name}: {e}") | |
| continue | |
| st.sidebar.success("Classical models trained!") | |
| if st.session_state.dataset is not None and st.session_state.task_type in ["binary", "multiclass"]: | |
| st.sidebar.subheader("5. Train RuBERT (Transformer)") | |
| if st.sidebar.button("Train RuBERT"): | |
| with st.spinner("Loading RuBERT..."): | |
| try: | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig | |
| num_labels = 2 if st.session_state.task_type == "binary" else 4 | |
| model_name = "DeepPavlov/rubert-base-cased" | |
| config = AutoConfig.from_pretrained( | |
| model_name, | |
| num_labels=num_labels, | |
| id2label=st.session_state.id2label, | |
| label2id=st.session_state.label2id | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config) | |
| st.session_state.rubert_model = model | |
| st.session_state.rubert_tokenizer = tokenizer | |
| st.session_state.rubert_trained = True | |
| st.sidebar.success("RuBERT loaded with correct labels!") | |
| except Exception as e: | |
| st.sidebar.error(f"RuBERT loading failed: {e}") | |
| st.exception(e) | |
| st.title("Text Classifiers") | |
| tab1, tab2, tab3, tab4 = st.tabs([ | |
| "Classify", | |
| "Interpret", | |
| "Compare", | |
| "Error Analysis" | |
| ]) | |
| with tab1: | |
| st.subheader("Classify New Text") | |
| input_text = st.text_area("Enter text", "Сегодня прошёл важный матч по хоккею.") | |
| if st.button("Classify"): | |
| cols = st.columns(2) | |
| with cols[0]: | |
| st.markdown("### Classical Models") | |
| if not st.session_state.models: | |
| st.info("No classical models trained") | |
| else: | |
| tokens = preprocess_text(input_text, lang='ru', remove_stopwords=False) | |
| preprocessed = " ".join(tokens) | |
| if st.session_state.vectorizer_type == "TF-IDF": | |
| X_input = st.session_state.vectorizer.tfidf_vectorizer.transform([preprocessed]).toarray() | |
| else: | |
| X_input = get_contextual_embeddings([preprocessed], model_name="DeepPavlov/rubert-base-cased") | |
| for name, model in st.session_state.models.items(): | |
| pred = model.predict(X_input)[0] | |
| st.write(f"**{name}**: {pred}") | |
| if hasattr(model, "predict_proba"): | |
| proba = model.predict_proba(X_input)[0] | |
| st.write(f"Probabilities: {dict(zip(model.classes_, proba))}") | |
| with cols[1]: | |
| st.markdown("### RuBERT") | |
| if not st.session_state.rubert_trained: | |
| st.info("Train RuBERT in sidebar") | |
| else: | |
| try: | |
| from transformers import pipeline | |
| pipe = pipeline( | |
| "text-classification", | |
| model=st.session_state.rubert_model, | |
| tokenizer=st.session_state.rubert_tokenizer, | |
| device=-1 | |
| ) | |
| result = pipe(input_text) | |
| label = result[0]['label'] | |
| confidence = result[0]['score'] | |
| if label.startswith("LABEL_") and st.session_state.id2label: | |
| label_id = int(label.replace("LABEL_", "")) | |
| readable_label = st.session_state.id2label.get(label_id, label) | |
| else: | |
| readable_label = label | |
| st.write(f"**Prediction**: {readable_label}") | |
| st.write(f"**Confidence**: {confidence:.3f}") | |
| except Exception as e: | |
| st.error(f"RuBERT inference failed: {e}") | |
| with tab2: | |
| subtab1, subtab2, subtab3 = st.tabs(["SHAP / LIME", "Attention Map", "Captum Heatmap"]) | |
| with subtab1: | |
| st.subheader("SHAP: Local Explanation for One Text") | |
| if not st.session_state.models: | |
| st.info("Train a classical model first") | |
| else: | |
| model_name = st.selectbox("Model", list(st.session_state.models.keys()), key="shap_model") | |
| text_for_explain = st.text_area("Text to explain", "Прекрасная новость о росте экономики!", key="shap_text") | |
| top_k = st.slider("Top features to show", 5, 30, 15) | |
| if st.button("Explain with SHAP"): | |
| try: | |
| import shap | |
| model = st.session_state.models[model_name] | |
| tokens = preprocess_text(text_for_explain, lang='ru', remove_stopwords=False) | |
| preprocessed = " ".join(tokens) | |
| if st.session_state.vectorizer_type == "TF-IDF": | |
| X_input = st.session_state.vectorizer.tfidf_vectorizer.transform([preprocessed]).toarray() | |
| feature_names = st.session_state.feature_names | |
| else: | |
| X_input = get_contextual_embeddings([preprocessed], model_name="DeepPavlov/rubert-base-cased") | |
| feature_names = [f"emb_{i}" for i in range(X_input.shape[1])] | |
| background = st.session_state.X[:100] | |
| # st.write(f"DEBUG: st.session_state.X shape = {st.session_state.X.shape}") | |
| # st.write(f"DEBUG: X_input shape = {X_input.shape}") | |
| # st.write(f'DEBUG: background shape = {background.shape}') | |
| if "tree" in str(type(model)).lower(): | |
| explainer = shap.TreeExplainer(model) | |
| shap_values = explainer.shap_values(X_input) | |
| else: | |
| explainer = shap.KernelExplainer(model.predict_proba, background) | |
| shap_values = explainer.shap_values(X_input, nsamples=200) | |
| if isinstance(shap_values, list): | |
| probs = model.predict_proba(X_input)[0] | |
| target_class = int(np.argmax(probs)) | |
| single_shap = shap_values[target_class][0] | |
| expected_val = explainer.expected_value[target_class] | |
| else: | |
| sv = shap_values | |
| if sv.ndim == 1: | |
| single_shap = sv | |
| expected_val = explainer.expected_value | |
| elif sv.ndim == 2: | |
| if sv.shape[0] == 1: | |
| single_shap = sv[0] | |
| expected_val = explainer.expected_value | |
| elif sv.shape[1] == X_input.shape[1]: | |
| probs = model.predict_proba(X_input)[0] | |
| target_class = int(np.argmax(probs)) | |
| single_shap = sv[:, target_class] | |
| expected_val = explainer.expected_value[target_class] if isinstance( | |
| explainer.expected_value, (list, np.ndarray)) else explainer.expected_value | |
| else: | |
| single_shap = sv[0] | |
| expected_val = explainer.expected_value | |
| elif sv.ndim == 3: | |
| if sv.shape[0] != 1: | |
| raise ValueError("SHAP explanation for more than one sample not supported") | |
| probs = model.predict_proba(X_input)[0] | |
| target_class = int(np.argmax(probs)) | |
| single_shap = sv[0, :, target_class] | |
| if isinstance(explainer.expected_value, (list, np.ndarray)) and len( | |
| explainer.expected_value) == sv.shape[2]: | |
| expected_val = explainer.expected_value[target_class] | |
| else: | |
| expected_val = explainer.expected_value | |
| else: | |
| raise ValueError(f"Unsupported SHAP shape: {sv.shape}") | |
| single_shap = np.array(single_shap).flatten() | |
| if single_shap.shape[0] != X_input.shape[1]: | |
| raise ValueError( | |
| f"SHAP vector length {single_shap.shape[0]} != input features {X_input.shape[1]}") | |
| if st.session_state.vectorizer_type == "TF-IDF": | |
| text_vector = X_input[0] | |
| nonzero_indices = np.where(text_vector != 0)[0] | |
| if len(nonzero_indices) == 0: | |
| st.warning("No known words from training vocabulary found in this text.") | |
| else: | |
| filtered_shap = single_shap[nonzero_indices] | |
| filtered_features = text_vector[nonzero_indices] | |
| filtered_names = [st.session_state.feature_names[i] for i in nonzero_indices] | |
| explanation = shap.Explanation( | |
| values=filtered_shap, | |
| base_values=expected_val, | |
| data=filtered_features, | |
| feature_names=filtered_names | |
| ) | |
| plt.figure(figsize=(10, min(8, top_k * 0.3))) | |
| shap.plots.waterfall(explanation, max_display=top_k, show=False) | |
| st.pyplot(plt.gcf()) | |
| plt.close() | |
| else: | |
| explanation = shap.Explanation( | |
| values=single_shap, | |
| base_values=expected_val, | |
| data=X_input[0], | |
| feature_names=feature_names | |
| ) | |
| plt.figure(figsize=(10, min(8, top_k * 0.3))) | |
| shap.plots.waterfall(explanation, max_display=top_k, show=False) | |
| st.pyplot(plt.gcf()) | |
| plt.close() | |
| except Exception as e: | |
| st.error(f"SHAP error: {e}") | |
| st.exception(e) | |
| with subtab2: | |
| st.subheader("Transformer Attention Map") | |
| if not st.session_state.rubert_trained: | |
| st.info("Train RuBERT first") | |
| else: | |
| text_att = st.text_area("Text for attention", "Матч завершился победой ЦСКА", key="att_text") | |
| layer = st.slider("Layer", 0, 11, 6) | |
| head = st.slider("Head", 0, 11, 0) | |
| if st.button("Visualize Attention"): | |
| try: | |
| tokens, attn = get_transformer_attention( | |
| st.session_state.rubert_model, | |
| st.session_state.rubert_tokenizer, | |
| text_att, | |
| device="cpu" | |
| ) | |
| weights = attn[layer, head, :len(tokens), :len(tokens)] | |
| fig, ax = plt.subplots(figsize=(10, 4)) | |
| sns.heatmap( | |
| weights, | |
| xticklabels=tokens, | |
| yticklabels=tokens, | |
| cmap="viridis", | |
| ax=ax | |
| ) | |
| plt.xticks(rotation=45, ha="right") | |
| plt.yticks(rotation=0) | |
| plt.title(f"Attention: Layer {layer}, Head {head}") | |
| st.pyplot(fig) | |
| plt.close(fig) | |
| except Exception as e: | |
| st.error(f"Attention failed: {e}") | |
| st.exception(e) | |
| with subtab3: | |
| st.subheader("Token Importance (Captum)") | |
| if not st.session_state.rubert_trained: | |
| st.info("Train RuBERT first") | |
| else: | |
| text_captum = st.text_area("Text for Captum", "Это очень плохая новость для политики", key="captum_text") | |
| method = "IntegratedGradients" | |
| if st.button("Compute Token Importance"): | |
| try: | |
| tokens, importance = get_token_importance_captum( | |
| st.session_state.rubert_model, | |
| st.session_state.rubert_tokenizer, | |
| text_captum, | |
| device="cpu" | |
| ) | |
| valid = [(t, imp) for t, imp in zip(tokens, importance) if t not in ["[CLS]", "[SEP]", "[PAD]"]] | |
| if valid: | |
| tokens_clean, imp_clean = zip(*valid) | |
| indices = np.argsort(np.abs(imp_clean))[-15:][::-1] | |
| tokens_top = [tokens_clean[i] for i in indices] | |
| imp_top = [imp_clean[i] for i in indices] | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| colors = ["red" if x < 0 else "green" for x in imp_top] | |
| ax.barh(range(len(imp_top)), imp_top, color=colors) | |
| ax.set_yticks(range(len(imp_top))) | |
| ax.set_yticklabels(tokens_top) | |
| ax.invert_yaxis() | |
| ax.set_xlabel("Attribution Score") | |
| ax.set_title("Token Importance") | |
| st.pyplot(fig) | |
| plt.close(fig) | |
| else: | |
| st.warning("No valid tokens") | |
| except Exception as e: | |
| st.error(f"Captum failed: {e}") | |
| st.exception(e) | |
| with tab3: | |
| st.subheader("Model Comparison") | |
| if st.session_state.results: | |
| df = pd.DataFrame(st.session_state.results).T | |
| st.dataframe(df) | |
| else: | |
| st.info("Train models to see metrics") | |
| with tab4: | |
| st.subheader("Error Analysis") | |
| if st.session_state.X_test is None: | |
| st.info("Train models first") | |
| else: | |
| model_name = st.selectbox("Model for error analysis", list(st.session_state.models.keys()), key="err_model") | |
| if st.button("Analyze Errors"): | |
| model = st.session_state.models[model_name] | |
| y_pred = model.predict(st.session_state.X_test) | |
| errors = analyze_errors( | |
| st.session_state.y_test, | |
| y_pred, | |
| st.session_state.test_texts | |
| ) | |
| st.dataframe(errors[['text', 'true_label', 'pred_label']].head(20)) |