TEM-52 ESM-2 (LoRA) Masked Language Models

Fine-tuned ESM-2 650M masked language models (with LoRA adapters) for predicting beneficial amino-acid substitutions at six saturation-mutagenesis (SSM) positions (Y103, N168, V214, A235, E237, R241; non-Ambler) of the TEM-52 β-lactamase.

One checkpoint is provided per antibiotic substrate, each in its own subfolder:

Subfolder Substrate
amp_100 ampicillin
caz_10000 ceftazidime
cet_15 cephalothin
ctx_1125 cefotaxime

Each subfolder contains pytorch_model.bin (full model state dict: ESM-2 650M + LoRA) and the tokenizer files.

Usage

Clone the code repository ajoujcb/TEM_ESM and use predict.py, which downloads the weights from this Hub repo automatically:

python predict.py --hf_repo AjouJCB/TEM_ESM --substrate caz_10000 -p 103 --top_k 5

Or from Python:

from predict import (load_model, resolve_weights_dir,
                     build_masked_sequence, predict_masked_tokens)

weights_dir = resolve_weights_dir(hf_repo="AjouJCB/TEM_ESM", substrate="caz_10000")
model, tokenizer = load_model(weights_dir, device="cuda")
sequence = build_masked_sequence(103)          # wild-type TEM-52 with Y103 masked
for hit in predict_masked_tokens(model, tokenizer, sequence, top_k=5):
    print(hit)

The architecture (LoRA config) used to rebuild the model before loading the state dict is: r=4, lora_alpha=8, target_modules=["query","key","value","out"], applied to all 33 transformer blocks. See the code repository for details.

Note: ESM-2 uses rotary position embeddings, so the unused position_embeddings.weight is absent from these checkpoints; load with strict=False (handled automatically by predict.py).

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for AjouJCB/TEM_ESM

Adapter
(22)
this model