Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Upload app.py
Browse files
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,284 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 2 | 
            +
            """app.ipynb
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            Automatically generated by Colab.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            Original file is located at
         | 
| 7 | 
            +
                https://colab.research.google.com/drive/1y3yISz14Lpsr131OIJCKA77lwbFmEJzB
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import streamlit as st
         | 
| 11 | 
            +
            import os
         | 
| 12 | 
            +
            import joblib
         | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
            import numpy as np
         | 
| 15 | 
            +
            import html
         | 
| 16 | 
            +
            from transformers import AutoTokenizer, AutoModel, logging as hf_logging
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            # Hugging Face Transformers ๋ก๊น
 ๋ ๋ฒจ ์ค์  (์ค๋ฅ๋ง ํ์)
         | 
| 19 | 
            +
            hf_logging.set_verbosity_error()
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            # โโโโโโโโโโ ์ค์  (Hugging Face Spaces ํ๊ฒฝ์ ๋ง๊ฒ ์กฐ์ ) โโโโโโโโโโ
         | 
| 22 | 
            +
            MODEL_NAME = "bert-base-uncased"
         | 
| 23 | 
            +
            DEVICE     = "cpu"  # Hugging Face Spaces ๋ฌด๋ฃ ํฐ์ด๋ CPU ์ฌ์ฉ
         | 
| 24 | 
            +
            SAVE_DIR   = "์ ์ฅ์ ์ฅ1" # ์
๋ก๋ํ  ํด๋๋ช
๊ณผ ์ผ์นํด์ผ ํจ
         | 
| 25 | 
            +
            LAYER_ID   = 4      # ์๋ณธ ์ฝ๋์ SeparationScore ์ต๊ณ  ๋ ์ด์ด
         | 
| 26 | 
            +
            SEED       = 0      # ์๋ณธ ์ฝ๋์ SEED ๊ฐ
         | 
| 27 | 
            +
            CLF_NAME   = "linear" # ์๋ณธ ์ฝ๋์ CLF_NAME
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            # โโโโโโโโโโ ๋ชจ๋ธ ๋ก๋ (Streamlit ์บ์ ์ฌ์ฉ์ผ๋ก ์ฑ ์ ์ฒด์์ ํ ๋ฒ๋ง ์คํ) โโโโโโโโโโ
         | 
| 30 | 
            +
            @st.cache_resource
         | 
| 31 | 
            +
            def load_all_models_and_data():
         | 
| 32 | 
            +
                """
         | 
| 33 | 
            +
                LDA, ๋ถ๋ฅ๊ธฐ, ํ ํฌ๋์ด์ , BERT ๋ชจ๋ธ ๋ฐ ๊ด๋ จ ํ๋ ฌ๋ค์ ๋ก๋ํฉ๋๋ค.
         | 
| 34 | 
            +
                Hugging Face Spaces์ ๋ฐฐํฌ ์ ํ์ผ ๊ฒฝ๋ก๊ฐ ์ ํํด์ผ ํฉ๋๋ค.
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
                lda_file_path = os.path.join(SAVE_DIR, f"lda_layer{LAYER_ID}_seed{SEED}.pkl")
         | 
| 37 | 
            +
                clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl")
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                # ํ์ผ ์กด์ฌ ์ฌ๋ถ ํ์ธ (๋ฐฐํฌ ํ๊ฒฝ ๋๋ฒ๊น
์ฉ)
         | 
| 40 | 
            +
                if not os.path.isdir(SAVE_DIR):
         | 
| 41 | 
            +
                    st.error(f"์ค๋ฅ: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ '{SAVE_DIR}'๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค. Spaces์ ํด๋๊ฐ ์ฌ๋ฐ๋ฅด๊ฒ ์
๋ก๋๋์๋์ง, ์ด๋ฆ์ด ์ผ์นํ๋์ง ํ์ธํ์ธ์.")
         | 
| 42 | 
            +
                    return None
         | 
| 43 | 
            +
                if not os.path.exists(lda_file_path):
         | 
| 44 | 
            +
                    st.error(f"์ค๋ฅ: LDA ๋ชจ๋ธ ํ์ผ '{lda_file_path}'๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค. ํ์ผ ์ด๋ฆ๊ณผ ๊ฒฝ๋ก๋ฅผ ํ์ธํ์ธ์.")
         | 
| 45 | 
            +
                    return None
         | 
| 46 | 
            +
                if not os.path.exists(clf_file_path):
         | 
| 47 | 
            +
                    st.error(f"์ค๋ฅ: ๋ถ๋ฅ๊ธฐ ๋ชจ๋ธ ํ์ผ '{clf_file_path}'๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค. ํ์ผ ์ด๋ฆ๊ณผ ๊ฒฝ๋ก๋ฅผ ํ์ธํ์ธ์.")
         | 
| 48 | 
            +
                    return None
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                try:
         | 
| 51 | 
            +
                    lda = joblib.load(lda_file_path)
         | 
| 52 | 
            +
                    clf = joblib.load(clf_file_path)
         | 
| 53 | 
            +
                except Exception as e:
         | 
| 54 | 
            +
                    st.error(f"๋ชจ๋ธ ํ์ผ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}")
         | 
| 55 | 
            +
                    st.error("ํ์ผ์ด ์์๋์๊ฑฐ๋, joblib ๋ฒ์  ํธํ์ฑ ๋ฌธ์ ๊ฐ ์์ ์ ์์ต๋๋ค.")
         | 
| 56 | 
            +
                    return None
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                if hasattr(clf, "base_estimator"):  # Calibrated Ridge ๊ฒฝ์ฐ
         | 
| 59 | 
            +
                    clf = clf.base_estimator
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                # LDA ํ๋ ฌยทํ๊ท , ๋ถ๋ฅ๊ธฐ ๊ฐ์ค์น๋ฅผ PyTorch Tensor๋ก ๋ณํ
         | 
| 62 | 
            +
                W_tensor   = torch.tensor(lda.scalings_,  dtype=torch.float32, device=DEVICE)
         | 
| 63 | 
            +
                mu_vector  = torch.tensor(lda.xbar_,     dtype=torch.float32, device=DEVICE)
         | 
| 64 | 
            +
                w_p_tensor = torch.tensor(clf.coef_,     dtype=torch.float32, device=DEVICE)
         | 
| 65 | 
            +
                b_p_vector = torch.tensor(clf.intercept_, dtype=torch.float32, device=DEVICE)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                # Hugging Face ํ ํฌ๋์ด์  ๋ฐ BERT ๋ชจ๋ธ ๋ก๋
         | 
| 68 | 
            +
                try:
         | 
| 69 | 
            +
                    tokenizer_obj = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
         | 
| 70 | 
            +
                    model_obj     = AutoModel.from_pretrained(
         | 
| 71 | 
            +
                        MODEL_NAME, output_hidden_states=True
         | 
| 72 | 
            +
                    ).to(DEVICE).eval()
         | 
| 73 | 
            +
                except Exception as e:
         | 
| 74 | 
            +
                    st.error(f"Hugging Face ๋ชจ๋ธ ({MODEL_NAME}) ๋ก๋ ์ค ์ค๋ฅ: {e}")
         | 
| 75 | 
            +
                    st.error("์ธํฐ๋ท ์ฐ๊ฒฐ ๋๋ ๋ชจ๋ธ ์ด๋ฆ์ด ์ฌ๋ฐ๋ฅธ์ง ํ์ธํ์ธ์.")
         | 
| 76 | 
            +
                    return None
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                # ํด๋์ค ์ด๋ฆ ๊ฐ์ ธ์ค๊ธฐ ์๋
         | 
| 79 | 
            +
                class_names = None
         | 
| 80 | 
            +
                if hasattr(lda, 'classes_'): # scikit-learn LDA์ ๊ฒฝ์ฐ
         | 
| 81 | 
            +
                    class_names = lda.classes_
         | 
| 82 | 
            +
                elif hasattr(clf, 'classes_'): # scikit-learn ๋ถ๋ฅ๊ธฐ์ ๊ฒฝ์ฐ
         | 
| 83 | 
            +
                    class_names = clf.classes_
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                return tokenizer_obj, model_obj, W_tensor, mu_vector, w_p_tensor, b_p_vector, class_names
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            # โโโโโโโโโโ ํต์ฌ ๋ถ์ ํจ์ (์๋ณธ ์ฝ๋ ๊ธฐ๋ฐ) โโโโโโโโโโ
         | 
| 88 | 
            +
            def explain_sentence_streamlit(
         | 
| 89 | 
            +
                text: str,
         | 
| 90 | 
            +
                tokenizer, model, W, mu, w_p, b_p, # ๋ก๋๋ ๊ฐ์ฒด๋ค
         | 
| 91 | 
            +
                layer_id_to_use: int, device_to_use: str, # ์ค์ ๊ฐ
         | 
| 92 | 
            +
                top_k_tokens: int = 5
         | 
| 93 | 
            +
            ) -> tuple[str, int, float, list] | None: # ๊ฒฐ๊ณผ ํ์
 ๋ช
์ (์คํจ ์ None)
         | 
| 94 | 
            +
                """
         | 
| 95 | 
            +
                ์
๋ ฅ ๋ฌธ์ฅ์ ์์ธกํ๊ณ  ํ ํฐ ์ค์๋๋ฅผ ๊ณ์ฐํ์ฌ ๊ฒฐ๊ณผ๋ฅผ ๋ฐํํฉ๋๋ค.
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
                try:
         | 
| 98 | 
            +
                    # 1) ํ ํฐํ (์ต๋ ๊ธธ์ด ๋ฐ ์๋ฆผ ์ฒ๋ฆฌ ์ถ๊ฐ)
         | 
| 99 | 
            +
                    enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=510, padding=True) # BERT ์ต๋ ๊ธธ์ด 512 ๊ณ ๋ ค, CLS/SEP ๊ณต๊ฐ ํ๋ณด
         | 
| 100 | 
            +
                    input_ids  = enc["input_ids"].to(device_to_use)
         | 
| 101 | 
            +
                    attn_mask  = enc["attention_mask"].to(device_to_use)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    if input_ids.shape[1] == 0: # ์
๋ ฅ์ด ๋๋ฌด ์งง๊ฑฐ๋ ๋ชจ๋ ํํฐ๋ง ๋ ๊ฒฝ์ฐ
         | 
| 104 | 
            +
                         # Streamlit ์ฑ์์๋ ์ฌ์ฉ์์๊ฒ ๊ฒฝ๊ณ ๋ฅผ ํ์ํ  ์ ์์ต๋๋ค.
         | 
| 105 | 
            +
                         # st.warning("ํ ํฐํ ๊ฒฐ๊ณผ ์ ํจํ ํ ํฐ์ด ์์ต๋๋ค. ๋ค๋ฅธ ๋ฌธ์ฅ์ ์๋ํด๏ฟฝ๏ฟฝ๏ฟฝ์ธ์.")
         | 
| 106 | 
            +
                         return None
         | 
| 107 | 
            +
             | 
| 108 | 
            +
             | 
| 109 | 
            +
                    # 2) ์๋ฒ ๋ฉ์ gradient ์ถ์ 
         | 
| 110 | 
            +
                    input_embeds = model.embeddings.word_embeddings(input_ids).clone().detach()
         | 
| 111 | 
            +
                    input_embeds.requires_grad_(True)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    # 3) Forward pass โ CLS ๋ฒกํฐ ์ถ์ถ
         | 
| 114 | 
            +
                    outputs = model(inputs_embeds=input_embeds,
         | 
| 115 | 
            +
                                    attention_mask=attn_mask, # Attention mask ์ ๋ฌ
         | 
| 116 | 
            +
                                    output_hidden_states=True)
         | 
| 117 | 
            +
                    cls_vec = outputs.hidden_states[layer_id_to_use][:, 0, :]  # (1, 768)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    # 4) LDA ํฌ์ โ ๋ถ๋ฅ logit ๊ณ์ฐ
         | 
| 120 | 
            +
                    z_projected = (cls_vec - mu) @ W          # (1, d)
         | 
| 121 | 
            +
                    logit_output = z_projected @ w_p.T + b_p  # (1, C)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    probs = torch.softmax(logit_output, dim=1)
         | 
| 124 | 
            +
                    pred_idx = torch.argmax(probs, dim=1).item()
         | 
| 125 | 
            +
                    pred_prob = probs[0, pred_idx].item()
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    # 5) Gradient ๊ณ์ฐ
         | 
| 128 | 
            +
                    if input_embeds.grad is not None:
         | 
| 129 | 
            +
                        input_embeds.grad.zero_() # ์ด์  ๊ทธ๋๋์ธํธ ์ด๊ธฐํ
         | 
| 130 | 
            +
                    logit_output[0, pred_idx].backward() # ์ ํ๋ ์์ธก ํด๋์ค์ ๋ํ ๊ทธ๋๋์ธํธ ๊ณ์ฐ
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    if input_embeds.grad is None: # backward ํ์๋ grad๊ฐ ์๋ ์์ธ์  ์ํฉ ๋ฐฉ์ง
         | 
| 133 | 
            +
                        # st.error("๊ทธ๋๋์ธํธ๋ฅผ ๊ณ์ฐํ  ์ ์์ต๋๋ค.") # Streamlit ์ฑ ๋ด์์ ์ค๋ฅ ํ์
         | 
| 134 | 
            +
                        return None
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    grads = input_embeds.grad.clone().detach()
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    # 6) Grad ร Input โ ์ค์๋ ์ ์ ๊ณ์ฐ
         | 
| 139 | 
            +
                    scores = (grads * input_embeds.detach()).norm(dim=2).squeeze(0)
         | 
| 140 | 
            +
                    scores_np = scores.cpu().numpy()
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    # ์ ํจํ ์ ์๋ง์ผ๋ก ์ ๊ทํ (NaN/Inf ๋ฐฉ์ง)
         | 
| 143 | 
            +
                    valid_scores = scores_np[np.isfinite(scores_np)]
         | 
| 144 | 
            +
                    if len(valid_scores) > 0 and valid_scores.max() > 0:
         | 
| 145 | 
            +
                        scores_np = scores_np / (valid_scores.max() + 1e-9) # 0~1 ์ ๊ทํ
         | 
| 146 | 
            +
                    else: # ๋ชจ๋  ์ ์๊ฐ 0์ด๊ฑฐ๋ ์ ํจํ์ง ์์ ๊ฒฝ์ฐ
         | 
| 147 | 
            +
                        scores_np = np.zeros_like(scores_np)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
             | 
| 150 | 
            +
                    # 7) HTML ํ์ด๋ผ์ดํธ ์์ฑ
         | 
| 151 | 
            +
                    tokens = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False) # ์คํ์
 ํ ํฐ ํฌํจ
         | 
| 152 | 
            +
                    html_tokens_list = []
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    # CLS, SEP, PAD ํ ํฐ ID ํ์ธ
         | 
| 155 | 
            +
                    cls_token_id = tokenizer.cls_token_id
         | 
| 156 | 
            +
                    sep_token_id = tokenizer.sep_token_id
         | 
| 157 | 
            +
                    pad_token_id = tokenizer.pad_token_id
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    for i, tok_str in enumerate(tokens):
         | 
| 160 | 
            +
                        if input_ids[0, i] == pad_token_id: # PAD ํ ํฐ์ ๊ฑด๋๋ฐ๊ธฐ
         | 
| 161 | 
            +
                            continue
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                        clean_tok_str = tok_str.replace("##", "") if "##" not in tok_str else tok_str[2:]
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                        # ์คํ์
 ํ ํฐ์ ๋ค๋ฅธ ์คํ์ผ ์ ์ฉ ๋๋ ์ค์๋ ๊ณ์ฐ์์ ์ ์ธ ๊ฐ๋ฅ
         | 
| 166 | 
            +
                        if input_ids[0, i] == cls_token_id or input_ids[0, i] == sep_token_id:
         | 
| 167 | 
            +
                             html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>")
         | 
| 168 | 
            +
                        else:
         | 
| 169 | 
            +
                            score_val = scores_np[i] if i < len(scores_np) else 0 # ์ ์ ๋ฐฐ์ด ๋ฒ์ ํ์ธ
         | 
| 170 | 
            +
                            color = f"rgba(255, 0, 0, {max(0, min(1, score_val)):.2f})" # ์ ์ ๋ฒ์ 0~1๋ก ํด๋ฆฌํ
         | 
| 171 | 
            +
                            html_tokens_list.append(
         | 
| 172 | 
            +
                                f"<span style='background-color:{color}; padding: 1px 2px; margin: 1px; border-radius: 3px; display:inline-block;'>{html.escape(clean_tok_str)}</span>"
         | 
| 173 | 
            +
                            )
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    html_output_str = " ".join(html_tokens_list)
         | 
| 176 | 
            +
                    # ๋ถํ์ํ ๊ณต๋ฐฑ ์ ๋ฆฌ (์: subword ์ฌ์ด ๊ณต๋ฐฑ)
         | 
| 177 | 
            +
                    html_output_str = html_output_str.replace(" ##", "")
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    # Top-K ์ค์ ํ ํฐ ์ ๋ณด (์คํ์
 ํ ํฐ ๋ฐ PAD ํ ํฐ ์ ์ธ)
         | 
| 180 | 
            +
                    top_tokens_info_list = []
         | 
| 181 | 
            +
                    valid_indices_for_top_k = [
         | 
| 182 | 
            +
                        idx for idx, token_id in enumerate(input_ids[0].tolist())
         | 
| 183 | 
            +
                        if token_id not in [cls_token_id, sep_token_id, pad_token_id] and idx < len(scores_np)
         | 
| 184 | 
            +
                    ]
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    # ์ ์๊ฐ ๋์ ์์ผ๋ก ์ ๋ ฌ
         | 
| 187 | 
            +
                    sorted_valid_indices = sorted(valid_indices_for_top_k, key=lambda idx: -scores_np[idx])
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    for token_idx in sorted_valid_indices[:top_k_tokens]:
         | 
| 190 | 
            +
                        top_tokens_info_list.append({
         | 
| 191 | 
            +
                            "token": tokens[token_idx],
         | 
| 192 | 
            +
                            "score": f"{scores_np[token_idx]:.3f}"
         | 
| 193 | 
            +
                        })
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    return html_output_str, pred_idx, pred_prob, top_tokens_info_list
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                except Exception as e:
         | 
| 198 | 
            +
                    # Streamlit ์ฑ ๋ด์์ ์ค๋ฅ๋ฅผ ๋ ์ ํ์ํ๋๋ก ์์ 
         | 
| 199 | 
            +
                    # st.error(f"๋ฌธ์ฅ ๋ถ์ ์ค ์๊ธฐ์น ์์ ์ค๋ฅ ๋ฐ์: {e}")
         | 
| 200 | 
            +
                    # import traceback
         | 
| 201 | 
            +
                    # st.text_area("์ค๋ฅ ์์ธ ์ ๋ณด (๋๋ฒ๊น
์ฉ):", traceback.format_exc(), height=200)
         | 
| 202 | 
            +
                    # print(f"๋ฌธ์ฅ ๋ถ์ ์ค ์๊ธฐ์น ์์ ์ค๋ฅ ๋ฐ์: {e}") # ์ฝ์ ๋ก๊น
 (Spaces ๋ก๊ทธ์์ ํ์ธ ๊ฐ๋ฅ)
         | 
| 203 | 
            +
                    # import traceback
         | 
| 204 | 
            +
                    # print(traceback.format_exc()) # ์ฝ์ ๋ก๊น
         | 
| 205 | 
            +
                    raise # ์ค๋ฅ๋ฅผ ๋ค์ ๋ฐ์์์ผ Streamlit์ด ์ฒ๋ฆฌํ๋๋ก ํ๊ฑฐ๋, ์๋์์ None์ ๋ฐํ
         | 
| 206 | 
            +
                    # return None
         | 
| 207 | 
            +
             | 
| 208 | 
            +
             | 
| 209 | 
            +
            # โโโโโโโโโโ Streamlit UI ๊ตฌ์ฑ โโโโโโโโโโ
         | 
| 210 | 
            +
            st.set_page_config(page_title="๋ฌธ์ฅ ํ ํฐ ์ค์๋ ๋ถ์๊ธฐ", layout="wide")
         | 
| 211 | 
            +
            st.title("๐ ๋ฌธ์ฅ ํ ํฐ ์ค์๋ ๋ถ์๊ธฐ")
         | 
| 212 | 
            +
            st.markdown("BERT์ LDA๋ฅผ ํ์ฉํ์ฌ ๋ฌธ์ฅ ๋ด ๊ฐ ํ ํฐ์ ์ค์๋๋ฅผ ์๊ฐํํฉ๋๋ค.")
         | 
| 213 | 
            +
             | 
| 214 | 
            +
            # ๋ชจ๋ธ ๋ก๋ ์๋
         | 
| 215 | 
            +
            loaded_data_tuple = load_all_models_and_data()
         | 
| 216 | 
            +
             | 
| 217 | 
            +
            if loaded_data_tuple:
         | 
| 218 | 
            +
                tokenizer, model, W, mu, w_p, b_p, class_names = loaded_data_tuple
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                # ์ฌ์ด๋๋ฐ์ ๋ชจ๋ธ ์ ๋ณด ํ์
         | 
| 221 | 
            +
                st.sidebar.header("โ๏ธ ๋ชจ๋ธ ๋ฐ ์ค์  ์ ๋ณด")
         | 
| 222 | 
            +
                st.sidebar.info(f"**BERT ๋ชจ๋ธ:** `{MODEL_NAME}`\n\n"
         | 
| 223 | 
            +
                                f"**์ฌ์ฉ๋ ๋ ์ด์ด ID:** `{LAYER_ID}`\n\n"
         | 
| 224 | 
            +
                                f"**๋ถ๋ฅ๊ธฐ ์ข
๋ฅ:** `{CLF_NAME}` (LDA ํฌ์ ๊ธฐ๋ฐ)\n\n"
         | 
| 225 | 
            +
                                f"**์คํ ์ฅ์น:** `{DEVICE.upper()}`")
         | 
| 226 | 
            +
                if class_names is not None:
         | 
| 227 | 
            +
                    st.sidebar.markdown(f"**์์ธก ๊ฐ๋ฅ ํด๋์ค:** `{', '.join(map(str, class_names))}`")
         | 
| 228 | 
            +
             | 
| 229 | 
            +
             | 
| 230 | 
            +
                # ์ฌ์ฉ์ ์
๋ ฅ
         | 
| 231 | 
            +
                st.subheader("๐ ๋ถ์ํ  ์์ด ๋ฌธ์ฅ์ ์
๋ ฅํ์ธ์:")
         | 
| 232 | 
            +
                user_sentence = st.text_area("๋ฌธ์ฅ ์
๋ ฅ:", "This movie is exceptionally good and I highly recommend it.", height=100)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                top_k_slider = st.slider("ํ์ํ  Top-K ์ค์ ํ ํฐ ์:", min_value=1, max_value=10, value=5, step=1)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                if st.button("๋ถ์ ์คํํ๊ธฐ ๐", type="primary"):
         | 
| 237 | 
            +
                    if user_sentence:
         | 
| 238 | 
            +
                        with st.spinner("๋ฌธ์ฅ์ ๋ถ์ํ๊ณ  ์์ต๋๋ค... ์กฐ๊ธ๋ง ๊ธฐ๋ค๋ ค์ฃผ์ธ์...โณ"):
         | 
| 239 | 
            +
                            analysis_results = None
         | 
| 240 | 
            +
                            try:
         | 
| 241 | 
            +
                                analysis_results = explain_sentence_streamlit(
         | 
| 242 | 
            +
                                    user_sentence, tokenizer, model, W, mu, w_p, b_p,
         | 
| 243 | 
            +
                                    LAYER_ID, DEVICE, top_k_tokens=top_k_slider
         | 
| 244 | 
            +
                                )
         | 
| 245 | 
            +
                            except Exception as e: # explain_sentence_streamlit ๋ด๋ถ์์ raise๋ ์ค๋ฅ ์ฒ๋ฆฌ
         | 
| 246 | 
            +
                                st.error(f"๋ถ์ ์ฒ๋ฆฌ ์ค ์ค๋ฅ ๋ฐ์: {e}")
         | 
| 247 | 
            +
                                st.info("์
๋ ฅ ๋ฌธ์ฅ์ด๋ ๋ชจ๋ธ ํธํ์ฑ ๋ฌธ์ ๋ฅผ ํ์ธํด๋ณด์ธ์. ๋ฌธ์ ๊ฐ ์ง์๋๋ฉด ๊ด๋ฆฌ์์๊ฒ ๋ฌธ์ํ์ธ์.")
         | 
| 248 | 
            +
                                # ๋ ์์ธํ ์ค๋ฅ๋ Spaces์ ๋ก๊ทธ์์ ํ์ธ ๊ฐ๋ฅ (print๋ฌธ ์ฌ์ฉ ์)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
             | 
| 251 | 
            +
                        if analysis_results: # ์ฑ๊ณต์ ์ผ๋ก ๊ฒฐ๊ณผ ๋ฐํ ์
         | 
| 252 | 
            +
                            html_viz, predicted_idx, probability, top_k_list = analysis_results
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                            st.markdown("---")
         | 
| 255 | 
            +
                            st.subheader("๐ ๋ถ์ ๊ฒฐ๊ณผ")
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                            predicted_class_label = str(predicted_idx) # ๊ธฐ๋ณธ๊ฐ: ์ธ๋ฑ์ค
         | 
| 258 | 
            +
                            if class_names is not None and 0 <= predicted_idx < len(class_names):
         | 
| 259 | 
            +
                                predicted_class_label = str(class_names[predicted_idx]) # ํด๋์ค ์ด๋ฆ ์ฌ์ฉ
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                            st.success(f"**์์ธก๋ ํด๋์ค:** **`{predicted_class_label}`** (์ ๋ขฐ๋: **{probability:.2f}**)")
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                            st.subheader("๐จ ํ ํฐ๋ณ ์ค์๋ ์๊ฐํ")
         | 
| 264 | 
            +
                            st.markdown(html_viz, unsafe_allow_html=True)
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                            st.subheader(f"โญ Top-{top_k_slider} ์ค์ ํ ํฐ")
         | 
| 267 | 
            +
                            if top_k_list:
         | 
| 268 | 
            +
                                cols = st.columns(len(top_k_list) if len(top_k_list) <=5 else 5 ) # ํ ์ค์ ์ต๋ 5๊ฐ
         | 
| 269 | 
            +
                                for i, item in enumerate(top_k_list):
         | 
| 270 | 
            +
                                    with cols[i % len(cols)]:
         | 
| 271 | 
            +
                                         st.metric(label=item['token'], value=item['score'])
         | 
| 272 | 
            +
                            else:
         | 
| 273 | 
            +
                                st.info("์ค์๋ ๋์ ํ ํฐ์ ์ฐพ์ ์ ์์ต๋๋ค (์คํ์
 ํ ํฐ ๋ฑ ์ ์ธ).")
         | 
| 274 | 
            +
                        # 'analysis_results is None' ์ด๊ณ  ์์ธ์ฒ๋ฆฌ๋ก st.error๊ฐ ์ด๋ฏธ ํ์๋ ๊ฒฝ์ฐ๋ ์ถ๊ฐ ๋ฉ์์ง ๋ถํ์
         | 
| 275 | 
            +
                        elif analysis_results is None and not user_sentence: # ๋ฌธ์ฅ ์
๋ ฅ ์์ด ๋ฒํผ ๋๋ฅธ ๊ฒฝ์ฐ (์ฌ์ค์ ์์์ ์ฒ๋ฆฌ)
         | 
| 276 | 
            +
                            pass # ์ด๋ฏธ st.warning์ผ๋ก ์ฒ๋ฆฌ๋จ
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    else: # ๋ฌธ์ฅ ์
๋ ฅ ์์ด ๋ฒํผ ๋๋ฅธ ๊ฒฝ์ฐ
         | 
| 279 | 
            +
                        st.warning("๋ถ์ํ  ๋ฌธ์ฅ์ ์
๋ ฅํด์ฃผ์ธ์.")
         | 
| 280 | 
            +
            else:
         | 
| 281 | 
            +
                st.error("๋ชจ๋ธ ๋ก๋ฉ์ ์คํจํ์ฌ ์ ํ๋ฆฌ์ผ์ด์
์ ์์ํ  ์ ์์ต๋๋ค. ์
๋ก๋๋ ํ์ผ๊ณผ ๊ฒฝ๋ก ์ค์ ์ ํ์ธํด์ฃผ์ธ์. Hugging Face Spaces์ 'Logs' ํญ์์ ์์ธ ์ค๋ฅ๋ฅผ ํ์ธํ  ์ ์์ต๋๋ค.")
         | 
| 282 | 
            +
             | 
| 283 | 
            +
            st.markdown("---")
         | 
| 284 | 
            +
            st.markdown("<p style='text-align: center; color: grey;'>BERT ๊ธฐ๋ฐ ๋ฌธ์ฅ ๋ถ์ ๋ฐ๋ชจ</p>", unsafe_allow_html=True)
         |