import gradio as gr import torch import torchvision.transforms as transforms import torch.nn as nn import numpy as np def predict(model, image, device): # Preprocess the image transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) img_tensor = transform(image).unsqueeze(0) pred_fun = torch.nn.Softmax(dim=1) # preds = [] with torch.set_grad_enabled(False): y = pred_fun(model(img_tensor)) print(y) y = y.cpu().numpy() print(y) y = y[:, 1] # cat:0, dog: 1 print(y) y = y[0] print(y) # preds.append(y) # preds = np.concatenate(preds) return {"tenka ippin":y, "no entry":1-y} # return preds def process_image(input_image): model = torch.load('models/tenichi_noentry.pth') preds = predict(model, input_image, 'cpu') return preds iface = gr.Interface( fn=process_image, inputs=[ gr.Image(type="pil", label="Input Image", height=512), ], outputs=gr.Label(label="Output", show_label=False), description="画像に映っているのが天下一品のロゴなのか、進入禁止標識なのか判別します", examples=[ ["examples/ten20.png"], ["examples/noe33.png"], ], # run_on_click=False, # cache_examples=True ) if __name__ == "__main__": iface.launch() # demo = gr.Interface(fn=greet, inputs="text", outputs="text") # demo.launch()