import gradio as gr import torch import requests from torchvision import transforms def predict(mfile,patch_file,positions_file,expName): positions = torch.from_numpy(np.load(positions_file)) patch_h5 = h5py.File(patch_file,'r') n_frames = len(patch_h5) transform_norm = transforms.ToTensor() patches = torch.zeros(size=(len(positions),15,15)) frame_nr = np.zeros(shape=(len(positions),)) patch_nr = np.zeros(shape=(len(positions),)) j=0 for i in range(1,n_frames+1): k=0 for patch in patch_h5['frame_nr%s' %i]: patches[j] = transform_norm(np.array(patch,dtype=np.uint8)) frame_nr[j] = i patch_nr[j] = k j+=1 k+=1 transformed_patches = patches transformed_patches = torch.reshape(transformed_patches,(len(patches),1,15,15)) inp = transforms.ToTensor()(inp).unsqueeze(0) model=torch.load(m_file) with torch.no_grad(): for i, (inputs, labels) in enumerate(dl_valid): inputs = inputs.to(device) y_pred = model(inputs) y_pred = y_pred.cpu().numpy() labels = labels.cpu().numpy() plot_error(y_pred,labels,args.expName) mse += np.power(y_pred[:,0] - labels[:,0],2) + np.power(y_pred[:,1] - labels[:,1],2) interface=gr.Interface(fn=predict, inputs={"m_file": "upload", "patch_file": "upload", "positions_file":"upload", "expName": "text"}, outputs={"output":"text"}) interface.launch()