from enum import Enum import streamlit as st import torch from tape import ProteinBertModel, TAPETokenizer from tokenizers import Tokenizer from transformers import ( AutoTokenizer, BertModel, BertTokenizer, GPT2LMHeadModel, GPT2TokenizerFast, T5EncoderModel, T5Tokenizer, ) class ModelType(str, Enum): TAPE_BERT = "TapeBert" ZymCTRL = "ZymCTRL" PROT_BERT = "ProtBert" PROT_T5 = "ProtT5" class Model: def __init__(self, name, layers, heads): self.name: ModelType = name self.layers: int = layers self.heads: int = heads @st.cache def get_tape_bert() -> tuple[TAPETokenizer, ProteinBertModel]: tokenizer = TAPETokenizer() model = ProteinBertModel.from_pretrained("bert-base", output_attentions=True) return tokenizer, model # Streamlit is not able to hash the tokenizer for ZymCTRL # With streamlit 1.19 cache_object should work without this @st.cache(hash_funcs={Tokenizer: lambda _: None}) def get_zymctrl() -> tuple[GPT2TokenizerFast, GPT2LMHeadModel]: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained("nferruz/ZymCTRL") model = GPT2LMHeadModel.from_pretrained("nferruz/ZymCTRL").to(device) return tokenizer, model @st.cache(hash_funcs={BertTokenizer: lambda _: None}) def get_prot_bert() -> tuple[BertTokenizer, BertModel]: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) model = BertModel.from_pretrained("Rostlab/prot_bert").to(device) return tokenizer, model @st.cache def get_prot_t5(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") tokenizer = T5Tokenizer.from_pretrained( "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False ) model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(device) return tokenizer, model