Ashoka74 commited on
Commit
9db138f
1 Parent(s): 274c822

Update app_3.py

Browse files
Files changed (1) hide show
  1. app_3.py +231 -0
app_3.py CHANGED
@@ -1303,6 +1303,237 @@ def process_image(input_image, input_text):
1303
 
1304
  return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
1305
  return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1306
 
1307
 
1308
  block = gr.Blocks().queue()
 
1303
 
1304
  return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
1305
  return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
1306
+
1307
+
1308
+ @spaces.GPU(duration=60)
1309
+ @torch.inference_mode
1310
+ def process_image(input_image, input_text):
1311
+ """Main processing function for the Gradio interface"""
1312
+
1313
+ if isinstance(input_image, Image.Image):
1314
+ input_image = np.array(input_image)
1315
+
1316
+ # Initialize configs
1317
+ API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
1318
+ SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
1319
+ SAM2_MODEL_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs/sam2_hiera_l.yaml")
1320
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
1321
+ OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
1322
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
1323
+
1324
+ HEIGHT = 768
1325
+ WIDTH = 768
1326
+
1327
+ # Initialize DDS client
1328
+ config = Config(API_TOKEN)
1329
+ client = Client(config)
1330
+
1331
+ # Process classes from text prompt
1332
+ classes = [x.strip().lower() for x in input_text.split('.') if x]
1333
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
1334
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
1335
+
1336
+ # Save input image to temp file and get URL
1337
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
1338
+ cv2.imwrite(tmpfile.name, input_image)
1339
+ image_url = client.upload_file(tmpfile.name)
1340
+ os.remove(tmpfile.name)
1341
+
1342
+ # Process detection results
1343
+ input_boxes = []
1344
+ masks = []
1345
+ confidences = []
1346
+ class_names = []
1347
+ class_ids = []
1348
+
1349
+ if len(input_text) == 0:
1350
+ task = DinoxTask(
1351
+ image_url=image_url,
1352
+ prompts=[TextPrompt(text="<prompt_free>")],
1353
+ # targets=[DetectionTarget.BBox, DetectionTarget.Mask]
1354
+ )
1355
+
1356
+ client.run_task(task)
1357
+ predictions = task.result.objects
1358
+ classes = [pred.category for pred in predictions]
1359
+ classes = list(set(classes))
1360
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
1361
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
1362
+
1363
+ for idx, obj in enumerate(predictions):
1364
+ input_boxes.append(obj.bbox)
1365
+ masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
1366
+ confidences.append(obj.score)
1367
+ cls_name = obj.category.lower().strip()
1368
+ class_names.append(cls_name)
1369
+ class_ids.append(class_name_to_id[cls_name])
1370
+
1371
+ boxes = np.array(input_boxes)
1372
+ masks = np.array(masks)
1373
+ class_ids = np.array(class_ids)
1374
+ labels = [
1375
+ f"{class_name} {confidence:.2f}"
1376
+ for class_name, confidence
1377
+ in zip(class_names, confidences)
1378
+ ]
1379
+ detections = sv.Detections(
1380
+ xyxy=boxes,
1381
+ mask=masks.astype(bool),
1382
+ class_id=class_ids
1383
+ )
1384
+
1385
+ box_annotator = sv.BoxAnnotator()
1386
+ label_annotator = sv.LabelAnnotator()
1387
+ mask_annotator = sv.MaskAnnotator()
1388
+
1389
+ annotated_frame = input_image.copy()
1390
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
1391
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
1392
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
1393
+
1394
+ # Create transparent mask for first detected object
1395
+ if len(detections) > 0:
1396
+ # Get first mask
1397
+ first_mask = detections.mask[0]
1398
+
1399
+ # Get original RGB image
1400
+ img = input_image.copy()
1401
+ H, W, C = img.shape
1402
+
1403
+ # Create RGBA image with default 255 alpha
1404
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
1405
+ alpha[~first_mask] = 128 # Set semi-transparency for background
1406
+ alpha[first_mask] = 255 # Make the foreground opaque
1407
+
1408
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
1409
+
1410
+ # get the bounding box of alpha
1411
+ y, x = np.where(alpha > 0)
1412
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
1413
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
1414
+
1415
+ image_center = rgba[y0:y1, x0:x1]
1416
+ # resize the longer side to H * 0.9
1417
+ H, W, _ = image_center.shape
1418
+ if H > W:
1419
+ W = int(W * (HEIGHT * 0.9) / H)
1420
+ H = int(HEIGHT * 0.9)
1421
+ else:
1422
+ H = int(H * (WIDTH * 0.9) / W)
1423
+ W = int(WIDTH * 0.9)
1424
+
1425
+ image_center = np.array(Image.fromarray(image_center).resize((W, H), Image.LANCZOS))
1426
+ # pad to H, W
1427
+ start_h = (HEIGHT - H) // 2
1428
+ start_w = (WIDTH - W) // 2
1429
+ image = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8)
1430
+ image[start_h : start_h + H, start_w : start_w + W] = image_center
1431
+ image = image.astype(np.float32) / 255.0
1432
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
1433
+ image = (image * 255).clip(0, 255).astype(np.uint8)
1434
+ image = Image.fromarray(image)
1435
+
1436
+ return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
1437
+ return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
1438
+ else:
1439
+ # Run DINO-X detection
1440
+ task = DinoxTask(
1441
+ image_url=image_url,
1442
+ prompts=[TextPrompt(text=input_text)],
1443
+ targets=[DetectionTarget.BBox, DetectionTarget.Mask]
1444
+ )
1445
+
1446
+ client.run_task(task)
1447
+ result = task.result
1448
+ objects = result.objects
1449
+
1450
+ predictions = task.result.objects
1451
+ classes = [x.strip().lower() for x in input_text.split('.') if x]
1452
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
1453
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
1454
+
1455
+ boxes = []
1456
+ masks = []
1457
+ confidences = []
1458
+ class_names = []
1459
+ class_ids = []
1460
+
1461
+ for idx, obj in enumerate(predictions):
1462
+ boxes.append(obj.bbox)
1463
+ masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
1464
+ confidences.append(obj.score)
1465
+ cls_name = obj.category.lower().strip()
1466
+ class_names.append(cls_name)
1467
+ class_ids.append(class_name_to_id[cls_name])
1468
+
1469
+ boxes = np.array(boxes)
1470
+ masks = np.array(masks)
1471
+ class_ids = np.array(class_ids)
1472
+ labels = [
1473
+ f"{class_name} {confidence:.2f}"
1474
+ for class_name, confidence
1475
+ in zip(class_names, confidences)
1476
+ ]
1477
+
1478
+ detections = sv.Detections(
1479
+ xyxy=boxes,
1480
+ mask=masks.astype(bool),
1481
+ class_id=class_ids,
1482
+ )
1483
+
1484
+ box_annotator = sv.BoxAnnotator()
1485
+ label_annotator = sv.LabelAnnotator()
1486
+ mask_annotator = sv.MaskAnnotator()
1487
+
1488
+ annotated_frame = input_image.copy()
1489
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
1490
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
1491
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
1492
+
1493
+ # Create transparent mask for first detected object
1494
+ if len(detections) > 0:
1495
+ # Get first mask
1496
+ first_mask = detections.mask[0]
1497
+
1498
+ # Get original RGB image
1499
+ img = input_image.copy()
1500
+ H, W, C = img.shape
1501
+
1502
+ # Create RGBA image with default 255 alpha
1503
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
1504
+ alpha[~first_mask] = 128 # Set semi-transparency for background
1505
+ alpha[first_mask] = 255 # Make the foreground opaque
1506
+
1507
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
1508
+
1509
+ # get the bounding box of alpha
1510
+ y, x = np.where(alpha > 0)
1511
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
1512
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
1513
+
1514
+ image_center = rgba[y0:y1, x0:x1]
1515
+ # resize the longer side to H * 0.9
1516
+ H, W, _ = image_center.shape
1517
+ if H > W:
1518
+ W = int(W * (HEIGHT * 0.9) / H)
1519
+ H = int(HEIGHT * 0.9)
1520
+ else:
1521
+ H = int(H * (WIDTH * 0.9) / W)
1522
+ W = int(WIDTH * 0.9)
1523
+
1524
+ image_center = np.array(Image.fromarray(image_center).resize((W, H), Image.LANCZOS))
1525
+ # pad to H, W
1526
+ start_h = (HEIGHT - H) // 2
1527
+ start_w = (WIDTH - W) // 2
1528
+ image = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8)
1529
+ image[start_h : start_h + H, start_w : start_w + W] = image_center
1530
+ image = image.astype(np.float32) / 255.0
1531
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
1532
+ image = (image * 255).clip(0, 255).astype(np.uint8)
1533
+ image = Image.fromarray(image)
1534
+
1535
+ return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
1536
+ return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
1537
 
1538
 
1539
  block = gr.Blocks().queue()