zbing commited on
Commit
a5cb3bd
·
verified ·
1 Parent(s): ebc2977

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. api.py +15 -1
api.py CHANGED
@@ -5,6 +5,8 @@ from io import BytesIO
5
  import base64
6
  from transformers import AutoProcessor, AutoModelForCausalLM
7
  import threading
 
 
8
 
9
  app = Flask(__name__)
10
 
@@ -17,7 +19,19 @@ args = parser.parse_args()
17
  # Determine the device
18
  device = "cpu"
19
  # Initialize the model and processor
20
- model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True, device_map=device)
 
 
 
 
 
 
 
 
 
 
 
 
21
  processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True, device_map=device)
22
 
23
  lock = threading.Lock() # Use a lock to ensure thread safety when accessing the model
 
5
  import base64
6
  from transformers import AutoProcessor, AutoModelForCausalLM
7
  import threading
8
+ from unittest.mock import patch
9
+ from transformers.dynamic_module_utils import get_imports
10
 
11
  app = Flask(__name__)
12
 
 
19
  # Determine the device
20
  device = "cpu"
21
  # Initialize the model and processor
22
+
23
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement
24
+ model = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation="sdpa", torch_dtype=dtype,trust_remote_code=True)
25
+
26
+ def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
27
+ if not str(filename).endswith("modeling_florence2.py"):
28
+ return get_imports(filename)
29
+ imports = get_imports(filename)
30
+ imports.remove("flash_attn")
31
+ return imports
32
+
33
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement
34
+ model = AutoModelForCausalLM.from_pretrained(args.model_path, attn_implementation="sdpa", torch_dtype=dtype,trust_remote_code=True)
35
  processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True, device_map=device)
36
 
37
  lock = threading.Lock() # Use a lock to ensure thread safety when accessing the model