import streamlit as st from transformers import AutoTokenizer, EsmModel import torch import json def embed(aa_seq, tokenizer, model): inputs = tokenizer(aa_seq, return_tensors="pt") outputs = model(**inputs) last_hidden_states = outputs.last_hidden_state.detach().numpy().tolist() return last_hidden_states # selecing and loading a model model_name = st.selectbox( 'Choose a model', ["facebook/esm2_t6_8M_UR50D", "facebook/esm2_t48_15B_UR50D"]) #aa_seq_input = st.text_input('Type AA sequance here') #uploading AA sequences file uploaded_file = st.file_uploader("Upload JSON with AA sequences", type='json') if uploaded_file is not None: data = json.load(uploaded_file) #st.write(data) def embed_upload_file(upload_dict_dania, tokenizer, model): # upload_dict_dania = { # 'uid1': ['aa', 'aan'] # } # output = { # 'uid1': {'aa':[[[0.1298, ....]]], 'aan':[[[0.1298, ....]]]} # } output = {} # Add a placeholder latest_iteration = st.empty() bar = st.progress(0) for idx, (uid, seqs) in enumerate(upload_dict_dania.items()): output[uid] = {} # Update the progress bar with each iteration. latest_iteration.text(f'Iteration {uid}') bar.progress(idx + 1) for seq in seqs: output[uid][seq] = embed(seq, tokenizer, model) json_data = json.dumps(output) st.download_button( label = "Download JSON file", data = json_data, file_name = "esm-2 last hidden states.json", mime = 'application/json' ) if st.button('Get embedding'): st.write('You selected model:', model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) model = EsmModel.from_pretrained(model_name) embed_upload_file(data, tokenizer, model) st.write('Also, Dania is not gay')