zbing commited on
Commit
c88b0bd
1 Parent(s): a5cb3bd

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. api.py +21 -11
api.py CHANGED
@@ -3,25 +3,34 @@ from flask import Flask, request, jsonify
3
  from PIL import Image
4
  from io import BytesIO
5
  import base64
 
6
  from transformers import AutoProcessor, AutoModelForCausalLM
7
  import threading
8
- from unittest.mock import patch
9
- from transformers.dynamic_module_utils import get_imports
10
 
11
  app = Flask(__name__)
12
 
13
  # Parse command line arguments
14
  parser = argparse.ArgumentParser(description='Start the Flask server with specified model and device.')
15
- parser.add_argument('--model-path', type=str, default="models/Florence-2-base", help='Path to the pretrained model')
16
  parser.add_argument('--device', type=str, choices=['cpu', 'gpu'], default='auto', help='Device to use: "cpu", "gpu", or "auto"')
17
  args = parser.parse_args()
18
 
19
  # Determine the device
20
- device = "cpu"
21
- # Initialize the model and processor
 
 
 
 
 
 
 
22
 
23
- with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement
24
- model = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation="sdpa", torch_dtype=dtype,trust_remote_code=True)
 
 
 
25
 
26
  def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
27
  if not str(filename).endswith("modeling_florence2.py"):
@@ -30,9 +39,10 @@ def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
30
  imports.remove("flash_attn")
31
  return imports
32
 
33
- with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement
34
- model = AutoModelForCausalLM.from_pretrained(args.model_path, attn_implementation="sdpa", torch_dtype=dtype,trust_remote_code=True)
35
- processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True, device_map=device)
 
36
 
37
  lock = threading.Lock() # Use a lock to ensure thread safety when accessing the model
38
 
@@ -40,7 +50,7 @@ def predict_image(image, task: str = "<OD>", prompt: str = None):
40
  prompt = task + " " + prompt if prompt else task
41
  print(f"Prompt: {prompt}")
42
  with lock:
43
- inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
44
  generated_ids = model.generate(
45
  input_ids=inputs["input_ids"],
46
  pixel_values=inputs["pixel_values"],
 
3
  from PIL import Image
4
  from io import BytesIO
5
  import base64
6
+ import torch
7
  from transformers import AutoProcessor, AutoModelForCausalLM
8
  import threading
 
 
9
 
10
  app = Flask(__name__)
11
 
12
  # Parse command line arguments
13
  parser = argparse.ArgumentParser(description='Start the Flask server with specified model and device.')
14
+ parser.add_argument('--model-path', type=str, required=True, help='Path to the pretrained model')
15
  parser.add_argument('--device', type=str, choices=['cpu', 'gpu'], default='auto', help='Device to use: "cpu", "gpu", or "auto"')
16
  args = parser.parse_args()
17
 
18
  # Determine the device
19
+ if args.device == 'auto':
20
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
21
+ elif args.device == 'gpu':
22
+ if torch.cuda.is_available():
23
+ device = "cuda:0"
24
+ else:
25
+ raise ValueError("GPU option specified but no GPU is available.")
26
+ else:
27
+ device = "cpu"
28
 
29
+ torch_dtype = torch.float16 if device.startswith("cuda") else torch.float32
30
+
31
+ from unittest.mock import patch
32
+ from transformers.dynamic_module_utils import get_imports
33
+ import os
34
 
35
  def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
36
  if not str(filename).endswith("modeling_florence2.py"):
 
39
  imports.remove("flash_attn")
40
  return imports
41
 
42
+ # Initialize the model and processor
43
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement
44
+ model = AutoModelForCausalLM.from_pretrained(args.model_path, attn_implementation="sdpa", torch_dtype=torch_dtype,trust_remote_code=True).to(device)
45
+ processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True)
46
 
47
  lock = threading.Lock() # Use a lock to ensure thread safety when accessing the model
48
 
 
50
  prompt = task + " " + prompt if prompt else task
51
  print(f"Prompt: {prompt}")
52
  with lock:
53
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
54
  generated_ids = model.generate(
55
  input_ids=inputs["input_ids"],
56
  pixel_values=inputs["pixel_values"],