ProductPlacement / utils /florence.py
Ashoka74's picture
Upload folder using huggingface_hub
5ea4356 verified
raw
history blame
2.67 kB
import os
from typing import Union, Any, Tuple, Dict, List
from unittest.mock import patch
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers.dynamic_module_utils import get_imports
import importlib
FLORENCE_CHECKPOINT = "microsoft/Florence-2-base"
# FLORENCE_CHECKPOINT = "microsoft/Florence-2-large"
FLORENCE_OBJECT_DETECTION_TASK = '<OD>'
FLORENCE_DETAILED_CAPTION_TASK = '<MORE_DETAILED_CAPTION>'
FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK = '<CAPTION_TO_PHRASE_GROUNDING>'
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = '<OPEN_VOCABULARY_DETECTION>'
FLORENCE_DENSE_REGION_CAPTION_TASK = '<DENSE_REGION_CAPTION>'
# Removing the unnecessary flash_attn import which causes issues on CPU or MPS backends
def fixed_get_imports(filename) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement
florence_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base",trust_remote_code=True).to(device)
florence_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
def load_florence_model(
device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
) -> Tuple[Any, Any]:
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
model = AutoModelForCausalLM.from_pretrained(
checkpoint, trust_remote_code=True).to(device).eval()
processor = AutoProcessor.from_pretrained(
checkpoint, trust_remote_code=True)
return model, processor
def run_florence_inference(
model: Any,
processor: Any,
device: torch.device,
image: Image,
task: str,
text: str = ""
) -> Tuple[str, Dict]:
prompt = task + text
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3
)
generated_text = processor.batch_decode(
generated_ids, skip_special_tokens=False)[0]
response = processor.post_process_generation(
generated_text, task=task, image_size=image.size)
return generated_text, response
load_florence_model(device='cuda')