er1t0 commited on
Commit
5680f2a
1 Parent(s): edd3bd3

flash attn fix

Browse files
Files changed (3) hide show
  1. app.py +13 -1
  2. requirements.txt +1 -2
  3. utils.py +17 -0
app.py CHANGED
@@ -10,6 +10,8 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
10
  import cv2
11
  import traceback
12
  import matplotlib.pyplot as plt
 
 
13
 
14
  # CUDA optimizations
15
  torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
@@ -26,9 +28,19 @@ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
26
  image_predictor = SAM2ImagePredictor(sam2_model)
27
 
28
  model_id = 'microsoft/Florence-2-large'
29
- florence_model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.bfloat16).eval().cuda()
 
 
 
 
 
 
 
 
 
30
  florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
31
 
 
32
  def apply_color_mask(frame, mask, obj_id):
33
  cmap = plt.get_cmap("tab10")
34
  color = np.array(cmap(obj_id % 10)[:3]) # Use modulo 10 to cycle through colors
 
10
  import cv2
11
  import traceback
12
  import matplotlib.pyplot as plt
13
+ from utils import load_model_without_flash_attn
14
+
15
 
16
  # CUDA optimizations
17
  torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
 
28
  image_predictor = SAM2ImagePredictor(sam2_model)
29
 
30
  model_id = 'microsoft/Florence-2-large'
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+
33
+ def load_florence_model():
34
+ return AutoModelForCausalLM.from_pretrained(
35
+ model_id,
36
+ trust_remote_code=True,
37
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
38
+ ).eval().to(device)
39
+
40
+ florence_model = load_model_without_flash_attn(load_florence_model)
41
  florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
42
 
43
+
44
  def apply_color_mask(frame, mask, obj_id):
45
  cmap = plt.get_cmap("tab10")
46
  color = np.array(cmap(obj_id % 10)[:3]) # Use modulo 10 to cycle through colors
requirements.txt CHANGED
@@ -8,5 +8,4 @@ opencv-python
8
  matplotlib
9
  einops
10
  timm
11
- pytest
12
- flash_attn
 
8
  matplotlib
9
  einops
10
  timm
11
+ pytest
 
utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from unittest.mock import patch
3
+ from transformers.dynamic_module_utils import get_imports
4
+
5
+ def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
6
+ """Workaround for flash_attn import issue."""
7
+ if not str(filename).endswith(("modeling_phi.py", "configuration_florence2.py")):
8
+ return get_imports(filename)
9
+ imports = get_imports(filename)
10
+ if "flash_attn" in imports:
11
+ imports.remove("flash_attn")
12
+ return imports
13
+
14
+ def load_model_without_flash_attn(model_loader):
15
+ """Load a model using the flash_attn workaround."""
16
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
17
+ return model_loader()