demo / api.py
zbing's picture
Upload folder using huggingface_hub
76b5d46 verified
raw
history blame
3.39 kB
import argparse
from flask import Flask, request, jsonify
from PIL import Image
from io import BytesIO
import base64
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
import threading
app = Flask(__name__)
# Parse command line arguments
parser = argparse.ArgumentParser(description='Start the Flask server with specified model and device.')
parser.add_argument('--model-path', type=str, required=True, help='Path to the pretrained model')
parser.add_argument('--device', type=str, choices=['cpu', 'gpu'], default='auto', help='Device to use: "cpu", "gpu", or "auto"')
args = parser.parse_args()
# Determine the device
if args.device == 'auto':
device = "cuda:0" if torch.cuda.is_available() else "cpu"
elif args.device == 'gpu':
if torch.cuda.is_available():
device = "cuda:0"
else:
raise ValueError("GPU option specified but no GPU is available.")
else:
device = "cpu"
torch_dtype = torch.float16 if device.startswith("cuda") else torch.float32
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
import os
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
# Initialize the model and processor
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement
model = AutoModelForCausalLM.from_pretrained(args.model_path, attn_implementation="sdpa", torch_dtype=torch_dtype,trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True)
lock = threading.Lock() # Use a lock to ensure thread safety when accessing the model
def predict_image(image, task: str = "<OD>", prompt: str = None):
prompt = task + " " + prompt if prompt else task
print(f"Prompt: {prompt}")
with lock:
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(generated_text, task=task, image_size=(image.width, image.height))
return parsed_answer
@app.route('/predict', methods=['POST'])
def predict():
if request.is_json:
data = request.get_json()
if 'image' not in data:
return jsonify({'error': 'No image found in JSON'}), 400
image_data = base64.b64decode(data['image'].split(",")[1])
image = Image.open(BytesIO(image_data))
else:
return jsonify({'error': 'No image file or JSON payload'}), 400
task = data.get('task', "<OD>")
prompt = data.get('prompt', None)
prediction = predict_image(image, task, prompt)
msgid = data.get('msgid', None)
response = {
'msgid': msgid,
'prediction': prediction
}
return jsonify(response)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, threaded=True) # Enable multi-threading