Furdockgr1 / main.py
Ashrafb's picture
Rename app.py to main.py
f83fa14 verified
raw history blame
No virus
3.99 kB
from fastapi import FastAPI, File, UploadFile, Form, Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import cv2
import os
import torch
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from gfpgan.utils import GFPGANer
from realesrgan.utils import RealESRGANer
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
# Download weights if not exists
def download_weights():
weights = [
('realesr-general-x4v3.pth', 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'),
('GFPGANv1.2.pth', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth'),
('GFPGANv1.3.pth', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'),
('GFPGANv1.4.pth', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth')
]
for weight_file, weight_url in weights:
if not os.path.exists(weight_file):
os.system(f"wget {weight_url} -P .")
# Initialize model and weights
def initialize_models():
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
half = True if torch.cuda.is_available() else False
return model, half
# Perform image enhancement
def enhance_image(img_path, version, scale, model, half):
try:
input_img = cv2.imread(img_path)
face_enhancer = None
if version == 'v1.2':
face_enhancer = GFPGANer(
model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None)
elif version == 'v1.3':
face_enhancer = GFPGANer(
model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None)
elif version == 'v1.4':
face_enhancer = GFPGANer(
model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None)
elif version == 'RealESR-General-x4v3':
face_enhancer = RealESRGANer(
scale=4, model_path='realesr-general-x4v3.pth', model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
if face_enhancer:
_, _, output = face_enhancer.enhance(input_img, has_aligned=False, only_center_face=False, paste_back=True)
if scale != 2:
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
h, w = input_img.shape[0:2]
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
output_path = f'output/out.jpg'
cv2.imwrite(output_path, output)
return output_path
else:
return None
except Exception as e:
print(f"Error enhancing image: {e}")
return None
# Download weights
download_weights()
# Initialize model
model, half = initialize_models()
@app.post("/process_image/")
async def process_image(file: UploadFile = File(...), version: str = Form(...), scale: int = Form(...)):
try:
contents = await file.read()
img_path = "temp.jpg"
with open(img_path, "wb") as f:
f.write(contents)
output_path = enhance_image(img_path, version, scale, model, half)
if output_path:
return FileResponse(output_path, media_type='image/jpeg')
else:
return {"error": "Failed to process the image."}
except Exception as e:
return {"error": f"An error occurred: {e}"}
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")