|
|
|
import os |
|
import torch |
|
from basic import ResNet18 |
|
import gradio as gr |
|
import numpy as np |
|
from torchvision.transforms import transforms |
|
from transformers import AutoModelForSequenceClassification |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained('hasanah10105/breast-cancer-classification/ckpt/epoch=49-step=1750.ckpt') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class_names = ['benign', 'malignant', 'normal'] |
|
class_names.sort() |
|
|
|
transformation_pipeline = transforms.Compose([ |
|
transforms.ToPILImage(), |
|
transforms.Grayscale(num_output_channels=1), |
|
transforms.Resize((256, 256)), |
|
transforms.RandomRotation(20), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0], std=[1]) |
|
]) |
|
|
|
def preprocess_image(image: np.ndarray): |
|
"""Preprocess the input image. |
|
|
|
Note that the input image is in RGB mode. |
|
|
|
Parameters |
|
---------- |
|
image: np.ndarray |
|
Input image from callback. |
|
""" |
|
|
|
image = transformation_pipeline(image) |
|
image = torch.unsqueeze(image, 0) |
|
|
|
return image |
|
|
|
def image_classifier(inp): |
|
"""Image Classifier Function. |
|
|
|
Parameters |
|
---------- |
|
inp: Optional[np.ndarray] = None |
|
Input image from callback |
|
|
|
Returns |
|
------- |
|
Dict |
|
A dictionary class names and its probability |
|
""" |
|
|
|
|
|
if inp is None: |
|
return {gr.Error()} |
|
|
|
|
|
image = preprocess_image(inp) |
|
image = image.to(dtype=torch.float32) |
|
|
|
|
|
result = model(image) |
|
|
|
|
|
result = torch.nn.functional.softmax(result.logits, dim=1) |
|
result = result[0].detach().numpy().tolist() |
|
labeled_result = {name:score for name, score in zip(class_names, result)} |
|
|
|
return labeled_result |
|
|
|
iface = gr.Interface(fn=image_classifier, inputs="image", outputs="label") |
|
iface.launch() |