saketh11's picture
Add local CodonTransformer modules for custom ColiFormer functionality
6e9b5dc
"""
File: CodonData.py
---------------------
Includes helper functions for preprocessing NCBI or Kazusa databases and
preparing the data for training and inference of the CodonTransformer model.
"""
import json
import os
import random
from typing import Dict, List, Optional, Tuple, Union
import pandas as pd
import python_codon_tables as pct
from Bio import SeqIO
from Bio.Seq import Seq
from sklearn.utils import shuffle as sk_shuffle
from tqdm import tqdm
from CodonTransformer.CodonUtils import (
AMBIGUOUS_AMINOACID_MAP,
AMINO2CODON_TYPE,
AMINO_ACIDS,
ORGANISM2ID,
START_CODONS,
STOP_CODONS,
STOP_SYMBOL,
STOP_SYMBOLS,
ProteinConfig,
find_pattern_in_fasta,
get_taxonomy_id,
sort_amino2codon_skeleton,
)
def prepare_training_data(
dataset: Union[str, pd.DataFrame], output_file: str, shuffle: bool = True
) -> None:
"""
Prepare a JSON dataset for training the CodonTransformer model.
Input dataset should have columns below:
- dna: str (DNA sequence)
- protein: str (Protein sequence)
- organism: Union[int, str] (ID or Name of the organism)
The output JSON dataset will have the following format:
{"idx": 0, "codons": "M_ATG R_AGG L_TTG L_CTA R_CGA __TAG", "organism": 51}
{"idx": 1, "codons": "M_ATG K_AAG C_TGC F_TTT F_TTC __TAA", "organism": 59}
Args:
dataset (Union[str, pd.DataFrame]): Input dataset in CSV or DataFrame format.
output_file (str): Path to save the output JSON dataset.
shuffle (bool, optional): Whether to shuffle the dataset before saving.
Defaults to True.
Returns:
None
"""
if isinstance(dataset, str):
dataset = pd.read_csv(dataset)
required_columns = {"dna", "protein", "organism"}
if not required_columns.issubset(dataset.columns):
raise ValueError(f"Input dataset must have columns: {required_columns}")
# Prepare the dataset for finetuning
dataset["codons"] = dataset.apply(
lambda row: get_merged_seq(row["protein"], row["dna"], separator="_"), axis=1
)
# Replace organism str with organism id using ORGANISM2ID
dataset["organism"] = dataset["organism"].apply(
lambda org: process_organism(org, ORGANISM2ID)
)
# Save the dataset to a JSON file
dataframe_to_json(dataset[["codons", "organism"]], output_file, shuffle=shuffle)
def dataframe_to_json(df: pd.DataFrame, output_file: str, shuffle: bool = True) -> None:
"""
Convert pandas DataFrame to JSON file format suitable for training CodonTransformer.
This function takes a preprocessed DataFrame and writes it to a JSON file
where each line is a JSON object representing a single record.
Args:
df (pd.DataFrame): The input DataFrame with 'codons' and 'organism' columns.
output_file (str): Path to the output JSON file.
shuffle (bool, optional): Whether to shuffle the dataset before saving.
Defaults to True.
Returns:
None
Raises:
ValueError: If the required columns are not present in the DataFrame.
"""
required_columns = {"codons", "organism"}
if not required_columns.issubset(df.columns):
raise ValueError(f"DataFrame must contain columns: {required_columns}")
print(f"\nStarted writing to {output_file}...")
# Shuffle the DataFrame if requested
if shuffle:
df = sk_shuffle(df)
# Write the DataFrame to a JSON file
with open(output_file, "w") as f:
for idx, row in tqdm(
df.iterrows(), total=len(df), desc="Writing JSON...", unit=" records"
):
doc = {"idx": idx, "codons": row["codons"], "organism": row["organism"]}
f.write(json.dumps(doc) + "\n")
print(f"\nTotal Entries Saved: {len(df)}, JSON data saved to {output_file}")
def process_organism(organism: Union[str, int], organism_to_id: Dict[str, int]) -> int:
"""
Process and validate the organism input, converting it to a valid organism ID.
This function handles both string (organism name) and integer (organism ID) inputs.
It validates the input against a provided mapping of organism names to IDs.
Args:
organism (Union[str, int]): Input organism, either as a name (str) or ID (int).
organism_to_id (Dict[str, int]): Dictionary mapping organism names to their
corresponding IDs.
Returns:
int: The validated organism ID.
Raises:
ValueError: If the input is an invalid organism name or ID.
TypeError: If the input is neither a string nor an integer.
"""
if isinstance(organism, str):
if organism not in organism_to_id:
raise ValueError(f"Invalid organism name: {organism}")
return organism_to_id[organism]
elif isinstance(organism, int):
if organism not in organism_to_id.values():
raise ValueError(f"Invalid organism ID: {organism}")
return organism
raise TypeError(
f"Organism must be a string or integer, not {type(organism).__name__}"
)
def preprocess_protein_sequence(protein: str) -> str:
"""
Preprocess a protein sequence by cleaning, standardizing, and handling
ambiguous amino acids.
Args:
protein (str): The input protein sequence.
Returns:
str: The preprocessed protein sequence.
Raises:
ValueError: If the protein sequence is invalid or if the configuration is invalid.
"""
if not protein:
raise ValueError("Protein sequence is empty.")
# Clean and standardize the protein sequence
protein = (
protein.upper().strip().replace("\n", "").replace(" ", "").replace("\t", "")
)
# Handle ambiguous amino acids based on the specified behavior
config = ProteinConfig()
ambiguous_aminoacid_map_override = config.get("ambiguous_aminoacid_map_override")
ambiguous_aminoacid_behavior = config.get("ambiguous_aminoacid_behavior")
ambiguous_aminoacid_map = AMBIGUOUS_AMINOACID_MAP.copy()
for aminoacid, standard_aminoacids in ambiguous_aminoacid_map_override.items():
ambiguous_aminoacid_map[aminoacid] = standard_aminoacids
if ambiguous_aminoacid_behavior == "raise_error":
if any(aminoacid in ambiguous_aminoacid_map for aminoacid in protein):
raise ValueError("Ambiguous amino acids found in protein sequence.")
elif ambiguous_aminoacid_behavior == "standardize_deterministic":
protein = "".join(
ambiguous_aminoacid_map.get(aminoacid, [aminoacid])[0]
for aminoacid in protein
)
elif ambiguous_aminoacid_behavior == "standardize_random":
protein = "".join(
random.choice(ambiguous_aminoacid_map.get(aminoacid, [aminoacid]))
for aminoacid in protein
)
else:
raise ValueError(
f"Invalid ambiguous_aminoacid_behavior: {ambiguous_aminoacid_behavior}."
)
# Check for sequence validity
if any(aminoacid not in AMINO_ACIDS + STOP_SYMBOLS for aminoacid in protein):
raise ValueError("Invalid characters in protein sequence.")
if protein[-1] not in AMINO_ACIDS + STOP_SYMBOLS:
raise ValueError(
"Protein sequence must end with `*`, or `_`, or an amino acid."
)
# Replace '*' at the end of protein with STOP_SYMBOL if present
if protein[-1] == "*":
protein = protein[:-1] + STOP_SYMBOL
# Add stop symbol to end of protein
if protein[-1] != STOP_SYMBOL:
protein += STOP_SYMBOL
return protein
def replace_ambiguous_codons(dna: str) -> str:
"""
Replaces ambiguous codons in a DNA sequence with "UNK".
Args:
dna (str): The DNA sequence to process.
Returns:
str: The processed DNA sequence with ambiguous codons replaced by "UNK".
"""
result = []
dna = dna.upper()
# Check codons in DNA sequence
for i in range(0, len(dna), 3):
codon = dna[i : i + 3]
if len(codon) == 3 and all(nucleotide in "ATCG" for nucleotide in codon):
result.append(codon)
else:
result.append("UNK")
return "".join(result)
def preprocess_dna_sequence(dna: str) -> str:
"""
Cleans and preprocesses a DNA sequence by standardizing it and replacing
ambiguous codons.
Args:
dna (str): The DNA sequence to preprocess.
Returns:
str: The cleaned and preprocessed DNA sequence.
"""
if not dna:
return ""
# Clean and standardize the DNA sequence
dna = dna.upper().strip().replace("\n", "").replace(" ", "").replace("\t", "")
# Replace codons with ambigous nucleotides with "UNK"
dna = replace_ambiguous_codons(dna)
# Add unkown stop codon to end of DNA sequence if not present
if dna[-3:] not in STOP_CODONS:
dna += "UNK"
return dna
def get_merged_seq(protein: str, dna: str = "", separator: str = "_") -> str:
"""
Return the merged sequence of protein amino acids and DNA codons in the form
of tokens separated by space, where each token is composed of an amino acid +
separator + codon.
Args:
protein (str): Protein sequence.
dna (str): DNA sequence.
separator (str): Separator between amino acid and codon.
Returns:
str: Merged sequence.
Example:
>>> get_merged_seq(protein="MAV_", dna="ATGGCTGTGTAA", separator="_")
'M_ATG A_GCT V_GTG __TAA'
>>> get_merged_seq(protein="QHH_", dna="", separator="_")
'Q_UNK H_UNK H_UNK __UNK'
"""
merged_seq = ""
# Prepare protein and dna sequences
dna = preprocess_dna_sequence(dna)
protein = preprocess_protein_sequence(protein)
# Check if the length of protein and dna sequences are equal
if len(dna) > 0 and len(protein) != len(dna) / 3:
raise ValueError(
'Length of protein (including stop symbol such as "_") and '
"the number of codons in DNA sequence (including stop codon) "
"must be equal."
)
# Merge protein and DNA sequences into tokens
for i, aminoacid in enumerate(protein):
merged_seq += f'{aminoacid}{separator}{dna[i * 3:i * 3 + 3] if dna else "UNK"} '
return merged_seq.strip()
def is_correct_seq(dna: str, protein: str, stop_symbol: str = STOP_SYMBOL) -> bool:
"""
Check if the given DNA and protein pair is correct, that is:
1. The length of dna is divisible by 3
2. There is an initiator codon in the beginning of dna
3. There is only one stop codon in the sequence
4. The only stop codon is the last codon
Note since in Codon Table 3, 'TGA' is interpreted as Triptophan (W),
there is a separate check to make sure those sequences are considered correct.
Args:
dna (str): DNA sequence.
protein (str): Protein sequence.
stop_symbol (str): Stop symbol.
Returns:
bool: True if the sequence is correct, False otherwise.
"""
return (
len(dna) % 3 == 0 # Check if DNA length is divisible by 3
and dna[:3].upper() in START_CODONS # Check for initiator codon
and protein[-1]
== stop_symbol # Check if the last protein symbol is the stop symbol
and protein.count(stop_symbol) == 1 # Check if there is only one stop symbol
and len(set(dna))
== 4 # Check if DNA consists of 4 unique nucleotides (A, T, C, G)
)
def get_amino_acid_sequence(
dna: str,
stop_symbol: str = "_",
codon_table: int = 1,
return_correct_seq: bool = False,
) -> Union[str, Tuple[str, bool]]:
"""
Return the translated protein sequence given a DNA sequence and codon table.
Args:
dna (str): DNA sequence.
stop_symbol (str): Stop symbol.
codon_table (int): Codon table number.
return_correct_seq (bool): Whether to return if the sequence is correct.
Returns:
Union[str, Tuple[str, bool]]: Protein sequence and correctness flag if
return_correct_seq is True, otherwise just the protein sequence.
"""
dna_seq = Seq(dna).strip()
# Translate the DNA sequence to a protein sequence
protein_seq = str(
dna_seq.translate(
stop_symbol=stop_symbol, # Symbol to use for stop codons
to_stop=False, # Translate the entire sequence, including any stop codons
cds=False, # Do not assume the input is a coding sequence
table=codon_table, # Codon table to use for translation
)
).strip()
return (
protein_seq
if not return_correct_seq
else (protein_seq, is_correct_seq(dna_seq, protein_seq, stop_symbol))
)
def read_fasta_file(
input_file: str,
save_to_file: Optional[str] = None,
organism: str = "",
buffer_size: int = 50000,
) -> pd.DataFrame:
"""
Read a FASTA file of DNA sequences and convert it to a Pandas DataFrame.
Optionally, save the DataFrame to a CSV file.
Args:
input_file (str): Path to the input FASTA file.
save_to_file (Optional[str]): Path to save the output DataFrame. If None,
data is only returned.
organism (str): Name of the organism. If empty, it will be extracted from
the FASTA description.
buffer_size (int): Number of records to process before writing to file.
Returns:
pd.DataFrame: DataFrame containing the DNA sequences if return_dataframe
is True, else None.
Raises:
FileNotFoundError: If the input file does not exist.
"""
if not os.path.exists(input_file):
raise FileNotFoundError(f"Input file not found: {input_file}")
buffer = []
columns = [
"dna",
"protein",
"correct_seq",
"organism",
"GeneID",
"description",
"tokenized",
]
# Initialize DataFrame to store all data if return_dataframe is True
all_data = pd.DataFrame(columns=columns)
with open(input_file, "r") as fasta_file:
for record in tqdm(
SeqIO.parse(fasta_file, "fasta"),
desc=f"Processing {organism}",
unit=" Records",
):
dna = str(record.seq).strip().upper() # Ensure uppercase DNA sequence
# Determine the organism from the record if not provided
current_organism = organism or find_pattern_in_fasta(
"organism", record.description
)
gene_id = find_pattern_in_fasta("GeneID", record.description)
# Get the appropriate codon table for the organism
codon_table = get_codon_table(current_organism)
# Translate DNA to protein sequence
protein, correct_seq = get_amino_acid_sequence(
dna,
stop_symbol=STOP_SYMBOL,
codon_table=codon_table,
return_correct_seq=True,
)
description = record.description.split("[", 1)[0].strip()
tokenized = get_merged_seq(protein, dna, separator=STOP_SYMBOL)
# Create a data row for the current sequence
data_row = {
"dna": dna,
"protein": protein,
"correct_seq": correct_seq,
"organism": current_organism,
"GeneID": gene_id,
"description": description,
"tokenized": tokenized,
}
buffer.append(data_row)
# Write buffer to CSV file when buffer size is reached
if save_to_file and len(buffer) >= buffer_size:
write_buffer_to_csv(buffer, save_to_file, columns)
buffer = []
all_data = pd.concat(
[all_data, pd.DataFrame([data_row])], ignore_index=True
)
# Write remaining buffer to CSV file
if save_to_file and buffer:
write_buffer_to_csv(buffer, save_to_file, columns)
return all_data
def write_buffer_to_csv(buffer: List[Dict], output_path: str, columns: List[str]):
"""Helper function to write buffer to CSV file."""
buffer_df = pd.DataFrame(buffer, columns=columns)
buffer_df.to_csv(
output_path,
mode="a",
header=(not os.path.exists(output_path)),
index=True,
)
def download_codon_frequencies_from_kazusa(
taxonomy_id: Optional[int] = None,
organism: Optional[str] = None,
taxonomy_reference: Optional[str] = None,
return_original_format: bool = False,
) -> AMINO2CODON_TYPE:
"""
Return the codon table of the given taxonomy ID from the Kazusa Database.
Args:
taxonomy_id (Optional[int]): Taxonomy ID.
organism (Optional[str]): Name of the organism.
taxonomy_reference (Optional[str]): Taxonomy reference.
return_original_format (bool): Whether to return in the original format.
Returns:
AMINO2CODON_TYPE: Codon table.
"""
if taxonomy_reference:
taxonomy_id = get_taxonomy_id(taxonomy_reference, organism=organism)
kazusa_amino2codon = pct.get_codons_table(table_name=taxonomy_id)
if return_original_format:
return kazusa_amino2codon
# Replace "*" with STOP_SYMBOL in the codon table
kazusa_amino2codon[STOP_SYMBOL] = kazusa_amino2codon.pop("*")
# Create amino2codon dictionary
amino2codon = {
aminoacid: (list(codon2freq.keys()), list(codon2freq.values()))
for aminoacid, codon2freq in kazusa_amino2codon.items()
}
return sort_amino2codon_skeleton(amino2codon)
def build_amino2codon_skeleton(organism: str) -> AMINO2CODON_TYPE:
"""
Return the empty skeleton of the amino2codon dictionary, needed for
get_codon_frequencies.
Args:
organism (str): Name of the organism.
Returns:
AMINO2CODON_TYPE: Empty amino2codon dictionary.
"""
amino2codon = {}
possible_codons = [f"{i}{j}{k}" for i in "ACGT" for j in "ACGT" for k in "ACGT"]
possible_aminoacids = get_amino_acid_sequence(
dna="".join(possible_codons),
codon_table=get_codon_table(organism),
return_correct_seq=False,
)
# Initialize the amino2codon skeleton with all possible codons and set their
# frequencies to 0
for i, (codon, amino) in enumerate(zip(possible_codons, possible_aminoacids)):
if amino not in amino2codon:
amino2codon[amino] = ([], [])
amino2codon[amino][0].append(codon)
amino2codon[amino][1].append(0)
# Sort the dictionary and each list of codon frequency alphabetically
amino2codon = sort_amino2codon_skeleton(amino2codon)
return amino2codon
def get_codon_frequencies(
dna_sequences: List[str],
protein_sequences: Optional[List[str]] = None,
organism: Optional[str] = None,
) -> AMINO2CODON_TYPE:
"""
Return a dictionary mapping each codon to its respective frequency based on
the collection of DNA sequences and protein sequences.
Args:
dna_sequences (List[str]): List of DNA sequences.
protein_sequences (Optional[List[str]]): List of protein sequences.
organism (Optional[str]): Name of the organism.
Returns:
AMINO2CODON_TYPE: Dictionary mapping each amino acid to a tuple of codons
and frequencies.
"""
if organism:
codon_table = get_codon_table(organism)
protein_sequences = [
get_amino_acid_sequence(
dna, codon_table=codon_table, return_correct_seq=False
)
for dna in dna_sequences
]
amino2codon = build_amino2codon_skeleton(organism)
# Count the frequencies of each codon for each amino acid
for dna, protein in zip(dna_sequences, protein_sequences):
for i, amino in enumerate(protein):
codon = dna[i * 3 : (i + 1) * 3]
codon_loc = amino2codon[amino][0].index(codon)
amino2codon[amino][1][codon_loc] += 1
# Normalize codon frequencies per amino acid so they sum to 1
amino2codon = {
amino: (codons, [freq / (sum(frequencies) + 1e-100) for freq in frequencies])
for amino, (codons, frequencies) in amino2codon.items()
}
return amino2codon
def get_organism_to_codon_frequencies(
dataset: pd.DataFrame, organisms: List[str]
) -> Dict[str, AMINO2CODON_TYPE]:
"""
Return a dictionary mapping each organism to their codon frequency distribution.
Args:
dataset (pd.DataFrame): DataFrame containing DNA sequences.
organisms (List[str]): List of organisms.
Returns:
Dict[str, AMINO2CODON_TYPE]: Dictionary mapping each organism to its codon
frequency distribution.
"""
organism2frequencies = {}
# Calculate codon frequencies for each organism in the dataset
for organism in tqdm(
organisms, desc="Calculating Codon Frequencies: ", unit="Organism"
):
organism_data = dataset.loc[dataset["organism"] == organism]
dna_sequences = organism_data["dna"].to_list()
protein_sequences = organism_data["protein"].to_list()
codon_frequencies = get_codon_frequencies(dna_sequences, protein_sequences)
organism2frequencies[organism] = codon_frequencies
return organism2frequencies
def get_codon_table(organism: str) -> int:
"""
Return the appropriate NCBI codon table for a given organism.
Args:
organism (str): Name of the organism.
Returns:
int: Codon table number.
"""
# Common codon table (Table 1) for many model organisms
if organism in [
"Arabidopsis thaliana",
"Caenorhabditis elegans",
"Chlamydomonas reinhardtii",
"Saccharomyces cerevisiae",
"Danio rerio",
"Drosophila melanogaster",
"Homo sapiens",
"Mus musculus",
"Nicotiana tabacum",
"Solanum tuberosum",
"Solanum lycopersicum",
"Oryza sativa",
"Glycine max",
"Zea mays",
]:
codon_table = 1
# Chloroplast codon table (Table 11)
elif organism in [
"Chlamydomonas reinhardtii chloroplast",
"Nicotiana tabacum chloroplast",
]:
codon_table = 11
# Default to Table 11 for other bacteria and archaea
else:
codon_table = 11
return codon_table