|
|
import os |
|
|
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' |
|
|
|
|
|
web_url ="https://api.quantumgrove.tech" |
|
|
|
|
|
from fastapi import FastAPI, UploadFile, File, BackgroundTasks |
|
|
|
|
|
import sqlite3 |
|
|
|
|
|
def create_database(): |
|
|
conn = sqlite3.connect('api.db') |
|
|
cursor = conn.cursor() |
|
|
cursor.execute('''CREATE TABLE IF NOT EXISTS ban_list ( |
|
|
id INTEGER PRIMARY KEY, |
|
|
ip TEXT NOT NULL, |
|
|
reason TEXT NOT NULL)''') |
|
|
cursor.execute('''CREATE TABLE IF NOT EXISTS request_log ( |
|
|
id INTEGER PRIMARY KEY, |
|
|
ip TEXT NOT NULL, |
|
|
url TEXT NOT NULL, |
|
|
method TEXT NOT NULL, |
|
|
endpoint TEXT NOT NULL, |
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)''') |
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
def insert_ban(ip, reason): |
|
|
conn = sqlite3.connect('api.db') |
|
|
cursor = conn.cursor() |
|
|
cursor.execute('''INSERT INTO ban_list (ip, reason) VALUES (?, ?)''', (ip, reason)) |
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
def search_ban(ip): |
|
|
conn = sqlite3.connect('api.db') |
|
|
cursor = conn.cursor() |
|
|
cursor.execute('''SELECT * FROM ban_list WHERE ip = ?''', (ip,)) |
|
|
rows = cursor.fetchall() |
|
|
conn.close() |
|
|
return rows |
|
|
|
|
|
def log_request(ip, url, method, endpoint): |
|
|
conn = sqlite3.connect('api.db') |
|
|
cursor = conn.cursor() |
|
|
cursor.execute('''INSERT INTO request_log (ip, url, method, endpoint) VALUES (?, ?, ?, ?)''', (ip, url, method, endpoint)) |
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
if not os.path.exists('api.db'): |
|
|
create_database() |
|
|
|
|
|
|
|
|
def generate_response(response_message , response_status ,uuid_code , age , gender , metadata): |
|
|
|
|
|
response_dict = { |
|
|
"response_message" : response_message, |
|
|
"response_status" : response_status, |
|
|
"data" :{ |
|
|
"UUID" : uuid_code, |
|
|
"info" : {"age" : age , "gender" : gender}, |
|
|
"metadata" : metadata |
|
|
} |
|
|
} |
|
|
|
|
|
return response_dict |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import cv2 |
|
|
import time |
|
|
import io |
|
|
import gc |
|
|
import shutil |
|
|
import uuid |
|
|
import torch |
|
|
import subprocess |
|
|
import torchvision.transforms as transforms |
|
|
from scripts.psp import pSp |
|
|
from argparse import Namespace |
|
|
import dlib |
|
|
from scripts.align_all_parallel import align_face |
|
|
from scripts.augmentations import AgeTransformer |
|
|
from scripts.common import tensor2im |
|
|
from torchvision.transforms.functional import normalize |
|
|
from scripts.basicsr.utils import img2tensor, tensor2img |
|
|
from scripts.basicsr.utils.misc import get_device |
|
|
from scripts.facelib.utils.face_restoration_helper import FaceRestoreHelper |
|
|
from scripts.basicsr.utils.registry import ARCH_REGISTRY |
|
|
from scripts.basicsr.archs.rrdbnet_arch import RRDBNet |
|
|
from scripts.basicsr.utils.realesrgan_utils import RealESRGANer |
|
|
from scripts.facelib.utils.misc import is_gray |
|
|
import PIL.Image |
|
|
from scripts.erasescratches.models import Pix2PixHDModel_Mapping |
|
|
from scripts.erasescratches.options import Options |
|
|
from scripts.maskscratches import ScratchesDetector |
|
|
from scripts.util import irregular_hole_synthesize, tensor_to_ndarray |
|
|
import insightface |
|
|
from insightface.app import FaceAnalysis |
|
|
from pydub import AudioSegment |
|
|
import scripts.RRDBNet_arch as arch |
|
|
from RealESRGAN import RealESRGAN |
|
|
from PIL import Image |
|
|
|
|
|
device = ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
EXPERIMENT_TYPE = 'ffhq_aging' |
|
|
model_path = "./models/sam_ffhq_aging.pt" |
|
|
model_age_slider = None |
|
|
|
|
|
EXPERIMENT_DATA_ARGS = { |
|
|
"ffhq_aging": { |
|
|
"transform": transforms.Compose([ |
|
|
transforms.Resize((256, 256)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) |
|
|
} |
|
|
} |
|
|
EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE] |
|
|
|
|
|
|
|
|
def load_model_age(): |
|
|
try: |
|
|
if not get_model_state_age(): |
|
|
ckpt = torch.load(model_path, map_location='cpu') |
|
|
opts = ckpt['opts'] |
|
|
del ckpt |
|
|
torch.cuda.empty_cache() |
|
|
opts['checkpoint_path'] = model_path |
|
|
opts = Namespace(**opts) |
|
|
global model_age_slider |
|
|
model_age_slider = pSp(opts) |
|
|
del opts |
|
|
torch.cuda.empty_cache() |
|
|
model_age_slider.eval() |
|
|
model_age_slider.cuda() |
|
|
torch.cuda.empty_cache() |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
def unload_model_age(): |
|
|
try: |
|
|
global model_age_slider |
|
|
if model_age_slider is not None: |
|
|
model_age_slider = None |
|
|
torch.cuda.empty_cache() |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
def get_model_state_age(): |
|
|
global model_age_slider |
|
|
if model_age_slider is None: |
|
|
return False |
|
|
return True |
|
|
|
|
|
from scripts.BG import BG |
|
|
bg_model = None |
|
|
|
|
|
def load_model_bg(): |
|
|
try: |
|
|
global bg_model |
|
|
if bg_model is None: |
|
|
bg_model = BG() |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
def unload_model_bg(): |
|
|
try: |
|
|
global bg_model |
|
|
if bg_model is not None: |
|
|
bg_model = None |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
def get_model_state_bg(): |
|
|
global bg_model |
|
|
if bg_model is None: |
|
|
return False |
|
|
return True |
|
|
|
|
|
model_upsampler = None |
|
|
model_code_former = None |
|
|
|
|
|
|
|
|
def set_realesrgan_cf(): |
|
|
use_half = True if torch.cuda.is_available() else False |
|
|
model = RRDBNet( |
|
|
num_in_ch=3, |
|
|
num_out_ch=3, |
|
|
num_feat=64, |
|
|
num_block=23, |
|
|
num_grow_ch=32, |
|
|
scale=2, |
|
|
) |
|
|
upsampler = RealESRGANer( |
|
|
scale=2, |
|
|
model_path="./models/realesrgan/RealESRGAN_x2plus.pth", |
|
|
model=model, |
|
|
tile=400, |
|
|
tile_pad=40, |
|
|
pre_pad=0, |
|
|
half=use_half |
|
|
) |
|
|
return upsampler |
|
|
|
|
|
|
|
|
|
|
|
def load_model_cf(): |
|
|
try: |
|
|
device = get_device() |
|
|
global model_code_former |
|
|
model_code_former = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, |
|
|
connect_list=['32', '64', '128', '256']).to(device) |
|
|
ckpt_path = 'models/CodeFormer/codeformer.pth' |
|
|
checkpoint = torch.load(ckpt_path)['params_ema'] |
|
|
model_code_former.load_state_dict(checkpoint) |
|
|
model_code_former.eval() |
|
|
|
|
|
global model_upsampler |
|
|
model_upsampler = set_realesrgan_cf() |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
|
|
|
def get_model_state_cf(): |
|
|
global model_code_former |
|
|
if model_code_former is None: |
|
|
return False |
|
|
else: |
|
|
return True |
|
|
|
|
|
|
|
|
def unload_model_cf(): |
|
|
if get_model_state_cf() == False: |
|
|
return True |
|
|
try: |
|
|
global model_code_former |
|
|
model_code_former = None |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
global model_upsampler |
|
|
model_upsampler = None |
|
|
torch.cuda.empty_cache() |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
swap_model = None |
|
|
app_model = None |
|
|
|
|
|
def load_model_swap(): |
|
|
try: |
|
|
global swap_model |
|
|
global app_model |
|
|
if swap_model is None: |
|
|
swap_model = insightface.model_zoo.get_model("./models/inswapper_128.onnx",download=False,download_zip=False) |
|
|
app_model = FaceAnalysis(name="buffalo_l") |
|
|
app_model.prepare(ctx_id=0,det_size=(640,640)) |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
def unload_model_swap(): |
|
|
try: |
|
|
global swap_model |
|
|
global app_model |
|
|
if swap_model is not None: |
|
|
swap_model = None |
|
|
if app_model is not None: |
|
|
app_model = None |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
def get_model_state_swap(): |
|
|
global swap_model |
|
|
|
|
|
if swap_model is None: |
|
|
return False |
|
|
return True |
|
|
|
|
|
model_scratches_remove = None |
|
|
model_scratches_remove_detector = None |
|
|
model_scratches_remove_options = None |
|
|
|
|
|
|
|
|
def load_model_rest(): |
|
|
try: |
|
|
|
|
|
model_path = "./models/zeroscratches/restoration" |
|
|
global model_scratches_remove_detector |
|
|
model_scratches_remove_detector = ScratchesDetector('./models/zeroscratches') |
|
|
gpu_ids = [] |
|
|
if torch.cuda.is_available(): |
|
|
gpu_ids = [d for d in range(torch.cuda.device_count())] |
|
|
global model_scratches_remove_options |
|
|
model_scratches_remove_options = Options(model_path, gpu_ids) |
|
|
model_scratches = Pix2PixHDModel_Mapping() |
|
|
global model_scratches_remove |
|
|
model_scratches_remove = Pix2PixHDModel_Mapping() |
|
|
model_scratches_remove.initialize(model_scratches_remove_options) |
|
|
model_scratches_remove.eval() |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
|
|
|
def get_model_state_rest(): |
|
|
global model_scratches_remove |
|
|
if model_scratches_remove is None: |
|
|
return False |
|
|
else: |
|
|
return True |
|
|
|
|
|
|
|
|
def unload_model_rest(): |
|
|
if get_model_state_rest() == False: |
|
|
return True |
|
|
try: |
|
|
global model_scratches_remove |
|
|
model_scratches_remove = None |
|
|
global model_scratches_remove_detector |
|
|
model_scratches_remove_detector = None |
|
|
global model_scratches_remove_options |
|
|
model_scratches_remove_options = None |
|
|
torch.cuda.empty_cache() |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
from scripts.talker import sad_talker |
|
|
sad_talker_model = None |
|
|
|
|
|
def load_model_talk(): |
|
|
try: |
|
|
global sad_talker_model |
|
|
if sad_talker_model is None: |
|
|
sad_talker_model = sad_talker() |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
def unload_model_talk(): |
|
|
try: |
|
|
global sad_talker_model |
|
|
if sad_talker_model is not None: |
|
|
sad_talker_model = None |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
def get_model_state_talk(): |
|
|
global sad_talker_model |
|
|
if sad_talker_model is None: |
|
|
return False |
|
|
return True |
|
|
|
|
|
from scripts.SKY import SKY |
|
|
sky_model = None |
|
|
|
|
|
def load_model_sky(): |
|
|
try: |
|
|
global sky_model |
|
|
if sky_model is None: |
|
|
sky_model = SKY() |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
def unload_model_sky(): |
|
|
try: |
|
|
global sky_model |
|
|
if sky_model is not None: |
|
|
sky_model = None |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
def get_model_state_sky(): |
|
|
global sky_model |
|
|
if sky_model is None: |
|
|
return False |
|
|
return True |
|
|
|
|
|
model_upscale = None |
|
|
|
|
|
def load_model_up(): |
|
|
try: |
|
|
model_path_upscale = './models/RRDB_ESRGAN_x4.pth' |
|
|
global model_upscale |
|
|
model_upscale = arch.RRDBNet(3, 3, 64, 23, gc=32) |
|
|
model_upscale.load_state_dict(torch.load(model_path_upscale), strict=True) |
|
|
|
|
|
model_upscale = model_upscale.half() |
|
|
model_upscale.eval() |
|
|
model_upscale = model_upscale.to(device) |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
|
|
|
def get_model_state_up(): |
|
|
global model_upscale |
|
|
if model_upscale is None: |
|
|
return False |
|
|
else: |
|
|
return True |
|
|
|
|
|
|
|
|
def unload_model_up(): |
|
|
if get_model_state_up() == False: |
|
|
return True |
|
|
try: |
|
|
global model_upscale |
|
|
model_upscale = None |
|
|
torch.cuda.empty_cache() |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
def run_alignment(image_path): |
|
|
predictor = dlib.shape_predictor("./models/shape_predictor_68_face_landmarks.dat") |
|
|
aligned_image = align_face(filepath=image_path, predictor=predictor) |
|
|
return aligned_image |
|
|
|
|
|
def run_on_batch(inputs, model_age_slider): |
|
|
result_batch = model_age_slider(inputs.to("cuda").float(), randomize_noise=False, resize=False) |
|
|
return result_batch |
|
|
|
|
|
|
|
|
|
|
|
model_esrgan = None |
|
|
|
|
|
def load_model_esrgan(): |
|
|
try: |
|
|
global model_esrgan |
|
|
model_esrgan = RealESRGAN( torch.device(device), scale=2) |
|
|
model_esrgan.load_weights('./models/realesrgan/RealESRGAN_x2plus.pth', download=False) |
|
|
|
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
|
|
|
def get_model_state_esrgan(): |
|
|
global model_esrgan |
|
|
if model_esrgan is None: |
|
|
return False |
|
|
else: |
|
|
return True |
|
|
|
|
|
|
|
|
def unload_model_esrgan(): |
|
|
if get_model_state_esrgan() == False: |
|
|
return True |
|
|
try: |
|
|
global model_esrgan |
|
|
model_esrgan = None |
|
|
torch.cuda.empty_cache() |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def unload_model_all(): |
|
|
try: |
|
|
unload_model_age() |
|
|
unload_model_bg() |
|
|
unload_model_cf() |
|
|
unload_model_up() |
|
|
unload_model_swap() |
|
|
unload_model_rest() |
|
|
unload_model_talk() |
|
|
unload_model_sky() |
|
|
unload_model_esrgan() |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
|
|
|
def start_http_file_server(): |
|
|
directory = './outputs' |
|
|
port = 8010 |
|
|
if not os.path.exists(directory): |
|
|
os.makedirs(directory) |
|
|
cmd = f"python -m http.server {port} --directory {directory}" |
|
|
return subprocess.Popen(cmd, shell=True) |
|
|
|
|
|
import asyncio |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
time_last_call = None |
|
|
|
|
|
async def continuous_function(): |
|
|
while True: |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
global time_last_call |
|
|
if time_last_call is not None: |
|
|
current_time = datetime.now() |
|
|
time_difference = current_time - time_last_call |
|
|
time_elapsed_minutes = time_difference.total_seconds() / 60 |
|
|
|
|
|
if time_elapsed_minutes > 1: |
|
|
gc.collect() |
|
|
|
|
|
if time_elapsed_minutes > 5: |
|
|
model_delete_status = unload_model_all() |
|
|
|
|
|
print(f"model delete status = {model_delete_status}") |
|
|
|
|
|
time_last_call = None |
|
|
|
|
|
await asyncio.sleep(60*1) |
|
|
|
|
|
async def start_continuous_function(): |
|
|
await continuous_function() |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
from fastapi.responses import HTMLResponse |
|
|
|
|
|
|
|
|
from fastapi import Request |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import JSONResponse |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
async def custom_middleware(request: Request, call_next): |
|
|
client_ip = request.client.host |
|
|
request_url = str(request.url) |
|
|
request_method = request.method |
|
|
endpoint = request.url.path |
|
|
|
|
|
log_request(client_ip, request_url, request_method, endpoint) |
|
|
|
|
|
if len(search_ban(client_ip)) > 0: |
|
|
return JSONResponse(status_code=401, content={"detail": "Unauthorized"}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response = await call_next(request) |
|
|
|
|
|
if (response.status_code == 404) or (response.status_code == 405): |
|
|
insert_ban(client_ip, "Access using unapproved method/endpoint") |
|
|
return JSONResponse(status_code=401, content={"detail": "Unauthorized"}) |
|
|
|
|
|
return response |
|
|
|
|
|
app.middleware('http')(custom_middleware) |
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
def read_index(): |
|
|
index_path = './assests/index.html' |
|
|
with open(index_path, 'r') as file: |
|
|
content = file.read() |
|
|
return HTMLResponse(content=content) |
|
|
|
|
|
from fastapi.responses import FileResponse |
|
|
|
|
|
@app.get("/favicon.ico") |
|
|
async def get_image(): |
|
|
image_path = "./assests/logo.png" |
|
|
return FileResponse(image_path, media_type="image/png") |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
asyncio.create_task(start_continuous_function()) |
|
|
|
|
|
@app.on_event("shutdown") |
|
|
async def shutdown_event(): |
|
|
print("Clearing Models") |
|
|
model_status = unload_model_all() |
|
|
print("Model clear status = "+str(model_status)) |
|
|
asyncio.get_event_loop().stop() |
|
|
print("Stopping Server") |
|
|
|
|
|
ALLOWED_IMAGE_FORMATS = ["jpg", "jpeg", "png"] |
|
|
ALLOWED_AUDIO_FORMATS = ["wav", "mp3"] |
|
|
|
|
|
def delete_uuid(uuid_to_delete): |
|
|
path = "./outputs" |
|
|
directories = [] |
|
|
for root, dirs, files in os.walk(path): |
|
|
for dir in dirs: |
|
|
directories.append(os.path.join(root, dir)) |
|
|
for i in directories: |
|
|
temp = i.split('/') |
|
|
if temp[-1] == uuid_to_delete: |
|
|
shutil.rmtree(i) |
|
|
return True |
|
|
return False |
|
|
|
|
|
@app.delete("/Delete/{uuid}") |
|
|
async def delete_item(uuid: str): |
|
|
gc.collect() |
|
|
try: |
|
|
result = delete_uuid(uuid) |
|
|
if result: |
|
|
return {"error": f"UUID '{uuid}' does not exist"} |
|
|
|
|
|
return {"message": f"UUID '{uuid}' deleted"} |
|
|
|
|
|
except : |
|
|
return {"error": "unexpected error"} |
|
|
|
|
|
|
|
|
|
|
|
@app.put("/ClearModel_ALL") |
|
|
async def clearModel_ALL(): |
|
|
try: |
|
|
model_delete_status = unload_model_all() |
|
|
|
|
|
|
|
|
if model_delete_status: |
|
|
global time_last_call |
|
|
time_last_call = None |
|
|
return {"message": "vram cleared"} |
|
|
else : |
|
|
return {"message": "vram not cleared"} |
|
|
|
|
|
except : |
|
|
return {"error": "unexpected error"} |
|
|
|
|
|
@app.post("/AgeSlider") |
|
|
async def age_slider(background_tasks: BackgroundTasks , image: UploadFile = File(...)): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global time_last_call |
|
|
time_last_call = datetime.now() |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_format = image.filename.split(".")[-1].lower() |
|
|
if image_format not in ALLOWED_IMAGE_FORMATS: |
|
|
response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
contents = await image.read() |
|
|
nparr = np.frombuffer(contents, np.uint8) |
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
|
|
|
if image is None: |
|
|
response_message = "Image is empty" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
app_face = FaceAnalysis(name="buffalo_l") |
|
|
app_face.prepare(ctx_id=0,det_size=(640,640)) |
|
|
|
|
|
faces = app_face.get(image) |
|
|
if len(faces)==0: |
|
|
response_message = "Error no face detected" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
elif len(faces)!=1: |
|
|
response_message = "Error more than 1 face detected" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
info_age = faces[0]['age'] |
|
|
info_gender = faces[0]['gender'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unique_id = str(uuid.uuid4()) |
|
|
output_path = f"./outputs/AGE/{unique_id}" |
|
|
os.makedirs(output_path) |
|
|
|
|
|
|
|
|
if not get_model_state_age(): |
|
|
unload_model_all() |
|
|
model_status = load_model_age() |
|
|
print(f"model status =========={model_status}") |
|
|
|
|
|
|
|
|
cv2.imwrite(output_path+'/input.png',image) |
|
|
aligned_image = run_alignment(output_path+'/input.png') |
|
|
|
|
|
copy_aligned_image = aligned_image.copy() |
|
|
|
|
|
aligned_image.resize((256, 256)) |
|
|
img_transforms = EXPERIMENT_ARGS['transform'] |
|
|
input_image = img_transforms(aligned_image) |
|
|
|
|
|
target_ages = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] |
|
|
age_transformers = [AgeTransformer(target_age=age) for age in target_ages] |
|
|
|
|
|
images = [] |
|
|
|
|
|
|
|
|
for age_transformer in age_transformers: |
|
|
with torch.no_grad(): |
|
|
input_image_age = [age_transformer(input_image.cpu()).to('cuda')] |
|
|
input_image_age = torch.stack(input_image_age) |
|
|
result_tensor = run_on_batch(input_image_age, model_age_slider)[0] |
|
|
result_image = tensor2im(result_tensor) |
|
|
images.append(result_image) |
|
|
|
|
|
image_paths_temp = [] |
|
|
|
|
|
result_paths_list = [] |
|
|
for idx,i in enumerate(images): |
|
|
image_temp = np.array(i) |
|
|
image_temp = cv2.cvtColor(np.array(image_temp), cv2.COLOR_RGB2BGR) |
|
|
result_path_temp = output_path+"/output_"+str(idx)+".png" |
|
|
cv2.imwrite(result_path_temp,image_temp) |
|
|
|
|
|
image_url_temp = '/'.join(result_path_temp.split('/')[-3:]) |
|
|
image_url_temp = f"{web_url}:8001/{image_url_temp}" |
|
|
result_paths_list.append({"age" :str(target_ages[idx]), "image_url" : image_url_temp}) |
|
|
|
|
|
image_paths_temp.append(result_path_temp) |
|
|
|
|
|
closest_age = min(target_ages, key=lambda x: abs(x - info_age)) |
|
|
closest_age = target_ages.index(closest_age) |
|
|
|
|
|
os.remove(f"{output_path}/output_{closest_age}.png") |
|
|
copy_aligned_image = copy_aligned_image.resize((1024, 1024)) |
|
|
copy_aligned_image.save(f"{output_path}/output_{closest_age}.png") |
|
|
|
|
|
background_tasks.add_task(Age_Slider_Video,unique_id , image_paths_temp) |
|
|
|
|
|
os.remove(output_path+'/input.png') |
|
|
|
|
|
response_message = "AgeSlider Ran Sucessfully" |
|
|
response_status = "Successful" |
|
|
return generate_response(response_message , response_status ,unique_id , info_age , ["Female","Male"][info_gender] , result_paths_list) |
|
|
|
|
|
except Exception as e: |
|
|
shutil.rmtree(output_path) |
|
|
return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) |
|
|
|
|
|
def Age_Slider_Video(uuid_input , image_files): |
|
|
temp_file_uuid = str(uuid.uuid4()) |
|
|
directory_path = f'./outputs/{temp_file_uuid}' |
|
|
if os.path.exists(directory_path): |
|
|
shutil.rmtree(directory_path) |
|
|
os.mkdir(directory_path) |
|
|
|
|
|
from scripts.morph_video import doMorphing |
|
|
doMorphing(image_files, 0.3, 20, f"{directory_path}/output") |
|
|
video_path = f"{directory_path}/output_combined.mp4" |
|
|
gif_path = f"{directory_path}/output.gif" |
|
|
|
|
|
shutil.copy(video_path, f"./outputs/AGE/{uuid_input}/output_combined.mp4") |
|
|
shutil.copy(gif_path, f"./outputs/AGE/{uuid_input}/output.gif") |
|
|
|
|
|
shutil.rmtree(directory_path) |
|
|
|
|
|
@app.post("/AgeSliderVideo/{uuid_input}") |
|
|
async def ageslidervideo(uuid_input: str): |
|
|
|
|
|
try: |
|
|
input_path = f"./outputs/AGE/{uuid_input}" |
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(input_path): |
|
|
response_message = f"Invalid UUID , file does not exist" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_video_path = '' |
|
|
output_gif_path = '' |
|
|
duration = 120 |
|
|
start_time = time.time() |
|
|
while time.time() - start_time < duration: |
|
|
|
|
|
if output_video_path == '': |
|
|
if os.path.exists(input_path+"/output_combined.mp4"): |
|
|
output_video_path = f'{web_url}:8001/AGE/{uuid_input}'+"/output_combined.mp4" |
|
|
|
|
|
if output_gif_path == '': |
|
|
if os.path.exists(input_path+"/output.gif"): |
|
|
output_gif_path = f'{web_url}:8001/AGE/{uuid_input}'+"/output.gif" |
|
|
|
|
|
if output_video_path != '': |
|
|
if output_gif_path != '': |
|
|
response_message = "Video Genrated Sucessfully" |
|
|
response_status = "Successful" |
|
|
return generate_response(response_message , response_status ,uuid_input , '' , '', [{"video":output_video_path , "gif":output_gif_path}]) |
|
|
|
|
|
|
|
|
if output_video_path != '': |
|
|
if output_gif_path == '': |
|
|
response_message = "Only Video Genrated" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,uuid_input , '' , '', [{"video":output_video_path , "gif":output_gif_path}]) |
|
|
|
|
|
|
|
|
if output_gif_path != '': |
|
|
if output_video_path == '': |
|
|
response_message = "Only GIF Genrated" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,uuid_input , '' , '', [{"video":output_video_path , "gif":output_gif_path}]) |
|
|
|
|
|
|
|
|
|
|
|
if output_gif_path == '': |
|
|
if output_video_path == '': |
|
|
response_message = "Timeout Reached , Retry Later" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,uuid_input , '' , '', [{"video":output_video_path , "gif":output_gif_path}]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/BG_remove") |
|
|
async def bg(image: UploadFile = File(...)): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global time_last_call |
|
|
time_last_call = datetime.now() |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_format = image.filename.split(".")[-1].lower() |
|
|
if image_format not in ALLOWED_IMAGE_FORMATS: |
|
|
response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
contents = await image.read() |
|
|
nparr = np.frombuffer(contents, np.uint8) |
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
|
|
|
if image is None: |
|
|
response_message = "Image is empty" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unique_id = str(uuid.uuid4()) |
|
|
output_path = f"./outputs/BG/{unique_id}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not get_model_state_bg(): |
|
|
unload_model_all() |
|
|
model_status = load_model_bg() |
|
|
print(f"model status =========={model_status}") |
|
|
|
|
|
|
|
|
output_image = bg_model.BG_remove(image) |
|
|
|
|
|
os.makedirs(output_path) |
|
|
output_image_path = output_path+'/output.png' |
|
|
cv2.imwrite(output_image_path,output_image) |
|
|
|
|
|
|
|
|
|
|
|
response_message = "BG Removed Sucessfully" |
|
|
response_status = "Successful" |
|
|
return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/BG/{unique_id}/output.png"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
shutil.rmtree(output_path) |
|
|
return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) |
|
|
|
|
|
|
|
|
@app.post("/CodeFormer") |
|
|
async def codeformer(image: UploadFile = File(...)): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global time_last_call |
|
|
time_last_call = datetime.now() |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_format = image.filename.split(".")[-1].lower() |
|
|
if image_format not in ALLOWED_IMAGE_FORMATS: |
|
|
response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
contents = await image.read() |
|
|
nparr = np.frombuffer(contents, np.uint8) |
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
|
|
|
if image is None: |
|
|
response_message = "Image is empty" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
height = image.shape[0] |
|
|
width = image.shape[1] |
|
|
|
|
|
if (height*width)>(500*500) : |
|
|
aspect_ratio = width / height |
|
|
new_height = int(((500*500) / aspect_ratio) ** 0.5) |
|
|
new_width = int(aspect_ratio * new_height) |
|
|
image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unique_id = str(uuid.uuid4()) |
|
|
output_path = f"./outputs/CF/{unique_id}" |
|
|
|
|
|
|
|
|
|
|
|
if not get_model_state_cf(): |
|
|
unload_model_all() |
|
|
model_status = load_model_cf() |
|
|
print(f"model status =========={model_status}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
has_aligned = False |
|
|
only_center_face = False |
|
|
draw_box = False |
|
|
detection_model = "retinaface_resnet50" |
|
|
background_enhance = True |
|
|
face_upsample = True |
|
|
upscale = 2 |
|
|
codeformer_fidelity = 0.5 |
|
|
img = image |
|
|
upscale = int(upscale) |
|
|
if upscale > 4: |
|
|
upscale = 4 |
|
|
if upscale > 2 and max(img.shape[:2])>1000: |
|
|
upscale = 2 |
|
|
if max(img.shape[:2]) > 1500: |
|
|
upscale = 1 |
|
|
background_enhance = False |
|
|
face_upsample = False |
|
|
face_helper = FaceRestoreHelper( |
|
|
upscale, |
|
|
face_size=512, |
|
|
crop_ratio=(1, 1), |
|
|
det_model=detection_model, |
|
|
save_ext="png", |
|
|
use_parse=True, |
|
|
device=device, |
|
|
) |
|
|
bg_upsampler = model_upsampler if background_enhance else None |
|
|
face_upsampler = model_upsampler if face_upsample else None |
|
|
if has_aligned: |
|
|
|
|
|
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) |
|
|
face_helper.is_gray = is_gray(img, threshold=5) |
|
|
if face_helper.is_gray: |
|
|
print('\tgrayscale input: True') |
|
|
face_helper.cropped_faces = [img] |
|
|
else: |
|
|
face_helper.read_image(img) |
|
|
|
|
|
num_det_faces = face_helper.get_face_landmarks_5( |
|
|
only_center_face=only_center_face, resize=640, eye_dist_threshold=5 |
|
|
) |
|
|
print(f'\tdetect {num_det_faces} faces') |
|
|
|
|
|
face_helper.align_warp_face() |
|
|
|
|
|
|
|
|
for idx, cropped_face in enumerate(face_helper.cropped_faces): |
|
|
|
|
|
cropped_face_t = img2tensor( |
|
|
cropped_face / 255.0, bgr2rgb=True, float32=True |
|
|
) |
|
|
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) |
|
|
cropped_face_t = cropped_face_t.unsqueeze(0).to(device) |
|
|
try: |
|
|
with torch.no_grad(): |
|
|
output = model_code_former(cropped_face_t, w=codeformer_fidelity, adain=True)[0] |
|
|
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) |
|
|
del output |
|
|
torch.cuda.empty_cache() |
|
|
except RuntimeError as error: |
|
|
print(f"Failed inference for CodeFormer: {error}") |
|
|
restored_face = tensor2img( |
|
|
cropped_face_t, rgb2bgr=True, min_max=(-1, 1) |
|
|
) |
|
|
restored_face = restored_face.astype("uint8") |
|
|
face_helper.add_restored_face(restored_face) |
|
|
|
|
|
|
|
|
if not has_aligned: |
|
|
|
|
|
if bg_upsampler is not None: |
|
|
|
|
|
bg_img = bg_upsampler.enhance(img, outscale=upscale)[0] |
|
|
else: |
|
|
bg_img = None |
|
|
face_helper.get_inverse_affine(None) |
|
|
|
|
|
if face_upsample and face_upsampler is not None: |
|
|
restored_img = face_helper.paste_faces_to_input_image( |
|
|
upsample_img=bg_img, |
|
|
draw_box=draw_box, |
|
|
face_upsampler=face_upsampler, |
|
|
) |
|
|
else: |
|
|
restored_img = face_helper.paste_faces_to_input_image( |
|
|
upsample_img=bg_img, draw_box=draw_box |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_image_path = output_path+'/output.png' |
|
|
os.makedirs(output_path) |
|
|
cv2.imwrite(output_image_path,restored_img) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response_message = "CodeFormer Image Sucessfully" |
|
|
response_status = "Successful" |
|
|
return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/CF/{unique_id}/output.png"]) |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
shutil.rmtree(output_path) |
|
|
return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) |
|
|
|
|
|
@app.post("/FaceSwap_single_image") |
|
|
async def Swap_single_image(image: UploadFile = File(...)): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global time_last_call |
|
|
time_last_call = datetime.now() |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_format = image.filename.split(".")[-1].lower() |
|
|
if image_format not in ALLOWED_IMAGE_FORMATS: |
|
|
response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
contents = await image.read() |
|
|
nparr = np.frombuffer(contents, np.uint8) |
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
|
|
|
if image is None: |
|
|
response_message = "Image is empty" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unique_id = str(uuid.uuid4()) |
|
|
output_path = f"./outputs/SWAP/{unique_id}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not get_model_state_swap(): |
|
|
unload_model_all() |
|
|
model_status = load_model_swap() |
|
|
print(f"model status =========={model_status}") |
|
|
|
|
|
faces = app_model.get(image) |
|
|
|
|
|
if len(faces)==0: |
|
|
response_message = "Error no face detected" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message,response_status,'' , '' , '' , []) |
|
|
elif len(faces)!=2: |
|
|
response_message = "Error more than 2 face detected" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message,response_status,'' , '' , '' , []) |
|
|
|
|
|
face1 = faces[0] |
|
|
face2 = faces[1] |
|
|
|
|
|
image = swap_model.get(image,face1,face2,paste_back=True) |
|
|
image = swap_model.get(image,face2,face1,paste_back=True) |
|
|
|
|
|
|
|
|
os.makedirs(output_path) |
|
|
output_image_path = output_path+'/output.png' |
|
|
cv2.imwrite(output_image_path,image) |
|
|
|
|
|
|
|
|
|
|
|
response_message = "Face Swapped Sucessfully" |
|
|
response_status = "Successful" |
|
|
return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/SWAP/{unique_id}/output.png"]) |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
shutil.rmtree(output_path) |
|
|
return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/FaceSwap_two_images") |
|
|
async def swap_two_image(image1: UploadFile = File(...), image2: UploadFile = File(...)): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global time_last_call |
|
|
time_last_call = datetime.now() |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
images = [] |
|
|
for idx ,image in enumerate([image1,image2]): |
|
|
image_format = image.filename.split(".")[-1].lower() |
|
|
print(image.filename,image.filename.split(".")[-1].lower()) |
|
|
if image_format not in ALLOWED_IMAGE_FORMATS: |
|
|
response_message = f"Unsupported image{str(idx+1)} format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
contents = await image.read() |
|
|
nparr = np.frombuffer(contents, np.uint8) |
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
|
|
|
images.append(image) |
|
|
|
|
|
if image is None: |
|
|
response_message = f"Image{str(idx+1)} is empty" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unique_id = str(uuid.uuid4()) |
|
|
output_path = f"./outputs/SWAP/{unique_id}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not get_model_state_swap(): |
|
|
unload_model_all() |
|
|
model_status = load_model_swap() |
|
|
print(f"model status =========={model_status}") |
|
|
|
|
|
faces = [] |
|
|
faces.append(app_model.get(images[0])) |
|
|
faces.append(app_model.get(images[1])) |
|
|
|
|
|
for idx ,face in enumerate(faces): |
|
|
if len(face)==0: |
|
|
response_message = f"Error no face detected in image{idx}" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message,response_status,'' , '' , '' , []) |
|
|
elif len(face) >1: |
|
|
response_message = f"Error more than 1 face detected in image{idx}" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message,response_status,'' , '' , '' , []) |
|
|
|
|
|
face1 = faces[0][0] |
|
|
face2 = faces[1][0] |
|
|
|
|
|
image = swap_model.get(images[0],face1,face2,paste_back=True) |
|
|
|
|
|
|
|
|
os.makedirs(output_path) |
|
|
output_image_path = output_path+'/output.png' |
|
|
cv2.imwrite(output_image_path,image) |
|
|
|
|
|
|
|
|
|
|
|
response_message = "Face Swapped Sucessfully" |
|
|
response_status = "Successful" |
|
|
return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/SWAP/{unique_id}/output.png"]) |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) |
|
|
|
|
|
|
|
|
@app.post("/Restore_Images") |
|
|
async def restore(image: UploadFile = File(...)): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global time_last_call |
|
|
time_last_call = datetime.now() |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_format = image.filename.split(".")[-1].lower() |
|
|
if image_format not in ALLOWED_IMAGE_FORMATS: |
|
|
response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
contents = await image.read() |
|
|
nparr = np.frombuffer(contents, np.uint8) |
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
|
|
|
if image is None: |
|
|
response_message = "Image is empty" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unique_id = str(uuid.uuid4()) |
|
|
output_path = f"./outputs/Restore/{unique_id}" |
|
|
os.makedirs(output_path) |
|
|
|
|
|
cv2.imwrite(output_path+'/input.png',image) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not get_model_state_rest(): |
|
|
unload_model_all() |
|
|
model_status = load_model_rest() |
|
|
print(f"model status =========={model_status}") |
|
|
|
|
|
image = PIL.Image.open(output_path+'/input.png') |
|
|
os.remove(output_path+'/input.png') |
|
|
|
|
|
|
|
|
transformed, mask = model_scratches_remove_detector.process(image) |
|
|
img_transform = transforms.Compose( |
|
|
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] |
|
|
) |
|
|
mask_transform = transforms.ToTensor() |
|
|
if model_scratches_remove_options.mask_dilation != 0: |
|
|
kernel = np.ones((3, 3), np.uint8) |
|
|
mask = np.array(mask) |
|
|
mask = cv2.dilate(mask, kernel, iterations=model_scratches_remove_options.mask_dilation) |
|
|
mask = PIL.Image.fromarray(mask.astype('uint8')) |
|
|
transformed = irregular_hole_synthesize(transformed, mask) |
|
|
mask = mask_transform(mask) |
|
|
mask = mask[:1, :, :] |
|
|
mask = mask.unsqueeze(0) |
|
|
transformed = img_transform(transformed) |
|
|
transformed = transformed.unsqueeze(0) |
|
|
generated = model_scratches_remove.inference(transformed, mask) |
|
|
tensor_restored = (generated.data.cpu() + 1.0) / 2.0 |
|
|
image_to_show = tensor_restored.squeeze().cpu().numpy().transpose((1, 2, 0)) |
|
|
|
|
|
image_to_show = (image_to_show * 255).astype(np.uint8)[:, :, ::-1] |
|
|
|
|
|
|
|
|
from scripts.colorizer import colorize_image |
|
|
image_to_show = colorize_image(image_to_show) |
|
|
|
|
|
output_image_path = output_path+'/output.png' |
|
|
cv2.imwrite(output_image_path,image_to_show) |
|
|
|
|
|
|
|
|
|
|
|
response_message = "Restred Image Sucessfully" |
|
|
response_status = "Successful" |
|
|
return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/Restore/{unique_id}/output.png"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
shutil.rmtree(output_path) |
|
|
return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) |
|
|
|
|
|
|
|
|
@app.post("/Sad_Talker") |
|
|
async def sadtalker(image: UploadFile = File(...), audio: UploadFile = File(...)): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global time_last_call |
|
|
time_last_call = datetime.now() |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_format = image.filename.split(".")[-1].lower() |
|
|
if image_format not in ALLOWED_IMAGE_FORMATS: |
|
|
response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
contents = await image.read() |
|
|
nparr = np.frombuffer(contents, np.uint8) |
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
|
|
|
if image is None: |
|
|
response_message = "Image is empty" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
|
|
|
|
|
|
audio_format = audio.filename.split(".")[-1].lower() |
|
|
if audio_format not in ALLOWED_AUDIO_FORMATS: |
|
|
response_message = f"Unsupported audio format. Supported formats: {', '.join(ALLOWED_AUDIO_FORMATS)}" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message, response_status, '', '', '', []) |
|
|
|
|
|
audio_contents = await audio.read() |
|
|
if len(audio_contents) == 0: |
|
|
response_message = "Audio file is empty" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message, response_status, '', '', '', []) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unique_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
output_path = f"./outputs/Sad_Talker/{unique_id}" |
|
|
|
|
|
os.makedirs(output_path) |
|
|
|
|
|
image_input_path = output_path+'/input.png' |
|
|
cv2.imwrite(image_input_path,image) |
|
|
|
|
|
audio_segment = AudioSegment.from_file(io.BytesIO(audio_contents), format=audio_format) |
|
|
audio_input_path = os.path.join(output_path, f"audio.{audio_format}") |
|
|
audio_segment.export(audio_input_path, format=audio_format) |
|
|
|
|
|
app_face = FaceAnalysis(name="buffalo_l") |
|
|
app_face.prepare(ctx_id=0,det_size=(640,640)) |
|
|
|
|
|
faces = app_face.get(image) |
|
|
if len(faces)==0: |
|
|
response_message = "Error no face detected" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
elif len(faces)!=1: |
|
|
response_message = "Error more than 1 face detected" |
|
|
response_status = "Failure" |
|
|
|
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
info_age = faces[0]['age'] |
|
|
info_gender = faces[0]['gender'] |
|
|
|
|
|
|
|
|
|
|
|
if not get_model_state_talk(): |
|
|
unload_model_all() |
|
|
model_status = load_model_talk() |
|
|
print(f"model status =========={model_status}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response = sad_talker_model.genrate_video(image_input_path,audio_input_path,output_path,True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.remove(image_input_path) |
|
|
os.remove(audio_input_path) |
|
|
|
|
|
if response: |
|
|
|
|
|
input_file = f"outputs/Sad_Talker/{unique_id}/output.mp4" |
|
|
output_file = f"outputs/Sad_Talker/{unique_id}/final.mp4" |
|
|
|
|
|
subprocess.run(['ffmpeg', '-i', input_file, '-c:v', 'libx264', '-preset', 'slow', '-crf', '23', '-c:a', 'aac', '-b:a', '192k', output_file], check=True) |
|
|
|
|
|
os.remove(input_file) |
|
|
|
|
|
response_message = "Video Generated Sucessfully" |
|
|
response_status = "Successful" |
|
|
return generate_response(response_message , response_status ,unique_id , info_age , ["Female","Male"][info_gender] , [f"{web_url}:8001/Sad_Talker/{unique_id}/final.mp4"]) |
|
|
|
|
|
else: |
|
|
response_message = "Error Genrating Video, try another image" |
|
|
response_status = "Failure" |
|
|
|
|
|
return generate_response(response_message , response_status ,'','','',[]) |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
shutil.rmtree(output_path) |
|
|
return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) |
|
|
|
|
|
|
|
|
@app.post("/SKY_remove") |
|
|
async def sky(image: UploadFile = File(...)): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global time_last_call |
|
|
time_last_call = datetime.now() |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_format = image.filename.split(".")[-1].lower() |
|
|
if image_format not in ALLOWED_IMAGE_FORMATS: |
|
|
response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
contents = await image.read() |
|
|
nparr = np.frombuffer(contents, np.uint8) |
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
|
|
|
if image is None: |
|
|
response_message = "Image is empty" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unique_id = str(uuid.uuid4()) |
|
|
output_path = f"./outputs/SKY/{unique_id}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not get_model_state_sky(): |
|
|
unload_model_all() |
|
|
model_status = load_model_sky() |
|
|
print(f"model status =========={model_status}") |
|
|
|
|
|
|
|
|
output_image = sky_model.SKY_remove(image) |
|
|
|
|
|
os.makedirs(output_path) |
|
|
output_image_path = output_path+'/output.png' |
|
|
cv2.imwrite(output_image_path,output_image) |
|
|
|
|
|
|
|
|
|
|
|
response_message = "SKY Removed Sucessfully" |
|
|
response_status = "Successful" |
|
|
return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/SKY/{unique_id}/output.png"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
shutil.rmtree(output_path) |
|
|
return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) |
|
|
|
|
|
|
|
|
@app.post("/ImageUpscale") |
|
|
async def upscale(image: UploadFile = File(...)): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global time_last_call |
|
|
time_last_call = datetime.now() |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_format = image.filename.split(".")[-1].lower() |
|
|
if image_format not in ALLOWED_IMAGE_FORMATS: |
|
|
response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
contents = await image.read() |
|
|
nparr = np.frombuffer(contents, np.uint8) |
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
|
|
|
if image is None: |
|
|
response_message = "Image is empty" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
height = image.shape[0] |
|
|
width = image.shape[1] |
|
|
|
|
|
if (height*width)>(200*200) : |
|
|
print(11) |
|
|
aspect_ratio = width / height |
|
|
new_height = int(((200*200) / aspect_ratio) ** 0.5) |
|
|
new_width = int(aspect_ratio * new_height) |
|
|
image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unique_id = str(uuid.uuid4()) |
|
|
output_path = f"./outputs/Upscale/{unique_id}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not get_model_state_up(): |
|
|
unload_model_all() |
|
|
model_status = load_model_up() |
|
|
print(f"model status =========={model_status}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image = image * 1.0 / 255 |
|
|
image = torch.from_numpy(np.transpose(image[:, :, [2, 1, 0]], (2, 0, 1))).float() |
|
|
img_LR = image.unsqueeze(0) |
|
|
img_LR = img_LR.to(device).half() |
|
|
|
|
|
output = model_upscale(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy() |
|
|
|
|
|
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) |
|
|
output = (output * 255.0).round() |
|
|
image = output.astype(np.uint8) |
|
|
|
|
|
output_image_path = output_path+'/output.png' |
|
|
os.makedirs(output_path) |
|
|
cv2.imwrite(output_image_path,image) |
|
|
|
|
|
|
|
|
|
|
|
response_message = "Upscaled Image Sucessfully" |
|
|
response_status = "Successful" |
|
|
return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/Upscale/{unique_id}/output.png"]) |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
shutil.rmtree(output_path) |
|
|
return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) |
|
|
|
|
|
@app.post("/ImageDenoise") |
|
|
async def deniose(image: UploadFile = File(...)): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global time_last_call |
|
|
time_last_call = datetime.now() |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_format = image.filename.split(".")[-1].lower() |
|
|
if image_format not in ALLOWED_IMAGE_FORMATS: |
|
|
response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
contents = await image.read() |
|
|
nparr = np.frombuffer(contents, np.uint8) |
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
|
|
|
if image is None: |
|
|
response_message = "Image is empty" |
|
|
response_status = "Failure" |
|
|
return generate_response(response_message , response_status ,'' , '' , '' , []) |
|
|
|
|
|
height = image.shape[0] |
|
|
width = image.shape[1] |
|
|
|
|
|
if (height*width)>(500*500) : |
|
|
print(11) |
|
|
aspect_ratio = width / height |
|
|
new_height = int(((500*500) / aspect_ratio) ** 0.5) |
|
|
new_width = int(aspect_ratio * new_height) |
|
|
image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unique_id = str(uuid.uuid4()) |
|
|
output_path = f"./outputs/Denoise/{unique_id}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not get_model_state_esrgan(): |
|
|
unload_model_all() |
|
|
model_status = load_model_esrgan() |
|
|
print(f"model status =========={model_status}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
image_pil = Image.fromarray(image_rgb) |
|
|
global model_esrgan |
|
|
result_pil = model_esrgan.predict(image_pil) |
|
|
|
|
|
result_rgb = result_pil.convert('RGB') |
|
|
result_np = np.array(result_rgb) |
|
|
image = cv2.cvtColor(result_np, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
output_image_path = output_path+'/output.png' |
|
|
os.makedirs(output_path) |
|
|
cv2.imwrite(output_image_path,image) |
|
|
|
|
|
|
|
|
|
|
|
response_message = "Upscaled Image Sucessfully" |
|
|
response_status = "Successful" |
|
|
return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/Denoise/{unique_id}/output.png"]) |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
shutil.rmtree(output_path) |
|
|
return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
import uvicorn |
|
|
try: |
|
|
http_file_server_process = start_http_file_server() |
|
|
uvicorn.run("main:app", host='0.0.0.0', port=8080, workers=2 , limit_max_requests=200) |
|
|
print("Received KeyboardInterrupt, shutting down gracefully...") |
|
|
http_file_server_process.kill() |
|
|
shutdown_event() |
|
|
except KeyboardInterrupt: |
|
|
print("Received KeyboardInterrupt, shutting down gracefully...") |
|
|
http_file_server_process.kill() |
|
|
shutdown_event() |
|
|
|