|
import os |
|
import sys |
|
import time |
|
import json |
|
import torch |
|
import base64 |
|
from PIL import Image |
|
from io import BytesIO |
|
|
|
|
|
os.environ["CUDA_MODULE_LOADING"] = "LAZY" |
|
|
|
os.environ["SAFETENSORS_FAST_GPU"] = "1" |
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "seg2art")) |
|
from seg2art.sstan_models.pix2pix_model import Pix2PixModel |
|
from seg2art.options.test_options import TestOptions |
|
from seg2art.inference_util import get_artwork |
|
|
|
import uvicorn |
|
from fastapi import FastAPI, Form |
|
from fastapi.templating import Jinja2Templates |
|
from fastapi.responses import PlainTextResponse, HTMLResponse |
|
from fastapi.requests import Request |
|
from fastapi.staticfiles import StaticFiles |
|
|
|
|
|
|
|
HOST = "0.0.0.0" |
|
PORT = 7860 |
|
|
|
app = FastAPI(root_path=os.path.abspath(os.path.dirname(__file__))) |
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
templates = Jinja2Templates(directory="templates") |
|
|
|
|
|
|
|
opt = TestOptions().parse() |
|
opt.status = "test" |
|
model = Pix2PixModel(opt) |
|
model = model.half() if torch.cuda.is_available() else model |
|
model.eval() |
|
|
|
|
|
from utils.umap_utils import get_code, load_boundries, modify_code |
|
|
|
boundaries = load_boundries() |
|
global current_codes |
|
current_codes = {} |
|
max_user_num = 5 |
|
|
|
initial_code_path = os.path.join(os.path.dirname(__file__), "static/init_code") |
|
initial_code = torch.load(initial_code_path) if torch.cuda.is_available() else torch.load(initial_code_path, map_location=torch.device("cpu")) |
|
|
|
|
|
def EncodeImage(img_pil): |
|
with BytesIO() as buffer: |
|
img_pil.save(buffer, "jpeg") |
|
image_data = base64.b64encode(buffer.getvalue()) |
|
return image_data |
|
|
|
|
|
def DecodeImage(img_pil): |
|
img_pil = BytesIO(base64.urlsafe_b64decode(img_pil)) |
|
img_pil = Image.open(img_pil).convert("RGB") |
|
return img_pil |
|
|
|
|
|
def process_input(body, random=False): |
|
global current_codes |
|
json_body = json.loads(body.decode("utf-8")) |
|
user_id = json_body["user_id"] |
|
start_time = time.time() |
|
|
|
|
|
if user_id not in current_codes: |
|
current_codes[user_id] = initial_code.clone() |
|
if len(current_codes[user_id]) > max_user_num: |
|
current_codes[user_id] = current_codes[user_id][-max_user_num:] |
|
|
|
if random: |
|
|
|
domain = json_body["model"] |
|
current_codes[user_id] = get_code(domain, boundaries) |
|
|
|
|
|
input_img = DecodeImage(json_body["img"]) |
|
|
|
try: |
|
move_range = float(json_body["move_range"]) |
|
except: |
|
move_range = 0 |
|
|
|
|
|
move_range = 3 if random else move_range |
|
|
|
|
|
domain = json_body["model"] |
|
if move_range != 0: |
|
modified_code = modify_code(current_codes[user_id], boundaries, domain, move_range) |
|
else: |
|
modified_code = current_code.clone() |
|
|
|
|
|
result = get_artwork(model, input_img, modified_code) |
|
print("Time Cost: ", time.time() - start_time) |
|
return EncodeImage(result) |
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
def root(request: Request): |
|
return templates.TemplateResponse("index.html", {"request": request}) |
|
|
|
@app.get("/check_gpu") |
|
async def check_gpu(): |
|
return torch.cuda.is_available() |
|
|
|
@app.post("/predict") |
|
async def predict(request: Request): |
|
body = await request.body() |
|
result = process_input(body, random=False) |
|
return result |
|
|
|
|
|
@app.post("/predict_random") |
|
async def predict_random(request: Request): |
|
body = await request.body() |
|
result = process_input(body, random=True) |
|
return result |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host=HOST, port=PORT, log_level="info") |
|
|