|
from transformers import AutoModel, AutoConfig |
|
from DaViT.modeling_davit import DaViTModel |
|
from DaViT.configuration_davit import DaViTConfig |
|
from unittest.mock import patch |
|
import os |
|
import logging |
|
import requests |
|
from PIL import Image |
|
import torch |
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
from unittest.mock import patch |
|
from transformers.dynamic_module_utils import get_imports |
|
from typing import Tuple, Dict, Any, Union, List |
|
|
|
|
|
def fixed_get_imports(filename: str | os.PathLike) -> list[str]: |
|
""" |
|
Custom workaround for the import error related to flash_attn. |
|
Args: |
|
filename (str | os.PathLike): The filename to check for imports. |
|
Returns: |
|
list[str]: List of required imports. |
|
""" |
|
if not str(filename).endswith("modeling_florence2.py"): |
|
return get_imports(filename) |
|
imports = get_imports(filename) |
|
if "flash_attn" in imports: |
|
imports.remove("flash_attn") |
|
return imports |
|
|
|
|
|
current_directory = os.getcwd() |
|
|
|
|
|
AutoConfig.register("davit", DaViTConfig) |
|
AutoModel.register(DaViTConfig, DaViTModel) |
|
|
|
|
|
|
|
DaViTConfig.register_for_auto_class() |
|
DaViTModel.register_for_auto_class("AutoModel") |
|
|
|
AutoConfig.register("davit", DaViTConfig) |
|
AutoModel.register(DaViTConfig, DaViTModel) |
|
|
|
|
|
config = DaViTConfig() |
|
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"microsoft/Florence-2-large-ft", |
|
trust_remote_code=True, |
|
cache_dir=current_directory, |
|
device_map="cpu", |
|
torch_dtype=torch.float16, |
|
) |
|
processor = AutoProcessor.from_pretrained( |
|
"microsoft/Florence-2-large-ft", |
|
trust_remote_code=True, |
|
cache_dir=current_directory, |
|
device_map="cpu", |
|
) |
|
|
|
model2 = AutoModel.from_config(config) |
|
model2.to(torch.float16) |
|
|
|
model2.load_state_dict(model.vision_tower.state_dict()) |
|
|
|
|
|
model2.push_to_hub("DaViT-Florence-2-large-ft") |
|
processor.push_to_hub("DaViT-Florence-2-large-ft") |