aptlm / inference.py
mrowenyao's picture
get rid of argparse
8bc1979 verified
raw
history blame
1.21 kB
from api_prediction import AptaTransPipeline_Dist
import gradio as gr
def infer(prot, apta):
pipeline = AptaTransPipeline_Dist(
lr=1e-6,
weight_decay=None,
epochs=None,
model_type=None,
model_version=None,
model_save_path=None,
accelerate_save_path=None,
tensorboard_logdir=None,
d_model=128,
d_ff=512,
n_layers=6,
n_heads=8,
dropout=0.1,
load_best_pt=True, # already loads the pretrained model using the datasets included in repo -- no need to run the bottom two cells
device='cuda',
seed=1004)
scores = pipeline.inference([apta], [prot], [0])
return scores[0]
def comparison(protein, aptamers):
aptamers = aptamers.split('\n')
pairs = [[protein, aptamer] for aptamer in aptamers]
print(pairs)
scores = []
for pair in pairs:
score = infer(pair[0], pair[1])
scores.append(score)
return scores
iface = gr.Interface(
fn=comparison,
inputs=[
gr.Textbox(lines=2, placeholder="Protein"),
gr.Textbox(lines=10,placeholder="Aptamers (1 per line)")
],
outputs=gr.Textbox()
)
iface.launch()