Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import joblib | |
| import numpy as np | |
| from itertools import product | |
| import torch.nn as nn | |
| import matplotlib.pyplot as plt | |
| import matplotlib.colors as mcolors | |
| from matplotlib.colors import LinearSegmentedColormap | |
| import io | |
| from io import BytesIO # Import io then BytesIO | |
| from PIL import Image, ImageDraw, ImageFont | |
| from Bio.Graphics import GenomeDiagram | |
| from Bio.SeqFeature import SeqFeature, FeatureLocation | |
| from reportlab.lib import colors | |
| import pandas as pd | |
| import tempfile | |
| import os | |
| from typing import List, Dict, Tuple, Optional, Any | |
| import seaborn as sns | |
| import shap | |
| ############################################################################### | |
| # 1. MODEL DEFINITION | |
| ############################################################################### | |
| class VirusClassifier(nn.Module): | |
| def __init__(self, input_shape: int): | |
| super(VirusClassifier, self).__init__() | |
| self.network = nn.Sequential( | |
| nn.Linear(input_shape, 64), | |
| nn.GELU(), | |
| nn.BatchNorm1d(64), | |
| nn.Dropout(0.3), | |
| nn.Linear(64, 32), | |
| nn.GELU(), | |
| nn.BatchNorm1d(32), | |
| nn.Dropout(0.3), | |
| nn.Linear(32, 32), | |
| nn.GELU(), | |
| nn.Linear(32, 2) | |
| ) | |
| def forward(self, x): | |
| return self.network(x) | |
| ############################################################################### | |
| # 2. FASTA PARSING & K-MER FEATURE ENGINEERING | |
| ############################################################################### | |
| def parse_fasta(text): | |
| sequences = [] | |
| current_header = None | |
| current_sequence = [] | |
| for line in text.strip().split('\n'): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| if line.startswith('>'): | |
| if current_header: | |
| sequences.append((current_header, ''.join(current_sequence))) | |
| current_header = line[1:] | |
| current_sequence = [] | |
| else: | |
| current_sequence.append(line.upper()) | |
| if current_header: | |
| sequences.append((current_header, ''.join(current_sequence))) | |
| return sequences | |
| def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray: | |
| kmers = [''.join(p) for p in product("ACGT", repeat=k)] | |
| kmer_dict = {km: i for i, km in enumerate(kmers)} | |
| vec = np.zeros(len(kmers), dtype=np.float32) | |
| for i in range(len(sequence) - k + 1): | |
| kmer = sequence[i:i+k] | |
| if kmer in kmer_dict: | |
| vec[kmer_dict[kmer]] += 1 | |
| total_kmers = len(sequence) - k + 1 | |
| if total_kmers > 0: | |
| vec /= total_kmers | |
| return vec | |
| ############################################################################### | |
| # 3. FEATURE IMPORTANCE (ABLATION) CALCULATION | |
| ############################################################################### | |
| def calculate_shap_values(model, x_tensor): | |
| model.eval() | |
| with torch.no_grad(): | |
| baseline_output = model(x_tensor) | |
| baseline_probs = torch.softmax(baseline_output, dim=1) | |
| baseline_prob = baseline_probs[0, 1].item() # Prob of 'human' | |
| shap_values = [] | |
| x_zeroed = x_tensor.clone() | |
| for i in range(x_tensor.shape[1]): | |
| original_val = x_zeroed[0, i].item() | |
| x_zeroed[0, i] = 0.0 | |
| output = model(x_zeroed) | |
| probs = torch.softmax(output, dim=1) | |
| prob = probs[0, 1].item() | |
| shap_values.append(baseline_prob - prob) | |
| x_zeroed[0, i] = original_val | |
| return np.array(shap_values), baseline_prob | |
| ############################################################################### | |
| # 4. PER-BASE FEATURE IMPORTANCE AGGREGATION | |
| ############################################################################### | |
| def compute_positionwise_scores(sequence, shap_values, k=4): | |
| kmers = [''.join(p) for p in product("ACGT", repeat=k)] | |
| kmer_dict = {km: i for i, km in enumerate(kmers)} | |
| seq_len = len(sequence) | |
| shap_sums = np.zeros(seq_len, dtype=np.float32) | |
| coverage = np.zeros(seq_len, dtype=np.float32) | |
| for i in range(seq_len - k + 1): | |
| kmer = sequence[i:i+k] | |
| if kmer in kmer_dict: | |
| val = shap_values[kmer_dict[kmer]] | |
| shap_sums[i:i+k] += val | |
| coverage[i:i+k] += 1 | |
| with np.errstate(divide='ignore', invalid='ignore'): | |
| shap_means = np.where(coverage > 0, shap_sums / coverage, 0.0) | |
| return shap_means | |
| ############################################################################### | |
| # 5. FIND EXTREME IMPORTANCE REGIONS | |
| ############################################################################### | |
| def find_extreme_subregion(shap_means, window_size=500, mode="max"): | |
| n = len(shap_means) | |
| if n == 0: | |
| return (0, 0, 0.0) | |
| if window_size >= n: | |
| return (0, n, float(np.mean(shap_means))) | |
| csum = np.zeros(n + 1, dtype=np.float32) | |
| csum[1:] = np.cumsum(shap_means) | |
| best_start = 0 | |
| best_sum = csum[window_size] - csum[0] | |
| best_avg = best_sum / window_size | |
| for start in range(1, n - window_size + 1): | |
| wsum = csum[start + window_size] - csum[start] | |
| wavg = wsum / window_size | |
| if mode == "max" and wavg > best_avg: | |
| best_avg = wavg | |
| best_start = start | |
| elif mode == "min" and wavg < best_avg: | |
| best_avg = wavg | |
| best_start = start | |
| return (best_start, best_start + window_size, float(best_avg)) | |
| ############################################################################### | |
| # 6. PLOTTING / UTILITIES | |
| ############################################################################### | |
| def fig_to_image(fig): | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png', bbox_inches='tight', dpi=150) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close(fig) | |
| return img | |
| def get_zero_centered_cmap(): | |
| colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')] | |
| return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors) | |
| def plot_linear_heatmap(shap_means, title="Per-base Feature Importance Heatmap", start=None, end=None): | |
| if start is not None and end is not None: | |
| local_shap = shap_means[start:end] | |
| subtitle = f" (positions {start}-{end})" | |
| else: | |
| local_shap = shap_means | |
| subtitle = "" | |
| if len(local_shap) == 0: | |
| local_shap = np.array([0.0]) | |
| heatmap_data = local_shap.reshape(1, -1) | |
| min_val = np.min(local_shap) | |
| max_val = np.max(local_shap) | |
| extent = max(abs(min_val), abs(max_val)) | |
| cmap = get_zero_centered_cmap() | |
| fig, ax = plt.subplots(figsize=(12, 1.8)) | |
| cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent) | |
| cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8) | |
| cbar.ax.tick_params(labelsize=8) | |
| cbar.set_label('Feature Importance', fontsize=9, labelpad=5) | |
| ax.set_yticks([]) | |
| ax.set_xlabel('Position in Sequence', fontsize=10) | |
| ax.set_title(f"{title}{subtitle}", pad=10) | |
| plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95) | |
| return fig | |
| def create_importance_bar_plot(shap_values, kmers, top_k=10): | |
| plt.rcParams.update({'font.size': 10}) | |
| fig = plt.figure(figsize=(10, 5)) | |
| indices = np.argsort(np.abs(shap_values))[-top_k:] | |
| values = shap_values[indices] | |
| features = [kmers[i] for i in indices] | |
| colors = ['#99ccff' if v < 0 else '#ff9999' for v in values] | |
| plt.barh(range(len(values)), values, color=colors) | |
| plt.yticks(range(len(values)), features) | |
| plt.xlabel('Feature Importance (impact on model output)') | |
| plt.title(f'Top {top_k} Most Influential k-mers') | |
| plt.gca().invert_yaxis() | |
| plt.tight_layout() | |
| return fig | |
| def plot_shap_histogram(shap_array, title="Feature Importance Distribution in Region", num_bins=30): | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black') | |
| ax.axvline(0, color='red', linestyle='--', label='0.0') | |
| ax.set_xlabel("Feature Importance Value") | |
| ax.set_ylabel("Count") | |
| ax.set_title(title) | |
| ax.legend() | |
| plt.tight_layout() | |
| return fig | |
| def compute_gc_content(sequence): | |
| if not sequence: | |
| return 0 | |
| gc_count = sequence.count('G') + sequence.count('C') | |
| return (gc_count / len(sequence)) * 100.0 | |
| ############################################################################### | |
| # 7. MAIN ANALYSIS STEP (Gradio Step 1) | |
| ############################################################################### | |
| def create_kmer_shap_csv(kmers, shap_values): | |
| """Create a CSV file with k-mer importance values and return the filepath""" | |
| # Create DataFrame with k-mers and importance values | |
| kmer_df = pd.DataFrame({ | |
| 'kmer': kmers, | |
| 'importance_value': shap_values, | |
| 'abs_importance': np.abs(shap_values) | |
| }) | |
| # Sort by absolute importance value (most influential first) | |
| kmer_df = kmer_df.sort_values('abs_importance', ascending=False) | |
| # Drop the abs_importance column used for sorting | |
| kmer_df = kmer_df[['kmer', 'importance_value']] | |
| # Save to temporary file | |
| temp_dir = tempfile.gettempdir() | |
| temp_path = os.path.join(temp_dir, f"kmer_importance_values_{os.urandom(4).hex()}.csv") | |
| kmer_df.to_csv(temp_path, index=False) | |
| return temp_path # Return only the file path, not a tuple | |
| def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500): | |
| if fasta_text.strip(): | |
| text = fasta_text.strip() | |
| elif file_obj is not None: | |
| try: | |
| with open(file_obj, 'r') as f: | |
| text = f.read() | |
| except Exception as e: | |
| return (f"Error reading file: {str(e)}", None, None, None, None, None, None) | |
| else: | |
| return ("Please provide a FASTA sequence.", None, None, None, None, None, None) | |
| sequences = parse_fasta(text) | |
| if not sequences: | |
| return ("No valid FASTA sequences found.", None, None, None, None, None, None) | |
| header, seq = sequences[0] | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| try: | |
| # IMPORTANT: adjust how you load your model as needed | |
| state_dict = torch.load('model.pt', map_location=device) | |
| model = VirusClassifier(256).to(device) | |
| model.load_state_dict(state_dict) | |
| scaler = joblib.load('scaler.pkl') | |
| except Exception as e: | |
| return (f"Error loading model/scaler: {str(e)}", None, None, None, None, None, None) | |
| freq_vector = sequence_to_kmer_vector(seq) | |
| scaled_vector = scaler.transform(freq_vector.reshape(1, -1)) | |
| x_tensor = torch.FloatTensor(scaled_vector).to(device) | |
| shap_values, prob_human = calculate_shap_values(model, x_tensor) | |
| prob_nonhuman = 1.0 - prob_human | |
| classification = "Human" if prob_human > 0.5 else "Non-human" | |
| confidence = max(prob_human, prob_nonhuman) | |
| shap_means = compute_positionwise_scores(seq, shap_values, k=4) | |
| max_start, max_end, max_avg = find_extreme_subregion(shap_means, window_size, mode="max") | |
| min_start, min_end, min_avg = find_extreme_subregion(shap_means, window_size, mode="min") | |
| results_text = ( | |
| f"Sequence: {header}\n" | |
| f"Length: {len(seq):,} bases\n" | |
| f"Classification: {classification}\n" | |
| f"Confidence: {confidence:.3f}\n" | |
| f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})\n\n" | |
| f"---\n" | |
| f"**Most Human-Pushing {window_size}-bp Subregion**:\n" | |
| f"Start: {max_start}, End: {max_end}, Avg Importance: {max_avg:.4f}\n\n" | |
| f"**Most Non-Human–Pushing {window_size}-bp Subregion**:\n" | |
| f"Start: {min_start}, End: {min_end}, Avg Importance: {min_avg:.4f}" | |
| ) | |
| kmers = [''.join(p) for p in product("ACGT", repeat=4)] | |
| bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers) | |
| bar_img = fig_to_image(bar_fig) | |
| heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide Feature Importance") | |
| heatmap_img = fig_to_image(heatmap_fig) | |
| # Create CSV with k-mer importance values and return the file path | |
| kmer_shap_csv = create_kmer_shap_csv(kmers, shap_values) | |
| # State dictionary for subregion analysis | |
| state_dict_out = {"seq": seq, "shap_means": shap_means} | |
| return (results_text, bar_img, heatmap_img, state_dict_out, header, None, kmer_shap_csv) | |
| ############################################################################### | |
| # 8. SUBREGION ANALYSIS (Gradio Step 2) | |
| ############################################################################### | |
| def analyze_subregion(state, header, region_start, region_end): | |
| if not state or "seq" not in state or "shap_means" not in state: | |
| return ("No sequence data found. Please run Step 1 first.", None, None, None) | |
| seq = state["seq"] | |
| shap_means = state["shap_means"] | |
| region_start = int(region_start) | |
| region_end = int(region_end) | |
| region_start = max(0, min(region_start, len(seq))) | |
| region_end = max(0, min(region_end, len(seq))) | |
| if region_end <= region_start: | |
| return ("Invalid region range. End must be > Start.", None, None, None) | |
| region_seq = seq[region_start:region_end] | |
| region_shap = shap_means[region_start:region_end] | |
| gc_percent = compute_gc_content(region_seq) | |
| avg_shap = float(np.mean(region_shap)) | |
| positive_fraction = np.mean(region_shap > 0) | |
| negative_fraction = np.mean(region_shap < 0) | |
| if avg_shap > 0.05: | |
| region_classification = "Likely pushing toward human" | |
| elif avg_shap < -0.05: | |
| region_classification = "Likely pushing toward non-human" | |
| else: | |
| region_classification = "Near neutral (no strong push)" | |
| region_info = ( | |
| f"Analyzing subregion of {header} from {region_start} to {region_end}\n" | |
| f"Region length: {len(region_seq)} bases\n" | |
| f"GC content: {gc_percent:.2f}%\n" | |
| f"Average importance in region: {avg_shap:.4f}\n" | |
| f"Fraction with importance > 0 (toward human): {positive_fraction:.2f}\n" | |
| f"Fraction with importance < 0 (toward non-human): {negative_fraction:.2f}\n" | |
| f"Subregion interpretation: {region_classification}\n" | |
| ) | |
| heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion Feature Importance", start=region_start, end=region_end) | |
| heatmap_img = fig_to_image(heatmap_fig) | |
| hist_fig = plot_shap_histogram(region_shap, title="Feature Importance Distribution in Subregion") | |
| hist_img = fig_to_image(hist_fig) | |
| # For demonstration, returning None for the file download as well | |
| return (region_info, heatmap_img, hist_img, None) | |
| ############################################################################### | |
| # 9. COMPARISON ANALYSIS FUNCTIONS | |
| ############################################################################### | |
| def get_zero_centered_cmap(): | |
| """Create a zero-centered blue-white-red colormap""" | |
| colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')] | |
| return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors) | |
| def compute_shap_difference(shap1_norm, shap2_norm): | |
| """Compute the feature importance difference between normalized sequences""" | |
| return shap2_norm - shap1_norm | |
| def plot_comparative_heatmap(shap_diff, title="Feature Importance Difference Heatmap"): | |
| """ | |
| Plot heatmap using relative positions (0-100%) | |
| """ | |
| heatmap_data = shap_diff.reshape(1, -1) | |
| extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff))) | |
| fig, ax = plt.subplots(figsize=(12, 1.8)) | |
| cmap = get_zero_centered_cmap() | |
| cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent) | |
| # Create percentage-based x-axis ticks | |
| num_ticks = 5 | |
| tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks) | |
| tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)] | |
| ax.set_xticks(tick_positions) | |
| ax.set_xticklabels(tick_labels) | |
| cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8) | |
| cbar.ax.tick_params(labelsize=8) | |
| cbar.set_label('Feature Importance Difference (Seq2 - Seq1)', fontsize=9, labelpad=5) | |
| ax.set_yticks([]) | |
| ax.set_xlabel('Relative Position in Sequence', fontsize=10) | |
| ax.set_title(title, pad=10) | |
| plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95) | |
| return fig | |
| def plot_shap_histogram(shap_array, title="Feature Importance Distribution", num_bins=30): | |
| """ | |
| Plot histogram of feature importance values with configurable number of bins | |
| """ | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black', alpha=0.7) | |
| ax.axvline(0, color='red', linestyle='--', label='0.0') | |
| ax.set_xlabel("Feature Importance Value") | |
| ax.set_ylabel("Count") | |
| ax.set_title(title) | |
| ax.legend() | |
| plt.tight_layout() | |
| return fig | |
| def calculate_adaptive_parameters(len1, len2): | |
| """ | |
| Calculate adaptive parameters based on sequence lengths and their difference. | |
| Returns: (num_points, smooth_window, resolution_factor) | |
| """ | |
| length_diff = abs(len1 - len2) | |
| max_length = max(len1, len2) | |
| min_length = min(len1, len2) | |
| length_ratio = min_length / max_length | |
| # Base number of points scales with sequence length | |
| base_points = min(2000, max(500, max_length // 100)) | |
| # Adjust parameters based on sequence properties | |
| if length_diff < 500: | |
| resolution_factor = 2.0 | |
| num_points = min(3000, base_points * 2) | |
| smooth_window = max(10, length_diff // 50) | |
| elif length_diff < 5000: | |
| resolution_factor = 1.5 | |
| num_points = min(2000, base_points * 1.5) | |
| smooth_window = max(20, length_diff // 100) | |
| elif length_diff < 50000: | |
| resolution_factor = 1.0 | |
| num_points = base_points | |
| smooth_window = max(50, length_diff // 200) | |
| else: | |
| resolution_factor = 0.75 | |
| num_points = max(500, base_points // 2) | |
| smooth_window = max(100, length_diff // 500) | |
| # Adjust window size based on length ratio | |
| smooth_window = int(smooth_window * (1 + (1 - length_ratio))) | |
| return int(num_points), int(smooth_window), resolution_factor | |
| def sliding_window_smooth(values, window_size=50): | |
| """ | |
| Apply sliding window smoothing with edge handling | |
| """ | |
| if window_size < 3: | |
| return values | |
| # Create window with exponential decay at edges | |
| window = np.ones(window_size) | |
| decay = np.exp(-np.linspace(0, 3, window_size // 2)) | |
| window[:window_size // 2] = decay | |
| window[-(window_size // 2):] = decay[::-1] | |
| window = window / window.sum() | |
| # Apply convolution | |
| smoothed = np.convolve(values, window, mode='valid') | |
| # Handle edges | |
| pad_size = len(values) - len(smoothed) | |
| pad_left = pad_size // 2 | |
| pad_right = pad_size - pad_left | |
| result = np.zeros_like(values) | |
| result[pad_left:-pad_right] = smoothed | |
| result[:pad_left] = values[:pad_left] | |
| result[-pad_right:] = values[-pad_right:] | |
| return result | |
| def normalize_shap_lengths(shap1, shap2): | |
| """ | |
| Normalize and smooth feature importance values with dynamic adaptation | |
| """ | |
| # Calculate adaptive parameters | |
| num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2)) | |
| # Apply initial smoothing | |
| shap1_smooth = sliding_window_smooth(shap1, smooth_window) | |
| shap2_smooth = sliding_window_smooth(shap2, smooth_window) | |
| # Create relative positions and interpolate | |
| x1 = np.linspace(0, 1, len(shap1_smooth)) | |
| x2 = np.linspace(0, 1, len(shap2_smooth)) | |
| x_norm = np.linspace(0, 1, num_points) | |
| shap1_interp = np.interp(x_norm, x1, shap1_smooth) | |
| shap2_interp = np.interp(x_norm, x2, shap2_smooth) | |
| return shap1_interp, shap2_interp, smooth_window | |
| def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""): | |
| """ | |
| Compare two sequences with adaptive parameters and visualization | |
| """ | |
| try: | |
| # Analyze first sequence | |
| res1 = analyze_sequence(file1, top_kmers=10, fasta_text=fasta1, window_size=500) | |
| if isinstance(res1[0], str) and "Error" in res1[0]: | |
| return (f"Error in sequence 1: {res1[0]}", None, None, None) | |
| # Analyze second sequence | |
| res2 = analyze_sequence(file2, top_kmers=10, fasta_text=fasta2, window_size=500) | |
| if isinstance(res2[0], str) and "Error" in res2[0]: | |
| return (f"Error in sequence 2: {res2[0]}", None, None, None) | |
| # Extract feature importance values and sequence info | |
| shap1 = res1[3]["shap_means"] | |
| shap2 = res2[3]["shap_means"] | |
| # Calculate sequence properties | |
| len1, len2 = len(shap1), len(shap2) | |
| length_diff = abs(len1 - len2) | |
| length_ratio = min(len1, len2) / max(len1, len2) | |
| # Normalize and compare sequences | |
| shap1_norm, shap2_norm, smooth_window = normalize_shap_lengths(shap1, shap2) | |
| shap_diff = compute_shap_difference(shap1_norm, shap2_norm) | |
| # Calculate adaptive threshold and statistics | |
| base_threshold = 0.05 | |
| adaptive_threshold = base_threshold * (1 + (1 - length_ratio)) | |
| if length_diff > 50000: | |
| adaptive_threshold *= 1.5 | |
| # Calculate comparison statistics | |
| avg_diff = np.mean(shap_diff) | |
| std_diff = np.std(shap_diff) | |
| max_diff = np.max(shap_diff) | |
| min_diff = np.min(shap_diff) | |
| substantial_diffs = np.abs(shap_diff) > adaptive_threshold | |
| frac_different = np.mean(substantial_diffs) | |
| # Extract classifications | |
| try: | |
| classification1 = res1[0].split('Classification: ')[1].split('\n')[0].strip() | |
| classification2 = res2[0].split('Classification: ')[1].split('\n')[0].strip() | |
| except: | |
| classification1 = "Unknown" | |
| classification2 = "Unknown" | |
| # Format output text | |
| comparison_text = ( | |
| "Sequence Comparison Results:\n" | |
| f"Sequence 1: {res1[4]}\n" | |
| f"Length: {len1:,} bases\n" | |
| f"Classification: {classification1}\n\n" | |
| f"Sequence 2: {res2[4]}\n" | |
| f"Length: {len2:,} bases\n" | |
| f"Classification: {classification2}\n\n" | |
| "Comparison Parameters:\n" | |
| f"Length Difference: {length_diff:,} bases\n" | |
| f"Length Ratio: {length_ratio:.3f}\n" | |
| f"Smoothing Window: {smooth_window} points\n" | |
| f"Adaptive Threshold: {adaptive_threshold:.3f}\n\n" | |
| "Statistics:\n" | |
| f"Average feature importance difference: {avg_diff:.4f}\n" | |
| f"Standard deviation: {std_diff:.4f}\n" | |
| f"Max difference: {max_diff:.4f} (Seq2 more human-like)\n" | |
| f"Min difference: {min_diff:.4f} (Seq1 more human-like)\n" | |
| f"Fraction with substantial differences: {frac_different:.2%}\n\n" | |
| "Note: All parameters automatically adjusted based on sequence properties\n\n" | |
| "Interpretation:\n" | |
| "- Red regions: Sequence 2 more human-like\n" | |
| "- Blue regions: Sequence 1 more human-like\n" | |
| "- White regions: Similar between sequences" | |
| ) | |
| # Generate visualizations | |
| heatmap_fig = plot_comparative_heatmap( | |
| shap_diff, | |
| title=f"Feature Importance Difference Heatmap (window: {smooth_window})" | |
| ) | |
| heatmap_img = fig_to_image(heatmap_fig) | |
| # Create histogram with adaptive bins | |
| num_bins = max(20, min(50, int(np.sqrt(len(shap_diff))))) | |
| hist_fig = plot_shap_histogram( | |
| shap_diff, | |
| title="Distribution of Feature Importance Differences", | |
| num_bins=num_bins | |
| ) | |
| hist_img = fig_to_image(hist_fig) | |
| # Return 4 outputs (text, image, image, and a file or None for the last) | |
| return (comparison_text, heatmap_img, hist_img, None) | |
| except Exception as e: | |
| error_msg = f"Error during sequence comparison: {str(e)}" | |
| return (error_msg, None, None, None) | |
| ############################################################################### | |
| # 11. GENE FEATURE ANALYSIS | |
| ############################################################################### | |
| import io | |
| from io import BytesIO | |
| from PIL import Image, ImageDraw, ImageFont | |
| import numpy as np | |
| import pandas as pd | |
| import tempfile | |
| import os | |
| from typing import List, Dict, Tuple, Optional, Any | |
| import matplotlib.pyplot as plt | |
| from matplotlib.colors import LinearSegmentedColormap | |
| import seaborn as sns | |
| def parse_gene_features(text: str) -> List[Dict[str, Any]]: | |
| """Parse gene features from text file in FASTA-like format""" | |
| genes = [] | |
| current_header = None | |
| current_sequence = [] | |
| for line in text.strip().split('\n'): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| if line.startswith('>'): | |
| if current_header: | |
| genes.append({ | |
| 'header': current_header, | |
| 'sequence': ''.join(current_sequence), | |
| 'metadata': parse_gene_metadata(current_header) | |
| }) | |
| current_header = line[1:] | |
| current_sequence = [] | |
| else: | |
| current_sequence.append(line.upper()) | |
| if current_header: | |
| genes.append({ | |
| 'header': current_header, | |
| 'sequence': ''.join(current_sequence), | |
| 'metadata': parse_gene_metadata(current_header) | |
| }) | |
| return genes | |
| def parse_gene_metadata(header: str) -> Dict[str, str]: | |
| """Extract metadata from gene header""" | |
| metadata = {} | |
| parts = header.split() | |
| for part in parts: | |
| if '[' in part and ']' in part: | |
| key_value = part[1:-1].split('=', 1) | |
| if len(key_value) == 2: | |
| metadata[key_value[0]] = key_value[1] | |
| return metadata | |
| def parse_location(location_str: str) -> Tuple[Optional[int], Optional[int]]: | |
| """Parse gene location string, handling both forward and complement strands""" | |
| try: | |
| # Remove 'complement(' and ')' if present | |
| clean_loc = location_str.replace('complement(', '').replace(')', '') | |
| # Split on '..' and convert to integers | |
| if '..' in clean_loc: | |
| start, end = map(int, clean_loc.split('..')) | |
| return start, end | |
| else: | |
| return None, None | |
| except Exception as e: | |
| print(f"Error parsing location {location_str}: {str(e)}") | |
| return None, None | |
| def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]: | |
| """Compute statistical measures for gene feature importance values""" | |
| return { | |
| 'avg_shap': float(np.mean(gene_shap)), | |
| 'median_shap': float(np.median(gene_shap)), | |
| 'std_shap': float(np.std(gene_shap)), | |
| 'max_shap': float(np.max(gene_shap)), | |
| 'min_shap': float(np.min(gene_shap)), | |
| 'pos_fraction': float(np.mean(gene_shap > 0)) | |
| } | |
| def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image: | |
| """ | |
| Create a simple genome diagram using PIL, forcing a minimum color intensity | |
| so that small feature importance values don't appear white. | |
| """ | |
| from PIL import Image, ImageDraw, ImageFont | |
| # Validate inputs | |
| if not gene_results or genome_length <= 0: | |
| img = Image.new('RGB', (800, 100), color='white') | |
| draw = ImageDraw.Draw(img) | |
| draw.text((10, 40), "Error: Invalid input data", fill='black') | |
| return img | |
| # Ensure all gene coordinates are valid integers | |
| for gene in gene_results: | |
| gene['start'] = max(0, int(gene['start'])) | |
| gene['end'] = min(genome_length, int(gene['end'])) | |
| if gene['start'] >= gene['end']: | |
| print(f"Warning: Invalid coordinates for gene {gene.get('gene_name','?')}: {gene['start']}-{gene['end']}") | |
| # Image dimensions | |
| width = 1500 | |
| height = 600 | |
| margin = 50 | |
| track_height = 40 | |
| # Create image with white background | |
| img = Image.new('RGB', (width, height), 'white') | |
| draw = ImageDraw.Draw(img) | |
| # Try to load font, fall back to default if unavailable | |
| try: | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12) | |
| title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16) | |
| except: | |
| font = ImageFont.load_default() | |
| title_font = ImageFont.load_default() | |
| # Draw title | |
| draw.text((margin, margin // 2), "Genome Feature Importance Analysis", fill='black', font=title_font or font) | |
| # Draw genome line | |
| line_y = height // 2 | |
| draw.line([(int(margin), int(line_y)), (int(width - margin), int(line_y))], fill='black', width=2) | |
| # Calculate scale factor | |
| scale = float(width - 2 * margin) / float(genome_length) | |
| # Determine a reasonable step for scale markers | |
| num_ticks = 10 | |
| if genome_length < num_ticks: | |
| step = 1 | |
| else: | |
| step = genome_length // num_ticks | |
| # Draw scale markers | |
| for i in range(0, genome_length + 1, step): | |
| x_coord = margin + i * scale | |
| draw.line([ | |
| (int(x_coord), int(line_y - 5)), | |
| (int(x_coord), int(line_y + 5)) | |
| ], fill='black', width=1) | |
| draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font) | |
| # Sort genes by absolute feature importance value for drawing | |
| sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap'])) | |
| # Draw genes | |
| for idx, gene in enumerate(sorted_genes): | |
| # Calculate position and ensure integers | |
| start_x = margin + int(gene['start'] * scale) | |
| end_x = margin + int(gene['end'] * scale) | |
| # Calculate color based on feature importance value | |
| avg_shap = gene['avg_shap'] | |
| # Convert importance -> color intensity (0 to 255) | |
| # Then clamp to a minimum intensity so it never ends up plain white | |
| intensity = int(abs(avg_shap) * 500) | |
| intensity = max(50, min(255, intensity)) # clamp between 50 and 255 | |
| if avg_shap > 0: | |
| # Red-ish for positive | |
| color = (255, 255 - intensity, 255 - intensity) | |
| else: | |
| # Blue-ish for negative or zero | |
| color = (255 - intensity, 255 - intensity, 255) | |
| # Draw gene rectangle | |
| draw.rectangle([ | |
| (int(start_x), int(line_y - track_height // 2)), | |
| (int(end_x), int(line_y + track_height // 2)) | |
| ], fill=color, outline='black') | |
| # Prepare gene name label | |
| label = str(gene.get('gene_name','?')) | |
| # Fallback for label size | |
| label_mask = font.getmask(label) | |
| label_width, label_height = label_mask.size | |
| # Alternate label positions | |
| if idx % 2 == 0: | |
| text_y = line_y - track_height - 15 | |
| else: | |
| text_y = line_y + track_height + 5 | |
| # Decide whether to rotate text based on space | |
| gene_width = end_x - start_x | |
| if gene_width > label_width: | |
| text_x = start_x + (gene_width - label_width) // 2 | |
| draw.text((int(text_x), int(text_y)), label, fill='black', font=font) | |
| elif gene_width > 20: | |
| txt_img = Image.new('RGBA', (label_width, label_height), (255, 255, 255, 0)) | |
| txt_draw = ImageDraw.Draw(txt_img) | |
| txt_draw.text((0, 0), label, font=font, fill='black') | |
| rotated_img = txt_img.rotate(90, expand=True) | |
| img.paste(rotated_img, (int(start_x), int(text_y)), rotated_img) | |
| # Draw legend | |
| legend_x = margin | |
| legend_y = height - margin | |
| draw.text((int(legend_x), int(legend_y - 60)), "Feature Importance Values:", fill='black', font=font) | |
| # Draw legend boxes | |
| box_width = 20 | |
| box_height = 20 | |
| spacing = 15 | |
| # Strong human-like | |
| draw.rectangle([ | |
| (int(legend_x), int(legend_y - 45)), | |
| (int(legend_x + box_width), int(legend_y - 45 + box_height)) | |
| ], fill=(255, 0, 0), outline='black') | |
| draw.text((int(legend_x + box_width + spacing), int(legend_y - 45)), | |
| "Strong human-like signal", fill='black', font=font) | |
| # Weak human-like | |
| draw.rectangle([ | |
| (int(legend_x), int(legend_y - 20)), | |
| (int(legend_x + box_width), int(legend_y - 20 + box_height)) | |
| ], fill=(255, 200, 200), outline='black') | |
| draw.text((int(legend_x + box_width + spacing), int(legend_y - 20)), | |
| "Weak human-like signal", fill='black', font=font) | |
| # Weak non-human-like | |
| draw.rectangle([ | |
| (int(legend_x + 250), int(legend_y - 45)), | |
| (int(legend_x + 250 + box_width), int(legend_y - 45 + box_height)) | |
| ], fill=(200, 200, 255), outline='black') | |
| draw.text((int(legend_x + 250 + box_width + spacing), int(legend_y - 45)), | |
| "Weak non-human-like signal", fill='black', font=font) | |
| # Strong non-human-like | |
| draw.rectangle([ | |
| (int(legend_x + 250), int(legend_y - 20)), | |
| (int(legend_x + 250 + box_width), int(legend_y - 20 + box_height)) | |
| ], fill=(0, 0, 255), outline='black') | |
| draw.text((int(legend_x + 250 + box_width + spacing), int(legend_y - 20)), | |
| "Strong non-human-like signal", fill='black', font=font) | |
| return img | |
| def analyze_gene_features(sequence_file: str, | |
| features_file: str, | |
| fasta_text: str = "", | |
| features_text: str = "") -> Tuple[str, Optional[str], Optional[Image.Image]]: | |
| """Analyze feature importance values for each gene feature""" | |
| # First analyze whole sequence | |
| sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text) | |
| if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]: | |
| return f"Error in sequence analysis: {sequence_results[0]}", None, None | |
| # Get feature importance values | |
| shap_means = sequence_results[3]["shap_means"] | |
| # Parse gene features | |
| try: | |
| if features_text.strip(): | |
| genes = parse_gene_features(features_text) | |
| else: | |
| with open(features_file, 'r') as f: | |
| genes = parse_gene_features(f.read()) | |
| except Exception as e: | |
| return f"Error reading features file: {str(e)}", None, None | |
| # Analyze each gene | |
| gene_results = [] | |
| for gene in genes: | |
| try: | |
| location = gene['metadata'].get('location', '') | |
| if not location: | |
| continue | |
| start, end = parse_location(location) | |
| if start is None or end is None: | |
| continue | |
| # Get feature importance values for this region | |
| gene_shap = shap_means[start:end] | |
| stats = compute_gene_statistics(gene_shap) | |
| gene_results.append({ | |
| 'gene_name': gene['metadata'].get('gene', 'Unknown'), | |
| 'location': location, | |
| 'start': start, | |
| 'end': end, | |
| 'locus_tag': gene['metadata'].get('locus_tag', ''), | |
| 'avg_shap': stats['avg_shap'], | |
| 'median_shap': stats['median_shap'], | |
| 'std_shap': stats['std_shap'], | |
| 'max_shap': stats['max_shap'], | |
| 'min_shap': stats['min_shap'], | |
| 'pos_fraction': stats['pos_fraction'], | |
| 'classification': 'Human' if stats['avg_shap'] > 0 else 'Non-human', | |
| 'confidence': abs(stats['avg_shap']) | |
| }) | |
| except Exception as e: | |
| print(f"Error processing gene {gene['metadata'].get('gene', 'Unknown')}: {str(e)}") | |
| continue | |
| if not gene_results: | |
| return "No valid genes could be processed", None, None | |
| # Sort genes by absolute feature importance value | |
| sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True) | |
| # Create results text | |
| results_text = "Gene Analysis Results:\n\n" | |
| results_text += f"Total genes analyzed: {len(gene_results)}\n" | |
| results_text += f"Human-like genes: {sum(1 for g in gene_results if g['classification'] == 'Human')}\n" | |
| results_text += f"Non-human-like genes: {sum(1 for g in gene_results if g['classification'] == 'Non-human')}\n\n" | |
| results_text += "Top 10 most distinctive genes:\n" | |
| for gene in sorted_genes[:10]: | |
| results_text += ( | |
| f"Gene: {gene['gene_name']}\n" | |
| f"Location: {gene['location']}\n" | |
| f"Classification: {gene['classification']} " | |
| f"(confidence: {gene['confidence']:.4f})\n" | |
| f"Average Feature Importance: {gene['avg_shap']:.4f}\n\n" | |
| ) | |
| # Create CSV content | |
| csv_content = "gene_name,location,avg_importance,median_importance,std_importance,max_importance,min_importance," | |
| csv_content += "pos_fraction,classification,confidence,locus_tag\n" | |
| for gene in gene_results: | |
| csv_content += ( | |
| f"{gene['gene_name']},{gene['location']},{gene['avg_shap']:.4f}," | |
| f"{gene['median_shap']:.4f},{gene['std_shap']:.4f},{gene['max_shap']:.4f}," | |
| f"{gene['min_shap']:.4f},{gene['pos_fraction']:.4f},{gene['classification']}," | |
| f"{gene['confidence']:.4f},{gene['locus_tag']}\n" | |
| ) | |
| # Save CSV to temp file | |
| try: | |
| temp_dir = tempfile.gettempdir() | |
| temp_path = os.path.join(temp_dir, f"gene_analysis_{os.urandom(4).hex()}.csv") | |
| with open(temp_path, 'w') as f: | |
| f.write(csv_content) | |
| except Exception as e: | |
| print(f"Error saving CSV: {str(e)}") | |
| temp_path = None | |
| # Create visualization | |
| try: | |
| diagram_img = create_simple_genome_diagram(gene_results, len(shap_means)) | |
| except Exception as e: | |
| print(f"Error creating visualization: {str(e)}") | |
| # Create error image | |
| diagram_img = Image.new('RGB', (800, 100), color='white') | |
| draw = ImageDraw.Draw(diagram_img) | |
| draw.text((10, 40), f"Error creating visualization: {str(e)}", fill='black') | |
| return results_text, temp_path, diagram_img | |
| ############################################################################### | |
| # 12. DOWNLOAD FUNCTIONS | |
| ############################################################################### | |
| def prepare_csv_download(data, filename="analysis_results.csv"): | |
| """Prepare CSV data for download""" | |
| if isinstance(data, str): | |
| return data.encode(), filename | |
| elif isinstance(data, (list, dict)): | |
| import csv | |
| from io import StringIO | |
| output = StringIO() | |
| writer = csv.DictWriter(output, fieldnames=data[0].keys()) | |
| writer.writeheader() | |
| writer.writerows(data) | |
| return output.getvalue().encode(), filename | |
| else: | |
| raise ValueError("Unsupported data type for CSV download") | |
| ############################################################################### | |
| # 14. BUILD GRADIO INTERFACE | |
| ############################################################################### | |
| def load_example_fasta(): | |
| """Load the example.fasta file contents""" | |
| try: | |
| with open('example.fasta', 'r') as f: | |
| example_text = f.read() | |
| return example_text | |
| except Exception as e: | |
| return f">example_sequence\nACGTACGT...\n\n(Note: Could not load example.fasta: {str(e)})" | |
| ############################################################################### | |
| # 14. BUILD GRADIO INTERFACE | |
| ############################################################################### | |
| css = """ | |
| .gradio-container { | |
| font-family: 'IBM Plex Sans', sans-serif; | |
| } | |
| .download-button { | |
| margin-top: 10px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as iface: | |
| gr.Markdown(""" | |
| # Virus Host Classifier | |
| **Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions. | |
| **Step 2**: Explore subregions to see local feature influence, distribution, GC content, etc. | |
| **Step 3**: Analyze gene features and their contributions. | |
| **Step 4**: Compare sequences and analyze differences. | |
| **Color Scale**: Negative values = Blue, Zero = White, Positive values = Red. | |
| """) | |
| with gr.Tab("1) Full-Sequence Analysis"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") | |
| text_input = gr.Textbox(label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5) | |
| # Add example FASTA button in a row | |
| with gr.Row(): | |
| example_btn = gr.Button("Load Example FASTA", variant="secondary") | |
| top_k = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of top k-mers to display") | |
| win_size = gr.Slider(minimum=100, maximum=5000, value=500, step=100, label="Window size for 'most pushing' subregions") | |
| analyze_btn = gr.Button("Analyze Sequence", variant="primary") | |
| with gr.Column(scale=2): | |
| results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False) | |
| kmer_img = gr.Image(label="Top k-mer Importance") | |
| genome_img = gr.Image(label="Genome-wide Feature Importance Heatmap (Blue=neg, White=0, Red=pos)") | |
| # File components with the correct type parameter | |
| download_kmer_shap = gr.File(label="Download k-mer Importance Values (CSV)", visible=True, type="filepath") | |
| download_results = gr.File(label="Download Results", visible=True, elem_classes="download-button") | |
| seq_state = gr.State() | |
| header_state = gr.State() | |
| # Event handlers | |
| # Connect the example button | |
| example_btn.click( | |
| load_example_fasta, | |
| inputs=[], | |
| outputs=[text_input] | |
| ) | |
| # Connect the analyze button | |
| analyze_btn.click( | |
| analyze_sequence, | |
| inputs=[file_input, top_k, text_input, win_size], | |
| outputs=[results_box, kmer_img, genome_img, seq_state, header_state, download_results, download_kmer_shap] | |
| ) | |
| with gr.Tab("2) Subregion Exploration"): | |
| gr.Markdown(""" | |
| **Subregion Analysis** | |
| Select start/end positions to view local feature importance, distribution, GC content, etc. | |
| The heatmap uses the same Blue-White-Red scale. | |
| """) | |
| with gr.Row(): | |
| region_start = gr.Number(label="Region Start", value=0) | |
| region_end = gr.Number(label="Region End", value=500) | |
| region_btn = gr.Button("Analyze Subregion") | |
| subregion_info = gr.Textbox(label="Subregion Analysis", lines=7, interactive=False) | |
| with gr.Row(): | |
| subregion_img = gr.Image(label="Subregion Feature Importance Heatmap (B-W-R)") | |
| subregion_hist_img = gr.Image(label="Feature Importance Distribution (Histogram)") | |
| download_subregion = gr.File(label="Download Subregion Analysis", visible=False, elem_classes="download-button") | |
| region_btn.click( | |
| analyze_subregion, | |
| inputs=[seq_state, header_state, region_start, region_end], | |
| outputs=[subregion_info, subregion_img, subregion_hist_img, download_subregion] | |
| ) | |
| with gr.Tab("3) Gene Features Analysis"): | |
| gr.Markdown(""" | |
| **Analyze Gene Features** | |
| Upload a FASTA file and corresponding gene features file to analyze feature importance values per gene. | |
| Gene features should be in the format: | |
| >gene_name [gene=X] [locus_tag=Y] [location=start..end] or [location=complement(start..end)] | |
| SEQUENCE | |
| The genome viewer will show genes color-coded by their contribution: | |
| - Red: Genes pushing toward human origin | |
| - Blue: Genes pushing toward non-human origin | |
| - Color intensity indicates strength of signal | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gene_fasta_file = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") | |
| gene_fasta_text = gr.Textbox(label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5) | |
| with gr.Column(scale=1): | |
| features_file = gr.File(label="Upload gene features file", file_types=[".txt"], type="filepath") | |
| features_text = gr.Textbox(label="Or paste gene features", placeholder=">gene_1 [gene=U12]...\nACGT...", lines=5) | |
| analyze_genes_btn = gr.Button("Analyze Gene Features", variant="primary") | |
| gene_results = gr.Textbox(label="Gene Analysis Results", lines=12, interactive=False) | |
| gene_diagram = gr.Image(label="Genome Diagram with Gene Features") | |
| download_gene_results = gr.File(label="Download Gene Analysis (CSV)", visible=True, type="filepath") | |
| analyze_genes_btn.click( | |
| analyze_gene_features, | |
| inputs=[gene_fasta_file, features_file, gene_fasta_text, features_text], | |
| outputs=[gene_results, download_gene_results, gene_diagram] | |
| ) | |
| with gr.Tab("4) Comparative Analysis"): | |
| gr.Markdown(""" | |
| **Compare Two Sequences** | |
| Upload or paste two FASTA sequences to compare their feature importance patterns. | |
| The sequences will be normalized to the same length for comparison. | |
| **Color Scale**: | |
| - Red: Sequence 2 more human-like | |
| - Blue: Sequence 1 more human-like | |
| - White: No substantial difference | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input1 = gr.File(label="Upload first FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") | |
| text_input1 = gr.Textbox(label="Or paste first FASTA sequence", placeholder=">sequence1\nACGTACGT...", lines=5) | |
| with gr.Column(scale=1): | |
| file_input2 = gr.File(label="Upload second FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") | |
| text_input2 = gr.Textbox(label="Or paste second FASTA sequence", placeholder=">sequence2\nACGTACGT...", lines=5) | |
| compare_btn = gr.Button("Compare Sequences", variant="primary") | |
| comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False) | |
| with gr.Row(): | |
| diff_heatmap = gr.Image(label="Feature Importance Difference Heatmap") | |
| diff_hist = gr.Image(label="Distribution of Feature Importance Differences") | |
| download_comparison = gr.File(label="Download Comparison Results", visible=False, elem_classes="download-button") | |
| compare_btn.click( | |
| analyze_sequence_comparison, | |
| inputs=[file_input1, file_input2, text_input1, text_input2], | |
| outputs=[comparison_text, diff_heatmap, diff_hist, download_comparison] | |
| ) | |
| gr.Markdown(""" | |
| ### Interface Features | |
| - **Overall Classification** (human vs non-human) using k-mer frequencies | |
| - **Feature Importance Analysis** shows which k-mers push classification toward or away from human | |
| - **White-Centered Gradient**: | |
| - Negative (blue), 0 (white), Positive (red) | |
| - Symmetrical color range around 0 | |
| - **Identify Subregions** with strongest push for human or non-human | |
| - **Gene Feature Analysis**: | |
| - Analyze individual genes' contributions | |
| - Interactive genome viewer | |
| - Gene-level statistics and classification | |
| - **Sequence Comparison**: | |
| - Compare two sequences to identify regions of difference | |
| - Normalized comparison to handle different lengths | |
| - Statistical summary of differences | |
| - **Data Export**: | |
| - Download results as CSV files | |
| - Download k-mer importance values | |
| - Save analysis outputs for further processing | |
| """) | |
| if __name__ == "__main__": | |
| iface.launch() |