wuming233 commited on
Commit
52cf196
1 Parent(s): 1a93700

first commit

Browse files
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ vit_b_16_dout0.3_10epochs.pth filter=lfs diff=lfs merge=lfs -text
36
+ examples/metal.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Garbage Sense
3
- emoji: 📈
4
  colorFrom: purple
5
  colorTo: yellow
6
  sdk: gradio
@@ -10,4 +10,4 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Garbage Sense
3
+ emoji: 👓
4
  colorFrom: purple
5
  colorTo: yellow
6
  sdk: gradio
 
10
  license: mit
11
  ---
12
 
13
+ A vision transformer trained to classify garbage into 6 categories, following DeIT’s training recipe.
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import gradio as gr
4
+
5
+ from model import create_vit
6
+ from timeit import default_timer as timer
7
+ from typing import Tuple, Dict
8
+
9
+ class_names = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
10
+
11
+ vit, vit_transform = create_vit(output_classes=len(class_names))
12
+
13
+ vit.load_state_dict(torch.load(f="vit_b_16_dout0.3_10epochs.pth"))
14
+
15
+ def predict(img) -> Tuple[Dict, float]:
16
+ start_time = timer()
17
+ img = vit_transform(img).unsqueeze(0)
18
+
19
+ vit.eval()
20
+ with torch.inference_mode():
21
+ pred_probs = torch.softmax(vit(img), dim=1)
22
+
23
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
24
+ pred_time = round(timer() - start_time, 5)
25
+ return pred_labels_and_probs, pred_time
26
+
27
+ title = "Garbage Sense"
28
+ description = "A vision transformer trained to classify garbage into 6 categories on [trashnet](https://github.com/garythung/trashnet)."
29
+ article = ""
30
+
31
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
32
+
33
+ demo = gr.Interface(
34
+ fn=predict,
35
+ inputs=gr.Image(type="pil"),
36
+ outputs=[
37
+ gr.Label(num_top_classes=6, label="Predictions"),
38
+ gr.Number(label="Prediction time (s)"),
39
+ ],
40
+ examples=example_list,
41
+ title=title,
42
+ description=description
43
+ )
44
+
45
+ demo.launch()
examples/cardboard.jpg ADDED
examples/glass.jpg ADDED
examples/metal.jpg ADDED

Git LFS Details

  • SHA256: 8fc8d4b106d580481df6b5d68c39dc42eab42bb29607c765d28a4e9ecf9adbed
  • Pointer size: 132 Bytes
  • Size of remote file: 1.48 MB
examples/paper.jpg ADDED
examples/plastic.jpg ADDED
model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms, models
3
+
4
+ def create_vit(output_classes: int = 6, seed: int = 233):
5
+
6
+ transform = transforms.Compose([
7
+ transforms.Resize((256, 256)),
8
+ transforms.CenterCrop(224),
9
+ transforms.ToTensor(),
10
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
11
+ ])
12
+
13
+ model = models.vit_b_16()
14
+
15
+ for param in model.parameters():
16
+ param.requires_grade = False
17
+
18
+ torch.manual_seed(seed)
19
+
20
+ model.heads = torch.nn.Sequential(
21
+ torch.nn.Dropout(0.3),
22
+ torch.nn.Linear(in_features=768, out_features=output_classes)
23
+ )
24
+
25
+ return model, transform
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch>=1.12.0
2
+ torchvision>=0.13.0
vit_b_16_dout0.3_10epochs.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc0379ab3c2dc54dff64813afd4e9507a197d34e0dc02bc5a5b944004b4b2b2e
3
+ size 343275877