import argparse from flask import Flask, request, jsonify from PIL import Image from io import BytesIO import base64 from transformers import AutoProcessor, AutoModelForCausalLM import threading app = Flask(__name__) # Parse command line arguments parser = argparse.ArgumentParser(description='Start the Flask server with specified model and device.') parser.add_argument('--model-path', type=str, default="models/Florence-2-base", help='Path to the pretrained model') parser.add_argument('--device', type=str, choices=['cpu', 'gpu'], default='auto', help='Device to use: "cpu", "gpu", or "auto"') args = parser.parse_args() # Determine the device device = "cpu" # Initialize the model and processor model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True).to(device) processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True) lock = threading.Lock() # Use a lock to ensure thread safety when accessing the model def predict_image(image, task: str = "", prompt: str = None): prompt = task + " " + prompt if prompt else task print(f"Prompt: {prompt}") with lock: inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, do_sample=False, num_beams=3 ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = processor.post_process_generation(generated_text, task=task, image_size=(image.width, image.height)) return parsed_answer @app.route('/predict', methods=['POST']) def predict(): if request.is_json: data = request.get_json() if 'image' not in data: return jsonify({'error': 'No image found in JSON'}), 400 image_data = base64.b64decode(data['image'].split(",")[1]) image = Image.open(BytesIO(image_data)) else: return jsonify({'error': 'No image file or JSON payload'}), 400 task = data.get('task', "") prompt = data.get('prompt', None) prediction = predict_image(image, task, prompt) msgid = data.get('msgid', None) response = { 'msgid': msgid, 'prediction': prediction } return jsonify(response) if __name__ == '__main__': app.run(host='0.0.0.0', port=5000, threaded=True) # Enable multi-threading