import os from flask import Flask, request, jsonify, send_file from optimum.intel.openvino.modeling_diffusion import OVStableDiffusionXLPipeline from PIL import Image import io import torch import logging from datetime import datetime # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = Flask(__name__) pipeline = OVStableDiffusionXLPipeline.from_pretrained( "rupeshs/hyper-sd-sdxl-1-step-openvino-int8", ov_config={"CACHE_DIR": ""}, ) @app.route('/generate', methods=['POST']) def generate_image(): try: # Get parameters from request data = request.get_json() prompt = data.get('prompt', 'A futuristic cityscape at sunset, cyberpunk style, 8k') width = data.get('width', 512) height = data.get('height', 512) num_inference_steps = data.get('num_inference_steps', 1) guidance_scale = data.get('guidance_scale', 1.0) # Generate image image = pipeline( prompt=prompt, width=width, height=height, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale ).images[0] # Save image to a bytes buffer img_io = io.BytesIO() image.save(img_io, 'PNG') img_io.seek(0) return send_file( img_io, mimetype='image/png', as_attachment=True, download_name='generated_image.png' ) except Exception as e: logger.error(f"Image generation failed: {str(e)}") return jsonify({'error': str(e)}), 500 @app.route('/health', methods=['GET']) def health_check(): """ Health check endpoint that returns the status of the application """ try: # Basic health check response health_status = { 'status': 'healthy', 'message': 'Application is running', 'timestamp': datetime.now().isoformat() } # Optional: Add additional checks (e.g., database connectivity) # Example: Check database connection # db_status = check_database_connection() # health_status['database'] = db_status return jsonify(health_status), 200 except Exception as e: error_response = { 'status': 'unhealthy', 'message': f'Health check failed: {str(e)}', 'timestamp': datetime.now().isoformat() } return jsonify(error_response), 500 if __name__ == '__main__': port = int(os.getenv('PORT', 7860)) app.run(host='localhost', port=port)