FtsI_Classifier / kmer_predict.py
Muhamed-Kheir's picture
Upload kmer_predict.py
137d7d8 verified
#!/usr/bin/env python3
"""
K-mer-based group prediction for unknown sequences.
Inputs:
- Unknown sequences: a FASTA file or a directory of FASTA files
- Unique k-mers: either
* a directory containing unique_k{k}_{group}.tsv/.txt files (from script #1), OR
* a ZIP file containing those files
Modes:
- fast: exact substring matching only (very fast)
- full: alignment-based matching (slower, more tolerant) + Fisher + FDR
Outputs:
- predictions_by_alignment.csv
- predicted_results_summary.png
Example:
python kmer_predict.py \
--unknown unknown_fastas/ \
--kmer-input kmer_results.zip \
--outdir pred_out \
--seqtype dna \
--mode fast
"""
from __future__ import annotations
import argparse
import os
import re
import shutil
import tempfile
import zipfile
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scipy.stats import fisher_exact
from statsmodels.stats.multitest import multipletests
from Bio import Align
from Bio.Align import substitution_matrices
FASTA_EXTS = (".fasta", ".fa", ".fas", ".fna")
KMER_FILE_EXTS = (".tsv", ".txt")
DEFAULT_GROUP_REGEX = r"unique_k\d+_(.+)\.(tsv|txt)$"
BLOSUM62 = substitution_matrices.load("BLOSUM62")
# ----------------------------
# FASTA utilities
# ----------------------------
def read_fasta(filepath: str) -> Tuple[List[str], List[str]]:
headers, seqs, seq = [], [], []
with open(filepath, "r", encoding="utf-8") as fh:
for line in fh:
line = line.rstrip("\n")
if not line:
continue
if line.startswith(">"):
if seq:
seqs.append("".join(seq))
seq = []
headers.append(line[1:].strip())
else:
seq.append(line.strip().upper())
if seq:
seqs.append("".join(seq))
return headers, seqs
def clean_protein(seq: str) -> str:
return re.sub(r"[^ACDEFGHIKLMNPQRSTVWY]", "", seq.upper())
def clean_dna(seq: str) -> str:
# allow U and N like your original
return re.sub(r"[^ACGTUN]", "", seq.upper())
def iter_unknown_sequences(unknown: str, is_protein: bool) -> List[Tuple[str, str, str]]:
"""
Returns list of (source_file, header, cleaned_seq).
unknown can be a fasta file or a directory with fasta files.
"""
seq_index: List[Tuple[str, str, str]] = []
if os.path.isdir(unknown):
files = [
os.path.join(unknown, f)
for f in os.listdir(unknown)
if f.lower().endswith(FASTA_EXTS)
]
else:
files = [unknown]
files = [f for f in files if os.path.isfile(f)]
for fp in sorted(files):
headers, seqs = read_fasta(fp)
if is_protein:
seqs = [clean_protein(s) for s in seqs]
else:
seqs = [clean_dna(s) for s in seqs]
for h, s in zip(headers, seqs):
if s: # drop empty after cleaning
seq_index.append((fp, h, s))
return seq_index
# ----------------------------
# ZIP utilities (safe extract)
# ----------------------------
def safe_extract_zip(zip_path: str, dst_dir: str) -> None:
"""Extract ZIP safely (prevents zip-slip)."""
with zipfile.ZipFile(zip_path, "r") as z:
for member in z.infolist():
if member.is_dir():
continue
target = os.path.normpath(os.path.join(dst_dir, member.filename))
if not target.startswith(os.path.abspath(dst_dir) + os.sep):
continue # skip suspicious paths
os.makedirs(os.path.dirname(target), exist_ok=True)
with z.open(member) as src, open(target, "wb") as out:
shutil.copyfileobj(src, out)
# ----------------------------
# Load unique kmers
# ----------------------------
@dataclass
class KmerDB:
group_kmers: Dict[str, List[str]]
source_dir: str
def parse_group_from_filename(fname: str, group_regex: str) -> str:
m = re.search(group_regex, fname, re.IGNORECASE)
if m:
return m.group(1)
# fallback: remove extension
return os.path.splitext(fname)[0]
def load_unique_kmers_from_dir(
kmer_dir: str,
is_protein: bool,
group_regex: str = DEFAULT_GROUP_REGEX,
) -> KmerDB:
"""
Loads k-mers from files like:
unique_k15_group1.tsv
unique_k20_groupA.txt
Accepts TSV or TXT; ignores comment/header lines.
"""
group_kmers: Dict[str, List[str]] = {}
for fname in sorted(os.listdir(kmer_dir)):
if not fname.lower().endswith(KMER_FILE_EXTS):
continue
fpath = os.path.join(kmer_dir, fname)
if not os.path.isfile(fpath):
continue
group = parse_group_from_filename(fname, group_regex)
group = group.strip()
group_kmers.setdefault(group, [])
with open(fpath, "r", encoding="utf-8") as fh:
for line in fh:
line = line.strip()
if (not line) or line.startswith("#"):
continue
if line.lower().startswith(("kmer", "total")):
continue
token = line.split()[0].upper()
token = clean_protein(token) if is_protein else clean_dna(token)
if token:
group_kmers[group].append(token)
# Deduplicate while preserving order
for g in list(group_kmers.keys()):
group_kmers[g] = list(dict.fromkeys(group_kmers[g]))
if len(group_kmers[g]) == 0:
# drop empty groups
del group_kmers[g]
if not group_kmers:
raise FileNotFoundError(f"No k-mer files found in: {kmer_dir}")
return KmerDB(group_kmers=group_kmers, source_dir=kmer_dir)
def load_unique_kmers(kmer_input: str, is_protein: bool, group_regex: str) -> KmerDB:
"""
kmer_input can be a directory OR a .zip file containing k-mer output files.
"""
if os.path.isdir(kmer_input):
return load_unique_kmers_from_dir(kmer_input, is_protein, group_regex=group_regex)
if os.path.isfile(kmer_input) and kmer_input.lower().endswith(".zip"):
tmp = tempfile.mkdtemp(prefix="kmerdb_")
safe_extract_zip(kmer_input, tmp)
# find a directory inside that actually contains kmer files; simplest: use tmp itself
return load_unique_kmers_from_dir(tmp, is_protein, group_regex=group_regex)
raise FileNotFoundError(f"--kmer-input must be a directory or a .zip file: {kmer_input}")
# ----------------------------
# Matching
# ----------------------------
def align_kmer_to_seq(
kmer: str,
seq: str,
is_protein: bool,
identity_threshold: float = 0.9,
min_coverage: float = 0.8,
gap_open: float = -10,
gap_extend: float = -0.5,
nuc_match: float = 2,
nuc_mismatch: float = -1,
nuc_gap_open: float = -5,
nuc_gap_extend: float = -1,
) -> bool:
if not kmer or not seq:
return False
# Fast exact substring path
if identity_threshold == 1.0 and min_coverage == 1.0:
return kmer in seq
if len(kmer) <= 3:
return kmer in seq
try:
aligner = Align.PairwiseAligner()
if is_protein:
aligner.substitution_matrix = BLOSUM62
aligner.open_gap_score = gap_open
aligner.extend_gap_score = gap_extend
else:
aligner.match_score = nuc_match
aligner.mismatch_score = nuc_mismatch
aligner.open_gap_score = nuc_gap_open
aligner.extend_gap_score = nuc_gap_extend
alns = aligner.align(kmer, seq)
if not alns:
return False
aln = alns[0]
aligned_query = aln.aligned[0]
aligned_target = aln.aligned[1]
aligned_len = sum(e - s for s, e in aligned_query)
if aligned_len == 0:
return False
matches = 0
for (qs, qe), (ts, te) in zip(aligned_query, aligned_target):
subseq_q = kmer[qs:qe]
subseq_t = seq[ts:te]
matches += sum(1 for a, b in zip(subseq_q, subseq_t) if a == b)
identity = matches / aligned_len
coverage = aligned_len / len(kmer)
return (identity >= identity_threshold) and (coverage >= min_coverage)
except Exception:
return False
# ----------------------------
# Prediction core
# ----------------------------
def predict(
unknown: str,
kmer_input: str,
output_dir: str,
seqtype: str,
mode: str,
identity_threshold: float,
min_coverage: float,
fdr_alpha: float,
group_regex: str,
) -> pd.DataFrame:
is_protein = (seqtype.lower() == "protein")
mode = mode.lower().strip()
if mode not in {"fast", "full"}:
raise ValueError("--mode must be 'fast' or 'full'")
# Load kmers (dir or zip)
db = load_unique_kmers(kmer_input, is_protein=is_protein, group_regex=group_regex)
group_kmers = db.group_kmers
print(f"Loaded k-mer counts: { {g: len(group_kmers[g]) for g in group_kmers} }")
# Unknown sequences
seq_index = iter_unknown_sequences(unknown, is_protein=is_protein)
if not seq_index:
raise FileNotFoundError("No sequences found in --unknown (file/dir).")
# Mode parameters
if mode == "fast":
eff_identity = 1.0
eff_coverage = 1.0
compute_stats = False
else:
eff_identity = float(identity_threshold)
eff_coverage = float(min_coverage)
compute_stats = True
results: List[dict] = []
total_seqs = len(seq_index)
for i, (srcfile, header, seq) in enumerate(seq_index, start=1):
print(f"Processing sequence {i}/{total_seqs} ({os.path.basename(srcfile)})")
group_found_counts = {g: 0 for g in group_kmers}
total_found = 0
for g, kmers in group_kmers.items():
for kmer in kmers:
if align_kmer_to_seq(
kmer, seq, is_protein=is_protein,
identity_threshold=eff_identity,
min_coverage=eff_coverage,
):
group_found_counts[g] += 1
total_found += 1
predicted = max(group_found_counts, key=group_found_counts.get)
conf_present = (group_found_counts[predicted] / total_found) if total_found else 0.0
conf_vocab = group_found_counts[predicted] / max(1, len(group_kmers[predicted]))
row = {
"Source_file": os.path.basename(srcfile),
"Sequence": header,
"Predicted_group": predicted,
"Matches_total": total_found,
**{f"Matches_{g}": group_found_counts[g] for g in group_kmers},
"Confidence_by_present": conf_present,
"Confidence_by_group_vocab": conf_vocab,
}
if compute_stats:
fisher_p = {}
# total vocabulary size of "other groups" for contingency table
other_vocab_total = {g: sum(len(group_kmers[og]) for og in group_kmers if og != g) for g in group_kmers}
for g in group_kmers:
a = group_found_counts[g]
b = len(group_kmers[g]) - a
c = total_found - a
d = other_vocab_total[g] - c
if d < 0:
d = 0
table = [[a, b], [c, d]]
_, p = fisher_exact(table, alternative="greater")
fisher_p[g] = p
row.update({f"FisherP_{g}": fisher_p[g] for g in group_kmers})
results.append(row)
df = pd.DataFrame(results)
# FDR correction (full mode)
if mode == "full":
fisher_cols = [c for c in df.columns if c.startswith("FisherP_")]
if fisher_cols:
all_pvals = df[fisher_cols].values.flatten()
_, qvals, _, _ = multipletests(all_pvals, alpha=float(fdr_alpha), method="fdr_bh")
qval_matrix = qvals.reshape(df[fisher_cols].shape)
for j, col in enumerate(fisher_cols):
grp = col.split("_", 1)[1]
df[f"FDR_{grp}"] = qval_matrix[:, j]
# Save
os.makedirs(output_dir, exist_ok=True)
out_csv = os.path.join(output_dir, "predictions_by_alignment.csv")
df.to_csv(out_csv, index=False)
print(f"Saved predictions to {out_csv}")
# Plot
save_summary_plot(df, output_dir)
return df
def save_summary_plot(df: pd.DataFrame, output_dir: str) -> None:
"""
Matplotlib-only summary figure:
- Left: predicted group counts
- Right: confidence distribution (boxplot)
"""
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# Left: bar counts
counts = df["Predicted_group"].value_counts()
axes[0].bar(counts.index.astype(str), counts.values)
axes[0].set_xlabel("Predicted Group")
axes[0].set_ylabel("Number of Sequences")
axes[0].set_title("Predicted Group Counts")
axes[0].tick_params(axis="x", rotation=45)
# Right: boxplot confidence_by_present by group
groups = sorted(df["Predicted_group"].unique().tolist())
data = [df.loc[df["Predicted_group"] == g, "Confidence_by_present"].values for g in groups]
axes[1].boxplot(data, labels=groups, showfliers=False)
axes[1].set_title("Prediction Confidence (by Present)")
axes[1].set_xlabel("Predicted Group")
axes[1].set_ylabel("Confidence")
axes[1].tick_params(axis="x", rotation=45)
fig.tight_layout()
fig.savefig(os.path.join(output_dir, "predicted_results_summary.png"), dpi=300)
plt.close(fig)
# ----------------------------
# CLI
# ----------------------------
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description="Predict group membership of unknown sequences using unique k-mers.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
p.add_argument("--unknown", required=True, help="Unknown FASTA file OR directory of FASTA files.")
p.add_argument("--kmer-input", required=True, help="Directory of unique_k*.tsv/txt OR a ZIP containing them.")
p.add_argument("--outdir", default="prediction_results", help="Output directory.")
p.add_argument("--seqtype", choices=["dna", "protein"], default="dna", help="Sequence type.")
p.add_argument("--mode", choices=["fast", "full"], default="fast", help="fast=substring only; full=alignment+Fisher+FDR.")
p.add_argument("--identity", type=float, default=0.9, help="Alignment identity threshold (full mode only).")
p.add_argument("--coverage", type=float, default=0.8, help="Alignment coverage threshold (full mode only).")
p.add_argument("--fdr", type=float, default=0.05, help="FDR alpha (full mode only).")
p.add_argument(
"--group-regex",
default=DEFAULT_GROUP_REGEX,
help="Regex to extract group name from k-mer filenames (1st capture group = group).",
)
return p
def main() -> None:
args = build_parser().parse_args()
# Validate unknown
if not os.path.exists(args.unknown):
raise FileNotFoundError(f"--unknown not found: {args.unknown}")
# Run
predict(
unknown=args.unknown,
kmer_input=args.kmer_input,
output_dir=args.outdir,
seqtype=args.seqtype,
mode=args.mode,
identity_threshold=args.identity,
min_coverage=args.coverage,
fdr_alpha=args.fdr,
group_regex=args.group_regex,
)
if __name__ == "__main__":
main()