Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import inseq | |
| import torch | |
| import os | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import json | |
| import requests | |
| from bs4 import BeautifulSoup | |
| import pandas as pd | |
| import numpy as np | |
| from inseq.models.huggingface_model import HuggingfaceDecoderOnlyModel | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| import plotly.graph_objects as go | |
| import re | |
| import markdown | |
| from utilities.localization import tr | |
| import faiss | |
| from sentence_transformers import SentenceTransformer, util | |
| from sentence_splitter import SentenceSplitter | |
| import html | |
| from utilities.utils import init_qwen_api | |
| from utilities.feedback_survey import display_attribution_feedback | |
| from thefuzz import process, fuzz | |
| import gc | |
| import time | |
| import sys | |
| from pathlib import Path | |
| # A dictionary to map method names to translation keys. | |
| METHOD_DESC_KEYS = { | |
| "integrated_gradients": "desc_integrated_gradients", | |
| "occlusion": "desc_occlusion", | |
| "saliency": "desc_saliency" | |
| } | |
| # Configuration for the influence tracer. | |
| sys.path.append(str(Path(__file__).resolve().parent.parent)) | |
| INDEX_DIR = os.path.join("influence_tracer", "influence_tracer_data") | |
| INDEX_PATH = os.path.join(INDEX_DIR, "dolma_index_multi.faiss") | |
| MAPPING_PATH = os.path.join(INDEX_DIR, "dolma_mapping_multi.json") | |
| TRACER_MODEL_NAME = 'paraphrase-multilingual-mpnet-base-v2' | |
| class CachedAttribution: | |
| # A mock object to mimic inseq's Attribution object for cached results. | |
| def __init__(self, html_content): | |
| self.html_content = html_content | |
| def show(self, display=False, return_html=True): | |
| return self.html_content | |
| def load_all_attribution_models(): | |
| # Loads all the attribution models. | |
| try: | |
| # Set the device to MPS, CUDA, or CPU. | |
| device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" | |
| # Path to the local model. | |
| model_path = "./models/OLMo-2-1124-7B" | |
| hf_token = os.environ.get("HF_TOKEN") | |
| # Load tokenizer and model. | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, token=hf_token, trust_remote_code=True) | |
| tokenizer.model_max_length = 512 | |
| # Load the model with half precision to save memory. | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| token=hf_token, | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| # Move the model to the selected device. | |
| base_model = base_model.to(device) | |
| # Add missing special tokens if necessary. | |
| if tokenizer.bos_token is None: | |
| tokenizer.add_special_tokens({'bos_token': '<s>'}) | |
| base_model.resize_token_embeddings(len(tokenizer)) | |
| # Patch the model config. | |
| if base_model.config.bos_token_id is None: | |
| base_model.config.bos_token_id = tokenizer.bos_token_id | |
| attribution_models = {} | |
| # Set up the Integrated Gradients model. | |
| attribution_models["integrated_gradients"] = HuggingfaceDecoderOnlyModel( | |
| model=base_model, | |
| tokenizer=tokenizer, | |
| device=device, | |
| attribution_method="integrated_gradients", | |
| attribution_kwargs={"n_steps": 10} | |
| ) | |
| # Set up the Occlusion model. | |
| attribution_models["occlusion"] = HuggingfaceDecoderOnlyModel( | |
| model=base_model, | |
| tokenizer=tokenizer, | |
| device=device, | |
| attribution_method="occlusion" | |
| ) | |
| # Set up the Saliency model. | |
| attribution_models["saliency"] = HuggingfaceDecoderOnlyModel( | |
| model=base_model, | |
| tokenizer=tokenizer, | |
| device=device, | |
| attribution_method="saliency" | |
| ) | |
| return attribution_models, tokenizer, base_model, device | |
| except Exception as e: | |
| st.error(f"Error loading models: {str(e)}") | |
| return None, None, None, None | |
| def load_influence_tracer_data(): | |
| # Loads the data needed for the influence tracer. | |
| if not os.path.exists(INDEX_PATH) or not os.path.exists(MAPPING_PATH): | |
| return None, None, None | |
| index = faiss.read_index(INDEX_PATH) | |
| with open(MAPPING_PATH, 'r', encoding='utf-8') as f: | |
| mapping = json.load(f) | |
| device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" | |
| model = SentenceTransformer(TRACER_MODEL_NAME, device=device) | |
| return index, mapping, model | |
| def get_influential_docs(text_to_trace: str, lang: str): | |
| # Finds influential documents from the training data for a given text. | |
| faiss_index, doc_mapping, tracer_model = load_influence_tracer_data() | |
| if not faiss_index: | |
| return [] | |
| # Get the embedding for the input text. | |
| doc_embedding = tracer_model.encode([text_to_trace], convert_to_numpy=True, normalize_embeddings=True) | |
| # Search the FAISS index for the top k documents. | |
| k = 3 | |
| similarities, indices = faiss_index.search(doc_embedding.astype('float32'), k) | |
| # Find the most similar sentence in each influential document. | |
| results = [] | |
| query_embedding = tracer_model.encode([text_to_trace], normalize_embeddings=True) | |
| for i in range(k): | |
| doc_id = str(indices[0][i]) | |
| if doc_id in doc_mapping: | |
| doc_info = doc_mapping[doc_id] | |
| file_path = os.path.join("influence_tracer", "dolma_dataset_sample_1.6v", doc_info['file']) | |
| try: | |
| full_doc_text = "" | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| try: | |
| line_data = json.loads(line) | |
| line_text = line_data.get('text', '') | |
| # Use fuzzy matching to find the text snippet. | |
| if fuzz.partial_ratio(doc_info['text_snippet'], line_text) > 95: | |
| full_doc_text = line_text | |
| break | |
| except json.JSONDecodeError: | |
| continue | |
| # Skip if the document text wasn't found. | |
| if not full_doc_text: | |
| print(f"Warning: Could not find document snippet for doc {doc_id} in {file_path}. Skipping.") | |
| continue | |
| # Find the most similar sentence in the document. | |
| splitter = SentenceSplitter(language=lang) | |
| sentences = splitter.split(text=full_doc_text) | |
| if not sentences: | |
| sentences = [full_doc_text] | |
| # Set a batch size to avoid memory issues. | |
| sentence_embeddings = tracer_model.encode(sentences, batch_size=64, show_progress_bar=False, normalize_embeddings=True) | |
| cos_scores = util.pytorch_cos_sim(query_embedding, sentence_embeddings)[0] | |
| best_sentence_idx = torch.argmax(cos_scores).item() | |
| most_similar_sentence = sentences[best_sentence_idx] | |
| results.append({ | |
| 'id': doc_id, | |
| 'file': doc_info['file'], | |
| 'source': doc_info['source'], | |
| 'text': full_doc_text, | |
| 'similarity': float(similarities[0][i]), | |
| 'highlight_sentence': str(most_similar_sentence) | |
| }) | |
| except (IOError, KeyError) as e: | |
| print(f"Could not retrieve full text for doc {doc_id}: {e}") | |
| continue | |
| return results | |
| # --- Qwen API for Explanations --- | |
| def _cached_explain_heatmap(api_config, img_base64, csv_text, structured_prompt): | |
| # Makes a cached API call to Qwen to get an explanation for a heatmap. | |
| headers = { | |
| "Authorization": f"Bearer {api_config['api_key']}", | |
| "Content-Type": "application/json" | |
| } | |
| content = [{"type": "text", "text": structured_prompt}] | |
| if img_base64: | |
| content.append({ | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/png;base64,{img_base64}" | |
| } | |
| }) | |
| data = { | |
| "model": api_config["model"], | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": content | |
| } | |
| ], | |
| "max_tokens": 1200, | |
| "temperature": 0.2, | |
| "top_p": 0.95, | |
| "seed": 42 | |
| } | |
| response = requests.post( | |
| f"{api_config['api_endpoint']}/chat/completions", | |
| headers=headers, | |
| json=data, | |
| timeout=300 | |
| ) | |
| # Raise an exception if the API call fails. | |
| response.raise_for_status() | |
| result = response.json() | |
| return result["choices"][0]["message"]["content"] | |
| def generate_all_attribution_analyses(_attribution_models, _tokenizer, _base_model, _device, prompt, max_tokens, force_exact_num_tokens=False): | |
| # Generates text and runs attribution analysis for all methods. | |
| # Generate the text first. | |
| inputs = _tokenizer(prompt, return_tensors="pt").to(_device) | |
| generation_args = { | |
| 'max_new_tokens': max_tokens, | |
| 'do_sample': False | |
| } | |
| if force_exact_num_tokens: | |
| generation_args['min_new_tokens'] = max_tokens | |
| generated_ids = _base_model.generate( | |
| inputs.input_ids, | |
| **generation_args | |
| ) | |
| generated_text = _tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
| # Run attribution analysis for all methods. | |
| all_attributions = {} | |
| methods = ["integrated_gradients", "occlusion", "saliency"] | |
| for method in methods: | |
| attributions = _attribution_models[method].attribute( | |
| input_texts=prompt, | |
| generated_texts=generated_text | |
| ) | |
| all_attributions[method] = attributions | |
| return generated_text, all_attributions | |
| def explain_heatmap_with_csv_data(api_config, image_buffer, csv_data, context_prompt, generated_text, method_name="Attribution"): | |
| # Generates an explanation for a heatmap using the Qwen API. | |
| try: | |
| # Convert the image to base64. | |
| img_base64 = None | |
| if image_buffer: | |
| image_buffer.seek(0) | |
| image = Image.open(image_buffer) | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_base64 = base64.b64encode(buffered.getvalue()).decode() | |
| # Clean the dataframe to handle duplicates. | |
| df_clean = csv_data.copy() | |
| cols = pd.Series(df_clean.columns) | |
| if cols.duplicated().any(): | |
| for dup in cols[cols.duplicated()].unique(): | |
| dup_indices = cols[cols == dup].index.values | |
| new_names = [f"{dup} ({i+1})" for i in range(len(dup_indices))] | |
| cols[dup_indices] = new_names | |
| df_clean.columns = cols | |
| if df_clean.index.has_duplicates: | |
| counts = {} | |
| new_index = list(df_clean.index) | |
| duplicated_indices = df_clean.index[df_clean.index.duplicated(keep=False)] | |
| for i, idx in enumerate(df_clean.index): | |
| if idx in duplicated_indices: | |
| counts[idx] = counts.get(idx, 0) + 1 | |
| new_index[i] = f"{idx} ({counts[idx]})" | |
| df_clean.index = new_index | |
| # --- Rule-Based Analysis --- | |
| unstacked = df_clean.unstack() | |
| unstacked.index = unstacked.index.map('{0[1]} -> {0[0]}'.format) | |
| # Get the top 5 individual scores. | |
| top_5_individual = unstacked.abs().nlargest(5).sort_index() | |
| top_individual_text_lines = ["\n### Top 5 Strongest Individual Connections:"] | |
| for label in top_5_individual.index: | |
| score = unstacked[label] | |
| top_individual_text_lines.append(f"- **{label}**: score {score:.2f}") | |
| # Get the top 5 average input scores. | |
| avg_input_scores = df_clean.mean(axis=1) | |
| top_5_average = avg_input_scores.abs().nlargest(5).sort_index() | |
| top_average_text_lines = ["\n### Top 5 Most Influential Input Tokens (on average over the whole generation):"] | |
| for input_token in top_5_average.index: | |
| score = avg_input_scores[input_token] | |
| top_average_text_lines.append(f"- **'{input_token}'**: average score {score:.2f}") | |
| # Get the top output token sources. | |
| top_output_text_lines = [] | |
| if not df_clean.empty: | |
| avg_output_scores = df_clean.mean(axis=0) | |
| top_3_output = avg_output_scores.abs().nlargest(min(3, len(df_clean.columns))).sort_index() | |
| if not top_3_output.empty: | |
| top_output_text_lines.append("\n### Top 3 Most Influenced Generated Tokens:") | |
| for output_token in top_3_output.index: | |
| # Find which input tokens influenced this output token the most. | |
| top_sources_for_output = df_clean[output_token].abs().nlargest(min(2, len(df_clean.index))).sort_index().index.tolist() | |
| if top_sources_for_output: | |
| top_output_text_lines.append(f"- **'{output_token}'** was most influenced by **'{', '.join(top_sources_for_output)}'**.") | |
| data_text_for_llm = "\n".join(top_individual_text_lines + top_average_text_lines + top_output_text_lines) | |
| # Get method-specific context from the translation files. | |
| desc_key = METHOD_DESC_KEYS.get(method_name, "unsupported_method_desc") | |
| method_context = tr(desc_key) | |
| # Format the instruction for the LLM. | |
| instruction_p1 = tr('instruction_part_1_desc').format(method_name=method_name.replace('_', ' ').title()) | |
| # Create the prompt for the LLM. | |
| structured_prompt = f"""{tr('ai_expert_intro')} | |
| ## {tr('analysis_details')} | |
| - **{tr('method_being_used')}** {method_name.replace('_', ' ').title()} | |
| - **{tr('prompt_analyzed')}** "{context_prompt}" | |
| - **{tr('full_generated_text')}** "{generated_text}" | |
| ## {tr('method_specific_context')} | |
| {method_context} | |
| ## {tr('instructions_for_analysis')} | |
| {tr('instruction_part_1_header')} | |
| {instruction_p1} | |
| {tr('instruction_synthesis_header')} | |
| {tr('instruction_synthesis_desc')} | |
| {tr('instruction_color_coding')} | |
| ## {tr('data_section_header')} | |
| {data_text_for_llm} | |
| {tr('begin_analysis_now')}""" | |
| # Call the cached function to get the explanation. | |
| explanation = _cached_explain_heatmap(api_config, img_base64, data_text_for_llm, structured_prompt) | |
| return explanation | |
| except Exception as e: | |
| # Catch errors from data prep or the API call. | |
| st.error(f"Error generating AI explanation: {str(e)}") | |
| return tr("unable_to_generate_explanation") | |
| # --- Faithfulness Verification --- | |
| def _cached_extract_claims_from_explanation(api_config, explanation_text, analysis_method): | |
| # Makes a cached API call to Qwen to get claims from an explanation. | |
| headers = {"Authorization": f"Bearer {api_config['api_key']}", "Content-Type": "application/json"} | |
| # Dynamically set claim types based on the analysis method. | |
| claim_types_details = tr("claim_extraction_prompt_types_details") | |
| claim_extraction_prompt = f"""{tr('claim_extraction_prompt_header')} | |
| {tr('claim_extraction_prompt_instruction')} | |
| {tr('claim_extraction_prompt_context_header').format(analysis_method=analysis_method, context=analysis_method)} | |
| {tr('claim_extraction_prompt_types_header')} | |
| {claim_types_details} | |
| {tr('claim_extraction_prompt_example_header')} | |
| {tr('claim_extraction_prompt_example_explanation')} | |
| {tr('claim_extraction_prompt_example_json')} | |
| {tr('claim_extraction_prompt_analyze_header')} | |
| "{explanation_text}" | |
| {tr('claim_extraction_prompt_instruction_footer')} | |
| """ | |
| data = { | |
| "model": api_config["model"], | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": [{"type": "text", "text": claim_extraction_prompt}] | |
| } | |
| ], | |
| "max_tokens": 1500, | |
| "temperature": 0.0, # Set to 0 for deterministic output. | |
| "seed": 42 | |
| } | |
| response = requests.post( | |
| f"{api_config['api_endpoint']}/chat/completions", | |
| headers=headers, | |
| json=data, | |
| timeout=300 | |
| ) | |
| response.raise_for_status() | |
| claims_text = response.json()["choices"][0]["message"]["content"] | |
| try: | |
| # The response might be inside a markdown code block, so we try to extract it. | |
| if '```json' in claims_text: | |
| claims_text = re.search(r'```json\n(.*?)\n```', claims_text, re.DOTALL).group(1) | |
| # Parse the JSON string into a Python list. | |
| return json.loads(claims_text) | |
| except (AttributeError, json.JSONDecodeError): | |
| return [] | |
| def _cached_verify_token_justification(api_config, analysis_method, input_prompt, generated_text, token, justification): | |
| # Uses an LLM to verify if a justification for a token's importance is sound. | |
| headers = {"Authorization": f"Bearer {api_config['api_key']}", "Content-Type": "application/json"} | |
| verification_prompt = f"""{tr('justification_verification_prompt_header')} | |
| {tr('justification_verification_prompt_crucial_rule')} | |
| {tr('justification_verification_prompt_token_location')} | |
| {tr('justification_verification_prompt_special_tokens')} | |
| {tr('justification_verification_prompt_evaluating_justifications')} | |
| {tr('justification_verification_prompt_linguistic_context')} | |
| {tr('justification_verification_prompt_collective_reasoning')} | |
| **Analysis Method:** {analysis_method} | |
| **Input Prompt:** "{input_prompt}" | |
| **Generated Text:** "{generated_text}" | |
| **Token in Question:** "{token}" | |
| **Provided Justification:** "{justification}" | |
| {tr('justification_verification_prompt_task_header')} | |
| {tr('justification_verification_prompt_task_instruction')} | |
| {tr('justification_verification_prompt_json_instruction')} | |
| {tr('justification_verification_prompt_footer')} | |
| """ | |
| data = { | |
| "model": "qwen2.5-vl-72b-instruct", | |
| "messages": [{"role": "user", "content": verification_prompt}], | |
| "max_tokens": 400, | |
| "temperature": 0.0, | |
| "seed": 42, | |
| "response_format": {"type": "json_object"} | |
| } | |
| response = requests.post( | |
| f"{api_config['api_endpoint']}/chat/completions", | |
| headers=headers, | |
| json=data, | |
| timeout=300 | |
| ) | |
| response.raise_for_status() | |
| try: | |
| result_json = response.json()["choices"][0]["message"]["content"] | |
| return json.loads(result_json) | |
| except (json.JSONDecodeError, KeyError): | |
| return {"is_verified": False, "reasoning": "Could not parse the semantic justification result."} | |
| def verify_claims(claims, analysis_data): | |
| # Verifies the extracted claims against the analysis data. | |
| verification_results = [] | |
| # Pre-calculate thresholds and rankings for efficiency. | |
| all_scores_flat = analysis_data['scores_df'].abs().values.flatten() | |
| # Average influence of each input token. | |
| avg_input_scores_abs = analysis_data['scores_df'].mean(axis=1).abs().sort_values(ascending=False) | |
| avg_input_scores_raw = analysis_data['scores_df'].mean(axis=1) # Keep signs for specific value checks | |
| # Average influence on each generated token. | |
| avg_output_scores = analysis_data['scores_df'].mean(axis=0).abs().sort_values(ascending=False) | |
| input_tokens = analysis_data['scores_df'].index.tolist() | |
| generated_tokens = analysis_data['scores_df'].columns.tolist() | |
| for claim in claims: | |
| is_verified = False | |
| evidence = "Could not be verified." | |
| details = claim.get('details', {}) | |
| claim_type = claim.get('claim_type') | |
| try: | |
| # Clean tokens in the claim's details, as the LLM sometimes includes extra quotes. | |
| if 'token' in details and isinstance(details['token'], str): | |
| details['token'] = re.sub(r"^\s*['\"]|['\"]\s*$", '', details['token']).strip() | |
| if 'tokens' in details and isinstance(details['tokens'], list): | |
| details['tokens'] = [re.sub(r"^\s*['\"]|['\"]\s*$", '', t).strip() for t in details['tokens']] | |
| if claim_type == 'attribution_claim': | |
| tokens_claimed = details.get('tokens', []) | |
| qualifier = details.get('qualifier', 'significant') # Default to the lower bar | |
| score_type = details.get('score_type', 'peak') | |
| # Calculate the correct scores based on the claim's score_type. | |
| if score_type == 'average': | |
| score_series = analysis_data['scores_df'].abs().mean(axis=1) | |
| score_name = "average score" | |
| else: # peak | |
| # Check both influence GIVEN (input) and RECEIVED (output) | |
| # We use fillna(0) to handle cases where a token is not in that axis | |
| input_peaks = analysis_data['scores_df'].abs().max(axis=1) | |
| output_peaks = analysis_data['scores_df'].abs().max(axis=0) | |
| combined_scores = {} | |
| all_tokens = set(input_peaks.index) | set(output_peaks.index) | |
| for t in all_tokens: | |
| s1 = input_peaks.get(t, 0.0) | |
| s2 = output_peaks.get(t, 0.0) | |
| combined_scores[t] = max(s1, s2) | |
| score_series = pd.Series(combined_scores) | |
| score_name = "peak score" | |
| if score_series.empty: | |
| evidence = "No attribution data available to verify claim." | |
| else: | |
| all_attributions = sorted( | |
| [{'token': token, 'attribution': score} for token, score in score_series.items()], | |
| key=lambda x: x['attribution'], | |
| reverse=True | |
| ) | |
| max_score = all_attributions[0]['attribution'] if all_attributions else 0 | |
| if qualifier == 'high': | |
| threshold = 0.70 * max_score | |
| threshold_name = "high" | |
| else: # 'significant' or default | |
| threshold = 0.50 * max_score | |
| threshold_name = "significant" | |
| token_scores_dict = {item['token'].lower().strip(): item['attribution'] for item in all_attributions} | |
| unverified_tokens = [] | |
| verified_tokens_details = [] | |
| for token in tokens_claimed: | |
| # New, more robust matching logic. | |
| # First, check for a direct match for specific claims like ', (1)'. | |
| token_lower = token.lower().strip() | |
| if token_lower in token_scores_dict: | |
| matching_keys = [token_lower] | |
| else: | |
| # If no direct match, fall back to a generic search for claims like ','. | |
| # This finds all instances: ', (1)', ', (2)', etc. | |
| matching_keys = [ | |
| k for k in token_scores_dict.keys() | |
| if re.sub(r'\s\(\d+\)$', '', k).strip() == token_lower | |
| ] | |
| if not matching_keys: | |
| unverified_tokens.append(f"'{token}' (not found in analysis)") | |
| continue | |
| # Check each matching instance against the threshold. | |
| for key in matching_keys: | |
| actual_score = token_scores_dict.get(key) | |
| if abs(actual_score) < threshold: | |
| unverified_tokens.append(f"'{key}' ({score_name}: {abs(actual_score):.2f})") | |
| else: | |
| verified_tokens_details.append(f"'{key}' ({score_name}: {abs(actual_score):.2f})") | |
| is_verified = not unverified_tokens | |
| if is_verified: | |
| evidence = f"Verified. All claimed tokens passed the {threshold_name} threshold (> {threshold:.2f}). Details: {', '.join(verified_tokens_details)}." | |
| else: | |
| fail_reason = f"the following did not meet the {threshold_name} threshold (> {threshold:.2f}): {', '.join(unverified_tokens)}" | |
| if verified_tokens_details: | |
| evidence = f"While some tokens passed ({', '.join(verified_tokens_details)}), {fail_reason}." | |
| else: | |
| evidence = f"The following did not meet the {threshold_name} threshold (> {threshold:.2f}): {', '.join(unverified_tokens)}." | |
| elif claim_type in ['token_justification_claim', 'token_begruendung_anspruch']: | |
| token_val = details.get('token') or details.get('tokens') | |
| if isinstance(token_val, list): | |
| token = ", ".join(map(str, token_val)) | |
| else: | |
| token = token_val | |
| justification = details.get('justification') or details.get('begruendung') | |
| input_prompt = analysis_data.get('prompt', '') | |
| generated_text = analysis_data.get('generated_text', '') | |
| if not all([token, justification, input_prompt, generated_text]): | |
| evidence = "Missing data for justification verification (token, justification, or prompt)." | |
| else: | |
| api_config = init_qwen_api() | |
| if api_config: | |
| verification = _cached_verify_token_justification(api_config, analysis_data['method'], input_prompt, generated_text, token, justification) | |
| is_verified = verification.get('is_verified', False) | |
| evidence = verification.get('reasoning', "Failed to get semantic reasoning for justification.") | |
| else: | |
| is_verified = False | |
| evidence = "API key not configured for semantic verification." | |
| except Exception as e: | |
| evidence = f"An error occurred during verification: {str(e)}" | |
| verification_results.append({ | |
| 'claim_text': claim.get('claim_text', 'N/A'), | |
| 'verified': is_verified, | |
| 'evidence': evidence | |
| }) | |
| return verification_results | |
| # --- End Faithfulness Verification --- | |
| def create_heatmap_visualization(attributions, method_name="Attribution"): | |
| # Creates a heatmap visualization from attribution scores. | |
| try: | |
| # Get the HTML content from the attributions. | |
| html_content = attributions.show(display=False, return_html=True) | |
| if not html_content: | |
| st.error(tr("error_inseq_no_html").format(method_name=method_name)) | |
| return None, None, None, None | |
| # Parse the HTML to extract the data table. | |
| soup = BeautifulSoup(html_content, 'html.parser') | |
| table = soup.find('table') | |
| if not table: | |
| st.error(tr("error_no_table_in_html").format(method_name=method_name)) | |
| return None, None, None, None | |
| # A more structured approach to parsing the HTML. | |
| header_row_element = table.find('thead') | |
| if header_row_element: | |
| headers = [th.get_text(strip=True) for th in header_row_element.find_all('th')[1:]] | |
| else: | |
| # Fallback if no <thead> is found. | |
| first_row = table.find('tr') | |
| if not first_row: | |
| st.error(tr("error_table_no_rows").format(method_name=method_name)) | |
| return None, None, None, None | |
| headers = [th.get_text(strip=True) for th in first_row.find_all('th')[1:]] | |
| data_rows = [] | |
| row_labels = [] | |
| # Find all `<tbody>` elements and iterate through their rows. | |
| table_bodies = table.find_all('tbody') | |
| if not table_bodies: | |
| # Fallback if no <tbody> is found. | |
| all_trs = table.find_all('tr') | |
| data_trs = all_trs[1:] if len(all_trs) > 1 else [] | |
| else: | |
| data_trs = [] | |
| for tbody in table_bodies: | |
| data_trs.extend(tbody.find_all('tr')) | |
| for tr_element in data_trs: | |
| all_cells = tr_element.find_all(['th', 'td']) | |
| if not all_cells or len(all_cells) <= 1: | |
| continue | |
| row_labels.append(all_cells[0].get_text(strip=True)) | |
| # Convert text values to float, handling empty strings as 0. | |
| row_data = [] | |
| for cell in all_cells[1:]: | |
| text_val = cell.get_text(strip=True) | |
| # Remove non-breaking spaces. | |
| clean_text = text_val.replace('\xa0', '').strip() | |
| if clean_text: | |
| try: | |
| row_data.append(float(clean_text)) | |
| except ValueError: | |
| # Default to 0 if conversion fails. | |
| row_data.append(0.0) | |
| else: | |
| row_data.append(0.0) | |
| data_rows.append(row_data) | |
| # Create the dataframe from the parsed data. | |
| if not data_rows or not data_rows[0]: | |
| st.error(tr("error_failed_to_parse_rows").format(method_name=method_name)) | |
| return None, None, None, None | |
| # --- Make token labels unique for duplicates --- | |
| def make_labels_unique(labels): | |
| counts = {} | |
| new_labels = [] | |
| # First, count all occurrences to decide which ones need numbering. | |
| label_counts = {label: labels.count(label) for label in set(labels)} | |
| for label in labels: | |
| if label_counts[label] > 1: | |
| counts[label] = counts.get(label, 0) + 1 | |
| new_labels.append(f"{label} ({counts[label]})") | |
| else: | |
| new_labels.append(label) | |
| return new_labels | |
| unique_row_labels = make_labels_unique(row_labels) | |
| unique_headers = make_labels_unique(headers) | |
| parsed_df = pd.DataFrame(data_rows, index=unique_row_labels, columns=unique_headers) | |
| attribution_scores = parsed_df.values | |
| # Clean tokens for display. | |
| clean_headers = parsed_df.columns.tolist() | |
| clean_row_labels = parsed_df.index.tolist() | |
| # Use numerical indices for the heatmap to handle duplicate labels. | |
| x_indices = list(range(len(clean_headers))) | |
| y_indices = list(range(len(clean_row_labels))) | |
| # Prepare custom data for hover labels. | |
| custom_data = np.empty(attribution_scores.shape, dtype=object) | |
| for i in range(len(clean_row_labels)): | |
| for j in range(len(clean_headers)): | |
| custom_data[i, j] = (clean_row_labels[i], clean_headers[j]) | |
| fig = go.Figure(data=go.Heatmap( | |
| z=attribution_scores, | |
| x=x_indices, | |
| y=y_indices, | |
| customdata=custom_data, | |
| hovertemplate="Input: %{customdata[0]}<br>Generated: %{customdata[1]}<br>Score: %{z:.4f}<extra></extra>", | |
| colorscale='Plasma', | |
| hoverongaps=False, | |
| )) | |
| fig.update_layout( | |
| title=tr('heatmap_title').format(method_name=method_name), | |
| xaxis_title=tr('heatmap_xaxis'), | |
| yaxis_title=tr('heatmap_yaxis'), | |
| xaxis=dict( | |
| tickmode='array', | |
| tickvals=x_indices, | |
| ticktext=clean_headers, | |
| tickangle=45 | |
| ), | |
| yaxis=dict( | |
| tickmode='array', | |
| tickvals=y_indices, | |
| ticktext=clean_row_labels, | |
| autorange='reversed' | |
| ), | |
| height=max(400, len(clean_row_labels) * 30), | |
| width=max(600, len(clean_headers) * 50) | |
| ) | |
| # Save the plot to a buffer. | |
| buffer = BytesIO() | |
| try: | |
| fig.write_image(buffer, format='png', scale=2) | |
| buffer.seek(0) | |
| except Exception as e: | |
| print(f"Warning: Could not generate static image (Kaleido error?): {e}") | |
| buffer = None | |
| return fig, html_content, buffer, parsed_df | |
| except Exception as e: | |
| st.error(tr("error_creating_heatmap").format(e=str(e))) | |
| return None, None, None, None | |
| def start_new_analysis(prompt, max_tokens, enable_explanations): | |
| # Clears old results and starts a new analysis. | |
| # Clear old results from the session state. | |
| keys_to_clear = [ | |
| 'generated_text', | |
| 'all_attributions' | |
| ] | |
| for key in keys_to_clear: | |
| if key in st.session_state: | |
| del st.session_state[key] | |
| # Clear any old cached items. | |
| for key in list(st.session_state.keys()): | |
| if key.startswith('influential_docs_'): | |
| del st.session_state[key] | |
| # Update the text area with the new prompt. | |
| st.session_state.attr_prompt = prompt | |
| # Set parameters for the new analysis. | |
| st.session_state.run_request = { | |
| "prompt": prompt, | |
| "max_tokens": max_tokens, | |
| "enable_explanations": enable_explanations | |
| } | |
| def update_cache_with_explanation(prompt, method_name, explanation): | |
| cache_file = os.path.join("cache", "cached_attribution_results.json") | |
| if not os.path.exists(cache_file): return | |
| try: | |
| with open(cache_file, "r", encoding="utf-8") as f: | |
| cached_data = json.load(f) | |
| if prompt in cached_data: | |
| if "explanations" not in cached_data[prompt]: | |
| cached_data[prompt]["explanations"] = {} | |
| cached_data[prompt]["explanations"][method_name] = explanation | |
| with open(cache_file, "w", encoding="utf-8") as f: | |
| json.dump(cached_data, f, ensure_ascii=False, indent=4) | |
| print(f"Saved explanation for {method_name} to cache.") | |
| except Exception as e: | |
| print(f"Failed to update cache with explanation: {e}") | |
| def update_cache_with_faithfulness(prompt, method_name, verification_results): | |
| cache_file = os.path.join("cache", "cached_attribution_results.json") | |
| if not os.path.exists(cache_file): return | |
| try: | |
| with open(cache_file, "r", encoding="utf-8") as f: | |
| cached_data = json.load(f) | |
| if prompt in cached_data: | |
| if "faithfulness" not in cached_data[prompt]: | |
| cached_data[prompt]["faithfulness"] = {} | |
| cached_data[prompt]["faithfulness"][method_name] = verification_results | |
| with open(cache_file, "w", encoding="utf-8") as f: | |
| json.dump(cached_data, f, ensure_ascii=False, indent=4) | |
| print(f"Saved faithfulness for {method_name} to cache.") | |
| except Exception as e: | |
| print(f"Failed to update cache with faithfulness: {e}") | |
| def run_analysis(prompt, max_tokens, enable_explanations, force_exact_num_tokens=False): | |
| # Runs the full analysis pipeline. | |
| if not prompt.strip(): | |
| st.warning(tr('please_enter_prompt_warning')) | |
| return | |
| # Check for cached results first | |
| cache_file = os.path.join("cache", "cached_attribution_results.json") | |
| if os.path.exists(cache_file): | |
| with open(cache_file, "r", encoding="utf-8") as f: | |
| cached_data = json.load(f) | |
| if prompt in cached_data: | |
| print("Loading full attribution analysis from cache.") | |
| cached_result = cached_data[prompt] | |
| # Check if influential_docs are missing and update the cache if possible | |
| if "influential_docs" not in cached_result: | |
| try: | |
| print(f"Updating cache for '{prompt}' with missing influence docs...") | |
| lang = st.session_state.get('lang', 'en') | |
| # This call should hit the Streamlit cache and be fast | |
| missing_docs = get_influential_docs(prompt, lang) | |
| if missing_docs: | |
| cached_result["influential_docs"] = missing_docs | |
| # Save updated cache back to file | |
| with open(cache_file, "w", encoding="utf-8") as f: | |
| json.dump(cached_data, f, ensure_ascii=False, indent=4) | |
| print("Cache updated successfully.") | |
| except Exception as e: | |
| print(f"Could not update cache with influence docs: {e}") | |
| # Populate session state from the comprehensive cache | |
| st.session_state.generated_text = cached_result["generated_text"] | |
| st.session_state.prompt = prompt | |
| st.session_state.enable_explanations = enable_explanations | |
| st.session_state.qwen_api_config = init_qwen_api() if enable_explanations else None | |
| # Reconstruct attribution objects and store explanations/faithfulness | |
| reconstructed_attributions = {} | |
| for method, data in cached_result["html_contents"].items(): | |
| reconstructed_attributions[method] = CachedAttribution(data) | |
| # Use a consistent key for caching in session state | |
| cache_key_base = f"{method}_{cached_result['generated_text']}" | |
| if "explanation" in data: | |
| st.session_state[f"explanation_{cache_key_base}"] = data["explanation"] | |
| if "faithfulness_results" in data: | |
| st.session_state[f"faithfulness_check_{cache_key_base}"] = data["faithfulness_results"] | |
| # Load new structured cache | |
| if "explanations" in cached_result and method in cached_result["explanations"]: | |
| st.session_state[f"explanation_{cache_key_base}"] = cached_result["explanations"][method] | |
| if "faithfulness" in cached_result and method in cached_result["faithfulness"]: | |
| st.session_state[f"faithfulness_check_{cache_key_base}"] = cached_result["faithfulness"][method] | |
| st.session_state.all_attributions = reconstructed_attributions | |
| # Store influential docs | |
| if "influential_docs" in cached_result: | |
| # Use a key that the UI part can check for | |
| st.session_state.cached_influential_docs = cached_result["influential_docs"] | |
| st.success(tr('analysis_complete_success')) | |
| return | |
| # If not in cache, check if models exist before trying to load | |
| model_path = "./models/OLMo-2-1124-7B" | |
| if not os.path.exists(model_path): | |
| st.info("This live demo is running in a static environment. Only the pre-cached example prompts are available. Please select an example to view its analysis.") | |
| return | |
| # Load the models. | |
| with st.spinner(tr('loading_models_spinner')): | |
| attribution_models, tokenizer, base_model, device = load_all_attribution_models() | |
| if not attribution_models: | |
| st.error(tr('failed_to_load_models_error')) | |
| return | |
| st.session_state.qwen_api_config = init_qwen_api() if enable_explanations else None | |
| st.session_state.enable_explanations = enable_explanations | |
| st.session_state.prompt = prompt | |
| # Generate text and attributions. | |
| with st.spinner(tr('running_attribution_analysis_spinner')): | |
| try: | |
| generated_text, all_attributions = generate_all_attribution_analyses( | |
| attribution_models, | |
| tokenizer, | |
| base_model, | |
| device, | |
| prompt, | |
| max_tokens, | |
| force_exact_num_tokens=force_exact_num_tokens | |
| ) | |
| except Exception as e: | |
| st.error(f"Error in attribution analysis: {str(e)}") | |
| # Let the rest of the function know it failed. | |
| generated_text, all_attributions = None, None | |
| if not generated_text or not all_attributions: | |
| st.error(tr('failed_to_generate_analysis_error')) | |
| return | |
| # Store the results in the session state. | |
| st.session_state.generated_text = generated_text | |
| st.session_state.all_attributions = all_attributions | |
| # --- New: Save the new result back to the cache --- | |
| try: | |
| cache_file = os.path.join("cache", "cached_attribution_results.json") | |
| os.makedirs("cache", exist_ok=True) | |
| # Load existing cache or create new | |
| if os.path.exists(cache_file): | |
| with open(cache_file, "r", encoding="utf-8") as f: | |
| cached_data = json.load(f) | |
| else: | |
| cached_data = {} | |
| # Add new result | |
| html_contents = {method: attr.show(display=False, return_html=True) for method, attr in all_attributions.items()} | |
| # Also fetch influential docs to cache them | |
| lang = st.session_state.get('lang', 'en') | |
| docs_to_cache = get_influential_docs(prompt, lang) | |
| cached_data[prompt] = { | |
| "generated_text": generated_text, | |
| "html_contents": html_contents, | |
| "influential_docs": docs_to_cache | |
| } | |
| # Write back to file | |
| with open(cache_file, "w", encoding="utf-8") as f: | |
| json.dump(cached_data, f, ensure_ascii=False, indent=4) | |
| print(f"Saved new analysis for '{prompt}' to cache.") | |
| except Exception as e: | |
| print(f"Warning: Could not save result to cache file. {e}") | |
| # --- End new section --- | |
| # Clean up models to free memory. | |
| del attribution_models | |
| del tokenizer | |
| del base_model | |
| gc.collect() | |
| if device == 'mps': | |
| torch.mps.empty_cache() | |
| elif device == 'cuda': | |
| torch.cuda.empty_cache() | |
| st.success(tr('analysis_complete_success')) | |
| def show_attribution_analysis(): | |
| # Shows the main attribution analysis page. | |
| # Add some CSS for icons. | |
| st.markdown('<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap-icons@1.10.5/font/bootstrap-icons.css">', unsafe_allow_html=True) | |
| st.markdown(f"<h1>{tr('attr_page_title')}</h1>", unsafe_allow_html=True) | |
| st.markdown(f"{tr('attr_page_desc')}", unsafe_allow_html=True) | |
| # Check if a new analysis has been requested by the user. | |
| if 'run_request' in st.session_state: | |
| request = st.session_state.pop('run_request') | |
| run_analysis( | |
| prompt=request['prompt'], | |
| max_tokens=request['max_tokens'], | |
| enable_explanations=request['enable_explanations'] | |
| ) | |
| # Set up the main layout. | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.markdown(f"<h2>{tr('input_header')}</h2>", unsafe_allow_html=True) | |
| # Get the current language from the session state. | |
| lang = st.session_state.get('lang', 'en') | |
| # Example prompts for English and German. | |
| example_prompts = { | |
| 'en': [ | |
| "The capital of France is", | |
| "The first person to walk on the moon was", | |
| "To be or not to be, that is the", | |
| "Once upon a time, in a land far, far away,", | |
| "The chemical formula for water is", | |
| "A stitch in time saves", | |
| "The opposite of hot is", | |
| "The main ingredients of a pizza are", | |
| "She opened the door and saw" | |
| ], | |
| 'de': [ | |
| "Die Hauptstadt von Frankreich ist", | |
| "Die erste Person auf dem Mond war", | |
| "Sein oder Nichtsein, das ist hier die", | |
| "Es war einmal, in einem weit, weit entfernten Land,", | |
| "Die chemische Formel für Wasser ist", | |
| "Was du heute kannst besorgen, das verschiebe nicht auf", | |
| "Das Gegenteil von heiß ist", | |
| "Die Hauptzutaten einer Pizza sind", | |
| "Sie öffnete die Tür und sah" | |
| ] | |
| } | |
| st.markdown('**<i class="bi bi-lightbulb"></i> Example Prompts:**', unsafe_allow_html=True) | |
| cols = st.columns(3) | |
| for i, example in enumerate(example_prompts[lang][:9]): | |
| with cols[i % 3]: | |
| st.button( | |
| example, | |
| key=f"example_{i}", | |
| use_container_width=True, | |
| on_click=start_new_analysis, | |
| args=(example, 10, st.session_state.get('enable_explanations', True)) | |
| ) | |
| # Text input area for the user's prompt. | |
| prompt = st.text_area( | |
| tr('enter_prompt'), | |
| value=st.session_state.get('attr_prompt', ""), | |
| height=100, | |
| help=tr('enter_prompt_help'), | |
| placeholder="Sadly no GPU available. Please select an example above.", | |
| disabled=True | |
| ) | |
| # Slider for the number of tokens to generate. | |
| max_tokens = st.slider( | |
| tr('max_new_tokens_slider'), | |
| min_value=1, | |
| max_value=50, | |
| value=5, | |
| help=tr('max_new_tokens_slider_help'), | |
| disabled=True | |
| ) | |
| # Checkbox to enable or disable AI explanations. | |
| enable_explanations = st.checkbox( | |
| tr('enable_ai_explanations'), | |
| value=True, | |
| help=tr('enable_ai_explanations_help') | |
| ) | |
| # Button to start the analysis. | |
| st.button( | |
| tr('generate_and_analyze_button'), | |
| type="primary", | |
| on_click=start_new_analysis, | |
| args=(prompt, max_tokens, enable_explanations), | |
| disabled=True | |
| ) | |
| with col2: | |
| st.markdown(f"<h2>{tr('output_header')}</h2>", unsafe_allow_html=True) | |
| if hasattr(st.session_state, 'generated_text'): | |
| st.subheader(tr('generated_text_subheader')) | |
| # Extract the generated part of the text. | |
| prompt_part = st.session_state.prompt | |
| full_text = st.session_state.generated_text | |
| generated_part = full_text | |
| if full_text.startswith(prompt_part): | |
| generated_part = full_text[len(prompt_part):].lstrip() | |
| else: | |
| # A fallback in case tokenization changes the prompt slightly. | |
| generated_part = full_text.replace(prompt_part, "", 1).strip() | |
| # Clean up the generated text for display. | |
| cleaned_generated_part = re.sub(r'\n{2,}', '\n', generated_part).strip() | |
| escaped_generated = html.escape(cleaned_generated_part) | |
| escaped_prompt = html.escape(prompt_part) | |
| st.markdown(f""" | |
| <div style="background-color: #2b2b2b; color: #ffffff; padding: 1.2rem; border-radius: 10px; margin: 1rem 0; border: 1px solid #444;"> | |
| <strong>{tr('input_label')}</strong> <span style="color: #60a5fa;">{escaped_prompt}</span><br> | |
| <strong>{tr('generated_label')}</strong> <span style="font-weight: bold; color: #fca5a5; white-space: pre-wrap;">{escaped_generated}</span> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Display the visualizations for each method. | |
| if hasattr(st.session_state, 'all_attributions'): | |
| st.header(tr('attribution_analysis_results_header')) | |
| # Create tabs for each analysis method. | |
| tab_titles = [ | |
| tr('saliency_tab'), | |
| tr('attr_tab'), | |
| tr('occlusion_tab') | |
| ] | |
| tabs = st.tabs(tab_titles) | |
| # Define the order of the methods in the tabs. | |
| methods = { | |
| "saliency": { | |
| "tab": tabs[0], | |
| "title": tr('saliency_title'), | |
| "description": tr('saliency_viz_desc') | |
| }, | |
| "integrated_gradients": { | |
| "tab": tabs[1], | |
| "title": tr('attr_title'), | |
| "description": tr('attr_viz_desc') | |
| }, | |
| "occlusion": { | |
| "tab": tabs[2], | |
| "title": tr('occlusion_title'), | |
| "description": tr('occlusion_viz_desc') | |
| } | |
| } | |
| # Generate and display the visualization for each method. | |
| for method_name, method_info in methods.items(): | |
| with method_info["tab"]: | |
| st.subheader(f"{method_info['title']} Analysis") | |
| # Generate the heatmap. | |
| with st.spinner(tr('creating_viz_spinner').format(method_title=method_info['title'])): | |
| heatmap_fig, html_content, heatmap_buffer, scores_df = create_heatmap_visualization( | |
| st.session_state.all_attributions[method_name], | |
| method_name=method_info['title'] | |
| ) | |
| if heatmap_fig: | |
| st.plotly_chart(heatmap_fig, use_container_width=True) | |
| # Add an explanation of how to read the heatmap. | |
| explanation_html = f""" | |
| <div style="background-color: #0E1117; border-radius: 10px; padding: 15px; margin: 10px 0; border: 1px solid #262730;"> | |
| <h4 style="color: #FAFAFA; margin-bottom: 10px;">{tr('how_to_read_heatmap')}</h4> | |
| <ul style="color: #DCDCDC; margin-left: 20px; padding-left: 0;"> | |
| <li style="margin-bottom: 5px;"><strong>{tr('xaxis_label')}:</strong> {tr('xaxis_desc')}</li> | |
| <li style="margin-bottom: 5px;"><strong>{tr('yaxis_label')}:</strong> {tr('yaxis_desc')}</li> | |
| <li style="margin-bottom: 5px;"><strong>{tr('color_intensity_label')}:</strong> {tr('color_intensity_desc')}</li> | |
| <li style="margin-bottom: 5px;"><strong>{tr('interpretation_label')}:</strong> {tr('interpretation_desc')}</li> | |
| <li style="margin-bottom: 5px;"><strong>{tr('special_tokens_label')}:</strong> {tr('special_tokens_desc')}</li> | |
| </ul> | |
| </div> | |
| """ | |
| st.markdown(explanation_html, unsafe_allow_html=True) | |
| # Generate an AI explanation for the heatmap. | |
| if (st.session_state.get('enable_explanations') and | |
| st.session_state.get('qwen_api_config') and | |
| heatmap_buffer is not None and scores_df is not None): | |
| explanation_cache_key = f"explanation_{method_name}_{st.session_state.generated_text}" | |
| # Get the explanation from the cache or generate it. | |
| if explanation_cache_key not in st.session_state: | |
| with st.spinner(tr('generating_ai_explanations_spinner').format(method_title=method_info['title'])): | |
| explanation = explain_heatmap_with_csv_data( | |
| st.session_state.qwen_api_config, | |
| heatmap_buffer, | |
| scores_df, | |
| st.session_state.prompt, | |
| st.session_state.generated_text, | |
| method_name | |
| ) | |
| st.session_state[explanation_cache_key] = explanation | |
| # Update cache file | |
| update_cache_with_explanation(st.session_state.prompt, method_name, explanation) | |
| explanation = st.session_state.get(explanation_cache_key) | |
| if explanation and not explanation.startswith("Error:"): | |
| simple_desc = tr(METHOD_DESC_KEYS.get(method_name, "unsupported_method_desc")) | |
| st.markdown(f"#### {tr('what_this_method_shows')}") | |
| st.markdown(f""" | |
| <div style="background-color: #2f3f70; color: #f5f7fb; padding: 1.2rem; border-radius: 12px; margin-bottom: 1rem; box-shadow: 0 12px 24px rgba(47, 63, 112, 0.35);"> | |
| <p style='font-size: 1.05em; font-weight: 500; margin:0; color: #f5f7fb;'>{simple_desc}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| html_explanation = markdown.markdown(explanation) | |
| st.markdown(f"#### {tr('ai_generated_analysis')}") | |
| st.markdown(f""" | |
| <div style="background-color: #2b2b2b; color: #ffffff; padding: 1.2rem; border-radius: 10px; border-left: 4px solid #dcae36; font-size: 0.9rem; margin-bottom: 1rem;"> | |
| {html_explanation} | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Faithfulness Check Expander | |
| with st.expander(tr('faithfulness_check_expander')): | |
| st.markdown(tr('faithfulness_check_explanation_html'), unsafe_allow_html=True) | |
| with st.spinner(tr('running_faithfulness_check_spinner')): | |
| try: | |
| # Use a cache key to avoid re-running the check unnecessarily. | |
| check_cache_key = f"faithfulness_check_{method_name}_{st.session_state.generated_text}" | |
| if check_cache_key not in st.session_state: | |
| claims = _cached_extract_claims_from_explanation( | |
| st.session_state.qwen_api_config, | |
| explanation, | |
| method_name | |
| ) | |
| if claims: | |
| analysis_data = { | |
| 'scores_df': scores_df, | |
| 'method': method_name, | |
| 'prompt': st.session_state.prompt, | |
| 'generated_text': st.session_state.generated_text | |
| } | |
| verification_results = verify_claims(claims, analysis_data) | |
| st.session_state[check_cache_key] = verification_results | |
| # Update cache file | |
| update_cache_with_faithfulness(st.session_state.prompt, method_name, verification_results) | |
| else: | |
| st.session_state[check_cache_key] = [] | |
| verification_results = st.session_state[check_cache_key] | |
| if verification_results: | |
| st.markdown(f"<h6>{tr('faithfulness_check_results_header')}</h6>", unsafe_allow_html=True) | |
| for result in verification_results: | |
| status_text = tr('verified_status') if result['verified'] else tr('contradicted_status') | |
| st.markdown(f""" | |
| <div style="margin-bottom: 1rem; padding: 0.8rem; border-radius: 8px; border-left: 5px solid {'#28a745' if result['verified'] else '#dc3545'}; background-color: #1a1a1a;"> | |
| <p style="margin-bottom: 0.3rem;"><strong>{tr('claim_label')}:</strong> <em>"{result['claim_text']}"</em></p> | |
| <p style="margin-bottom: 0.3rem;"><strong>{tr('status_label')}:</strong> {status_text}</p> | |
| <p style="margin-bottom: 0;"><strong>{tr('evidence_label')}:</strong> {result['evidence']}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| else: | |
| st.info(tr('no_verifiable_claims_info')) | |
| except Exception as e: | |
| st.error(tr('faithfulness_check_error').format(e=str(e))) | |
| # Add download buttons for the results. | |
| st.subheader(tr("download_results_subheader")) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if html_content: | |
| st.download_button( | |
| label=tr("download_html_button").format(method_title=method_info['title']), | |
| data=html_content, | |
| file_name=f"{method_name}_analysis.html", | |
| mime="text/html", | |
| key=f"html_{method_name}" | |
| ) | |
| if scores_df is not None: | |
| st.download_button( | |
| label=tr("download_csv_button"), | |
| data=scores_df.to_csv().encode('utf-8'), | |
| file_name=f"{method_name}_scores.csv", | |
| mime="text/csv", | |
| key=f"csv_raw_{method_name}" | |
| ) | |
| with col2: | |
| if heatmap_fig: | |
| img_bytes = heatmap_fig.to_image(format="png", scale=2) | |
| st.download_button( | |
| label=tr("download_png_button").format(method_title=method_info['title']), | |
| data=img_bytes, | |
| file_name=f"{method_name}_heatmap.png", | |
| mime="image/png", | |
| key=f"png_{method_name}" | |
| ) | |
| # Display the influence tracer section. | |
| st.markdown("---") | |
| st.markdown(f'<h3><i class="bi bi-compass"></i> {tr("influence_tracer_title")}</h3>', unsafe_allow_html=True) | |
| st.markdown(f"<div style='font-size: 1.1rem;'>{tr('influence_tracer_desc')}</div>", unsafe_allow_html=True) | |
| # Add a visual explanation of cosine similarity. | |
| # Get translated text. | |
| sentence_a = tr('influence_example_sentence_a') | |
| sentence_b = tr('influence_example_sentence_b') | |
| # Create the SVG for the diagram. | |
| svg_code = f""" | |
| <svg width="250" height="150" viewBox="0 0 250 150" xmlns="http://www.w3.org/2000/svg"> | |
| <line x1="10" y1="130" x2="240" y2="130" stroke="#555" stroke-width="2"></line> | |
| <line x1="10" y1="130" x2="10" y2="10" stroke="#555" stroke-width="2"></line> | |
| <!-- Corrected angle arc and theta position --> | |
| <path d="M 49 123 A 40 40 0 0 0 42 107" fill="none" stroke="#FFD700" stroke-width="2"></path> | |
| <text x="50" y="115" font-family="monospace" font-size="12" fill="#FFD700">θ</text> | |
| <line x1="10" y1="130" x2="150" y2="30" stroke="#87CEEB" stroke-width="3"></line> | |
| <text x="155" y="25" font-family="monospace" font-size="12" fill="#87CEEB">Vector A</text> | |
| <text x="155" y="40" font-family="monospace" font-size="10" fill="#aaa">{sentence_a}</text> | |
| <line x1="10" y1="130" x2="170" y2="100" stroke="#90EE90" stroke-width="3"></line> | |
| <text x="175" y="95" font-family="monospace" font-size="12" fill="#90EE90">Vector B</text> | |
| <text x="175" y="110" font-family="monospace" font-size="10" fill="#aaa">{sentence_b}</text> | |
| </svg> | |
| """ | |
| # Encode the SVG to base64. | |
| encoded_svg = base64.b64encode(svg_code.encode("utf-8")).decode("utf-8") | |
| image_uri = f"data:image/svg+xml;base64,{encoded_svg}" | |
| # Display the explanation and diagram. | |
| st.markdown(f""" | |
| <div style="background-color: #2b2b2b; border-radius: 10px; padding: 1.5rem; margin: 1rem 0; border-left: 4px solid #FFD700;"> | |
| <h4 style="color: #FFD700; margin-top: 0; margin-bottom: 1rem;">{tr('how_influence_is_found_header')}</h4> | |
| <div> | |
| <p style="font-size: 1rem;">{tr('how_influence_is_found_desc')}</p> | |
| <div style="font-family: 'SF Mono', 'Consolas', 'Menlo', monospace; margin-top: 1.5rem; font-size: 0.95em;"> | |
| <p>{tr('influence_step_1_title')}: {tr('influence_step_1_desc')}</p> | |
| <p>{tr('influence_step_2_title')}: {tr('influence_step_2_desc')}</p> | |
| <p>{tr('influence_step_3_title')}: {tr('influence_step_3_desc')}</p> | |
| </div> | |
| </div> | |
| <div style="text-align: center; margin-top: 2rem;"> | |
| <img src="{image_uri}" alt="Cosine Similarity Diagram" /> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.write("") | |
| if hasattr(st.session_state, 'generated_text'): | |
| # First, check if influential docs are available in the cache from session_state | |
| if 'cached_influential_docs' in st.session_state: | |
| influential_docs = st.session_state.pop('cached_influential_docs') # Use and remove | |
| else: | |
| with st.spinner(tr('running_influence_trace_spinner')): | |
| lang = st.session_state.get('lang', 'en') | |
| influential_docs = get_influential_docs(st.session_state.prompt, lang) | |
| # Display the results. | |
| if influential_docs: | |
| st.markdown(f"#### {tr('top_influential_docs_header').format(num_docs=len(influential_docs))}") | |
| # A nice visualization for the influential documents. | |
| for i, doc in enumerate(influential_docs): | |
| colors = ["#A78BFA", "#7F9CF5", "#6EE7B7", "#FBBF24", "#F472B6"] | |
| card_color = colors[i % len(colors)] | |
| full_text = doc['text'] | |
| highlight_sentence = doc.get('highlight_sentence', '') | |
| highlighted_html = "" | |
| lang = st.session_state.get('lang', 'en') | |
| if highlight_sentence: | |
| # Normalize the sentence to be highlighted. | |
| normalized_highlight = re.sub(r'\s+', ' ', highlight_sentence).strip() | |
| # Use fuzzy matching to find the best match in the document. | |
| splitter = SentenceSplitter(language=lang) | |
| sentences_in_doc = splitter.split(text=full_text) | |
| if sentences_in_doc: | |
| best_match, score = process.extractOne(normalized_highlight, sentences_in_doc) | |
| start_index = full_text.find(best_match) | |
| if start_index != -1: | |
| end_index = start_index + len(best_match) | |
| # Create a context window around the matched sentence. | |
| context_window = 500 | |
| snippet_start = max(0, start_index - context_window) | |
| snippet_end = min(len(full_text), end_index + context_window) | |
| # Reconstruct the HTML with the highlighted sentence. | |
| before = html.escape(full_text[snippet_start:start_index]) | |
| highlight = html.escape(best_match) | |
| after = html.escape(full_text[end_index:snippet_end]) | |
| # Add ellipses if we're not showing the full text. | |
| start_ellipsis = "... " if snippet_start > 0 else "" | |
| end_ellipsis = " ..." if snippet_end < len(full_text) else "" | |
| highlighted_html = ( | |
| f"{start_ellipsis}{before}" | |
| f'<mark style="background-color: {card_color}77; color: #DCDCDC; padding: 2px 4px; border-radius: 4px; font-weight: bold;">{highlight}</mark>' | |
| f"{after}{end_ellipsis}" | |
| ) | |
| # If no highlight was applied, just show the full text. | |
| if not highlighted_html: | |
| highlighted_html = html.escape(full_text) | |
| st.markdown(f""" | |
| <div style="border: 1px solid #262730; border-left: 5px solid {card_color}; border-radius: 10px; padding: 1.5rem; margin-bottom: 1.5rem; background-color: #0E1117; box-shadow: 0 4px 8px rgba(0,0,0,0.2);"> | |
| <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 1rem;"> | |
| <span style="font-size: 1.1rem; color: #FAFAFA; font-weight: 600;"><i class="bi bi-journal-text"></i> {tr('source_label')}: {doc['source']}</span> | |
| <span style="font-size: 1.1rem; color: {card_color}; background-color: {card_color}22; padding: 0.3rem 0.8rem; border-radius: 15px; font-weight: bold;"> | |
| <i class="bi bi-stars"></i> {tr('similarity_label')}: {doc['similarity']:.3f} | |
| </span> | |
| </div> | |
| <div style="background-color: #1a1a1a; color: #DCDCDC; padding: 1rem; border-radius: 8px; font-family: 'Courier New', Courier, monospace; white-space: pre-wrap; word-wrap: break-word; max-height: 300px; overflow-y: auto;"> | |
| {highlighted_html.strip()} | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| else: | |
| # Give a helpful message if the index is missing. | |
| if not os.path.exists(INDEX_PATH) or not os.path.exists(MAPPING_PATH): | |
| st.warning(tr('influence_index_not_found_warning')) | |
| else: | |
| st.info(tr('no_influential_docs_found')) | |
| else: | |
| st.info(tr('run_analysis_for_influence_info')) | |
| # Show the feedback survey in the sidebar. | |
| #if 'all_attributions' in st.session_state: | |
| # display_attribution_feedback() | |
| if __name__ == "__main__": | |
| show_attribution_analysis() |