RFamLlama-base / README.md
jinyuan22's picture
Update README.md
923e19e verified
|
raw
history blame
2.13 kB
metadata
license: cc-by-nc-4.0
library_name: transformers
tags:
  - biology
pipeline_tag: text-generation
widget:
  - text: <|bos|> <|tag_start|> 00050 <|tag_end|> <|5|>

RFamLlama

The ability to efficiently generate specific RNA sequences on demand has significant implications for both scientific research and therapeutic applications. In this context, we introduce RFamLlama, a conditional language model that is specifically optimized for generating RNA sequences across diverse families. This model was trained on RNA sequences representing over 4,000 distinct families, each augmented with control tags to denote the specific family. We have shown that the inclusion of family-specific tags substantially enhances the capabilities of our model in zero-shot fitness prediction of RNA molecules. Additionally, this model supports a conditional generation approach, allowing for the generation of RNA sequences by using Rfam IDs as input prompts, thereby eliminating the need for further functional-specific fine-tuning. Consequently, RFamLlama is poised to be an effective and widely applicable tool for the zero-shot fitness prediction and generation of RNA sequences, potentially pushing the boundaries of what can be achieved beyond natural evolutionary processes.

Use RFamLlama-base

# generation
from transformers import LlamaForCausalLM, AutoTokenizer, pipeline
import torch
import sys

model_url = "jinyuan22/RFamLlama-base"
model = LlamaForCausalLM.from_pretrained(model_url, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_url)

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

pipe = pipeline("text-generation", model=model, device=device, tokenizer=tokenizer)

tag = "RF00005"
txt = f"<|bos|> <|tag_start|> {tag[2:]} <|tag_end|> <|5|> "
all_outputs = []
outputs = pipe(txt, num_return_sequences=10, max_new_tokens=300, repetition_penalty=1, top_p=1,temperature=1, do_sample=True)

for i, output in enumerate(outputs):
    seq = output["generated_text"]
    seq = seq.split("<|5|>")[1].split("<|3|>")[0]
    print(f">{i}\n{seq}")