pradanaadn commited on
Commit
869b2b3
·
1 Parent(s): 14a6bf2

feat: create ui using gradio

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .venv
2
+ .gradio
3
+ uv.lock
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
examples/cardbox.jpeg ADDED
examples/glass.jpeg ADDED
examples/plastic.png ADDED
main.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+ from torchvision.models import mobilenet_v3_large
6
+ from torchvision.transforms import v2
7
+ from PIL import Image
8
+ import os
9
+
10
+
11
+ class TrashMobileNet(nn.Module, PyTorchModelHubMixin):
12
+ def __init__(self, num_classes=6):
13
+ super(TrashMobileNet, self).__init__()
14
+ self.model = mobilenet_v3_large(weights="DEFAULT")
15
+ for param in self.model.parameters():
16
+ param.requires_grad = False
17
+ num_features = self.model.classifier[-1].in_features
18
+ self.model.classifier[-1] = nn.Linear(num_features, num_classes)
19
+ for param in self.model.classifier[-1].parameters():
20
+ param.requires_grad = True
21
+
22
+ def forward(self, x):
23
+ x = self.model(x)
24
+ return x
25
+
26
+
27
+ # Load the model from Hugging Face Hub
28
+ model_name = "pradanaadn/trash-clasification"
29
+ model = TrashMobileNet.from_pretrained(model_name)
30
+ model.eval()
31
+
32
+ # Define the image transformations
33
+ transform = v2.Compose([
34
+ v2.Resize((224, 224)),
35
+ v2.ToImage(),
36
+ v2.ToDtype(torch.float32, scale=True),
37
+ ])
38
+
39
+
40
+ def predict(image):
41
+ """
42
+ Prediction function that takes a Gradio image input and returns class probabilities
43
+ """
44
+ labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
45
+
46
+ # Convert Gradio image to PIL Image if it's not already
47
+ if not isinstance(image, Image.Image):
48
+ image = Image.fromarray(image)
49
+
50
+ # Apply transformations and add batch dimension
51
+ image_tensor = transform(image)
52
+ image_tensor = image_tensor.unsqueeze(0)
53
+
54
+ # Get model predictions
55
+ with torch.no_grad():
56
+ outputs = model(image_tensor)
57
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
58
+ probabilities = probabilities[0].tolist()
59
+
60
+ # Create dictionary of label-probability pairs
61
+ return {label: float(prob) for label, prob in zip(labels, probabilities)}
62
+
63
+
64
+
65
+
66
+ # Create example images if they don't exist (you would need to provide these images)
67
+ examples = [
68
+ ["examples/cardbox.jpeg", "A cardboard box"],
69
+ ["examples/glass.jpeg", "A glass bottle"],
70
+ ["examples/plastic.png", "Mixed trash"]
71
+ ]
72
+
73
+
74
+ with gr.Blocks() as iface:
75
+
76
+
77
+ with gr.Row():
78
+ with gr.Column():
79
+ input_image = gr.Image(
80
+ label="Upload Image",
81
+ type="pil",
82
+ elem_id="image_upload"
83
+ )
84
+ submit_btn = gr.Button("Classify", variant="primary")
85
+
86
+ with gr.Column():
87
+ output_label = gr.Label(
88
+ label="Classification Results",
89
+ num_top_classes=6
90
+ )
91
+
92
+ gr.Markdown("### Example Images")
93
+ gr.Examples(
94
+ examples=examples,
95
+ inputs=input_image,
96
+ outputs=output_label,
97
+ fn=predict,
98
+ cache_examples=True
99
+ )
100
+
101
+ submit_btn.click(
102
+ fn=predict,
103
+ inputs=input_image,
104
+ outputs=output_label
105
+ )
106
+
107
+
108
+ # Launch the interface
109
+ iface.launch(share=True)
pyproject.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "trash-classification"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "gradio>=5.3.0",
9
+ "huggingface-hub>=0.27.0",
10
+ "torch>=2.5.1",
11
+ "torchvision>=0.20.1",
12
+ "transformers>=4.47.1",
13
+ ]
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml --no-deps -o requirements.txt
3
+ gradio==5.9.1
4
+ # via trash-classification (pyproject.toml)
5
+ huggingface-hub==0.27.0
6
+ # via trash-classification (pyproject.toml)
7
+ torch==2.5.1
8
+ # via trash-classification (pyproject.toml)
9
+ torchvision==0.20.1
10
+ # via trash-classification (pyproject.toml)
11
+ transformers==4.47.1
12
+ # via trash-classification (pyproject.toml)