Spaces:
Running
Running
import gradio as gr | |
import torch | |
import joblib | |
import numpy as np | |
from itertools import product | |
import torch.nn as nn | |
import matplotlib.pyplot as plt | |
import matplotlib.colors as mcolors | |
from matplotlib.colors import LinearSegmentedColormap | |
import io | |
from io import BytesIO # Import io then BytesIO | |
from PIL import Image, ImageDraw, ImageFont | |
from Bio.Graphics import GenomeDiagram | |
from Bio.SeqFeature import SeqFeature, FeatureLocation | |
from reportlab.lib import colors | |
import pandas as pd | |
import tempfile | |
import os | |
from typing import List, Dict, Tuple, Optional, Any | |
import seaborn as sns | |
############################################################################### | |
# 1. MODEL DEFINITION | |
############################################################################### | |
class VirusClassifier(nn.Module): | |
def __init__(self, input_shape: int): | |
super(VirusClassifier, self).__init__() | |
self.network = nn.Sequential( | |
nn.Linear(input_shape, 64), | |
nn.GELU(), | |
nn.BatchNorm1d(64), | |
nn.Dropout(0.3), | |
nn.Linear(64, 32), | |
nn.GELU(), | |
nn.BatchNorm1d(32), | |
nn.Dropout(0.3), | |
nn.Linear(32, 32), | |
nn.GELU(), | |
nn.Linear(32, 2) | |
) | |
def forward(self, x): | |
return self.network(x) | |
############################################################################### | |
# 2. FASTA PARSING & K-MER FEATURE ENGINEERING | |
############################################################################### | |
def parse_fasta(text): | |
sequences = [] | |
current_header = None | |
current_sequence = [] | |
for line in text.strip().split('\n'): | |
line = line.strip() | |
if not line: | |
continue | |
if line.startswith('>'): | |
if current_header: | |
sequences.append((current_header, ''.join(current_sequence))) | |
current_header = line[1:] | |
current_sequence = [] | |
else: | |
current_sequence.append(line.upper()) | |
if current_header: | |
sequences.append((current_header, ''.join(current_sequence))) | |
return sequences | |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray: | |
kmers = [''.join(p) for p in product("ACGT", repeat=k)] | |
kmer_dict = {km: i for i, km in enumerate(kmers)} | |
vec = np.zeros(len(kmers), dtype=np.float32) | |
for i in range(len(sequence) - k + 1): | |
kmer = sequence[i:i+k] | |
if kmer in kmer_dict: | |
vec[kmer_dict[kmer]] += 1 | |
total_kmers = len(sequence) - k + 1 | |
if total_kmers > 0: | |
vec /= total_kmers | |
return vec | |
############################################################################### | |
# 3. SHAP-VALUE (ABLATION) CALCULATION | |
############################################################################### | |
def calculate_shap_values(model, x_tensor): | |
model.eval() | |
with torch.no_grad(): | |
baseline_output = model(x_tensor) | |
baseline_probs = torch.softmax(baseline_output, dim=1) | |
baseline_prob = baseline_probs[0, 1].item() # Prob of 'human' | |
shap_values = [] | |
x_zeroed = x_tensor.clone() | |
for i in range(x_tensor.shape[1]): | |
original_val = x_zeroed[0, i].item() | |
x_zeroed[0, i] = 0.0 | |
output = model(x_zeroed) | |
probs = torch.softmax(output, dim=1) | |
prob = probs[0, 1].item() | |
shap_values.append(baseline_prob - prob) | |
x_zeroed[0, i] = original_val | |
return np.array(shap_values), baseline_prob | |
############################################################################### | |
# 4. PER-BASE SHAP AGGREGATION | |
############################################################################### | |
def compute_positionwise_scores(sequence, shap_values, k=4): | |
kmers = [''.join(p) for p in product("ACGT", repeat=k)] | |
kmer_dict = {km: i for i, km in enumerate(kmers)} | |
seq_len = len(sequence) | |
shap_sums = np.zeros(seq_len, dtype=np.float32) | |
coverage = np.zeros(seq_len, dtype=np.float32) | |
for i in range(seq_len - k + 1): | |
kmer = sequence[i:i+k] | |
if kmer in kmer_dict: | |
val = shap_values[kmer_dict[kmer]] | |
shap_sums[i:i+k] += val | |
coverage[i:i+k] += 1 | |
with np.errstate(divide='ignore', invalid='ignore'): | |
shap_means = np.where(coverage > 0, shap_sums / coverage, 0.0) | |
return shap_means | |
############################################################################### | |
# 5. FIND EXTREME SHAP REGIONS | |
############################################################################### | |
def find_extreme_subregion(shap_means, window_size=500, mode="max"): | |
n = len(shap_means) | |
if n == 0: | |
return (0, 0, 0.0) | |
if window_size >= n: | |
return (0, n, float(np.mean(shap_means))) | |
csum = np.zeros(n + 1, dtype=np.float32) | |
csum[1:] = np.cumsum(shap_means) | |
best_start = 0 | |
best_sum = csum[window_size] - csum[0] | |
best_avg = best_sum / window_size | |
for start in range(1, n - window_size + 1): | |
wsum = csum[start + window_size] - csum[start] | |
wavg = wsum / window_size | |
if mode == "max" and wavg > best_avg: | |
best_avg = wavg | |
best_start = start | |
elif mode == "min" and wavg < best_avg: | |
best_avg = wavg | |
best_start = start | |
return (best_start, best_start + window_size, float(best_avg)) | |
############################################################################### | |
# 6. PLOTTING / UTILITIES | |
############################################################################### | |
def fig_to_image(fig): | |
buf = io.BytesIO() | |
fig.savefig(buf, format='png', bbox_inches='tight', dpi=150) | |
buf.seek(0) | |
img = Image.open(buf) | |
plt.close(fig) | |
return img | |
def get_zero_centered_cmap(): | |
colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')] | |
return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors) | |
def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None): | |
if start is not None and end is not None: | |
local_shap = shap_means[start:end] | |
subtitle = f" (positions {start}-{end})" | |
else: | |
local_shap = shap_means | |
subtitle = "" | |
if len(local_shap) == 0: | |
local_shap = np.array([0.0]) | |
heatmap_data = local_shap.reshape(1, -1) | |
min_val = np.min(local_shap) | |
max_val = np.max(local_shap) | |
extent = max(abs(min_val), abs(max_val)) | |
cmap = get_zero_centered_cmap() | |
fig, ax = plt.subplots(figsize=(12, 1.8)) | |
cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent) | |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8) | |
cbar.ax.tick_params(labelsize=8) | |
cbar.set_label('SHAP Contribution', fontsize=9, labelpad=5) | |
ax.set_yticks([]) | |
ax.set_xlabel('Position in Sequence', fontsize=10) | |
ax.set_title(f"{title}{subtitle}", pad=10) | |
plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95) | |
return fig | |
def create_importance_bar_plot(shap_values, kmers, top_k=10): | |
plt.rcParams.update({'font.size': 10}) | |
fig = plt.figure(figsize=(10, 5)) | |
indices = np.argsort(np.abs(shap_values))[-top_k:] | |
values = shap_values[indices] | |
features = [kmers[i] for i in indices] | |
colors = ['#99ccff' if v < 0 else '#ff9999' for v in values] | |
plt.barh(range(len(values)), values, color=colors) | |
plt.yticks(range(len(values)), features) | |
plt.xlabel('SHAP Value (impact on model output)') | |
plt.title(f'Top {top_k} Most Influential k-mers') | |
plt.gca().invert_yaxis() | |
plt.tight_layout() | |
return fig | |
def plot_shap_histogram(shap_array, title="SHAP Distribution in Region", num_bins=30): | |
fig, ax = plt.subplots(figsize=(6, 4)) | |
ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black') | |
ax.axvline(0, color='red', linestyle='--', label='0.0') | |
ax.set_xlabel("SHAP Value") | |
ax.set_ylabel("Count") | |
ax.set_title(title) | |
ax.legend() | |
plt.tight_layout() | |
return fig | |
def compute_gc_content(sequence): | |
if not sequence: | |
return 0 | |
gc_count = sequence.count('G') + sequence.count('C') | |
return (gc_count / len(sequence)) * 100.0 | |
############################################################################### | |
# 7. MAIN ANALYSIS STEP (Gradio Step 1) | |
############################################################################### | |
def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500): | |
if fasta_text.strip(): | |
text = fasta_text.strip() | |
elif file_obj is not None: | |
try: | |
with open(file_obj, 'r') as f: | |
text = f.read() | |
except Exception as e: | |
return (f"Error reading file: {str(e)}", None, None, None, None, None) | |
else: | |
return ("Please provide a FASTA sequence.", None, None, None, None, None) | |
sequences = parse_fasta(text) | |
if not sequences: | |
return ("No valid FASTA sequences found.", None, None, None, None, None) | |
header, seq = sequences[0] | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
try: | |
# IMPORTANT: adjust how you load your model as needed | |
state_dict = torch.load('model.pt', map_location=device) | |
model = VirusClassifier(256).to(device) | |
model.load_state_dict(state_dict) | |
scaler = joblib.load('scaler.pkl') | |
except Exception as e: | |
return (f"Error loading model/scaler: {str(e)}", None, None, None, None, None) | |
freq_vector = sequence_to_kmer_vector(seq) | |
scaled_vector = scaler.transform(freq_vector.reshape(1, -1)) | |
x_tensor = torch.FloatTensor(scaled_vector).to(device) | |
shap_values, prob_human = calculate_shap_values(model, x_tensor) | |
prob_nonhuman = 1.0 - prob_human | |
classification = "Human" if prob_human > 0.5 else "Non-human" | |
confidence = max(prob_human, prob_nonhuman) | |
shap_means = compute_positionwise_scores(seq, shap_values, k=4) | |
max_start, max_end, max_avg = find_extreme_subregion(shap_means, window_size, mode="max") | |
min_start, min_end, min_avg = find_extreme_subregion(shap_means, window_size, mode="min") | |
results_text = ( | |
f"Sequence: {header}\n" | |
f"Length: {len(seq):,} bases\n" | |
f"Classification: {classification}\n" | |
f"Confidence: {confidence:.3f}\n" | |
f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})\n\n" | |
f"---\n" | |
f"**Most Human-Pushing {window_size}-bp Subregion**:\n" | |
f"Start: {max_start}, End: {max_end}, Avg SHAP: {max_avg:.4f}\n\n" | |
f"**Most Non-Human–Pushing {window_size}-bp Subregion**:\n" | |
f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}" | |
) | |
kmers = [''.join(p) for p in product("ACGT", repeat=4)] | |
bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers) | |
bar_img = fig_to_image(bar_fig) | |
heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP") | |
heatmap_img = fig_to_image(heatmap_fig) | |
# You might want to provide a CSV or other data for the 6th return item | |
# Here, we'll simply return None for the file download: | |
state_dict_out = {"seq": seq, "shap_means": shap_means} | |
return (results_text, bar_img, heatmap_img, state_dict_out, header, None) | |
############################################################################### | |
# 8. SUBREGION ANALYSIS (Gradio Step 2) | |
############################################################################### | |
def analyze_subregion(state, header, region_start, region_end): | |
if not state or "seq" not in state or "shap_means" not in state: | |
return ("No sequence data found. Please run Step 1 first.", None, None, None) | |
seq = state["seq"] | |
shap_means = state["shap_means"] | |
region_start = int(region_start) | |
region_end = int(region_end) | |
region_start = max(0, min(region_start, len(seq))) | |
region_end = max(0, min(region_end, len(seq))) | |
if region_end <= region_start: | |
return ("Invalid region range. End must be > Start.", None, None, None) | |
region_seq = seq[region_start:region_end] | |
region_shap = shap_means[region_start:region_end] | |
gc_percent = compute_gc_content(region_seq) | |
avg_shap = float(np.mean(region_shap)) | |
positive_fraction = np.mean(region_shap > 0) | |
negative_fraction = np.mean(region_shap < 0) | |
if avg_shap > 0.05: | |
region_classification = "Likely pushing toward human" | |
elif avg_shap < -0.05: | |
region_classification = "Likely pushing toward non-human" | |
else: | |
region_classification = "Near neutral (no strong push)" | |
region_info = ( | |
f"Analyzing subregion of {header} from {region_start} to {region_end}\n" | |
f"Region length: {len(region_seq)} bases\n" | |
f"GC content: {gc_percent:.2f}%\n" | |
f"Average SHAP in region: {avg_shap:.4f}\n" | |
f"Fraction with SHAP > 0 (toward human): {positive_fraction:.2f}\n" | |
f"Fraction with SHAP < 0 (toward non-human): {negative_fraction:.2f}\n" | |
f"Subregion interpretation: {region_classification}\n" | |
) | |
heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion SHAP", start=region_start, end=region_end) | |
heatmap_img = fig_to_image(heatmap_fig) | |
hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion") | |
hist_img = fig_to_image(hist_fig) | |
# For demonstration, returning None for the file download as well | |
return (region_info, heatmap_img, hist_img, None) | |
############################################################################### | |
# 9. COMPARISON ANALYSIS FUNCTIONS | |
############################################################################### | |
def get_zero_centered_cmap(): | |
"""Create a zero-centered blue-white-red colormap""" | |
colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')] | |
return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors) | |
def compute_shap_difference(shap1_norm, shap2_norm): | |
"""Compute the SHAP difference between normalized sequences""" | |
return shap2_norm - shap1_norm | |
def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"): | |
""" | |
Plot heatmap using relative positions (0-100%) | |
""" | |
heatmap_data = shap_diff.reshape(1, -1) | |
extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff))) | |
fig, ax = plt.subplots(figsize=(12, 1.8)) | |
cmap = get_zero_centered_cmap() | |
cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent) | |
# Create percentage-based x-axis ticks | |
num_ticks = 5 | |
tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks) | |
tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)] | |
ax.set_xticks(tick_positions) | |
ax.set_xticklabels(tick_labels) | |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8) | |
cbar.ax.tick_params(labelsize=8) | |
cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5) | |
ax.set_yticks([]) | |
ax.set_xlabel('Relative Position in Sequence', fontsize=10) | |
ax.set_title(title, pad=10) | |
plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95) | |
return fig | |
def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30): | |
""" | |
Plot histogram of SHAP values with configurable number of bins | |
""" | |
fig, ax = plt.subplots(figsize=(6, 4)) | |
ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black', alpha=0.7) | |
ax.axvline(0, color='red', linestyle='--', label='0.0') | |
ax.set_xlabel("SHAP Value") | |
ax.set_ylabel("Count") | |
ax.set_title(title) | |
ax.legend() | |
plt.tight_layout() | |
return fig | |
def calculate_adaptive_parameters(len1, len2): | |
""" | |
Calculate adaptive parameters based on sequence lengths and their difference. | |
Returns: (num_points, smooth_window, resolution_factor) | |
""" | |
length_diff = abs(len1 - len2) | |
max_length = max(len1, len2) | |
min_length = min(len1, len2) | |
length_ratio = min_length / max_length | |
# Base number of points scales with sequence length | |
base_points = min(2000, max(500, max_length // 100)) | |
# Adjust parameters based on sequence properties | |
if length_diff < 500: | |
resolution_factor = 2.0 | |
num_points = min(3000, base_points * 2) | |
smooth_window = max(10, length_diff // 50) | |
elif length_diff < 5000: | |
resolution_factor = 1.5 | |
num_points = min(2000, base_points * 1.5) | |
smooth_window = max(20, length_diff // 100) | |
elif length_diff < 50000: | |
resolution_factor = 1.0 | |
num_points = base_points | |
smooth_window = max(50, length_diff // 200) | |
else: | |
resolution_factor = 0.75 | |
num_points = max(500, base_points // 2) | |
smooth_window = max(100, length_diff // 500) | |
# Adjust window size based on length ratio | |
smooth_window = int(smooth_window * (1 + (1 - length_ratio))) | |
return int(num_points), int(smooth_window), resolution_factor | |
def sliding_window_smooth(values, window_size=50): | |
""" | |
Apply sliding window smoothing with edge handling | |
""" | |
if window_size < 3: | |
return values | |
# Create window with exponential decay at edges | |
window = np.ones(window_size) | |
decay = np.exp(-np.linspace(0, 3, window_size // 2)) | |
window[:window_size // 2] = decay | |
window[-(window_size // 2):] = decay[::-1] | |
window = window / window.sum() | |
# Apply convolution | |
smoothed = np.convolve(values, window, mode='valid') | |
# Handle edges | |
pad_size = len(values) - len(smoothed) | |
pad_left = pad_size // 2 | |
pad_right = pad_size - pad_left | |
result = np.zeros_like(values) | |
result[pad_left:-pad_right] = smoothed | |
result[:pad_left] = values[:pad_left] | |
result[-pad_right:] = values[-pad_right:] | |
return result | |
def normalize_shap_lengths(shap1, shap2): | |
""" | |
Normalize and smooth SHAP values with dynamic adaptation | |
""" | |
# Calculate adaptive parameters | |
num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2)) | |
# Apply initial smoothing | |
shap1_smooth = sliding_window_smooth(shap1, smooth_window) | |
shap2_smooth = sliding_window_smooth(shap2, smooth_window) | |
# Create relative positions and interpolate | |
x1 = np.linspace(0, 1, len(shap1_smooth)) | |
x2 = np.linspace(0, 1, len(shap2_smooth)) | |
x_norm = np.linspace(0, 1, num_points) | |
shap1_interp = np.interp(x_norm, x1, shap1_smooth) | |
shap2_interp = np.interp(x_norm, x2, shap2_smooth) | |
return shap1_interp, shap2_interp, smooth_window | |
def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""): | |
""" | |
Compare two sequences with adaptive parameters and visualization | |
""" | |
try: | |
# Analyze first sequence | |
res1 = analyze_sequence(file1, top_kmers=10, fasta_text=fasta1, window_size=500) | |
if isinstance(res1[0], str) and "Error" in res1[0]: | |
return (f"Error in sequence 1: {res1[0]}", None, None, None) | |
# Analyze second sequence | |
res2 = analyze_sequence(file2, top_kmers=10, fasta_text=fasta2, window_size=500) | |
if isinstance(res2[0], str) and "Error" in res2[0]: | |
return (f"Error in sequence 2: {res2[0]}", None, None, None) | |
# Extract SHAP values and sequence info | |
shap1 = res1[3]["shap_means"] | |
shap2 = res2[3]["shap_means"] | |
# Calculate sequence properties | |
len1, len2 = len(shap1), len(shap2) | |
length_diff = abs(len1 - len2) | |
length_ratio = min(len1, len2) / max(len1, len2) | |
# Normalize and compare sequences | |
shap1_norm, shap2_norm, smooth_window = normalize_shap_lengths(shap1, shap2) | |
shap_diff = compute_shap_difference(shap1_norm, shap2_norm) | |
# Calculate adaptive threshold and statistics | |
base_threshold = 0.05 | |
adaptive_threshold = base_threshold * (1 + (1 - length_ratio)) | |
if length_diff > 50000: | |
adaptive_threshold *= 1.5 | |
# Calculate comparison statistics | |
avg_diff = np.mean(shap_diff) | |
std_diff = np.std(shap_diff) | |
max_diff = np.max(shap_diff) | |
min_diff = np.min(shap_diff) | |
substantial_diffs = np.abs(shap_diff) > adaptive_threshold | |
frac_different = np.mean(substantial_diffs) | |
# Extract classifications | |
try: | |
classification1 = res1[0].split('Classification: ')[1].split('\n')[0].strip() | |
classification2 = res2[0].split('Classification: ')[1].split('\n')[0].strip() | |
except: | |
classification1 = "Unknown" | |
classification2 = "Unknown" | |
# Format output text | |
comparison_text = ( | |
"Sequence Comparison Results:\n" | |
f"Sequence 1: {res1[4]}\n" | |
f"Length: {len1:,} bases\n" | |
f"Classification: {classification1}\n\n" | |
f"Sequence 2: {res2[4]}\n" | |
f"Length: {len2:,} bases\n" | |
f"Classification: {classification2}\n\n" | |
"Comparison Parameters:\n" | |
f"Length Difference: {length_diff:,} bases\n" | |
f"Length Ratio: {length_ratio:.3f}\n" | |
f"Smoothing Window: {smooth_window} points\n" | |
f"Adaptive Threshold: {adaptive_threshold:.3f}\n\n" | |
"Statistics:\n" | |
f"Average SHAP difference: {avg_diff:.4f}\n" | |
f"Standard deviation: {std_diff:.4f}\n" | |
f"Max difference: {max_diff:.4f} (Seq2 more human-like)\n" | |
f"Min difference: {min_diff:.4f} (Seq1 more human-like)\n" | |
f"Fraction with substantial differences: {frac_different:.2%}\n\n" | |
"Note: All parameters automatically adjusted based on sequence properties\n\n" | |
"Interpretation:\n" | |
"- Red regions: Sequence 2 more human-like\n" | |
"- Blue regions: Sequence 1 more human-like\n" | |
"- White regions: Similar between sequences" | |
) | |
# Generate visualizations | |
heatmap_fig = plot_comparative_heatmap( | |
shap_diff, | |
title=f"SHAP Difference Heatmap (window: {smooth_window})" | |
) | |
heatmap_img = fig_to_image(heatmap_fig) | |
# Create histogram with adaptive bins | |
num_bins = max(20, min(50, int(np.sqrt(len(shap_diff))))) | |
hist_fig = plot_shap_histogram( | |
shap_diff, | |
title="Distribution of SHAP Differences", | |
num_bins=num_bins | |
) | |
hist_img = fig_to_image(hist_fig) | |
# Return 4 outputs (text, image, image, and a file or None for the last) | |
return (comparison_text, heatmap_img, hist_img, None) | |
except Exception as e: | |
error_msg = f"Error during sequence comparison: {str(e)}" | |
return (error_msg, None, None, None) | |
############################################################################### | |
# 11. GENE FEATURE ANALYSIS | |
############################################################################### | |
import io | |
from io import BytesIO | |
from PIL import Image, ImageDraw, ImageFont | |
import numpy as np | |
import pandas as pd | |
import tempfile | |
import os | |
from typing import List, Dict, Tuple, Optional, Any | |
import matplotlib.pyplot as plt | |
from matplotlib.colors import LinearSegmentedColormap | |
import seaborn as sns | |
def parse_gene_features(text: str) -> List[Dict[str, Any]]: | |
"""Parse gene features from text file in FASTA-like format""" | |
genes = [] | |
current_header = None | |
current_sequence = [] | |
for line in text.strip().split('\n'): | |
line = line.strip() | |
if not line: | |
continue | |
if line.startswith('>'): | |
if current_header: | |
genes.append({ | |
'header': current_header, | |
'sequence': ''.join(current_sequence), | |
'metadata': parse_gene_metadata(current_header) | |
}) | |
current_header = line[1:] | |
current_sequence = [] | |
else: | |
current_sequence.append(line.upper()) | |
if current_header: | |
genes.append({ | |
'header': current_header, | |
'sequence': ''.join(current_sequence), | |
'metadata': parse_gene_metadata(current_header) | |
}) | |
return genes | |
def parse_gene_metadata(header: str) -> Dict[str, str]: | |
"""Extract metadata from gene header""" | |
metadata = {} | |
parts = header.split() | |
for part in parts: | |
if '[' in part and ']' in part: | |
key_value = part[1:-1].split('=', 1) | |
if len(key_value) == 2: | |
metadata[key_value[0]] = key_value[1] | |
return metadata | |
def parse_location(location_str: str) -> Tuple[Optional[int], Optional[int]]: | |
"""Parse gene location string, handling both forward and complement strands""" | |
try: | |
# Remove 'complement(' and ')' if present | |
clean_loc = location_str.replace('complement(', '').replace(')', '') | |
# Split on '..' and convert to integers | |
if '..' in clean_loc: | |
start, end = map(int, clean_loc.split('..')) | |
return start, end | |
else: | |
return None, None | |
except Exception as e: | |
print(f"Error parsing location {location_str}: {str(e)}") | |
return None, None | |
def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]: | |
"""Compute statistical measures for gene SHAP values""" | |
return { | |
'avg_shap': float(np.mean(gene_shap)), | |
'median_shap': float(np.median(gene_shap)), | |
'std_shap': float(np.std(gene_shap)), | |
'max_shap': float(np.max(gene_shap)), | |
'min_shap': float(np.min(gene_shap)), | |
'pos_fraction': float(np.mean(gene_shap > 0)) | |
} | |
def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image: | |
""" | |
Create a simple genome diagram using PIL, forcing a minimum color intensity | |
so that small SHAP values don't appear white. | |
""" | |
from PIL import Image, ImageDraw, ImageFont | |
# Validate inputs | |
if not gene_results or genome_length <= 0: | |
img = Image.new('RGB', (800, 100), color='white') | |
draw = ImageDraw.Draw(img) | |
draw.text((10, 40), "Error: Invalid input data", fill='black') | |
return img | |
# Ensure all gene coordinates are valid integers | |
for gene in gene_results: | |
gene['start'] = max(0, int(gene['start'])) | |
gene['end'] = min(genome_length, int(gene['end'])) | |
if gene['start'] >= gene['end']: | |
print(f"Warning: Invalid coordinates for gene {gene.get('gene_name','?')}: {gene['start']}-{gene['end']}") | |
# Image dimensions | |
width = 1500 | |
height = 600 | |
margin = 50 | |
track_height = 40 | |
# Create image with white background | |
img = Image.new('RGB', (width, height), 'white') | |
draw = ImageDraw.Draw(img) | |
# Try to load font, fall back to default if unavailable | |
try: | |
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12) | |
title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16) | |
except: | |
font = ImageFont.load_default() | |
title_font = ImageFont.load_default() | |
# Draw title | |
draw.text((margin, margin // 2), "Genome SHAP Analysis", fill='black', font=title_font or font) | |
# Draw genome line | |
line_y = height // 2 | |
draw.line([(int(margin), int(line_y)), (int(width - margin), int(line_y))], fill='black', width=2) | |
# Calculate scale factor | |
scale = float(width - 2 * margin) / float(genome_length) | |
# Determine a reasonable step for scale markers | |
num_ticks = 10 | |
if genome_length < num_ticks: | |
step = 1 | |
else: | |
step = genome_length // num_ticks | |
# Draw scale markers | |
for i in range(0, genome_length + 1, step): | |
x_coord = margin + i * scale | |
draw.line([ | |
(int(x_coord), int(line_y - 5)), | |
(int(x_coord), int(line_y + 5)) | |
], fill='black', width=1) | |
draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font) | |
# Sort genes by absolute SHAP value for drawing | |
sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap'])) | |
# Draw genes | |
for idx, gene in enumerate(sorted_genes): | |
# Calculate position and ensure integers | |
start_x = margin + int(gene['start'] * scale) | |
end_x = margin + int(gene['end'] * scale) | |
# Calculate color based on SHAP value | |
avg_shap = gene['avg_shap'] | |
# Convert shap -> color intensity (0 to 255) | |
# Then clamp to a minimum intensity so it never ends up plain white | |
intensity = int(abs(avg_shap) * 500) | |
intensity = max(50, min(255, intensity)) # clamp between 50 and 255 | |
if avg_shap > 0: | |
# Red-ish for positive | |
color = (255, 255 - intensity, 255 - intensity) | |
else: | |
# Blue-ish for negative or zero | |
color = (255 - intensity, 255 - intensity, 255) | |
# Draw gene rectangle | |
draw.rectangle([ | |
(int(start_x), int(line_y - track_height // 2)), | |
(int(end_x), int(line_y + track_height // 2)) | |
], fill=color, outline='black') | |
# Prepare gene name label | |
label = str(gene.get('gene_name','?')) | |
# Fallback for label size | |
label_mask = font.getmask(label) | |
label_width, label_height = label_mask.size | |
# Alternate label positions | |
if idx % 2 == 0: | |
text_y = line_y - track_height - 15 | |
else: | |
text_y = line_y + track_height + 5 | |
# Decide whether to rotate text based on space | |
gene_width = end_x - start_x | |
if gene_width > label_width: | |
text_x = start_x + (gene_width - label_width) // 2 | |
draw.text((int(text_x), int(text_y)), label, fill='black', font=font) | |
elif gene_width > 20: | |
txt_img = Image.new('RGBA', (label_width, label_height), (255, 255, 255, 0)) | |
txt_draw = ImageDraw.Draw(txt_img) | |
txt_draw.text((0, 0), label, font=font, fill='black') | |
rotated_img = txt_img.rotate(90, expand=True) | |
img.paste(rotated_img, (int(start_x), int(text_y)), rotated_img) | |
# Draw legend | |
legend_x = margin | |
legend_y = height - margin | |
draw.text((int(legend_x), int(legend_y - 60)), "SHAP Values:", fill='black', font=font) | |
# Draw legend boxes | |
box_width = 20 | |
box_height = 20 | |
spacing = 15 | |
# Strong human-like | |
draw.rectangle([ | |
(int(legend_x), int(legend_y - 45)), | |
(int(legend_x + box_width), int(legend_y - 45 + box_height)) | |
], fill=(255, 0, 0), outline='black') | |
draw.text((int(legend_x + box_width + spacing), int(legend_y - 45)), | |
"Strong human-like signal", fill='black', font=font) | |
# Weak human-like | |
draw.rectangle([ | |
(int(legend_x), int(legend_y - 20)), | |
(int(legend_x + box_width), int(legend_y - 20 + box_height)) | |
], fill=(255, 200, 200), outline='black') | |
draw.text((int(legend_x + box_width + spacing), int(legend_y - 20)), | |
"Weak human-like signal", fill='black', font=font) | |
# Weak non-human-like | |
draw.rectangle([ | |
(int(legend_x + 250), int(legend_y - 45)), | |
(int(legend_x + 250 + box_width), int(legend_y - 45 + box_height)) | |
], fill=(200, 200, 255), outline='black') | |
draw.text((int(legend_x + 250 + box_width + spacing), int(legend_y - 45)), | |
"Weak non-human-like signal", fill='black', font=font) | |
# Strong non-human-like | |
draw.rectangle([ | |
(int(legend_x + 250), int(legend_y - 20)), | |
(int(legend_x + 250 + box_width), int(legend_y - 20 + box_height)) | |
], fill=(0, 0, 255), outline='black') | |
draw.text((int(legend_x + 250 + box_width + spacing), int(legend_y - 20)), | |
"Strong non-human-like signal", fill='black', font=font) | |
return img | |
def analyze_gene_features(sequence_file: str, | |
features_file: str, | |
fasta_text: str = "", | |
features_text: str = "") -> Tuple[str, Optional[str], Optional[Image.Image]]: | |
"""Analyze SHAP values for each gene feature""" | |
# First analyze whole sequence | |
sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text) | |
if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]: | |
return f"Error in sequence analysis: {sequence_results[0]}", None, None | |
# Get SHAP values | |
shap_means = sequence_results[3]["shap_means"] | |
# Parse gene features | |
try: | |
if features_text.strip(): | |
genes = parse_gene_features(features_text) | |
else: | |
with open(features_file, 'r') as f: | |
genes = parse_gene_features(f.read()) | |
except Exception as e: | |
return f"Error reading features file: {str(e)}", None, None | |
# Analyze each gene | |
gene_results = [] | |
for gene in genes: | |
try: | |
location = gene['metadata'].get('location', '') | |
if not location: | |
continue | |
start, end = parse_location(location) | |
if start is None or end is None: | |
continue | |
# Get SHAP values for this region | |
gene_shap = shap_means[start:end] | |
stats = compute_gene_statistics(gene_shap) | |
gene_results.append({ | |
'gene_name': gene['metadata'].get('gene', 'Unknown'), | |
'location': location, | |
'start': start, | |
'end': end, | |
'locus_tag': gene['metadata'].get('locus_tag', ''), | |
'avg_shap': stats['avg_shap'], | |
'median_shap': stats['median_shap'], | |
'std_shap': stats['std_shap'], | |
'max_shap': stats['max_shap'], | |
'min_shap': stats['min_shap'], | |
'pos_fraction': stats['pos_fraction'], | |
'classification': 'Human' if stats['avg_shap'] > 0 else 'Non-human', | |
'confidence': abs(stats['avg_shap']) | |
}) | |
except Exception as e: | |
print(f"Error processing gene {gene['metadata'].get('gene', 'Unknown')}: {str(e)}") | |
continue | |
if not gene_results: | |
return "No valid genes could be processed", None, None | |
# Sort genes by absolute SHAP value | |
sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True) | |
# Create results text | |
results_text = "Gene Analysis Results:\n\n" | |
results_text += f"Total genes analyzed: {len(gene_results)}\n" | |
results_text += f"Human-like genes: {sum(1 for g in gene_results if g['classification'] == 'Human')}\n" | |
results_text += f"Non-human-like genes: {sum(1 for g in gene_results if g['classification'] == 'Non-human')}\n\n" | |
results_text += "Top 10 most distinctive genes:\n" | |
for gene in sorted_genes[:10]: | |
results_text += ( | |
f"Gene: {gene['gene_name']}\n" | |
f"Location: {gene['location']}\n" | |
f"Classification: {gene['classification']} " | |
f"(confidence: {gene['confidence']:.4f})\n" | |
f"Average SHAP: {gene['avg_shap']:.4f}\n\n" | |
) | |
# Create CSV content | |
csv_content = "gene_name,location,avg_shap,median_shap,std_shap,max_shap,min_shap," | |
csv_content += "pos_fraction,classification,confidence,locus_tag\n" | |
for gene in gene_results: | |
csv_content += ( | |
f"{gene['gene_name']},{gene['location']},{gene['avg_shap']:.4f}," | |
f"{gene['median_shap']:.4f},{gene['std_shap']:.4f},{gene['max_shap']:.4f}," | |
f"{gene['min_shap']:.4f},{gene['pos_fraction']:.4f},{gene['classification']}," | |
f"{gene['confidence']:.4f},{gene['locus_tag']}\n" | |
) | |
# Save CSV to temp file | |
try: | |
temp_dir = tempfile.gettempdir() | |
temp_path = os.path.join(temp_dir, f"gene_analysis_{os.urandom(4).hex()}.csv") | |
with open(temp_path, 'w') as f: | |
f.write(csv_content) | |
except Exception as e: | |
print(f"Error saving CSV: {str(e)}") | |
temp_path = None | |
# Create visualization | |
try: | |
diagram_img = create_simple_genome_diagram(gene_results, len(shap_means)) | |
except Exception as e: | |
print(f"Error creating visualization: {str(e)}") | |
# Create error image | |
diagram_img = Image.new('RGB', (800, 100), color='white') | |
draw = ImageDraw.Draw(diagram_img) | |
draw.text((10, 40), f"Error creating visualization: {str(e)}", fill='black') | |
return results_text, temp_path, diagram_img | |
############################################################################### | |
# 12. DOWNLOAD FUNCTIONS | |
############################################################################### | |
def prepare_csv_download(data, filename="analysis_results.csv"): | |
"""Prepare CSV data for download""" | |
if isinstance(data, str): | |
return data.encode(), filename | |
elif isinstance(data, (list, dict)): | |
import csv | |
from io import StringIO | |
output = StringIO() | |
writer = csv.DictWriter(output, fieldnames=data[0].keys()) | |
writer.writeheader() | |
writer.writerows(data) | |
return output.getvalue().encode(), filename | |
else: | |
raise ValueError("Unsupported data type for CSV download") | |
############################################################################### | |
# 13. BUILD GRADIO INTERFACE | |
############################################################################### | |
css = """ | |
.gradio-container { | |
font-family: 'IBM Plex Sans', sans-serif; | |
} | |
.download-button { | |
margin-top: 10px; | |
} | |
""" | |
with gr.Blocks(css=css) as iface: | |
gr.Markdown(""" | |
# Virus Host Classifier | |
**Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions. | |
**Step 2**: Explore subregions to see local SHAP signals, distribution, GC content, etc. | |
**Step 3**: Analyze gene features and their contributions. | |
**Step 4**: Compare sequences and analyze differences. | |
**Color Scale**: Negative SHAP = Blue, Zero = White, Positive SHAP = Red. | |
""") | |
with gr.Tab("1) Full-Sequence Analysis"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
file_input = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") | |
text_input = gr.Textbox(label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5) | |
top_k = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of top k-mers to display") | |
win_size = gr.Slider(minimum=100, maximum=5000, value=500, step=100, label="Window size for 'most pushing' subregions") | |
analyze_btn = gr.Button("Analyze Sequence", variant="primary") | |
with gr.Column(scale=2): | |
results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False) | |
kmer_img = gr.Image(label="Top k-mer SHAP") | |
genome_img = gr.Image(label="Genome-wide SHAP Heatmap (Blue=neg, White=0, Red=pos)") | |
download_results = gr.File(label="Download Results", visible=False, elem_classes="download-button") | |
seq_state = gr.State() | |
header_state = gr.State() | |
analyze_btn.click( | |
analyze_sequence, | |
inputs=[file_input, top_k, text_input, win_size], | |
outputs=[results_box, kmer_img, genome_img, seq_state, header_state, download_results] | |
) | |
with gr.Tab("2) Subregion Exploration"): | |
gr.Markdown(""" | |
**Subregion Analysis** | |
Select start/end positions to view local SHAP signals, distribution, GC content, etc. | |
The heatmap uses the same Blue-White-Red scale. | |
""") | |
with gr.Row(): | |
region_start = gr.Number(label="Region Start", value=0) | |
region_end = gr.Number(label="Region End", value=500) | |
region_btn = gr.Button("Analyze Subregion") | |
subregion_info = gr.Textbox(label="Subregion Analysis", lines=7, interactive=False) | |
with gr.Row(): | |
subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)") | |
subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)") | |
download_subregion = gr.File(label="Download Subregion Analysis", visible=False, elem_classes="download-button") | |
region_btn.click( | |
analyze_subregion, | |
inputs=[seq_state, header_state, region_start, region_end], | |
outputs=[subregion_info, subregion_img, subregion_hist_img, download_subregion] | |
) | |
with gr.Tab("3) Gene Features Analysis"): | |
gr.Markdown(""" | |
**Analyze Gene Features** | |
Upload a FASTA file and corresponding gene features file to analyze SHAP values per gene. | |
Gene features should be in the format: | |
``` | |
>gene_name [gene=X] [locus_tag=Y] [location=start..end] or [location=complement(start..end)] | |
SEQUENCE | |
``` | |
The genome viewer will show genes color-coded by their contribution: | |
- Red: Genes pushing toward human origin | |
- Blue: Genes pushing toward non-human origin | |
- Color intensity indicates strength of signal | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gene_fasta_file = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") | |
gene_fasta_text = gr.Textbox(label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5) | |
with gr.Column(scale=1): | |
features_file = gr.File(label="Upload gene features file", file_types=[".txt"], type="filepath") | |
features_text = gr.Textbox(label="Or paste gene features", placeholder=">gene_1 [gene=U12]...\nACGT...", lines=5) | |
analyze_genes_btn = gr.Button("Analyze Gene Features", variant="primary") | |
gene_results = gr.Textbox(label="Gene Analysis Results", lines=12, interactive=False) | |
gene_diagram = gr.Image(label="Genome Diagram with Gene Features") | |
download_gene_results = gr.File(label="Download Gene Analysis (CSV)", visible=True) | |
analyze_genes_btn.click( | |
analyze_gene_features, | |
inputs=[gene_fasta_file, features_file, gene_fasta_text, features_text], | |
outputs=[gene_results, download_gene_results, gene_diagram] | |
) | |
with gr.Tab("4) Comparative Analysis"): | |
gr.Markdown(""" | |
**Compare Two Sequences** | |
Upload or paste two FASTA sequences to compare their SHAP patterns. | |
The sequences will be normalized to the same length for comparison. | |
**Color Scale**: | |
- Red: Sequence 2 more human-like | |
- Blue: Sequence 1 more human-like | |
- White: No substantial difference | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
file_input1 = gr.File(label="Upload first FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") | |
text_input1 = gr.Textbox(label="Or paste first FASTA sequence", placeholder=">sequence1\nACGTACGT...", lines=5) | |
with gr.Column(scale=1): | |
file_input2 = gr.File(label="Upload second FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") | |
text_input2 = gr.Textbox(label="Or paste second FASTA sequence", placeholder=">sequence2\nACGTACGT...", lines=5) | |
compare_btn = gr.Button("Compare Sequences", variant="primary") | |
comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False) | |
with gr.Row(): | |
diff_heatmap = gr.Image(label="SHAP Difference Heatmap") | |
diff_hist = gr.Image(label="Distribution of SHAP Differences") | |
download_comparison = gr.File(label="Download Comparison Results", visible=False, elem_classes="download-button") | |
compare_btn.click( | |
analyze_sequence_comparison, | |
inputs=[file_input1, file_input2, text_input1, text_input2], | |
outputs=[comparison_text, diff_heatmap, diff_hist, download_comparison] | |
) | |
gr.Markdown(""" | |
### Interface Features | |
- **Overall Classification** (human vs non-human) using k-mer frequencies | |
- **SHAP Analysis** shows which k-mers push classification toward or away from human | |
- **White-Centered SHAP Gradient**: | |
- Negative (blue), 0 (white), Positive (red) | |
- Symmetrical color range around 0 | |
- **Identify Subregions** with strongest push for human or non-human | |
- **Gene Feature Analysis**: | |
- Analyze individual genes' contributions | |
- Interactive genome viewer | |
- Gene-level statistics and classification | |
- **Sequence Comparison**: | |
- Compare two sequences to identify regions of difference | |
- Normalized comparison to handle different lengths | |
- Statistical summary of differences | |
- **Data Export**: | |
- Download results as CSV files | |
- Save analysis outputs for further processing | |
""") | |
if __name__ == "__main__": | |
iface.launch() | |