import gradio as gr import torch import torch.nn as nn from torchvision import transforms from torchvision.models import resnet18 from transferwee import download model = resnet18(pretrained=True) model.fc = nn.Sequential( nn.Linear(512, 16), nn.ReLU(), nn.Linear(16,1) ) # download latest model # download("https://we.tl/t-bbgc3gXROZ", "best.pt") # 1 # download("https://we.tl/t-25s74dahjU", "best.pt") # 4 --> 0.92 # checkpoint = torch.load("best.pt", map_location=torch.device('cpu')) # model.load_state_dict(checkpoint['model_state_dict']) model.eval() labels_to_class = { 0: "normal", 1: "risk" } def predict(inp): inp = transforms.ToTensor()(inp).unsqueeze(0) # [1, C, H, W] with torch.no_grad(): prediction = torch.sigmoid(model(inp)[0]) if prediction > 0.7: confidences = { "Normal": float(prediction[0]) } else: confidences = { "Riesgo": float(prediction[0]) } print(confidences) return confidences description = """ <center> Este nuestro clasificador de video de uso de grúas horquillas.\n\n A partir de un múltiples frames, nuestro modelo compone un <i>vídeo< /i> con objetivo de este es poder determinar si es que operación de riesgo o no.\n\n Nuestro modelo utiliza una red convolucional pre-entrenada y, posteriormente, finetuneada en nuestro conjunto de datos. </center> """ examples = [ "videos/normal-1_frame_2.jpg", "videos/risk-0_frame_177.jpg", ] gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=3), # title="Forklikt Risk Detection Demo", description=description, examples=examples, ).launch(share=False)