|
|
import sys |
|
|
from eval_dataset import SingleRegionCaptionDataset |
|
|
from segment_anything import sam_model_registry, SamPredictor |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import base64 |
|
|
import torch |
|
|
from PIL import Image |
|
|
import io |
|
|
import argparse |
|
|
from fastapi import FastAPI |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from transformers import AutoModel, AutoProcessor, GenerationConfig |
|
|
from transformers import SamModel, SamProcessor |
|
|
try: |
|
|
from spaces import GPU |
|
|
except ImportError: |
|
|
print("Spaces not installed, using dummy GPU decorator") |
|
|
def GPU(*args, **kwargs): |
|
|
def decorator(fn): |
|
|
return fn |
|
|
return decorator |
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cpu") |
|
|
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) |
|
|
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") |
|
|
print("sam ready") |
|
|
model_path = "HaochenWang/GAR-1B" |
|
|
|
|
|
model = AutoModel.from_pretrained( |
|
|
model_path, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="cpu", |
|
|
use_flash_attn=False |
|
|
).eval() |
|
|
|
|
|
processor = AutoProcessor.from_pretrained( |
|
|
model_path, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
@GPU(duration=75) |
|
|
def image_to_sam_embedding(base64_image): |
|
|
try: |
|
|
|
|
|
image_bytes = base64.b64decode(base64_image) |
|
|
|
|
|
|
|
|
image = Image.open(io.BytesIO(image_bytes)) |
|
|
|
|
|
|
|
|
inputs = sam_processor(image, return_tensors="pt").to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
image_embedding = sam_model.get_image_embeddings(inputs["pixel_values"]) |
|
|
|
|
|
|
|
|
image_embedding = image_embedding.cpu().numpy() |
|
|
|
|
|
|
|
|
embedding_bytes = image_embedding.tobytes() |
|
|
embedding_base64 = base64.b64encode(embedding_bytes).decode('utf-8') |
|
|
|
|
|
return embedding_base64 |
|
|
except Exception as e: |
|
|
print(f"Error processing image: {str(e)}") |
|
|
raise gr.Error(f"Failed to process image: {str(e)}") |
|
|
|
|
|
@GPU(duration=75) |
|
|
def describe(image_base64: str, mask_base64: str, query: str): |
|
|
|
|
|
image_bytes = base64.b64decode(image_base64.split(',')[1] if ',' in image_base64 else image_base64) |
|
|
img = Image.open(io.BytesIO(image_bytes)) |
|
|
mask_bytes = base64.b64decode(mask_base64.split(',')[1] if ',' in mask_base64 else mask_base64) |
|
|
mask = Image.open(io.BytesIO(mask_bytes)) |
|
|
mask = np.array(mask.convert('L')) |
|
|
|
|
|
|
|
|
prompt_number = model.config.prompt_numbers |
|
|
prompt_tokens = [f"<Prompt{i_p}>" for i_p in range(prompt_number)] + ["<NO_Prompt>"] |
|
|
|
|
|
|
|
|
dataset = SingleRegionCaptionDataset( |
|
|
image=img, |
|
|
mask=mask, |
|
|
processor=processor, |
|
|
prompt_number=prompt_number, |
|
|
visual_prompt_tokens=prompt_tokens, |
|
|
data_dtype=torch.bfloat16, |
|
|
) |
|
|
|
|
|
data_sample = dataset[0] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
generate_ids = model.generate( |
|
|
**data_sample, |
|
|
generation_config=GenerationConfig( |
|
|
max_new_tokens=1024, |
|
|
|
|
|
eos_token_id=processor.tokenizer.eos_token_id, |
|
|
pad_token_id=processor.tokenizer.pad_token_id, |
|
|
), |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
output_caption = processor.tokenizer.decode(generate_ids.sequences[0], skip_special_tokens=True).strip() |
|
|
|
|
|
|
|
|
|
|
|
text = "" |
|
|
for token in output_caption: |
|
|
text += token |
|
|
yield text |
|
|
|
|
|
@GPU(duration=75) |
|
|
def describe_without_streaming(image_base64: str, mask_base64: str, query: str): |
|
|
|
|
|
image_bytes = base64.b64decode(image_base64.split(',')[1] if ',' in image_base64 else image_base64) |
|
|
img = Image.open(io.BytesIO(image_bytes)) |
|
|
mask_bytes = base64.b64decode(mask_base64.split(',')[1] if ',' in mask_base64 else mask_base64) |
|
|
mask = Image.open(io.BytesIO(mask_bytes)) |
|
|
mask = np.array(mask.convert('L')) |
|
|
prompt_number = model.config.prompt_numbers |
|
|
prompt_tokens = [f"<Prompt{i_p}>" for i_p in range(prompt_number)] + ["<NO_Prompt>"] |
|
|
|
|
|
|
|
|
dataset = SingleRegionCaptionDataset( |
|
|
image=img, |
|
|
mask=mask, |
|
|
processor=processor, |
|
|
prompt_number=prompt_number, |
|
|
visual_prompt_tokens=prompt_tokens, |
|
|
data_dtype=torch.bfloat16, |
|
|
) |
|
|
|
|
|
data_sample = dataset[0] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
generate_ids = model.generate( |
|
|
**data_sample, |
|
|
generation_config=GenerationConfig( |
|
|
max_new_tokens=1024, |
|
|
|
|
|
eos_token_id=processor.tokenizer.eos_token_id, |
|
|
pad_token_id=processor.tokenizer.pad_token_id, |
|
|
), |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
output_caption = processor.tokenizer.decode(generate_ids.sequences[0], skip_special_tokens=True).strip() |
|
|
|
|
|
|
|
|
return output_caption |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Describe Anything gradio demo") |
|
|
parser.add_argument("--server_addr", "--host", type=str, default=None, help="The server address to listen on.") |
|
|
parser.add_argument("--server_port", "--port", type=int, default=None, help="The port to listen on.") |
|
|
parser.add_argument("--model-path", type=str, default="HaochenWang/GAR-1B", help="Path to the model checkpoint") |
|
|
parser.add_argument("--prompt-mode", type=str, default="full+focal_crop", help="Prompt mode") |
|
|
parser.add_argument("--conv-mode", type=str, default="v1", help="Conversation mode") |
|
|
parser.add_argument("--temperature", type=float, default=0.2, help="Sampling temperature") |
|
|
parser.add_argument("--top_p", type=float, default=0.5, help="Top-p for sampling") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Interface( |
|
|
fn=image_to_sam_embedding, |
|
|
inputs=gr.Textbox(label="Image Base64"), |
|
|
outputs=gr.Textbox(label="Embedding Base64"), |
|
|
title="Image Embedding Generator", |
|
|
api_name="image_to_sam_embedding" |
|
|
) |
|
|
gr.Interface( |
|
|
fn=describe, |
|
|
inputs=[ |
|
|
gr.Textbox(label="Image Base64"), |
|
|
gr.Text(label="Mask Base64"), |
|
|
gr.Text(label="Prompt") |
|
|
], |
|
|
outputs=[ |
|
|
gr.Text(label="Description") |
|
|
], |
|
|
title="Mask Description Generator", |
|
|
api_name="describe" |
|
|
) |
|
|
gr.Interface( |
|
|
fn=describe_without_streaming, |
|
|
inputs=[ |
|
|
gr.Textbox(label="Image Base64"), |
|
|
gr.Text(label="Mask Base64"), |
|
|
gr.Text(label="Prompt") |
|
|
], |
|
|
outputs=[ |
|
|
gr.Text(label="Description") |
|
|
], |
|
|
title="Mask Description Generator (Non-Streaming)", |
|
|
api_name="describe_without_streaming" |
|
|
) |
|
|
|
|
|
demo._block_thread = demo.block_thread |
|
|
demo.block_thread = lambda: None |
|
|
demo.launch( |
|
|
share=True, |
|
|
server_name=args.server_addr, |
|
|
server_port=args.server_port, |
|
|
ssr_mode=False, |
|
|
) |
|
|
|
|
|
for route in demo.app.routes: |
|
|
if route.path == "/": |
|
|
demo.app.routes.remove(route) |
|
|
demo.app.mount("/", StaticFiles(directory="dist", html=True), name="demo") |
|
|
|
|
|
demo._block_thread() |
|
|
|