#!/usr/bin/env python3 import gradio as gr import torch import torchvision import torch.nn.functional as F from torchvision import transforms, models from PIL import Image import itertools import numpy as np import os import logging args = { "model_path": "model_last_epoch_34_torchvision0_3_state.ptw", "device": torch.device('cuda' if torch.cuda.is_available() else 'cpu'), "dxlabels": ["akiec", "bcc", "bkl", "df", "mel", "nv","vasc"] } logging.warning(args) # No specific normalization was performed during training def normtransform(x): return x # Load model logging.warning("Loading model...") model = torchvision.models.resnet34() model.fc = torch.nn.Linear(model.fc.in_features, len(args.get("dxlabels"))) model.load_state_dict(torch.load(args.get("model_path"))) model.eval() model.to(device=args.get("device")) torch.set_grad_enabled(False) logging.warning("Model loaded.") def predict(image: str)->dict: global model, args logging.warning(f"Starting predict('{image}') ...") prediction_tensor = torch.zeros([1, len(args.get("dxlabels"))]).to(device=args.get("device")) # Test-time augmentations available_sizes = [224] target_sizes, hflips, rotations, crops = available_sizes, [0, 1], [0, 90], [0.8] aug_combos = [x for x in itertools.product(target_sizes, hflips, rotations, crops)] # Load image img = Image.open(image) img = img.convert('RGB') # Predict with Test-time augmentation for (target_size, hflip, rotation, crop) in aug_combos: tfm = transforms.Compose([ transforms.Resize(int(target_size // crop)), transforms.CenterCrop(target_size), transforms.RandomHorizontalFlip(hflip), transforms.RandomRotation([rotation, rotation]), transforms.ToTensor(), # shades_of_grey_torch, normtransform ]) test_data = tfm(img).unsqueeze(0).to(device=args.get("device")) running_preds = torch.FloatTensor().to(device=args.get("device")) outputs = model(test_data) prediction_tensor += running_preds prediction_tensor /= len(aug_combos) predictions = F.softmax(prediction_tensor, dim=1)[0].cpu().numpy() logging.warning(f"Returning {predictions=}") return {args.get("dxlabels")[enu]: p for enu, p in enumerate(predictions)} description = '''Research artifact for multi-class predictions of common dermatologic tumors. This is the model used in the publication [Tschandl P. et al. Nature Medicine 2020](https://www.nature.com/articles/s41591-020-0942-0). For education and research use only. **DO NOT use this to obtain medical advice!** If you have a skin change in question, seek contact to your physician.''' logging.warning("Starting Gradio interface...") gr.Interface( predict, inputs=gr.Image(label="Upload a dermatoscopic image", type="filepath"), outputs=gr.Label(num_top_classes=len(args.get("dxlabels"))), title="Dermatoscopic classification", description=description, allow_flagging="manual", examples=[os.path.join(os.path.dirname(__file__), "ISIC_0024306.jpg")] ).launch()