Spaces:
Sleeping
Sleeping
Update app_3.py
Browse files
app_3.py
CHANGED
@@ -1303,6 +1303,237 @@ def process_image(input_image, input_text):
|
|
1303 |
|
1304 |
return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
|
1305 |
return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1306 |
|
1307 |
|
1308 |
block = gr.Blocks().queue()
|
|
|
1303 |
|
1304 |
return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
|
1305 |
return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
|
1306 |
+
|
1307 |
+
|
1308 |
+
@spaces.GPU(duration=60)
|
1309 |
+
@torch.inference_mode
|
1310 |
+
def process_image(input_image, input_text):
|
1311 |
+
"""Main processing function for the Gradio interface"""
|
1312 |
+
|
1313 |
+
if isinstance(input_image, Image.Image):
|
1314 |
+
input_image = np.array(input_image)
|
1315 |
+
|
1316 |
+
# Initialize configs
|
1317 |
+
API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
|
1318 |
+
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
|
1319 |
+
SAM2_MODEL_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs/sam2_hiera_l.yaml")
|
1320 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
1321 |
+
OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
|
1322 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
1323 |
+
|
1324 |
+
HEIGHT = 768
|
1325 |
+
WIDTH = 768
|
1326 |
+
|
1327 |
+
# Initialize DDS client
|
1328 |
+
config = Config(API_TOKEN)
|
1329 |
+
client = Client(config)
|
1330 |
+
|
1331 |
+
# Process classes from text prompt
|
1332 |
+
classes = [x.strip().lower() for x in input_text.split('.') if x]
|
1333 |
+
class_name_to_id = {name: id for id, name in enumerate(classes)}
|
1334 |
+
class_id_to_name = {id: name for name, id in class_name_to_id.items()}
|
1335 |
+
|
1336 |
+
# Save input image to temp file and get URL
|
1337 |
+
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
|
1338 |
+
cv2.imwrite(tmpfile.name, input_image)
|
1339 |
+
image_url = client.upload_file(tmpfile.name)
|
1340 |
+
os.remove(tmpfile.name)
|
1341 |
+
|
1342 |
+
# Process detection results
|
1343 |
+
input_boxes = []
|
1344 |
+
masks = []
|
1345 |
+
confidences = []
|
1346 |
+
class_names = []
|
1347 |
+
class_ids = []
|
1348 |
+
|
1349 |
+
if len(input_text) == 0:
|
1350 |
+
task = DinoxTask(
|
1351 |
+
image_url=image_url,
|
1352 |
+
prompts=[TextPrompt(text="<prompt_free>")],
|
1353 |
+
# targets=[DetectionTarget.BBox, DetectionTarget.Mask]
|
1354 |
+
)
|
1355 |
+
|
1356 |
+
client.run_task(task)
|
1357 |
+
predictions = task.result.objects
|
1358 |
+
classes = [pred.category for pred in predictions]
|
1359 |
+
classes = list(set(classes))
|
1360 |
+
class_name_to_id = {name: id for id, name in enumerate(classes)}
|
1361 |
+
class_id_to_name = {id: name for name, id in class_name_to_id.items()}
|
1362 |
+
|
1363 |
+
for idx, obj in enumerate(predictions):
|
1364 |
+
input_boxes.append(obj.bbox)
|
1365 |
+
masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
|
1366 |
+
confidences.append(obj.score)
|
1367 |
+
cls_name = obj.category.lower().strip()
|
1368 |
+
class_names.append(cls_name)
|
1369 |
+
class_ids.append(class_name_to_id[cls_name])
|
1370 |
+
|
1371 |
+
boxes = np.array(input_boxes)
|
1372 |
+
masks = np.array(masks)
|
1373 |
+
class_ids = np.array(class_ids)
|
1374 |
+
labels = [
|
1375 |
+
f"{class_name} {confidence:.2f}"
|
1376 |
+
for class_name, confidence
|
1377 |
+
in zip(class_names, confidences)
|
1378 |
+
]
|
1379 |
+
detections = sv.Detections(
|
1380 |
+
xyxy=boxes,
|
1381 |
+
mask=masks.astype(bool),
|
1382 |
+
class_id=class_ids
|
1383 |
+
)
|
1384 |
+
|
1385 |
+
box_annotator = sv.BoxAnnotator()
|
1386 |
+
label_annotator = sv.LabelAnnotator()
|
1387 |
+
mask_annotator = sv.MaskAnnotator()
|
1388 |
+
|
1389 |
+
annotated_frame = input_image.copy()
|
1390 |
+
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
|
1391 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
1392 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
1393 |
+
|
1394 |
+
# Create transparent mask for first detected object
|
1395 |
+
if len(detections) > 0:
|
1396 |
+
# Get first mask
|
1397 |
+
first_mask = detections.mask[0]
|
1398 |
+
|
1399 |
+
# Get original RGB image
|
1400 |
+
img = input_image.copy()
|
1401 |
+
H, W, C = img.shape
|
1402 |
+
|
1403 |
+
# Create RGBA image with default 255 alpha
|
1404 |
+
alpha = np.zeros((H, W, 1), dtype=np.uint8)
|
1405 |
+
alpha[~first_mask] = 128 # Set semi-transparency for background
|
1406 |
+
alpha[first_mask] = 255 # Make the foreground opaque
|
1407 |
+
|
1408 |
+
rgba = np.dstack((img, alpha)).astype(np.uint8)
|
1409 |
+
|
1410 |
+
# get the bounding box of alpha
|
1411 |
+
y, x = np.where(alpha > 0)
|
1412 |
+
y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
|
1413 |
+
x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
|
1414 |
+
|
1415 |
+
image_center = rgba[y0:y1, x0:x1]
|
1416 |
+
# resize the longer side to H * 0.9
|
1417 |
+
H, W, _ = image_center.shape
|
1418 |
+
if H > W:
|
1419 |
+
W = int(W * (HEIGHT * 0.9) / H)
|
1420 |
+
H = int(HEIGHT * 0.9)
|
1421 |
+
else:
|
1422 |
+
H = int(H * (WIDTH * 0.9) / W)
|
1423 |
+
W = int(WIDTH * 0.9)
|
1424 |
+
|
1425 |
+
image_center = np.array(Image.fromarray(image_center).resize((W, H), Image.LANCZOS))
|
1426 |
+
# pad to H, W
|
1427 |
+
start_h = (HEIGHT - H) // 2
|
1428 |
+
start_w = (WIDTH - W) // 2
|
1429 |
+
image = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8)
|
1430 |
+
image[start_h : start_h + H, start_w : start_w + W] = image_center
|
1431 |
+
image = image.astype(np.float32) / 255.0
|
1432 |
+
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
|
1433 |
+
image = (image * 255).clip(0, 255).astype(np.uint8)
|
1434 |
+
image = Image.fromarray(image)
|
1435 |
+
|
1436 |
+
return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
|
1437 |
+
return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
|
1438 |
+
else:
|
1439 |
+
# Run DINO-X detection
|
1440 |
+
task = DinoxTask(
|
1441 |
+
image_url=image_url,
|
1442 |
+
prompts=[TextPrompt(text=input_text)],
|
1443 |
+
targets=[DetectionTarget.BBox, DetectionTarget.Mask]
|
1444 |
+
)
|
1445 |
+
|
1446 |
+
client.run_task(task)
|
1447 |
+
result = task.result
|
1448 |
+
objects = result.objects
|
1449 |
+
|
1450 |
+
predictions = task.result.objects
|
1451 |
+
classes = [x.strip().lower() for x in input_text.split('.') if x]
|
1452 |
+
class_name_to_id = {name: id for id, name in enumerate(classes)}
|
1453 |
+
class_id_to_name = {id: name for name, id in class_name_to_id.items()}
|
1454 |
+
|
1455 |
+
boxes = []
|
1456 |
+
masks = []
|
1457 |
+
confidences = []
|
1458 |
+
class_names = []
|
1459 |
+
class_ids = []
|
1460 |
+
|
1461 |
+
for idx, obj in enumerate(predictions):
|
1462 |
+
boxes.append(obj.bbox)
|
1463 |
+
masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
|
1464 |
+
confidences.append(obj.score)
|
1465 |
+
cls_name = obj.category.lower().strip()
|
1466 |
+
class_names.append(cls_name)
|
1467 |
+
class_ids.append(class_name_to_id[cls_name])
|
1468 |
+
|
1469 |
+
boxes = np.array(boxes)
|
1470 |
+
masks = np.array(masks)
|
1471 |
+
class_ids = np.array(class_ids)
|
1472 |
+
labels = [
|
1473 |
+
f"{class_name} {confidence:.2f}"
|
1474 |
+
for class_name, confidence
|
1475 |
+
in zip(class_names, confidences)
|
1476 |
+
]
|
1477 |
+
|
1478 |
+
detections = sv.Detections(
|
1479 |
+
xyxy=boxes,
|
1480 |
+
mask=masks.astype(bool),
|
1481 |
+
class_id=class_ids,
|
1482 |
+
)
|
1483 |
+
|
1484 |
+
box_annotator = sv.BoxAnnotator()
|
1485 |
+
label_annotator = sv.LabelAnnotator()
|
1486 |
+
mask_annotator = sv.MaskAnnotator()
|
1487 |
+
|
1488 |
+
annotated_frame = input_image.copy()
|
1489 |
+
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
|
1490 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
1491 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
1492 |
+
|
1493 |
+
# Create transparent mask for first detected object
|
1494 |
+
if len(detections) > 0:
|
1495 |
+
# Get first mask
|
1496 |
+
first_mask = detections.mask[0]
|
1497 |
+
|
1498 |
+
# Get original RGB image
|
1499 |
+
img = input_image.copy()
|
1500 |
+
H, W, C = img.shape
|
1501 |
+
|
1502 |
+
# Create RGBA image with default 255 alpha
|
1503 |
+
alpha = np.zeros((H, W, 1), dtype=np.uint8)
|
1504 |
+
alpha[~first_mask] = 128 # Set semi-transparency for background
|
1505 |
+
alpha[first_mask] = 255 # Make the foreground opaque
|
1506 |
+
|
1507 |
+
rgba = np.dstack((img, alpha)).astype(np.uint8)
|
1508 |
+
|
1509 |
+
# get the bounding box of alpha
|
1510 |
+
y, x = np.where(alpha > 0)
|
1511 |
+
y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
|
1512 |
+
x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
|
1513 |
+
|
1514 |
+
image_center = rgba[y0:y1, x0:x1]
|
1515 |
+
# resize the longer side to H * 0.9
|
1516 |
+
H, W, _ = image_center.shape
|
1517 |
+
if H > W:
|
1518 |
+
W = int(W * (HEIGHT * 0.9) / H)
|
1519 |
+
H = int(HEIGHT * 0.9)
|
1520 |
+
else:
|
1521 |
+
H = int(H * (WIDTH * 0.9) / W)
|
1522 |
+
W = int(WIDTH * 0.9)
|
1523 |
+
|
1524 |
+
image_center = np.array(Image.fromarray(image_center).resize((W, H), Image.LANCZOS))
|
1525 |
+
# pad to H, W
|
1526 |
+
start_h = (HEIGHT - H) // 2
|
1527 |
+
start_w = (WIDTH - W) // 2
|
1528 |
+
image = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8)
|
1529 |
+
image[start_h : start_h + H, start_w : start_w + W] = image_center
|
1530 |
+
image = image.astype(np.float32) / 255.0
|
1531 |
+
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
|
1532 |
+
image = (image * 255).clip(0, 255).astype(np.uint8)
|
1533 |
+
image = Image.fromarray(image)
|
1534 |
+
|
1535 |
+
return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
|
1536 |
+
return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
|
1537 |
|
1538 |
|
1539 |
block = gr.Blocks().queue()
|