FBDF / backend /base.py
Firas HADJ KACEM
created the interface
5c7385e
#Base classes and utilities for TokenSHAP
# SPDX-FileCopyrightText: 2023-2024 The TokenSHAP Authors
import numpy as np
import pandas as pd
import re
import random
from typing import List, Dict, Optional, Tuple, Any
from tqdm.auto import tqdm
class ModelBase:
"""Base model interface"""
def generate(self, **kwargs) -> str:
"""Generate a response for the given input"""
raise NotImplementedError
class BaseSHAP:
"""Base class for SHAP value calculation with Monte Carlo sampling"""
def __init__(self, model: ModelBase, debug: bool = False):
"""
Initialize BaseSHAP
Args:
model: Model to analyze
debug: Enable debug output
"""
self.model = model
self.cache = {} # Cache for model responses
self.debug = debug
def _calculate_baseline(self, content: str) -> Dict[str, Any]:
"""Calculate baseline model response for full content"""
# Content here should already have the prefix/suffix if needed
baseline = self.model.generate(prompt=content)
if self.debug:
print(f"Baseline prediction: {baseline['label']}")
return baseline
def _prepare_generate_args(self, content: str, **kwargs) -> Dict:
"""Prepare arguments for model.generate()"""
raise NotImplementedError
def _get_samples(self, content: str) -> List[str]:
"""Get samples from content"""
raise NotImplementedError
def _prepare_combination_args(self, combination: List[str], original_content: str) -> Dict:
"""Prepare model arguments for a combination"""
raise NotImplementedError
def _get_combination_key(self, combination: List[str], indexes: Tuple[int, ...]) -> str:
"""Get unique key for combination"""
raise NotImplementedError
def _get_all_combinations(self, samples: List[str], sampling_ratio: float = 0.0,
max_combinations: Optional[int] = None) -> Dict[str, Tuple[List[str], Tuple[int, ...]]]:
"""
Get all possible combinations of samples with their indices
Args:
samples: List of samples (e.g., tokens)
sampling_ratio: Ratio of combinations to sample (0-1)
max_combinations: Maximum number of combinations to generate
Returns:
Dictionary mapping combination keys to (combination, indices) tuples
"""
n = len(samples)
# Always include combinations that exclude exactly one token
essential_combinations = {}
for i in range(n):
combination = samples.copy()
del combination[i]
indices = tuple(j for j in range(n) if j != i)
key = f"omit_{i}"
essential_combinations[key] = (combination, indices)
# Calculate total possible combinations and sampling count
if sampling_ratio <= 0:
# Just return essential combinations
return essential_combinations
total_combinations = 2**n - 1 # All non-empty combinations
sample_count = int(total_combinations * sampling_ratio)
if max_combinations is not None:
sample_count = min(sample_count, max_combinations)
if sample_count <= len(essential_combinations):
return essential_combinations
# Randomly sample additional combinations
all_combinations = essential_combinations.copy()
additional_needed = sample_count - len(essential_combinations)
# Generate random combinations
combinations_added = 0
max_attempts = additional_needed * 10 # Limit attempts to avoid infinite loop
attempts = 0
while combinations_added < additional_needed and attempts < max_attempts:
# Decide how many tokens to include
subset_size = random.randint(1, n-1) # At least 1, at most n-1
# Randomly select indices
indices = tuple(sorted(random.sample(range(n), subset_size)))
# Create combination
combination = [samples[i] for i in indices]
key = f"random_{','.join(str(i) for i in indices)}"
# Only add if not already present
if key not in all_combinations:
all_combinations[key] = (combination, indices)
combinations_added += 1
attempts += 1
if self.debug and attempts >= max_attempts:
print(f"Warning: Reached max attempts ({max_attempts}) when generating combinations")
return all_combinations
def _get_result_per_combination(self, content: str, sampling_ratio: float = 0.0,
max_combinations: Optional[int] = None) -> Dict[str, Dict[str, Any]]:
"""
Get model responses for combinations of content
Args:
content: Original content
sampling_ratio: Ratio of combinations to sample
max_combinations: Maximum number of combinations
Returns:
Dictionary mapping combination keys to response data
"""
samples = self._get_samples(content)
if self.debug:
print(f"Found {len(samples)} samples in content")
combinations = self._get_all_combinations(samples, sampling_ratio, max_combinations)
if self.debug:
print(f"Generated {len(combinations)} combinations")
results = {}
# Process each combination
for key, (combination, indices) in tqdm(combinations.items(), desc="Processing combinations"):
comb_args = self._prepare_combination_args(combination, content)
comb_key = self._get_combination_key(combination, indices)
# Check cache first
if comb_key in self._cache:
response = self._cache[comb_key]
else:
response = self.model.generate(**comb_args)
self._cache[comb_key] = response
# Store results
results[key] = {
"combination": combination,
"indices": indices,
"response": response
}
return results
def analyze(self, content: str, sampling_ratio: float = 0.0, max_combinations: Optional[int] = None) -> pd.DataFrame:
"""
Analyze importance in content
Args:
content: Content to analyze
sampling_ratio: Ratio of combinations to sample
max_combinations: Maximum number of combinations
Returns:
DataFrame with analysis results
"""
raise NotImplementedError