Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, File, UploadFile, Form, Request | |
from fastapi.responses import HTMLResponse, FileResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
import cv2 | |
import os | |
import torch | |
from basicsr.archs.srvgg_arch import SRVGGNetCompact | |
from gfpgan.utils import GFPGANer | |
from realesrgan.utils import RealESRGANer | |
app = FastAPI() | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
templates = Jinja2Templates(directory="templates") | |
# Download weights if not exists | |
def download_weights(): | |
weights = [ | |
('realesr-general-x4v3.pth', 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'), | |
('GFPGANv1.2.pth', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth'), | |
('GFPGANv1.3.pth', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'), | |
('GFPGANv1.4.pth', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth') | |
] | |
for weight_file, weight_url in weights: | |
if not os.path.exists(weight_file): | |
os.system(f"wget {weight_url} -P .") | |
# Initialize model and weights | |
def initialize_models(): | |
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | |
half = True if torch.cuda.is_available() else False | |
return model, half | |
# Perform image enhancement | |
def enhance_image(img_path, version, scale, model, half): | |
try: | |
input_img = cv2.imread(img_path) | |
face_enhancer = None | |
if version == 'v1.2': | |
face_enhancer = GFPGANer( | |
model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None) | |
elif version == 'v1.3': | |
face_enhancer = GFPGANer( | |
model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None) | |
elif version == 'v1.4': | |
face_enhancer = GFPGANer( | |
model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None) | |
elif version == 'RealESR-General-x4v3': | |
face_enhancer = RealESRGANer( | |
scale=4, model_path='realesr-general-x4v3.pth', model=model, tile=0, tile_pad=10, pre_pad=0, half=half) | |
if face_enhancer: | |
_, _, output = face_enhancer.enhance(input_img, has_aligned=False, only_center_face=False, paste_back=True) | |
if scale != 2: | |
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 | |
h, w = input_img.shape[0:2] | |
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation) | |
output_path = f'output/out.jpg' | |
cv2.imwrite(output_path, output) | |
return output_path | |
else: | |
return None | |
except Exception as e: | |
print(f"Error enhancing image: {e}") | |
return None | |
# Download weights | |
download_weights() | |
# Initialize model | |
model, half = initialize_models() | |
async def process_image(file: UploadFile = File(...), version: str = Form(...), scale: int = Form(...)): | |
try: | |
contents = await file.read() | |
img_path = "temp.jpg" | |
with open(img_path, "wb") as f: | |
f.write(contents) | |
output_path = enhance_image(img_path, version, scale, model, half) | |
if output_path: | |
return FileResponse(output_path, media_type='image/jpeg') | |
else: | |
return {"error": "Failed to process the image."} | |
except Exception as e: | |
return {"error": f"An error occurred: {e}"} | |
app.mount("/", StaticFiles(directory="static", html=True), name="static") | |
def index() -> FileResponse: | |
return FileResponse(path="/app/static/index.html", media_type="text/html") | |