PhyloGPN is a convolutional neural network that takes encoded DNA sequences as input and outputs rate matrix parameters for Felsenstein's 1981 model (the F81 model, for short). It was trained to maximize the likelihood of columns in the Zoonomia alignment given a phylogenetic tree. The stationary distribution of the substitution process described by the F81 model indicates the relative viability of each allele at any given locus. As a result, PhyloGPN is formally a genomic language model. It can be used for transfer learning.

The following Python snippet shows how to obtain embeddings and log rate parameters from PhyloGPN for each site in a batch of sequences. Note that PhyloGPN is designed as a sliding window function: it takes a batch of bb encoded sequences of any given length โ„“>=481\ell >= 481 as input and yields outputs for the bร—(โ„“โˆ’480)b \times (\ell - 480) central positions.

import torch
from transformers import AutoModel, AutoTokenizer

checkpoint = "calbors/PhyloGPN"
commit_hash = "8e0a2a4"
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True, commit_hash=commit_hash)
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True, commit_hash=commit_hash)

# Example data
seqs = [
    "TATAAA",
    "GGCCAATCT",
    "CACGTG",
    "AGGTCACGT",
    "GCCAGCC",
    "GGGGATTTCC"
]

# Output length is input length minus 480 (the receptive field size minus 1)
pad_token = tokenizer.pad_token
pad_size = 481 // 2
pad_sequence = lambda seq: pad_token * pad_size + seq + pad_token * pad_size
padded_seqs = [pad_sequence(seq) for seq in seqs]
input_tensor = tokenizer(padded_seqs, return_tensors="pt", padding=True)["input_ids"]

with torch.no_grad():
    padded_embeddings = model.get_embeddings(input_tensor)
    padded_logits = model(input_tensor) # These are log rate parameters for the F81 model

embeddings = []
logits = []

for i in range(len(seqs)):
    length = len(seqs[i])
    embeddings.append(padded_embeddings[i, :length])
    logits.append({})

    for k in "ACGT":
        logits[-1][k] = padded_logits[k][i, :length]
Downloads last month
50
Safetensors
Model size
83.2M params
Tensor type
F32
ยท
Inference API
Unable to determine this model's library. Check the docs .