santit96 commited on
Commit
dd14920
1 Parent(s): 7dea6fc

Stop versioning the model checkpoints, now they are downloaded from huggingface. Add env vars

Browse files
.env.example ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ CLAS_FILENAME = "classifier_model_name.pkl"
2
+ DET_FILENAME = "detector_model_name.pth.tar"
3
+ HF_DET_REPO_NAME = "org/repo_name"
4
+ HF_CLAS_REPO_NAME = "org/repo_name"
.gitignore CHANGED
@@ -3,3 +3,7 @@ __pycache__
3
  *.jpg
4
  *.png
5
  *.jpeg
 
 
 
 
 
3
  *.jpg
4
  *.png
5
  *.jpeg
6
+ *.tar
7
+ *.pkl
8
+ .env
9
+ *.pth
README.md CHANGED
@@ -32,6 +32,17 @@ python -m venv venv-waste-classifier
32
  source venv-waste-classifier/bin/activate # On Windows, use 'venv-waste-classifier\Scripts\activate'
33
  pip install -r requirements.txt
34
  ```
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  ### Running the App
37
 
 
32
  source venv-waste-classifier/bin/activate # On Windows, use 'venv-waste-classifier\Scripts\activate'
33
  pip install -r requirements.txt
34
  ```
35
+ Create a `.env` file and set the following properties:
36
+
37
+ - CLAS_FILENAME --> The name of the waste classificator model checkpoint
38
+ - DET_FILENAME --> The name of the waste detector model checkpoint
39
+ - HF_DET_REPO_NAME --> The huggingface repository name of the detector model
40
+ - For example: [rootstrap-org/waste-detector](https://huggingface.co/rootstrap-org/waste-detector)
41
+ - HF_CLAS_REPO_NAME --> The huggingface repository name of the classifier model
42
+ - For example: [rootstrap-org/waste-classifier](https://huggingface.co/rootstrap-org/waste-classifier)
43
+
44
+ You can ommit setting the last two properties if you download the models manually and put them under the `models` directory.
45
+
46
 
47
  ### Running the App
48
 
app.py CHANGED
@@ -4,14 +4,12 @@ Streamlit app
4
  import sys
5
 
6
  import streamlit as st
 
7
 
8
- from constants import (CLAS_FILEPATH, CLAS_THRESHOLD, CLASSES, DET_FILEPATH,
9
- DET_NAME, DET_THRESHOLD, DEVICE, OUTPUT_IMG_FILEPATH)
10
 
11
  sys.path.append("./efficientdet")
12
 
13
- from PIL import Image
14
-
15
  from efficientdet.efficientdet import plot_results
16
  from trash_detector import detect_trash
17
 
@@ -57,15 +55,7 @@ def render():
57
  with col2:
58
  with st.spinner(text="Classifying the trash..."):
59
  img = Image.open(uploaded_file).convert("RGB")
60
- cls_prob, bboxes_final = detect_trash(
61
- img,
62
- DET_NAME,
63
- DET_FILEPATH,
64
- CLAS_FILEPATH,
65
- DEVICE,
66
- DET_THRESHOLD,
67
- CLAS_THRESHOLD,
68
- )
69
  # plot and save demo image
70
  plot_results(
71
  img, cls_prob, bboxes_final, CLASSES, OUTPUT_IMG_FILEPATH
 
4
  import sys
5
 
6
  import streamlit as st
7
+ from PIL import Image
8
 
9
+ from constants import CLASSES, OUTPUT_IMG_FILEPATH
 
10
 
11
  sys.path.append("./efficientdet")
12
 
 
 
13
  from efficientdet.efficientdet import plot_results
14
  from trash_detector import detect_trash
15
 
 
55
  with col2:
56
  with st.spinner(text="Classifying the trash..."):
57
  img = Image.open(uploaded_file).convert("RGB")
58
+ cls_prob, bboxes_final = detect_trash(img)
 
 
 
 
 
 
 
 
59
  # plot and save demo image
60
  plot_results(
61
  img, cls_prob, bboxes_final, CLASSES, OUTPUT_IMG_FILEPATH
constants.py CHANGED
@@ -1,8 +1,24 @@
1
- CLAS_FILEPATH = "models/resnet50-classifier.pkl"
2
- DET_FILEPATH = "models/efficientdet-d2-detector.pth.tar"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  CLASSES = ["cardboard", "compost", "glass", "metal", "paper", "plastic", "trash"]
4
- DET_NAME = "tf_efficientdet_d2"
5
  CLAS_THRESHOLD = 0.5
 
6
  DET_THRESHOLD = 0.17
7
  DEVICE = "cpu"
8
  OUTPUT_IMG_FILEPATH = "classified_image.jpg"
 
1
+ import os
2
+
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv()
6
+
7
+ # Model checkpoints and repo names
8
+ CLAS_FILENAME = os.getenv("CLAS_FILENAME")
9
+ DET_FILENAME = os.getenv("DET_FILENAME")
10
+ HF_CLAS_REPO_NAME = os.getenv("HF_CLAS_REPO_NAME")
11
+ HF_DET_REPO_NAME = os.getenv("HF_DET_REPO_NAME")
12
+
13
+ # Models paths
14
+ MODELS_PATH = "models"
15
+ CLAS_FILEPATH = f"{MODELS_PATH}/{CLAS_FILENAME}"
16
+ DET_FILEPATH = f"{MODELS_PATH}/{DET_FILENAME}"
17
+
18
+ # Other constants
19
  CLASSES = ["cardboard", "compost", "glass", "metal", "paper", "plastic", "trash"]
 
20
  CLAS_THRESHOLD = 0.5
21
+ DET_NAME = "tf_efficientdet_d2"
22
  DET_THRESHOLD = 0.17
23
  DEVICE = "cpu"
24
  OUTPUT_IMG_FILEPATH = "classified_image.jpg"
models/.gitkeep ADDED
File without changes
models/efficientdet-d2-detector.pth.tar DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:499a3f0c75e13669d69be25854e980812e2f6b50e618ba2b2e90b25f193e7fd9
3
- size 97791163
 
 
 
 
models/resnet50-classifier.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8d2c0667090f996cbe4bab8585300528b8896071e70b1edfdbe671015a074e85
3
- size 102980821
 
 
 
 
requirements.txt CHANGED
@@ -2,6 +2,7 @@ albumentations>=0.5.2
2
  efficientnet_pytorch
3
  fastai==2.7.13
4
  funcy==1.15
 
5
  iterative-stratification==0.1.6
6
  matplotlib==3.8.2
7
  numpy==1.26.2
@@ -9,6 +10,7 @@ omegaconf>=2.0
9
  opencv-python==4.8.1.78
10
  opencv-python-headless==4.8.1.78
11
  pycocotools>=2.0.0
 
12
  pytorch_lightning
13
  pyyaml
14
  rembg==2.0.53
 
2
  efficientnet_pytorch
3
  fastai==2.7.13
4
  funcy==1.15
5
+ huggingface_hub
6
  iterative-stratification==0.1.6
7
  matplotlib==3.8.2
8
  numpy==1.26.2
 
10
  opencv-python==4.8.1.78
11
  opencv-python-headless==4.8.1.78
12
  pycocotools>=2.0.0
13
+ python-dotenv
14
  pytorch_lightning
15
  pyyaml
16
  rembg==2.0.53
trash_detector.py CHANGED
@@ -1,29 +1,41 @@
 
 
1
  import numpy as np
2
  import torch
3
  from fastai.vision.all import load_learner
 
4
 
 
 
 
 
5
  from efficientdet.efficientdet import get_transforms, rescale_bboxes, set_model
6
 
7
 
8
- def localize_trash(im, det_name, det_checkpoint, device, prob_threshold):
9
- # detector
10
- detector = set_model(det_name, 1, det_checkpoint, device)
 
 
 
11
  detector.eval()
12
  # mean-std normalize the input image (batch-size: 1)
13
  img = get_transforms(im)
14
  # propagate through the model
15
- outputs = detector(img.to(device))
16
  # keep only predictions above set confidence
17
- bboxes_keep = outputs[0, outputs[0, :, 4] > prob_threshold]
18
  probas = bboxes_keep[:, 4:]
19
  # convert boxes to image scales
20
  bboxes_scaled = rescale_bboxes(bboxes_keep[:, :4], im.size, tuple(img.size()[2:]))
21
  return probas, bboxes_scaled
22
 
23
 
24
- def classify_trash(im, clas_checkpoint, cls_th, probas, bboxes_scaled):
25
- # classifier
26
- classifier = load_learner(clas_checkpoint)
 
 
27
 
28
  bboxes_final = []
29
  cls_prob = []
@@ -32,26 +44,20 @@ def classify_trash(im, clas_checkpoint, cls_th, probas, bboxes_scaled):
32
  outputs = classifier.predict(img)
33
  p[1] = torch.topk(outputs[2], k=1).indices.squeeze(0).item()
34
  p[0] = torch.max(np.trunc(outputs[2] * 100))
35
- if p[0] >= cls_th * 100:
36
  bboxes_final.append((xmin, ymin, xmax, ymax))
37
  cls_prob.append(p)
38
  return cls_prob, bboxes_final
39
 
40
 
41
- def detect_trash(
42
- im, det_name, det_checkpoint, clas_checkpoint, device, prob_threshold, cls_th
43
- ):
44
  # prepare models for evaluation
45
  torch.set_grad_enabled(False)
46
 
47
  # 1) Localize
48
- probas, bboxes_scaled = localize_trash(
49
- im, det_name, det_checkpoint, device, prob_threshold
50
- )
51
 
52
  # 2) Classify
53
- cls_prob, bboxes_final = classify_trash(
54
- im, clas_checkpoint, cls_th, probas, bboxes_scaled
55
- )
56
 
57
  return cls_prob, bboxes_final
 
1
+ import os
2
+
3
  import numpy as np
4
  import torch
5
  from fastai.vision.all import load_learner
6
+ from huggingface_hub import hf_hub_download
7
 
8
+ from constants import (CLAS_FILENAME, CLAS_FILEPATH, CLAS_THRESHOLD,
9
+ DET_FILENAME, DET_FILEPATH, DET_NAME, DET_THRESHOLD,
10
+ DEVICE, HF_CLAS_REPO_NAME, HF_DET_REPO_NAME,
11
+ MODELS_PATH)
12
  from efficientdet.efficientdet import get_transforms, rescale_bboxes, set_model
13
 
14
 
15
+ def localize_trash(im):
16
+ # detector, if checkpoint doesn't exist then download from hf
17
+ if not os.path.exists(DET_FILEPATH):
18
+ hf_hub_download(HF_DET_REPO_NAME, DET_FILENAME, local_dir=MODELS_PATH)
19
+ detector = set_model(DET_NAME, 1, DET_FILEPATH, DEVICE)
20
+
21
  detector.eval()
22
  # mean-std normalize the input image (batch-size: 1)
23
  img = get_transforms(im)
24
  # propagate through the model
25
+ outputs = detector(img.to(DEVICE))
26
  # keep only predictions above set confidence
27
+ bboxes_keep = outputs[0, outputs[0, :, 4] > DET_THRESHOLD]
28
  probas = bboxes_keep[:, 4:]
29
  # convert boxes to image scales
30
  bboxes_scaled = rescale_bboxes(bboxes_keep[:, :4], im.size, tuple(img.size()[2:]))
31
  return probas, bboxes_scaled
32
 
33
 
34
+ def classify_trash(im, probas, bboxes_scaled):
35
+ # classifier, if checkpoint doesn't exist then download from hf
36
+ if not os.path.exists(CLAS_FILEPATH):
37
+ hf_hub_download(HF_CLAS_REPO_NAME, CLAS_FILENAME, local_dir=MODELS_PATH)
38
+ classifier = load_learner(CLAS_FILEPATH)
39
 
40
  bboxes_final = []
41
  cls_prob = []
 
44
  outputs = classifier.predict(img)
45
  p[1] = torch.topk(outputs[2], k=1).indices.squeeze(0).item()
46
  p[0] = torch.max(np.trunc(outputs[2] * 100))
47
+ if p[0] >= CLAS_THRESHOLD * 100:
48
  bboxes_final.append((xmin, ymin, xmax, ymax))
49
  cls_prob.append(p)
50
  return cls_prob, bboxes_final
51
 
52
 
53
+ def detect_trash(img):
 
 
54
  # prepare models for evaluation
55
  torch.set_grad_enabled(False)
56
 
57
  # 1) Localize
58
+ probas, bboxes_scaled = localize_trash(img)
 
 
59
 
60
  # 2) Classify
61
+ cls_prob, bboxes_final = classify_trash(img, probas, bboxes_scaled)
 
 
62
 
63
  return cls_prob, bboxes_final