Axelottle commited on
Commit
5ae65da
1 Parent(s): 04fd117

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. .DS_Store +0 -0
  2. __pycache__/app.cpython-39.pyc +0 -0
  3. app.py +77 -48
  4. requirements.txt +1 -2
.DS_Store ADDED
Binary file (8.2 kB). View file
 
__pycache__/app.cpython-39.pyc ADDED
Binary file (5.49 kB). View file
 
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import gradio as gr
2
- import cv2 as cv2
3
  import pandas as pd
4
- from sahi.prediction import ObjectPrediction
5
- from sahi.utils.cv import visualize_object_predictions, read_image
6
- from ultralyticsplus import YOLO
 
7
 
8
  # Gradio Theme
9
  theme = gr.themes.Soft(
@@ -16,6 +16,7 @@ theme = gr.themes.Soft(
16
  background_fill_primary='*neutral_100',
17
  )
18
 
 
19
  # Bread Prices
20
  bread_types = {
21
  "baguette": {"name": "Baguette", "price": 108},
@@ -35,41 +36,60 @@ bread_types = {
35
  "whole-grain-bread": {"name": "Whole Grain Bread", "price": 10},
36
  }
37
 
 
38
  # Instantiate the model
39
- model = YOLO('best.pt')
40
 
41
- # Bread Prediction function
42
- def detect_bread(image):
43
- results = model.predict(image, conf=0.4)
44
- result = results[0]
45
- object_prediction_list = []
 
 
 
46
 
47
- # Image Output function
48
- for box in result.boxes:
49
- bbox = box.xyxy[0].tolist()
50
- score = round(box.conf[0].item(), 2)
51
- category_id = box.cls[0]
52
- category_name = result.names[box.cls[0].item()]
53
-
54
- object_prediction = ObjectPrediction(
55
- bbox=bbox,
56
- category_id=int(category_id),
57
- score=score,
58
- category_name=category_name,
59
- )
60
- object_prediction_list.append(object_prediction)
61
 
62
- image = read_image(image)
63
- output_image = visualize_object_predictions(image=image, object_prediction_list=object_prediction_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- # Receipt Output function
66
  detected_classes = []
 
 
 
 
 
 
 
 
 
 
 
 
67
  detected_items = []
68
  counts = {} # Dictionary to store bread type counts
69
-
70
- for cls in result.boxes.cls: # Stores all detected classes in the list
71
- detected_classes.append(result.names[int(cls)])
72
-
73
  for item_class in detected_classes: # Counts the quantity of each class
74
  counts[item_class] = counts.get(item_class, 0) + 1
75
 
@@ -103,18 +123,27 @@ def detect_bread(image):
103
 
104
  df = pd.DataFrame(data)
105
 
106
- return output_image['image'], df
 
107
 
108
  # Export to CSV function
109
  def export_csv(df):
110
  df.to_csv("receipt.csv", index=False)
111
  return gr.File.update(value="receipt.csv", visible=True)
112
 
 
113
  # Export to JSON function
114
  def export_json(df):
115
  df.to_json("receipt.json")
116
  return gr.File.update(value="receipt.json", visible=True)
117
 
 
 
 
 
 
 
 
118
  # Gradio Interface
119
  with gr.Blocks(theme=theme) as demo:
120
 
@@ -124,13 +153,12 @@ with gr.Blocks(theme=theme) as demo:
124
 
125
  with gr.Row():
126
  with gr.Column():
127
- fn = detect_bread
128
- img_input = gr.Image(type="filepath", label="Input Image")
129
- #img_input = gr.Files(file_types=["filepath"], label="Input Image")
130
- detect_btn = gr.Button(variant="primary", value="Detect")
131
-
132
  with gr.Column():
133
- img_output = gr.Image(type="filepath", label='Output Image')
134
  receipt_output = gr.Dataframe(
135
  headers=["Item", "Quantity", "Price", "Amount"],
136
  datatype=["str", "number", "number", "number"],
@@ -138,21 +166,22 @@ with gr.Blocks(theme=theme) as demo:
138
  interactive=False,
139
  )
140
  with gr.Row():
141
- clear_btn = gr.ClearButton([img_input, img_output, receipt_output])
142
- export_btn = gr.Button(variant="primary", value="Export as CSV")
143
  export_json_btn = gr.Button(variant="primary", value="Export as JSON")
144
  with gr.Row():
145
- csv = gr.File(interactive=False, visible=False)
146
-
147
- with gr.Row():
148
- gr.Examples(
149
- examples = ["examples/bonete.jpg", "examples/pandesal.jpg", "examples/croissant_baguette.jpg", "examples/slices.png"],
150
- inputs = img_input,
151
- )
152
 
 
 
153
  detect_btn.click(detect_bread, inputs=img_input, outputs=[img_output, receipt_output])
154
- export_btn.click(export_csv, receipt_output, csv)
155
  export_json_btn.click(export_json, receipt_output, csv)
 
 
 
 
156
 
157
  demo.queue()
158
  demo.launch()
 
1
  import gradio as gr
 
2
  import pandas as pd
3
+ import os, shutil
4
+ from PIL import Image
5
+ from ultralyticsplus import YOLO
6
+
7
 
8
  # Gradio Theme
9
  theme = gr.themes.Soft(
 
16
  background_fill_primary='*neutral_100',
17
  )
18
 
19
+
20
  # Bread Prices
21
  bread_types = {
22
  "baguette": {"name": "Baguette", "price": 108},
 
36
  "whole-grain-bread": {"name": "Whole Grain Bread", "price": 10},
37
  }
38
 
39
+
40
  # Instantiate the model
41
+ model = YOLO("best.pt")
42
 
43
+
44
+ # Converts image input into a list
45
+ def preprocess_image(image):
46
+ img_list = []
47
+
48
+ for im in image:
49
+ image = Image.open(im.name)
50
+ img_list.append(image)
51
 
52
+ return img_list
53
+
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ # Gets all output images
56
+ def get_predictions(directory):
57
+ allowed_extensions = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
58
+ return [
59
+ os.path.join(root, file)
60
+ for root, _, files in os.walk(directory)
61
+ for file in files
62
+ if file.lower().endswith(allowed_extensions)
63
+ ]
64
+
65
+
66
+ # Clear output from previous detection
67
+ def clear_output():
68
+ shutil.rmtree('output/', ignore_errors=True)
69
+
70
+
71
+ # Bread Prediction function
72
+ def detect_bread(image):
73
+ clear_output()
74
+ image_list = preprocess_image(image)
75
+ results = model.predict(image_list, conf=0.4, save=True, hide_conf=True, project = "output", name = "results")
76
 
 
77
  detected_classes = []
78
+
79
+ for result in results:
80
+ for cls in result.boxes.cls: # Stores all detected classes in the list
81
+ detected_classes.append(result.names[int(cls)])
82
+
83
+ receipt = generate_receipt(detected_classes)
84
+
85
+ return get_predictions(f'output/results'), receipt
86
+
87
+
88
+ # Generate Receipt function
89
+ def generate_receipt(detected_classes):
90
  detected_items = []
91
  counts = {} # Dictionary to store bread type counts
92
+
 
 
 
93
  for item_class in detected_classes: # Counts the quantity of each class
94
  counts[item_class] = counts.get(item_class, 0) + 1
95
 
 
123
 
124
  df = pd.DataFrame(data)
125
 
126
+ return df
127
+
128
 
129
  # Export to CSV function
130
  def export_csv(df):
131
  df.to_csv("receipt.csv", index=False)
132
  return gr.File.update(value="receipt.csv", visible=True)
133
 
134
+
135
  # Export to JSON function
136
  def export_json(df):
137
  df.to_json("receipt.json")
138
  return gr.File.update(value="receipt.json", visible=True)
139
 
140
+
141
+ # Select image from Files
142
+ def preview(files, sd: gr.SelectData):
143
+ prev = files[sd.index].name
144
+ return gr.Image.update(value=prev, visible=True)
145
+
146
+
147
  # Gradio Interface
148
  with gr.Blocks(theme=theme) as demo:
149
 
 
153
 
154
  with gr.Row():
155
  with gr.Column():
156
+ fn = detect_bread
157
+ img_input = gr.Files(file_types=["image"], label="Input Image")
158
+ img_preview = gr.Image(label="Preview Image", interactive=False, visible=False)
159
+ detect_btn = gr.Button(variant="primary", value="Detect")
 
160
  with gr.Column():
161
+ img_output = gr.Gallery(label='Output Image')
162
  receipt_output = gr.Dataframe(
163
  headers=["Item", "Quantity", "Price", "Amount"],
164
  datatype=["str", "number", "number", "number"],
 
166
  interactive=False,
167
  )
168
  with gr.Row():
169
+ clear_btn = gr.ClearButton()
170
+ export_csv_btn = gr.Button(variant="primary", value="Export as CSV")
171
  export_json_btn = gr.Button(variant="primary", value="Export as JSON")
172
  with gr.Row():
173
+ csv = gr.File(interactive=False, visible=False)
174
+
 
 
 
 
 
175
 
176
+ # Gradio Buttons
177
+ img_input.select(preview, img_input, img_preview)
178
  detect_btn.click(detect_bread, inputs=img_input, outputs=[img_output, receipt_output])
179
+ export_csv_btn.click(export_csv, receipt_output, csv)
180
  export_json_btn.click(export_json, receipt_output, csv)
181
+ clear_btn.click(lambda: [None, None, None, gr.File.update(visible=False), gr.Image.update(visible=False)],
182
+ outputs=[img_input, img_output, receipt_output, csv, img_preview]
183
+ )
184
+
185
 
186
  demo.queue()
187
  demo.launch()
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  gradio==3.40.1
2
- opencv_python==4.8.0.74
3
  pandas==2.0.3
4
- sahi==0.11.14
5
  ultralyticsplus==0.0.28
 
1
  gradio==3.40.1
 
2
  pandas==2.0.3
3
+ Pillow==10.0.0
4
  ultralyticsplus==0.0.28