You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

SARITA_M

Symbol

This model is a fine-tuned version of lightonai/RITA_m on an unknown dataset. It achieves the following results on the evaluation set:

  • Loss: 0.0329
  • Accuracy: 0.9895

The Pre-print paper is avaiable here. The codes to train and to evaluate the model is avaiable on GitHub

Model #Params d_model layers
Small 85M 768 12
Medium 300M 1024 24
Large 680M 1536 24
XLarge 1.2B 2048 24

Model description

SARITA M is an LLM with 300 million parameters, based on GPT-3 architecture, designed to generate high-quality synthetic SARS-CoV-2 Spike sequences. SARITA is trained via continuous learning on the pre-existing protein model RITA. The model was trained using sequences uploaded to GISAID between December 2019 and August 2024.

Intended uses & limitations

This model can be used by user to generate synthetic Spike proteins of SARS-CoV-2 Virus.

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 5e-05
  • train_batch_size: 48
  • eval_batch_size: 48
  • seed: 42
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • num_epochs: 10
  • mixed_precision_training: Native AMP

Training results

Training Loss Epoch Step Accuracy Validation Loss
0.0333 1.0 57263 0.9895 0.0334
0.033 2.0 114526 0.9895 0.0331
0.0328 3.0 171789 0.9895 0.0331
0.0327 4.0 229052 0.9895 0.0330
0.0325 5.0 286315 0.9895 0.0330
0.0324 6.0 343578 0.9896 0.0329
0.0325 7.0 400841 0.9895 0.0329
0.0322 8.0 458104 0.0329 0.9896
0.0321 9.0 515367 0.0330 0.9896
0.0319 10.0 572630 0.0331 0.9896

Framework versions

  • Transformers 4.20.1
  • Pytorch 1.9.0+cu111
  • Datasets 2.18.0
  • Tokenizers 0.12.1

Usage

Instantiate a model like so:

from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("SimoRancati/SARITA_M.0.1", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("SimoRancati/SARITA_M.0.1")

for generation used this code:

# Check for GPU availability and move the model to GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

start = ['MFVFLVLLPLVSSQ']

for i in range(len(start)):
    # Prepare model inputs
    model_inputs = tokenizer([start[i]], return_tensors="pt")
    model_inputs = {k: v.to(device) for k, v in model_inputs.items()}

    # Generate predictions using the model
    generated_ids = model.generate(**model_inputs, min_length=701, max_length=701,
                                   do_sample=True, top_k=950, repetition_penalty=1.2,
                                   num_return_sequences=100, eos_token_id=2, truncation=True)

    # Decode and print outputs
    generated_sequences = []
    for f in range(len(generated_ids)):
        sequence = tokenizer.decode(generated_ids[f], skip_special_tokens=True).replace(' ', '')
        generated_sequences.append(sequence)

Avaiability

SARITA model is public, but downloading it requires approval.
To request access, click on the Request Access button and provide a brief explanation of your intended use.

License

The use of this model is restricted to research purposes. Commercial use is not allowed without prior approval.

Downloads last month
0
Inference API
Unable to determine this model's library. Check the docs .