mussie1212 commited on
Commit
70d5cce
·
verified ·
1 Parent(s): 81ab450

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -0
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ # from fastapi.responses import JSONResponse
3
+ # import logging
4
+ # from ultralytics import YOLO
5
+ # import numpy as np
6
+ # import cv2
7
+ # from io import BytesIO
8
+ # from PIL import Image
9
+ # import base64
10
+ # import os
11
+
12
+ # # Setup logging
13
+ # logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
14
+ # logger = logging.getLogger(__name__)
15
+
16
+
17
+
18
+
19
+ # app = FastAPI(title="Car Parts & Damage Detection API")
20
+
21
+
22
+
23
+ # # Log model file presence
24
+ # model_files = ["car_part_detector_model.pt", "damage_general_model.pt"]
25
+ # for model_file in model_files:
26
+ # if os.path.exists(model_file):
27
+ # logger.info(f"Model file found: {model_file}")
28
+ # else:
29
+ # logger.error(f"Model file missing: {model_file}")
30
+
31
+
32
+
33
+ # # Load YOLO models
34
+ # try:
35
+ # logger.info("Loading car part model...")
36
+ # car_part_model = YOLO("car_part_detector_model.pt")
37
+ # logger.info("Car part model loaded successfully")
38
+ # logger.info("Loading damage model...")
39
+ # damage_model = YOLO("damage_general_model.pt")
40
+ # logger.info("Damage model loaded successfully")
41
+ # except Exception as e:
42
+ # logger.error(f"Failed to load models: {str(e)}")
43
+ # raise RuntimeError(f"Failed to load models: {str(e)}")
44
+
45
+
46
+
47
+
48
+ # def image_to_base64(img: np.ndarray) -> str:
49
+ # """Convert numpy image to base64 string."""
50
+ # try:
51
+ # _, buffer = cv2.imencode(".png", img)
52
+ # return base64.b64encode(buffer).decode("utf-8")
53
+ # except Exception as e:
54
+ # logger.error(f"Error encoding image to base64: {str(e)}")
55
+ # raise
56
+
57
+
58
+
59
+
60
+ # @app.post("/predict", summary="Run inference on an image for car parts and damage")
61
+ # async def predict(file: UploadFile = File(...)):
62
+ # """Upload an image and get car parts and damage detection results."""
63
+ # logger.info("Received image upload")
64
+ # try:
65
+ # contents = await file.read()
66
+ # image = Image.open(BytesIO(contents)).convert("RGB")
67
+ # img = np.array(image)
68
+ # logger.info(f"Image loaded: shape={img.shape}")
69
+
70
+ # blank_img = np.full((img.shape[0], img.shape[1], 3), 128, dtype=np.uint8)
71
+ # car_part_img = blank_img.copy()
72
+ # damage_img = blank_img.copy()
73
+ # car_part_text = "Car Parts: No detections"
74
+ # damage_text = "Damage: No detections"
75
+
76
+ # try:
77
+ # logger.info("Running car part detection...")
78
+ # car_part_results = car_part_model(img)[0]
79
+ # if car_part_results.boxes:
80
+ # car_part_img = car_part_results.plot()[..., ::-1]
81
+ # car_part_text = "Car Parts:\n" + "\n".join(
82
+ # f"- {car_part_results.names[int(cls)]} ({conf:.2f})"
83
+ # for conf, cls in zip(car_part_results.boxes.conf, car_part_results.boxes.cls)
84
+ # )
85
+ # logger.info("Car part detection completed")
86
+ # except Exception as e:
87
+ # car_part_text = f"Car Parts: Error: {str(e)}"
88
+ # logger.error(f"Car part detection error: {str(e)}")
89
+
90
+ # try:
91
+ # logger.info("Running damage detection...")
92
+ # damage_results = damage_model(img)[0]
93
+ # if damage_results.boxes:
94
+ # damage_img = damage_results.plot()[..., ::-1]
95
+ # damage_text = "Damage:\n" + "\n".join(
96
+ # f"- {damage_results.names[int(cls)]} ({conf:.2f})"
97
+ # for conf, cls in zip(damage_results.boxes.conf, damage_results.boxes.cls)
98
+ # )
99
+ # logger.info("Damage detection completed")
100
+ # except Exception as e:
101
+ # damage_text = f"Damage: Error: {str(e)}"
102
+ # logger.error(f"Damage detection error: {str(e)}")
103
+
104
+ # car_part_img_base64 = image_to_base64(car_part_img)
105
+ # damage_img_base64 = image_to_base64(damage_img)
106
+ # logger.info("Returning prediction results")
107
+ # return JSONResponse({
108
+ # "car_part_image": car_part_img_base64,
109
+ # "car_part_text": car_part_text,
110
+ # "damage_image": damage_img_base64,
111
+ # "damage_text": damage_text
112
+ # })
113
+ # except Exception as e:
114
+ # logger.error(f"Inference error: {str(e)}")
115
+ # raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
116
+
117
+
118
+
119
+
120
+
121
+
122
+ # @app.get("/", summary="Health check")
123
+ # async def root():
124
+ # """Check if the API is running."""
125
+ # logger.info("Health check accessed")
126
+ # return {"message": "Car Parts & Damage Detection API is running"}
127
+
128
+
129
+
130
+ import gradio as gr
131
+ import numpy as np
132
+ import cv2
133
+ from PIL import Image
134
+ import base64
135
+ from io import BytesIO
136
+ from ultralytics import YOLO
137
+ import logging
138
+ import time
139
+
140
+
141
+
142
+ # Set up logging
143
+ logging.basicConfig(level=logging.INFO)
144
+ logger = logging.getLogger(__name__)
145
+
146
+ # Load damage detection model
147
+ try:
148
+ logger.info("Loading damage model...")
149
+ damage_model = YOLO("damage_general_model.pt")
150
+ logger.info("Damage model loaded successfully")
151
+ except Exception as e:
152
+ logger.error(f"Failed to load damage model: {str(e)}")
153
+ raise RuntimeError(f"Failed to load damage model: {str(e)}")
154
+
155
+ def image_to_base64(img: np.ndarray) -> str:
156
+ """Convert numpy image to base64 string."""
157
+ try:
158
+ _, buffer = cv2.imencode(".png", img)
159
+ return base64.b64encode(buffer).decode("utf-8")
160
+ except Exception as e:
161
+ logger.error(f"Error encoding image to base64: {str(e)}")
162
+ raise
163
+
164
+ def process_images(*images):
165
+ """Process up to 5 images for damage detection."""
166
+ if not any(images):
167
+ return "Please upload at least one image.", []
168
+
169
+ results = []
170
+ timing_info = []
171
+ for idx, img in enumerate(images):
172
+ if img is None:
173
+ continue
174
+ try:
175
+ start_image_time = time.time() # Start timer for individual image
176
+ logger.info(f"Processing image {idx + 1}")
177
+ # Convert Gradio image input (PIL) to numpy
178
+ img_np = np.array(img)
179
+ blank_img = np.full((img_np.shape[0], img_np.shape[1], 3), 128, dtype=np.uint8)
180
+ damage_text = f"Image {idx + 1} - Damage: No detections"
181
+ damage_img = blank_img.copy()
182
+
183
+ # Run damage detection
184
+ logger.info(f"Running damage detection for image {idx + 1}...")
185
+ damage_results = damage_model(img_np)[0]
186
+ if damage_results.boxes:
187
+ damage_img = damage_results.plot()[..., ::-1]
188
+ damage_text = f"Image {idx + 1} - Damage:\n" + "\n".join(
189
+ f"- {damage_results.names[int(cls)]} ({conf:.2f})"
190
+ for conf, cls in zip(damage_results.boxes.conf, damage_results.boxes.cls)
191
+ )
192
+ logger.info(f"Damage detection completed for image {idx + 1}")
193
+
194
+ # Convert result image to PIL for Gradio display
195
+ damage_pil = Image.fromarray(damage_img)
196
+ results.append((damage_pil, damage_text))
197
+ except Exception as e:
198
+ logger.error(f"Error processing image {idx + 1}: {str(e)}")
199
+ results.append((None, f"Image {idx + 1} - Error: {str(e)}"))
200
+
201
+ # Calculate total processing time
202
+ total_time = time.time() - start_image_time
203
+ timing_info.append(f"Total processing time: {total_time:.2f} seconds")
204
+
205
+ return "Damage detection completed.", results, "\n".join(timing_info)
206
+
207
+ # Define Gradio interface
208
+ iface = gr.Interface(
209
+ fn=process_images,
210
+ inputs=[
211
+ gr.Image(type="pil", label="Upload Image 1"),
212
+ gr.Image(type="pil", label="Upload Image 2"),
213
+ gr.Image(type="pil", label="Upload Image 3"),
214
+ gr.Image(type="pil", label="Upload Image 4"),
215
+ gr.Image(type="pil", label="Upload Image 5"),
216
+ ],
217
+ outputs=[
218
+ gr.Textbox(label="Status"),
219
+ gr.Gallery(label="Detected Damage Images and Results", columns=2),
220
+ ],
221
+ title="Car Damage Detection",
222
+ description="Upload up to 5 images to detect car damage. Results will display annotated images and detected damage details.",
223
+ )
224
+
225
+ # Launch the Gradio app
226
+ if __name__ == "__main__":
227
+ iface.launch()