sky24h's picture
update
0cc889f
import os
import sys
import time
import json
import torch
import base64
from PIL import Image
from io import BytesIO
# set CUDA_MODULE_LOADING=LAZY to speed up the serverless function
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
# set SAFETENSORS_FAST_GPU=1 to speed up the serverless function
os.environ["SAFETENSORS_FAST_GPU"] = "1"
sys.path.append(os.path.join(os.path.dirname(__file__), "seg2art"))
from seg2art.sstan_models.pix2pix_model import Pix2PixModel
from seg2art.options.test_options import TestOptions
from seg2art.inference_util import get_artwork
import uvicorn
from fastapi import FastAPI, Form
from fastapi.templating import Jinja2Templates
from fastapi.responses import PlainTextResponse, HTMLResponse
from fastapi.requests import Request
from fastapi.staticfiles import StaticFiles
# declare constants
HOST = "0.0.0.0"
PORT = 7860
# FastAPI
app = FastAPI(root_path=os.path.abspath(os.path.dirname(__file__)))
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
# initialize SEAN model.
opt = TestOptions().parse()
opt.status = "test"
model = Pix2PixModel(opt)
model = model.half() if torch.cuda.is_available() else model
model.eval()
from utils.umap_utils import get_code, load_boundries, modify_code
boundaries = load_boundries()
global current_codes
current_codes = {}
max_user_num = 5
initial_code_path = os.path.join(os.path.dirname(__file__), "static/init_code")
initial_code = torch.load(initial_code_path) if torch.cuda.is_available() else torch.load(initial_code_path, map_location=torch.device("cpu"))
def EncodeImage(img_pil):
with BytesIO() as buffer:
img_pil.save(buffer, "jpeg")
image_data = base64.b64encode(buffer.getvalue())
return image_data
def DecodeImage(img_pil):
img_pil = BytesIO(base64.urlsafe_b64decode(img_pil))
img_pil = Image.open(img_pil).convert("RGB")
return img_pil
def process_input(body, random=False):
global current_codes
json_body = json.loads(body.decode("utf-8"))
user_id = json_body["user_id"]
start_time = time.time()
# save current code for different users
if user_id not in current_codes:
current_codes[user_id] = initial_code.clone()
if len(current_codes[user_id]) > max_user_num:
current_codes[user_id] = current_codes[user_id][-max_user_num:]
if random:
# randomize code
domain = json_body["model"]
current_codes[user_id] = get_code(domain, boundaries)
# get input
input_img = DecodeImage(json_body["img"])
try:
move_range = float(json_body["move_range"])
except:
move_range = 0
# set move range to 3 if random is True
move_range = 3 if random else move_range
# print("Input image was received")
# get selected style
domain = json_body["model"]
if move_range != 0:
modified_code = modify_code(current_codes[user_id], boundaries, domain, move_range)
else:
modified_code = current_code.clone()
# inference
result = get_artwork(model, input_img, modified_code)
print("Time Cost: ", time.time() - start_time)
return EncodeImage(result)
@app.get("/", response_class=HTMLResponse)
def root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/check_gpu")
async def check_gpu():
return torch.cuda.is_available()
@app.post("/predict")
async def predict(request: Request):
body = await request.body()
result = process_input(body, random=False)
return result
@app.post("/predict_random")
async def predict_random(request: Request):
body = await request.body()
result = process_input(body, random=True)
return result
if __name__ == "__main__":
uvicorn.run(app, host=HOST, port=PORT, log_level="info")