Spaces:
Running
on
Zero
Running
on
Zero
Upload core files for paper 2510.18876
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +35 -0
- CLAUDE.md +304 -0
- GRADIO_APP_SUMMARY.md +180 -0
- LICENSE +201 -0
- README.md +49 -5
- README_original.md +208 -0
- app.py +442 -0
- demo/gar_relationship.py +143 -0
- demo/gar_with_mask.py +132 -0
- demo/gar_with_sam.py +272 -0
- demo/gradio/.gradio/certificate.pem +31 -0
- demo/gradio/README.md +11 -0
- demo/gradio/app.py +267 -0
- demo/gradio/frontend/README.md +126 -0
- demo/gradio/frontend/configs/webpack/common.js +85 -0
- demo/gradio/frontend/configs/webpack/dev.js +25 -0
- demo/gradio/frontend/configs/webpack/prod.js +22 -0
- demo/gradio/frontend/package.json +64 -0
- demo/gradio/frontend/postcss.config.js +10 -0
- demo/gradio/frontend/src/App.tsx +306 -0
- demo/gradio/frontend/src/components/ErrorModal.tsx +32 -0
- demo/gradio/frontend/src/components/LoadingOverlay.tsx +30 -0
- demo/gradio/frontend/src/components/QueueStatusIndicator.tsx +29 -0
- demo/gradio/frontend/src/components/Stage.tsx +343 -0
- demo/gradio/frontend/src/components/Tool.tsx +182 -0
- demo/gradio/frontend/src/components/helpers/Interfaces.tsx +47 -0
- demo/gradio/frontend/src/components/helpers/imageUtils.tsx +21 -0
- demo/gradio/frontend/src/components/helpers/maskUtils.tsx +65 -0
- demo/gradio/frontend/src/components/helpers/onnxModelAPI.tsx +71 -0
- demo/gradio/frontend/src/components/helpers/scaleHelper.tsx +18 -0
- demo/gradio/frontend/src/components/hooks/context.tsx +35 -0
- demo/gradio/frontend/src/components/hooks/createContext.tsx +35 -0
- demo/gradio/frontend/src/index.tsx +17 -0
- demo/gradio/frontend/src/services/maskApi.tsx +211 -0
- demo/gradio/frontend/tailwind.config.js +12 -0
- demo/gradio/frontend/tsconfig.json +24 -0
- demo/gradio/frontend/yarn.lock +0 -0
- demo/gradio/requirements.txt +15 -0
- evaluation/DLC-Bench/annotations/annotations.json +0 -0
- evaluation/DLC-Bench/annotations/class_names.json +102 -0
- evaluation/DLC-Bench/annotations/qa.json +0 -0
- evaluation/DLC-Bench/eval_gpt_with_image.py +483 -0
- evaluation/DLC-Bench/eval_llama_without_image.py +503 -0
- evaluation/DLC-Bench/inference.py +173 -0
- evaluation/DLC-Bench/model_outputs/gar_1b.json +102 -0
- evaluation/DLC-Bench/model_outputs/gar_1b_eval.json +0 -0
- evaluation/DLC-Bench/model_outputs/gar_1b_eval_gpt.json +0 -0
- evaluation/DLC-Bench/model_outputs/gar_8b.json +102 -0
- evaluation/DLC-Bench/model_outputs/gar_8b_eval.json +0 -0
- evaluation/DLC-Bench/model_outputs/gar_8b_eval_gpt.json +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__
|
| 3 |
+
*.pyc
|
| 4 |
+
*.egg-info
|
| 5 |
+
|
| 6 |
+
# Log
|
| 7 |
+
*.log
|
| 8 |
+
*.log.*
|
| 9 |
+
# *.json
|
| 10 |
+
# *.jsonl
|
| 11 |
+
|
| 12 |
+
# Data
|
| 13 |
+
!**/alpaca-data-conversation.json
|
| 14 |
+
|
| 15 |
+
# Editor
|
| 16 |
+
.idea
|
| 17 |
+
*.swp
|
| 18 |
+
|
| 19 |
+
# Other
|
| 20 |
+
.DS_Store
|
| 21 |
+
wandb
|
| 22 |
+
# output
|
| 23 |
+
|
| 24 |
+
checkpoints
|
| 25 |
+
ckpts*
|
| 26 |
+
pretrained*
|
| 27 |
+
|
| 28 |
+
.ipynb_checkpoints
|
| 29 |
+
*.ipynb
|
| 30 |
+
|
| 31 |
+
# DevContainer
|
| 32 |
+
!.devcontainer/*
|
| 33 |
+
|
| 34 |
+
# Demo
|
| 35 |
+
serve_images/
|
CLAUDE.md
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLAUDE.md
|
| 2 |
+
|
| 3 |
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
| 4 |
+
|
| 5 |
+
## Project Overview
|
| 6 |
+
|
| 7 |
+
**Grasp Any Region (GAR)** is a research project for region-level multimodal understanding in vision-language models. It enables:
|
| 8 |
+
|
| 9 |
+
1. **Single Region Understanding**: Detailed description of specific image/video regions via points/boxes/scribbles/masks
|
| 10 |
+
2. **Multi-Region Reasoning**: Complex relationship modeling and reasoning across multiple regions simultaneously
|
| 11 |
+
3. **Advanced Compositional Reasoning**: Active dialogue about regions rather than passive description
|
| 12 |
+
|
| 13 |
+
The model is built on top of Facebook's Perception-LM architecture and uses xTuner training framework with PyTorch distributed training.
|
| 14 |
+
|
| 15 |
+
## Architecture
|
| 16 |
+
|
| 17 |
+
### Core Components
|
| 18 |
+
|
| 19 |
+
**Model Architecture** (`projects/grasp_any_region/models/grasp_any_region.py:GraspAnyRegion`):
|
| 20 |
+
- Wraps `PerceptionLMForConditionalGeneration` from HuggingFace
|
| 21 |
+
- Key innovation: **RoI-aligned feature replay technique** using `torchvision.ops.roi_align`
|
| 22 |
+
- Adds `mask_patch_embedding` layer (Conv2d) for region mask encoding
|
| 23 |
+
- Supports 15 visual prompt tokens (`<Prompt0>` through `<Prompt14>`) plus `<NO_Prompt>`
|
| 24 |
+
- Forward pass implements feature replay mechanism at grasp_any_region.py:291-377
|
| 25 |
+
|
| 26 |
+
**Visual Prompt System**:
|
| 27 |
+
- Masks are encoded with prompt IDs (0-14) where each ID represents a different region
|
| 28 |
+
- Special value (15 = `<NO_Prompt>`) indicates background/non-region areas
|
| 29 |
+
- RoI features are extracted using bounding boxes and replayed into the sequence at crop token positions
|
| 30 |
+
|
| 31 |
+
**Training Pipeline**:
|
| 32 |
+
- Uses xTuner framework (built on MMEngine)
|
| 33 |
+
- Dataset: Arrow format with three subsets (Seed, Fine-Grained, Relation)
|
| 34 |
+
- Custom collate function handles variable-length sequences and multi-region inputs
|
| 35 |
+
- Flash Attention 2 required for efficiency
|
| 36 |
+
|
| 37 |
+
### Directory Structure
|
| 38 |
+
|
| 39 |
+
```
|
| 40 |
+
projects/grasp_any_region/ # Main model code
|
| 41 |
+
├── configs/ # Training configs (gar_1b.py, gar_8b.py)
|
| 42 |
+
├── models/
|
| 43 |
+
│ ├── grasp_any_region.py # Main model wrapper
|
| 44 |
+
│ └── modeling/ # Custom PerceptionLM implementations
|
| 45 |
+
├── datasets/ # Dataset and data loading
|
| 46 |
+
└── hf_models/ # HuggingFace conversion utilities
|
| 47 |
+
|
| 48 |
+
demo/ # Inference demos
|
| 49 |
+
├── gar_with_mask.py # Direct mask input
|
| 50 |
+
├── gar_with_sam.py # SAM-based region selection
|
| 51 |
+
├── gar_relationship.py # Multi-region reasoning
|
| 52 |
+
└── gradio/ # Web demo
|
| 53 |
+
|
| 54 |
+
evaluation/ # Benchmarks
|
| 55 |
+
├── GAR-Bench/ # Custom benchmark (Caption-Simple, Caption-Detailed, VQA)
|
| 56 |
+
├── DLC-Bench/ # Detailed localized captioning
|
| 57 |
+
├── Ferret-Bench/ # Region description
|
| 58 |
+
└── MDVP-Bench/ # Multi-domain visual perception
|
| 59 |
+
|
| 60 |
+
tools/
|
| 61 |
+
├── train.py # Training entry point
|
| 62 |
+
├── test.py # Testing entry point
|
| 63 |
+
└── dist.sh # Distributed training launcher
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
## Common Commands
|
| 67 |
+
|
| 68 |
+
### Environment Setup
|
| 69 |
+
|
| 70 |
+
```bash
|
| 71 |
+
# Create environment (requires Python 3.11.2)
|
| 72 |
+
conda create -n gar python=3.11.2 -y
|
| 73 |
+
conda activate gar
|
| 74 |
+
|
| 75 |
+
# Install dependencies
|
| 76 |
+
pip3 install xtuner==0.2.0rc0
|
| 77 |
+
pip3 install -r requirements.txt
|
| 78 |
+
pip3 install flash-attn==2.7.4.post1 --no-build-isolation -v
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
### Training
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
# Single-node distributed training (8 GPUs)
|
| 85 |
+
bash tools/dist.sh train projects/grasp_any_region/configs/gar_1b.py 8
|
| 86 |
+
|
| 87 |
+
# The dist.sh script uses torchrun with:
|
| 88 |
+
# - Configurable MASTER_ADDR, PORT, NNODES, NODE_RANK
|
| 89 |
+
# - DeepSpeed Zero2 by default (set DEEPSPEED env var to override)
|
| 90 |
+
# - 5-hour timeout (TORCHELASTIC_TIMEOUT=18000)
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
**Config Files**:
|
| 94 |
+
- `projects/grasp_any_region/configs/gar_1b.py` - 1B model
|
| 95 |
+
- `projects/grasp_any_region/configs/gar_8b.py` - 8B model
|
| 96 |
+
|
| 97 |
+
Key training settings (gar_1b.py):
|
| 98 |
+
- Base model: `facebook/Perception-LM-1B`
|
| 99 |
+
- Batch size: 1 per device × 2 accumulation × 32 GPUs = 64 global
|
| 100 |
+
- Learning rate: 1e-5 (AdamW), warmup: 3%, cosine annealing
|
| 101 |
+
- Max length: 16384 tokens
|
| 102 |
+
- Saves every 5000 steps, keeps last 2 checkpoints
|
| 103 |
+
|
| 104 |
+
### Dataset Preparation
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
# Download dataset from HuggingFace
|
| 108 |
+
hf download HaochenWang/Grasp-Any-Region-Dataset --local-dir data --repo-type dataset
|
| 109 |
+
|
| 110 |
+
# Expected structure:
|
| 111 |
+
# data/
|
| 112 |
+
# ├── Seed-Dataset/data-*.arrow
|
| 113 |
+
# ├── Fine-Grained-Dataset/data-*.arrow
|
| 114 |
+
# └── Relation-Dataset/data-*.arrow
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
### Inference Demos
|
| 118 |
+
|
| 119 |
+
**Single Region with Mask**:
|
| 120 |
+
```bash
|
| 121 |
+
torchrun --nproc-per-node=1 --master-port=8119 \
|
| 122 |
+
demo/gar_with_mask.py \
|
| 123 |
+
--image_path assets/demo_image_1.png \
|
| 124 |
+
--mask_path assets/demo_mask_1.png
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
**Single Region with SAM** (points or box):
|
| 128 |
+
```bash
|
| 129 |
+
# Using points
|
| 130 |
+
torchrun --nproc-per-node=1 --master-port=8119 \
|
| 131 |
+
demo/gar_with_sam.py \
|
| 132 |
+
--image_path assets/demo_image_2.jpg \
|
| 133 |
+
--points '[[1172, 812], [1572, 800]]'
|
| 134 |
+
|
| 135 |
+
# Using bounding box
|
| 136 |
+
torchrun --nproc-per-node=1 --master-port=8119 \
|
| 137 |
+
demo/gar_with_sam.py \
|
| 138 |
+
--image_path assets/demo_image_2.jpg \
|
| 139 |
+
--box '[800, 500, 1800, 1000]' \
|
| 140 |
+
--use_box
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
**Multi-Region Relationship**:
|
| 144 |
+
```bash
|
| 145 |
+
torchrun --nproc-per-node=1 --master-port=8119 \
|
| 146 |
+
demo/gar_relationship.py \
|
| 147 |
+
--image_path assets/demo_image_3.png \
|
| 148 |
+
--mask_paths "['assets/demo_mask_3_0.png', 'assets/demo_mask_3_1.png', 'assets/demo_mask_3_2.png']" \
|
| 149 |
+
--question_str 'Question: What is the relationship between <Prompt0>, <Prompt1>, and <Prompt2>?'
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
**Gradio Demo**:
|
| 153 |
+
```bash
|
| 154 |
+
cd demo/gradio
|
| 155 |
+
pip install -r requirements.txt
|
| 156 |
+
python app.py
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
### Evaluation
|
| 160 |
+
|
| 161 |
+
All evaluation scripts follow the same pattern: inference → evaluation with LLM judge (GPT-4o or Llama).
|
| 162 |
+
|
| 163 |
+
**GARBench-Caption-Simple**:
|
| 164 |
+
```bash
|
| 165 |
+
# Inference
|
| 166 |
+
torchrun --nproc-per-node=1 --master-port=9811 \
|
| 167 |
+
evaluation/GAR-Bench/inference.py \
|
| 168 |
+
--model_name_or_path HaochenWang/GAR-8B \
|
| 169 |
+
--anno_file evaluation/GAR-Bench/annotations/GAR-Bench-Caption-Simple.json \
|
| 170 |
+
--mode simple \
|
| 171 |
+
--cache_name my_test \
|
| 172 |
+
--data_type bf16 \
|
| 173 |
+
--seed 42
|
| 174 |
+
|
| 175 |
+
# Evaluation (requires Azure OpenAI)
|
| 176 |
+
export AZURE_OPENAI_ENDPOINT=YOUR_ENDPOINT
|
| 177 |
+
export AZURE_OPENAI_KEY=YOUR_KEY
|
| 178 |
+
python3 evaluation/GAR-Bench/eval_simple.py \
|
| 179 |
+
--pred evaluation/GAR-Bench/model_outputs/my_test_simple.json
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
**GARBench-VQA** (multi-region reasoning):
|
| 183 |
+
```bash
|
| 184 |
+
torchrun --nproc-per-node=1 --master-port=9811 \
|
| 185 |
+
evaluation/GAR-Bench/inference.py \
|
| 186 |
+
--model_name_or_path HaochenWang/GAR-8B \
|
| 187 |
+
--anno_file evaluation/GAR-Bench/annotations/GAR-Bench-VQA.json \
|
| 188 |
+
--mode vqa \
|
| 189 |
+
--cache_name my_test \
|
| 190 |
+
--data_type bf16
|
| 191 |
+
# VQA evaluation is automatic (no LLM judge)
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
**DLC-Bench** (detailed localized captioning):
|
| 195 |
+
```bash
|
| 196 |
+
# Download images first
|
| 197 |
+
cd evaluation/DLC-Bench/annotations
|
| 198 |
+
hf download nvidia/DLC-Bench --repo-type dataset --include "images/*" --local-dir ./
|
| 199 |
+
cd ../../..
|
| 200 |
+
|
| 201 |
+
# Inference
|
| 202 |
+
torchrun --nproc-per-node=1 --master-port=8841 \
|
| 203 |
+
evaluation/DLC-Bench/inference.py \
|
| 204 |
+
--model_name_or_path HaochenWang/GAR-8B \
|
| 205 |
+
--cache_name my_test \
|
| 206 |
+
--data_type bf16
|
| 207 |
+
|
| 208 |
+
# Evaluation with GPT-4o
|
| 209 |
+
python3 evaluation/DLC-Bench/eval_gpt_with_image.py \
|
| 210 |
+
--pred evaluation/DLC-Bench/model_outputs/my_test.json
|
| 211 |
+
|
| 212 |
+
# Alternative: Evaluation with Llama3.1-8B (requires vLLM server)
|
| 213 |
+
bash evaluation/DLC-Bench/serve_judge.sh # in one terminal
|
| 214 |
+
python3 evaluation/DLC-Bench/eval_llama_without_image.py \
|
| 215 |
+
--pred evaluation/DLC-Bench/model_outputs/my_test.json \
|
| 216 |
+
--base_url http://localhost:8007/v1
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
### Model Conversion
|
| 220 |
+
|
| 221 |
+
```bash
|
| 222 |
+
# Convert trained checkpoint to HuggingFace format
|
| 223 |
+
python3 projects/grasp_any_region/hf_models/convert_to_hf.py \
|
| 224 |
+
projects/grasp_any_region/configs/gar_1b.py \
|
| 225 |
+
--pth-model PATH_TO_PTH_MODEL \
|
| 226 |
+
--save-path PATH_TO_SAVE_FOLDER
|
| 227 |
+
|
| 228 |
+
# Note: Manually copy required .py files to save folder after conversion
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
## Key Implementation Details
|
| 232 |
+
|
| 233 |
+
### RoI Feature Replay Mechanism
|
| 234 |
+
|
| 235 |
+
The core innovation is at `grasp_any_region.py:291-377`:
|
| 236 |
+
|
| 237 |
+
1. Image features are extracted as tiles (16×16 patches per tile)
|
| 238 |
+
2. Tiles are merged into full-resolution feature map
|
| 239 |
+
3. For each `<PromptN>` token in input:
|
| 240 |
+
- Extract RoI bounding box from `data["bboxes"]`
|
| 241 |
+
- Apply `torchvision.ops.roi_align` to extract 16×16 features
|
| 242 |
+
- Replace prompt tokens in sequence with RoI features
|
| 243 |
+
4. This allows attending to region-specific features with global context
|
| 244 |
+
|
| 245 |
+
### Mask Encoding
|
| 246 |
+
|
| 247 |
+
Masks are provided as 3-channel images where pixel values encode prompt IDs:
|
| 248 |
+
- Values 0-14: Different region prompts
|
| 249 |
+
- Value 15 (or `prompt_numbers`): Background (no prompt)
|
| 250 |
+
- `mask_patch_embedding` (Conv2d) encodes binary masks into feature space
|
| 251 |
+
- Masks are processed at patch level matching vision encoder stride
|
| 252 |
+
|
| 253 |
+
### Data Format
|
| 254 |
+
|
| 255 |
+
Dataset uses Arrow format with fields:
|
| 256 |
+
- `pixel_values`: (num_tiles, 3, H, W) image tiles
|
| 257 |
+
- `input_ids`: Token sequence with special image/prompt tokens
|
| 258 |
+
- `labels`: Target sequence (-100 for non-loss positions)
|
| 259 |
+
- `global_mask_values`: Region masks with prompt IDs
|
| 260 |
+
- `aspect_ratios`: (ncw, nch) tile arrangement
|
| 261 |
+
- `bboxes`: Dict mapping crop tokens to normalized bbox coordinates
|
| 262 |
+
|
| 263 |
+
### Special Tokens
|
| 264 |
+
|
| 265 |
+
The model extends base tokenizer with:
|
| 266 |
+
- `<Prompt0>` through `<Prompt14>`: Region identifiers in text
|
| 267 |
+
- `<NO_Prompt>`: Background/non-region marker
|
| 268 |
+
- `<|reserved_special_token_{pid+2}|>`: Internal crop tokens for feature replay
|
| 269 |
+
|
| 270 |
+
## Important Notes
|
| 271 |
+
|
| 272 |
+
- **Flash Attention 2 is required** - training will fail without it
|
| 273 |
+
- **Python 3.11.2 specifically** - later versions may have compatibility issues
|
| 274 |
+
- **Single batch size only** - code asserts `batch_size=1` at grasp_any_region.py:270
|
| 275 |
+
- **Distributed training required** - single-GPU training not well supported
|
| 276 |
+
- **DeepSpeed Zero2** - default optimization for memory efficiency
|
| 277 |
+
- **torchrun vs torch.distributed.launch** - dist.sh tries torchrun first, falls back to launch
|
| 278 |
+
- **xTuner framework** - all training uses xTuner's runner, not native PyTorch
|
| 279 |
+
- **Evaluation randomness** - LLM judges have variance even with temperature=0
|
| 280 |
+
|
| 281 |
+
## HuggingFace Models
|
| 282 |
+
|
| 283 |
+
Pre-trained models available:
|
| 284 |
+
- `HaochenWang/GAR-1B` - 1 billion parameter model
|
| 285 |
+
- `HaochenWang/GAR-8B` - 8 billion parameter model
|
| 286 |
+
|
| 287 |
+
Base architecture:
|
| 288 |
+
- `facebook/Perception-LM-1B` - Base vision-language model
|
| 289 |
+
- `facebook/Perception-LM-8B` - Larger variant
|
| 290 |
+
|
| 291 |
+
## Citation
|
| 292 |
+
|
| 293 |
+
```bibtex
|
| 294 |
+
@article{wang2025grasp,
|
| 295 |
+
title={Grasp Any Region: Prompting MLLM to Understand the Dense World},
|
| 296 |
+
author={Haochen Wang and Yuhao Wang and Tao Zhang and Yikang Zhou and Yanwei Li and Jiacong Wang and Ye Tian and Jiahao Meng and Zilong Huang and Guangcan Mai and Anran Wang and Yunhai Tong and Zhuochen Wang and Xiangtai Li and Zhaoxiang Zhang},
|
| 297 |
+
journal={arXiv preprint arXiv:2510.18876},
|
| 298 |
+
year={2025}
|
| 299 |
+
}
|
| 300 |
+
```
|
| 301 |
+
|
| 302 |
+
## License
|
| 303 |
+
|
| 304 |
+
Apache-2.0 License
|
GRADIO_APP_SUMMARY.md
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Gradio App Summary for Grasp Any Region (GAR)
|
| 2 |
+
|
| 3 |
+
## ✅ Completion Status
|
| 4 |
+
|
| 5 |
+
Successfully created a comprehensive Gradio demo for the Grasp Any Region (GAR) project.
|
| 6 |
+
|
| 7 |
+
## 📁 Files Created/Modified
|
| 8 |
+
|
| 9 |
+
### 1. **app.py** (NEW)
|
| 10 |
+
- Complete Gradio interface with 3 tabs:
|
| 11 |
+
- **Points → Describe**: Interactive point-based segmentation with SAM
|
| 12 |
+
- **Box → Describe**: Bounding box-based segmentation
|
| 13 |
+
- **Mask → Describe**: Direct mask upload for region description
|
| 14 |
+
- Features:
|
| 15 |
+
- ZeroGPU integration with `@spaces.GPU` decorator
|
| 16 |
+
- Proper import order (spaces first, then CUDA packages)
|
| 17 |
+
- SAM (Segment Anything Model) integration for interactive segmentation
|
| 18 |
+
- GAR-1B model for detailed region descriptions
|
| 19 |
+
- Visualization with contours and input annotations
|
| 20 |
+
- Example images and clear instructions
|
| 21 |
+
- Error handling and status messages
|
| 22 |
+
|
| 23 |
+
### 2. **requirements.txt** (UPDATED)
|
| 24 |
+
- Gradio 5.49.1 (required version)
|
| 25 |
+
- httpx version fixed to >=0.24.1,<1.0 (Gradio compatibility)
|
| 26 |
+
- PyTorch 2.8.0 (pinned for FlashAttention compatibility)
|
| 27 |
+
- FlashAttention 2.8.3 prebuilt wheel (PyTorch 2.8, Python 3.10, CUDA 12, abiFALSE)
|
| 28 |
+
- spaces==0.30.4 for ZeroGPU
|
| 29 |
+
- All original dependencies preserved
|
| 30 |
+
- Segment Anything from GitHub
|
| 31 |
+
- Vision libraries (opencv-python, pillow, pycocotools)
|
| 32 |
+
- Transformers 4.56.2 and supporting ML libraries
|
| 33 |
+
|
| 34 |
+
## 🎯 Key Features
|
| 35 |
+
|
| 36 |
+
1. **Three Interaction Modes**:
|
| 37 |
+
- Points: Click or enter coordinates to segment regions
|
| 38 |
+
- Box: Draw or enter bounding boxes
|
| 39 |
+
- Mask: Upload pre-made masks directly
|
| 40 |
+
|
| 41 |
+
2. **Model Integration**:
|
| 42 |
+
- GAR-1B for region understanding (1 billion parameters)
|
| 43 |
+
- SAM ViT-Huge for automatic segmentation
|
| 44 |
+
- Both models loaded once at startup for efficiency
|
| 45 |
+
|
| 46 |
+
3. **ZeroGPU Optimization**:
|
| 47 |
+
- Proper `@spaces.GPU(duration=120)` decorator usage
|
| 48 |
+
- 2-minute GPU allocation per function call
|
| 49 |
+
- NVIDIA H200 with 70GB VRAM available
|
| 50 |
+
- Critical import order: `spaces` imported before torch
|
| 51 |
+
|
| 52 |
+
4. **User Experience**:
|
| 53 |
+
- Clear step-by-step instructions
|
| 54 |
+
- Example images included
|
| 55 |
+
- Real-time visualization with overlays
|
| 56 |
+
- Comprehensive error handling
|
| 57 |
+
- Professional UI with Gradio 5.x Soft theme
|
| 58 |
+
|
| 59 |
+
## 🔧 Technical Details
|
| 60 |
+
|
| 61 |
+
### Import Order (CRITICAL)
|
| 62 |
+
```python
|
| 63 |
+
# 🚨 spaces MUST be imported FIRST
|
| 64 |
+
import spaces
|
| 65 |
+
|
| 66 |
+
# Then import CUDA packages
|
| 67 |
+
import torch
|
| 68 |
+
from transformers import AutoModel, AutoProcessor
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
This prevents the "CUDA has been initialized" error.
|
| 72 |
+
|
| 73 |
+
### FlashAttention Configuration
|
| 74 |
+
- Using prebuilt wheel for PyTorch 2.8.0
|
| 75 |
+
- Python 3.10 (cp310)
|
| 76 |
+
- CUDA 12 (cu12)
|
| 77 |
+
- abiFALSE (REQUIRED - never use abiTRUE)
|
| 78 |
+
- URL: https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
| 79 |
+
|
| 80 |
+
### Model Loading Strategy
|
| 81 |
+
- Models loaded once at startup (outside decorated functions)
|
| 82 |
+
- Moved to CUDA device after loading
|
| 83 |
+
- GPU-decorated functions only handle inference
|
| 84 |
+
- Efficient memory usage
|
| 85 |
+
|
| 86 |
+
## 📋 Dependencies Highlights
|
| 87 |
+
|
| 88 |
+
**Core:**
|
| 89 |
+
- gradio==5.49.1
|
| 90 |
+
- torch==2.8.0
|
| 91 |
+
- spaces==0.30.4
|
| 92 |
+
- flash-attn (prebuilt wheel)
|
| 93 |
+
|
| 94 |
+
**AI/ML:**
|
| 95 |
+
- transformers==4.56.2
|
| 96 |
+
- accelerate>=0.28.0
|
| 97 |
+
- timm==1.0.19
|
| 98 |
+
- peft==0.15.2
|
| 99 |
+
|
| 100 |
+
**Vision:**
|
| 101 |
+
- opencv-python
|
| 102 |
+
- pillow>=9.4.0
|
| 103 |
+
- segment-anything (from GitHub)
|
| 104 |
+
- pycocotools
|
| 105 |
+
|
| 106 |
+
## 🎨 UI Structure
|
| 107 |
+
|
| 108 |
+
```
|
| 109 |
+
Grasp Any Region (GAR) Demo
|
| 110 |
+
├── Introduction & Links
|
| 111 |
+
├── Tab 1: Points → Describe
|
| 112 |
+
│ ├── Image upload + points input
|
| 113 |
+
│ ├── Generate Mask button
|
| 114 |
+
│ ├── Describe Region button
|
| 115 |
+
│ └── Outputs: mask, visualization, description
|
| 116 |
+
├── Tab 2: Box → Describe
|
| 117 |
+
│ ├── Image upload + box input
|
| 118 |
+
│ ├── Generate Mask button
|
| 119 |
+
│ ├── Describe Region button
|
| 120 |
+
│ └── Outputs: mask, visualization, description
|
| 121 |
+
├── Tab 3: Mask → Describe
|
| 122 |
+
│ ├── Image upload + mask upload
|
| 123 |
+
│ ├── Describe Region button
|
| 124 |
+
│ └── Outputs: visualization, description
|
| 125 |
+
└── Documentation & Citation
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
## 🚀 How to Run
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
# Install dependencies
|
| 132 |
+
pip install -r requirements.txt
|
| 133 |
+
|
| 134 |
+
# Run the app
|
| 135 |
+
python app.py
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
The app will automatically:
|
| 139 |
+
1. Load GAR-1B and SAM models
|
| 140 |
+
2. Launch Gradio interface
|
| 141 |
+
3. Allocate GPU on-demand with ZeroGPU
|
| 142 |
+
|
| 143 |
+
## 📊 Expected Performance
|
| 144 |
+
|
| 145 |
+
- **Model**: GAR-1B (lightweight, fast inference)
|
| 146 |
+
- **GPU**: NVIDIA H200, 70GB VRAM
|
| 147 |
+
- **Inference Time**: ~10-30 seconds per region (depending on complexity)
|
| 148 |
+
- **Max New Tokens**: 1024 (configurable)
|
| 149 |
+
|
| 150 |
+
## ⚠️ Important Notes
|
| 151 |
+
|
| 152 |
+
1. **Import Order**: Always import `spaces` before torch/CUDA packages
|
| 153 |
+
2. **Python Version**: Requires Python 3.10 (for FlashAttention wheel)
|
| 154 |
+
3. **FlashAttention**: Uses prebuilt wheel (no compilation needed)
|
| 155 |
+
4. **Asset Files**: Demo expects images in `assets/` directory
|
| 156 |
+
5. **SingleRegionCaptionDataset**: Required from evaluation module
|
| 157 |
+
|
| 158 |
+
## 🔗 References
|
| 159 |
+
|
| 160 |
+
- **Paper**: https://arxiv.org/abs/2510.18876
|
| 161 |
+
- **GitHub**: https://github.com/Haochen-Wang409/Grasp-Any-Region
|
| 162 |
+
- **Model**: https://huggingface.co/HaochenWang/GAR-1B
|
| 163 |
+
- **SAM**: https://github.com/facebookresearch/segment-anything
|
| 164 |
+
|
| 165 |
+
## 📝 Citation
|
| 166 |
+
|
| 167 |
+
```bibtex
|
| 168 |
+
@article{wang2025grasp,
|
| 169 |
+
title={Grasp Any Region: Prompting MLLM to Understand the Dense World},
|
| 170 |
+
author={Haochen Wang et al.},
|
| 171 |
+
journal={arXiv preprint arXiv:2510.18876},
|
| 172 |
+
year={2025}
|
| 173 |
+
}
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
---
|
| 177 |
+
|
| 178 |
+
**Created**: 2025-10-25
|
| 179 |
+
**Status**: ✅ Ready for deployment
|
| 180 |
+
**Hardware**: zerogpu (NVIDIA H200, 70GB VRAM)
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,12 +1,56 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: "Grasp-Any-Region"
|
| 3 |
+
emoji: 🤖
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
short_description: "Manual Entry: https://huggingface.co/papers/2510.18876"
|
| 11 |
+
hardware: zerogpu
|
| 12 |
+
tags:
|
| 13 |
+
- research
|
| 14 |
+
- paper
|
| 15 |
+
- code
|
| 16 |
+
- cheatcode
|
| 17 |
+
license: mit
|
| 18 |
---
|
| 19 |
|
| 20 |
+
# Grasp-Any-Region
|
| 21 |
+
|
| 22 |
+
**Automated upload by CheatCode** 🚀
|
| 23 |
+
|
| 24 |
+
## 📄 Paper Information
|
| 25 |
+
|
| 26 |
+
- **Paper ID**: 2510.18876
|
| 27 |
+
- **Title**: Manual Entry: https://huggingface.co/papers/2510.18876
|
| 28 |
+
- **Original Repository**: [https://github.com/Haochen-Wang409/Grasp-Any-Region](https://github.com/Haochen-Wang409/Grasp-Any-Region)
|
| 29 |
+
|
| 30 |
+
## 🛠️ Repository Information
|
| 31 |
+
|
| 32 |
+
- **Languages**: JavaScript, Python, Shell, TypeScript
|
| 33 |
+
- **Gradio App**: ✅ Generated by CheatCode
|
| 34 |
+
|
| 35 |
+
## 🤖 About CheatCode
|
| 36 |
+
|
| 37 |
+
This Space was automatically created by [CheatCode](https://github.com/jbilcke-hf/CheatCode),
|
| 38 |
+
an AI-powered tool that:
|
| 39 |
+
|
| 40 |
+
1. Discovers research papers from HuggingFace
|
| 41 |
+
2. Extracts and analyzes linked repositories
|
| 42 |
+
3. Generates Gradio demo applications
|
| 43 |
+
4. Uploads everything to HuggingFace Spaces
|
| 44 |
+
|
| 45 |
+
## 📝 Usage
|
| 46 |
+
|
| 47 |
+
This Space includes a Gradio app that was automatically generated from the repository code.
|
| 48 |
+
|
| 49 |
+
## ⚠️ Disclaimer
|
| 50 |
+
|
| 51 |
+
This is an automated upload. The code comes from the original repository and may require
|
| 52 |
+
additional configuration or dependencies to run properly.
|
| 53 |
+
|
| 54 |
+
## 📜 License
|
| 55 |
+
|
| 56 |
+
Please refer to the original repository for licensing information: https://github.com/Haochen-Wang409/Grasp-Any-Region
|
README_original.md
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Grasp Any Region - Region-Level Visual Understanding
|
| 3 |
+
emoji: 🎯
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.49.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
short_description: A multimodal model for precise region-level understanding and reasoning in images and videos
|
| 11 |
+
hardware: zerogpu
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# Grasp Any Region: Towards Precise, Contextual Pixel Understanding for Multimodal LLMs
|
| 15 |
+
|
| 16 |
+
by
|
| 17 |
+
[Haochen Wang](https://haochen-wang409.github.io),
|
| 18 |
+
Yuhao Wang,
|
| 19 |
+
[Tao Zhang](https://scholar.google.com/citations?user=3xu4a5oAAAAJ),
|
| 20 |
+
[Yikang Zhou](https://scholar.google.com/citations?user=dZikW2YAAAAJ),
|
| 21 |
+
[Yanwei Li](https://yanwei-li.com/),
|
| 22 |
+
[Jiacong Wang](https://scholar.google.com/citations?user=rzYgLkgAAAAJ),
|
| 23 |
+
[Ye Tian](https://scholar.google.com/citations?user=vUY_PIUAAAAJ),
|
| 24 |
+
[Jiahao Meng](https://scholar.google.com/citations?user=NJfjvfIAAAAJ),
|
| 25 |
+
[Zilong Huang](https://speedinghzl.github.io/),
|
| 26 |
+
[Guangcan Mai](https://scholar.google.com/citations?user=739cUNMAAAAJ),
|
| 27 |
+
[Anran Wang](https://sites.google.com/view/anranwang/home),
|
| 28 |
+
[Yunhai Tong](https://scholar.google.com/citations?user=T4gqdPkAAAAJ),
|
| 29 |
+
Zhuochen Wang,
|
| 30 |
+
[Xiangtai Li](https://lxtgh.github.io/), and
|
| 31 |
+
[Zhaoxiang Zhang](https://scholar.google.com/citations?user=qxWfV6cAAAAJ).
|
| 32 |
+
|
| 33 |
+
[[Paper](https://arxiv.org/abs/2510.18876)] | [[HuggingFace](https://huggingface.co/collections/HaochenWang/grasp-any-region-68f7433671030d6ea682f692)] | [[Citation](#citation)]
|
| 34 |
+
|
| 35 |
+
**TL; DR**: Our Grasp Any Region (GAR) supports both (1) describing a *single* region of an image or a video in the form of points/boxes/scribbles/masks in detail and (2) understanding *multiple* regions such as modeling interactions and performing complex reasoning. We also release a new benchmark, GARBench, to evaluate models on advanced region-level understanding tasks.
|
| 36 |
+
|
| 37 |
+

|
| 38 |
+
|
| 39 |
+
> **Abstract.** While Multimodal Large Language Models (MLLMs) excel at holistic understanding, they struggle
|
| 40 |
+
> in capturing the dense world with complex scenes, requiring fine-grained analysis of intricate
|
| 41 |
+
> details and object inter-relationships. Region-level MLLMs have been a promising step. However,
|
| 42 |
+
> previous attempts are generally optimized to understand given regions in isolation, neglecting
|
| 43 |
+
> crucial global contexts. To address this, we introduce Grasp Any Region (GAR) for comprehensive
|
| 44 |
+
> region-level visual understanding. Empowered by an effective RoI-aligned feature replay
|
| 45 |
+
> technique, GAR supports (1) precise perception by leveraging necessary global contexts, and (2)
|
| 46 |
+
> modeling interactions between multiple prompts. Together, it then naturally achieves (3) advanced
|
| 47 |
+
> compositional reasoning to answer specific free-form questions about any region, shifting the
|
| 48 |
+
> paradigm from passive description to active dialogue. Moreover, we construct GARBench, which
|
| 49 |
+
> not only provides a more accurate evaluation of single-region comprehension, but also, more
|
| 50 |
+
> importantly, measures interactions and complex reasoning across multiple regions. Extensive
|
| 51 |
+
> experiments have demonstrated that GAR-1B not only maintains the state-of-the-art captioning
|
| 52 |
+
> capabilities, e.g., outperforming DAM-3B +4.5 on DLC-Bench, but also excels at modeling rela-
|
| 53 |
+
> tionships between multiple prompts with advanced comprehension capabilities, even surpassing
|
| 54 |
+
> InternVL3-78B on GARBench-VQA. More importantly, our zero-shot GAR-8B even outperforms
|
| 55 |
+
> in-domain VideoRefer-7B on VideoRefer-BenchQ, indicating its strong capabilities can be easily
|
| 56 |
+
> transferred to videos.
|
| 57 |
+
|
| 58 |
+
# Installation
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
conda create -n gar python=3.11.2 -y
|
| 62 |
+
conda activate gar
|
| 63 |
+
|
| 64 |
+
pip3 install xtuner==0.2.0rc0
|
| 65 |
+
pip3 install -r requirements.txt
|
| 66 |
+
pip3 install flash-attn==2.7.4.post1 --no-build-isolation -v
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
# Demos
|
| 70 |
+
|
| 71 |
+
## Gradio Demo
|
| 72 |
+
|
| 73 |
+
Please refer to [`demo/gradio/README.md`](demo/gradio/README.md) for serving an online captioning demo using gradio.
|
| 74 |
+
|
| 75 |
+
## Examples
|
| 76 |
+
|
| 77 |
+
### Detailed Localized Image Descriptions with Masks
|
| 78 |
+
|
| 79 |
+
- [`demo/gar_with_mask.py`](demo/gar_with_mask.py) - Command-line tool for processing single images, allowing users to specify specify the region-of-interest using its segmentation mask.
|
| 80 |
+
|
| 81 |
+
<details>
|
| 82 |
+
<summary>Expand to see example commands</summary>
|
| 83 |
+
|
| 84 |
+
<img src="assets/1_output_visualization.png" width="400">
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
torchrun --nproc-per-node=1 --master-port=8119 demo/gar_with_mask.py --image_path assets/demo_image_1.png --mask_path assets/demo_mask_1.png
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
**Input instruction:** Describe the masked region in detail.
|
| 91 |
+
|
| 92 |
+
**Output answer:** A bright green, **frog-shaped slipper** with a smooth, rounded body and a wide, open mouth. The slipper has a small, raised bump on the top of its head, resembling a frog's eye.
|
| 93 |
+
|
| 94 |
+
</details>
|
| 95 |
+
|
| 96 |
+
### Detailed Localized Image Descriptions with SAM
|
| 97 |
+
|
| 98 |
+
- [`demo/gar_with_sam.py`](demo/gar_with_sam.py) - Command-line tool for processing single images using SAM v1, allowing users to specify points or bounding boxes for mask generation
|
| 99 |
+
|
| 100 |
+
<details>
|
| 101 |
+
<summary>Expand to see example commands</summary>
|
| 102 |
+
|
| 103 |
+
<img src="assets/2_output_visualization.png" width="400">
|
| 104 |
+
|
| 105 |
+
```bash
|
| 106 |
+
# You can use it with points or a bounding box for the region of interest.
|
| 107 |
+
# SAM is used to turn points or a bounding box into a mask.
|
| 108 |
+
# You can also use mask directly, see `demo/gar_with_mask.py`.
|
| 109 |
+
torchrun --nproc-per-node=1 --master-port=8119 demo/gar_with_sam.py --image_path assets/demo_image_2.jpg --points '[[1172, 812], [1572, 800]]' --output_image_path output_visualization.png
|
| 110 |
+
torchrun --nproc-per-node=1 --master-port=8119 demo/gar_with_sam.py --image_path assets/demo_image_2.jpg --box '[800, 500, 1800, 1000]' --use_box --output_image_path output_visualization.png
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
**Input instruction:** Describe the masked region in detail.
|
| 114 |
+
|
| 115 |
+
**Output answer:** A medium-sized, short-haired dog with a predominantly tan coat featuring white markings on its face, chest, and paws. The dog has a white stripe running down the center of its face, extending from the forehead to the nose. Its ears are large, pointed, and stand erect. The dog is wearing a red collar with a visible tag. Its mouth is open, revealing its tongue and teeth, and it appears to be in mid-leap with its front legs extended forward and hind legs stretched out behind.
|
| 116 |
+
|
| 117 |
+
</details>
|
| 118 |
+
|
| 119 |
+
### Modeling Complex Relationship between Multiple Regions
|
| 120 |
+
|
| 121 |
+
- [`demo/gar_relationship.py`](demo/gar_relationship.py) - Command-line tool for processing single images with multiple regions-of-interest, allowing users to specify specify the region-of-interest using its segmentation mask
|
| 122 |
+
|
| 123 |
+
<details>
|
| 124 |
+
<summary>Expand to see example commands</summary>
|
| 125 |
+
|
| 126 |
+
<img src="assets/3_output_visualization.png" width="400">
|
| 127 |
+
|
| 128 |
+
```bash
|
| 129 |
+
torchrun --nproc-per-node=1 --master-port=8119 demo/gar_relationship.py --image_path assets/demo_image_3.png --mask_paths "['assets/demo_mask_3_0.png', 'assets/demo_mask_3_1.png', 'assets/demo_mask_3_2.png']" --question_str 'Question: What is the relationship between <Prompt0>, <Prompt1>, and <Prompt2>?\nOptions:\nA. <Prompt0> is using <Prompt2> to point at <Prompt1>\nB. <Prompt0> has already hit <Prompt1> with <Prompt2>\nC. <Prompt0> is swinging <Prompt2> and is about to hit <Prompt1>\nD. <Prompt0> is holding <Prompt2> while looking away from <Prompt1>'
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
**Input instruction:**
|
| 133 |
+
|
| 134 |
+
```
|
| 135 |
+
Question: What is the relationship between <Prompt0>, <Prompt1>, and <Prompt2>?
|
| 136 |
+
Options:
|
| 137 |
+
A. <Prompt0> is using <Prompt2> to point at <Prompt1>
|
| 138 |
+
B. <Prompt0> has already hit <Prompt1> with <Prompt2>
|
| 139 |
+
C. <Prompt0> is swinging <Prompt2> and is about to hit <Prompt1>
|
| 140 |
+
D. <Prompt0> is holding <Prompt2> while looking away from <Prompt1>
|
| 141 |
+
Answer with the correct option's letter directly.
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
**Output answer:** C
|
| 145 |
+
|
| 146 |
+
Note that `<Prompt0>`, `<Prompt1>`, and `<Prompt2>` are illustrated in <span style="color:#C00000;">red</span>, <span style="color:#00B050;">green</span>, and <span style="color:#0000FF;">blue</span>, respectively.
|
| 147 |
+
|
| 148 |
+
</details>
|
| 149 |
+
|
| 150 |
+
# Training
|
| 151 |
+
|
| 152 |
+
**1. Dataset Preparation**
|
| 153 |
+
|
| 154 |
+
First, download the dataset:
|
| 155 |
+
|
| 156 |
+
`hf download HaochenWang/Grasp-Any-Region-Dataset --local-dir data --repo-type dataset`
|
| 157 |
+
|
| 158 |
+
The overall data structure should be:
|
| 159 |
+
```sh
|
| 160 |
+
data
|
| 161 |
+
├── Fine-Grained-Dataset
|
| 162 |
+
│ └── data-*-of-*.arrow
|
| 163 |
+
├── Relation-Dataset
|
| 164 |
+
│ └── data-*-of-*.arrow
|
| 165 |
+
└── Seed-Dataset
|
| 166 |
+
└── data-*-of-*.arrow
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
**2. Launch Training**
|
| 170 |
+
|
| 171 |
+
Next, run the following script to train using 8 GPUS:
|
| 172 |
+
|
| 173 |
+
`bash tools/dist.sh train projects/grasp_any_region/configs/gar_1b.py 8`
|
| 174 |
+
|
| 175 |
+
**3. Convert to HuggingFace Format**
|
| 176 |
+
|
| 177 |
+
```python3 projects/grasp_any_region/hf_models/convert_to_hf.py projects/grasp_any_region/configs/gar_1b.py --pth-model PATH_TO_PTH_MODEL --save-path PATH_TO_SAVE_FOLDER```
|
| 178 |
+
|
| 179 |
+
Note that this script only convert the checkpoint and some `*.py` files requires manually copy to `${PATH_TO_SAVE_FOLDER}`.
|
| 180 |
+
|
| 181 |
+
# Evaluation
|
| 182 |
+
|
| 183 |
+
Please refer to [`evaluation/EVALUATION.md`](evaluation/EVALUATION.md).
|
| 184 |
+
|
| 185 |
+
# License
|
| 186 |
+
|
| 187 |
+
This project is licensed under the [Apache-2.0 License](LICENSE).
|
| 188 |
+
|
| 189 |
+
# Citation
|
| 190 |
+
|
| 191 |
+
If you use our work or our implementation in this repo, or find them helpful, please consider giving a citation in the following format.
|
| 192 |
+
|
| 193 |
+
```
|
| 194 |
+
@article{wang2025grasp,
|
| 195 |
+
title={Grasp Any Region: Prompting MLLM to Understand the Dense World},
|
| 196 |
+
author={Haochen Wang and Yuhao Wang and Tao Zhang and Yikang Zhou and Yanwei Li and Jiacong Wang and Ye Tian and Jiahao Meng and Zilong Huang and Guangcan Mai and Anran Wang and Yunhai Tong and Zhuochen Wang and Xiangtai Li and Zhaoxiang Zhang},
|
| 197 |
+
journal={arXiv preprint arXiv:2510.18876},
|
| 198 |
+
year={2025}
|
| 199 |
+
}
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
# Acknowledgements
|
| 203 |
+
|
| 204 |
+
We would like to thank the following projects for their contributions to this work:
|
| 205 |
+
|
| 206 |
+
- [SAM](https://github.com/facebookresearch/segment-anything)
|
| 207 |
+
- [DAM](https://github.com/NVlabs/describe-anything)
|
| 208 |
+
- [Sa2VA](https://github.com/bytedance/Sa2VA)
|
app.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# *************************************************************************
|
| 2 |
+
# Grasp Any Region (GAR) - Gradio Demo
|
| 3 |
+
# Region-level Multimodal Understanding for Vision-Language Models
|
| 4 |
+
# *************************************************************************
|
| 5 |
+
|
| 6 |
+
# 🚨 CRITICAL: Import spaces FIRST before any CUDA-related packages
|
| 7 |
+
import spaces
|
| 8 |
+
|
| 9 |
+
# Now import CUDA-related packages
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import gradio as gr
|
| 14 |
+
from transformers import (
|
| 15 |
+
AutoModel,
|
| 16 |
+
AutoProcessor,
|
| 17 |
+
GenerationConfig,
|
| 18 |
+
SamModel,
|
| 19 |
+
SamProcessor,
|
| 20 |
+
)
|
| 21 |
+
import cv2
|
| 22 |
+
import sys
|
| 23 |
+
import os
|
| 24 |
+
|
| 25 |
+
# Add project root to path for imports
|
| 26 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from evaluation.eval_dataset import SingleRegionCaptionDataset
|
| 30 |
+
except ImportError:
|
| 31 |
+
print("Warning: Could not import SingleRegionCaptionDataset. Using simplified version.")
|
| 32 |
+
SingleRegionCaptionDataset = None
|
| 33 |
+
|
| 34 |
+
# Initialize device
|
| 35 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
+
|
| 37 |
+
# Global model variables (loaded once)
|
| 38 |
+
gar_model = None
|
| 39 |
+
gar_processor = None
|
| 40 |
+
sam_model = None
|
| 41 |
+
sam_processor = None
|
| 42 |
+
|
| 43 |
+
def load_models():
|
| 44 |
+
"""Load models once at startup"""
|
| 45 |
+
global gar_model, gar_processor, sam_model, sam_processor
|
| 46 |
+
|
| 47 |
+
if gar_model is None:
|
| 48 |
+
print("Loading GAR model...")
|
| 49 |
+
model_path = "HaochenWang/GAR-1B"
|
| 50 |
+
gar_model = AutoModel.from_pretrained(
|
| 51 |
+
model_path,
|
| 52 |
+
trust_remote_code=True,
|
| 53 |
+
torch_dtype=torch.bfloat16,
|
| 54 |
+
device_map="auto",
|
| 55 |
+
).eval()
|
| 56 |
+
|
| 57 |
+
gar_processor = AutoProcessor.from_pretrained(
|
| 58 |
+
model_path,
|
| 59 |
+
trust_remote_code=True,
|
| 60 |
+
)
|
| 61 |
+
print("GAR model loaded successfully!")
|
| 62 |
+
|
| 63 |
+
if sam_model is None:
|
| 64 |
+
print("Loading SAM model...")
|
| 65 |
+
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
| 66 |
+
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
| 67 |
+
print("SAM model loaded successfully!")
|
| 68 |
+
|
| 69 |
+
@spaces.GPU(duration=120)
|
| 70 |
+
def generate_mask_from_points(image, points_str):
|
| 71 |
+
"""Generate mask using SAM from point coordinates"""
|
| 72 |
+
try:
|
| 73 |
+
load_models()
|
| 74 |
+
|
| 75 |
+
if not points_str or points_str.strip() == "":
|
| 76 |
+
return None, "Please provide points in format: x1,y1;x2,y2"
|
| 77 |
+
|
| 78 |
+
# Parse points
|
| 79 |
+
points = []
|
| 80 |
+
labels = []
|
| 81 |
+
for point in points_str.split(';'):
|
| 82 |
+
point = point.strip()
|
| 83 |
+
if point:
|
| 84 |
+
x, y = map(float, point.split(','))
|
| 85 |
+
points.append([x, y])
|
| 86 |
+
labels.append(1) # Foreground point
|
| 87 |
+
|
| 88 |
+
if not points:
|
| 89 |
+
return None, "No valid points provided"
|
| 90 |
+
|
| 91 |
+
# Apply SAM
|
| 92 |
+
inputs = sam_processor(
|
| 93 |
+
image,
|
| 94 |
+
input_points=[points],
|
| 95 |
+
input_labels=[labels],
|
| 96 |
+
return_tensors="pt",
|
| 97 |
+
).to(device)
|
| 98 |
+
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
outputs = sam_model(**inputs)
|
| 101 |
+
|
| 102 |
+
masks = sam_processor.image_processor.post_process_masks(
|
| 103 |
+
outputs.pred_masks.cpu(),
|
| 104 |
+
inputs["original_sizes"].cpu(),
|
| 105 |
+
inputs["reshaped_input_sizes"].cpu(),
|
| 106 |
+
)[0][0]
|
| 107 |
+
|
| 108 |
+
scores = outputs.iou_scores[0, 0]
|
| 109 |
+
mask_selection_index = scores.argmax()
|
| 110 |
+
mask_np = masks[mask_selection_index].numpy()
|
| 111 |
+
|
| 112 |
+
# Visualize mask
|
| 113 |
+
mask_img = (mask_np * 255).astype(np.uint8)
|
| 114 |
+
|
| 115 |
+
return Image.fromarray(mask_img), "Mask generated successfully!"
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
return None, f"Error generating mask: {str(e)}"
|
| 119 |
+
|
| 120 |
+
@spaces.GPU(duration=120)
|
| 121 |
+
def generate_mask_from_box(image, box_str):
|
| 122 |
+
"""Generate mask using SAM from bounding box"""
|
| 123 |
+
try:
|
| 124 |
+
load_models()
|
| 125 |
+
|
| 126 |
+
if not box_str or box_str.strip() == "":
|
| 127 |
+
return None, "Please provide box in format: x1,y1,x2,y2"
|
| 128 |
+
|
| 129 |
+
# Parse box
|
| 130 |
+
box = list(map(float, box_str.split(',')))
|
| 131 |
+
if len(box) != 4:
|
| 132 |
+
return None, "Box must have 4 coordinates: x1,y1,x2,y2"
|
| 133 |
+
|
| 134 |
+
# Apply SAM
|
| 135 |
+
inputs = sam_processor(
|
| 136 |
+
image,
|
| 137 |
+
input_boxes=[[box]],
|
| 138 |
+
return_tensors="pt",
|
| 139 |
+
).to(device)
|
| 140 |
+
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
outputs = sam_model(**inputs)
|
| 143 |
+
|
| 144 |
+
masks = sam_processor.image_processor.post_process_masks(
|
| 145 |
+
outputs.pred_masks.cpu(),
|
| 146 |
+
inputs["original_sizes"].cpu(),
|
| 147 |
+
inputs["reshaped_input_sizes"].cpu(),
|
| 148 |
+
)[0][0]
|
| 149 |
+
|
| 150 |
+
scores = outputs.iou_scores[0, 0]
|
| 151 |
+
mask_selection_index = scores.argmax()
|
| 152 |
+
mask_np = masks[mask_selection_index].numpy()
|
| 153 |
+
|
| 154 |
+
# Visualize mask
|
| 155 |
+
mask_img = (mask_np * 255).astype(np.uint8)
|
| 156 |
+
|
| 157 |
+
return Image.fromarray(mask_img), "Mask generated successfully!"
|
| 158 |
+
|
| 159 |
+
except Exception as e:
|
| 160 |
+
return None, f"Error generating mask: {str(e)}"
|
| 161 |
+
|
| 162 |
+
@spaces.GPU(duration=120)
|
| 163 |
+
def describe_region(image, mask):
|
| 164 |
+
"""Generate description for a region defined by a mask"""
|
| 165 |
+
try:
|
| 166 |
+
load_models()
|
| 167 |
+
|
| 168 |
+
if image is None:
|
| 169 |
+
return "Please provide an image"
|
| 170 |
+
|
| 171 |
+
if mask is None:
|
| 172 |
+
return "Please provide a mask (upload or generate using SAM)"
|
| 173 |
+
|
| 174 |
+
# Convert mask to numpy
|
| 175 |
+
if isinstance(mask, Image.Image):
|
| 176 |
+
mask_np = np.array(mask.convert("L"))
|
| 177 |
+
else:
|
| 178 |
+
mask_np = np.array(mask)
|
| 179 |
+
|
| 180 |
+
# Ensure mask is binary
|
| 181 |
+
mask_np = (mask_np > 127).astype(np.uint8)
|
| 182 |
+
|
| 183 |
+
# Prepare data
|
| 184 |
+
prompt_number = gar_model.config.prompt_numbers
|
| 185 |
+
prompt_tokens = [f"<Prompt{i_p}>" for i_p in range(prompt_number)] + ["<NO_Prompt>"]
|
| 186 |
+
|
| 187 |
+
if SingleRegionCaptionDataset is not None:
|
| 188 |
+
dataset = SingleRegionCaptionDataset(
|
| 189 |
+
image=image,
|
| 190 |
+
mask=mask_np,
|
| 191 |
+
processor=gar_processor,
|
| 192 |
+
prompt_number=prompt_number,
|
| 193 |
+
visual_prompt_tokens=prompt_tokens,
|
| 194 |
+
data_dtype=torch.bfloat16,
|
| 195 |
+
)
|
| 196 |
+
data_sample = dataset[0]
|
| 197 |
+
else:
|
| 198 |
+
# Simplified processing if dataset class not available
|
| 199 |
+
# This is a fallback - the actual implementation requires SingleRegionCaptionDataset
|
| 200 |
+
return "Error: SingleRegionCaptionDataset not available. Please check installation."
|
| 201 |
+
|
| 202 |
+
# Generate description
|
| 203 |
+
with torch.no_grad():
|
| 204 |
+
generate_ids = gar_model.generate(
|
| 205 |
+
**data_sample,
|
| 206 |
+
generation_config=GenerationConfig(
|
| 207 |
+
max_new_tokens=1024,
|
| 208 |
+
do_sample=False,
|
| 209 |
+
eos_token_id=gar_processor.tokenizer.eos_token_id,
|
| 210 |
+
pad_token_id=gar_processor.tokenizer.pad_token_id,
|
| 211 |
+
),
|
| 212 |
+
return_dict=True,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
output_caption = gar_processor.tokenizer.decode(
|
| 216 |
+
generate_ids.sequences[0], skip_special_tokens=True
|
| 217 |
+
).strip()
|
| 218 |
+
|
| 219 |
+
return output_caption
|
| 220 |
+
|
| 221 |
+
except Exception as e:
|
| 222 |
+
return f"Error generating description: {str(e)}"
|
| 223 |
+
|
| 224 |
+
def create_visualization(image, mask, points_str=None, box_str=None):
|
| 225 |
+
"""Create visualization with mask overlay"""
|
| 226 |
+
try:
|
| 227 |
+
if image is None or mask is None:
|
| 228 |
+
return None
|
| 229 |
+
|
| 230 |
+
img_np = np.array(image).astype(float) / 255.0
|
| 231 |
+
if isinstance(mask, Image.Image):
|
| 232 |
+
mask_np = np.array(mask.convert("L")) > 127
|
| 233 |
+
else:
|
| 234 |
+
mask_np = np.array(mask) > 127
|
| 235 |
+
|
| 236 |
+
# Draw contour
|
| 237 |
+
mask_uint8 = mask_np.astype(np.uint8) * 255
|
| 238 |
+
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 239 |
+
img_vis = img_np.copy()
|
| 240 |
+
cv2.drawContours(img_vis, contours, -1, (1.0, 1.0, 0.0), thickness=3)
|
| 241 |
+
|
| 242 |
+
# Draw points if provided
|
| 243 |
+
if points_str:
|
| 244 |
+
for point in points_str.split(';'):
|
| 245 |
+
point = point.strip()
|
| 246 |
+
if point:
|
| 247 |
+
x, y = map(float, point.split(','))
|
| 248 |
+
cv2.circle(img_vis, (int(x), int(y)), radius=8, color=(1.0, 0.0, 0.0), thickness=-1)
|
| 249 |
+
cv2.circle(img_vis, (int(x), int(y)), radius=8, color=(1.0, 1.0, 1.0), thickness=2)
|
| 250 |
+
|
| 251 |
+
# Draw box if provided
|
| 252 |
+
if box_str:
|
| 253 |
+
coords = list(map(float, box_str.split(',')))
|
| 254 |
+
if len(coords) == 4:
|
| 255 |
+
x1, y1, x2, y2 = map(int, coords)
|
| 256 |
+
cv2.rectangle(img_vis, (x1, y1), (x2, y2), color=(1.0, 1.0, 1.0), thickness=3)
|
| 257 |
+
cv2.rectangle(img_vis, (x1, y1), (x2, y2), color=(1.0, 0.0, 0.0), thickness=1)
|
| 258 |
+
|
| 259 |
+
img_pil = Image.fromarray((img_vis * 255.0).astype(np.uint8))
|
| 260 |
+
return img_pil
|
| 261 |
+
|
| 262 |
+
except Exception as e:
|
| 263 |
+
print(f"Error creating visualization: {str(e)}")
|
| 264 |
+
return None
|
| 265 |
+
|
| 266 |
+
# Create Gradio interface
|
| 267 |
+
with gr.Blocks(title="Grasp Any Region (GAR) Demo", theme=gr.themes.Soft()) as demo:
|
| 268 |
+
gr.Markdown("""
|
| 269 |
+
# 🎯 Grasp Any Region (GAR)
|
| 270 |
+
|
| 271 |
+
**Region-level Multimodal Understanding for Vision-Language Models**
|
| 272 |
+
|
| 273 |
+
This demo showcases GAR's ability to understand and describe specific regions in images:
|
| 274 |
+
- 🎨 **Single Region Understanding**: Describe specific areas using points, boxes, or masks
|
| 275 |
+
- 🔍 **SAM Integration**: Generate masks interactively using Segment Anything Model
|
| 276 |
+
- 💡 **Detailed Descriptions**: Get comprehensive descriptions of any region
|
| 277 |
+
|
| 278 |
+
Built on top of Perception-LM with RoI-aligned feature replay technique.
|
| 279 |
+
|
| 280 |
+
📄 [Paper](https://arxiv.org/abs/2510.18876) | 💻 [GitHub](https://github.com/Haochen-Wang409/Grasp-Any-Region) | 🤗 [Model](https://huggingface.co/HaochenWang/GAR-1B)
|
| 281 |
+
""")
|
| 282 |
+
|
| 283 |
+
with gr.Tabs():
|
| 284 |
+
# Tab 1: Points-based segmentation
|
| 285 |
+
with gr.Tab("🎯 Points → Describe"):
|
| 286 |
+
gr.Markdown("### Click points on the image or enter coordinates to segment and describe a region")
|
| 287 |
+
with gr.Row():
|
| 288 |
+
with gr.Column():
|
| 289 |
+
img_points = gr.Image(label="Input Image", type="pil")
|
| 290 |
+
points_input = gr.Textbox(
|
| 291 |
+
label="Points (format: x1,y1;x2,y2;...)",
|
| 292 |
+
placeholder="e.g., 1172,812;1572,800",
|
| 293 |
+
value="1172,812;1572,800"
|
| 294 |
+
)
|
| 295 |
+
with gr.Row():
|
| 296 |
+
gen_mask_points_btn = gr.Button("Generate Mask", variant="primary")
|
| 297 |
+
describe_points_btn = gr.Button("Describe Region", variant="secondary")
|
| 298 |
+
|
| 299 |
+
with gr.Column():
|
| 300 |
+
mask_points = gr.Image(label="Generated Mask", type="pil")
|
| 301 |
+
vis_points = gr.Image(label="Visualization")
|
| 302 |
+
desc_points = gr.Textbox(label="Region Description", lines=5)
|
| 303 |
+
|
| 304 |
+
points_status = gr.Textbox(label="Status", visible=False)
|
| 305 |
+
|
| 306 |
+
gen_mask_points_btn.click(
|
| 307 |
+
fn=generate_mask_from_points,
|
| 308 |
+
inputs=[img_points, points_input],
|
| 309 |
+
outputs=[mask_points, points_status]
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
describe_points_btn.click(
|
| 313 |
+
fn=describe_region,
|
| 314 |
+
inputs=[img_points, mask_points],
|
| 315 |
+
outputs=desc_points
|
| 316 |
+
).then(
|
| 317 |
+
fn=create_visualization,
|
| 318 |
+
inputs=[img_points, mask_points, points_input, gr.Textbox(visible=False)],
|
| 319 |
+
outputs=vis_points
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
gr.Examples(
|
| 323 |
+
examples=[
|
| 324 |
+
["assets/demo_image_2.jpg", "1172,812;1572,800"],
|
| 325 |
+
],
|
| 326 |
+
inputs=[img_points, points_input],
|
| 327 |
+
label="Example Images"
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Tab 2: Box-based segmentation
|
| 331 |
+
with gr.Tab("📦 Box → Describe"):
|
| 332 |
+
gr.Markdown("### Draw a bounding box or enter coordinates to segment and describe a region")
|
| 333 |
+
with gr.Row():
|
| 334 |
+
with gr.Column():
|
| 335 |
+
img_box = gr.Image(label="Input Image", type="pil")
|
| 336 |
+
box_input = gr.Textbox(
|
| 337 |
+
label="Bounding Box (format: x1,y1,x2,y2)",
|
| 338 |
+
placeholder="e.g., 800,500,1800,1000",
|
| 339 |
+
value="800,500,1800,1000"
|
| 340 |
+
)
|
| 341 |
+
with gr.Row():
|
| 342 |
+
gen_mask_box_btn = gr.Button("Generate Mask", variant="primary")
|
| 343 |
+
describe_box_btn = gr.Button("Describe Region", variant="secondary")
|
| 344 |
+
|
| 345 |
+
with gr.Column():
|
| 346 |
+
mask_box = gr.Image(label="Generated Mask", type="pil")
|
| 347 |
+
vis_box = gr.Image(label="Visualization")
|
| 348 |
+
desc_box = gr.Textbox(label="Region Description", lines=5)
|
| 349 |
+
|
| 350 |
+
box_status = gr.Textbox(label="Status", visible=False)
|
| 351 |
+
|
| 352 |
+
gen_mask_box_btn.click(
|
| 353 |
+
fn=generate_mask_from_box,
|
| 354 |
+
inputs=[img_box, box_input],
|
| 355 |
+
outputs=[mask_box, box_status]
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
describe_box_btn.click(
|
| 359 |
+
fn=describe_region,
|
| 360 |
+
inputs=[img_box, mask_box],
|
| 361 |
+
outputs=desc_box
|
| 362 |
+
).then(
|
| 363 |
+
fn=create_visualization,
|
| 364 |
+
inputs=[img_box, mask_box, gr.Textbox(visible=False), box_input],
|
| 365 |
+
outputs=vis_box
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
gr.Examples(
|
| 369 |
+
examples=[
|
| 370 |
+
["assets/demo_image_2.jpg", "800,500,1800,1000"],
|
| 371 |
+
],
|
| 372 |
+
inputs=[img_box, box_input],
|
| 373 |
+
label="Example Images"
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# Tab 3: Direct mask upload
|
| 377 |
+
with gr.Tab("🎭 Mask → Describe"):
|
| 378 |
+
gr.Markdown("### Upload a pre-made mask to describe a region")
|
| 379 |
+
with gr.Row():
|
| 380 |
+
with gr.Column():
|
| 381 |
+
img_mask = gr.Image(label="Input Image", type="pil")
|
| 382 |
+
mask_upload = gr.Image(label="Upload Mask", type="pil")
|
| 383 |
+
describe_mask_btn = gr.Button("Describe Region", variant="primary")
|
| 384 |
+
|
| 385 |
+
with gr.Column():
|
| 386 |
+
vis_mask = gr.Image(label="Visualization")
|
| 387 |
+
desc_mask = gr.Textbox(label="Region Description", lines=5)
|
| 388 |
+
|
| 389 |
+
describe_mask_btn.click(
|
| 390 |
+
fn=describe_region,
|
| 391 |
+
inputs=[img_mask, mask_upload],
|
| 392 |
+
outputs=desc_mask
|
| 393 |
+
).then(
|
| 394 |
+
fn=create_visualization,
|
| 395 |
+
inputs=[img_mask, mask_upload, gr.Textbox(visible=False), gr.Textbox(visible=False)],
|
| 396 |
+
outputs=vis_mask
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
gr.Examples(
|
| 400 |
+
examples=[
|
| 401 |
+
["assets/demo_image_1.png", "assets/demo_mask_1.png"],
|
| 402 |
+
],
|
| 403 |
+
inputs=[img_mask, mask_upload],
|
| 404 |
+
label="Example Images"
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
gr.Markdown("""
|
| 408 |
+
---
|
| 409 |
+
### 📖 How to Use:
|
| 410 |
+
|
| 411 |
+
1. **Points → Describe**: Click or enter point coordinates, generate mask, then describe
|
| 412 |
+
2. **Box → Describe**: Draw or enter a bounding box, generate mask, then describe
|
| 413 |
+
3. **Mask → Describe**: Upload a pre-made mask directly and describe
|
| 414 |
+
|
| 415 |
+
### 🔧 Technical Details:
|
| 416 |
+
|
| 417 |
+
- **Model**: GAR-1B (1 billion parameters)
|
| 418 |
+
- **Base**: Facebook Perception-LM with RoI-aligned feature replay
|
| 419 |
+
- **Segmentation**: Segment Anything Model (SAM ViT-Huge)
|
| 420 |
+
- **Hardware**: Powered by ZeroGPU (NVIDIA H200, 70GB VRAM)
|
| 421 |
+
|
| 422 |
+
### 📚 Citation:
|
| 423 |
+
|
| 424 |
+
```bibtex
|
| 425 |
+
@article{wang2025grasp,
|
| 426 |
+
title={Grasp Any Region: Prompting MLLM to Understand the Dense World},
|
| 427 |
+
author={Haochen Wang et al.},
|
| 428 |
+
journal={arXiv preprint arXiv:2510.18876},
|
| 429 |
+
year={2025}
|
| 430 |
+
}
|
| 431 |
+
```
|
| 432 |
+
""")
|
| 433 |
+
|
| 434 |
+
# Load models on startup
|
| 435 |
+
try:
|
| 436 |
+
load_models()
|
| 437 |
+
except Exception as e:
|
| 438 |
+
print(f"Warning: Could not pre-load models: {e}")
|
| 439 |
+
print("Models will be loaded on first use.")
|
| 440 |
+
|
| 441 |
+
if __name__ == "__main__":
|
| 442 |
+
demo.launch()
|
demo/gar_relationship.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License")
|
| 4 |
+
# Grasp Any Region Project
|
| 5 |
+
# Written by Haochen Wang
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import ast
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from transformers import AutoModel, AutoProcessor, GenerationConfig
|
| 15 |
+
|
| 16 |
+
from evaluation.eval_dataset import MultiRegionDataset
|
| 17 |
+
|
| 18 |
+
TORCH_DTYPE_MAP = dict(fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def parse_args():
|
| 22 |
+
parser = argparse.ArgumentParser(
|
| 23 |
+
description="Inference of Grasp Any Region models on DLC-Bench."
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--model_name_or_path",
|
| 28 |
+
help="HF model name or path",
|
| 29 |
+
default="HaochenWang/GAR-8B",
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--image_path",
|
| 33 |
+
help="image path",
|
| 34 |
+
required=True,
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--mask_paths",
|
| 38 |
+
help="mask path",
|
| 39 |
+
required=True,
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--question_str",
|
| 43 |
+
help="input instructions",
|
| 44 |
+
required=True,
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--data_type",
|
| 48 |
+
help="data dtype",
|
| 49 |
+
type=str,
|
| 50 |
+
choices=["fp16", "bf16", "fp32"],
|
| 51 |
+
default="bf16",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--seed",
|
| 55 |
+
type=int,
|
| 56 |
+
default=0,
|
| 57 |
+
help="Random seed for reproducible text generation",
|
| 58 |
+
)
|
| 59 |
+
args = parser.parse_args()
|
| 60 |
+
return args
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def select_ann(coco, img_id, area_min=None, area_max=None):
|
| 64 |
+
cat_ids = coco.getCatIds()
|
| 65 |
+
ann_ids = coco.getAnnIds(imgIds=[img_id], catIds=cat_ids, iscrowd=None)
|
| 66 |
+
|
| 67 |
+
if area_min is not None:
|
| 68 |
+
ann_ids = [
|
| 69 |
+
ann_id for ann_id in ann_ids if coco.anns[ann_id]["area"] >= area_min
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
if area_max is not None:
|
| 73 |
+
ann_ids = [
|
| 74 |
+
ann_id for ann_id in ann_ids if coco.anns[ann_id]["area"] <= area_max
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
return ann_ids
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def main():
|
| 81 |
+
args = parse_args()
|
| 82 |
+
data_dtype = TORCH_DTYPE_MAP[args.data_type]
|
| 83 |
+
torch.manual_seed(args.seed)
|
| 84 |
+
|
| 85 |
+
# init ditribution for dispatch_modules in LLM
|
| 86 |
+
torch.cuda.set_device(0)
|
| 87 |
+
torch.distributed.init_process_group(backend="nccl")
|
| 88 |
+
|
| 89 |
+
# build HF model
|
| 90 |
+
model = AutoModel.from_pretrained(
|
| 91 |
+
args.model_name_or_path,
|
| 92 |
+
trust_remote_code=True,
|
| 93 |
+
torch_dtype=data_dtype,
|
| 94 |
+
device_map="cuda:0",
|
| 95 |
+
).eval()
|
| 96 |
+
|
| 97 |
+
processor = AutoProcessor.from_pretrained(
|
| 98 |
+
args.model_name_or_path,
|
| 99 |
+
trust_remote_code=True,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
img = Image.open(args.image_path)
|
| 103 |
+
masks = []
|
| 104 |
+
for mask_path in ast.literal_eval(args.mask_paths):
|
| 105 |
+
mask = np.array(Image.open(mask_path).convert("L")).astype(bool)
|
| 106 |
+
masks.append(mask)
|
| 107 |
+
|
| 108 |
+
prompt_number = model.config.prompt_numbers
|
| 109 |
+
prompt_tokens = [f"<Prompt{i_p}>" for i_p in range(prompt_number)] + ["<NO_Prompt>"]
|
| 110 |
+
dataset = MultiRegionDataset(
|
| 111 |
+
image=img,
|
| 112 |
+
masks=masks,
|
| 113 |
+
question_str=args.question_str
|
| 114 |
+
+ "\nAnswer with the correct option's letter directly.",
|
| 115 |
+
processor=processor,
|
| 116 |
+
prompt_number=prompt_number,
|
| 117 |
+
visual_prompt_tokens=prompt_tokens,
|
| 118 |
+
data_dtype=data_dtype,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
data_sample = dataset[0]
|
| 122 |
+
|
| 123 |
+
with torch.no_grad():
|
| 124 |
+
generate_ids = model.generate(
|
| 125 |
+
**data_sample,
|
| 126 |
+
generation_config=GenerationConfig(
|
| 127 |
+
max_new_tokens=1024,
|
| 128 |
+
do_sample=False,
|
| 129 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
| 130 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
| 131 |
+
),
|
| 132 |
+
return_dict=True,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
outputs = processor.tokenizer.decode(
|
| 136 |
+
generate_ids.sequences[0], skip_special_tokens=True
|
| 137 |
+
).strip()
|
| 138 |
+
|
| 139 |
+
print(outputs) # Print model output for this image
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
if __name__ == "__main__":
|
| 143 |
+
main()
|
demo/gar_with_mask.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License")
|
| 4 |
+
# Grasp Any Region Project
|
| 5 |
+
# Written by Haochen Wang
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from transformers import AutoModel, AutoProcessor, GenerationConfig
|
| 14 |
+
|
| 15 |
+
from evaluation.eval_dataset import SingleRegionCaptionDataset
|
| 16 |
+
|
| 17 |
+
TORCH_DTYPE_MAP = dict(fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def parse_args():
|
| 21 |
+
parser = argparse.ArgumentParser(
|
| 22 |
+
description="Inference demo of Grasp Any Region models."
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--model_name_or_path",
|
| 27 |
+
help="HF model name or path",
|
| 28 |
+
default="HaochenWang/GAR-8B",
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--image_path",
|
| 32 |
+
help="image path",
|
| 33 |
+
required=True,
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--mask_path",
|
| 37 |
+
help="mask path",
|
| 38 |
+
required=True,
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--data_type",
|
| 42 |
+
help="data dtype",
|
| 43 |
+
type=str,
|
| 44 |
+
choices=["fp16", "bf16", "fp32"],
|
| 45 |
+
default="bf16",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--seed",
|
| 49 |
+
type=int,
|
| 50 |
+
default=0,
|
| 51 |
+
help="Random seed for reproducible text generation",
|
| 52 |
+
)
|
| 53 |
+
args = parser.parse_args()
|
| 54 |
+
return args
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def select_ann(coco, img_id, area_min=None, area_max=None):
|
| 58 |
+
cat_ids = coco.getCatIds()
|
| 59 |
+
ann_ids = coco.getAnnIds(imgIds=[img_id], catIds=cat_ids, iscrowd=None)
|
| 60 |
+
|
| 61 |
+
if area_min is not None:
|
| 62 |
+
ann_ids = [
|
| 63 |
+
ann_id for ann_id in ann_ids if coco.anns[ann_id]["area"] >= area_min
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
if area_max is not None:
|
| 67 |
+
ann_ids = [
|
| 68 |
+
ann_id for ann_id in ann_ids if coco.anns[ann_id]["area"] <= area_max
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
return ann_ids
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def main():
|
| 75 |
+
args = parse_args()
|
| 76 |
+
data_dtype = TORCH_DTYPE_MAP[args.data_type]
|
| 77 |
+
torch.manual_seed(args.seed)
|
| 78 |
+
|
| 79 |
+
# init ditribution for dispatch_modules in LLM
|
| 80 |
+
torch.cuda.set_device(0)
|
| 81 |
+
torch.distributed.init_process_group(backend="nccl")
|
| 82 |
+
|
| 83 |
+
# build HF model
|
| 84 |
+
model = AutoModel.from_pretrained(
|
| 85 |
+
args.model_name_or_path,
|
| 86 |
+
trust_remote_code=True,
|
| 87 |
+
torch_dtype=data_dtype,
|
| 88 |
+
device_map="cuda:0",
|
| 89 |
+
).eval()
|
| 90 |
+
|
| 91 |
+
processor = AutoProcessor.from_pretrained(
|
| 92 |
+
args.model_name_or_path,
|
| 93 |
+
trust_remote_code=True,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
img = Image.open(args.image_path)
|
| 97 |
+
mask = np.array(Image.open(args.mask_path).convert("L")).astype(bool)
|
| 98 |
+
|
| 99 |
+
prompt_number = model.config.prompt_numbers
|
| 100 |
+
prompt_tokens = [f"<Prompt{i_p}>" for i_p in range(prompt_number)] + ["<NO_Prompt>"]
|
| 101 |
+
dataset = SingleRegionCaptionDataset(
|
| 102 |
+
image=img,
|
| 103 |
+
mask=mask,
|
| 104 |
+
processor=processor,
|
| 105 |
+
prompt_number=prompt_number,
|
| 106 |
+
visual_prompt_tokens=prompt_tokens,
|
| 107 |
+
data_dtype=data_dtype,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
data_sample = dataset[0]
|
| 111 |
+
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
generate_ids = model.generate(
|
| 114 |
+
**data_sample,
|
| 115 |
+
generation_config=GenerationConfig(
|
| 116 |
+
max_new_tokens=1024,
|
| 117 |
+
do_sample=False,
|
| 118 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
| 119 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
| 120 |
+
),
|
| 121 |
+
return_dict=True,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
outputs = processor.tokenizer.decode(
|
| 125 |
+
generate_ids.sequences[0], skip_special_tokens=True
|
| 126 |
+
).strip()
|
| 127 |
+
|
| 128 |
+
print(outputs) # Print model output for this image
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
main()
|
demo/gar_with_sam.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# *************************************************************************
|
| 2 |
+
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
|
| 3 |
+
# difications”). All Bytedance Inc.'s Modifications are Copyright (2025) B-
|
| 4 |
+
# ytedance Inc..
|
| 5 |
+
# *************************************************************************
|
| 6 |
+
|
| 7 |
+
# Adapted from https://github.com/NVlabs/describe-anything/blob/main/examples/dam_with_sam.py
|
| 8 |
+
|
| 9 |
+
# Copyright 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 10 |
+
#
|
| 11 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 12 |
+
# you may not use this file except in compliance with the License.
|
| 13 |
+
# You may obtain a copy of the License at
|
| 14 |
+
#
|
| 15 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 16 |
+
#
|
| 17 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 18 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 19 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 20 |
+
# See the License for the specific language governing permissions and
|
| 21 |
+
# limitations under the License.
|
| 22 |
+
#
|
| 23 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import ast
|
| 27 |
+
|
| 28 |
+
import cv2
|
| 29 |
+
import numpy as np
|
| 30 |
+
import torch
|
| 31 |
+
from PIL import Image
|
| 32 |
+
from transformers import (
|
| 33 |
+
AutoModel,
|
| 34 |
+
AutoProcessor,
|
| 35 |
+
GenerationConfig,
|
| 36 |
+
SamModel,
|
| 37 |
+
SamProcessor,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
from evaluation.eval_dataset import SingleRegionCaptionDataset
|
| 41 |
+
|
| 42 |
+
TORCH_DTYPE_MAP = dict(fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def apply_sam(image, input_points=None, input_boxes=None, input_labels=None):
|
| 46 |
+
inputs = sam_processor(
|
| 47 |
+
image,
|
| 48 |
+
input_points=input_points,
|
| 49 |
+
input_boxes=input_boxes,
|
| 50 |
+
input_labels=input_labels,
|
| 51 |
+
return_tensors="pt",
|
| 52 |
+
).to(device)
|
| 53 |
+
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
outputs = sam_model(**inputs)
|
| 56 |
+
|
| 57 |
+
masks = sam_processor.image_processor.post_process_masks(
|
| 58 |
+
outputs.pred_masks.cpu(),
|
| 59 |
+
inputs["original_sizes"].cpu(),
|
| 60 |
+
inputs["reshaped_input_sizes"].cpu(),
|
| 61 |
+
)[0][0]
|
| 62 |
+
scores = outputs.iou_scores[0, 0]
|
| 63 |
+
|
| 64 |
+
mask_selection_index = scores.argmax()
|
| 65 |
+
|
| 66 |
+
mask_np = masks[mask_selection_index].numpy()
|
| 67 |
+
|
| 68 |
+
return mask_np
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def add_contour(img, mask, input_points=None, input_boxes=None):
|
| 72 |
+
img = img.copy()
|
| 73 |
+
|
| 74 |
+
# Draw contour
|
| 75 |
+
mask = mask.astype(np.uint8) * 255
|
| 76 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 77 |
+
cv2.drawContours(img, contours, -1, (1.0, 1.0, 1.0), thickness=6)
|
| 78 |
+
|
| 79 |
+
# Draw points if provided
|
| 80 |
+
if input_points is not None:
|
| 81 |
+
for points in input_points: # Handle batch of points
|
| 82 |
+
for x, y in points:
|
| 83 |
+
# Draw a filled circle for each point
|
| 84 |
+
cv2.circle(
|
| 85 |
+
img,
|
| 86 |
+
(int(x), int(y)),
|
| 87 |
+
radius=10,
|
| 88 |
+
color=(1.0, 0.0, 0.0),
|
| 89 |
+
thickness=-1,
|
| 90 |
+
)
|
| 91 |
+
# Draw a white border around the circle
|
| 92 |
+
cv2.circle(
|
| 93 |
+
img, (int(x), int(y)), radius=10, color=(1.0, 1.0, 1.0), thickness=2
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Draw boxes if provided
|
| 97 |
+
if input_boxes is not None:
|
| 98 |
+
for box_batch in input_boxes: # Handle batch of boxes
|
| 99 |
+
for box in box_batch: # Iterate through boxes in the batch
|
| 100 |
+
x1, y1, x2, y2 = map(int, box)
|
| 101 |
+
# Draw rectangle with white color
|
| 102 |
+
cv2.rectangle(
|
| 103 |
+
img, (x1, y1), (x2, y2), color=(1.0, 1.0, 1.0), thickness=4
|
| 104 |
+
)
|
| 105 |
+
# Draw inner rectangle with red color
|
| 106 |
+
cv2.rectangle(
|
| 107 |
+
img, (x1, y1), (x2, y2), color=(1.0, 0.0, 0.0), thickness=2
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
return img
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def denormalize_coordinates(coords, image_size, is_box=False):
|
| 114 |
+
"""Convert normalized coordinates (0-1) to pixel coordinates."""
|
| 115 |
+
width, height = image_size
|
| 116 |
+
if is_box:
|
| 117 |
+
# For boxes: [x1, y1, x2, y2]
|
| 118 |
+
x1, y1, x2, y2 = coords
|
| 119 |
+
return [int(x1 * width), int(y1 * height), int(x2 * width), int(y2 * height)]
|
| 120 |
+
else:
|
| 121 |
+
# For points: [x, y]
|
| 122 |
+
x, y = coords
|
| 123 |
+
return [int(x * width), int(y * height)]
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def print_streaming(text):
|
| 127 |
+
"""Helper function to print streaming text with flush"""
|
| 128 |
+
print(text, end="", flush=True)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
parser = argparse.ArgumentParser(
|
| 133 |
+
description="Detailed Localized Image Descriptions with SAM"
|
| 134 |
+
)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--model_name_or_path",
|
| 137 |
+
help="HF model name or path",
|
| 138 |
+
default="HaochenWang/GAR-8B",
|
| 139 |
+
)
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--image_path", type=str, required=True, help="Path to the image file"
|
| 142 |
+
)
|
| 143 |
+
parser.add_argument(
|
| 144 |
+
"--points",
|
| 145 |
+
type=str,
|
| 146 |
+
default="[[1172, 812], [1572, 800]]",
|
| 147 |
+
help="List of points for SAM input",
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--box",
|
| 151 |
+
type=str,
|
| 152 |
+
default="[773, 518, 1172, 812]",
|
| 153 |
+
help="Bounding box for SAM input (x1, y1, x2, y2)",
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--use_box",
|
| 157 |
+
action="store_true",
|
| 158 |
+
help="Use box instead of points for SAM input (default: use points)",
|
| 159 |
+
)
|
| 160 |
+
parser.add_argument(
|
| 161 |
+
"--normalized_coords",
|
| 162 |
+
action="store_true",
|
| 163 |
+
help="Interpret coordinates as normalized (0-1) values",
|
| 164 |
+
)
|
| 165 |
+
parser.add_argument(
|
| 166 |
+
"--output_image_path",
|
| 167 |
+
type=str,
|
| 168 |
+
default=None,
|
| 169 |
+
help="Path to save the output image with contour",
|
| 170 |
+
)
|
| 171 |
+
parser.add_argument(
|
| 172 |
+
"--data_type",
|
| 173 |
+
help="data dtype",
|
| 174 |
+
type=str,
|
| 175 |
+
choices=["fp16", "bf16", "fp32"],
|
| 176 |
+
default="bf16",
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
args = parser.parse_args()
|
| 180 |
+
data_dtype = TORCH_DTYPE_MAP[args.data_type]
|
| 181 |
+
|
| 182 |
+
# Load the image
|
| 183 |
+
img = Image.open(args.image_path).convert("RGB")
|
| 184 |
+
|
| 185 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 186 |
+
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
| 187 |
+
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
| 188 |
+
|
| 189 |
+
image_size = img.size # (width, height)
|
| 190 |
+
|
| 191 |
+
# Prepare input_points or input_boxes
|
| 192 |
+
if args.use_box:
|
| 193 |
+
input_boxes = ast.literal_eval(args.box)
|
| 194 |
+
if args.normalized_coords:
|
| 195 |
+
input_boxes = denormalize_coordinates(input_boxes, image_size, is_box=True)
|
| 196 |
+
input_boxes = [[input_boxes]] # Add an extra level of nesting
|
| 197 |
+
print(f"Using input_boxes: {input_boxes}")
|
| 198 |
+
mask_np = apply_sam(img, input_boxes=input_boxes)
|
| 199 |
+
else:
|
| 200 |
+
input_points = ast.literal_eval(args.points)
|
| 201 |
+
if args.normalized_coords:
|
| 202 |
+
input_points = [
|
| 203 |
+
denormalize_coordinates(point, image_size) for point in input_points
|
| 204 |
+
]
|
| 205 |
+
# Assume all points are foreground
|
| 206 |
+
input_labels = [1] * len(input_points)
|
| 207 |
+
input_points = [[x, y] for x, y in input_points] # Convert to list of lists
|
| 208 |
+
input_points = [input_points] # Wrap in outer list
|
| 209 |
+
input_labels = [input_labels] # Wrap labels in list
|
| 210 |
+
print(f"Using input_points: {input_points}")
|
| 211 |
+
mask_np = apply_sam(img, input_points=input_points, input_labels=input_labels)
|
| 212 |
+
|
| 213 |
+
# build HF model
|
| 214 |
+
model = AutoModel.from_pretrained(
|
| 215 |
+
args.model_name_or_path,
|
| 216 |
+
trust_remote_code=True,
|
| 217 |
+
torch_dtype=data_dtype,
|
| 218 |
+
device_map="cuda:0",
|
| 219 |
+
).eval()
|
| 220 |
+
|
| 221 |
+
processor = AutoProcessor.from_pretrained(
|
| 222 |
+
args.model_name_or_path,
|
| 223 |
+
trust_remote_code=True,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Get description
|
| 227 |
+
prompt_number = model.config.prompt_numbers
|
| 228 |
+
prompt_tokens = [f"<Prompt{i_p}>" for i_p in range(prompt_number)] + ["<NO_Prompt>"]
|
| 229 |
+
dataset = SingleRegionCaptionDataset(
|
| 230 |
+
image=img,
|
| 231 |
+
mask=mask_np,
|
| 232 |
+
processor=processor,
|
| 233 |
+
prompt_number=prompt_number,
|
| 234 |
+
visual_prompt_tokens=prompt_tokens,
|
| 235 |
+
data_dtype=data_dtype,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
data_sample = dataset[0]
|
| 239 |
+
|
| 240 |
+
with torch.no_grad():
|
| 241 |
+
generate_ids = model.generate(
|
| 242 |
+
**data_sample,
|
| 243 |
+
generation_config=GenerationConfig(
|
| 244 |
+
max_new_tokens=1024,
|
| 245 |
+
do_sample=False,
|
| 246 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
| 247 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
| 248 |
+
),
|
| 249 |
+
return_dict=True,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
outputs = processor.tokenizer.decode(
|
| 253 |
+
generate_ids.sequences[0], skip_special_tokens=True
|
| 254 |
+
).strip()
|
| 255 |
+
|
| 256 |
+
print(outputs) # Print model output for this image
|
| 257 |
+
|
| 258 |
+
if args.output_image_path:
|
| 259 |
+
img_np = np.asarray(img).astype(float) / 255.0
|
| 260 |
+
|
| 261 |
+
# Prepare visualization inputs
|
| 262 |
+
vis_points = input_points if not args.use_box else None
|
| 263 |
+
vis_boxes = input_boxes if args.use_box else None
|
| 264 |
+
|
| 265 |
+
img_with_contour_np = add_contour(
|
| 266 |
+
img_np, mask_np, input_points=vis_points, input_boxes=vis_boxes
|
| 267 |
+
)
|
| 268 |
+
img_with_contour_pil = Image.fromarray(
|
| 269 |
+
(img_with_contour_np * 255.0).astype(np.uint8)
|
| 270 |
+
)
|
| 271 |
+
img_with_contour_pil.save(args.output_image_path)
|
| 272 |
+
print(f"Output image with contour saved as {args.output_image_path}")
|
demo/gradio/.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
demo/gradio/README.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Please install segment-anything package through:
|
| 2 |
+
```
|
| 3 |
+
pip install git+https://github.com/facebookresearch/segment-anything.git
|
| 4 |
+
```
|
| 5 |
+
|
| 6 |
+
This demo is based on the Segment Anything demo under Apache 2.0 license. Please refer to the [Segment Anything LICENSE](https://github.com/facebookresearch/segment-anything/blob/main/LICENSE) for more details.
|
| 7 |
+
|
| 8 |
+
## Run the demo
|
| 9 |
+
```
|
| 10 |
+
python demo/gradio/app.py
|
| 11 |
+
```
|
demo/gradio/app.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# *************************************************************************
|
| 2 |
+
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
|
| 3 |
+
# difications”). All Bytedance Inc.'s Modifications are Copyright (2025) B-
|
| 4 |
+
# ytedance Inc..
|
| 5 |
+
# *************************************************************************
|
| 6 |
+
|
| 7 |
+
# Adapted from https://github.com/NVlabs/describe-anything/blob/main/examples/dam_with_sam.py
|
| 8 |
+
|
| 9 |
+
# Copyright 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 10 |
+
#
|
| 11 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 12 |
+
# you may not use this file except in compliance with the License.
|
| 13 |
+
# You may obtain a copy of the License at
|
| 14 |
+
#
|
| 15 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 16 |
+
#
|
| 17 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 18 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 19 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 20 |
+
# See the License for the specific language governing permissions and
|
| 21 |
+
# limitations under the License.
|
| 22 |
+
#
|
| 23 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import base64
|
| 27 |
+
import io
|
| 28 |
+
|
| 29 |
+
import cv2
|
| 30 |
+
import gradio as gr
|
| 31 |
+
import numpy as np
|
| 32 |
+
import torch
|
| 33 |
+
from fastapi import FastAPI
|
| 34 |
+
from fastapi.staticfiles import StaticFiles
|
| 35 |
+
from PIL import Image
|
| 36 |
+
from segment_anything import SamPredictor, sam_model_registry
|
| 37 |
+
from transformers import (
|
| 38 |
+
AutoModel,
|
| 39 |
+
AutoProcessor,
|
| 40 |
+
GenerationConfig,
|
| 41 |
+
SamModel,
|
| 42 |
+
SamProcessor,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
from spaces import GPU
|
| 47 |
+
except ImportError:
|
| 48 |
+
print("Spaces not installed, using dummy GPU decorator")
|
| 49 |
+
|
| 50 |
+
def GPU(*args, **kwargs):
|
| 51 |
+
def decorator(fn):
|
| 52 |
+
return fn
|
| 53 |
+
|
| 54 |
+
return decorator
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
from evaluation.eval_dataset import SingleRegionCaptionDataset
|
| 58 |
+
|
| 59 |
+
# Load SAM model
|
| 60 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 61 |
+
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
| 62 |
+
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
| 63 |
+
|
| 64 |
+
# Initialize the captioning model and processor
|
| 65 |
+
model_path = "HaochenWang/GAR-1B"
|
| 66 |
+
model = AutoModel.from_pretrained(
|
| 67 |
+
model_path,
|
| 68 |
+
trust_remote_code=True,
|
| 69 |
+
torch_dtype=torch.bfloat16,
|
| 70 |
+
device_map="cuda:0",
|
| 71 |
+
).eval()
|
| 72 |
+
|
| 73 |
+
processor = AutoProcessor.from_pretrained(
|
| 74 |
+
model_path,
|
| 75 |
+
trust_remote_code=True,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@GPU(duration=75)
|
| 80 |
+
def image_to_sam_embedding(base64_image):
|
| 81 |
+
try:
|
| 82 |
+
# Decode base64 string to bytes
|
| 83 |
+
image_bytes = base64.b64decode(base64_image)
|
| 84 |
+
|
| 85 |
+
# Convert bytes to PIL Image
|
| 86 |
+
image = Image.open(io.BytesIO(image_bytes))
|
| 87 |
+
|
| 88 |
+
# Process image with SAM processor
|
| 89 |
+
inputs = sam_processor(image, return_tensors="pt").to(device)
|
| 90 |
+
|
| 91 |
+
# Get image embedding
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
image_embedding = sam_model.get_image_embeddings(inputs["pixel_values"])
|
| 94 |
+
|
| 95 |
+
# Convert to CPU and numpy
|
| 96 |
+
image_embedding = image_embedding.cpu().numpy()
|
| 97 |
+
|
| 98 |
+
# Encode the embedding as base64
|
| 99 |
+
embedding_bytes = image_embedding.tobytes()
|
| 100 |
+
embedding_base64 = base64.b64encode(embedding_bytes).decode("utf-8")
|
| 101 |
+
|
| 102 |
+
return embedding_base64
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"Error processing image: {str(e)}")
|
| 105 |
+
raise gr.Error(f"Failed to process image: {str(e)}")
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@GPU(duration=75)
|
| 109 |
+
def describe(image_base64: str, mask_base64: str, query: str):
|
| 110 |
+
# Convert base64 to PIL Image
|
| 111 |
+
image_bytes = base64.b64decode(
|
| 112 |
+
image_base64.split(",")[1] if "," in image_base64 else image_base64
|
| 113 |
+
)
|
| 114 |
+
img = Image.open(io.BytesIO(image_bytes))
|
| 115 |
+
mask_bytes = base64.b64decode(
|
| 116 |
+
mask_base64.split(",")[1] if "," in mask_base64 else mask_base64
|
| 117 |
+
)
|
| 118 |
+
mask = Image.open(io.BytesIO(mask_bytes))
|
| 119 |
+
mask = np.array(mask.convert("L"))
|
| 120 |
+
|
| 121 |
+
prompt_number = model.config.prompt_numbers
|
| 122 |
+
prompt_tokens = [f"<Prompt{i_p}>" for i_p in range(prompt_number)] + ["<NO_Prompt>"]
|
| 123 |
+
|
| 124 |
+
# Assuming mask is given as a numpy array and the image is a PIL image
|
| 125 |
+
dataset = SingleRegionCaptionDataset(
|
| 126 |
+
image=img,
|
| 127 |
+
mask=mask,
|
| 128 |
+
processor=processor,
|
| 129 |
+
prompt_number=prompt_number,
|
| 130 |
+
visual_prompt_tokens=prompt_tokens,
|
| 131 |
+
data_dtype=torch.bfloat16,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
data_sample = dataset[0]
|
| 135 |
+
|
| 136 |
+
# Generate the caption
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
generate_ids = model.generate(
|
| 139 |
+
**data_sample,
|
| 140 |
+
generation_config=GenerationConfig(
|
| 141 |
+
max_new_tokens=1024,
|
| 142 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
| 143 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
| 144 |
+
),
|
| 145 |
+
return_dict=True,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
output_caption = processor.tokenizer.decode(
|
| 149 |
+
generate_ids.sequences[0], skip_special_tokens=True
|
| 150 |
+
).strip()
|
| 151 |
+
|
| 152 |
+
# Stream the tokens
|
| 153 |
+
text = ""
|
| 154 |
+
for token in output_caption:
|
| 155 |
+
text += token
|
| 156 |
+
yield text
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@GPU(duration=75)
|
| 160 |
+
def describe_without_streaming(image_base64: str, mask_base64: str, query: str):
|
| 161 |
+
# Convert base64 to PIL Image
|
| 162 |
+
image_bytes = base64.b64decode(
|
| 163 |
+
image_base64.split(",")[1] if "," in image_base64 else image_base64
|
| 164 |
+
)
|
| 165 |
+
img = Image.open(io.BytesIO(image_bytes))
|
| 166 |
+
mask_bytes = base64.b64decode(
|
| 167 |
+
mask_base64.split(",")[1] if "," in mask_base64 else mask_base64
|
| 168 |
+
)
|
| 169 |
+
mask = Image.open(io.BytesIO(mask_bytes))
|
| 170 |
+
mask = np.array(mask.convert("L"))
|
| 171 |
+
prompt_number = model.config.prompt_numbers
|
| 172 |
+
prompt_tokens = [f"<Prompt{i_p}>" for i_p in range(prompt_number)] + ["<NO_Prompt>"]
|
| 173 |
+
|
| 174 |
+
# Assuming mask is given as a numpy array and the image is a PIL image
|
| 175 |
+
dataset = SingleRegionCaptionDataset(
|
| 176 |
+
image=img,
|
| 177 |
+
mask=mask,
|
| 178 |
+
processor=processor,
|
| 179 |
+
prompt_number=prompt_number,
|
| 180 |
+
visual_prompt_tokens=prompt_tokens,
|
| 181 |
+
data_dtype=torch.bfloat16,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
data_sample = dataset[0]
|
| 185 |
+
|
| 186 |
+
# Generate the caption
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
generate_ids = model.generate(
|
| 189 |
+
**data_sample,
|
| 190 |
+
generation_config=GenerationConfig(
|
| 191 |
+
max_new_tokens=1024,
|
| 192 |
+
# do_sample=False,
|
| 193 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
| 194 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
| 195 |
+
),
|
| 196 |
+
return_dict=True,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
output_caption = processor.tokenizer.decode(
|
| 200 |
+
generate_ids.sequences[0], skip_special_tokens=True
|
| 201 |
+
).strip()
|
| 202 |
+
|
| 203 |
+
return output_caption
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
if __name__ == "__main__":
|
| 207 |
+
parser = argparse.ArgumentParser(description="Describe Anything gradio demo")
|
| 208 |
+
parser.add_argument(
|
| 209 |
+
"--server_addr",
|
| 210 |
+
"--host",
|
| 211 |
+
type=str,
|
| 212 |
+
default=None,
|
| 213 |
+
help="The server address to listen on.",
|
| 214 |
+
)
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--server_port", "--port", type=int, default=None, help="The port to listen on."
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
args = parser.parse_args()
|
| 220 |
+
|
| 221 |
+
# Create Gradio interface
|
| 222 |
+
with gr.Blocks() as demo:
|
| 223 |
+
gr.Interface(
|
| 224 |
+
fn=image_to_sam_embedding,
|
| 225 |
+
inputs=gr.Textbox(label="Image Base64"),
|
| 226 |
+
outputs=gr.Textbox(label="Embedding Base64"),
|
| 227 |
+
title="Image Embedding Generator",
|
| 228 |
+
api_name="image_to_sam_embedding",
|
| 229 |
+
)
|
| 230 |
+
gr.Interface(
|
| 231 |
+
fn=describe,
|
| 232 |
+
inputs=[
|
| 233 |
+
gr.Textbox(label="Image Base64"),
|
| 234 |
+
gr.Text(label="Mask Base64"),
|
| 235 |
+
gr.Text(label="Prompt"),
|
| 236 |
+
],
|
| 237 |
+
outputs=[gr.Text(label="Description")],
|
| 238 |
+
title="Mask Description Generator",
|
| 239 |
+
api_name="describe",
|
| 240 |
+
)
|
| 241 |
+
gr.Interface(
|
| 242 |
+
fn=describe_without_streaming,
|
| 243 |
+
inputs=[
|
| 244 |
+
gr.Textbox(label="Image Base64"),
|
| 245 |
+
gr.Text(label="Mask Base64"),
|
| 246 |
+
gr.Text(label="Prompt"),
|
| 247 |
+
],
|
| 248 |
+
outputs=[gr.Text(label="Description")],
|
| 249 |
+
title="Mask Description Generator (Non-Streaming)",
|
| 250 |
+
api_name="describe_without_streaming",
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
demo._block_thread = demo.block_thread
|
| 254 |
+
demo.block_thread = lambda: None
|
| 255 |
+
demo.launch(
|
| 256 |
+
share=True,
|
| 257 |
+
server_name=args.server_addr,
|
| 258 |
+
server_port=args.server_port,
|
| 259 |
+
ssr_mode=False,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
for route in demo.app.routes:
|
| 263 |
+
if route.path == "/":
|
| 264 |
+
demo.app.routes.remove(route)
|
| 265 |
+
demo.app.mount("/", StaticFiles(directory="dist", html=True), name="demo")
|
| 266 |
+
|
| 267 |
+
demo._block_thread()
|
demo/gradio/frontend/README.md
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Segment Anything Simple Web demo
|
| 2 |
+
|
| 3 |
+
This **front-end only** React based web demo shows how to load a fixed image and corresponding `.npy` file of the SAM image embedding, and run the SAM ONNX model in the browser using Web Assembly with mulithreading enabled by `SharedArrayBuffer`, Web Worker, and SIMD128.
|
| 4 |
+
|
| 5 |
+
<img src="https://github.com/facebookresearch/segment-anything/raw/main/assets/minidemo.gif" width="500"/>
|
| 6 |
+
|
| 7 |
+
## Run the app
|
| 8 |
+
|
| 9 |
+
Install Yarn
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
npm install --g yarn
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
Build and run:
|
| 16 |
+
|
| 17 |
+
```
|
| 18 |
+
yarn && yarn start
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
Navigate to [`http://localhost:8081/`](http://localhost:8081/)
|
| 22 |
+
|
| 23 |
+
Move your cursor around to see the mask prediction update in real time.
|
| 24 |
+
|
| 25 |
+
## Export the image embedding
|
| 26 |
+
|
| 27 |
+
In the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) upload the image of your choice and generate and save corresponding embedding.
|
| 28 |
+
|
| 29 |
+
Initialize the predictor:
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
checkpoint = "sam_vit_h_4b8939.pth"
|
| 33 |
+
model_type = "vit_h"
|
| 34 |
+
sam = sam_model_registry[model_type](checkpoint=checkpoint)
|
| 35 |
+
sam.to(device='cuda')
|
| 36 |
+
predictor = SamPredictor(sam)
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
Set the new image and export the embedding:
|
| 40 |
+
|
| 41 |
+
```
|
| 42 |
+
image = cv2.imread('src/assets/dogs.jpg')
|
| 43 |
+
predictor.set_image(image)
|
| 44 |
+
image_embedding = predictor.get_image_embedding().cpu().numpy()
|
| 45 |
+
np.save("dogs_embedding.npy", image_embedding)
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
Save the new image and embedding in `src/assets/data`.
|
| 49 |
+
|
| 50 |
+
## Export the ONNX model
|
| 51 |
+
|
| 52 |
+
You also need to export the quantized ONNX model from the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb).
|
| 53 |
+
|
| 54 |
+
Run the cell in the notebook which saves the `sam_onnx_quantized_example.onnx` file, download it and copy it to the path `/model/sam_onnx_quantized_example.onnx`.
|
| 55 |
+
|
| 56 |
+
Here is a snippet of the export/quantization code:
|
| 57 |
+
|
| 58 |
+
```
|
| 59 |
+
onnx_model_path = "sam_onnx_example.onnx"
|
| 60 |
+
onnx_model_quantized_path = "sam_onnx_quantized_example.onnx"
|
| 61 |
+
quantize_dynamic(
|
| 62 |
+
model_input=onnx_model_path,
|
| 63 |
+
model_output=onnx_model_quantized_path,
|
| 64 |
+
optimize_model=True,
|
| 65 |
+
per_channel=False,
|
| 66 |
+
reduce_range=False,
|
| 67 |
+
weight_type=QuantType.QUInt8,
|
| 68 |
+
)
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
**NOTE: if you change the ONNX model by using a new checkpoint you need to also re-export the embedding.**
|
| 72 |
+
|
| 73 |
+
## Update the image, embedding, model in the app
|
| 74 |
+
|
| 75 |
+
Update the following file paths at the top of`App.tsx`:
|
| 76 |
+
|
| 77 |
+
```py
|
| 78 |
+
const IMAGE_PATH = "/assets/data/dogs.jpg";
|
| 79 |
+
const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy";
|
| 80 |
+
const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx";
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
## ONNX multithreading with SharedArrayBuffer
|
| 84 |
+
|
| 85 |
+
To use multithreading, the appropriate headers need to be set to create a cross origin isolation state which will enable use of `SharedArrayBuffer` (see this [blog post](https://cloudblogs.microsoft.com/opensource/2021/09/02/onnx-runtime-web-running-your-machine-learning-model-in-browser/) for more details)
|
| 86 |
+
|
| 87 |
+
The headers below are set in `configs/webpack/dev.js`:
|
| 88 |
+
|
| 89 |
+
```js
|
| 90 |
+
headers: {
|
| 91 |
+
"Cross-Origin-Opener-Policy": "same-origin",
|
| 92 |
+
"Cross-Origin-Embedder-Policy": "credentialless",
|
| 93 |
+
}
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
## Structure of the app
|
| 97 |
+
|
| 98 |
+
**`App.tsx`**
|
| 99 |
+
|
| 100 |
+
- Initializes ONNX model
|
| 101 |
+
- Loads image embedding and image
|
| 102 |
+
- Runs the ONNX model based on input prompts
|
| 103 |
+
|
| 104 |
+
**`Stage.tsx`**
|
| 105 |
+
|
| 106 |
+
- Handles mouse move interaction to update the ONNX model prompt
|
| 107 |
+
|
| 108 |
+
**`Tool.tsx`**
|
| 109 |
+
|
| 110 |
+
- Renders the image and the mask prediction
|
| 111 |
+
|
| 112 |
+
**`helpers/maskUtils.tsx`**
|
| 113 |
+
|
| 114 |
+
- Conversion of ONNX model output from array to an HTMLImageElement
|
| 115 |
+
|
| 116 |
+
**`helpers/onnxModelAPI.tsx`**
|
| 117 |
+
|
| 118 |
+
- Formats the inputs for the ONNX model
|
| 119 |
+
|
| 120 |
+
**`helpers/scaleHelper.tsx`**
|
| 121 |
+
|
| 122 |
+
- Handles image scaling logic for SAM (longest size 1024)
|
| 123 |
+
|
| 124 |
+
**`hooks/`**
|
| 125 |
+
|
| 126 |
+
- Handle shared state for the app
|
demo/gradio/frontend/configs/webpack/common.js
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
const { resolve } = require("path");
|
| 8 |
+
const HtmlWebpackPlugin = require("html-webpack-plugin");
|
| 9 |
+
const FriendlyErrorsWebpackPlugin = require("friendly-errors-webpack-plugin");
|
| 10 |
+
const CopyPlugin = require("copy-webpack-plugin");
|
| 11 |
+
const webpack = require("webpack");
|
| 12 |
+
|
| 13 |
+
module.exports = {
|
| 14 |
+
entry: "./src/index.tsx",
|
| 15 |
+
resolve: {
|
| 16 |
+
extensions: [".js", ".jsx", ".ts", ".tsx"],
|
| 17 |
+
fallback: { 'process/browser': require.resolve('process/browser'), }
|
| 18 |
+
},
|
| 19 |
+
output: {
|
| 20 |
+
path: resolve(__dirname, "dist"),
|
| 21 |
+
},
|
| 22 |
+
module: {
|
| 23 |
+
rules: [
|
| 24 |
+
{
|
| 25 |
+
test: /\.mjs$/,
|
| 26 |
+
include: /node_modules/,
|
| 27 |
+
type: "javascript/auto",
|
| 28 |
+
resolve: {
|
| 29 |
+
fullySpecified: false,
|
| 30 |
+
},
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
test: [/\.jsx?$/, /\.tsx?$/],
|
| 34 |
+
use: ["ts-loader"],
|
| 35 |
+
exclude: /node_modules/,
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
test: /\.css$/,
|
| 39 |
+
use: ["style-loader", "css-loader"],
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
test: /\.(scss|sass)$/,
|
| 43 |
+
use: ["style-loader", "css-loader", "postcss-loader"],
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
test: /\.(jpe?g|png|gif|svg)$/i,
|
| 47 |
+
use: [
|
| 48 |
+
"file-loader?hash=sha512&digest=hex&name=img/[contenthash].[ext]",
|
| 49 |
+
"image-webpack-loader?bypassOnDebug&optipng.optimizationLevel=7&gifsicle.interlaced=false",
|
| 50 |
+
],
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
test: /\.(woff|woff2|ttf)$/,
|
| 54 |
+
use: {
|
| 55 |
+
loader: "url-loader",
|
| 56 |
+
},
|
| 57 |
+
},
|
| 58 |
+
],
|
| 59 |
+
},
|
| 60 |
+
plugins: [
|
| 61 |
+
new CopyPlugin({
|
| 62 |
+
patterns: [
|
| 63 |
+
{
|
| 64 |
+
from: "node_modules/onnxruntime-web/dist/*.wasm",
|
| 65 |
+
to: "[name][ext]",
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
from: "model",
|
| 69 |
+
to: "model",
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
from: "src/assets/examples",
|
| 73 |
+
to: "examples",
|
| 74 |
+
},
|
| 75 |
+
],
|
| 76 |
+
}),
|
| 77 |
+
new HtmlWebpackPlugin({
|
| 78 |
+
template: "./src/assets/index.html",
|
| 79 |
+
}),
|
| 80 |
+
new FriendlyErrorsWebpackPlugin(),
|
| 81 |
+
new webpack.ProvidePlugin({
|
| 82 |
+
process: "process/browser",
|
| 83 |
+
}),
|
| 84 |
+
],
|
| 85 |
+
};
|
demo/gradio/frontend/configs/webpack/dev.js
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
// development config
|
| 8 |
+
const { merge } = require("webpack-merge");
|
| 9 |
+
const commonConfig = require("./common");
|
| 10 |
+
|
| 11 |
+
module.exports = merge(commonConfig, {
|
| 12 |
+
mode: "development",
|
| 13 |
+
devServer: {
|
| 14 |
+
hot: true, // enable HMR on the server
|
| 15 |
+
open: true,
|
| 16 |
+
// These headers enable the cross origin isolation state
|
| 17 |
+
// needed to enable use of SharedArrayBuffer for ONNX
|
| 18 |
+
// multithreading.
|
| 19 |
+
headers: {
|
| 20 |
+
"Cross-Origin-Opener-Policy": "same-origin",
|
| 21 |
+
"Cross-Origin-Embedder-Policy": "credentialless",
|
| 22 |
+
},
|
| 23 |
+
},
|
| 24 |
+
devtool: "cheap-module-source-map",
|
| 25 |
+
});
|
demo/gradio/frontend/configs/webpack/prod.js
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
// production config
|
| 8 |
+
const { merge } = require("webpack-merge");
|
| 9 |
+
const { resolve } = require("path");
|
| 10 |
+
const Dotenv = require("dotenv-webpack");
|
| 11 |
+
const commonConfig = require("./common");
|
| 12 |
+
|
| 13 |
+
module.exports = merge(commonConfig, {
|
| 14 |
+
mode: "production",
|
| 15 |
+
output: {
|
| 16 |
+
filename: "js/bundle.[contenthash].min.js",
|
| 17 |
+
path: resolve(__dirname, "../../dist"),
|
| 18 |
+
publicPath: "/",
|
| 19 |
+
},
|
| 20 |
+
devtool: "source-map",
|
| 21 |
+
plugins: [new Dotenv()],
|
| 22 |
+
});
|
demo/gradio/frontend/package.json
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "segment-anything-mini-demo",
|
| 3 |
+
"version": "0.1.0",
|
| 4 |
+
"license": "MIT",
|
| 5 |
+
"scripts": {
|
| 6 |
+
"build": "yarn run clean-dist && webpack --config=configs/webpack/prod.js && mv dist/*.wasm dist/js && rsync -r --delete dist ../",
|
| 7 |
+
"clean-dist": "rimraf dist/*",
|
| 8 |
+
"lint": "eslint './src/**/*.{js,ts,tsx}' --quiet",
|
| 9 |
+
"start": "yarn run start-dev",
|
| 10 |
+
"test": "yarn run start-model-test",
|
| 11 |
+
"start-dev": "webpack serve --config=configs/webpack/dev.js"
|
| 12 |
+
},
|
| 13 |
+
"devDependencies": {
|
| 14 |
+
"@babel/core": "^7.18.13",
|
| 15 |
+
"@babel/preset-env": "^7.18.10",
|
| 16 |
+
"@babel/preset-react": "^7.18.6",
|
| 17 |
+
"@babel/preset-typescript": "^7.18.6",
|
| 18 |
+
"@pmmmwh/react-refresh-webpack-plugin": "^0.5.7",
|
| 19 |
+
"@testing-library/react": "^13.3.0",
|
| 20 |
+
"@types/node": "^18.7.13",
|
| 21 |
+
"@types/react": "^18.0.17",
|
| 22 |
+
"@types/react-dom": "^18.0.6",
|
| 23 |
+
"@types/underscore": "^1.11.4",
|
| 24 |
+
"@typescript-eslint/eslint-plugin": "^5.35.1",
|
| 25 |
+
"@typescript-eslint/parser": "^5.35.1",
|
| 26 |
+
"babel-loader": "^8.2.5",
|
| 27 |
+
"copy-webpack-plugin": "^11.0.0",
|
| 28 |
+
"css-loader": "^6.7.1",
|
| 29 |
+
"dotenv": "^16.0.2",
|
| 30 |
+
"dotenv-webpack": "^8.0.1",
|
| 31 |
+
"eslint": "^8.22.0",
|
| 32 |
+
"eslint-plugin-react": "^7.31.0",
|
| 33 |
+
"file-loader": "^6.2.0",
|
| 34 |
+
"fork-ts-checker-webpack-plugin": "^7.2.13",
|
| 35 |
+
"friendly-errors-webpack-plugin": "^1.7.0",
|
| 36 |
+
"html-webpack-plugin": "^5.5.0",
|
| 37 |
+
"image-webpack-loader": "^8.1.0",
|
| 38 |
+
"postcss-loader": "^7.0.1",
|
| 39 |
+
"postcss-preset-env": "^7.8.0",
|
| 40 |
+
"process": "^0.11.10",
|
| 41 |
+
"rimraf": "^3.0.2",
|
| 42 |
+
"sass": "^1.54.5",
|
| 43 |
+
"sass-loader": "^13.0.2",
|
| 44 |
+
"style-loader": "^3.3.1",
|
| 45 |
+
"tailwindcss": "^3.1.8",
|
| 46 |
+
"ts-loader": "^9.3.1",
|
| 47 |
+
"typescript": "^4.8.2",
|
| 48 |
+
"webpack": "^5.74.0",
|
| 49 |
+
"webpack-cli": "^4.10.0",
|
| 50 |
+
"webpack-dev-server": "^4.10.0",
|
| 51 |
+
"webpack-dotenv-plugin": "^2.1.0",
|
| 52 |
+
"webpack-merge": "^5.8.0"
|
| 53 |
+
},
|
| 54 |
+
"dependencies": {
|
| 55 |
+
"@gradio/client": "^1.7.1",
|
| 56 |
+
"npyjs": "^0.4.0",
|
| 57 |
+
"onnxruntime-web": "1.14.0",
|
| 58 |
+
"react": "^18.2.0",
|
| 59 |
+
"react-dom": "^18.2.0",
|
| 60 |
+
"react-refresh": "^0.14.0",
|
| 61 |
+
"underscore": "^1.13.6",
|
| 62 |
+
"axios": "^1.6.7"
|
| 63 |
+
}
|
| 64 |
+
}
|
demo/gradio/frontend/postcss.config.js
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
const tailwindcss = require("tailwindcss");
|
| 8 |
+
module.exports = {
|
| 9 |
+
plugins: ["postcss-preset-env", 'tailwindcss/nesting', tailwindcss],
|
| 10 |
+
};
|
demo/gradio/frontend/src/App.tsx
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import { InferenceSession, Tensor } from "onnxruntime-web";
|
| 8 |
+
import React, { useContext, useEffect, useState, useRef } from "react";
|
| 9 |
+
import axios from "axios";
|
| 10 |
+
import "./assets/scss/App.scss";
|
| 11 |
+
import { handleImageScale } from "./components/helpers/scaleHelper";
|
| 12 |
+
import { modelScaleProps, QueueStatus } from "./components/helpers/Interfaces";
|
| 13 |
+
import { onnxMaskToImage, arrayToImageData, imageDataToURL } from "./components/helpers/maskUtils";
|
| 14 |
+
import { modelData } from "./components/helpers/onnxModelAPI";
|
| 15 |
+
import Stage, { DescriptionState } from "./components/Stage";
|
| 16 |
+
import AppContext from "./components/hooks/createContext";
|
| 17 |
+
import { imageToSamEmbedding } from "./services/maskApi";
|
| 18 |
+
import LoadingOverlay from "./components/LoadingOverlay";
|
| 19 |
+
import ErrorModal from './components/ErrorModal';
|
| 20 |
+
import QueueStatusIndicator from "./components/QueueStatusIndicator";
|
| 21 |
+
|
| 22 |
+
const ort = require("onnxruntime-web");
|
| 23 |
+
|
| 24 |
+
// Define image and model paths
|
| 25 |
+
const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx";
|
| 26 |
+
|
| 27 |
+
const App = () => {
|
| 28 |
+
const {
|
| 29 |
+
clicks: [clicks, setClicks],
|
| 30 |
+
image: [image, setImage],
|
| 31 |
+
maskImg: [maskImg, setMaskImg],
|
| 32 |
+
maskImgData: [maskImgData, setMaskImgData],
|
| 33 |
+
isClicked: [isClicked, setIsClicked]
|
| 34 |
+
} = useContext(AppContext)!;
|
| 35 |
+
const [model, setModel] = useState<InferenceSession | null>(null);
|
| 36 |
+
const [tensor, setTensor] = useState<Tensor | null>(null);
|
| 37 |
+
const [modelScale, setModelScale] = useState<modelScaleProps | null>(null);
|
| 38 |
+
const [isLoading, setIsLoading] = useState<boolean>(false);
|
| 39 |
+
const [error, setError] = useState<string | null>(null);
|
| 40 |
+
const [descriptionState, setDescriptionState] = useState<DescriptionState>({
|
| 41 |
+
state: 'ready',
|
| 42 |
+
description: ''
|
| 43 |
+
});
|
| 44 |
+
const [queueStatus, setQueueStatus] = useState<QueueStatus>({ inQueue: false });
|
| 45 |
+
|
| 46 |
+
// Initialize the ONNX model
|
| 47 |
+
useEffect(() => {
|
| 48 |
+
const initModel = async () => {
|
| 49 |
+
try {
|
| 50 |
+
if (MODEL_DIR === undefined) return;
|
| 51 |
+
const URL: string = MODEL_DIR;
|
| 52 |
+
const model = await InferenceSession.create(URL);
|
| 53 |
+
setModel(model);
|
| 54 |
+
} catch (e) {
|
| 55 |
+
console.log(e);
|
| 56 |
+
}
|
| 57 |
+
};
|
| 58 |
+
initModel();
|
| 59 |
+
}, []);
|
| 60 |
+
|
| 61 |
+
const handleImageUpload = async (event: React.ChangeEvent<HTMLInputElement>) => {
|
| 62 |
+
const file = event.target.files?.[0];
|
| 63 |
+
if (!file) return;
|
| 64 |
+
|
| 65 |
+
try {
|
| 66 |
+
const url = URL.createObjectURL(file);
|
| 67 |
+
await loadImage(new URL(url));
|
| 68 |
+
} catch (error) {
|
| 69 |
+
setError('Failed to load image. Please try again with a different image.');
|
| 70 |
+
console.error('Error loading image:', error);
|
| 71 |
+
}
|
| 72 |
+
};
|
| 73 |
+
|
| 74 |
+
const loadImage = async (url: URL) => {
|
| 75 |
+
try {
|
| 76 |
+
setIsLoading(true);
|
| 77 |
+
const img = new Image();
|
| 78 |
+
img.src = url.href;
|
| 79 |
+
img.onload = async () => {
|
| 80 |
+
const { height, width, samScale } = handleImageScale(img);
|
| 81 |
+
setModelScale({
|
| 82 |
+
height: height,
|
| 83 |
+
width: width,
|
| 84 |
+
samScale: samScale,
|
| 85 |
+
});
|
| 86 |
+
img.width = width;
|
| 87 |
+
img.height = height;
|
| 88 |
+
setImage(img);
|
| 89 |
+
|
| 90 |
+
// After image is loaded, fetch its embedding from Gradio
|
| 91 |
+
await fetchImageEmbedding(img);
|
| 92 |
+
setIsLoading(false);
|
| 93 |
+
};
|
| 94 |
+
} catch (error) {
|
| 95 |
+
console.log(error);
|
| 96 |
+
setIsLoading(false);
|
| 97 |
+
}
|
| 98 |
+
};
|
| 99 |
+
|
| 100 |
+
const fetchImageEmbedding = async (img: HTMLImageElement) => {
|
| 101 |
+
try {
|
| 102 |
+
// Create a canvas to convert the image to base64
|
| 103 |
+
const canvas = document.createElement('canvas');
|
| 104 |
+
canvas.width = img.width;
|
| 105 |
+
canvas.height = img.height;
|
| 106 |
+
const ctx = canvas.getContext('2d');
|
| 107 |
+
ctx?.drawImage(img, 0, 0);
|
| 108 |
+
|
| 109 |
+
// Convert image to base64 data URL and extract the base64 string
|
| 110 |
+
const base64Image = canvas.toDataURL('image/jpeg').split(',')[1];
|
| 111 |
+
|
| 112 |
+
// Make request to Gradio API
|
| 113 |
+
const samEmbedding = await imageToSamEmbedding(
|
| 114 |
+
base64Image,
|
| 115 |
+
(status: QueueStatus) => {
|
| 116 |
+
setQueueStatus(status);
|
| 117 |
+
}
|
| 118 |
+
);
|
| 119 |
+
|
| 120 |
+
// Convert base64 embedding back to array buffer
|
| 121 |
+
const binaryString = window.atob(samEmbedding);
|
| 122 |
+
const len = binaryString.length;
|
| 123 |
+
const bytes = new Uint8Array(len);
|
| 124 |
+
for (let i = 0; i < len; i++) {
|
| 125 |
+
bytes[i] = binaryString.charCodeAt(i);
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
// Create tensor from the embedding
|
| 129 |
+
const embedding = new ort.Tensor(
|
| 130 |
+
'float32',
|
| 131 |
+
new Float32Array(bytes.buffer), // Convert to Float32Array
|
| 132 |
+
[1, 256, 64, 64] // SAM embedding shape
|
| 133 |
+
);
|
| 134 |
+
setTensor(embedding);
|
| 135 |
+
} catch (error) {
|
| 136 |
+
setQueueStatus({ inQueue: false }); // Reset queue status on error
|
| 137 |
+
let errorMessage = 'Failed to process image. Please try again.';
|
| 138 |
+
if (axios.isAxiosError(error)) {
|
| 139 |
+
errorMessage = error.response?.data?.message || errorMessage;
|
| 140 |
+
}
|
| 141 |
+
setError(errorMessage);
|
| 142 |
+
console.error('Error fetching embedding:', error);
|
| 143 |
+
}
|
| 144 |
+
};
|
| 145 |
+
|
| 146 |
+
useEffect(() => {
|
| 147 |
+
const handleMaskUpdate = async () => {
|
| 148 |
+
await runONNX();
|
| 149 |
+
};
|
| 150 |
+
handleMaskUpdate();
|
| 151 |
+
}, [clicks]);
|
| 152 |
+
|
| 153 |
+
const runONNX = async () => {
|
| 154 |
+
try {
|
| 155 |
+
// Don't run if already described or is describing
|
| 156 |
+
if (descriptionState.state !== 'ready') return;
|
| 157 |
+
|
| 158 |
+
console.log('Running ONNX model with:', {
|
| 159 |
+
modelLoaded: model !== null,
|
| 160 |
+
hasClicks: clicks !== null,
|
| 161 |
+
hasTensor: tensor !== null,
|
| 162 |
+
hasModelScale: modelScale !== null
|
| 163 |
+
});
|
| 164 |
+
|
| 165 |
+
if (
|
| 166 |
+
model === null ||
|
| 167 |
+
clicks === null ||
|
| 168 |
+
tensor === null ||
|
| 169 |
+
modelScale === null
|
| 170 |
+
) {
|
| 171 |
+
console.log('Missing required inputs, returning early');
|
| 172 |
+
return;
|
| 173 |
+
}
|
| 174 |
+
else {
|
| 175 |
+
console.log('Preparing model feeds with:', {
|
| 176 |
+
clicks,
|
| 177 |
+
tensorShape: tensor.dims,
|
| 178 |
+
modelScale
|
| 179 |
+
});
|
| 180 |
+
|
| 181 |
+
const feeds = modelData({
|
| 182 |
+
clicks,
|
| 183 |
+
tensor,
|
| 184 |
+
modelScale,
|
| 185 |
+
});
|
| 186 |
+
|
| 187 |
+
if (feeds === undefined) {
|
| 188 |
+
console.log('Model feeds undefined, returning early');
|
| 189 |
+
return;
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
console.log('Running model with feeds:', feeds);
|
| 193 |
+
const results = await model.run(feeds);
|
| 194 |
+
console.log('Model run complete, got results:', results);
|
| 195 |
+
|
| 196 |
+
const output = results[model.outputNames[0]];
|
| 197 |
+
console.log('Processing output with dims:', output.dims);
|
| 198 |
+
|
| 199 |
+
// Calculate and log the mask area (number of non-zero values)
|
| 200 |
+
const maskArray = Array.from(output.data as Uint8Array);
|
| 201 |
+
const maskArea = maskArray.filter(val => val > 0).length;
|
| 202 |
+
console.log('Mask area (number of non-zero pixels):', maskArea);
|
| 203 |
+
|
| 204 |
+
// Double check that the state is ready before processing the mask since the state may have changed
|
| 205 |
+
if (descriptionState.state !== 'ready') return;
|
| 206 |
+
// If clicked, we only handle the first mask (note that mask will be cleared after clicking before handling to let us know if it's the first mask).
|
| 207 |
+
if (isClicked && maskImgData != null) return;
|
| 208 |
+
if (maskArea > 0) {
|
| 209 |
+
setMaskImg(onnxMaskToImage(output.data, output.dims[2], output.dims[3], false));
|
| 210 |
+
setMaskImgData(imageDataToURL(arrayToImageData(output.data, output.dims[2], output.dims[3], true)));
|
| 211 |
+
} else {
|
| 212 |
+
console.warn('No mask area detected, clearing mask');
|
| 213 |
+
setMaskImg(null);
|
| 214 |
+
// setMaskImgData(null);
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
console.log('Mask processing complete');
|
| 218 |
+
}
|
| 219 |
+
} catch (e) {
|
| 220 |
+
setError('Failed to process the image. Please try again.');
|
| 221 |
+
console.error('Error running ONNX model:', e);
|
| 222 |
+
}
|
| 223 |
+
};
|
| 224 |
+
|
| 225 |
+
const handleNewRegion = () => {
|
| 226 |
+
setDescriptionState({
|
| 227 |
+
state: 'ready',
|
| 228 |
+
description: ''
|
| 229 |
+
} as DescriptionState);
|
| 230 |
+
setMaskImg(null);
|
| 231 |
+
// setMaskImgData(null);
|
| 232 |
+
setIsClicked(false);
|
| 233 |
+
};
|
| 234 |
+
|
| 235 |
+
const handleCopyDescription = () => {
|
| 236 |
+
navigator.clipboard.writeText(descriptionState.description);
|
| 237 |
+
};
|
| 238 |
+
|
| 239 |
+
const handleReset = () => {
|
| 240 |
+
// Clear all states
|
| 241 |
+
setDescriptionState({
|
| 242 |
+
state: 'ready',
|
| 243 |
+
description: ''
|
| 244 |
+
} as DescriptionState);
|
| 245 |
+
setMaskImg(null);
|
| 246 |
+
// setMaskImgData(null);
|
| 247 |
+
setImage(null);
|
| 248 |
+
setClicks(null);
|
| 249 |
+
setIsClicked(false);
|
| 250 |
+
};
|
| 251 |
+
|
| 252 |
+
return (
|
| 253 |
+
<div className="flex flex-col h-screen">
|
| 254 |
+
{isLoading && <LoadingOverlay />}
|
| 255 |
+
{error && <ErrorModal message={error} onClose={() => setError(null)} />}
|
| 256 |
+
<QueueStatusIndicator queueStatus={queueStatus} />
|
| 257 |
+
<div className="flex-1">
|
| 258 |
+
<Stage
|
| 259 |
+
onImageUpload={handleImageUpload}
|
| 260 |
+
descriptionState={descriptionState}
|
| 261 |
+
setDescriptionState={setDescriptionState}
|
| 262 |
+
queueStatus={queueStatus}
|
| 263 |
+
setQueueStatus={setQueueStatus}
|
| 264 |
+
/>
|
| 265 |
+
</div>
|
| 266 |
+
<div className="description-container">
|
| 267 |
+
<div className={`description-box ${descriptionState.state !== 'described' ? descriptionState.state : ''}`}>
|
| 268 |
+
{descriptionState.description ? (
|
| 269 |
+
descriptionState.description + (descriptionState.state === 'describing' ? '...' : '')
|
| 270 |
+
) : descriptionState.state === 'describing' ? (
|
| 271 |
+
<em>Describing the region... (this may take a while if compute resources are busy)</em>
|
| 272 |
+
) : (
|
| 273 |
+
image ? (
|
| 274 |
+
<em>Click on the image to describe the region</em>
|
| 275 |
+
) : (
|
| 276 |
+
<em>Upload an image to describe the region</em>
|
| 277 |
+
)
|
| 278 |
+
)}
|
| 279 |
+
</div>
|
| 280 |
+
<div className="description-controls">
|
| 281 |
+
<button
|
| 282 |
+
onClick={handleCopyDescription}
|
| 283 |
+
disabled={descriptionState.state !== 'described'}
|
| 284 |
+
>
|
| 285 |
+
Copy description
|
| 286 |
+
</button>
|
| 287 |
+
<button
|
| 288 |
+
onClick={handleNewRegion}
|
| 289 |
+
disabled={descriptionState.state !== 'described'}
|
| 290 |
+
>
|
| 291 |
+
Describe a new region
|
| 292 |
+
</button>
|
| 293 |
+
<button
|
| 294 |
+
onClick={handleReset}
|
| 295 |
+
className="reset-button"
|
| 296 |
+
disabled={descriptionState.state === 'describing' || !image}
|
| 297 |
+
>
|
| 298 |
+
Try a new image
|
| 299 |
+
</button>
|
| 300 |
+
</div>
|
| 301 |
+
</div>
|
| 302 |
+
</div>
|
| 303 |
+
);
|
| 304 |
+
};
|
| 305 |
+
|
| 306 |
+
export default App;
|
demo/gradio/frontend/src/components/ErrorModal.tsx
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React from 'react';
|
| 2 |
+
|
| 3 |
+
interface ErrorModalProps {
|
| 4 |
+
message: string;
|
| 5 |
+
onClose: () => void;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
const ErrorModal: React.FC<ErrorModalProps> = ({ message, onClose }) => {
|
| 9 |
+
return (
|
| 10 |
+
<div className="fixed inset-0 bg-black bg-opacity-50 flex items-center justify-center z-50">
|
| 11 |
+
<div className="bg-white p-6 rounded-lg shadow-xl max-w-md w-full mx-4">
|
| 12 |
+
<div className="flex flex-col items-center">
|
| 13 |
+
<div className="bg-red-100 p-4 rounded-full mb-4">
|
| 14 |
+
<svg className="w-6 h-6 text-red-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
| 15 |
+
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth="2" d="M6 18L18 6M6 6l12 12" />
|
| 16 |
+
</svg>
|
| 17 |
+
</div>
|
| 18 |
+
<h3 className="text-lg font-semibold text-gray-900 mb-2">Error</h3>
|
| 19 |
+
<p className="text-gray-600 text-center mb-6">{message}</p>
|
| 20 |
+
<button
|
| 21 |
+
onClick={onClose}
|
| 22 |
+
className="bg-red-600 text-white px-4 py-2 rounded hover:bg-red-700 transition-colors"
|
| 23 |
+
>
|
| 24 |
+
Close
|
| 25 |
+
</button>
|
| 26 |
+
</div>
|
| 27 |
+
</div>
|
| 28 |
+
</div>
|
| 29 |
+
);
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
export default ErrorModal;
|
demo/gradio/frontend/src/components/LoadingOverlay.tsx
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React from 'react';
|
| 2 |
+
|
| 3 |
+
const LoadingOverlay: React.FC = () => {
|
| 4 |
+
return (
|
| 5 |
+
<div className="fixed inset-0 bg-gray-500 bg-opacity-75 flex items-center justify-center z-50">
|
| 6 |
+
<div className="bg-white p-8 rounded-lg shadow-xl flex flex-col items-center">
|
| 7 |
+
<svg width="54" height="54" viewBox="0 0 54 54" fill="none" xmlns="http://www.w3.org/2000/svg" className="w-16 h-16 mb-4">
|
| 8 |
+
<path d="M5.92017 41.0562L27.0002 48.0802" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 9 |
+
<path d="M5.92017 12.9438L27.0002 26.9998" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 10 |
+
<path d="M27 5.91992L48.08 26.9999" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 11 |
+
<path d="M5.92017 41.0559L27.0002 5.91992" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 12 |
+
<path d="M27 48.08L48.08 27" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 13 |
+
<path d="M27 27H48.08" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 14 |
+
<path d="M5.92017 12.9439L27.0002 5.91992" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 15 |
+
<path d="M5.92017 41.056L27.0002 27" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 16 |
+
<path d="M5.92017 12.9438L27.0002 48.0798" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 17 |
+
<path d="M26.9998 31.9201C29.7171 31.9201 31.9198 29.7173 31.9198 27.0001C31.9198 24.2828 29.7171 22.0801 26.9998 22.0801C24.2826 22.0801 22.0798 24.2828 22.0798 27.0001C22.0798 29.7173 24.2826 31.9201 26.9998 31.9201Z" fill="white" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 18 |
+
<path d="M5.92 17.8639C8.63724 17.8639 10.84 15.6612 10.84 12.9439C10.84 10.2267 8.63724 8.02393 5.92 8.02393C3.20276 8.02393 1 10.2267 1 12.9439C1 15.6612 3.20276 17.8639 5.92 17.8639Z" fill="white" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 19 |
+
<path d="M5.92 45.9757C8.63724 45.9757 10.84 43.773 10.84 41.0557C10.84 38.3385 8.63724 36.1357 5.92 36.1357C3.20276 36.1357 1 38.3385 1 41.0557C1 43.773 3.20276 45.9757 5.92 45.9757Z" fill="white" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 20 |
+
<path d="M48.0806 31.9201C50.7979 31.9201 53.0006 29.7173 53.0006 27.0001C53.0006 24.2828 50.7979 22.0801 48.0806 22.0801C45.3634 22.0801 43.1606 24.2828 43.1606 27.0001C43.1606 29.7173 45.3634 31.9201 48.0806 31.9201Z" fill="white" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 21 |
+
<path d="M26.9998 53.0002C29.7171 53.0002 31.9198 50.7974 31.9198 48.0802C31.9198 45.3629 29.7171 43.1602 26.9998 43.1602C24.2826 43.1602 22.0798 45.3629 22.0798 48.0802C22.0798 50.7974 24.2826 53.0002 26.9998 53.0002Z" fill="white" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 22 |
+
<path d="M26.9998 10.84C29.7171 10.84 31.9198 8.63724 31.9198 5.92C31.9198 3.20276 29.7171 1 26.9998 1C24.2826 1 22.0798 3.20276 22.0798 5.92C22.0798 8.63724 24.2826 10.84 26.9998 10.84Z" fill="white" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 23 |
+
</svg>
|
| 24 |
+
<p className="text-lg font-semibold text-gray-800">Loading image embedding...</p>
|
| 25 |
+
</div>
|
| 26 |
+
</div>
|
| 27 |
+
);
|
| 28 |
+
};
|
| 29 |
+
|
| 30 |
+
export default LoadingOverlay;
|
demo/gradio/frontend/src/components/QueueStatusIndicator.tsx
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React from 'react';
|
| 2 |
+
import { QueueStatus } from './helpers/Interfaces';
|
| 3 |
+
|
| 4 |
+
interface QueueStatusIndicatorProps {
|
| 5 |
+
queueStatus: QueueStatus;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
const QueueStatusIndicator: React.FC<QueueStatusIndicatorProps> = ({ queueStatus }) => {
|
| 9 |
+
if (!queueStatus.inQueue) return null;
|
| 10 |
+
|
| 11 |
+
return (
|
| 12 |
+
<div className="fixed top-4 right-4 bg-white rounded-lg shadow-lg p-4 z-50">
|
| 13 |
+
<div className="flex flex-col gap-2">
|
| 14 |
+
{queueStatus.rank === 0 ? (
|
| 15 |
+
<p className="text-sm">You're next in line! ({queueStatus.queueSize} total in queue)</p>
|
| 16 |
+
) : (
|
| 17 |
+
<p className="text-sm">Queue position: {queueStatus.rank! + 1} of {queueStatus.queueSize}</p>
|
| 18 |
+
)}
|
| 19 |
+
{queueStatus.rankEta && (
|
| 20 |
+
<p className="text-sm text-gray-600">
|
| 21 |
+
Estimated wait: {Math.ceil(queueStatus.rankEta)} seconds
|
| 22 |
+
</p>
|
| 23 |
+
)}
|
| 24 |
+
</div>
|
| 25 |
+
</div>
|
| 26 |
+
);
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
export default QueueStatusIndicator;
|
demo/gradio/frontend/src/components/Stage.tsx
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import React, { useContext, useState, useEffect } from "react";
|
| 8 |
+
import * as _ from "underscore";
|
| 9 |
+
import Tool from "./Tool";
|
| 10 |
+
import { modelInputProps, QueueStatus } from "./helpers/Interfaces";
|
| 11 |
+
import AppContext from "./hooks/createContext";
|
| 12 |
+
// import { describeMask } from '../services/maskApi';
|
| 13 |
+
|
| 14 |
+
interface DescriptionState {
|
| 15 |
+
state: string; // 'ready', 'describing', 'described'
|
| 16 |
+
description: string;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
interface StageProps {
|
| 20 |
+
onImageUpload: (event: React.ChangeEvent<HTMLInputElement>) => Promise<void>;
|
| 21 |
+
descriptionState: DescriptionState;
|
| 22 |
+
setDescriptionState: React.Dispatch<React.SetStateAction<DescriptionState>>;
|
| 23 |
+
queueStatus: QueueStatus;
|
| 24 |
+
setQueueStatus: (status: QueueStatus) => void;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
const EXAMPLE_IMAGES = Array.from({ length: 21 }, (_, i) => `/examples/${i + 1}.jpg`);
|
| 28 |
+
const BREAKPOINT_MEDIUM = 2100;
|
| 29 |
+
const BREAKPOINT_SMALL = 1100;
|
| 30 |
+
|
| 31 |
+
const Stage = ({ onImageUpload, descriptionState, setDescriptionState, queueStatus, setQueueStatus }: StageProps) => {
|
| 32 |
+
const {
|
| 33 |
+
clicks: [, setClicks],
|
| 34 |
+
image: [image],
|
| 35 |
+
maskImg: [maskImg],
|
| 36 |
+
maskImgData: [maskImgData]
|
| 37 |
+
} = useContext(AppContext)!;
|
| 38 |
+
|
| 39 |
+
const [isDragging, setIsDragging] = useState(false);
|
| 40 |
+
const [currentPage, setCurrentPage] = useState(1);
|
| 41 |
+
const [imagesPerPage, setImagesPerPage] = useState(8);
|
| 42 |
+
|
| 43 |
+
useEffect(() => {
|
| 44 |
+
const handleResize = () => {
|
| 45 |
+
if (window.innerWidth < BREAKPOINT_SMALL) {
|
| 46 |
+
setImagesPerPage(1);
|
| 47 |
+
} else if (window.innerWidth < BREAKPOINT_MEDIUM) {
|
| 48 |
+
setImagesPerPage(4);
|
| 49 |
+
} else {
|
| 50 |
+
setImagesPerPage(8);
|
| 51 |
+
}
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
// Set initial value
|
| 55 |
+
handleResize();
|
| 56 |
+
|
| 57 |
+
// Add event listener
|
| 58 |
+
window.addEventListener('resize', handleResize);
|
| 59 |
+
|
| 60 |
+
// Cleanup
|
| 61 |
+
return () => window.removeEventListener('resize', handleResize);
|
| 62 |
+
}, []);
|
| 63 |
+
|
| 64 |
+
const getClick = (x: number, y: number): modelInputProps => {
|
| 65 |
+
const clickType = 1;
|
| 66 |
+
return { x, y, clickType };
|
| 67 |
+
};
|
| 68 |
+
|
| 69 |
+
const handleMouseMove = _.throttle((e: any) => {
|
| 70 |
+
if (descriptionState.state !== 'ready') return;
|
| 71 |
+
if (e.clientX === undefined || e.clientY === undefined) {
|
| 72 |
+
console.warn('Mouse move event does not contain clientX or clientY');
|
| 73 |
+
return;
|
| 74 |
+
}
|
| 75 |
+
let el = e.nativeEvent.target;
|
| 76 |
+
const rect = el.getBoundingClientRect();
|
| 77 |
+
|
| 78 |
+
// Calculate the actual dimensions of the contained image
|
| 79 |
+
const containerAspectRatio = el.offsetWidth / el.offsetHeight;
|
| 80 |
+
const imageAspectRatio = image ? image.width / image.height : 1;
|
| 81 |
+
|
| 82 |
+
let renderedWidth, renderedHeight;
|
| 83 |
+
if (containerAspectRatio > imageAspectRatio) {
|
| 84 |
+
// Image is constrained by height
|
| 85 |
+
renderedHeight = el.offsetHeight;
|
| 86 |
+
renderedWidth = renderedHeight * imageAspectRatio;
|
| 87 |
+
} else {
|
| 88 |
+
// Image is constrained by width
|
| 89 |
+
renderedWidth = el.offsetWidth;
|
| 90 |
+
renderedHeight = renderedWidth / imageAspectRatio;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
// Calculate the empty space offset
|
| 94 |
+
const offsetX = (el.offsetWidth - renderedWidth) / 2;
|
| 95 |
+
const offsetY = (el.offsetHeight - renderedHeight) / 2;
|
| 96 |
+
|
| 97 |
+
// Get click position relative to the actual image
|
| 98 |
+
let x = e.clientX - rect.left - offsetX;
|
| 99 |
+
let y = e.clientY - rect.top - offsetY;
|
| 100 |
+
|
| 101 |
+
// Convert to original image coordinates
|
| 102 |
+
const scaleX = image ? image.width / renderedWidth : 1;
|
| 103 |
+
const scaleY = image ? image.height / renderedHeight : 1;
|
| 104 |
+
x *= scaleX;
|
| 105 |
+
y *= scaleY;
|
| 106 |
+
|
| 107 |
+
// Ensure coordinates are within bounds
|
| 108 |
+
if (image) {
|
| 109 |
+
x = Math.max(0, Math.min(x, image.width));
|
| 110 |
+
y = Math.max(0, Math.min(y, image.height));
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
const click = getClick(x, y);
|
| 114 |
+
if (click) {
|
| 115 |
+
setClicks([click]);
|
| 116 |
+
}
|
| 117 |
+
}, 15);
|
| 118 |
+
|
| 119 |
+
const handleDragEnter = (e: React.DragEvent) => {
|
| 120 |
+
e.preventDefault();
|
| 121 |
+
e.stopPropagation();
|
| 122 |
+
setIsDragging(true);
|
| 123 |
+
};
|
| 124 |
+
|
| 125 |
+
const handleDragLeave = (e: React.DragEvent) => {
|
| 126 |
+
e.preventDefault();
|
| 127 |
+
e.stopPropagation();
|
| 128 |
+
setIsDragging(false);
|
| 129 |
+
};
|
| 130 |
+
|
| 131 |
+
const handleDragOver = (e: React.DragEvent) => {
|
| 132 |
+
e.preventDefault();
|
| 133 |
+
e.stopPropagation();
|
| 134 |
+
};
|
| 135 |
+
|
| 136 |
+
const handleDrop = async (e: React.DragEvent) => {
|
| 137 |
+
e.preventDefault();
|
| 138 |
+
e.stopPropagation();
|
| 139 |
+
setIsDragging(false);
|
| 140 |
+
|
| 141 |
+
const files = e.dataTransfer.files;
|
| 142 |
+
if (files && files[0]) {
|
| 143 |
+
const file = files[0];
|
| 144 |
+
// Cast to unknown first, then to the desired type
|
| 145 |
+
const syntheticEvent = {
|
| 146 |
+
target: {
|
| 147 |
+
files: [file]
|
| 148 |
+
}
|
| 149 |
+
} as unknown as React.ChangeEvent<HTMLInputElement>;
|
| 150 |
+
|
| 151 |
+
onImageUpload(syntheticEvent);
|
| 152 |
+
}
|
| 153 |
+
};
|
| 154 |
+
|
| 155 |
+
const flexCenterClasses = "flex items-center justify-center";
|
| 156 |
+
|
| 157 |
+
// const handleDescribeMask = async () => {
|
| 158 |
+
// if (!maskImg || !maskImgData || !image) {
|
| 159 |
+
// console.warn('No mask or image available to describe');
|
| 160 |
+
// return;
|
| 161 |
+
// }
|
| 162 |
+
|
| 163 |
+
// try {
|
| 164 |
+
// const canvas = document.createElement('canvas');
|
| 165 |
+
// canvas.width = image.width;
|
| 166 |
+
// canvas.height = image.height;
|
| 167 |
+
// const ctx = canvas.getContext('2d');
|
| 168 |
+
// ctx?.drawImage(image, 0, 0);
|
| 169 |
+
// const imageBase64 = canvas.toDataURL('image/jpeg').split(',')[1];
|
| 170 |
+
// const maskBase64 = maskImgData.split(',')[1];
|
| 171 |
+
|
| 172 |
+
// const result = await describeMask(maskBase64, imageBase64);
|
| 173 |
+
// console.log('Mask description:', result.description);
|
| 174 |
+
|
| 175 |
+
// alert("Mask description: " + result.description);
|
| 176 |
+
// } catch (error) {
|
| 177 |
+
// console.error('Failed to describe mask:', error);
|
| 178 |
+
// }
|
| 179 |
+
// };
|
| 180 |
+
|
| 181 |
+
return (
|
| 182 |
+
<div
|
| 183 |
+
className={`flex flex-col w-full h-full relative`}
|
| 184 |
+
onDragEnter={handleDragEnter}
|
| 185 |
+
onDragOver={handleDragOver}
|
| 186 |
+
onDragLeave={handleDragLeave}
|
| 187 |
+
onDrop={handleDrop}
|
| 188 |
+
>
|
| 189 |
+
{/* Title and Description */}
|
| 190 |
+
<div className="w-full px-8 mb-8 flex flex-col justify-center mt-4">
|
| 191 |
+
<div className="flex flex-col sm:flex-row justify-between items-center gap-4">
|
| 192 |
+
<h1 className="text-3xl font-bold text-center sm:text-left"><a href="/">Describe Anything Model Demo</a></h1>
|
| 193 |
+
<div className="flex flex-wrap justify-center gap-4 sm:space-x-8 text-lg font-medium">
|
| 194 |
+
<a href="https://describe-anything.github.io/" target="_blank" rel="noopener noreferrer" className="text-gray-600 hover:text-gray-800">Project Page</a>
|
| 195 |
+
<a href="https://github.com/NVlabs/describe-anything?tab=readme-ov-file#simple-gradio-demo-for-detailed-localized-video-descriptions" target="_blank" rel="noopener noreferrer" className="text-gray-600 hover:text-gray-800">DAM for video</a>
|
| 196 |
+
</div>
|
| 197 |
+
</div>
|
| 198 |
+
<div className="border-b border-gray-300 mt-4 mb-4"></div>
|
| 199 |
+
{!image && <div className="space-y-4 text-gray-600 text-left">
|
| 200 |
+
<p>Describe Anything Model (DAM) takes in a region of an image or a video in the form of points/boxes/scribbles/masks and outputs detailed descriptions to the region. For videos, it is sufficient to supply annotation on any frame.</p>
|
| 201 |
+
<p>This demo supports DAM model that takes points on images as queries. For other use cases, please refer to the <a href="" className="text-gray-600 hover:text-gray-800 underline">inference scripts and video demo</a> for more details.</p>
|
| 202 |
+
</div>}
|
| 203 |
+
</div>
|
| 204 |
+
|
| 205 |
+
{/* Main Content Area */}
|
| 206 |
+
<div className={`flex items-center justify-center flex-grow overflow-hidden`}>
|
| 207 |
+
{/* Main Stage */}
|
| 208 |
+
<div
|
| 209 |
+
className={`${flexCenterClasses} relative w-full h-full max-h-[calc(100vh-300px)] px-8 ${
|
| 210 |
+
isDragging ? 'border-4 border-dashed border-blue-500 bg-blue-50' : ''
|
| 211 |
+
}`}
|
| 212 |
+
>
|
| 213 |
+
{image ? (
|
| 214 |
+
<>
|
| 215 |
+
<Tool
|
| 216 |
+
handleMouseMove={handleMouseMove}
|
| 217 |
+
descriptionState={descriptionState}
|
| 218 |
+
setDescriptionState={setDescriptionState}
|
| 219 |
+
queueStatus={queueStatus}
|
| 220 |
+
setQueueStatus={setQueueStatus}
|
| 221 |
+
/>
|
| 222 |
+
</>
|
| 223 |
+
) : (
|
| 224 |
+
<>
|
| 225 |
+
<div className="flex flex-col items-center gap-6 w-full h-full">
|
| 226 |
+
<div className="flex-1" />
|
| 227 |
+
|
| 228 |
+
<div className="text-gray-500 text-lg">
|
| 229 |
+
{isDragging ? 'Drop image here' : 'Upload your own image'}
|
| 230 |
+
</div>
|
| 231 |
+
<div className="flex gap-4 mb-8">
|
| 232 |
+
<input
|
| 233 |
+
type="file"
|
| 234 |
+
id="imageUpload"
|
| 235 |
+
accept="image/*"
|
| 236 |
+
onChange={onImageUpload}
|
| 237 |
+
className="hidden"
|
| 238 |
+
/>
|
| 239 |
+
<label
|
| 240 |
+
htmlFor="imageUpload"
|
| 241 |
+
className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded cursor-pointer"
|
| 242 |
+
>
|
| 243 |
+
Upload Image
|
| 244 |
+
</label>
|
| 245 |
+
</div>
|
| 246 |
+
|
| 247 |
+
<div className="text-gray-500 text-lg">
|
| 248 |
+
or choose an example image below
|
| 249 |
+
</div>
|
| 250 |
+
|
| 251 |
+
<div className="relative w-full max-w-[2200px]">
|
| 252 |
+
{/* Left Arrow */}
|
| 253 |
+
<button
|
| 254 |
+
onClick={() => setCurrentPage(prev => Math.max(prev - 1, 1))}
|
| 255 |
+
disabled={currentPage === 1}
|
| 256 |
+
className={`absolute left-0 top-1/2 -translate-y-1/2 z-10 p-4 ${
|
| 257 |
+
currentPage === 1
|
| 258 |
+
? 'text-gray-300 cursor-not-allowed'
|
| 259 |
+
: 'text-gray-600 hover:text-gray-800'
|
| 260 |
+
}`}
|
| 261 |
+
>
|
| 262 |
+
<svg xmlns="http://www.w3.org/2000/svg" className="h-8 w-8" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
| 263 |
+
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M15 19l-7-7 7-7" />
|
| 264 |
+
</svg>
|
| 265 |
+
</button>
|
| 266 |
+
|
| 267 |
+
{/* Example Images */}
|
| 268 |
+
<div className="flex flex-wrap justify-center gap-8 px-16">
|
| 269 |
+
{EXAMPLE_IMAGES.slice(
|
| 270 |
+
(currentPage - 1) * imagesPerPage,
|
| 271 |
+
currentPage * imagesPerPage
|
| 272 |
+
).map((src, index) => (
|
| 273 |
+
<img
|
| 274 |
+
key={index}
|
| 275 |
+
src={src}
|
| 276 |
+
alt={`Example ${index + 1}`}
|
| 277 |
+
className="w-[200px] h-[150px] object-cover rounded-sm cursor-pointer hover:opacity-80 transition-opacity"
|
| 278 |
+
onClick={() => {
|
| 279 |
+
fetch(src)
|
| 280 |
+
.then(res => res.blob())
|
| 281 |
+
.then(blob => {
|
| 282 |
+
const file = new File([blob], `example-${index + 1}.jpg`, { type: 'image/jpeg' });
|
| 283 |
+
const syntheticEvent = {
|
| 284 |
+
target: {
|
| 285 |
+
files: [file]
|
| 286 |
+
}
|
| 287 |
+
} as unknown as React.ChangeEvent<HTMLInputElement>;
|
| 288 |
+
|
| 289 |
+
onImageUpload(syntheticEvent);
|
| 290 |
+
});
|
| 291 |
+
}}
|
| 292 |
+
/>
|
| 293 |
+
))}
|
| 294 |
+
</div>
|
| 295 |
+
|
| 296 |
+
{/* Right Arrow */}
|
| 297 |
+
<button
|
| 298 |
+
onClick={() => setCurrentPage(prev => Math.min(prev + 1, Math.ceil(EXAMPLE_IMAGES.length / imagesPerPage)))}
|
| 299 |
+
disabled={currentPage === Math.ceil(EXAMPLE_IMAGES.length / imagesPerPage)}
|
| 300 |
+
className={`absolute right-0 top-1/2 -translate-y-1/2 z-10 p-4 ${
|
| 301 |
+
currentPage === Math.ceil(EXAMPLE_IMAGES.length / imagesPerPage)
|
| 302 |
+
? 'text-gray-300 cursor-not-allowed'
|
| 303 |
+
: 'text-gray-600 hover:text-gray-800'
|
| 304 |
+
}`}
|
| 305 |
+
>
|
| 306 |
+
<svg xmlns="http://www.w3.org/2000/svg" className="h-8 w-8" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
| 307 |
+
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 5l7 7-7 7" />
|
| 308 |
+
</svg>
|
| 309 |
+
</button>
|
| 310 |
+
|
| 311 |
+
{/* Page Indicator */}
|
| 312 |
+
{/* <div className="w-full text-center mt-4 text-gray-600">
|
| 313 |
+
Page {currentPage} of {Math.ceil(EXAMPLE_IMAGES.length / imagesPerPage)}
|
| 314 |
+
</div> */}
|
| 315 |
+
</div>
|
| 316 |
+
|
| 317 |
+
<div className="flex-1" /> {/* Bottom spacer */}
|
| 318 |
+
{/* Image Credits */}
|
| 319 |
+
{!image && (
|
| 320 |
+
<div className="pl-5 pr-5 text-gray-500 text-sm">
|
| 321 |
+
Image credit for example images: {' '}
|
| 322 |
+
<a
|
| 323 |
+
href="https://segment-anything.com/terms"
|
| 324 |
+
target="_blank"
|
| 325 |
+
className="text-gray-600 hover:text-gray-800 underline"
|
| 326 |
+
>
|
| 327 |
+
Segment Anything Materials
|
| 328 |
+
</a>
|
| 329 |
+
{' '}(CC BY-SA 4.0)
|
| 330 |
+
</div>
|
| 331 |
+
)}
|
| 332 |
+
</div>
|
| 333 |
+
</>
|
| 334 |
+
)}
|
| 335 |
+
</div>
|
| 336 |
+
</div>
|
| 337 |
+
|
| 338 |
+
</div>
|
| 339 |
+
);
|
| 340 |
+
};
|
| 341 |
+
|
| 342 |
+
export default Stage;
|
| 343 |
+
export type { DescriptionState };
|
demo/gradio/frontend/src/components/Tool.tsx
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { useContext, useEffect, useState } from "react";
|
| 2 |
+
import AppContext from "./hooks/createContext";
|
| 3 |
+
import { ToolProps, QueueStatus } from "./helpers/Interfaces";
|
| 4 |
+
import * as _ from "underscore";
|
| 5 |
+
import { describeMask, describeMaskWithoutStreaming } from "../services/maskApi";
|
| 6 |
+
import ErrorModal from './ErrorModal';
|
| 7 |
+
import { DescriptionState } from "./Stage";
|
| 8 |
+
|
| 9 |
+
const prompt = "<image>\nDescribe the masked region in detail.";
|
| 10 |
+
|
| 11 |
+
const Tool = ({
|
| 12 |
+
handleMouseMove,
|
| 13 |
+
descriptionState,
|
| 14 |
+
setDescriptionState,
|
| 15 |
+
queueStatus,
|
| 16 |
+
setQueueStatus
|
| 17 |
+
}: ToolProps) => {
|
| 18 |
+
console.log("Tool handleMouseMove");
|
| 19 |
+
const {
|
| 20 |
+
image: [image],
|
| 21 |
+
maskImg: [maskImg, setMaskImg],
|
| 22 |
+
maskImgData: [maskImgData, setMaskImgData],
|
| 23 |
+
isClicked: [isClicked, setIsClicked]
|
| 24 |
+
} = useContext(AppContext)!;
|
| 25 |
+
|
| 26 |
+
const [shouldFitToWidth, setShouldFitToWidth] = useState(true);
|
| 27 |
+
const bodyEl = document.body;
|
| 28 |
+
const fitToPage = () => {
|
| 29 |
+
if (!image) return;
|
| 30 |
+
const maxWidth = window.innerWidth - 64; // Account for padding (32px on each side)
|
| 31 |
+
const maxHeight = window.innerHeight - 200; // Account for header and some padding
|
| 32 |
+
const imageAspectRatio = image.width / image.height;
|
| 33 |
+
const containerAspectRatio = maxWidth / maxHeight;
|
| 34 |
+
|
| 35 |
+
setShouldFitToWidth(
|
| 36 |
+
imageAspectRatio > containerAspectRatio ||
|
| 37 |
+
image.width > maxWidth
|
| 38 |
+
);
|
| 39 |
+
};
|
| 40 |
+
const resizeObserver = new ResizeObserver((entries) => {
|
| 41 |
+
for (const entry of entries) {
|
| 42 |
+
if (entry.target === bodyEl) {
|
| 43 |
+
fitToPage();
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
});
|
| 47 |
+
useEffect(() => {
|
| 48 |
+
fitToPage();
|
| 49 |
+
resizeObserver.observe(bodyEl);
|
| 50 |
+
return () => {
|
| 51 |
+
resizeObserver.unobserve(bodyEl);
|
| 52 |
+
};
|
| 53 |
+
}, [image]);
|
| 54 |
+
|
| 55 |
+
const imageClasses = "";
|
| 56 |
+
const maskImageClasses = `absolute opacity-40 pointer-events-none`;
|
| 57 |
+
|
| 58 |
+
const [error, setError] = useState<string | null>(null);
|
| 59 |
+
const [useStreaming, setUseStreaming] = useState(true);
|
| 60 |
+
|
| 61 |
+
useEffect(() => {
|
| 62 |
+
if (!isClicked || !maskImg || !maskImgData || !image || descriptionState.state !== 'ready') {
|
| 63 |
+
console.log("Not ready to call model, isClicked:", isClicked, "maskImg:", maskImg !== null, "maskImgData:", maskImgData !== null, "image:", image !== null, "descriptionState.state:", descriptionState.state);
|
| 64 |
+
return;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
try {
|
| 68 |
+
setDescriptionState({
|
| 69 |
+
state: 'describing',
|
| 70 |
+
description: ''
|
| 71 |
+
} as DescriptionState);
|
| 72 |
+
|
| 73 |
+
const canvas = document.createElement('canvas');
|
| 74 |
+
canvas.width = image.width;
|
| 75 |
+
canvas.height = image.height;
|
| 76 |
+
const ctx = canvas.getContext('2d');
|
| 77 |
+
ctx?.drawImage(image, 0, 0);
|
| 78 |
+
const imageBase64 = canvas.toDataURL('image/jpeg').split(',')[1];
|
| 79 |
+
const maskBase64 = maskImgData.split(',')[1];
|
| 80 |
+
|
| 81 |
+
const describeMaskWithFallback = async (useStreamingInFunction: boolean) => {
|
| 82 |
+
try {
|
| 83 |
+
let result;
|
| 84 |
+
console.log("useStreaming", useStreaming, "useStreamingInFunction", useStreamingInFunction);
|
| 85 |
+
if (useStreamingInFunction) {
|
| 86 |
+
result = await describeMask(
|
| 87 |
+
maskBase64,
|
| 88 |
+
imageBase64,
|
| 89 |
+
prompt,
|
| 90 |
+
(streamResult: string) => {
|
| 91 |
+
setDescriptionState({
|
| 92 |
+
state: 'describing',
|
| 93 |
+
description: streamResult
|
| 94 |
+
} as DescriptionState);
|
| 95 |
+
},
|
| 96 |
+
(status: QueueStatus) => {
|
| 97 |
+
setQueueStatus(status);
|
| 98 |
+
}
|
| 99 |
+
);
|
| 100 |
+
} else {
|
| 101 |
+
result = await describeMaskWithoutStreaming(
|
| 102 |
+
maskBase64,
|
| 103 |
+
imageBase64,
|
| 104 |
+
prompt
|
| 105 |
+
);
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
setDescriptionState({
|
| 109 |
+
state: 'described',
|
| 110 |
+
description: result
|
| 111 |
+
} as DescriptionState);
|
| 112 |
+
setQueueStatus({ inQueue: false });
|
| 113 |
+
setIsClicked(false);
|
| 114 |
+
} catch (error) {
|
| 115 |
+
if (useStreaming) {
|
| 116 |
+
console.log("Error describing mask, switching to non-streaming", error);
|
| 117 |
+
setUseStreaming(false);
|
| 118 |
+
describeMaskWithFallback(false);
|
| 119 |
+
} else {
|
| 120 |
+
setError('Failed to generate description. Please try again.');
|
| 121 |
+
setDescriptionState({
|
| 122 |
+
state: 'ready',
|
| 123 |
+
description: ''
|
| 124 |
+
} as DescriptionState);
|
| 125 |
+
setIsClicked(false);
|
| 126 |
+
console.error('Failed to describe mask:', error);
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
};
|
| 130 |
+
|
| 131 |
+
describeMaskWithFallback(useStreaming);
|
| 132 |
+
|
| 133 |
+
} catch (error) {
|
| 134 |
+
setIsClicked(false);
|
| 135 |
+
setError('Failed to generate description. Please try again.');
|
| 136 |
+
setDescriptionState({
|
| 137 |
+
state: 'ready',
|
| 138 |
+
description: ''
|
| 139 |
+
} as DescriptionState);
|
| 140 |
+
console.error('Failed to describe mask:', error);
|
| 141 |
+
}
|
| 142 |
+
}, [maskImgData]);
|
| 143 |
+
|
| 144 |
+
const handleClick = async (e: React.MouseEvent<HTMLImageElement>) => {
|
| 145 |
+
if (descriptionState.state !== 'ready') return;
|
| 146 |
+
|
| 147 |
+
setMaskImg(null);
|
| 148 |
+
setMaskImgData(null);
|
| 149 |
+
setIsClicked(true);
|
| 150 |
+
handleMouseMove(e);
|
| 151 |
+
};
|
| 152 |
+
|
| 153 |
+
return (
|
| 154 |
+
<>
|
| 155 |
+
{error && <ErrorModal message={error} onClose={() => setError(null)} />}
|
| 156 |
+
<div className="relative flex items-center justify-center w-full h-full">
|
| 157 |
+
{image && (
|
| 158 |
+
<img
|
| 159 |
+
onMouseMove={handleMouseMove}
|
| 160 |
+
onMouseLeave={() => _.defer(() => (descriptionState.state === 'ready' && !isClicked) ? setMaskImg(null) : undefined)}
|
| 161 |
+
onTouchStart={handleMouseMove}
|
| 162 |
+
onClick={handleClick}
|
| 163 |
+
src={image.src}
|
| 164 |
+
className={`${
|
| 165 |
+
shouldFitToWidth ? "w-full" : "h-full"
|
| 166 |
+
} ${imageClasses} object-contain max-h-full max-w-full`}
|
| 167 |
+
></img>
|
| 168 |
+
)}
|
| 169 |
+
{maskImg && (
|
| 170 |
+
<img
|
| 171 |
+
src={maskImg.src}
|
| 172 |
+
className={`${
|
| 173 |
+
shouldFitToWidth ? "w-full" : "h-full"
|
| 174 |
+
} ${maskImageClasses} object-contain max-h-full max-w-full`}
|
| 175 |
+
></img>
|
| 176 |
+
)}
|
| 177 |
+
</div>
|
| 178 |
+
</>
|
| 179 |
+
);
|
| 180 |
+
};
|
| 181 |
+
|
| 182 |
+
export default Tool;
|
demo/gradio/frontend/src/components/helpers/Interfaces.tsx
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import { Tensor } from "onnxruntime-web";
|
| 8 |
+
import { DescriptionState } from "../Stage";
|
| 9 |
+
|
| 10 |
+
export interface modelScaleProps {
|
| 11 |
+
samScale: number;
|
| 12 |
+
height: number;
|
| 13 |
+
width: number;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
export interface modelInputProps {
|
| 17 |
+
x: number;
|
| 18 |
+
y: number;
|
| 19 |
+
clickType: number;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
export interface modeDataProps {
|
| 23 |
+
clicks?: Array<modelInputProps>;
|
| 24 |
+
tensor: Tensor;
|
| 25 |
+
modelScale: modelScaleProps;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
export interface ToolProps {
|
| 29 |
+
handleMouseMove: (e: any) => void;
|
| 30 |
+
descriptionState: DescriptionState;
|
| 31 |
+
setDescriptionState: (value: DescriptionState) => void;
|
| 32 |
+
queueStatus: QueueStatus;
|
| 33 |
+
setQueueStatus: (value: QueueStatus) => void;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
export interface StageProps {
|
| 37 |
+
onImageUpload: (event: React.ChangeEvent<HTMLInputElement>) => void;
|
| 38 |
+
descriptionState: DescriptionState;
|
| 39 |
+
setDescriptionState: (value: DescriptionState) => void;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
export interface QueueStatus {
|
| 43 |
+
inQueue: boolean;
|
| 44 |
+
rank?: number;
|
| 45 |
+
queueSize?: number;
|
| 46 |
+
rankEta?: number | null;
|
| 47 |
+
}
|
demo/gradio/frontend/src/components/helpers/imageUtils.tsx
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { Buffer } from 'buffer';
|
| 2 |
+
|
| 3 |
+
export const base64ToImage = async (base64String: string): Promise<HTMLImageElement> => {
|
| 4 |
+
return new Promise((resolve, reject) => {
|
| 5 |
+
const img = new Image();
|
| 6 |
+
img.onload = () => resolve(img);
|
| 7 |
+
img.onerror = reject;
|
| 8 |
+
img.src = base64String.startsWith('data:') ?
|
| 9 |
+
base64String :
|
| 10 |
+
`data:image/png;base64,${base64String}`;
|
| 11 |
+
});
|
| 12 |
+
};
|
| 13 |
+
|
| 14 |
+
export const imageToBase64 = (img: HTMLImageElement): string => {
|
| 15 |
+
const canvas = document.createElement('canvas');
|
| 16 |
+
canvas.width = img.width;
|
| 17 |
+
canvas.height = img.height;
|
| 18 |
+
const ctx = canvas.getContext('2d');
|
| 19 |
+
ctx?.drawImage(img, 0, 0);
|
| 20 |
+
return canvas.toDataURL('image/png');
|
| 21 |
+
};
|
demo/gradio/frontend/src/components/helpers/maskUtils.tsx
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
// Convert the onnx model mask prediction to ImageData
|
| 8 |
+
function arrayToImageData(input: any, width: number, height: number, binary: boolean) {
|
| 9 |
+
let [r, g, b, a] = [0, 114, 189, 255]; // the masks's blue color
|
| 10 |
+
let [r_bg, g_bg, b_bg, a_bg] = [0, 0, 0, 0]; // the background's white color
|
| 11 |
+
if (binary) {
|
| 12 |
+
[r, g, b, a] = [255, 255, 255, 255]; // black and white
|
| 13 |
+
[r_bg, g_bg, b_bg, a_bg] = [0, 0, 0, 255]; // black and white
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
const arr = new Uint8ClampedArray(4 * width * height).fill(0);
|
| 17 |
+
for (let i = 0; i < input.length; i++) {
|
| 18 |
+
|
| 19 |
+
// Threshold the onnx model mask prediction at 0.0
|
| 20 |
+
// This is equivalent to thresholding the mask using predictor.model.mask_threshold
|
| 21 |
+
// in python
|
| 22 |
+
if (input[i] > 0.0) {
|
| 23 |
+
arr[4 * i + 0] = r;
|
| 24 |
+
arr[4 * i + 1] = g;
|
| 25 |
+
arr[4 * i + 2] = b;
|
| 26 |
+
arr[4 * i + 3] = a;
|
| 27 |
+
} else if (binary){
|
| 28 |
+
arr[4 * i + 0] = r_bg;
|
| 29 |
+
arr[4 * i + 1] = g_bg;
|
| 30 |
+
arr[4 * i + 2] = b_bg;
|
| 31 |
+
arr[4 * i + 3] = a_bg;
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
return new ImageData(arr, height, width);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// Use a Canvas element to produce an image from ImageData
|
| 38 |
+
function imageDataToImage(imageData: ImageData) {
|
| 39 |
+
const canvas = imageDataToCanvas(imageData);
|
| 40 |
+
const image = new Image();
|
| 41 |
+
image.src = canvas.toDataURL();
|
| 42 |
+
return image;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
function imageDataToURL(imageData: ImageData) {
|
| 46 |
+
const canvas = imageDataToCanvas(imageData);
|
| 47 |
+
return canvas.toDataURL();
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
// Canvas elements can be created from ImageData
|
| 51 |
+
function imageDataToCanvas(imageData: ImageData) {
|
| 52 |
+
const canvas = document.createElement("canvas");
|
| 53 |
+
const ctx = canvas.getContext("2d");
|
| 54 |
+
canvas.width = imageData.width;
|
| 55 |
+
canvas.height = imageData.height;
|
| 56 |
+
ctx?.putImageData(imageData, 0, 0);
|
| 57 |
+
return canvas;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
// Convert the onnx model mask output to an HTMLImageElement
|
| 61 |
+
function onnxMaskToImage(input: any, width: number, height: number, binary: boolean) {
|
| 62 |
+
return imageDataToImage(arrayToImageData(input, width, height, binary));
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
export { arrayToImageData, imageDataToImage, onnxMaskToImage, imageDataToURL };
|
demo/gradio/frontend/src/components/helpers/onnxModelAPI.tsx
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import { Tensor } from "onnxruntime-web";
|
| 8 |
+
import { modeDataProps } from "./Interfaces";
|
| 9 |
+
|
| 10 |
+
const modelData = ({ clicks, tensor, modelScale }: modeDataProps) => {
|
| 11 |
+
const imageEmbedding = tensor;
|
| 12 |
+
let pointCoords;
|
| 13 |
+
let pointLabels;
|
| 14 |
+
let pointCoordsTensor;
|
| 15 |
+
let pointLabelsTensor;
|
| 16 |
+
|
| 17 |
+
// Check there are input click prompts
|
| 18 |
+
if (clicks) {
|
| 19 |
+
let n = clicks.length;
|
| 20 |
+
|
| 21 |
+
// If there is no box input, a single padding point with
|
| 22 |
+
// label -1 and coordinates (0.0, 0.0) should be concatenated
|
| 23 |
+
// so initialize the array to support (n + 1) points.
|
| 24 |
+
pointCoords = new Float32Array(2 * (n + 1));
|
| 25 |
+
pointLabels = new Float32Array(n + 1);
|
| 26 |
+
|
| 27 |
+
// Add clicks and scale to what SAM expects
|
| 28 |
+
for (let i = 0; i < n; i++) {
|
| 29 |
+
pointCoords[2 * i] = clicks[i].x * modelScale.samScale;
|
| 30 |
+
pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale;
|
| 31 |
+
pointLabels[i] = clicks[i].clickType;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// Add in the extra point/label when only clicks and no box
|
| 35 |
+
// The extra point is at (0, 0) with label -1
|
| 36 |
+
pointCoords[2 * n] = 0.0;
|
| 37 |
+
pointCoords[2 * n + 1] = 0.0;
|
| 38 |
+
pointLabels[n] = -1.0;
|
| 39 |
+
|
| 40 |
+
// Create the tensor
|
| 41 |
+
pointCoordsTensor = new Tensor("float32", pointCoords, [1, n + 1, 2]);
|
| 42 |
+
pointLabelsTensor = new Tensor("float32", pointLabels, [1, n + 1]);
|
| 43 |
+
}
|
| 44 |
+
const imageSizeTensor = new Tensor("float32", [
|
| 45 |
+
modelScale.height,
|
| 46 |
+
modelScale.width,
|
| 47 |
+
]);
|
| 48 |
+
|
| 49 |
+
if (pointCoordsTensor === undefined || pointLabelsTensor === undefined)
|
| 50 |
+
return;
|
| 51 |
+
|
| 52 |
+
// There is no previous mask, so default to an empty tensor
|
| 53 |
+
const maskInput = new Tensor(
|
| 54 |
+
"float32",
|
| 55 |
+
new Float32Array(256 * 256),
|
| 56 |
+
[1, 1, 256, 256]
|
| 57 |
+
);
|
| 58 |
+
// There is no previous mask, so default to 0
|
| 59 |
+
const hasMaskInput = new Tensor("float32", [0]);
|
| 60 |
+
|
| 61 |
+
return {
|
| 62 |
+
image_embeddings: imageEmbedding,
|
| 63 |
+
point_coords: pointCoordsTensor,
|
| 64 |
+
point_labels: pointLabelsTensor,
|
| 65 |
+
orig_im_size: imageSizeTensor,
|
| 66 |
+
mask_input: maskInput,
|
| 67 |
+
has_mask_input: hasMaskInput,
|
| 68 |
+
};
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
export { modelData };
|
demo/gradio/frontend/src/components/helpers/scaleHelper.tsx
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
// Helper function for handling image scaling needed for SAM
|
| 9 |
+
const handleImageScale = (image: HTMLImageElement) => {
|
| 10 |
+
// Input images to SAM must be resized so the longest side is 1024
|
| 11 |
+
const LONG_SIDE_LENGTH = 1024;
|
| 12 |
+
let w = image.naturalWidth;
|
| 13 |
+
let h = image.naturalHeight;
|
| 14 |
+
const samScale = LONG_SIDE_LENGTH / Math.max(h, w);
|
| 15 |
+
return { height: h, width: w, samScale };
|
| 16 |
+
};
|
| 17 |
+
|
| 18 |
+
export { handleImageScale };
|
demo/gradio/frontend/src/components/hooks/context.tsx
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import React, { useState } from "react";
|
| 8 |
+
import { modelInputProps } from "../helpers/Interfaces";
|
| 9 |
+
import AppContext from "./createContext";
|
| 10 |
+
|
| 11 |
+
const AppContextProvider = (props: {
|
| 12 |
+
children: React.ReactElement<any, string | React.JSXElementConstructor<any>>;
|
| 13 |
+
}) => {
|
| 14 |
+
const [clicks, setClicks] = useState<Array<modelInputProps> | null>(null);
|
| 15 |
+
const [image, setImage] = useState<HTMLImageElement | null>(null);
|
| 16 |
+
const [maskImg, setMaskImg] = useState<HTMLImageElement | null>(null);
|
| 17 |
+
const [maskImgData, setMaskImgData] = useState<string | null>(null);
|
| 18 |
+
const [isClicked, setIsClicked] = useState<boolean>(false);
|
| 19 |
+
|
| 20 |
+
return (
|
| 21 |
+
<AppContext.Provider
|
| 22 |
+
value={{
|
| 23 |
+
clicks: [clicks, setClicks],
|
| 24 |
+
image: [image, setImage],
|
| 25 |
+
maskImg: [maskImg, setMaskImg],
|
| 26 |
+
maskImgData: [maskImgData, setMaskImgData],
|
| 27 |
+
isClicked: [isClicked, setIsClicked],
|
| 28 |
+
}}
|
| 29 |
+
>
|
| 30 |
+
{props.children}
|
| 31 |
+
</AppContext.Provider>
|
| 32 |
+
);
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
export default AppContextProvider;
|
demo/gradio/frontend/src/components/hooks/createContext.tsx
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import { createContext } from "react";
|
| 8 |
+
import { modelInputProps } from "../helpers/Interfaces";
|
| 9 |
+
|
| 10 |
+
interface contextProps {
|
| 11 |
+
clicks: [
|
| 12 |
+
clicks: modelInputProps[] | null,
|
| 13 |
+
setClicks: (e: modelInputProps[] | null) => void
|
| 14 |
+
];
|
| 15 |
+
image: [
|
| 16 |
+
image: HTMLImageElement | null,
|
| 17 |
+
setImage: (e: HTMLImageElement | null) => void
|
| 18 |
+
];
|
| 19 |
+
maskImg: [
|
| 20 |
+
maskImg: HTMLImageElement | null,
|
| 21 |
+
setMaskImg: (e: HTMLImageElement | null) => void
|
| 22 |
+
];
|
| 23 |
+
maskImgData: [
|
| 24 |
+
maskImgData: string | null,
|
| 25 |
+
setMaskImgData: (e: string | null) => void
|
| 26 |
+
];
|
| 27 |
+
isClicked: [
|
| 28 |
+
isClicked: boolean,
|
| 29 |
+
setIsClicked: (e: boolean) => void
|
| 30 |
+
];
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
const AppContext = createContext<contextProps | null>(null);
|
| 34 |
+
|
| 35 |
+
export default AppContext;
|
demo/gradio/frontend/src/index.tsx
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import * as React from "react";
|
| 8 |
+
import { createRoot } from "react-dom/client";
|
| 9 |
+
import AppContextProvider from "./components/hooks/context";
|
| 10 |
+
import App from "./App";
|
| 11 |
+
const container = document.getElementById("root");
|
| 12 |
+
const root = createRoot(container!);
|
| 13 |
+
root.render(
|
| 14 |
+
<AppContextProvider>
|
| 15 |
+
<App/>
|
| 16 |
+
</AppContextProvider>
|
| 17 |
+
);
|
demo/gradio/frontend/src/services/maskApi.tsx
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import axios from 'axios';
|
| 2 |
+
import * as _ from 'underscore';
|
| 3 |
+
|
| 4 |
+
const API_URL = process.env.NODE_ENV === 'development' ? 'http://localhost:7860/gradio_api' : '/gradio_api';
|
| 5 |
+
|
| 6 |
+
export const describeMaskWithoutStreaming = _.throttle(async (
|
| 7 |
+
maskBase64: string,
|
| 8 |
+
imageBase64: string,
|
| 9 |
+
query: string
|
| 10 |
+
): Promise<string> => {
|
| 11 |
+
try {
|
| 12 |
+
const response = await axios.post(`${API_URL}/run/describe_without_streaming`, {
|
| 13 |
+
data: [imageBase64, maskBase64, query],
|
| 14 |
+
});
|
| 15 |
+
|
| 16 |
+
console.log("response", response.data);
|
| 17 |
+
return response.data.data[0];
|
| 18 |
+
} catch (error) {
|
| 19 |
+
console.error('Error describing mask:', error);
|
| 20 |
+
throw error;
|
| 21 |
+
}
|
| 22 |
+
}, 100);
|
| 23 |
+
|
| 24 |
+
export const describeMask = _.throttle(async (
|
| 25 |
+
maskBase64: string,
|
| 26 |
+
imageBase64: string,
|
| 27 |
+
query: string,
|
| 28 |
+
onStreamUpdate: (token: string) => void,
|
| 29 |
+
onQueueUpdate?: (status: {
|
| 30 |
+
inQueue: boolean,
|
| 31 |
+
rank?: number,
|
| 32 |
+
queueSize?: number,
|
| 33 |
+
rankEta?: number | null
|
| 34 |
+
}) => void
|
| 35 |
+
): Promise<string> => {
|
| 36 |
+
console.log("describeMask");
|
| 37 |
+
const initiateResponse = await axios.post(`${API_URL}/call/describe`, {
|
| 38 |
+
data: [imageBase64, maskBase64, query],
|
| 39 |
+
});
|
| 40 |
+
|
| 41 |
+
const eventId = initiateResponse.data.event_id;
|
| 42 |
+
|
| 43 |
+
const response = await axios.get(`${API_URL}/queue/data?session_hash=${eventId}`, {
|
| 44 |
+
headers: {
|
| 45 |
+
'Accept': 'text/event-stream',
|
| 46 |
+
},
|
| 47 |
+
responseType: 'stream',
|
| 48 |
+
adapter: 'fetch',
|
| 49 |
+
});
|
| 50 |
+
|
| 51 |
+
const stream = response.data;
|
| 52 |
+
const reader = stream.pipeThrough(new TextDecoderStream()).getReader();
|
| 53 |
+
|
| 54 |
+
let result = '';
|
| 55 |
+
let partialMessage = '';
|
| 56 |
+
|
| 57 |
+
while (true) {
|
| 58 |
+
const { value, done } = await reader.read();
|
| 59 |
+
if (done) {
|
| 60 |
+
return result;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
// Concatenate with any previous partial message
|
| 64 |
+
const currentData = partialMessage + value;
|
| 65 |
+
const lines = currentData.split('\n');
|
| 66 |
+
|
| 67 |
+
// Save the last line if it's incomplete
|
| 68 |
+
partialMessage = lines[lines.length - 1];
|
| 69 |
+
|
| 70 |
+
// Process all complete lines except the last one
|
| 71 |
+
let eventType = '';
|
| 72 |
+
for (let i = 0; i < lines.length - 1; i++) {
|
| 73 |
+
const line = lines[i];
|
| 74 |
+
if (line.startsWith('event: ')) {
|
| 75 |
+
eventType = line.slice(7); // Remove 'event: ' prefix
|
| 76 |
+
console.log('Event message', line);
|
| 77 |
+
} else if (line.startsWith('data: ')) {
|
| 78 |
+
const eventData = line.slice(6); // Remove 'data: ' prefix
|
| 79 |
+
try {
|
| 80 |
+
let data = JSON.parse(eventData);
|
| 81 |
+
if (data['msg']) {
|
| 82 |
+
eventType = data['msg'];
|
| 83 |
+
if (eventType === 'process_generating') {
|
| 84 |
+
eventType = 'generating';
|
| 85 |
+
data = data['output']['data'];
|
| 86 |
+
} else if (eventType === 'process_completed') {
|
| 87 |
+
eventType = 'complete';
|
| 88 |
+
data = data['output']['data'];
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
if (eventType === 'estimation' && onQueueUpdate) {
|
| 93 |
+
onQueueUpdate({
|
| 94 |
+
inQueue: true,
|
| 95 |
+
rank: data.rank,
|
| 96 |
+
queueSize: data.queue_size,
|
| 97 |
+
rankEta: data.rank_eta
|
| 98 |
+
});
|
| 99 |
+
} else if (eventType === 'process_starts' && onQueueUpdate) {
|
| 100 |
+
onQueueUpdate({
|
| 101 |
+
inQueue: false
|
| 102 |
+
});
|
| 103 |
+
} else if ((eventType === 'generating' || eventType === 'complete') && data[0]) {
|
| 104 |
+
result = data[0];
|
| 105 |
+
onStreamUpdate(data[0]);
|
| 106 |
+
|
| 107 |
+
if (eventType === 'complete') {
|
| 108 |
+
return result;
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
} catch (e) {
|
| 112 |
+
console.log('Error parsing SSE message:', e);
|
| 113 |
+
}
|
| 114 |
+
} else if (line !== '') {
|
| 115 |
+
console.log('Unknown message', line);
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
}, 100);
|
| 120 |
+
|
| 121 |
+
export const imageToSamEmbedding = _.throttle(async (
|
| 122 |
+
imageBase64: string,
|
| 123 |
+
onQueueUpdate?: (status: {
|
| 124 |
+
inQueue: boolean,
|
| 125 |
+
rank?: number,
|
| 126 |
+
queueSize?: number,
|
| 127 |
+
rankEta?: number | null
|
| 128 |
+
}) => void
|
| 129 |
+
): Promise<string> => {
|
| 130 |
+
// First call to initiate the process
|
| 131 |
+
const initiateResponse = await axios.post(`${API_URL}/call/image_to_sam_embedding`, {
|
| 132 |
+
data: [imageBase64]
|
| 133 |
+
});
|
| 134 |
+
|
| 135 |
+
const eventId = initiateResponse.data.event_id;
|
| 136 |
+
|
| 137 |
+
// Get the stream for queue updates and results
|
| 138 |
+
const response = await axios.get(`${API_URL}/queue/data?session_hash=${eventId}`, {
|
| 139 |
+
headers: {
|
| 140 |
+
'Accept': 'text/event-stream',
|
| 141 |
+
},
|
| 142 |
+
responseType: 'stream',
|
| 143 |
+
adapter: 'fetch',
|
| 144 |
+
});
|
| 145 |
+
|
| 146 |
+
const stream = response.data;
|
| 147 |
+
const reader = stream.pipeThrough(new TextDecoderStream()).getReader();
|
| 148 |
+
|
| 149 |
+
let result = '';
|
| 150 |
+
let partialMessage = '';
|
| 151 |
+
|
| 152 |
+
while (true) {
|
| 153 |
+
const { value, done } = await reader.read();
|
| 154 |
+
if (done) {
|
| 155 |
+
return result;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
// Concatenate with any previous partial message
|
| 159 |
+
const currentData = partialMessage + value;
|
| 160 |
+
const lines = currentData.split('\n');
|
| 161 |
+
|
| 162 |
+
// Save the last line if it's incomplete (doesn't end with \n)
|
| 163 |
+
// The endpoint will send an empty line to indicate the end of a message, so it's ok to not process the partial message.
|
| 164 |
+
partialMessage = lines[lines.length - 1];
|
| 165 |
+
|
| 166 |
+
// Process all complete lines except the last one
|
| 167 |
+
let eventType = '';
|
| 168 |
+
for (let i = 0; i < lines.length - 1; i++) {
|
| 169 |
+
const line = lines[i];
|
| 170 |
+
if (line.startsWith('event: ')) {
|
| 171 |
+
eventType = line.slice(7);
|
| 172 |
+
} else if (line.startsWith('data: ')) {
|
| 173 |
+
const eventData = line.slice(6);
|
| 174 |
+
try {
|
| 175 |
+
let data = JSON.parse(eventData);
|
| 176 |
+
if (data['msg']) {
|
| 177 |
+
eventType = data['msg'];
|
| 178 |
+
console.log("Event type:", eventType);
|
| 179 |
+
if (eventType === 'process_completed') {
|
| 180 |
+
eventType = 'complete';
|
| 181 |
+
data = data['output']['data'];
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
if (eventType === 'estimation' && onQueueUpdate) {
|
| 186 |
+
onQueueUpdate({
|
| 187 |
+
inQueue: true,
|
| 188 |
+
rank: data.rank,
|
| 189 |
+
queueSize: data.queue_size,
|
| 190 |
+
rankEta: data.rank_eta
|
| 191 |
+
});
|
| 192 |
+
} else if (eventType === 'process_starts' && onQueueUpdate) {
|
| 193 |
+
onQueueUpdate({
|
| 194 |
+
inQueue: false
|
| 195 |
+
});
|
| 196 |
+
} else if (eventType === 'complete' && data[0]) {
|
| 197 |
+
result = data[0];
|
| 198 |
+
console.log("Result for image to sam embedding:", result);
|
| 199 |
+
return result;
|
| 200 |
+
} else {
|
| 201 |
+
console.log("Unknown event type:", eventType);
|
| 202 |
+
}
|
| 203 |
+
} catch (e) {
|
| 204 |
+
console.log('Error parsing SSE message:', e, 'Raw data:', eventData);
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
}
|
| 209 |
+
}, 100);
|
| 210 |
+
|
| 211 |
+
export { API_URL };
|
demo/gradio/frontend/tailwind.config.js
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
/** @type {import('tailwindcss').Config} */
|
| 8 |
+
module.exports = {
|
| 9 |
+
content: ["./src/**/*.{html,js,tsx}"],
|
| 10 |
+
theme: {},
|
| 11 |
+
plugins: [],
|
| 12 |
+
};
|
demo/gradio/frontend/tsconfig.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"compilerOptions": {
|
| 3 |
+
"lib": ["dom", "dom.iterable", "esnext"],
|
| 4 |
+
"allowJs": true,
|
| 5 |
+
"skipLibCheck": true,
|
| 6 |
+
"strict": true,
|
| 7 |
+
"forceConsistentCasingInFileNames": true,
|
| 8 |
+
"noEmit": false,
|
| 9 |
+
"esModuleInterop": true,
|
| 10 |
+
"module": "esnext",
|
| 11 |
+
"moduleResolution": "node",
|
| 12 |
+
"resolveJsonModule": true,
|
| 13 |
+
"isolatedModules": true,
|
| 14 |
+
"jsx": "react",
|
| 15 |
+
"incremental": true,
|
| 16 |
+
"target": "ESNext",
|
| 17 |
+
"useDefineForClassFields": true,
|
| 18 |
+
"allowSyntheticDefaultImports": true,
|
| 19 |
+
"outDir": "./dist/",
|
| 20 |
+
"sourceMap": true
|
| 21 |
+
},
|
| 22 |
+
"include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", "src"],
|
| 23 |
+
"exclude": ["node_modules"]
|
| 24 |
+
}
|
demo/gradio/frontend/yarn.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
demo/gradio/requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sentencepiece
|
| 2 |
+
accelerate>=0.28.0
|
| 3 |
+
pydantic>=2.10.1
|
| 4 |
+
numpy>=1.23.5,<2.0.0
|
| 5 |
+
pillow>=9.4.0
|
| 6 |
+
gradio>=5.5.0
|
| 7 |
+
requests
|
| 8 |
+
httpx
|
| 9 |
+
uvicorn
|
| 10 |
+
fastapi
|
| 11 |
+
protobuf
|
| 12 |
+
opencv-python
|
| 13 |
+
openai>=1.55.0
|
| 14 |
+
spaces==0.30.4
|
| 15 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
evaluation/DLC-Bench/annotations/annotations.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
evaluation/DLC-Bench/annotations/class_names.json
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"2391781": "wild bird",
|
| 3 |
+
"2580323": "picture/frame",
|
| 4 |
+
"4782942": "megaphone/speaker",
|
| 5 |
+
"6037269": "showerhead",
|
| 6 |
+
"7050495": "handbag",
|
| 7 |
+
"8331699": "computer box",
|
| 8 |
+
"8556676": "apple",
|
| 9 |
+
"11012500": "taco",
|
| 10 |
+
"12348080": "scissors",
|
| 11 |
+
"16951734": "potato",
|
| 12 |
+
"17265254": "rickshaw",
|
| 13 |
+
"18845103": "spoon",
|
| 14 |
+
"20993402": "tape",
|
| 15 |
+
"21529954": "can/container",
|
| 16 |
+
"22879790": "garlic",
|
| 17 |
+
"24010373": "guitar",
|
| 18 |
+
"24694197": "avocado",
|
| 19 |
+
"279135": "ski",
|
| 20 |
+
"622329": "eraser",
|
| 21 |
+
"622332": "stapler",
|
| 22 |
+
"1075308": "monitor/tv",
|
| 23 |
+
"1770866": "sign/banner",
|
| 24 |
+
"2391761": "boat",
|
| 25 |
+
"2580318": "mouse",
|
| 26 |
+
"2588513": "wood block",
|
| 27 |
+
"3993075": "marker",
|
| 28 |
+
"4027486": "truck",
|
| 29 |
+
"4243725": "soap",
|
| 30 |
+
"4781902": "stool",
|
| 31 |
+
"4782949": "drum",
|
| 32 |
+
"5211280": "rice cooker",
|
| 33 |
+
"5718392": "storage box",
|
| 34 |
+
"6037272": "bottle",
|
| 35 |
+
"6820594": "cat",
|
| 36 |
+
"5718424": "sneakers",
|
| 37 |
+
"6055310": "tape measure/ruler",
|
| 38 |
+
"8201777": "van",
|
| 39 |
+
"8331685": "headphone",
|
| 40 |
+
"8331718": "notebook",
|
| 41 |
+
"8557176": "watch",
|
| 42 |
+
"8557195": "toaster",
|
| 43 |
+
"9766617": "duck/goose",
|
| 44 |
+
"11021544": "faucet",
|
| 45 |
+
"11775390": "sandals",
|
| 46 |
+
"11950619": "table tennis paddle",
|
| 47 |
+
"12178946": "bottle",
|
| 48 |
+
"12348079": "scale",
|
| 49 |
+
"14832137": "barrel/bucket",
|
| 50 |
+
"15050320": "wine glass",
|
| 51 |
+
"16957916": "lettuce",
|
| 52 |
+
"17385866": "ice cream",
|
| 53 |
+
"17404769": "suv",
|
| 54 |
+
"18217373": "glasses",
|
| 55 |
+
"19455186": "cart/trolley",
|
| 56 |
+
"19610023": "slippers",
|
| 57 |
+
"19610025": "rabbit",
|
| 58 |
+
"20568676": "pot",
|
| 59 |
+
"21107974": "gavel/mallet",
|
| 60 |
+
"22064315": "antelope",
|
| 61 |
+
"22107522": "bow tie",
|
| 62 |
+
"24017816": "car",
|
| 63 |
+
"24498027": "street lights",
|
| 64 |
+
"24581953": "dog",
|
| 65 |
+
"24786060": "towel",
|
| 66 |
+
"25054869": "toilet",
|
| 67 |
+
"25273553": "tripod",
|
| 68 |
+
"25419495": "tong",
|
| 69 |
+
"25419516": "stuffed toy",
|
| 70 |
+
"25579493": "bowl",
|
| 71 |
+
"297718": "sushi",
|
| 72 |
+
"361105": "herb",
|
| 73 |
+
"1196168": "air conditioner",
|
| 74 |
+
"1894089": "screwdriver",
|
| 75 |
+
"2391780": "wild bird",
|
| 76 |
+
"4502267": "green bean",
|
| 77 |
+
"4604873": "crane",
|
| 78 |
+
"4916799": "globe",
|
| 79 |
+
"5718415": "tent",
|
| 80 |
+
"6012878": "traffic light",
|
| 81 |
+
"6820595": "cat",
|
| 82 |
+
"8556674": "orange/tangerine",
|
| 83 |
+
"8906172": "earphone",
|
| 84 |
+
"10666665": "clock",
|
| 85 |
+
"10811497": "key",
|
| 86 |
+
"11021562": "microwave",
|
| 87 |
+
"11021563": "stove",
|
| 88 |
+
"12348078": "person",
|
| 89 |
+
"13138178": "stool",
|
| 90 |
+
"13187927": "motorcycle",
|
| 91 |
+
"14490578": "seal",
|
| 92 |
+
"14640483": "cutting/chopping board",
|
| 93 |
+
"16010041": "chopsticks",
|
| 94 |
+
"17072759": "belt",
|
| 95 |
+
"17072764": "pear",
|
| 96 |
+
"18301585": "bench",
|
| 97 |
+
"18680641": "carpet",
|
| 98 |
+
"25273528": "hot air balloon",
|
| 99 |
+
"25419509": "fork",
|
| 100 |
+
"25612310": "basket",
|
| 101 |
+
"17265253": "rickshaw"
|
| 102 |
+
}
|
evaluation/DLC-Bench/annotations/qa.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
evaluation/DLC-Bench/eval_gpt_with_image.py
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# *************************************************************************
|
| 2 |
+
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
|
| 3 |
+
# difications”). All Bytedance Inc.'s Modifications are Copyright (2025) B-
|
| 4 |
+
# ytedance Inc..
|
| 5 |
+
# *************************************************************************
|
| 6 |
+
|
| 7 |
+
# Adapted from https://github.com/NVlabs/describe-anything/blob/main/evaluation/eval_model_outputs.py
|
| 8 |
+
|
| 9 |
+
# Copyright 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 10 |
+
#
|
| 11 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 12 |
+
# you may not use this file except in compliance with the License.
|
| 13 |
+
# You may obtain a copy of the License at
|
| 14 |
+
#
|
| 15 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 16 |
+
#
|
| 17 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 18 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 19 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 20 |
+
# See the License for the specific language governing permissions and
|
| 21 |
+
# limitations under the License.
|
| 22 |
+
#
|
| 23 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import base64
|
| 27 |
+
import io
|
| 28 |
+
import json
|
| 29 |
+
import os
|
| 30 |
+
|
| 31 |
+
import inflect
|
| 32 |
+
import numpy as np
|
| 33 |
+
import openai
|
| 34 |
+
from PIL import Image
|
| 35 |
+
from pycocotools.coco import COCO
|
| 36 |
+
from tqdm import tqdm
|
| 37 |
+
|
| 38 |
+
# Define Azure OpenAI details
|
| 39 |
+
model_name = "gpt-4o-2024-11-20"
|
| 40 |
+
max_tokens = 1000 # range: [1, 4095]
|
| 41 |
+
|
| 42 |
+
# Initialize the Azure client
|
| 43 |
+
client = openai.AzureOpenAI(
|
| 44 |
+
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
| 45 |
+
api_key=os.getenv("AZURE_OPENAI_KEY"),
|
| 46 |
+
api_version="2024-03-01-preview",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
prompt_eval = """Answer the multiple-choice question based on the text description of an object in this image. You need to follow these rules:
|
| 50 |
+
1. Do not output any reasoning. Do not perform correction. Please output exactly one answer from the choices for each question. Do not repeat the question.
|
| 51 |
+
2. There is no need for exact matching. Please choose the closest option based on the description.
|
| 52 |
+
|
| 53 |
+
The description is:
|
| 54 |
+
{pred_caption}
|
| 55 |
+
|
| 56 |
+
From the description above, please answer the following question with one of the choices:
|
| 57 |
+
{question_text_str}
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
api_call_count = 0
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def query(prompt, images, temperature, max_tokens):
|
| 64 |
+
global api_call_count
|
| 65 |
+
if api_call_count >= args.api_call_limit:
|
| 66 |
+
raise Exception("API call limit reached")
|
| 67 |
+
|
| 68 |
+
api_call_count += 1
|
| 69 |
+
content = [
|
| 70 |
+
{"type": "text", "text": "The image:\n"},
|
| 71 |
+
{
|
| 72 |
+
"type": "image_url",
|
| 73 |
+
"image_url": {"url": f"data:image/jpeg;base64,{images[0]}"},
|
| 74 |
+
},
|
| 75 |
+
{"type": "text", "text": "\nThe mask of the image:\n"},
|
| 76 |
+
{
|
| 77 |
+
"type": "image_url",
|
| 78 |
+
"image_url": {"url": f"data:image/jpeg;base64,{images[1]}"},
|
| 79 |
+
},
|
| 80 |
+
{"type": "text", "text": f"\n{prompt}\n"},
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
# Adjusted to use the Azure OpenAI client with the specified parameters
|
| 84 |
+
response = client.chat.completions.create(
|
| 85 |
+
model=model_name,
|
| 86 |
+
messages=[{"role": "user", "content": content}],
|
| 87 |
+
max_tokens=max_tokens,
|
| 88 |
+
temperature=temperature,
|
| 89 |
+
top_p=1,
|
| 90 |
+
frequency_penalty=0,
|
| 91 |
+
presence_penalty=0,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
message = response.choices[0].message.content
|
| 95 |
+
return message
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def parse_pred(pred, choices, key):
|
| 99 |
+
pred = pred.strip().lower()
|
| 100 |
+
substr_indices = []
|
| 101 |
+
for index, choice in enumerate(choices):
|
| 102 |
+
choice = choice.strip().lower()
|
| 103 |
+
prefix = "abcde"[index]
|
| 104 |
+
if choice == pred or pred == f"{prefix}. {choice}" or pred == prefix:
|
| 105 |
+
return index
|
| 106 |
+
if choice in pred:
|
| 107 |
+
substr_indices.append((index, pred.index(choice), len(choice)))
|
| 108 |
+
|
| 109 |
+
if len(substr_indices) == 1:
|
| 110 |
+
return substr_indices[0][0]
|
| 111 |
+
|
| 112 |
+
choices_label = "abcde"
|
| 113 |
+
if pred[0] in choices_label and pred[1] == ".":
|
| 114 |
+
ret = choices_label.index(pred[0])
|
| 115 |
+
return ret
|
| 116 |
+
|
| 117 |
+
if substr_indices:
|
| 118 |
+
if len(substr_indices) > 1:
|
| 119 |
+
ret, ret_pos, _ = max(substr_indices, key=lambda x: x[1])
|
| 120 |
+
max_items = [item for item in substr_indices if item[1] == ret_pos]
|
| 121 |
+
if len(max_items) > 1:
|
| 122 |
+
ret = max(max_items, key=lambda x: x[2])[0]
|
| 123 |
+
return ret
|
| 124 |
+
else:
|
| 125 |
+
ret = substr_indices[0][0]
|
| 126 |
+
return ret
|
| 127 |
+
|
| 128 |
+
match_lengths = []
|
| 129 |
+
for index, choice in enumerate(choices):
|
| 130 |
+
choice = choice.strip().lower()
|
| 131 |
+
if pred in choice:
|
| 132 |
+
match_lengths.append((index, len(choice)))
|
| 133 |
+
if match_lengths:
|
| 134 |
+
if len(match_lengths) > 1:
|
| 135 |
+
ret = max(match_lengths, key=lambda x: x[1])[0]
|
| 136 |
+
else:
|
| 137 |
+
ret = match_lengths[0][0]
|
| 138 |
+
return ret
|
| 139 |
+
|
| 140 |
+
if pred[0] in "abcde" and (len(pred.strip()) == 1 or pred[1] == "\n"):
|
| 141 |
+
ret = "abcde".index(pred[0])
|
| 142 |
+
return ret
|
| 143 |
+
|
| 144 |
+
return None
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def evaluate(
|
| 148 |
+
question_dicts,
|
| 149 |
+
pred_caption,
|
| 150 |
+
temperature,
|
| 151 |
+
max_tokens,
|
| 152 |
+
images,
|
| 153 |
+
*,
|
| 154 |
+
response_override=None,
|
| 155 |
+
key,
|
| 156 |
+
verbose=False,
|
| 157 |
+
) -> dict:
|
| 158 |
+
pred_answers = []
|
| 159 |
+
prompt = []
|
| 160 |
+
response = []
|
| 161 |
+
for index, question_dict in enumerate(question_dicts):
|
| 162 |
+
question_text_str = f"{question_dict['question']}\n"
|
| 163 |
+
choices_text = ""
|
| 164 |
+
for choice_index, (choice, score) in enumerate(question_dict["choices"]):
|
| 165 |
+
choice_index = "ABCDE"[choice_index]
|
| 166 |
+
choices_text += f"{choice_index}. {choice}\n"
|
| 167 |
+
question_text_str += choices_text
|
| 168 |
+
prompt_item = prompt_eval.format(
|
| 169 |
+
pred_caption=pred_caption, question_text_str=question_text_str.strip()
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if (
|
| 173 |
+
response_override is None
|
| 174 |
+
or len(response_override) < index
|
| 175 |
+
or response_override[index] is None
|
| 176 |
+
):
|
| 177 |
+
response_item = query(prompt_item, images, temperature, max_tokens)
|
| 178 |
+
else:
|
| 179 |
+
response_item = response_override[index]
|
| 180 |
+
|
| 181 |
+
pred_answer = response_item.strip()
|
| 182 |
+
pred_answers.append(pred_answer)
|
| 183 |
+
prompt.append(prompt_item)
|
| 184 |
+
response.append(response_item)
|
| 185 |
+
|
| 186 |
+
pred_indices = [
|
| 187 |
+
parse_pred(
|
| 188 |
+
pred_answer, [choice for choice, score in question_dict["choices"]], key
|
| 189 |
+
)
|
| 190 |
+
for pred_answer, question_dict in zip(pred_answers, question_dicts)
|
| 191 |
+
]
|
| 192 |
+
parsed_eval_results = [
|
| 193 |
+
question_dict["choices"][pred_index][1] if pred_index is not None else 0
|
| 194 |
+
for pred_index, question_dict in zip(pred_indices, question_dicts)
|
| 195 |
+
]
|
| 196 |
+
|
| 197 |
+
parsed_eval_results_positives = []
|
| 198 |
+
parsed_eval_results_negatives = []
|
| 199 |
+
details_positives = []
|
| 200 |
+
details_negatives = []
|
| 201 |
+
details_recognition = []
|
| 202 |
+
recognition_result = None
|
| 203 |
+
for question_index, (parsed_eval_result, question_dict) in enumerate(
|
| 204 |
+
zip(parsed_eval_results, question_dicts)
|
| 205 |
+
):
|
| 206 |
+
if question_dict["type"] == "recognition":
|
| 207 |
+
if parsed_eval_result == "correct":
|
| 208 |
+
recognition_result = True
|
| 209 |
+
elif parsed_eval_result == "incorrect":
|
| 210 |
+
recognition_result = False
|
| 211 |
+
print(
|
| 212 |
+
f"Recognition is incorrect for key {key}, setting score to at most 0 for all questions"
|
| 213 |
+
)
|
| 214 |
+
else:
|
| 215 |
+
raise ValueError(f"Invalid recognition result: {parsed_eval_result}")
|
| 216 |
+
details_recognition.append(
|
| 217 |
+
{
|
| 218 |
+
**question_dict,
|
| 219 |
+
"pred_answer": pred_answers[question_index],
|
| 220 |
+
"pred_index": pred_indices[question_index],
|
| 221 |
+
"eval_result": parsed_eval_result,
|
| 222 |
+
}
|
| 223 |
+
)
|
| 224 |
+
elif question_dict["type"] == "negative":
|
| 225 |
+
if recognition_result is False:
|
| 226 |
+
parsed_eval_result = min(0, parsed_eval_result)
|
| 227 |
+
parsed_eval_results_negatives.append(parsed_eval_result)
|
| 228 |
+
|
| 229 |
+
details_negatives.append(
|
| 230 |
+
{
|
| 231 |
+
**question_dict,
|
| 232 |
+
"pred_answer": pred_answers[question_index],
|
| 233 |
+
"pred_index": pred_indices[question_index],
|
| 234 |
+
"eval_result": parsed_eval_result,
|
| 235 |
+
}
|
| 236 |
+
)
|
| 237 |
+
elif question_dict["type"] == "positive":
|
| 238 |
+
if recognition_result is False:
|
| 239 |
+
parsed_eval_result = min(0, parsed_eval_result)
|
| 240 |
+
parsed_eval_results_positives.append(parsed_eval_result)
|
| 241 |
+
|
| 242 |
+
details_positives.append(
|
| 243 |
+
{
|
| 244 |
+
**question_dict,
|
| 245 |
+
"pred_answer": pred_answers[question_index],
|
| 246 |
+
"pred_index": pred_indices[question_index],
|
| 247 |
+
"eval_result": parsed_eval_result,
|
| 248 |
+
}
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
score_pos = sum(parsed_eval_results_positives) / len(parsed_eval_results_positives)
|
| 252 |
+
score_neg = (
|
| 253 |
+
sum(parsed_eval_results_negatives) / len(parsed_eval_results_negatives)
|
| 254 |
+
if parsed_eval_results_negatives
|
| 255 |
+
else None
|
| 256 |
+
)
|
| 257 |
+
score = (
|
| 258 |
+
sum(parsed_eval_results_positives) + sum(parsed_eval_results_negatives)
|
| 259 |
+
) / (len(parsed_eval_results_positives) + len(parsed_eval_results_negatives))
|
| 260 |
+
|
| 261 |
+
info = dict(
|
| 262 |
+
details_positives=details_positives,
|
| 263 |
+
details_negatives=details_negatives,
|
| 264 |
+
details_recognition=details_recognition,
|
| 265 |
+
prompt=prompt,
|
| 266 |
+
response=response,
|
| 267 |
+
score=score,
|
| 268 |
+
score_pos=score_pos,
|
| 269 |
+
score_neg=score_neg,
|
| 270 |
+
recognition_result=recognition_result,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
return info
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def is_plural(string):
|
| 277 |
+
if string == "bus":
|
| 278 |
+
return False
|
| 279 |
+
return p.singular_noun(string) is not False
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def select_ann(img_id, area_min=None, area_max=None):
|
| 283 |
+
cat_ids = coco.getCatIds()
|
| 284 |
+
ann_ids = coco.getAnnIds(imgIds=[img_id], catIds=cat_ids, iscrowd=None)
|
| 285 |
+
|
| 286 |
+
if area_min is not None:
|
| 287 |
+
ann_ids = [
|
| 288 |
+
ann_id for ann_id in ann_ids if coco.anns[ann_id]["area"] >= area_min
|
| 289 |
+
]
|
| 290 |
+
|
| 291 |
+
if area_max is not None:
|
| 292 |
+
ann_ids = [
|
| 293 |
+
ann_id for ann_id in ann_ids if coco.anns[ann_id]["area"] <= area_max
|
| 294 |
+
]
|
| 295 |
+
|
| 296 |
+
return ann_ids
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def mask_to_box(mask_np):
|
| 300 |
+
mask_coords = np.argwhere(mask_np)
|
| 301 |
+
y0, x0 = mask_coords.min(axis=0)
|
| 302 |
+
y1, x1 = mask_coords.max(axis=0) + 1
|
| 303 |
+
|
| 304 |
+
h = y1 - y0
|
| 305 |
+
w = x1 - x0
|
| 306 |
+
|
| 307 |
+
return x0, y0, w, h
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def encode_pil_image_to_base64(pil_image):
|
| 311 |
+
buffered = io.BytesIO()
|
| 312 |
+
pil_image.save(buffered, format="PNG")
|
| 313 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 314 |
+
return img_str
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
if __name__ == "__main__":
|
| 318 |
+
parser = argparse.ArgumentParser(description="Evaluate model outputs")
|
| 319 |
+
parser.add_argument(
|
| 320 |
+
"--pred", type=str, help="Path to the prediction JSON file", required=True
|
| 321 |
+
)
|
| 322 |
+
parser.add_argument(
|
| 323 |
+
"--qa",
|
| 324 |
+
type=str,
|
| 325 |
+
help="Path to the reference QA file",
|
| 326 |
+
default="evaluation/DLC-Bench/annotations/qa.json",
|
| 327 |
+
)
|
| 328 |
+
parser.add_argument(
|
| 329 |
+
"--class-names",
|
| 330 |
+
type=str,
|
| 331 |
+
help="Path to the class names JSON file",
|
| 332 |
+
default="evaluation/DLC-Bench/annotations/class_names.json",
|
| 333 |
+
)
|
| 334 |
+
parser.add_argument(
|
| 335 |
+
"--api-call-limit", type=int, default=1000, help="API call limit"
|
| 336 |
+
)
|
| 337 |
+
parser.add_argument(
|
| 338 |
+
"--suffix", type=str, default="", help="Suffix for the evaluation file"
|
| 339 |
+
)
|
| 340 |
+
parser.add_argument("--verbose", action="store_true", help="Enable verbose mode")
|
| 341 |
+
parser.add_argument(
|
| 342 |
+
"--quiet", action="store_true", help="Enable quiet mode (result only)"
|
| 343 |
+
)
|
| 344 |
+
parser.add_argument("--csv", action="store_true", help="Output results as CSV only")
|
| 345 |
+
parser.add_argument(
|
| 346 |
+
"--data-root", type=str, default="evaluation/DLC-Bench/annotations"
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
args = parser.parse_args()
|
| 350 |
+
|
| 351 |
+
eval_file = os.path.splitext(args.pred)[0] + f"_eval_gpt{args.suffix}.json"
|
| 352 |
+
|
| 353 |
+
eval_results = {}
|
| 354 |
+
|
| 355 |
+
if os.path.exists(eval_file):
|
| 356 |
+
with open(eval_file) as f:
|
| 357 |
+
eval_results = json.load(f)
|
| 358 |
+
|
| 359 |
+
with open(args.pred) as f:
|
| 360 |
+
data_pred = json.load(f)
|
| 361 |
+
|
| 362 |
+
with open(args.qa) as f:
|
| 363 |
+
data_qa = json.load(f)
|
| 364 |
+
|
| 365 |
+
with open(args.class_names) as f:
|
| 366 |
+
data_class_names = json.load(f)
|
| 367 |
+
|
| 368 |
+
scores = {}
|
| 369 |
+
scores_pos = {}
|
| 370 |
+
scores_neg = {}
|
| 371 |
+
|
| 372 |
+
keys = list(data_qa.keys())
|
| 373 |
+
p = inflect.engine()
|
| 374 |
+
|
| 375 |
+
annotations_file = os.path.join(args.data_root, "annotations.json")
|
| 376 |
+
coco = COCO(annotations_file)
|
| 377 |
+
|
| 378 |
+
with open(annotations_file, "r") as f:
|
| 379 |
+
data = json.load(f)
|
| 380 |
+
|
| 381 |
+
missing_key_count = 0
|
| 382 |
+
for key in tqdm(keys, disable=args.quiet):
|
| 383 |
+
key = str(key)
|
| 384 |
+
for item in data["annotations"]:
|
| 385 |
+
if int(item["id"]) == int(key):
|
| 386 |
+
img_id = item["image_id"]
|
| 387 |
+
|
| 388 |
+
img_info = coco.loadImgs(img_id)[0]
|
| 389 |
+
img_path = os.path.join(args.data_root, "images", img_info["file_name"])
|
| 390 |
+
img = Image.open(img_path)
|
| 391 |
+
|
| 392 |
+
anns = coco.loadAnns([int(key)])
|
| 393 |
+
mask_np = coco.annToMask(anns[0]).astype(bool)
|
| 394 |
+
|
| 395 |
+
img_np = np.array(img)
|
| 396 |
+
pil_mask = Image.fromarray((mask_np * 255).astype(np.uint8))
|
| 397 |
+
|
| 398 |
+
assert (
|
| 399 |
+
img_np.shape[:2] == mask_np.shape
|
| 400 |
+
), f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
|
| 401 |
+
img_h, img_w = img_np.shape[:2]
|
| 402 |
+
|
| 403 |
+
x0, y0, w, h = mask_to_box(mask_np)
|
| 404 |
+
xc, yc = x0 + w / 2, y0 + h / 2
|
| 405 |
+
|
| 406 |
+
# focal_crop: need to have at least min_box_w and min_box_h pixels, otherwise resizing to (384, 384) leads to artifacts that may be OOD
|
| 407 |
+
w, h = max(w, 56), max(h, 56)
|
| 408 |
+
x0, y0 = int(xc - w / 2), int(yc - h / 2)
|
| 409 |
+
|
| 410 |
+
# focal crop
|
| 411 |
+
cropped_img_np = img_np[
|
| 412 |
+
max(y0 - h, 0) : min(y0 + 2 * h, img_h),
|
| 413 |
+
max(x0 - w, 0) : min(x0 + 2 * w, img_w),
|
| 414 |
+
]
|
| 415 |
+
cropped_mask_np = mask_np[
|
| 416 |
+
max(y0 - h, 0) : min(y0 + 2 * h, img_h),
|
| 417 |
+
max(x0 - w, 0) : min(x0 + 2 * w, img_w),
|
| 418 |
+
]
|
| 419 |
+
|
| 420 |
+
cropped_pil_img = Image.fromarray(cropped_img_np)
|
| 421 |
+
cropped_pil_mask = Image.fromarray((cropped_mask_np * 255).astype(np.uint8))
|
| 422 |
+
|
| 423 |
+
base64_image = encode_pil_image_to_base64(img)
|
| 424 |
+
base64_mask = encode_pil_image_to_base64(pil_mask)
|
| 425 |
+
base64_cropped_image = encode_pil_image_to_base64(cropped_pil_img)
|
| 426 |
+
base64_cropped_mask = encode_pil_image_to_base64(cropped_pil_mask)
|
| 427 |
+
images = [base64_cropped_image, base64_cropped_mask]
|
| 428 |
+
|
| 429 |
+
if key in eval_results:
|
| 430 |
+
response_override = eval_results[key]["response"]
|
| 431 |
+
else:
|
| 432 |
+
response_override = None
|
| 433 |
+
|
| 434 |
+
if key not in data_pred:
|
| 435 |
+
if args.default_prediction is None:
|
| 436 |
+
raise ValueError(f"Key {key} not found in prediction data")
|
| 437 |
+
else:
|
| 438 |
+
pred_value = args.default_prediction
|
| 439 |
+
missing_key_count += 1
|
| 440 |
+
else:
|
| 441 |
+
pred_value = data_pred[key]
|
| 442 |
+
|
| 443 |
+
class_name = data_class_names[key]
|
| 444 |
+
recognition_question = f"The object in the image is {class_name}. Based on the image, is it likely that the object in the description is given class: {class_name} or object of a similar type?"
|
| 445 |
+
recognition_question_dict = {
|
| 446 |
+
"question": recognition_question,
|
| 447 |
+
"choices": [("Yes", "correct"), ("No", "incorrect")],
|
| 448 |
+
"type": "recognition",
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
question_dicts = [recognition_question_dict, *data_qa[key]]
|
| 452 |
+
info = evaluate(
|
| 453 |
+
question_dicts=question_dicts,
|
| 454 |
+
pred_caption=pred_value,
|
| 455 |
+
images=images,
|
| 456 |
+
temperature=0.0,
|
| 457 |
+
max_tokens=300,
|
| 458 |
+
response_override=response_override,
|
| 459 |
+
key=key,
|
| 460 |
+
)
|
| 461 |
+
score = info["score"]
|
| 462 |
+
scores[key] = score
|
| 463 |
+
scores_pos[key] = info["score_pos"]
|
| 464 |
+
scores_neg[key] = info["score_neg"]
|
| 465 |
+
eval_results[key] = {"pred": pred_value, **info}
|
| 466 |
+
|
| 467 |
+
avg_score_pos = sum(scores_pos.values()) / len(scores_pos)
|
| 468 |
+
avg_score_neg = sum(
|
| 469 |
+
[item for item in scores_neg.values() if item is not None]
|
| 470 |
+
) / len(scores_neg)
|
| 471 |
+
eval_results["avg_pos"] = avg_score_pos
|
| 472 |
+
eval_results["avg_neg"] = avg_score_neg
|
| 473 |
+
|
| 474 |
+
with open(eval_file, "w") as f:
|
| 475 |
+
json.dump(eval_results, f, indent=4)
|
| 476 |
+
|
| 477 |
+
print(f"Average Positive Score: {avg_score_pos:.3f}")
|
| 478 |
+
print(f"Average Negative Score: {avg_score_neg:.3f}")
|
| 479 |
+
print(
|
| 480 |
+
f"Summary (Pos\tNeg\tAvg(Pos, Neg)):\t{avg_score_pos:.3f},\t{avg_score_neg:.3f},\t{(avg_score_pos + avg_score_neg) / 2:.3f}"
|
| 481 |
+
)
|
| 482 |
+
print(f"QA Scores: {scores}")
|
| 483 |
+
print(f"Evaluation data saved to {eval_file}")
|
evaluation/DLC-Bench/eval_llama_without_image.py
ADDED
|
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
import inflect
|
| 22 |
+
from openai import OpenAI
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
|
| 25 |
+
prompt_eval = """Answer the multiple-choice question based on the text description of an object in an image. You need to follow these rules:
|
| 26 |
+
1. Do not output any reasoning. Do not perform correction. Please output exactly one answer from the choices for each question. Do not repeat the question.
|
| 27 |
+
2. There is no need for exact matching. Please choose the closest option based on the description.
|
| 28 |
+
|
| 29 |
+
The description is:
|
| 30 |
+
{pred_caption}
|
| 31 |
+
|
| 32 |
+
From the description above, please answer the following question with one of the choices:
|
| 33 |
+
{question_text_str}
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
api_call_count = 0
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def query(prompt, temperature, max_tokens, model):
|
| 40 |
+
global api_call_count
|
| 41 |
+
if api_call_count >= args.api_call_limit:
|
| 42 |
+
raise Exception("API call limit reached")
|
| 43 |
+
|
| 44 |
+
api_call_count += 1
|
| 45 |
+
response = client.chat.completions.create(
|
| 46 |
+
model=model,
|
| 47 |
+
messages=[{"role": "user", "content": prompt}],
|
| 48 |
+
temperature=temperature,
|
| 49 |
+
max_tokens=max_tokens,
|
| 50 |
+
top_p=1,
|
| 51 |
+
frequency_penalty=0,
|
| 52 |
+
presence_penalty=0,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
message = response.choices[0].message.content
|
| 56 |
+
return message
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def parse_pred(pred, choices, key):
|
| 60 |
+
pred = pred.strip().lower()
|
| 61 |
+
substr_indices = []
|
| 62 |
+
for index, choice in enumerate(choices):
|
| 63 |
+
choice = choice.strip().lower()
|
| 64 |
+
prefix = "abcde"[index]
|
| 65 |
+
if choice == pred or pred == f"{prefix}. {choice}" or pred == prefix:
|
| 66 |
+
return index
|
| 67 |
+
if choice in pred:
|
| 68 |
+
substr_indices.append((index, pred.index(choice), len(choice)))
|
| 69 |
+
|
| 70 |
+
# Only one match (choice in prediction)
|
| 71 |
+
if len(substr_indices) == 1:
|
| 72 |
+
return substr_indices[0][0]
|
| 73 |
+
|
| 74 |
+
# Prefix match
|
| 75 |
+
choices_label = "abcde"
|
| 76 |
+
if pred[0] in choices_label and pred[1] == ".":
|
| 77 |
+
ret = choices_label.index(pred[0])
|
| 78 |
+
# print(f"{key}: Chosen {ret} for pred: {pred}, choices: {choices}")
|
| 79 |
+
# print(f"{key}: More than one occurrence found or no substr of choice in pred: pred {pred}, choices {choices}, substr indices: {substr_indices}, returning {ret} (choice {choices_label})")
|
| 80 |
+
return ret
|
| 81 |
+
|
| 82 |
+
# More than one match
|
| 83 |
+
if substr_indices:
|
| 84 |
+
# Return the last occurrence if there are multiple matches (referenced from MMMU): https://github.com/MMMU-Benchmark/MMMU/blob/b119c944a15c145c10d52a58e841c5b9cb6a535e/eval/utils/eval_utils.py#L57
|
| 85 |
+
if len(substr_indices) > 1:
|
| 86 |
+
ret, ret_pos, _ = max(substr_indices, key=lambda x: x[1])
|
| 87 |
+
max_items = [item for item in substr_indices if item[1] == ret_pos]
|
| 88 |
+
if len(max_items) > 1:
|
| 89 |
+
# select the item with the longest match if there are multiple occurrence at the same place
|
| 90 |
+
ret = max(max_items, key=lambda x: x[2])[0]
|
| 91 |
+
print(
|
| 92 |
+
f"{key}: More than one occurrence found: pred {pred}, choices {choices}, {substr_indices}, returning {ret} (choice {choices_label})"
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
ret = substr_indices[0][0]
|
| 96 |
+
return ret
|
| 97 |
+
|
| 98 |
+
# Parse the case where pred is a substr of choice
|
| 99 |
+
match_lengths = []
|
| 100 |
+
for index, choice in enumerate(choices):
|
| 101 |
+
choice = choice.strip().lower()
|
| 102 |
+
if pred in choice:
|
| 103 |
+
match_lengths.append((index, len(choice)))
|
| 104 |
+
if match_lengths:
|
| 105 |
+
# Return the longest matched substring if there are multiple matches
|
| 106 |
+
if len(match_lengths) > 1:
|
| 107 |
+
ret = max(match_lengths, key=lambda x: x[1])[0]
|
| 108 |
+
print(
|
| 109 |
+
f"{key}: More than one occurrence found: pred {pred}, choices {choices}, {match_lengths}, returning {ret}"
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
ret = match_lengths[0][0]
|
| 113 |
+
return ret
|
| 114 |
+
|
| 115 |
+
if pred[0] in "abcde" and (len(pred.strip()) == 1 or pred[1] == "\n"):
|
| 116 |
+
ret = "abcde".index(pred[0])
|
| 117 |
+
print(f"{key}: Chosen {ret} for pred: {pred}, choices: {choices}")
|
| 118 |
+
return ret
|
| 119 |
+
|
| 120 |
+
print(f"*WARNING*: {key}: No match found. Pred: {pred}, choices: {choices}")
|
| 121 |
+
|
| 122 |
+
# If no matching choice is found, raise an error.
|
| 123 |
+
# raise ValueError(f"No match found. Pred: {pred}, Choices: {choices}")
|
| 124 |
+
# If no matching choice is found, return None (treat as no mention, score 0).
|
| 125 |
+
return None
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def evaluate(
|
| 129 |
+
question_dicts,
|
| 130 |
+
pred_caption,
|
| 131 |
+
temperature,
|
| 132 |
+
max_tokens,
|
| 133 |
+
model,
|
| 134 |
+
*,
|
| 135 |
+
response_override=None,
|
| 136 |
+
key,
|
| 137 |
+
verbose=False,
|
| 138 |
+
) -> dict:
|
| 139 |
+
pred_answers = []
|
| 140 |
+
prompt = []
|
| 141 |
+
response = []
|
| 142 |
+
for index, question_dict in enumerate(question_dicts):
|
| 143 |
+
question_text_str = f"{question_dict['question']}\n"
|
| 144 |
+
choices_text = ""
|
| 145 |
+
for choice_index, (choice, score) in enumerate(question_dict["choices"]):
|
| 146 |
+
choice_index = "ABCDE"[choice_index]
|
| 147 |
+
choices_text += f"{choice_index}. {choice}\n"
|
| 148 |
+
question_text_str += choices_text
|
| 149 |
+
prompt_item = prompt_eval.format(
|
| 150 |
+
pred_caption=pred_caption, question_text_str=question_text_str.strip()
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if (
|
| 154 |
+
response_override is None
|
| 155 |
+
or len(response_override) < index
|
| 156 |
+
or response_override[index] is None
|
| 157 |
+
):
|
| 158 |
+
response_item = query(prompt_item, temperature, max_tokens, model)
|
| 159 |
+
# print(f"Prompt:\n{prompt_item}")
|
| 160 |
+
# print(f"Output: {response_item}")
|
| 161 |
+
else:
|
| 162 |
+
response_item = response_override[index]
|
| 163 |
+
|
| 164 |
+
pred_answer = response_item.strip()
|
| 165 |
+
pred_answers.append(pred_answer)
|
| 166 |
+
prompt.append(prompt_item)
|
| 167 |
+
response.append(response_item)
|
| 168 |
+
|
| 169 |
+
assert len(pred_answers) == len(
|
| 170 |
+
question_dicts
|
| 171 |
+
), f"Length mismatch for key {key} question {index}: pred: {len(pred_answers)} vs question: {len(question_dicts)}"
|
| 172 |
+
pred_indices = [
|
| 173 |
+
parse_pred(
|
| 174 |
+
pred_answer, [choice for choice, score in question_dict["choices"]], key
|
| 175 |
+
)
|
| 176 |
+
for pred_answer, question_dict in zip(pred_answers, question_dicts)
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
assert len(pred_indices) == len(
|
| 180 |
+
question_dicts
|
| 181 |
+
), f"Length mismatch for key {key} question {index}: pred: {len(pred_indices)} vs question: {len(question_dicts)}"
|
| 182 |
+
|
| 183 |
+
# If no matching, treat as no mention.
|
| 184 |
+
try:
|
| 185 |
+
parsed_eval_results = [
|
| 186 |
+
question_dict["choices"][pred_index][1] if pred_index is not None else 0
|
| 187 |
+
for pred_index, question_dict in zip(pred_indices, question_dicts)
|
| 188 |
+
]
|
| 189 |
+
except IndexError as e:
|
| 190 |
+
print(
|
| 191 |
+
f"Error: {e}, key: {key}, pred_indices: {pred_indices}, question_dicts: {question_dicts}"
|
| 192 |
+
)
|
| 193 |
+
raise e
|
| 194 |
+
|
| 195 |
+
parsed_eval_results_positives = []
|
| 196 |
+
parsed_eval_results_negatives = []
|
| 197 |
+
|
| 198 |
+
details_positives = []
|
| 199 |
+
details_negatives = []
|
| 200 |
+
details_recognition = []
|
| 201 |
+
recognition_result = None
|
| 202 |
+
for question_index, (parsed_eval_result, question_dict) in enumerate(
|
| 203 |
+
zip(parsed_eval_results, question_dicts)
|
| 204 |
+
):
|
| 205 |
+
if question_dict["type"] == "recognition":
|
| 206 |
+
# If the type is recognition, it's the recognition question.
|
| 207 |
+
if parsed_eval_result == "correct":
|
| 208 |
+
recognition_result = True
|
| 209 |
+
elif parsed_eval_result == "incorrect":
|
| 210 |
+
recognition_result = False
|
| 211 |
+
print(
|
| 212 |
+
f"Recognition is incorrect for key {key}, setting score to at most 0 for all questions"
|
| 213 |
+
)
|
| 214 |
+
else:
|
| 215 |
+
raise ValueError(f"Invalid recognition result: {parsed_eval_result}")
|
| 216 |
+
details_recognition.append(
|
| 217 |
+
{
|
| 218 |
+
**question_dict,
|
| 219 |
+
"pred_answer": pred_answers[question_index],
|
| 220 |
+
"pred_index": pred_indices[question_index],
|
| 221 |
+
"eval_result": parsed_eval_result,
|
| 222 |
+
}
|
| 223 |
+
)
|
| 224 |
+
elif question_dict["type"] == "negative":
|
| 225 |
+
assert (
|
| 226 |
+
recognition_result is not None
|
| 227 |
+
), f"Negative questions come before recognition question in {key}, {question_dicts}"
|
| 228 |
+
if recognition_result is False:
|
| 229 |
+
if verbose:
|
| 230 |
+
print(
|
| 231 |
+
f"Processing negative question {question_index} for key {key}, setting score to at most 0 since recognition is incorrect"
|
| 232 |
+
)
|
| 233 |
+
parsed_eval_result = min(0, parsed_eval_result)
|
| 234 |
+
# If the type is negative, it's one of the negatives.
|
| 235 |
+
parsed_eval_results_negatives.append(parsed_eval_result)
|
| 236 |
+
details_negatives.append(
|
| 237 |
+
{
|
| 238 |
+
**question_dict,
|
| 239 |
+
"pred_answer": pred_answers[question_index],
|
| 240 |
+
"pred_index": pred_indices[question_index],
|
| 241 |
+
# Subtract 1 to get the index in the original question list (excluding the recognition question)
|
| 242 |
+
"question_index": question_index - 1,
|
| 243 |
+
"eval_result": parsed_eval_result,
|
| 244 |
+
}
|
| 245 |
+
)
|
| 246 |
+
elif question_dict["type"] == "positive":
|
| 247 |
+
assert (
|
| 248 |
+
recognition_result is not None
|
| 249 |
+
), f"Positive questions come before recognition question in {key}, {question_dicts}"
|
| 250 |
+
if recognition_result is False:
|
| 251 |
+
if verbose:
|
| 252 |
+
print(
|
| 253 |
+
f"Processing positive question {question_index} for key {key}, setting score to at most 0 since recognition is incorrect"
|
| 254 |
+
)
|
| 255 |
+
parsed_eval_result = min(0, parsed_eval_result)
|
| 256 |
+
parsed_eval_results_positives.append(parsed_eval_result)
|
| 257 |
+
details_positives.append(
|
| 258 |
+
{
|
| 259 |
+
**question_dict,
|
| 260 |
+
"pred_answer": pred_answers[question_index],
|
| 261 |
+
"pred_index": pred_indices[question_index],
|
| 262 |
+
# Subtract 1 to get the index in the original question list (excluding the recognition question)
|
| 263 |
+
"question_index": question_index - 1,
|
| 264 |
+
"eval_result": parsed_eval_result,
|
| 265 |
+
}
|
| 266 |
+
)
|
| 267 |
+
else:
|
| 268 |
+
raise ValueError(f"Invalid question type: {question_dict['type']}")
|
| 269 |
+
|
| 270 |
+
score_pos = sum(parsed_eval_results_positives) / len(parsed_eval_results_positives)
|
| 271 |
+
# It's possible that we don't have negatives for an instance. For this case, we skip over the instance for negative score calculation.
|
| 272 |
+
if len(parsed_eval_results_negatives):
|
| 273 |
+
score_neg = sum(parsed_eval_results_negatives) / len(
|
| 274 |
+
parsed_eval_results_negatives
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
score_neg = None
|
| 278 |
+
|
| 279 |
+
# Overall score is the average of the positive and negative scores
|
| 280 |
+
info = dict(
|
| 281 |
+
details_positives=details_positives,
|
| 282 |
+
details_negatives=details_negatives,
|
| 283 |
+
details_recognition=details_recognition,
|
| 284 |
+
prompt=prompt,
|
| 285 |
+
response=response,
|
| 286 |
+
score=(sum(parsed_eval_results_positives) + sum(parsed_eval_results_negatives))
|
| 287 |
+
/ (len(parsed_eval_results_positives) + len(parsed_eval_results_negatives)),
|
| 288 |
+
score_pos=score_pos,
|
| 289 |
+
score_neg=score_neg,
|
| 290 |
+
neg_valid_num=len(parsed_eval_results_negatives),
|
| 291 |
+
recognition_result=recognition_result,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
return info
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def is_plural(string):
|
| 298 |
+
# A case that the inflect library does not handle
|
| 299 |
+
if string == "bus":
|
| 300 |
+
return False
|
| 301 |
+
# singular_noun returns False if the word is already singular (otherwise it returns the singular form)
|
| 302 |
+
return p.singular_noun(string) is not False
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
if __name__ == "__main__":
|
| 306 |
+
# Example:
|
| 307 |
+
# python eval_model_outputs.py --pred model_outputs_cache/dam_3b_v1.json --base-url "http://localhost:9100/v1"
|
| 308 |
+
|
| 309 |
+
parser = argparse.ArgumentParser(description="Evaluate model outputs")
|
| 310 |
+
parser.add_argument(
|
| 311 |
+
"--pred", type=str, help="Path to the prediction JSON file", required=True
|
| 312 |
+
)
|
| 313 |
+
parser.add_argument(
|
| 314 |
+
"--qa",
|
| 315 |
+
type=str,
|
| 316 |
+
help="Path to the reference QA file",
|
| 317 |
+
default="evaluation/DLC-Bench/annotations/qa.json",
|
| 318 |
+
)
|
| 319 |
+
parser.add_argument(
|
| 320 |
+
"--class-names",
|
| 321 |
+
type=str,
|
| 322 |
+
help="Path to the class names JSON file",
|
| 323 |
+
default="evaluation/DLC-Bench/annotations/class_names.json",
|
| 324 |
+
)
|
| 325 |
+
parser.add_argument(
|
| 326 |
+
"--default-prediction",
|
| 327 |
+
type=str,
|
| 328 |
+
default=None,
|
| 329 |
+
help="Default prediction if key is not present in the prediction file",
|
| 330 |
+
)
|
| 331 |
+
parser.add_argument(
|
| 332 |
+
"--api-call-limit", type=int, default=1000, help="API call limit"
|
| 333 |
+
)
|
| 334 |
+
parser.add_argument(
|
| 335 |
+
"--api-key", type=str, default=None, help="Path to the OpenAI API key file"
|
| 336 |
+
)
|
| 337 |
+
parser.add_argument(
|
| 338 |
+
"--suffix", type=str, default="", help="Suffix for the evaluation file"
|
| 339 |
+
)
|
| 340 |
+
parser.add_argument("--model", type=str, default="llama3.1-8b", help="Model name")
|
| 341 |
+
parser.add_argument("--verbose", action="store_true", help="Enable verbose mode")
|
| 342 |
+
parser.add_argument(
|
| 343 |
+
"--quiet", action="store_true", help="Enable quiet mode (result only)"
|
| 344 |
+
)
|
| 345 |
+
parser.add_argument("--csv", action="store_true", help="Output results as CSV only")
|
| 346 |
+
|
| 347 |
+
parser.add_argument(
|
| 348 |
+
"--base-url",
|
| 349 |
+
type=str,
|
| 350 |
+
default="http://localhost:8007/v1",
|
| 351 |
+
help="Base URL for the API call",
|
| 352 |
+
)
|
| 353 |
+
args = parser.parse_args()
|
| 354 |
+
|
| 355 |
+
always_print = print
|
| 356 |
+
if args.quiet:
|
| 357 |
+
print = lambda *args, **kwargs: None
|
| 358 |
+
|
| 359 |
+
# v3 is from v2.1
|
| 360 |
+
eval_file = os.path.splitext(args.pred)[0] + f"_eval{args.suffix}.json"
|
| 361 |
+
eval_results = {}
|
| 362 |
+
|
| 363 |
+
if False:
|
| 364 |
+
assert not os.path.exists(eval_file), f"Evaluation file exists at {eval_file}"
|
| 365 |
+
else:
|
| 366 |
+
if os.path.exists(eval_file):
|
| 367 |
+
print(f"Loading existing evaluation data from {eval_file}")
|
| 368 |
+
try:
|
| 369 |
+
with open(eval_file) as f:
|
| 370 |
+
eval_results = json.load(f)
|
| 371 |
+
except Exception as e:
|
| 372 |
+
always_print(f"Error loading evaluation data {eval_file}: {e}")
|
| 373 |
+
raise e
|
| 374 |
+
|
| 375 |
+
if args.api_key:
|
| 376 |
+
with open(args.api_key) as f:
|
| 377 |
+
client = OpenAI(api_key=f.read().strip(), base_url=args.base_url)
|
| 378 |
+
else:
|
| 379 |
+
client = OpenAI(api_key="sk-abc123", base_url=args.base_url)
|
| 380 |
+
|
| 381 |
+
with open(args.pred) as f:
|
| 382 |
+
data_pred = json.load(f)
|
| 383 |
+
|
| 384 |
+
with open(args.qa) as f:
|
| 385 |
+
data_qa = json.load(f)
|
| 386 |
+
|
| 387 |
+
with open(args.class_names) as f:
|
| 388 |
+
data_class_names = json.load(f)
|
| 389 |
+
|
| 390 |
+
scores = {}
|
| 391 |
+
scores_pos = {}
|
| 392 |
+
scores_neg = {}
|
| 393 |
+
|
| 394 |
+
keys = list(data_qa.keys())
|
| 395 |
+
|
| 396 |
+
p = inflect.engine()
|
| 397 |
+
|
| 398 |
+
print(f"Using model {args.model}")
|
| 399 |
+
|
| 400 |
+
missing_key_count = 0
|
| 401 |
+
for key in tqdm(keys, disable=args.quiet):
|
| 402 |
+
key = str(key)
|
| 403 |
+
if key in eval_results:
|
| 404 |
+
if args.verbose:
|
| 405 |
+
print(f"Skipping {key}")
|
| 406 |
+
response_override = eval_results[key]["response"]
|
| 407 |
+
else:
|
| 408 |
+
response_override = None
|
| 409 |
+
|
| 410 |
+
if key not in data_pred:
|
| 411 |
+
if args.default_prediction is None:
|
| 412 |
+
raise ValueError(
|
| 413 |
+
f"Key {key} not found in prediction data, and no default prediction provided"
|
| 414 |
+
)
|
| 415 |
+
else:
|
| 416 |
+
print(
|
| 417 |
+
f"Key {key} not found in prediction data, using default prediction {args.default_prediction}"
|
| 418 |
+
)
|
| 419 |
+
pred_value = args.default_prediction
|
| 420 |
+
missing_key_count += 1
|
| 421 |
+
elif data_pred[key].startswith("Error:"):
|
| 422 |
+
if args.default_prediction is None:
|
| 423 |
+
raise ValueError(
|
| 424 |
+
f"Key {key} has an error in prediction data, and no default prediction provided: {data_pred[key]}"
|
| 425 |
+
)
|
| 426 |
+
else:
|
| 427 |
+
print(
|
| 428 |
+
f"Key {key} has an error in prediction: {data_pred[key]}, using default prediction {args.default_prediction}"
|
| 429 |
+
)
|
| 430 |
+
pred_value = args.default_prediction
|
| 431 |
+
missing_key_count += 1
|
| 432 |
+
else:
|
| 433 |
+
pred_value = data_pred[key]
|
| 434 |
+
|
| 435 |
+
# print(f"Evaluating {key}")
|
| 436 |
+
class_name = data_class_names[key]
|
| 437 |
+
|
| 438 |
+
if is_plural(class_name):
|
| 439 |
+
recognition_question = f"Is it likely that the objects in the description are {class_name} or objects of a similar type? Again, It does not have to be an exact match."
|
| 440 |
+
else:
|
| 441 |
+
recognition_question = f"Is it likely that the object in the description is {p.a(class_name)} or an object of a similar type? Again, It does not have to be an exact match."
|
| 442 |
+
recognition_question_dict = {
|
| 443 |
+
"question": recognition_question,
|
| 444 |
+
"choices": [("Yes", "correct"), ("No", "incorrect")],
|
| 445 |
+
"type": "recognition",
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
# Add the recognition question to the beginning of the list
|
| 449 |
+
question_dicts = [recognition_question_dict, *data_qa[key]]
|
| 450 |
+
info = evaluate(
|
| 451 |
+
question_dicts=question_dicts,
|
| 452 |
+
pred_caption=pred_value,
|
| 453 |
+
model=args.model,
|
| 454 |
+
temperature=0.0,
|
| 455 |
+
max_tokens=300,
|
| 456 |
+
response_override=response_override,
|
| 457 |
+
key=key,
|
| 458 |
+
verbose=args.verbose,
|
| 459 |
+
)
|
| 460 |
+
score = info["score"]
|
| 461 |
+
scores[key] = score
|
| 462 |
+
scores_pos[key] = info["score_pos"]
|
| 463 |
+
scores_neg[key] = info["score_neg"]
|
| 464 |
+
eval_results[key] = {"pred": pred_value, **info}
|
| 465 |
+
|
| 466 |
+
if args.verbose:
|
| 467 |
+
print(f"Score: {score}")
|
| 468 |
+
|
| 469 |
+
with open(eval_file, "w") as f:
|
| 470 |
+
json.dump(eval_results, f, indent=4)
|
| 471 |
+
|
| 472 |
+
avg_score_pos = sum(scores_pos.values()) / len(scores_pos)
|
| 473 |
+
scores_neg_valid_only = [item for item in scores_neg.values() if item is not None]
|
| 474 |
+
avg_score_neg = sum(scores_neg_valid_only) / len(scores_neg_valid_only)
|
| 475 |
+
|
| 476 |
+
if args.csv:
|
| 477 |
+
# Print comma-separated values directly to stdout
|
| 478 |
+
always_print(
|
| 479 |
+
f"{avg_score_pos:.3f},{avg_score_neg:.3f},{(avg_score_pos + avg_score_neg) / 2:.3f}"
|
| 480 |
+
)
|
| 481 |
+
else:
|
| 482 |
+
always_print(f"Result for {args.pred}")
|
| 483 |
+
always_print(f"Average Positive Score: {avg_score_pos:.3f}")
|
| 484 |
+
always_print(f"Average Negative Score: {avg_score_neg:.3f}")
|
| 485 |
+
always_print(
|
| 486 |
+
f"Average of Positive and Negative Scores: {(avg_score_pos + avg_score_neg) / 2:.3f}"
|
| 487 |
+
)
|
| 488 |
+
always_print(
|
| 489 |
+
f"Summary (Pos\tNeg\tAvg(Pos, Neg)):\t{avg_score_pos:.3f},\t{avg_score_neg:.3f},\t{(avg_score_pos + avg_score_neg) / 2:.3f}"
|
| 490 |
+
)
|
| 491 |
+
print(f"QA Scores: {scores}")
|
| 492 |
+
|
| 493 |
+
if missing_key_count:
|
| 494 |
+
print(
|
| 495 |
+
f"Note: Missing {missing_key_count} keys, using default prediction {args.default_prediction}"
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
eval_results["avg_pos"] = avg_score_pos
|
| 499 |
+
eval_results["avg_neg"] = avg_score_neg
|
| 500 |
+
with open(eval_file, "w") as f:
|
| 501 |
+
json.dump(eval_results, f, indent=4)
|
| 502 |
+
|
| 503 |
+
print(f"Evaluation data saved to {eval_file}")
|
evaluation/DLC-Bench/inference.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License")
|
| 4 |
+
# Grasp Any Region Project
|
| 5 |
+
# Written by Haochen Wang
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from pycocotools import mask as mask_utils
|
| 16 |
+
from pycocotools.coco import COCO
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from transformers import AutoModel, AutoProcessor, GenerationConfig
|
| 19 |
+
|
| 20 |
+
from evaluation.eval_dataset import SingleRegionCaptionDataset
|
| 21 |
+
|
| 22 |
+
TORCH_DTYPE_MAP = dict(fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def parse_args():
|
| 26 |
+
parser = argparse.ArgumentParser(
|
| 27 |
+
description="Inference of Grasp Any Region models on DLC-Bench."
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--model_name_or_path",
|
| 32 |
+
help="HF model name or path",
|
| 33 |
+
default="HaochenWang/GAR-1B",
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--cache_name",
|
| 37 |
+
help="cache name to save model outputs.",
|
| 38 |
+
default="gar_1b",
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--data_type",
|
| 42 |
+
help="data dtype",
|
| 43 |
+
type=str,
|
| 44 |
+
choices=["fp16", "bf16", "fp32"],
|
| 45 |
+
default="bf16",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--anno_file",
|
| 49 |
+
help="path to the annotation file.",
|
| 50 |
+
default="evaluation/DLC-Bench/annotations/annotations.json",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--image_folder",
|
| 54 |
+
help="the folder of images",
|
| 55 |
+
default="evaluation/DLC-Bench/annotations",
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--seed",
|
| 59 |
+
type=int,
|
| 60 |
+
default=0,
|
| 61 |
+
help="Random seed for reproducible text generation",
|
| 62 |
+
)
|
| 63 |
+
args = parser.parse_args()
|
| 64 |
+
return args
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def select_ann(coco, img_id, area_min=None, area_max=None):
|
| 68 |
+
cat_ids = coco.getCatIds()
|
| 69 |
+
ann_ids = coco.getAnnIds(imgIds=[img_id], catIds=cat_ids, iscrowd=None)
|
| 70 |
+
|
| 71 |
+
if area_min is not None:
|
| 72 |
+
ann_ids = [
|
| 73 |
+
ann_id for ann_id in ann_ids if coco.anns[ann_id]["area"] >= area_min
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
if area_max is not None:
|
| 77 |
+
ann_ids = [
|
| 78 |
+
ann_id for ann_id in ann_ids if coco.anns[ann_id]["area"] <= area_max
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
return ann_ids
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def main():
|
| 85 |
+
args = parse_args()
|
| 86 |
+
data_dtype = TORCH_DTYPE_MAP[args.data_type]
|
| 87 |
+
torch.manual_seed(args.seed)
|
| 88 |
+
|
| 89 |
+
# init ditribution for dispatch_modules in LLM
|
| 90 |
+
torch.cuda.set_device(0)
|
| 91 |
+
torch.distributed.init_process_group(backend="nccl")
|
| 92 |
+
|
| 93 |
+
# build HF model
|
| 94 |
+
model = AutoModel.from_pretrained(
|
| 95 |
+
args.model_name_or_path,
|
| 96 |
+
trust_remote_code=True,
|
| 97 |
+
torch_dtype=data_dtype,
|
| 98 |
+
)
|
| 99 |
+
model.cuda()
|
| 100 |
+
model.eval()
|
| 101 |
+
|
| 102 |
+
processor = AutoProcessor.from_pretrained(
|
| 103 |
+
args.model_name_or_path,
|
| 104 |
+
trust_remote_code=True,
|
| 105 |
+
)
|
| 106 |
+
model_outputs = {}
|
| 107 |
+
cache_name = args.cache_name
|
| 108 |
+
|
| 109 |
+
# This coco instance is actually an o365 subset. This is for code reuse.
|
| 110 |
+
coco = COCO(args.anno_file)
|
| 111 |
+
img_ids = list(coco.imgs.keys())
|
| 112 |
+
num_anns = len(coco.anns)
|
| 113 |
+
pbar = tqdm(total=num_anns)
|
| 114 |
+
|
| 115 |
+
for img_id in img_ids:
|
| 116 |
+
ann_ids = select_ann(coco, img_id)
|
| 117 |
+
img_info = coco.loadImgs(img_id)[0]
|
| 118 |
+
|
| 119 |
+
for i, ann_id in enumerate(ann_ids):
|
| 120 |
+
if ann_id in model_outputs.keys():
|
| 121 |
+
pbar.update(1)
|
| 122 |
+
continue
|
| 123 |
+
|
| 124 |
+
anns = coco.loadAnns([ann_id])
|
| 125 |
+
mask = coco.annToMask(anns[0])
|
| 126 |
+
|
| 127 |
+
img_path = os.path.join(args.image_folder, "images", img_info["file_name"])
|
| 128 |
+
img = Image.open(img_path)
|
| 129 |
+
|
| 130 |
+
prompt_number = model.config.prompt_numbers
|
| 131 |
+
prompt_tokens = [f"<Prompt{i_p}>" for i_p in range(prompt_number)] + [
|
| 132 |
+
"<NO_Prompt>"
|
| 133 |
+
]
|
| 134 |
+
dataset = SingleRegionCaptionDataset(
|
| 135 |
+
image=img,
|
| 136 |
+
mask=mask,
|
| 137 |
+
processor=processor,
|
| 138 |
+
prompt_number=prompt_number,
|
| 139 |
+
visual_prompt_tokens=prompt_tokens,
|
| 140 |
+
data_dtype=data_dtype,
|
| 141 |
+
)
|
| 142 |
+
data_sample = dataset[0]
|
| 143 |
+
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
generate_ids = model.generate(
|
| 146 |
+
**data_sample,
|
| 147 |
+
generation_config=GenerationConfig(
|
| 148 |
+
max_new_tokens=1024,
|
| 149 |
+
do_sample=False,
|
| 150 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
| 151 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
| 152 |
+
),
|
| 153 |
+
return_dict=True,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
outputs = processor.tokenizer.decode(
|
| 157 |
+
generate_ids.sequences[0], skip_special_tokens=True
|
| 158 |
+
).strip()
|
| 159 |
+
|
| 160 |
+
print(outputs) # Print model output for this image
|
| 161 |
+
|
| 162 |
+
model_outputs[ann_id] = outputs
|
| 163 |
+
pbar.update(1)
|
| 164 |
+
pbar.close()
|
| 165 |
+
|
| 166 |
+
with open(f"evaluation/DLC-Bench/model_outputs/{cache_name}.json", "w") as file:
|
| 167 |
+
json.dump(model_outputs, file, indent=4, ensure_ascii=False)
|
| 168 |
+
|
| 169 |
+
print(f"Cache name: {cache_name}")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
if __name__ == "__main__":
|
| 173 |
+
main()
|
evaluation/DLC-Bench/model_outputs/gar_1b.json
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"279135": "The ski features a predominantly black surface with intricate orange and white geometric patterns along its length. It is equipped with a black binding system, including a silver and black mechanism, and an orange adjustment lever. The tip of the ski has a similar pattern to the rest of the ski.",
|
| 3 |
+
"297718": "A piece of sushi with a bed of white rice, topped with a layer of black seaweed, and filled with a mixture of pink and white fish roe. The top is garnished with a sprinkle of sesame seeds.",
|
| 4 |
+
"361105": "A small cluster of fresh, vibrant green leaves with a smooth texture, attached to a slender, slightly curved stem. The leaves are elongated with pointed tips and a glossy surface, showing a few small brown spots.",
|
| 5 |
+
"622329": "A rectangular, flat, beige eraser with rounded corners and a slightly textured surface.",
|
| 6 |
+
"622332": "A black rectangular stapler with a glossy finish, featuring a silver brand logo on the top right corner. The stapler has a visible metal stapling mechanism on the right side.",
|
| 7 |
+
"1075308": "A vintage-style television set with a boxy, black plastic casing. The front features a large, square screen with a slightly curved surface. The top of the television has a series of buttons and dials, and there is a small, rectangular display area above the screen.",
|
| 8 |
+
"1196168": "A rectangular, white air conditioner unit with a large circular fan grille on the left side. The grille has a grid pattern with multiple blades visible. To the right of the grille, there is a small rectangular panel with a circular emblem and some text.",
|
| 9 |
+
"1770866": "A white price tag with handwritten text in blue and red marker. The text reads \"Libra\" in blue at the top, followed by \"Lb\" in blue, \"per\" in blue, \"lb\" in blue, and \"950\" in red.",
|
| 10 |
+
"1894089": "A metallic screwdriver with a long, slender shaft and a flat, rectangular head. The shaft is smooth and tapers slightly towards the head, which is flat and has a small, circular indentation near the tip.",
|
| 11 |
+
"2391761": "The canoe has a wooden hull with horizontal planks and a blue tarpaulin cover draped over it. The tarpaulin is secured with ropes and has some white markings on it. The canoe also features a small outboard motor mounted on the stern.",
|
| 12 |
+
"2391780": "A bird with a long, slender neck and a pointed beak. Its plumage is predominantly brown with lighter, almost white, streaks on the wings and back. The bird's legs are thin and dark, and it has a small, rounded tail.",
|
| 13 |
+
"2391781": "The bird has a predominantly dark brown body with lighter brown and white markings on its wings and back. Its wings are outstretched, showing a mix of dark and light feathers. The bird's head is slightly turned, with a visible beak and a hint of white feathers around the neck area.",
|
| 14 |
+
"2580318": "The mouse has a sleek, metallic silver body with a smooth, reflective surface. The visible part of the mouse is triangular in shape, with a slightly curved edge and a subtle gradient of light reflecting off its surface.",
|
| 15 |
+
"2580323": "A rectangular wooden frame encloses a detailed architectural blueprint with various lines, symbols, and text. The frame has a natural wood finish and is mounted on a wall.",
|
| 16 |
+
"2588513": "A rectangular wooden block with a light beige top surface and a black bottom surface. The top surface has a smooth texture with visible wood grain patterns, while the bottom surface is solid black.",
|
| 17 |
+
"3993075": "A cylindrical marker with a white body featuring a colorful design, including a blue and green pattern near the middle and a red cap.",
|
| 18 |
+
"4027486": "The bus is predominantly blue with a white section on the right side. It has a black horizontal stripe running along the middle, with a green stripe above it. The rear of the bus features a white license plate with black text. There is a small, white, triangular logo with a black design on the blue section near the rear.",
|
| 19 |
+
"4243725": "The soap is a rectangular, slightly curved bar with a smooth, creamy beige surface.",
|
| 20 |
+
"4502267": "A green bean with a smooth, slightly curved surface, tapering to a point at one end and having a broader, rounded base at the other. The bean has a consistent green color with subtle variations in shading.",
|
| 21 |
+
"4604873": "A large, industrial crane with a lattice structure, featuring a long, horizontal boom extending from a vertical mast. The boom is supported by a series of diagonal cross-bracing and has a hook at the end. The mast is equipped with various mechanical components and a counterweight at the base.",
|
| 22 |
+
"4781902": "A dark brown, wooden ladder with evenly spaced, flat rungs and side rails that taper slightly towards the top.",
|
| 23 |
+
"4782942": "A large, dark-colored, conical-shaped horn with a wide, flared opening and a narrow, cylindrical body.",
|
| 24 |
+
"4782949": "The drum has a circular shape with a red body and a black rim. The drumhead is a light brown color with a blue circular patch in the center.",
|
| 25 |
+
"4916799": "A spherical sculpture with a textured surface composed of small, raised, silver-colored elements. The sphere is adorned with blue, three-dimensional letters spelling \"Reve\" and is mounted on a black, cylindrical base. A green band encircles the sphere, and there are colorful, abstract shapes and patterns on the left side.",
|
| 26 |
+
"5211280": "A stainless steel rice cooker with a black handle on top. The front features a digital display screen in the center, surrounded by various buttons and controls. The buttons are arranged in a circular pattern around the display, with additional buttons below. The cooker has a sleek, modern design with a smooth, reflective surface.",
|
| 27 |
+
"5718392": "A woven basket with a dark brown color and a pattern of interlocking diamond shapes, featuring a sturdy, slightly curved handle.",
|
| 28 |
+
"5718415": "The tent has a yellow canopy with a dark brown edge. The visible part of the tent includes a metal pole with a rusted section near the bottom.",
|
| 29 |
+
"5718424": "A black athletic shoe with a textured surface, featuring a prominent yellow swoosh logo on the side. The shoe has a low-top design with a padded collar and a lace-up closure. The sole is thick and rugged, designed for traction.",
|
| 30 |
+
"6012878": "A square pedestrian traffic light with a black background, featuring a red illuminated hand symbol on the left side.",
|
| 31 |
+
"6037269": "A metallic shower head with a curved, elongated handle. The handle is cylindrical and appears to be made of a light-colored material. The shower head itself is conical with a rounded tip and a slightly wider base, featuring a reflective surface.",
|
| 32 |
+
"6037272": "A green shampoo bottle with a slightly curved shape, featuring a white label with text and a small orange logo.",
|
| 33 |
+
"6055310": "A golden ruler with a series of evenly spaced, small, rectangular notches along its length.",
|
| 34 |
+
"6820594": "A medium-sized cat with a predominantly white face and chest, featuring a mix of brown and black tabby markings on its back and sides. The cat has large, expressive green eyes and a pink nose. Its ears are pointed and have a light brown color with darker tips.",
|
| 35 |
+
"6820595": "A cat with a white face and ears, featuring a mix of black and brown fur on its back and tail. The cat's body is predominantly white with black patches, and it has a short, sleek coat.",
|
| 36 |
+
"7050495": "A black leather handbag with a smooth texture and a slightly curved bottom edge.",
|
| 37 |
+
"8201777": "A black van with a rear window displaying the word \"TAXI\" in large yellow letters. The van has a yellow license plate with black text and a small white sticker on the lower left side of the rear bumper. The rear lights are vertically aligned on both sides of the van.",
|
| 38 |
+
"8331685": "The earphone features a sleek, curved design with a dark gray color. The earpiece is circular and appears to be cushioned for comfort. The headband is also dark gray and has a smooth, slightly glossy finish.",
|
| 39 |
+
"8331699": "The visible part of the printer is black with a smooth, curved surface.",
|
| 40 |
+
"8331718": "A black spiral-bound notebook with a white cover and the word \"Xtreme\" written in white on the cover.",
|
| 41 |
+
"8556674": "A round, orange fruit with a smooth, glossy surface. The fruit has a gradient of colors, transitioning from a deep orange at the bottom to a lighter, almost yellowish hue at the top. There is a small, white, irregularly shaped patch near the top center.",
|
| 42 |
+
"8556676": "A deep red apple with a glossy surface reflecting light, showcasing a smooth curvature and a small, visible portion of the stem.",
|
| 43 |
+
"8557176": "The watch features a rectangular gold-toned case with a black dial. It has a black leather strap with white stitching and a small metallic buckle.",
|
| 44 |
+
"8557195": "The microwave oven features a smooth, curved, off-white exterior with a slightly reflective surface. The visible part of the microwave includes a rounded edge and a small, dark-colored component at the top.",
|
| 45 |
+
"8906172": "A black, in-ear headphone with a sleek, curved design.",
|
| 46 |
+
"9766617": "The goose has a predominantly brown body with a pattern of darker brown and black feathers on its back. Its head is black with a white patch on the side of its neck. The underbelly is white, and the legs and feet are greenish.",
|
| 47 |
+
"10666665": "A round wall clock with a black frame and a white face. The clock features black Arabic numerals for each hour, with the numbers 1 through 12 clearly visible. The hour and minute hands are black and pointed, while the second hand is thin and black. The clock has a simple, minimalist design.",
|
| 48 |
+
"10811497": "A dark green, oval-shaped key with a smooth surface and a small, circular indentation near the bottom.",
|
| 49 |
+
"11012500": "A soft, round tortilla filled with fresh arugula, a slice of ripe tomato, shredded lettuce, and a dollop of creamy sauce.",
|
| 50 |
+
"11021544": "The faucet features a sleek, curved design with a polished chrome finish. It has a single lever handle on the right side for controlling water flow and temperature. The spout is slightly arched, extending outward with a smooth, flowing curve.",
|
| 51 |
+
"11021562": "The microwave oven features a white exterior with a prominent vertical handle on the left side of the door. The door has a series of horizontal ventilation slits near the top.",
|
| 52 |
+
"11021563": "A white stove with four black burners, featuring a control panel with knobs on the back.",
|
| 53 |
+
"11775390": "A green rubber shoe with a thick, textured sole and multiple circular holes on the side. The shoe features a black and white design on the side, with a prominent black section and white accents. The upper part of the shoe has a smooth, rounded shape with a slight curve at the top.",
|
| 54 |
+
"11950619": "The racket has a light-colored wooden handle with a smooth finish. The head of the racket is covered with a transparent protective guard, revealing a blue and white string bed. The guard has a rectangular shape with rounded edges and is secured to the head of the racket.",
|
| 55 |
+
"12178946": "A cylindrical bottle with a yellow cap and a blue label featuring white text.",
|
| 56 |
+
"12348078": "A woman with dark hair tied back in a bun, wearing a white t-shirt with red text and graphics on the front, and black pants. She is seated and holding a baby close to her chest.",
|
| 57 |
+
"12348079": "A digital weighing scale with a rectangular, blue weighing platform on top. The scale has a white base with a control panel on the left side, featuring several buttons and a display screen. The right side of the scale has a series of blue and black buttons.",
|
| 58 |
+
"12348080": "A pair of scissors with bright red handles, each handle forming a loop with a smooth, rounded edge. The blades are metallic and converge at a central pivot point, with one blade partially visible.",
|
| 59 |
+
"13138178": "The stool has a deep blue, glossy finish with a smooth, curved design. The visible part includes a rounded, arch-like structure with a slight indentation in the middle, creating a sleek and modern appearance.",
|
| 60 |
+
"13187927": "The motorcycle is a white scooter with a black seat and a rear storage compartment. It features a rear red tail light and a license plate mounted below the tail light. The handlebars are equipped with rearview mirrors, and the body has a sleek, modern design with a slightly curved front.",
|
| 61 |
+
"14490578": "The harbor seal has a sleek, elongated body with a dark brown to black coloration. Its skin appears smooth and slightly shiny, with a few lighter patches scattered across its back. The seal's head is rounded, and its body tapers towards the tail.",
|
| 62 |
+
"14640483": "A rectangular wooden chopping board with a smooth surface and rounded edges. The board has a natural wood grain pattern and a warm, honey-brown color.",
|
| 63 |
+
"14832137": "A cylindrical, dark blue plastic bucket with a smooth surface and a slightly flared rim. The bucket has a handle attached to the top edge, which is also dark blue.",
|
| 64 |
+
"15050320": "A dark brown wine glass with a wide, shallow bowl and a short stem. The glass has a smooth, reflective surface with a few light reflections visible on the bowl.",
|
| 65 |
+
"16010041": "A pair of light-colored chopsticks with a smooth, slightly tapered design, featuring a subtle gradient from a pale yellow to a soft orange hue at the tips.",
|
| 66 |
+
"16951734": "A triangular slice of mango with a smooth, light orange flesh and a slightly darker orange edge.",
|
| 67 |
+
"16957916": "A piece of green lettuce with a slightly curled edge and a mix of light and dark green hues, featuring a few small brown spots and a hint of red at the base.",
|
| 68 |
+
"17072759": "A black belt with a smooth texture, featuring a silver rectangular buckle.",
|
| 69 |
+
"17072764": "A partially visible pear with a smooth, light green skin transitioning to a yellowish hue towards the top. The pear has a small, brown stem protruding from its top left side.",
|
| 70 |
+
"17265253": "A black rickshaw with a single visible wheel featuring a silver rim and black tire. The wheel is attached to a black frame with a visible axle and a small, round, orange reflector on the side. The rickshaw has a black canopy with a slightly curved top edge.",
|
| 71 |
+
"17265254": "A traditional rickshaw with a black frame and a red seat cushion. It features a single large spoked wheel on each side, connected by a horizontal axle. The rickshaw has a curved handlebar at the front, and a footrest is visible beneath the seat. The wheels are equipped with black tires and silver rims.",
|
| 72 |
+
"17385866": "A scoop of vanilla ice cream with a swirl of red and yellow fruit toppings, possibly strawberry and lemon, on a light green and yellow marbled base.",
|
| 73 |
+
"17404769": "The car is a white SUV with a rear hatchback design. It features a rear window with a slight tint and a small, square fuel cap on the right side of the rear door. The taillights are vertically aligned and wrap around the side of the vehicle. The rear bumper is slightly curved, and the car has a visible rear wheel with a five-spoke alloy rim.",
|
| 74 |
+
"18217373": "The spectacles feature a thin, dark brown frame with a slightly curved bridge. The lenses are rectangular with rounded edges, and the frame has a subtle metallic sheen.",
|
| 75 |
+
"18301585": "The bench features a black metal frame with horizontal slats forming the backrest and seat. The backrest slats are evenly spaced and supported by white, rectangular concrete supports. The seat slats are also black and run parallel to the backrest. The bench has a sturdy, industrial design with a solid, robust appearance.",
|
| 76 |
+
"18680641": "A rectangular, plush, red carpet with a slightly uneven surface and a subtle gradient of darker red in the middle. The edges are bordered by a thin, dark gray trim.",
|
| 77 |
+
"18845103": "A metallic spoon with a slightly curved, elongated handle and a shallow, oval-shaped bowl. The handle has a smooth, polished finish, and the bowl is also metallic with a reflective surface.",
|
| 78 |
+
"19455186": "A blue metal handcart with two horizontal bars and two vertical supports. The cart has two black wheels at the bottom.",
|
| 79 |
+
"19610023": "A bright green croc-style shoe with a thick, textured sole and a wide, open toe design. The shoe features a smooth, rounded toe and a slightly raised heel.",
|
| 80 |
+
"19610025": "A white rabbit with large, upright ears and a red backpack. It is wearing a yellow shirt and blue pants. The rabbit has a playful expression with its mouth open and eyes wide.",
|
| 81 |
+
"20568676": "A stainless steel cooking pot with a rounded bottom and a rolled edge, featuring two riveted handles on opposite sides.",
|
| 82 |
+
"20993402": "A roll of white adhesive tape with a smooth, glossy surface and a slightly reflective sheen. The tape is wound tightly around a cylindrical core, with the outer edge appearing clean and unblemished.",
|
| 83 |
+
"21107974": "A wooden gavel with a cylindrical head featuring three evenly spaced, horizontal grooves. The handle is smooth and tapers slightly towards the end.",
|
| 84 |
+
"21529954": "A cylindrical can with a predominantly green label featuring the word \"Pepsi\" in white, bold letters. The top of the can is orange with a white cap.",
|
| 85 |
+
"22064315": "The visible part of the antelope shows two long, curved horns with a dark, almost black coloration, tapering to a point. The horns are covered in a pattern of ridges and grooves, giving them a textured appearance.",
|
| 86 |
+
"22107522": "A black bow tie with a smooth, satin-like finish, featuring a classic butterfly shape with pointed tips. The bow tie has a symmetrical design with a central knot and two loops that are slightly curved outward.",
|
| 87 |
+
"22879790": "A single, partially peeled white onion with a smooth, slightly shiny surface. The onion has a bulbous shape with a visible root end that is dry and brownish. The layers are tightly packed, and the outer skin is mostly intact, showing a few small, white root remnants.",
|
| 88 |
+
"24010373": "The guitar has a dark, glossy finish with a cutaway body design. It features a white pickguard and a circular soundhole with a simple rosette pattern. The fretboard is dark with white dot inlays, and the headstock is equipped with tuning pegs.",
|
| 89 |
+
"24017816": "The car features a dark-tinted side window with a black frame, and a portion of the front windshield is visible, also with a black frame.",
|
| 90 |
+
"24498027": "A tall, slender black pole with a decorative, ornate top featuring a small, pointed finial. The pole has a horizontal arm extending from the middle, supporting a lantern-style light fixture with a glass enclosure and a metal frame.",
|
| 91 |
+
"24581953": "A large, light gray dog with a short, smooth coat is lying down with its body stretched out. The dog has a long, slender tail that extends straight out behind it. Its legs are extended, with the front legs slightly bent and the hind legs stretched out. The dog's head is resting on the ground, and its ears are relaxed and folded back.",
|
| 92 |
+
"24694197": "A ripe avocado with a bumpy, dark green to almost black skin and a large, round, red to yellow-green pit nestled in the center.",
|
| 93 |
+
"24786060": "A light gray towel with a soft, plush texture, featuring a slightly wrinkled appearance. The towel has a rectangular shape with a visible fold running vertically down the center.",
|
| 94 |
+
"25054869": "A beige toilet cistern with a smooth, curved top surface and a slightly protruding front edge.",
|
| 95 |
+
"25273528": "A hot air balloon with a vibrant pattern of alternating vertical stripes in dark blue, red, and yellow. The balloon has a teardrop shape with a small basket attached at the bottom.",
|
| 96 |
+
"25273553": "A black tripod with three legs, each leg featuring a rubber foot for stability. The legs are connected at the top by a central column, which supports a mounting platform with a quick-release plate. The tripod has a telescopic head with a pan handle for adjusting the angle of the head.",
|
| 97 |
+
"25419495": "The tongs have a dark green handle with a black grip at the end. The metal arms are slightly curved and have a dark, matte finish. The tips of the arms are pointed and designed for precise gripping.",
|
| 98 |
+
"25419509": "A metallic fork with a slightly curved handle and four evenly spaced tines. The handle has a smooth, reflective surface with a subtle gradient from light to dark.",
|
| 99 |
+
"25419516": "The toy is a plush, blue creature with large, expressive eyes and prominent, pointed ears. It has a small tuft of hair on top of its head and a light blue underbelly.",
|
| 100 |
+
"25579493": "A square-shaped mixing bowl with rounded edges, featuring a light blue exterior and a cream-colored interior. The bowl contains a mixture of white and yellow ingredients, with a small piece of red garnish on top.",
|
| 101 |
+
"25612310": "A woven wicker basket with a dark brown hue, featuring a series of horizontal slats and a slightly curved edge."
|
| 102 |
+
}
|
evaluation/DLC-Bench/model_outputs/gar_1b_eval.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
evaluation/DLC-Bench/model_outputs/gar_1b_eval_gpt.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
evaluation/DLC-Bench/model_outputs/gar_8b.json
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"279135": "The ski features a predominantly black base with intricate orange and white geometric patterns. The design includes a series of interconnected shapes and lines, creating a dynamic and modern appearance. The tip of the ski is slightly curved and tapers to a point, with the pattern continuing seamlessly along the length of the ski.",
|
| 3 |
+
"297718": "A piece of sushi with a bed of white rice wrapped in a dark seaweed sheet, filled with a generous portion of pink and white crab meat. The top is sprinkled with sesame seeds and a light drizzle of soy sauce.",
|
| 4 |
+
"361105": "A small cluster of fresh, vibrant green leaves with a smooth texture, attached to a thin, green stem. The leaves are broad and slightly serrated at the edges, with a glossy surface.",
|
| 5 |
+
"622329": "A rectangular, flat, beige-colored eraser with a slightly rough texture and rounded edges.",
|
| 6 |
+
"622332": "A black, rectangular stapler with a glossy finish. The top surface features a white logo and text. The front edge has a slightly raised, horizontal groove.",
|
| 7 |
+
"1075308": "A vintage-style television set with a boxy, black frame and a slightly curved screen. The top of the television features a series of control buttons and a small display screen.",
|
| 8 |
+
"1196168": "A rectangular, wall-mounted air conditioner with a large circular vent on the left side, featuring a grid pattern. The right side of the unit has a smooth surface with a small rectangular panel and a few visible screws.",
|
| 9 |
+
"1770866": "A white tag with handwritten text in blue and red marker. The blue text reads \"LIBRA\" and \"my tabouts\" in a cursive style. Below, in red marker, the text \"Add $50\" is written in a bold, sans-serif font.",
|
| 10 |
+
"1894089": "A metallic screwdriver with a flathead tip and a cylindrical shaft. The handle is textured for grip and has a slight taper towards the tip.",
|
| 11 |
+
"2391761": "The canoe features a blue tarpaulin cover secured over its wooden frame. The visible part of the canoe's hull is made of wooden planks, with a natural brown finish. The canoe has a pointed bow and a slightly raised stern. A white fender is attached to the side, and a red and white lifebuoy is visible inside the canoe.",
|
| 12 |
+
"2391780": "The bird has a long, slender neck and a pointed beak. Its body is covered in brown feathers with a slightly lighter underbelly. The wings are outstretched, showing a mix of brown and white feathers with a distinct pattern. The tail feathers are short and pointed.",
|
| 13 |
+
"2391781": "The bird has a predominantly white body with a mix of gray and brown feathers on its wings and back. Its wings are outstretched, showing a gradient from white at the base to darker shades towards the tips. The bird's head is slightly turned, with a small, pointed beak visible.",
|
| 14 |
+
"2580318": "The mouse has a smooth, metallic surface with a slightly curved, ergonomic shape. The visible part is a triangular section with a gradient of light and dark shades, giving it a sleek and modern appearance.",
|
| 15 |
+
"2580323": "A rectangular wooden picture frame with a light brown finish, containing a detailed architectural floor plan and elevation drawings. The drawings are monochrome and feature various rooms, furniture, and structural elements. The frame has a simple, smooth design with slightly rounded edges.",
|
| 16 |
+
"2588513": "A rectangular wooden block with a light beige color and visible wood grain texture. The block has a black base and a white band wrapped around its middle.",
|
| 17 |
+
"3993075": "A white pen with a red cap and a green and blue design on the barrel.",
|
| 18 |
+
"4027486": "The bus is predominantly blue with a white section near the bottom. It has a rectangular window with a black frame and a visible license plate that reads \"SABF.\" The bus features a sleek, modern design with a slightly curved roof and a small, white, triangular logo near the bottom.",
|
| 19 |
+
"4243725": "A curved, elongated object with a gradient of colors ranging from light yellow to dark brown, featuring a smooth, glossy surface.",
|
| 20 |
+
"4502267": "A green bean with a smooth, slightly curved surface, featuring a gradient of light to dark green hues. The bean has a tapered end and a small, pointed tip.",
|
| 21 |
+
"4604873": "A tall, lattice-style mobile crane with a long, horizontal boom extending to the left. The crane has a rectangular base and a vertical mast with a series of diagonal cross-bracing. The boom is supported by a series of cables and pulleys, and there is a hook at the end of the boom.",
|
| 22 |
+
"4781902": "A dark brown wooden stool with a triangular seat and four legs, each leg angling outward and connected by a lower horizontal support beam.",
|
| 23 |
+
"4782942": "A dark-colored, conical-shaped horn with a wide, flared opening and a smooth, cylindrical body.",
|
| 24 |
+
"4782949": "A cylindrical drum with a dark brown, textured surface and a metallic rim. The drum has a blue and white striped pattern on the side.",
|
| 25 |
+
"4916799": "A spherical sculpture composed of numerous small, white, dome-shaped elements arranged in a grid pattern. The sphere is mounted on a cylindrical base and features a blue band with the word \"Pune\" in blue letters. There are also green and yellow accents on the sphere.",
|
| 26 |
+
"5211280": "A stainless steel crock pot with a curved, dark gray handle on top. The control panel features a digital display in the center, surrounded by various buttons and indicators. The buttons are arranged in a semi-circular pattern around the display, with labels in both English and another language. The crock pot has a smooth, reflective surface and a slightly tapered design towards the base.",
|
| 27 |
+
"5718392": "The box is a rectangular prism with a woven pattern of interlocking dark brown and light brown strips. The surface has a textured appearance, with the weave creating a series of small, diamond-shaped openings.",
|
| 28 |
+
"5718415": "The tent has a yellow canopy with a slightly curved edge. The visible part of the tent includes a vertical metal pole supporting the canopy.",
|
| 29 |
+
"5718424": "A rugged, dark-colored shoe with a thick, textured sole and a prominent, rounded toe. The shoe features a light-colored trim around the opening and a visible lace-up design.",
|
| 30 |
+
"6012878": "A square traffic light with a black background and a red illuminated hand symbol on the left side.",
|
| 31 |
+
"6037269": "A vintage-style shower head with a curved, metallic arm and a cylindrical, cream-colored handle. The shower head itself is round and metallic, with a slightly domed top and a flat bottom.",
|
| 32 |
+
"6037272": "A green, cylindrical shampoo bottle with a slightly tapered end. The bottle has a smooth surface with a small, circular, orange and white label near the top.",
|
| 33 |
+
"6055310": "A wooden measuring stick with a natural finish, featuring black measurement markings in centimeters and millimeters. The stick has a slightly tapered end and a metal tip at the opposite end.",
|
| 34 |
+
"6820594": "A medium-sized cat with a predominantly white face and underbelly, featuring a mix of dark brown and black patches on its back and sides. The cat has large, round, light green eyes and a pink nose. Its ears are upright, with the left ear having a light brown patch and the right ear being mostly white. The cat's fur is short and smooth.",
|
| 35 |
+
"6820595": "A cat with a white face, black ears, and a black patch over its left eye. The body is predominantly black with a white underbelly and a white patch on its right side. The tail is black.",
|
| 36 |
+
"7050495": "A black leather handbag with a smooth, slightly glossy finish. The visible part shows a rectangular shape with a subtle seam along the bottom edge.",
|
| 37 |
+
"8201777": "A black van with a rear window displaying the word \"TAXI\" in yellow letters. The van has a yellow license plate with black text and a small white sticker below it. The rear lights are vertically aligned on both sides, and the van has a small emblem above the license plate.",
|
| 38 |
+
"8331685": "A black over-ear headphone with a curved headband and a cushioned earcup. The earcup has a circular shape with a smooth, matte finish. The headband is attached to the earcup with a visible hinge mechanism.",
|
| 39 |
+
"8331699": "The visible part of the waste container is black with a smooth surface and a slightly curved edge.",
|
| 40 |
+
"8331718": "A black spiral-bound notebook with a white cover featuring the word \"Xtreme\" in a stylized font.",
|
| 41 |
+
"8556674": "A single, round orange with a smooth, glossy surface. The orange has a vibrant, bright orange color with a small, lighter patch near the top left.",
|
| 42 |
+
"8556676": "A deep red apple with a smooth, glossy surface. The apple has a slightly irregular shape with a prominent bulge on the left side and a smaller bulge on the right side. The bottom part of the apple is slightly darker, almost black, with a few small, reflective spots.",
|
| 43 |
+
"8557176": "The watch features a rectangular gold case with a white dial. The strap is black with a textured pattern and a gold buckle.",
|
| 44 |
+
"8557195": "A beige, rectangular bread maker with a smooth surface and slightly rounded edges. The top edge has a small, dark opening.",
|
| 45 |
+
"8906172": "A black, curved earphone with a smooth, glossy finish.",
|
| 46 |
+
"9766617": "The goose has a predominantly brown body with a pattern of darker brown and black feathers on its back. Its head is black with a white patch on the side of its neck. The beak is black, and the legs and feet are also black. The underbelly is white, and the tail feathers are black with a white tip.",
|
| 47 |
+
"10666665": "A round wall clock with a black frame and a white face. The clock features black Arabic numerals at each hour mark, with the numbers 12, 3, 6, and 9 in larger font. The clock has three black hands: an hour hand, a minute hand, and a second hand. The hour hand is pointing between the 10 and 11, the minute hand is pointing at 12, and the second hand is pointing at 6.",
|
| 48 |
+
"10811497": "The mouse is a dark green, oval-shaped device with a smooth surface. It has a small, circular indentation near the bottom edge.",
|
| 49 |
+
"11012500": "A burrito filled with fresh green arugula, a slice of ripe tomato, shredded lettuce, and a layer of seasoned ground meat, all wrapped in a soft, lightly toasted tortilla.",
|
| 50 |
+
"11021544": "A metallic, curved faucet with a polished finish, featuring a single lever handle and a long, slender spout.",
|
| 51 |
+
"11021562": "The microwave oven has a white exterior with a rectangular shape. It features a prominent, curved handle on the front door, which is also white. The control panel is located on the right side of the door, with a series of buttons and a small display screen. The top of the microwave has a vented section for ventilation.",
|
| 52 |
+
"11021563": "A stainless steel gas stove with a black control panel featuring four knobs. The stove has a rectangular shape with a slightly raised back panel. The control panel is positioned at the back, and the stove has a smooth, reflective surface.",
|
| 53 |
+
"11775390": "A green rubber shoe with a textured sole and multiple circular holes on the side. The shoe features a black and white design on the upper part, with green laces threaded through the eyelets.",
|
| 54 |
+
"11950619": "The dumbbell features a white, rectangular handle with rounded edges and a smooth surface. The handle is attached to a metallic, rectangular weight plate with a series of evenly spaced, vertical slots. The weight plate is secured to the handle with a visible screw.",
|
| 55 |
+
"12178946": "A yellow bottle with a blue label featuring white text.",
|
| 56 |
+
"12348078": "A woman with dark hair tied up in a bun, wearing a white t-shirt with red text and graphics on the front, and black pants. She is holding a baby in her arms.",
|
| 57 |
+
"12348079": "A rectangular digital weighing scale with a metallic blue weighing platform. The scale has a white base with a control panel on the left side, featuring several buttons and a small display screen. The edges of the scale are slightly rounded, and the weighing platform has a textured surface.",
|
| 58 |
+
"12348080": "A pair of scissors with bright red plastic handles and metallic blades. The handles are oval-shaped with a smooth, glossy finish. The blades are straight and sharp, with a slight taper towards the tips.",
|
| 59 |
+
"13138178": "A blue plastic stool with a smooth, curved seat and rounded legs. The stool has a simple, sturdy design with a slightly glossy finish.",
|
| 60 |
+
"13187927": "The motorcycle is a white scooter with a sleek, modern design. It features a black seat and a rear storage compartment with a red reflector. The rear light is integrated into the storage compartment, and the scooter has a visible license plate mounted below the light. The handlebars are equipped with rearview mirrors, and the front section includes a headlight and a windshield.",
|
| 61 |
+
"14490578": "The harbor seal has a sleek, elongated body with a dark, almost black coloration. Its skin appears smooth and slightly glossy, with a subtle gradient of lighter shades along its back. The seal's head is rounded, and its body tapers towards the tail.",
|
| 62 |
+
"14640483": "A rectangular wooden chopping board with a smooth surface and a natural wood grain pattern. The board has a slightly rounded edge and a visible handle on one side.",
|
| 63 |
+
"14832137": "A cylindrical, light purple plastic bucket with a smooth surface and a slightly flared rim. The bucket has a small, curved handle attached near the top.",
|
| 64 |
+
"15050320": "A dark brown wine glass with a wide, flat base and a slender stem.",
|
| 65 |
+
"16010041": "A pair of light-colored wooden chopsticks with a smooth, polished surface. The tips of the chopsticks are slightly tapered and have a subtle orange hue.",
|
| 66 |
+
"16951734": "A wedge of cantaloupe with a smooth, light orange flesh and a thin, pale rind.",
|
| 67 |
+
"16957916": "Fresh green lettuce leaves with ruffled edges and a crisp texture, exhibiting a gradient of color from pale green at the base to a darker green towards the tips.",
|
| 68 |
+
"17072759": "A black belt with a smooth texture, featuring a silver rectangular buckle. The belt has a single prong and a loop near the buckle for securing the tail end.",
|
| 69 |
+
"17072764": "A pear with a smooth, light green skin, featuring a slight yellowish hue on the upper right side. The pear has a short, brown stem attached to its top.",
|
| 70 |
+
"17265253": "A black rickshaw with a black canopy, featuring a single visible wheel with a silver rim and black tire. The wheel is attached to a black frame with a visible pedal mechanism.",
|
| 71 |
+
"17265254": "A traditional rickshaw with a black frame and a red seat, featuring a curved handlebar and a single front wheel with spokes and a rubber tire.",
|
| 72 |
+
"17385866": "A scoop of vanilla ice cream topped with a slice of red strawberry, resting on a bed of green mint leaves.",
|
| 73 |
+
"17404769": "The car is a white minivan with a rear design featuring a large, dark-tinted rear window and a smaller, rectangular window on the side. The rear lights are vertically aligned and wrap around the side of the vehicle. The car has a visible rear wheel with a five-spoke alloy rim. There is a small, square fuel cap located on the side panel near the rear wheel.",
|
| 74 |
+
"18217373": "The spectacles feature a round, gold-colored frame with a thin, dark brown temple arm. The lens is a light, translucent yellow.",
|
| 75 |
+
"18301585": "The bench features a black metal frame with horizontal slats forming the backrest and seat. The backrest consists of three horizontal slats, while the seat has two horizontal slats. The bench is supported by white concrete legs that are rectangular in shape and have a slightly tapered design.",
|
| 76 |
+
"18680641": "A rectangular, plush, red carpet with a slightly textured surface and a dark gray border along the edges.",
|
| 77 |
+
"18845103": "A metallic spoon with a slightly curved handle and a shallow, oval-shaped bowl. The handle has a smooth, reflective surface with a subtle taper towards the bowl.",
|
| 78 |
+
"19455186": "A blue metal cart with a rectangular frame and four black wheels. The cart has two horizontal blue bars across the front, with a small white label affixed to the upper bar.",
|
| 79 |
+
"19610023": "A bright green, frog-shaped slipper with a smooth, rounded body and a wide, open mouth. The slipper has a small, raised bump on the top of its head, resembling an eye.",
|
| 80 |
+
"19610025": "A white rabbit with upright ears, wearing a yellow shirt and blue pants, is holding a brown basket on its back.",
|
| 81 |
+
"20568676": "A stainless steel bowl filled with a mixture of chopped nuts and a yellow spatula resting on top.",
|
| 82 |
+
"20993402": "A roll of translucent adhesive tape with a smooth, glossy surface and a slightly reflective finish. The tape is wound tightly around a central cardboard core, which is visible at the top.",
|
| 83 |
+
"21107974": "A wooden gavel with a cylindrical head and a smooth, slightly tapered handle. The head features a prominent, rounded end and a series of horizontal grooves near the top. The handle is uniformly cylindrical and extends straight from the head.",
|
| 84 |
+
"21529954": "A cylindrical can with a white cap, featuring a vibrant design. The top half is orange with a small white logo, while the bottom half is green with a large, stylized white text. The can has a slightly curved shape and a glossy finish.",
|
| 85 |
+
"22064315": "The visible part of the gazelle shows a pair of long, curved horns with a dark, almost black coloration. The horns are smooth and taper to a point. The base of the horns is attached to a light brown, slightly textured head.",
|
| 86 |
+
"22107522": "A black bow tie with a smooth, satin-like finish, featuring a classic butterfly shape with pointed tips.",
|
| 87 |
+
"22879790": "A single, large, white onion with a smooth, slightly shiny surface. The onion has a bulbous shape with a few thin, papery layers visible near the top. The root end is dark brown and slightly shriveled, with a few small roots extending from it.",
|
| 88 |
+
"24010373": "The guitar has a dark, glossy body with a cutaway design. The neck is dark with white dot inlays on the fretboard. The headstock is also dark, matching the neck, and features tuning pegs. The body has a circular soundhole and a bridge with white bridge pins.",
|
| 89 |
+
"24017816": "The van is white with a large, rectangular side window and a side mirror. The window is tinted, and the side mirror is black. The van has a sleek, modern design with smooth lines and a slightly curved roof.",
|
| 90 |
+
"24498027": "A tall, slender black pole with a decorative, ornate top featuring a small, pointed finial. The pole has a rectangular, box-like structure attached near the top, and a smaller, horizontal arm extending from the middle section.",
|
| 91 |
+
"24581953": "A large, light-colored carnivore with a robust, muscular body and a thick, short coat. It has a broad head with small, rounded ears and a long, tapering tail. The legs are sturdy and strong, with large paws.",
|
| 92 |
+
"24694197": "A ripe avocado with a bumpy, dark green skin and a central pit cavity filled with a reddish-brown, creamy substance.",
|
| 93 |
+
"24786060": "A light gray towel with a soft, slightly wrinkled texture, hanging loosely with a gentle curve.",
|
| 94 |
+
"25054869": "The toilet features a smooth, rounded lid with a glossy finish, seamlessly integrated into the tank. The tank has a slightly curved, angular design with a uniform, light beige color.",
|
| 95 |
+
"25273528": "The balloon features a vibrant pattern with alternating vertical stripes of red, yellow, and green. The red stripes are the most prominent, with yellow and green stripes creating a striking contrast. The balloon has a teardrop shape with a small black basket attached at the bottom.",
|
| 96 |
+
"25273553": "A black tripod with a central column and three legs, each leg featuring a rubber foot for stability. The legs are connected to a central hub, which is part of the tripod's support structure.",
|
| 97 |
+
"25419495": "The tongs have a metallic, slightly curved arm with a black rubberized grip handle. The handle is ergonomically designed with a smooth, matte finish. The tongs are open, showing the inner surfaces of the arms, which are also metallic and slightly curved.",
|
| 98 |
+
"25419509": "A metallic fork with a slightly curved handle and four evenly spaced tines. The handle has a smooth, reflective surface with a gentle upward curve near the tines.",
|
| 99 |
+
"25419516": "A plush toy with a blue face, large white eyes with black pupils, and two pointed ears.",
|
| 100 |
+
"25579493": "A small, square-shaped bowl with rounded edges, featuring a light blue exterior and a white interior. The bowl contains a mixture of white rice and a small piece of red food item in the center.",
|
| 101 |
+
"25612310": "A dark brown wicker basket with a woven pattern, featuring a slightly curved edge and a visible portion of the basket's side."
|
| 102 |
+
}
|
evaluation/DLC-Bench/model_outputs/gar_8b_eval.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
evaluation/DLC-Bench/model_outputs/gar_8b_eval_gpt.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|