import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import resnet18
from transferwee import download

model = resnet18(pretrained=True)
model.fc = nn.Sequential(
    nn.Linear(512, 16),
    nn.ReLU(),                             
    nn.Linear(16,1)
)

# download latest model
# download("https://we.tl/t-bbgc3gXROZ", "best.pt") # 1
# download("https://we.tl/t-25s74dahjU", "best.pt") # 4 --> 0.92

# checkpoint = torch.load("best.pt", map_location=torch.device('cpu'))
# model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

labels_to_class = {
    0: "normal",
    1: "risk"
}
def predict(inp):
    inp = transforms.ToTensor()(inp).unsqueeze(0) # [1, C, H, W]
    with torch.no_grad():
        prediction = torch.sigmoid(model(inp)[0])
        if prediction > 0.7:
            confidences = {
                "Normal": float(prediction[0])
            }
        
        else:
            confidences = {
                "Riesgo": float(prediction[0])
            }
    
    print(confidences)
    return confidences


description = """
<center>
    Este nuestro clasificador de video de uso de grúas horquillas.\n\n

    A partir de un múltiples frames, nuestro modelo compone un <i>vídeo<
    /i> con objetivo de este es poder determinar si es que operación de riesgo o no.\n\n

    Nuestro modelo utiliza una red convolucional pre-entrenada y, posteriormente, finetuneada en nuestro conjunto de datos.
</center>
"""

examples = [
    "videos/normal-1_frame_2.jpg",
    "videos/risk-0_frame_177.jpg",
]

gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(num_top_classes=3),
    # title="Forklikt Risk Detection Demo",
    description=description,
    examples=examples,
).launch(share=False)