demo / api.py
zbing's picture
Upload folder using huggingface_hub
7dfaa81 verified
raw
history blame
2.49 kB
import argparse
from flask import Flask, request, jsonify
from PIL import Image
from io import BytesIO
import base64
from transformers import AutoProcessor, AutoModelForCausalLM
import threading
app = Flask(__name__)
# Parse command line arguments
parser = argparse.ArgumentParser(description='Start the Flask server with specified model and device.')
parser.add_argument('--model-path', type=str, default="models/Florence-2-base", help='Path to the pretrained model')
parser.add_argument('--device', type=str, choices=['cpu', 'gpu'], default='auto', help='Device to use: "cpu", "gpu", or "auto"')
args = parser.parse_args()
# Determine the device
device = "cpu"
# Initialize the model and processor
model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True)
lock = threading.Lock() # Use a lock to ensure thread safety when accessing the model
def predict_image(image, task: str = "<OD>", prompt: str = None):
prompt = task + " " + prompt if prompt else task
print(f"Prompt: {prompt}")
with lock:
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(generated_text, task=task, image_size=(image.width, image.height))
return parsed_answer
@app.route('/predict', methods=['POST'])
def predict():
if request.is_json:
data = request.get_json()
if 'image' not in data:
return jsonify({'error': 'No image found in JSON'}), 400
image_data = base64.b64decode(data['image'].split(",")[1])
image = Image.open(BytesIO(image_data))
else:
return jsonify({'error': 'No image file or JSON payload'}), 400
task = data.get('task', "<OD>")
prompt = data.get('prompt', None)
prediction = predict_image(image, task, prompt)
msgid = data.get('msgid', None)
response = {
'msgid': msgid,
'prediction': prediction
}
return jsonify(response)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, threaded=True) # Enable multi-threading