File size: 1,372 Bytes
c80d125
 
2b25469
c726874
 
2b25469
c80d125
c726874
c80d125
c912b6b
c726874
 
 
 
 
 
c80d125
c726874
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c912b6b
c726874
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from PIL import Image

import gradio as gr
import torch
import torchvision.transforms as transforms

from model import *

title = "Garment Classifier"
description = "Trained on the Fashion MNIST dataset (28x28 pixels). The model expects images containing only one garment article as in the examples."
inputs = gr.components.Image()
outputs = gr.components.Label()
examples = "examples"

model = torch.load("model/fashion.mnist.base.pt", map_location=torch.device("cpu"))

# Images need to be transformed to the `Fashion MNIST` dataset format
# see https://arxiv.org/abs/1708.07747
transform = transforms.Compose(
    [
        transforms.Resize((28, 28)),
        transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),  # Normalization
        transforms.Lambda(lambda x: 1.0 - x),  # Invert colors
        transforms.Lambda(lambda x: x[0]),
        transforms.Lambda(lambda x: x.unsqueeze(0)),
    ]
)


def predict(img):
    img = transform(Image.fromarray(img))
    predictions = model.predictions(img)
    return predictions


with gr.Blocks() as demo:
    with gr.Tab("Garment Prediction"):
        gr.Interface(
            fn=predict,
            inputs=inputs,
            outputs=outputs,
            examples=examples,
            description=description,
        ).queue(default_concurrency_limit=5)

demo.launch()