peterverebics commited on
Commit
7cd5fc7
·
1 Parent(s): 5edda84

add prithivMLmods/Trash-Net

Browse files
Files changed (3) hide show
  1. app.py +36 -36
  2. requirements.txt +2 -11
  3. runtime.txt +1 -1
app.py CHANGED
@@ -1,46 +1,46 @@
1
  import gradio as gr
2
- import tensorflow as tf
 
 
3
  from PIL import Image
4
- import numpy as np
5
- import requests
6
- import os
7
 
8
- # Ensure model folder exists
9
- os.makedirs("model", exist_ok=True)
 
 
10
 
11
- # Download the model from Hugging Face if not already present
12
- model_path = "model/mobnet_model.keras"
13
- if not os.path.exists(model_path):
14
- url = "https://huggingface.co/ahmzakif/TrashNet-Classification/resolve/main/model/mobnet_model.keras"
15
- r = requests.get(url)
16
- with open(model_path, "wb") as f:
17
- f.write(r.content)
18
-
19
- # Load Keras model
20
- model = tf.keras.models.load_model(model_path)
21
-
22
- # TrashNet classes
23
- classes = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
24
-
25
- # Image preprocessing
26
- def predict(image: Image.Image):
27
- image = image.convert("RGB").resize((224, 224))
28
- x = np.array(image, dtype=np.float32) / 255.0
29
- x = np.expand_dims(x, axis=0)
30
 
31
- preds = model.predict(x)[0]
32
- scores = {classes[i]: float(preds[i]) for i in range(len(classes))}
33
- top_class = max(scores, key=scores.get)
 
 
 
 
 
 
34
 
35
- return {"prediction": top_class, "scores": scores}
36
 
37
- # Gradio interface
38
  iface = gr.Interface(
39
- fn=predict,
40
- inputs=gr.Image(type="pil"),
41
- outputs="json",
42
- title="TrashNet Classification API",
43
- description="Upload an image of trash to get its classification."
44
  )
45
 
46
- iface.launch()
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoImageProcessor
3
+ from transformers import SiglipForImageClassification
4
+ from transformers.image_utils import load_image
5
  from PIL import Image
6
+ import torch
 
 
7
 
8
+ # Load model and processor
9
+ model_name = "prithivMLmods/Trash-Net"
10
+ model = SiglipForImageClassification.from_pretrained(model_name)
11
+ processor = AutoImageProcessor.from_pretrained(model_name)
12
 
13
+ def trash_classification(image):
14
+ """Predicts the category of waste material in the image."""
15
+ image = Image.fromarray(image).convert("RGB")
16
+ inputs = processor(images=image, return_tensors="pt")
17
+
18
+ with torch.no_grad():
19
+ outputs = model(**inputs)
20
+ logits = outputs.logits
21
+ probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist()
 
 
 
 
 
 
 
 
 
 
22
 
23
+ labels = {
24
+ "0": "cardboard",
25
+ "1": "glass",
26
+ "2": "metal",
27
+ "3": "paper",
28
+ "4": "plastic",
29
+ "5": "trash"
30
+ }
31
+ predictions = {labels[str(i)]: round(probs[i], 3) for i in range(len(probs))}
32
 
33
+ return predictions
34
 
35
+ # Create Gradio interface
36
  iface = gr.Interface(
37
+ fn=trash_classification,
38
+ inputs=gr.Image(type="numpy"),
39
+ outputs=gr.Label(label="Prediction Scores"),
40
+ title="Trash Classification",
41
+ description="Upload an image to classify the type of waste material."
42
  )
43
 
44
+ # Launch the app
45
+ if __name__ == "__main__":
46
+ iface.launch()
requirements.txt CHANGED
@@ -1,13 +1,4 @@
1
  gradio
2
- tqdm
3
- imutils
4
- numpy
5
- pandas
6
  pillow
7
- matplotlib
8
- seaborn
9
- albumentations
10
- opencv-python
11
- tensorflow>=2.15.0
12
- scikit-learn
13
- wandb
 
1
  gradio
2
+ transformers
3
+ torch
 
 
4
  pillow
 
 
 
 
 
 
 
runtime.txt CHANGED
@@ -1 +1 @@
1
- python-3.11.5
 
1
+ python-3.12