|
import gradio as gr |
|
import torch |
|
import requests |
|
from torchvision import transforms |
|
|
|
def predict(mfile,patch_file,positions_file,expName): |
|
positions = torch.from_numpy(np.load(positions_file)) |
|
|
|
patch_h5 = h5py.File(patch_file,'r') |
|
n_frames = len(patch_h5) |
|
|
|
transform_norm = transforms.ToTensor() |
|
patches = torch.zeros(size=(len(positions),15,15)) |
|
frame_nr = np.zeros(shape=(len(positions),)) |
|
patch_nr = np.zeros(shape=(len(positions),)) |
|
|
|
j=0 |
|
for i in range(1,n_frames+1): |
|
k=0 |
|
for patch in patch_h5['frame_nr%s' %i]: |
|
patches[j] = transform_norm(np.array(patch,dtype=np.uint8)) |
|
frame_nr[j] = i |
|
patch_nr[j] = k |
|
j+=1 |
|
k+=1 |
|
|
|
transformed_patches = patches |
|
transformed_patches = torch.reshape(transformed_patches,(len(patches),1,15,15)) |
|
|
|
inp = transforms.ToTensor()(inp).unsqueeze(0) |
|
model=torch.load(m_file) |
|
with torch.no_grad(): |
|
for i, (inputs, labels) in enumerate(dl_valid): |
|
inputs = inputs.to(device) |
|
y_pred = model(inputs) |
|
y_pred = y_pred.cpu().numpy() |
|
labels = labels.cpu().numpy() |
|
plot_error(y_pred,labels,args.expName) |
|
mse += np.power(y_pred[:,0] - labels[:,0],2) + np.power(y_pred[:,1] - labels[:,1],2) |
|
|
|
interface=gr.Interface(fn=predict, inputs={"m_file": "upload", "patch_file": "upload", "positions_file":"upload", "expName": "text"}, outputs={"output":"text"}) |
|
interface.launch() |
|
|