Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| import io, os, re, math, zipfile | |
| from typing import Dict, List, Tuple, Set, Optional | |
| import pandas as pd | |
| from Bio import SeqIO | |
| from statsmodels.stats.multitest import multipletests | |
| from scipy.stats import fisher_exact | |
| import matplotlib.pyplot as plt | |
| FA_EXT = (".fasta", ".fa", ".fas", ".fna") | |
| def _read_fasta_bytes(name: str, data: bytes) -> List[Tuple[str, str, str]]: | |
| recs = [] | |
| with io.BytesIO(data) as bio: | |
| for rec in SeqIO.parse(io.TextIOWrapper(bio, encoding="utf-8"), "fasta"): | |
| header = str(rec.id) | |
| seq = str(rec.seq).upper().replace("\n", "").replace("\r", "") | |
| recs.append((name, header, seq)) | |
| return recs | |
| def read_uploaded_fasta_or_zip(uploaded_file) -> List[Tuple[str, str, str]]: | |
| if uploaded_file is None: | |
| return [] | |
| name = uploaded_file.name | |
| data = uploaded_file.read() | |
| if name.lower().endswith(".zip"): | |
| results = [] | |
| with zipfile.ZipFile(io.BytesIO(data)) as z: | |
| for zi in z.infolist(): | |
| if zi.is_dir(): continue | |
| if not any(zi.filename.lower().endswith(ext) for ext in FA_EXT): | |
| continue | |
| file_bytes = z.read(zi.filename) | |
| results.extend(_read_fasta_bytes(os.path.basename(zi.filename), file_bytes)) | |
| return results | |
| else: | |
| return _read_fasta_bytes(os.path.basename(name), data) | |
| def clean_protein(seq: str) -> str: | |
| return re.sub(r"[^ACDEFGHIKLMNPQRSTVWY]", "", seq.upper()) | |
| def clean_dna(seq: str) -> str: | |
| return re.sub(r"[^ACGTUN]", "", seq.upper()) | |
| def get_kmers_noN(sequence: str, k: int) -> List[str]: | |
| s = sequence | |
| out = [] | |
| L = len(s) | |
| for i in range(L - k + 1): | |
| kmer = s[i:i+k] | |
| if "N" not in kmer: | |
| out.append(kmer) | |
| return out | |
| def parse_k_input(k_input: str, default_single: int) -> List[int]: | |
| k_input = (k_input or "").strip() | |
| if not k_input: | |
| return [default_single] | |
| if "-" in k_input: | |
| a, b = k_input.split("-", 1) | |
| a = int(a.strip()); b = int(b.strip()) | |
| if a > b: a, b = b, a | |
| return list(range(a, b+1)) | |
| if "," in k_input: | |
| return [int(x.strip()) for x in k_input.split(",") if x.strip()] | |
| return [int(k_input)] | |
| def derive_serotype_names_from_sources(known_records: List[Tuple[str, str, str]]) -> Dict[str, str]: | |
| counts: Dict[str, int] = {} | |
| for src, header, _ in known_records: | |
| counts[src] = counts.get(src, 0) + 1 | |
| name_map: Dict[str, str] = {} | |
| for src, header, _ in known_records: | |
| if counts.get(src, 0) == 1: | |
| sero = os.path.splitext(os.path.basename(src))[0] | |
| else: | |
| sero = header.split()[0] | |
| name_map[header] = sero | |
| return name_map | |
| def compute_unique_kmers_per_serotype(serotype_to_seq: Dict[str, str], is_protein: bool, k_values: List[int]) -> Dict[str, Dict[int, Set[str]]]: | |
| all_sets: Dict[str, Dict[int, Set[str]]] = {g: {} for g in serotype_to_seq} | |
| for g, seq in serotype_to_seq.items(): | |
| seq = clean_protein(seq) if is_protein else clean_dna(seq) | |
| for k in k_values: | |
| all_sets[g][k] = set(get_kmers_noN(seq, k)) | |
| unique: Dict[str, Dict[int, Set[str]]] = {g: {k: set() for k in k_values} for g in serotype_to_seq} | |
| for k in k_values: | |
| union_all = set().union(*(all_sets[g][k] for g in all_sets)) | |
| for g in all_sets: | |
| others_union = union_all - all_sets[g][k] | |
| unique[g][k] = all_sets[g][k] - others_union | |
| return unique | |
| def classify_unknown_sequences(unknown_records: List[Tuple[str, str, str]], unique_kmers: Dict[str, Dict[int, Set[str]]], is_protein: bool, fdr_alpha: float = 0.05) -> pd.DataFrame: | |
| vocab_by_sero: Dict[str, int] = {} | |
| k_values = sorted({k for g in unique_kmers for k in unique_kmers[g]}) | |
| for g in unique_kmers: | |
| vocab_by_sero[g] = sum(len(unique_kmers[g][k]) for k in k_values) | |
| results = [] | |
| for src, header, seq in unknown_records: | |
| seq2 = clean_protein(seq) if is_protein else clean_dna(seq) | |
| unk_kmers: Dict[int, Set[str]] = {} | |
| for k in k_values: | |
| unk_kmers[k] = set(get_kmers_noN(seq2, k)) | |
| match_counts: Dict[str, int] = {} | |
| total_matches = 0 | |
| for g in unique_kmers: | |
| mg = 0 | |
| for k in k_values: | |
| mg += len(unique_kmers[g][k].intersection(unk_kmers[k])) | |
| match_counts[g] = mg | |
| total_matches += mg | |
| if total_matches == 0: | |
| predicted = "NoMatch"; conf_present = 0.0; conf_vocab = 0.0 | |
| else: | |
| predicted = max(match_counts, key=match_counts.get) | |
| conf_present = match_counts[predicted] / total_matches | |
| conf_vocab = match_counts[predicted] / max(1, vocab_by_sero[predicted]) | |
| fisher_p = {} | |
| if total_matches > 0: | |
| sum_vocab_all = sum(vocab_by_sero.values()) | |
| for g in unique_kmers: | |
| a = match_counts[g] | |
| b = vocab_by_sero[g] - a | |
| c = total_matches - a | |
| d = (sum_vocab_all - vocab_by_sero[g]) - c | |
| a = max(0, a); b = max(0, b); c = max(0, c); d = max(0, d) | |
| _, p = fisher_exact([[a, b], [c, d]], alternative="greater") | |
| fisher_p[g] = p | |
| groups = list(unique_kmers.keys()) | |
| pvals = [fisher_p[g] for g in groups] | |
| _, qvals, _, _ = multipletests(pvals, alpha=fdr_alpha, method="fdr_bh") | |
| fdr_map = {g: q for g, q in zip(groups, qvals)} | |
| else: | |
| fisher_p = {g: 1.0 for g in unique_kmers} | |
| fdr_map = {g: 1.0 for g in unique_kmers} | |
| row = {"Source": src, "Sequence": header, "Predicted_serotype": predicted, "Matches_total": total_matches, "Confidence_by_present": conf_present, "Confidence_by_serotype_vocab": conf_vocab} | |
| for g in unique_kmers: | |
| row[f"Matches_{g}"] = match_counts[g] | |
| row[f"FisherP_{g}"] = fisher_p[g] | |
| row[f"FDR_{g}"] = fdr_map[g] | |
| results.append(row) | |
| return pd.DataFrame(results) | |
| def plot_counts_by_serotype(simple_df: pd.DataFrame): | |
| fig = plt.figure(figsize=(8,5)) | |
| ax = fig.add_subplot(111) | |
| counts = simple_df["Predicted_serotype"].value_counts() | |
| ax.bar(counts.index.astype(str), counts.values) | |
| ax.set_xlabel("Predicted serotype") | |
| ax.set_ylabel("Number of sequences") | |
| ax.set_title("Predicted serotype counts") | |
| fig.tight_layout() | |
| return fig | |