Spaces:
Runtime error
Runtime error
| import torch | |
| from PIL import Image | |
| from RealESRGAN import RealESRGAN | |
| from flask import Flask, request, jsonify, send_file | |
| import io | |
| import logging | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = Flask(__name__) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| logger.info(f'Using device: {device}') | |
| model2 = RealESRGAN(device, scale=2) | |
| model2.load_weights('weights/RealESRGAN_x2.pth', download=True) | |
| logger.info('Model x2 loaded successfully') | |
| model4 = RealESRGAN(device, scale=4) | |
| model4.load_weights('weights/RealESRGAN_x4.pth', download=True) | |
| logger.info('Model x4 loaded successfully') | |
| model8 = RealESRGAN(device, scale=8) | |
| model8.load_weights('weights/RealESRGAN_x8.pth', download=True) | |
| logger.info('Model x8 loaded successfully') | |
| def inference(image, size): | |
| global model2, model4, model8 | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info('CUDA cache cleared') | |
| logger.info(f'Starting inference with scale {size}') | |
| try: | |
| if size == '2x': | |
| result = model2.predict(image.convert('RGB')) | |
| elif size == '4x': | |
| result = model4.predict(image.convert('RGB')) | |
| else: | |
| width, height = image.size | |
| if width >= 5000 or height >= 5000: | |
| return None, "The image is too large." | |
| result = model8.predict(image.convert('RGB')) | |
| logger.info(f'Inference completed for scale {size}') | |
| except torch.cuda.OutOfMemoryError as e: | |
| logger.error(f'OutOfMemoryError: {e}') | |
| logger.info(f'Reloading model for scale {size}') | |
| if size == '2x': | |
| model2 = RealESRGAN(device, scale=2) | |
| model2.load_weights('weights/RealESRGAN_x2.pth', download=False) | |
| result = model2.predict(image.convert('RGB')) | |
| elif size == '4x': | |
| model4 = RealESRGAN(device, scale=4) | |
| model4.load_weights('weights/RealESRGAN_x4.pth', download=False) | |
| result = model4.predict(image.convert('RGB')) | |
| else: | |
| model8 = RealESRGAN(device, scale=8) | |
| model8.load_weights('weights/RealESRGAN_x8.pth', download=False) | |
| result = model8.predict(image.convert('RGB')) | |
| logger.info(f'Model reloaded and inference completed for scale {size}') | |
| return result, None | |
| def upscale(): | |
| if 'image' not in request.files: | |
| logger.warning('No image uploaded') | |
| return jsonify({"error": "No image uploaded"}), 400 | |
| image_file = request.files['image'] | |
| size = request.form.get('size', '2x') | |
| try: | |
| image = Image.open(image_file) | |
| logger.info(f'Image uploaded and opened successfully') | |
| except Exception as e: | |
| logger.error(f'Invalid image file: {e}') | |
| return jsonify({"error": "Invalid image file"}), 400 | |
| result, error = inference(image, size) | |
| if error: | |
| logger.error(f'Error during inference: {error}') | |
| return jsonify({"error": error}), 400 | |
| img_io = io.BytesIO() | |
| result.save(img_io, 'PNG') | |
| img_io.seek(0) | |
| logger.info('Image processing completed and ready to be sent back') | |
| return send_file(img_io, mimetype='image/png') | |
| if __name__ == '__main__': | |
| logger.info('Starting the Flask server...') | |
| app.run(host='0.0.0.0', port=5000) |