supersolar's picture
Update utils/utils.py
c2469bd verified
raw
history blame
3.39 kB
# utils.py
import torch
import supervision as sv
from PIL import Image
import os
from typing import Union, Any, Tuple, Dict
from unittest.mock import patch
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers.dynamic_module_utils import get_imports
FLORENCE_CHECKPOINT = "microsoft/Florence-2-base"
FLORENCE_OBJECT_DETECTION_TASK = '<OD>'
FLORENCE_DETAILED_CAPTION_TASK = '<MORE_DETAILED_CAPTION>'
FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK = '<CAPTION_TO_PHRASE_GROUNDING>'
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = '<OPEN_VOCABULARY_DETECTION>'
FLORENCE_DENSE_REGION_CAPTION_TASK = '<DENSE_REGION_CAPTION>'
def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
"""Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
if not str(filename).endswith("/modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
def load_florence_model(
device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
) -> Tuple[Any, Any]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
return model, processor
def run_florence_inference(
model: Any,
processor: Any,
device: torch.device,
image: Image,
task: str,
text: str = ""
) -> Tuple[str, Dict]:
prompt = task + text
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3
)
generated_text = processor.batch_decode(
generated_ids, skip_special_tokens=False)[0]
response = processor.post_process_generation(
generated_text, task=task, image_size=image.size)
return generated_text, response
def detect_objects_in_image(image_input_path, texts, device):
# 加载图像
image_input = Image.open(image_input_path)
# 初始化检测列表
detections_list = []
# 对每个文本进行检测
for text in texts:
_, result = run_florence_inference(
model=FLORENCE_MODEL.to(device),
processor=FLORENCE_PROCESSOR,
device=device,
image=image_input,
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
text=text
)
# 从结果中构建监督检测对象
detections = sv.Detections.from_lmm(
lmm=sv.LMM.FLORENCE_2,
result=result,
resolution_wh=image_input.size
)
# 运行 SAM 推理
detections = run_sam_inference(SAM_IMAGE_MODEL.to(device), image_input, detections)
# 将检测结果添加到列表中
detections_list.append(detections)
# 合并所有检测结果
detections = sv.Detections.merge(detections_list)
# 再次运行 SAM 推理
detections = run_sam_inference(SAM_IMAGE_MODEL.to(device), image_input, detections)
return detections