Spaces:
Sleeping
Sleeping
Stop versioning the model checkpoints, now they are downloaded from huggingface. Add env vars
Browse files- .env.example +4 -0
- .gitignore +4 -0
- README.md +11 -0
- app.py +3 -13
- constants.py +19 -3
- models/.gitkeep +0 -0
- models/efficientdet-d2-detector.pth.tar +0 -3
- models/resnet50-classifier.pkl +0 -3
- requirements.txt +2 -0
- trash_detector.py +24 -18
.env.example
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CLAS_FILENAME = "classifier_model_name.pkl"
|
2 |
+
DET_FILENAME = "detector_model_name.pth.tar"
|
3 |
+
HF_DET_REPO_NAME = "org/repo_name"
|
4 |
+
HF_CLAS_REPO_NAME = "org/repo_name"
|
.gitignore
CHANGED
@@ -3,3 +3,7 @@ __pycache__
|
|
3 |
*.jpg
|
4 |
*.png
|
5 |
*.jpeg
|
|
|
|
|
|
|
|
|
|
3 |
*.jpg
|
4 |
*.png
|
5 |
*.jpeg
|
6 |
+
*.tar
|
7 |
+
*.pkl
|
8 |
+
.env
|
9 |
+
*.pth
|
README.md
CHANGED
@@ -32,6 +32,17 @@ python -m venv venv-waste-classifier
|
|
32 |
source venv-waste-classifier/bin/activate # On Windows, use 'venv-waste-classifier\Scripts\activate'
|
33 |
pip install -r requirements.txt
|
34 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
### Running the App
|
37 |
|
|
|
32 |
source venv-waste-classifier/bin/activate # On Windows, use 'venv-waste-classifier\Scripts\activate'
|
33 |
pip install -r requirements.txt
|
34 |
```
|
35 |
+
Create a `.env` file and set the following properties:
|
36 |
+
|
37 |
+
- CLAS_FILENAME --> The name of the waste classificator model checkpoint
|
38 |
+
- DET_FILENAME --> The name of the waste detector model checkpoint
|
39 |
+
- HF_DET_REPO_NAME --> The huggingface repository name of the detector model
|
40 |
+
- For example: [rootstrap-org/waste-detector](https://huggingface.co/rootstrap-org/waste-detector)
|
41 |
+
- HF_CLAS_REPO_NAME --> The huggingface repository name of the classifier model
|
42 |
+
- For example: [rootstrap-org/waste-classifier](https://huggingface.co/rootstrap-org/waste-classifier)
|
43 |
+
|
44 |
+
You can ommit setting the last two properties if you download the models manually and put them under the `models` directory.
|
45 |
+
|
46 |
|
47 |
### Running the App
|
48 |
|
app.py
CHANGED
@@ -4,14 +4,12 @@ Streamlit app
|
|
4 |
import sys
|
5 |
|
6 |
import streamlit as st
|
|
|
7 |
|
8 |
-
from constants import
|
9 |
-
DET_NAME, DET_THRESHOLD, DEVICE, OUTPUT_IMG_FILEPATH)
|
10 |
|
11 |
sys.path.append("./efficientdet")
|
12 |
|
13 |
-
from PIL import Image
|
14 |
-
|
15 |
from efficientdet.efficientdet import plot_results
|
16 |
from trash_detector import detect_trash
|
17 |
|
@@ -57,15 +55,7 @@ def render():
|
|
57 |
with col2:
|
58 |
with st.spinner(text="Classifying the trash..."):
|
59 |
img = Image.open(uploaded_file).convert("RGB")
|
60 |
-
cls_prob, bboxes_final = detect_trash(
|
61 |
-
img,
|
62 |
-
DET_NAME,
|
63 |
-
DET_FILEPATH,
|
64 |
-
CLAS_FILEPATH,
|
65 |
-
DEVICE,
|
66 |
-
DET_THRESHOLD,
|
67 |
-
CLAS_THRESHOLD,
|
68 |
-
)
|
69 |
# plot and save demo image
|
70 |
plot_results(
|
71 |
img, cls_prob, bboxes_final, CLASSES, OUTPUT_IMG_FILEPATH
|
|
|
4 |
import sys
|
5 |
|
6 |
import streamlit as st
|
7 |
+
from PIL import Image
|
8 |
|
9 |
+
from constants import CLASSES, OUTPUT_IMG_FILEPATH
|
|
|
10 |
|
11 |
sys.path.append("./efficientdet")
|
12 |
|
|
|
|
|
13 |
from efficientdet.efficientdet import plot_results
|
14 |
from trash_detector import detect_trash
|
15 |
|
|
|
55 |
with col2:
|
56 |
with st.spinner(text="Classifying the trash..."):
|
57 |
img = Image.open(uploaded_file).convert("RGB")
|
58 |
+
cls_prob, bboxes_final = detect_trash(img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
# plot and save demo image
|
60 |
plot_results(
|
61 |
img, cls_prob, bboxes_final, CLASSES, OUTPUT_IMG_FILEPATH
|
constants.py
CHANGED
@@ -1,8 +1,24 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
CLASSES = ["cardboard", "compost", "glass", "metal", "paper", "plastic", "trash"]
|
4 |
-
DET_NAME = "tf_efficientdet_d2"
|
5 |
CLAS_THRESHOLD = 0.5
|
|
|
6 |
DET_THRESHOLD = 0.17
|
7 |
DEVICE = "cpu"
|
8 |
OUTPUT_IMG_FILEPATH = "classified_image.jpg"
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
# Model checkpoints and repo names
|
8 |
+
CLAS_FILENAME = os.getenv("CLAS_FILENAME")
|
9 |
+
DET_FILENAME = os.getenv("DET_FILENAME")
|
10 |
+
HF_CLAS_REPO_NAME = os.getenv("HF_CLAS_REPO_NAME")
|
11 |
+
HF_DET_REPO_NAME = os.getenv("HF_DET_REPO_NAME")
|
12 |
+
|
13 |
+
# Models paths
|
14 |
+
MODELS_PATH = "models"
|
15 |
+
CLAS_FILEPATH = f"{MODELS_PATH}/{CLAS_FILENAME}"
|
16 |
+
DET_FILEPATH = f"{MODELS_PATH}/{DET_FILENAME}"
|
17 |
+
|
18 |
+
# Other constants
|
19 |
CLASSES = ["cardboard", "compost", "glass", "metal", "paper", "plastic", "trash"]
|
|
|
20 |
CLAS_THRESHOLD = 0.5
|
21 |
+
DET_NAME = "tf_efficientdet_d2"
|
22 |
DET_THRESHOLD = 0.17
|
23 |
DEVICE = "cpu"
|
24 |
OUTPUT_IMG_FILEPATH = "classified_image.jpg"
|
models/.gitkeep
ADDED
File without changes
|
models/efficientdet-d2-detector.pth.tar
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:499a3f0c75e13669d69be25854e980812e2f6b50e618ba2b2e90b25f193e7fd9
|
3 |
-
size 97791163
|
|
|
|
|
|
|
|
models/resnet50-classifier.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:8d2c0667090f996cbe4bab8585300528b8896071e70b1edfdbe671015a074e85
|
3 |
-
size 102980821
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -2,6 +2,7 @@ albumentations>=0.5.2
|
|
2 |
efficientnet_pytorch
|
3 |
fastai==2.7.13
|
4 |
funcy==1.15
|
|
|
5 |
iterative-stratification==0.1.6
|
6 |
matplotlib==3.8.2
|
7 |
numpy==1.26.2
|
@@ -9,6 +10,7 @@ omegaconf>=2.0
|
|
9 |
opencv-python==4.8.1.78
|
10 |
opencv-python-headless==4.8.1.78
|
11 |
pycocotools>=2.0.0
|
|
|
12 |
pytorch_lightning
|
13 |
pyyaml
|
14 |
rembg==2.0.53
|
|
|
2 |
efficientnet_pytorch
|
3 |
fastai==2.7.13
|
4 |
funcy==1.15
|
5 |
+
huggingface_hub
|
6 |
iterative-stratification==0.1.6
|
7 |
matplotlib==3.8.2
|
8 |
numpy==1.26.2
|
|
|
10 |
opencv-python==4.8.1.78
|
11 |
opencv-python-headless==4.8.1.78
|
12 |
pycocotools>=2.0.0
|
13 |
+
python-dotenv
|
14 |
pytorch_lightning
|
15 |
pyyaml
|
16 |
rembg==2.0.53
|
trash_detector.py
CHANGED
@@ -1,29 +1,41 @@
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
from fastai.vision.all import load_learner
|
|
|
4 |
|
|
|
|
|
|
|
|
|
5 |
from efficientdet.efficientdet import get_transforms, rescale_bboxes, set_model
|
6 |
|
7 |
|
8 |
-
def localize_trash(im
|
9 |
-
# detector
|
10 |
-
|
|
|
|
|
|
|
11 |
detector.eval()
|
12 |
# mean-std normalize the input image (batch-size: 1)
|
13 |
img = get_transforms(im)
|
14 |
# propagate through the model
|
15 |
-
outputs = detector(img.to(
|
16 |
# keep only predictions above set confidence
|
17 |
-
bboxes_keep = outputs[0, outputs[0, :, 4] >
|
18 |
probas = bboxes_keep[:, 4:]
|
19 |
# convert boxes to image scales
|
20 |
bboxes_scaled = rescale_bboxes(bboxes_keep[:, :4], im.size, tuple(img.size()[2:]))
|
21 |
return probas, bboxes_scaled
|
22 |
|
23 |
|
24 |
-
def classify_trash(im,
|
25 |
-
# classifier
|
26 |
-
|
|
|
|
|
27 |
|
28 |
bboxes_final = []
|
29 |
cls_prob = []
|
@@ -32,26 +44,20 @@ def classify_trash(im, clas_checkpoint, cls_th, probas, bboxes_scaled):
|
|
32 |
outputs = classifier.predict(img)
|
33 |
p[1] = torch.topk(outputs[2], k=1).indices.squeeze(0).item()
|
34 |
p[0] = torch.max(np.trunc(outputs[2] * 100))
|
35 |
-
if p[0] >=
|
36 |
bboxes_final.append((xmin, ymin, xmax, ymax))
|
37 |
cls_prob.append(p)
|
38 |
return cls_prob, bboxes_final
|
39 |
|
40 |
|
41 |
-
def detect_trash(
|
42 |
-
im, det_name, det_checkpoint, clas_checkpoint, device, prob_threshold, cls_th
|
43 |
-
):
|
44 |
# prepare models for evaluation
|
45 |
torch.set_grad_enabled(False)
|
46 |
|
47 |
# 1) Localize
|
48 |
-
probas, bboxes_scaled = localize_trash(
|
49 |
-
im, det_name, det_checkpoint, device, prob_threshold
|
50 |
-
)
|
51 |
|
52 |
# 2) Classify
|
53 |
-
cls_prob, bboxes_final = classify_trash(
|
54 |
-
im, clas_checkpoint, cls_th, probas, bboxes_scaled
|
55 |
-
)
|
56 |
|
57 |
return cls_prob, bboxes_final
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
from fastai.vision.all import load_learner
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
|
8 |
+
from constants import (CLAS_FILENAME, CLAS_FILEPATH, CLAS_THRESHOLD,
|
9 |
+
DET_FILENAME, DET_FILEPATH, DET_NAME, DET_THRESHOLD,
|
10 |
+
DEVICE, HF_CLAS_REPO_NAME, HF_DET_REPO_NAME,
|
11 |
+
MODELS_PATH)
|
12 |
from efficientdet.efficientdet import get_transforms, rescale_bboxes, set_model
|
13 |
|
14 |
|
15 |
+
def localize_trash(im):
|
16 |
+
# detector, if checkpoint doesn't exist then download from hf
|
17 |
+
if not os.path.exists(DET_FILEPATH):
|
18 |
+
hf_hub_download(HF_DET_REPO_NAME, DET_FILENAME, local_dir=MODELS_PATH)
|
19 |
+
detector = set_model(DET_NAME, 1, DET_FILEPATH, DEVICE)
|
20 |
+
|
21 |
detector.eval()
|
22 |
# mean-std normalize the input image (batch-size: 1)
|
23 |
img = get_transforms(im)
|
24 |
# propagate through the model
|
25 |
+
outputs = detector(img.to(DEVICE))
|
26 |
# keep only predictions above set confidence
|
27 |
+
bboxes_keep = outputs[0, outputs[0, :, 4] > DET_THRESHOLD]
|
28 |
probas = bboxes_keep[:, 4:]
|
29 |
# convert boxes to image scales
|
30 |
bboxes_scaled = rescale_bboxes(bboxes_keep[:, :4], im.size, tuple(img.size()[2:]))
|
31 |
return probas, bboxes_scaled
|
32 |
|
33 |
|
34 |
+
def classify_trash(im, probas, bboxes_scaled):
|
35 |
+
# classifier, if checkpoint doesn't exist then download from hf
|
36 |
+
if not os.path.exists(CLAS_FILEPATH):
|
37 |
+
hf_hub_download(HF_CLAS_REPO_NAME, CLAS_FILENAME, local_dir=MODELS_PATH)
|
38 |
+
classifier = load_learner(CLAS_FILEPATH)
|
39 |
|
40 |
bboxes_final = []
|
41 |
cls_prob = []
|
|
|
44 |
outputs = classifier.predict(img)
|
45 |
p[1] = torch.topk(outputs[2], k=1).indices.squeeze(0).item()
|
46 |
p[0] = torch.max(np.trunc(outputs[2] * 100))
|
47 |
+
if p[0] >= CLAS_THRESHOLD * 100:
|
48 |
bboxes_final.append((xmin, ymin, xmax, ymax))
|
49 |
cls_prob.append(p)
|
50 |
return cls_prob, bboxes_final
|
51 |
|
52 |
|
53 |
+
def detect_trash(img):
|
|
|
|
|
54 |
# prepare models for evaluation
|
55 |
torch.set_grad_enabled(False)
|
56 |
|
57 |
# 1) Localize
|
58 |
+
probas, bboxes_scaled = localize_trash(img)
|
|
|
|
|
59 |
|
60 |
# 2) Classify
|
61 |
+
cls_prob, bboxes_final = classify_trash(img, probas, bboxes_scaled)
|
|
|
|
|
62 |
|
63 |
return cls_prob, bboxes_final
|