BraggNN / gradio_demo.py
dennistrujillo's picture
added gradio_demo.py
769a90a
raw
history blame
1.53 kB
import gradio as gr
import torch
import requests
from torchvision import transforms
def predict(mfile,patch_file,positions_file,expName):
positions = torch.from_numpy(np.load(positions_file))
patch_h5 = h5py.File(patch_file,'r')
n_frames = len(patch_h5)
transform_norm = transforms.ToTensor()
patches = torch.zeros(size=(len(positions),15,15))
frame_nr = np.zeros(shape=(len(positions),))
patch_nr = np.zeros(shape=(len(positions),))
j=0
for i in range(1,n_frames+1):
k=0
for patch in patch_h5['frame_nr%s' %i]:
patches[j] = transform_norm(np.array(patch,dtype=np.uint8))
frame_nr[j] = i
patch_nr[j] = k
j+=1
k+=1
transformed_patches = patches
transformed_patches = torch.reshape(transformed_patches,(len(patches),1,15,15))
inp = transforms.ToTensor()(inp).unsqueeze(0)
model=torch.load(m_file)
with torch.no_grad():
for i, (inputs, labels) in enumerate(dl_valid):
inputs = inputs.to(device)
y_pred = model(inputs)
y_pred = y_pred.cpu().numpy()
labels = labels.cpu().numpy()
plot_error(y_pred,labels,args.expName)
mse += np.power(y_pred[:,0] - labels[:,0],2) + np.power(y_pred[:,1] - labels[:,1],2)
interface=gr.Interface(fn=predict, inputs={"m_file": "upload", "patch_file": "upload", "positions_file":"upload", "expName": "text"}, outputs={"output":"text"})
interface.launch()