Lang2mol-Diff / app_streamlit.py
ndhieunguyen's picture
feat: use gradio
7cacf8f
import torch
import selfies as sf
from transformers import T5EncoderModel
from src.scripts.mytokenizers import Tokenizer
from src.improved_diffusion import gaussian_diffusion as gd
from src.improved_diffusion.respace import SpacedDiffusion
from src.improved_diffusion.transformer_model import TransformerNetModel
import streamlit as st
import spaces
import os
@st.cache_resource
def get_encoder(device):
model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
model.to(device)
model.eval()
return model
@st.cache_resource
def get_tokenizer():
return Tokenizer()
@st.cache_resource
def get_model(device):
model = TransformerNetModel(
in_channels=32,
model_channels=128,
dropout=0.1,
vocab_size=35073,
hidden_size=1024,
num_attention_heads=16,
num_hidden_layers=12,
)
model.load_state_dict(
torch.load(
os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"),
map_location=torch.device(device),
)
)
model.to(device)
model.eval()
return model
@st.cache_resource
def get_diffusion():
return SpacedDiffusion(
use_timesteps=[i for i in range(0, 2000, 10)],
betas=gd.get_named_beta_schedule("sqrt", 2000),
model_mean_type=(gd.ModelMeanType.START_X),
model_var_type=((gd.ModelVarType.FIXED_LARGE)),
loss_type=gd.LossType.E2E_MSE,
rescale_timesteps=True,
model_arch="transformer",
training_mode="e2e",
)
@spaces.GPU
def generate(text_input):
with st.spinner("Please wait..."):
output = tokenizer(
text_input,
max_length=256,
truncation=True,
padding="max_length",
add_special_tokens=True,
return_tensors="pt",
return_attention_mask=True,
)
caption_state = encoder(
input_ids=output["input_ids"].to(device),
attention_mask=output["attention_mask"].to(device),
).last_hidden_state
caption_mask = output["attention_mask"]
outputs = diffusion.p_sample_loop(
model,
(1, 256, 32),
clip_denoised=False,
denoised_fn=None,
model_kwargs={},
top_p=1.0,
progress=True,
caption=(caption_state.to(device), caption_mask.to(device)),
)
logits = model.get_logits(torch.tensor(outputs))
cands = torch.topk(logits, k=1, dim=-1)
outputs = cands.indices
outputs = outputs.squeeze(-1)
outputs = tokenizer.decode(outputs)
result = sf.decoder(
outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
).replace("\t", "")
return result
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
tokenizer = get_tokenizer()
encoder = get_encoder(device)
model = get_model(device)
diffusion = get_diffusion()
st.title("Lang2mol-Diff")
text_input = st.text_area("Enter molecule description")
button = st.button("Submit")
if button:
result = generate(text_input)
st.write(result)