GFPGAN / app.py
anujakkulkarni's picture
Update app.py
0608b30 verified
raw
history blame
2.96 kB
import os
import cv2
import torch
from flask import Flask, render_template, request, send_file
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from gfpgan.utils import GFPGANer
from realesrgan.utils import RealESRGANer
# Fix OpenMP threads issue
os.environ["OMP_NUM_THREADS"] = "1"
app = Flask(__name__)
os.makedirs("output", exist_ok=True)
# Setup RealESRGAN upsampler
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = 'realesr-general-x4v3.pth'
half = torch.cuda.is_available()
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
# GFPGAN inference function
def enhance_face(img_path, version="v1.4", scale=2):
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
img_mode = None
if len(img.shape) == 3 and img.shape[2] == 4:
img_mode = 'RGBA'
elif len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if version == 'RestoreFormer':
face_enhancer = GFPGANer(model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
else:
face_enhancer = GFPGANer(model_path=f"{version}.pth", upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
# Optional rescale
if scale != 2:
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
h, w = output.shape[0:2]
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
# Save output
extension = 'png' if img_mode == 'RGBA' else 'jpg'
save_path = f"output/out.{extension}"
cv2.imwrite(save_path, output)
return save_path
# Flask routes
@app.route("/", methods=["GET", "POST"])
def index():
if request.method == "POST":
file = request.files["image"]
version = request.form.get("version", "v1.4")
scale = float(request.form.get("scale", 2))
filepath = os.path.join("output", file.filename)
file.save(filepath)
output_path = enhance_face(filepath, version, scale)
return send_file(output_path, as_attachment=True)
return """
<h1>GFPGAN Face Restoration</h1>
<form method="post" enctype="multipart/form-data">
Upload Image: <input type="file" name="image"><br><br>
Version:
<select name="version">
<option value="v1.2">v1.2</option>
<option value="v1.3">v1.3</option>
<option value="v1.4" selected>v1.4</option>
<option value="RestoreFormer">RestoreFormer</option>
</select><br><br>
Rescale factor: <input type="number" step="0.1" name="scale" value="2"><br><br>
<input type="submit" value="Enhance">
</form>
"""
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True)