Spaces:
Running
on
T4
Running
on
T4
import base64 | |
import datetime | |
import os | |
import sys | |
from io import BytesIO | |
from pathlib import Path | |
import numpy as np | |
import requests | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
import time | |
import streamlit as st | |
from demo_config import HUGGING_FACE, WORKER_URL | |
PACKAGE_PARENT = 'wise' | |
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__)))) | |
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT))) | |
from parameter_optimization.parametric_styletransfer import single_optimize | |
from parameter_optimization.parametric_styletransfer import CONFIG as ST_CONFIG | |
from parameter_optimization.strotss_org import strotss, pil_resize_long_edge_to | |
from helpers import torch_to_np, np_to_torch | |
def retrieve_for_results_from_server(): | |
task_id = st.session_state['current_server_task_id'] | |
vp_res = requests.get(WORKER_URL+"/get_vp", params={"task_id": task_id}) | |
image_res = requests.get(WORKER_URL+"/get_image", params={"task_id": task_id}) | |
if vp_res.status_code != 200 or image_res.status_code != 200: | |
st.warning("got status for " + WORKER_URL+"/get_vp" + str(vp_res.status_code)) | |
st.warning("got status for " + WORKER_URL+"/image_res" + str(image_res.status_code)) | |
st.session_state['current_server_task_id'] = None | |
vp_res.raise_for_status() | |
image_res.raise_for_status() | |
else: | |
st.session_state['current_server_task_id'] = None | |
vp = np.load(BytesIO(vp_res.content))["vp"] | |
print("received vp from server") | |
print("got numpy array", vp.shape) | |
vp = torch.from_numpy(vp).cuda() | |
image = Image.open(BytesIO(image_res.content)) | |
print("received image from server") | |
image = np_to_torch(np.asarray(image)).cuda() | |
st.session_state["effect_input"] = image | |
st.session_state["result_vp"] = vp | |
def monitor_task(progress_placeholder): | |
task_id = st.session_state['current_server_task_id'] | |
started_time = time.time() | |
retries = 3 | |
while True: | |
status = requests.get(WORKER_URL+"/get_status", params={"task_id": task_id}) | |
if status.status_code != 200: | |
print("get_status got status_code", status.status_code) | |
st.warning(status.content) | |
retries -= 1 | |
if retries == 0: | |
return | |
else: | |
time.sleep(2) | |
continue | |
status = status.json() | |
print(status) | |
if status["status"] != "running" and status["status"] != "queued" : | |
if status["msg"] != "": | |
print("got error for task", task_id, ":", status["msg"]) | |
progress_placeholder.error(status["msg"]) | |
st.session_state['current_server_task_id'] = None | |
st.stop() | |
if status["status"] == "finished": | |
retrieve_for_results_from_server() | |
return | |
elif status["status"] == "queued": | |
started_time = time.time() | |
queue_length = requests.get(WORKER_URL+"/queue_length").json() | |
progress_placeholder.write(f"There are {queue_length['length']} tasks in the queue") | |
elif status["progress"] == 0.0: | |
progressed = min(0.5 * (time.time() - started_time) / 80.0, 0.5) #estimate 80s for strotts | |
progress_placeholder.progress(progressed) | |
else: | |
progress_placeholder.progress(min(0.5 + status["progress"] / 2.0, 1.0)) | |
time.sleep(2) | |
def optimize_on_server(content, style, result_image_placeholder): | |
url = WORKER_URL + "/upload" | |
content_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg" | |
style_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg" | |
asp_c, asp_s = content.height / content.width, style.height / style.width | |
if any(a < 0.5 or a > 2.0 for a in (asp_c, asp_s)): | |
result_image_placeholder.error('aspect ratio must be <= 2') | |
st.stop() | |
content = pil_resize_long_edge_to(content, 1024) | |
content.save(content_path) | |
style = pil_resize_long_edge_to(style, 1024) | |
style.save(style_path) | |
files = {'style-image': open(style_path, "rb"), "content-image": open(content_path, "rb")} | |
print("start-optimizing") | |
task_id_res = requests.post(url, files=files) | |
if task_id_res.status_code != 200: | |
result_image_placeholder.error(task_id_res.content) | |
st.stop() | |
else: | |
task_id = task_id_res.json()['task_id'] | |
st.session_state['current_server_task_id'] = task_id | |
monitor_task(result_image_placeholder) | |
def optimize_params(effect, preset, content, style, result_image_placeholder): | |
result_image_placeholder.text("Executing NST to create reference image..") | |
base_dir = f"result/{datetime.datetime.now().strftime(r'%Y-%m-%d %H.%Mh %Ss')}" | |
os.makedirs(base_dir) | |
reference = strotss(pil_resize_long_edge_to(content, 1024), | |
pil_resize_long_edge_to(style, 1024), content_weight=16.0, | |
device=torch.device("cuda"), space="uniform") | |
progress_bar = result_image_placeholder.progress(0.0) | |
ref_save_path = os.path.join(base_dir, "reference.jpg") | |
content_save_path = os.path.join(base_dir, "content.jpg") | |
resize_to = 720 | |
reference = pil_resize_long_edge_to(reference, resize_to) | |
reference.save(ref_save_path) | |
content.save(content_save_path) | |
ST_CONFIG["n_iterations"] = 300 | |
vp, content_img_cuda = single_optimize(effect, preset, "l1", content_save_path, str(ref_save_path), | |
write_video=False, base_dir=base_dir, | |
iter_callback=lambda i: progress_bar.progress( | |
float(i) / ST_CONFIG["n_iterations"])) | |
st.session_state["effect_input"], st.session_state["result_vp"] = content_img_cuda.detach(), vp.cuda().detach() |