deepamr-api / src /ml /feature_extractor.py
hossainlab's picture
Deploy DeepAMR API backend
3255634
# src/ml/feature_extractor.py
import torch
from transformers import AutoTokenizer, AutoModel
from Bio import SeqIO
import numpy as np
from typing import List, Dict
import re
class ProteinFeatureExtractor:
"""Extract features from protein sequences using ESM-2"""
def __init__(self, model_path="models/pretrained/esm2"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModel.from_pretrained(model_path).to(self.device)
self.model.eval()
def extract_proteins_from_genome(self, genome_sequence: str) -> List[str]:
"""
Extract protein sequences from genome
Use Prodigal or simple ORF finder
"""
# Simple ORF finder (for demo - use Prodigal in production)
proteins = []
# Find ORFs starting with ATG and ending with stop codons
start_codons = ['ATG']
stop_codons = ['TAA', 'TAG', 'TGA']
for i in range(len(genome_sequence) - 3):
codon = genome_sequence[i:i+3]
if codon in start_codons:
# Look for stop codon
for j in range(i+3, len(genome_sequence)-3, 3):
stop_codon = genome_sequence[j:j+3]
if stop_codon in stop_codons:
orf = genome_sequence[i:j+3]
if len(orf) >= 300: # Minimum 100 amino acids
protein = self.translate_dna_to_protein(orf)
if protein:
proteins.append(protein)
break
return proteins[:50] # Top 50 proteins to avoid too much data
def translate_dna_to_protein(self, dna_seq: str) -> str:
"""Translate DNA to protein sequence"""
codon_table = {
'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*',
'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L',
'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q',
'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M',
'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K',
'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V',
'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E',
'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G',
}
protein = []
for i in range(0, len(dna_seq) - 2, 3):
codon = dna_seq[i:i+3].upper()
if codon in codon_table:
aa = codon_table[codon]
if aa == '*':
break
protein.append(aa)
return ''.join(protein) if len(protein) > 0 else None
def get_protein_embedding(self, protein_seq: str) -> np.ndarray:
"""Get ESM-2 embedding for a protein sequence"""
# Truncate if too long (ESM-2 has max length ~1000)
if len(protein_seq) > 1000:
protein_seq = protein_seq[:1000]
# Tokenize
inputs = self.tokenizer(protein_seq, return_tensors="pt", truncation=True, max_length=1024)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get embeddings
with torch.no_grad():
outputs = self.model(**inputs)
# Mean pooling over sequence length
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
return embeddings.squeeze()
def extract_genome_features(self, genome_path: str) -> np.ndarray:
"""Extract features from entire genome"""
# Load genome
genome_seq = ""
for record in SeqIO.parse(genome_path, "fasta"):
genome_seq += str(record.seq)
# Extract proteins
proteins = self.extract_proteins_from_genome(genome_seq)
print(f"Extracted {len(proteins)} proteins from genome")
if len(proteins) == 0:
return np.zeros(320) # Return zero vector if no proteins found
# Get embeddings for all proteins
embeddings = []
for protein in proteins[:20]: # Top 20 proteins
try:
emb = self.get_protein_embedding(protein)
embeddings.append(emb)
except Exception as e:
print(f"Error processing protein: {e}")
continue
if len(embeddings) == 0:
return np.zeros(320)
# Aggregate embeddings (mean pooling)
genome_embedding = np.mean(embeddings, axis=0)
return genome_embedding
class AMRGeneDetector:
"""Detect known AMR genes using CARD database"""
def __init__(self, card_db_path="data/external/card"):
self.card_sequences = self.load_card_database(card_db_path)
def load_card_database(self, card_path):
"""Load CARD AMR gene sequences"""
card_genes = {}
# Load from CARD FASTA file
fasta_path = f"{card_path}/nucleotide_fasta_protein_homolog_model.fasta"
try:
for record in SeqIO.parse(fasta_path, "fasta"):
# Parse gene name and antibiotic class
gene_info = self.parse_card_header(record.description)
card_genes[record.id] = {
'sequence': str(record.seq),
'gene_name': gene_info['gene_name'],
'drug_class': gene_info['drug_class']
}
except FileNotFoundError:
print(f"CARD database not found at {fasta_path}")
# Return empty dict for now
return {}
print(f"Loaded {len(card_genes)} AMR genes from CARD")
return card_genes
def parse_card_header(self, header: str) -> Dict:
"""Parse CARD FASTA header"""
# Example: "ARO:3000026|mecA [Staphylococcus aureus]"
parts = header.split('|')
gene_name = parts[1].split('[')[0].strip() if len(parts) > 1 else "unknown"
return {
'gene_name': gene_name,
'drug_class': 'beta-lactam' # Simplified for now
}
def detect_amr_genes(self, genome_sequence: str) -> List[Dict]:
"""
Detect AMR genes in genome using sequence similarity
In production, use BLAST or MMseqs2
"""
detected_genes = []
# Simplified: check for exact substring matches
# In production: use BLAST or diamond
for gene_id, gene_info in self.card_sequences.items():
if gene_info['sequence'] in genome_sequence:
detected_genes.append({
'gene_id': gene_id,
'gene_name': gene_info['gene_name'],
'drug_class': gene_info['drug_class']
})
return detected_genes
class CombinedFeatureExtractor:
"""Combine protein embeddings and gene detection"""
def __init__(self):
self.protein_extractor = ProteinFeatureExtractor()
self.gene_detector = AMRGeneDetector()
def extract_features(self, genome_path: str) -> Dict:
"""Extract all features from genome"""
# 1. Protein embeddings (320-dim from ESM-2)
protein_features = self.protein_extractor.extract_genome_features(genome_path)
# 2. Load genome for gene detection
genome_seq = ""
for record in SeqIO.parse(genome_path, "fasta"):
genome_seq += str(record.seq)
# 3. AMR gene detection
detected_genes = self.gene_detector.detect_amr_genes(genome_seq)
# 4. Create gene presence/absence vector
gene_features = self.create_gene_feature_vector(detected_genes)
# 5. Combine features
combined_features = np.concatenate([protein_features, gene_features])
return {
'features': combined_features,
'detected_genes': detected_genes,
'feature_dim': len(combined_features)
}
def create_gene_feature_vector(self, detected_genes: List[Dict], num_genes=50) -> np.ndarray:
"""Create binary vector for gene presence/absence"""
# Top 50 most important AMR genes
important_genes = [
'mecA', 'vanA', 'blaCTX-M', 'blaKPC', 'blaNDM', 'blaOXA',
'ermB', 'tetM', 'aac', 'aph', 'sul1', 'sul2', 'dfrA'
]
gene_vector = np.zeros(num_genes)
detected_names = [g['gene_name'] for g in detected_genes]
for i, gene in enumerate(important_genes[:num_genes]):
if any(gene in name for name in detected_names):
gene_vector[i] = 1
return gene_vector