|
import streamlit as st |
|
import os |
|
import torch |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from robust_detection import utils |
|
from robust_detection.models.rcnn import RCNN |
|
from robust_detection.data_utils.rcnn_data_utils import Objects_RCNN, COCO_RCNN |
|
|
|
import pytorch_lightning as pl |
|
from pytorch_lightning.loggers import WandbLogger |
|
from pytorch_lightning.loggers import CSVLogger |
|
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint |
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping |
|
from pytorch_lightning.callbacks import LearningRateMonitor |
|
from rdkit import Chem |
|
from rdkit.Chem import AllChem |
|
from rdkit import DataStructs |
|
from PIL import Image |
|
|
|
model_cls = RCNN |
|
experiment_path_atoms="./models/atoms_model/" |
|
dir_list = os.listdir(experiment_path_atoms) |
|
dir_list = [os.path.join(experiment_path_atoms,f) for f in dir_list] |
|
dir_list.sort(key=os.path.getctime, reverse=True) |
|
checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0] |
|
model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms) |
|
model_atom.model.roi_heads.score_thresh = 0.65 |
|
|
|
|
|
st.title("Atom Level Entity Detector") |
|
|
|
image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png','jpeg','jpg']) |
|
|
|
if image_file is not None: |
|
col1, col2 = st.columns(2) |
|
|
|
image = Image.open(image_file) |
|
col1.image(image, use_column_width=True) |
|
with open(os.path.join("uploads",image_file.name),"wb") as f: |
|
f.write(image_file.getbuffer()) |
|
st.success("Saved File") |
|
x = st.slider('Select a value') |
|
st.write(x, 'squared is', x * x) |
|
|