Spaces:
Running
Running
## All Generation Gradio Interface | |
import uuid | |
import time | |
from .utils import * | |
from .vote_utils import t2s_logger, t2s_multi_logger, i2s_logger, i2s_multi_logger | |
from constants import RGBA_SERVER, LOG_SERVER, TEXT_PROMPT_PATH, PROMPT_NUM | |
with open(TEXT_PROMPT_PATH, 'r') as f: | |
prompt_list = json.load(f) | |
assert len(prompt_list) == PROMPT_NUM, f"Load {len(prompt_list)} text prompts, but expected {PROMPT_NUM}." | |
class State: | |
def __init__(self, | |
model_name, i2s_mode=False, offline=False, | |
prompt=None, image=None, offline_idx=None, | |
normal_video=None , rgb_video=None, geo_video=None, | |
evaluted_dims=0): | |
self.conv_id = uuid.uuid4().hex | |
self.model_name = model_name | |
self.i2s_mode = i2s_mode | |
self.offline = offline | |
self.prompt = prompt | |
self.image = image | |
self.offline_idx = offline_idx | |
# self.output = None | |
self.normal_video = normal_video | |
self.rgb_video = rgb_video | |
self.geo_video = geo_video | |
self.evaluted_dims = evaluted_dims | |
def dict(self): | |
base = { | |
"conv_id": self.conv_id, | |
"model_name": self.model_name, | |
"i2s_mode": self.i2s_mode, | |
"offline": self.offline, | |
"prompt": self.prompt, | |
"evaluted_dims": self.evaluted_dims, | |
} | |
if self.offline: | |
base['offline_idx'] = self.offline_idx | |
return base | |
# class StateI2S: | |
# def __init__(self, model_name): | |
# self.conv_id = uuid.uuid4().hex | |
# self.model_name = model_name | |
# self.image = None | |
# self.output = None | |
# def dict(self): | |
# base = { | |
# "conv_id": self.conv_id, | |
# "model_name": self.model_name, | |
# } | |
# return base | |
def sample_t2s_model(state_0, state_1, model_list): | |
model_name_0, model_name_1 = random.sample(eval(model_list), 2) | |
if state_0 is None: | |
state_0 = State(model_name_0, i2s_mode=False) | |
if state_1 is None: | |
state_1 = State(model_name_1, i2s_mode=False) | |
state_0.model_name = model_name_0 | |
state_0.i2s_mode = False | |
state_1.model_name = model_name_1 | |
state_1.i2s_mode = False | |
return state_0, state_1, model_name_0, model_name_1 | |
def sample_i2s_model(state_0, state_1, model_list): | |
model_name_0, model_name_1 = random.sample(eval(model_list), 2) | |
if state_0 is None: | |
state_0 = State(model_name_0, i2s_mode=True) | |
if state_1 is None: | |
state_1 = State(model_name_1, i2s_mode=True) | |
state_0.model_name = model_name_0 | |
state_0.i2s_mode = True | |
state_1.model_name = model_name_1 | |
state_1.i2s_mode = True | |
return state_0, state_1, model_name_0, model_name_1 | |
def sample_prompt(state, model_name): | |
if state is None: | |
state = State(model_name) | |
idx = random.randint(0, PROMPT_NUM-1) | |
prompt = prompt_list[idx] | |
state.model_name = model_name | |
state.prompt = prompt | |
state.i2s_mode = False | |
state.offline = True, | |
state.offline_idx = idx | |
return state, prompt | |
def sample_prompt_side_by_side(state_0, state_1, model_name_0, model_name_1): | |
if state_0 is None: | |
state_0 = State(model_name_0) | |
if state_1 is None: | |
state_1 = State(model_name_1) | |
idx = random.randint(0, PROMPT_NUM-1) | |
prompt = prompt_list[idx] | |
state_0.i2s_mode, state_1.i2s_mode = False, False | |
state_0.offline, state_1.offline = True, True | |
state_0.offline_idx, state_1.offline_idx = idx, idx | |
state_0.prompt, state_1.prompt = prompt, prompt | |
return state_0, state_1, prompt | |
def sample_image(state, model_name): | |
if state is None: | |
state = State(model_name) | |
idx = random.randint(0, PROMPT_NUM-1) | |
img_url = f"{RGBA_SERVER}/{idx}.png" | |
state.model_name = model_name | |
state.image = img_url | |
state.i2s_mode = True | |
state.offline = True, | |
state.offline_idx = idx | |
return state, img_url | |
def sample_image_side_by_side(state_0, state_1, model_name_0, model_name_1): | |
if state_0 is None: | |
state_0 = State(model_name_0) | |
if state_1 is None: | |
state_1 = State(model_name_1) | |
idx = random.randint(0, PROMPT_NUM-1) | |
img_url = f"{RGBA_SERVER}/{idx}.png" | |
state_0.i2s_mode, state_1.i2s_mode = True, True | |
state_0.offline, state_1.offline = True, True | |
state_0.offline_idx, state_1.offline_idx = idx, idx | |
state_0.image, state_1.image = img_url, img_url | |
return state_0, state_1, img_url | |
def generate_t2s(gen_func, render_func, | |
state, | |
text, | |
model_name, | |
request: gr.Request): | |
if not text or text.strip()=="": | |
raise gr.Warning("Prompt cannot be empty.") | |
if not model_name: | |
raise gr.Warning("Model name cannot be empty.") | |
if state is None: | |
state = State(model_name, i2s_mode=False, offline=False) | |
text = text.strip() | |
ip = get_ip(request) | |
t2s_logger.info(f"generate. ip: {ip}") | |
state.model_name = model_name | |
state.prompt = text | |
state.evaluted_dims = 0 | |
try: | |
idx = prompt_list.index(text) | |
state.offline = True | |
state.offline_idx = idx | |
except: | |
state.offline = False | |
state.offline_idx = None | |
if state.offline and state.offline_idx: | |
start_time = time.time() | |
videos = gen_func(text, model_name, offline=state.offline, offline_idx=state.offline_idx) | |
# normal_video = os.path.join(OFFLINE_DIR, "text2shape", model_name, "normal", f"{state.offline_idx}.mp4") | |
# rgb_video = os.path.join(OFFLINE_DIR, "text2shape", model_name, "rgb", f"{state.offline_idx}.mp4") | |
state.normal_video = videos['normal'] | |
state.rgb_video = videos['rgb'] | |
state.geo_video = videos['geo'] | |
yield state, videos['geo'], videos['normal'], videos['rgb'] | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name, | |
"type": "offline", | |
"gen_params": {}, | |
"state": state.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"ip": get_ip(request), | |
} | |
else: | |
start_time = time.time() | |
shape = gen_func(text, model_name) | |
generate_time = time.time() - start_time | |
videos = render_func(shape, model_name) | |
finish_time = time.time() | |
render_time = finish_time - start_time - generate_time | |
state.normal_video = videos['normal'] | |
state.rgb_video = videos['rgb'] | |
yield state, videos['normal'], videos['rgb'] | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name, | |
"type": "online", | |
"gen_params": {}, | |
"state": state.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"time": round(finish_time - start_time, 4), | |
"generate_time": round(generate_time, 4), | |
"render_time": round(render_time, 4), | |
"ip": get_ip(request), | |
} | |
with open(get_conv_log_filename(), "a") as fout: | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
# output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png' | |
# os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
# with open(output_file, 'w') as f: | |
# state.output.save(f, 'PNG') | |
# save_image_file_on_log_server(output_file) | |
def generate_t2s_multi(gen_func, render_func, | |
state_0, state_1, | |
text, | |
model_name_0, model_name_1, | |
request: gr.Request): | |
if not text or text.strip()=="": | |
raise gr.Warning("Prompt cannot be empty.") | |
if not model_name_0: | |
raise gr.Warning("Model name A cannot be empty.") | |
if not model_name_1: | |
raise gr.Warning("Model name B cannot be empty.") | |
if state_0 is None: | |
state_0 = State(model_name_0, i2s_mode=False, offline=False) | |
if state_1 is None: | |
state_1 = State(model_name_1, i2s_mode=False, offline=False) | |
text = text.strip() | |
ip = get_ip(request) | |
t2s_multi_logger.info(f"generate. ip: {ip}") | |
state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
state_0.prompt, state_1.prompt = text, text | |
state_0.evaluted_dims, state_1.evaluted_dims = 0, 0 | |
try: | |
idx = prompt_list.index(text) | |
state_0.offline, state_1.offline = True, True | |
state_0.offline_idx, state_1.offline_idx = idx, idx | |
except: | |
state_0.offline, state_1.offline = False, False | |
state_0.offline_idx, state_1.offline_idx = None, None | |
if state_0.offline and state_0.offline_idx: | |
start_time = time.time() | |
videos_0, videos_1 = gen_func(text, model_name_0, model_name_1, offline=state_0.offline, offline_idx=state_0.offline_idx) | |
# normal_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4") | |
# rgb_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4") | |
# normal_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4") | |
# rgb_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4") | |
state_0.normal_video, state_0.rgb_video, state_0.geo_video = videos_0['normal'], videos_0['rgb'], videos_0['geo'] | |
state_1.normal_video, state_1.rgb_video, state_1.geo_video = videos_1['normal'], videos_1['rgb'], videos_1['geo'] | |
yield state_0, state_1, \ | |
videos_0['geo'], videos_0['normal'], videos_0['rgb'], \ | |
videos_1['geo'], videos_1['normal'], videos_1['rgb'] | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
data_0 = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name_0, | |
"type": "offline", | |
"gen_params": {}, | |
"state": state_0.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"ip": get_ip(request), | |
} | |
data_1 = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name_1, | |
"type": "offline", | |
"gen_params": {}, | |
"state": state_1.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"ip": get_ip(request), | |
} | |
else: | |
start_time = time.time() | |
shape_0, shape_1 = gen_func(text, model_name_0, model_name_1) | |
generate_time = time.time() - start_time | |
videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1) | |
finish_time = time.time() | |
render_time = finish_time - start_time - generate_time | |
state_0.normal_video, state_0.rgb_video, state_0.geo_video = videos_0['normal'], videos_0['rgb'], videos_0['geo'] | |
state_1.normal_video, state_1.rgb_video, state_1.geo_video = videos_1['normal'], videos_1['rgb'], videos_1['geo'] | |
yield state_0, state_1, \ | |
videos_0['geo'], videos_0['normal'], videos_0['rgb'], \ | |
videos_1['geo'], videos_1['normal'], videos_1['rgb'] | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
data_0 = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name_0, | |
"type": "online", | |
"gen_params": {}, | |
"state": state_0.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"time": round(finish_time - start_time, 4), | |
"generate_time": round(generate_time, 4), | |
"render_time": round(render_time, 4), | |
"ip": get_ip(request), | |
} | |
data_1 = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name_1, | |
"type": "online", | |
"gen_params": {}, | |
"state": state_1.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"time": round(finish_time - start_time, 4), | |
"generate_time": round(generate_time, 4), | |
"render_time": round(render_time, 4), | |
"ip": get_ip(request), | |
} | |
with open(get_conv_log_filename(), "a") as fout: | |
fout.write(json.dumps(data_0) + "\n") | |
fout.write(json.dumps(data_1) + "\n") | |
append_json_item_on_log_server(data_0, get_conv_log_filename()) | |
append_json_item_on_log_server(data_1, get_conv_log_filename()) | |
# for i, state in enumerate([state_0, state_1]): | |
# output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png' | |
# os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
# with open(output_file, 'w') as f: | |
# state.output.save(f, 'PNG') | |
# save_image_file_on_log_server(output_file) | |
def generate_t2s_multi_annoy(gen_func, render_func, | |
state_0, state_1, | |
text, | |
model_name_0, model_name_1, | |
request: gr.Request): | |
if not text or text.strip()=="": | |
raise gr.Warning("Prompt cannot be empty.") | |
if state_0 is None: | |
state_0 = State(model_name_0, i2s_mode=False, offline=False) | |
if state_1 is None: | |
state_1 = State(model_name_1, i2s_mode=False, offline=False) | |
text = text.strip() | |
ip = get_ip(request) | |
t2s_multi_logger.info(f"generate. ip: {ip}") | |
state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
state_0.prompt, state_1.prompt = text, text | |
state_0.evaluted_dims, state_1.evaluted_dims = 0, 0 | |
try: | |
idx = prompt_list.index(text) | |
state_0.offline, state_1.offline = True, True | |
state_0.offline_idx, state_1.offline_idx = idx, idx | |
except: | |
state_0.offline, state_1.offline = False, False | |
state_0.offline_idx, state_1.offline_idx = None, None | |
if state_0.offline and state_0.offline_idx: | |
start_time = time.time() | |
videos_0, videos_1, model_name_0, model_name_1 = gen_func(text, model_name_0, model_name_1, | |
i2s_model=False, offline=state_0.offline, offline_idx=state_0.offline_idx) | |
# normal_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4") | |
# rgb_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4") | |
# normal_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4") | |
# rgb_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4") | |
state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
state_0.normal_video, state_0.rgb_video, state_0.geo_video = videos_0['normal'], videos_0['rgb'], videos_0['geo'] | |
state_1.normal_video, state_1.rgb_video, state_1.geo_video = videos_1['normal'], videos_1['rgb'], videos_1['geo'] | |
yield state_0, state_1, \ | |
videos_0['geo'], videos_0['normal'], videos_0['rgb'], \ | |
videos_1['geo'], videos_1['normal'], videos_1['rgb'], \ | |
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
data_0 = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name_0, | |
"type": "offline", | |
"gen_params": {}, | |
"state": state_0.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"ip": get_ip(request), | |
} | |
data_1 = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name_1, | |
"type": "offline", | |
"gen_params": {}, | |
"state": state_1.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"ip": get_ip(request), | |
} | |
else: | |
start_time = time.time() | |
shape_0, shape_1, model_name_0, model_name_1 = gen_func(text, model_name_0, model_name_1, i2s_model=False) | |
generate_time = time.time() - start_time | |
videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1) | |
finish_time = time.time() | |
render_time = finish_time - start_time - generate_time | |
state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
state_0.normal_video, state_0.rgb_video, state_0.geo_video = videos_0['normal'], videos_0['rgb'], videos_0['geo'] | |
state_1.normal_video, state_1.rgb_video, state_1.geo_video = videos_1['normal'], videos_1['rgb'], videos_1['geo'] | |
yield state_0, state_1, \ | |
videos_0['geo'], videos_0['normal'], videos_0['rgb'], \ | |
videos_1['geo'], videos_1['normal'], videos_1['rgb'], \ | |
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
data_0 = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name_0, | |
"type": "online", | |
"gen_params": {}, | |
"state": state_0.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"time": round(finish_time - start_time, 4), | |
"generate_time": round(generate_time, 4), | |
"render_time": round(render_time, 4), | |
"ip": get_ip(request), | |
} | |
data_1 = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name_1, | |
"type": "online", | |
"gen_params": {}, | |
"state": state_1.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"time": round(finish_time - start_time, 4), | |
"generate_time": round(generate_time, 4), | |
"render_time": round(render_time, 4), | |
"ip": get_ip(request), | |
} | |
with open(get_conv_log_filename(), "a") as fout: | |
fout.write(json.dumps(data_0) + "\n") | |
fout.write(json.dumps(data_1) + "\n") | |
append_json_item_on_log_server(data_0, get_conv_log_filename()) | |
append_json_item_on_log_server(data_1, get_conv_log_filename()) | |
# for i, state in enumerate([state_0, state_1]): | |
# output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png' | |
# os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
# with open(output_file, 'w') as f: | |
# state.output.save(f, 'PNG') | |
# save_image_file_on_log_server(output_file) | |
def generate_i2s(gen_func, render_func, state, image, model_name, request: gr.Request): | |
if image is None: | |
raise gr.Warning("Image cannot be empty.") | |
if not model_name: | |
raise gr.Warning("Model name cannot be empty.") | |
if state is None: | |
state = State(model_name, i2s_mode=True, offline=False) | |
ip = get_ip(request) | |
i2s_logger.info(f"generate. ip: {ip}") | |
state.model_name = model_name | |
state.image = image | |
state.evaluted_dims = 0 | |
if state.offline and state.offline_idx: | |
start_time = time.time() | |
videos = gen_func(image, model_name, offline=state.offline, offline_idx=state.offline_idx) | |
# normal_video = os.path.join(OFFLINE_DIR, "image2shape", model_name, "normal", f"{state.offline_idx}.mp4") | |
# rgb_video = os.path.join(OFFLINE_DIR, "image2shape", model_name, "rgb", f"{state.offline_idx}.mp4") | |
state.normal_video = videos['normal'] | |
state.rgb_video = videos['rgb'] | |
state.geo_video = videos['geo'] | |
yield state, videos['geo'], videos['normal'], videos['rgb'] | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name, | |
"type": "offline", | |
"gen_params": {}, | |
"state": state.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"ip": get_ip(request), | |
} | |
else: | |
start_time = time.time() | |
shape = gen_func(image, model_name) | |
generate_time = time.time() - start_time | |
videos = render_func(shape, model_name) | |
finish_time = time.time() | |
render_time = finish_time - start_time - generate_time | |
state.normal_video = videos['normal'] | |
state.rgb_video = videos['rgb'] | |
yield state, videos['normal'], videos['rgb'] | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name, | |
"type": "online", | |
"gen_params": {}, | |
"state": state.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"time": round(finish_time - start_time, 4), | |
"generate_time": round(generate_time, 4), | |
"render_time": round(render_time, 4), | |
"ip": get_ip(request), | |
} | |
with open(get_conv_log_filename(), "a") as fout: | |
fout.write(json.dumps(data) + "\n") | |
append_json_item_on_log_server(data, get_conv_log_filename()) | |
# src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png' | |
# os.makedirs(os.path.dirname(src_img_file), exist_ok=True) | |
# with open(src_img_file, 'w') as f: | |
# state.source_image.save(f, 'PNG') | |
# output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png' | |
# with open(output_file, 'w') as f: | |
# state.output.save(f, 'PNG') | |
# save_image_file_on_log_server(src_img_file) | |
# save_image_file_on_log_server(output_file) | |
def generate_i2s_multi(gen_func, render_func, | |
state_0, state_1, | |
image, | |
model_name_0, model_name_1, | |
request: gr.Request): | |
if image is None: | |
raise gr.Warning("Image cannot be empty.") | |
if not model_name_0: | |
raise gr.Warning("Model name A cannot be empty.") | |
if not model_name_1: | |
raise gr.Warning("Model name B cannot be empty.") | |
if state_0 is None: | |
state_0 = State(model_name_0, i2s_mode=True, offline=False) | |
if state_1 is None: | |
state_1 = State(model_name_1, i2s_mode=True, offline=False) | |
ip = get_ip(request) | |
i2s_multi_logger.info(f"generate. ip: {ip}") | |
state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
state_0.image, state_1.image = image, image | |
state_0.evaluted_dims, state_1.evaluted_dims = 0, 0 | |
if state_0.offline and state_0.offline_idx: | |
start_time = time.time() | |
videos_0, videos_1 = gen_func(image, model_name_0, model_name_1, offline=state_0.offline, offline_idx=state_0.offline_idx) | |
# normal_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4") | |
# rgb_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4") | |
# normal_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4") | |
# rgb_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4") | |
state_0.normal_video, state_0.rgb_video, state_0.geo_video = videos_0['normal'], videos_0['rgb'], videos_0['geo'] | |
state_1.normal_video, state_1.rgb_video, state_1.geo_video = videos_1['normal'], videos_1['rgb'], videos_1['geo'] | |
yield state_0, state_1, \ | |
videos_0['geo'], videos_0['normal'], videos_0['rgb'], \ | |
videos_1['geo'], videos_1['normal'], videos_1['rgb'], \ | |
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
data_0 = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name_0, | |
"type": "offline", | |
"gen_params": {}, | |
"state": state_0.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"ip": get_ip(request), | |
} | |
data_1 = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name_1, | |
"type": "offline", | |
"gen_params": {}, | |
"state": state_1.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"ip": get_ip(request), | |
} | |
else: | |
start_time = time.time() | |
shape_0, shape_1 = gen_func(image, model_name_0, model_name_1) | |
generate_time = time.time() - start_time | |
videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1) | |
finish_time = time.time() | |
render_time = finish_time - start_time - generate_time | |
state_0.normal_video, state_0.rgb_video, state_0.geo_video = videos_0['normal'], videos_0['rgb'], videos_0['geo'] | |
state_1.normal_video, state_1.rgb_video, state_1.geo_video = videos_1['normal'], videos_1['rgb'], videos_1['geo'] | |
yield state_0, state_1, \ | |
videos_0['geo'], videos_0['normal'], videos_0['rgb'], \ | |
videos_1['geo'], videos_1['normal'], videos_1['rgb'] | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
data_0 = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name_0, | |
"type": "online", | |
"gen_params": {}, | |
"state": state_0.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"time": round(finish_time - start_time, 4), | |
"generate_time": round(generate_time, 4), | |
"render_time": round(render_time, 4), | |
"ip": get_ip(request), | |
} | |
data_1 = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name_1, | |
"type": "online", | |
"gen_params": {}, | |
"state": state_1.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"time": round(finish_time - start_time, 4), | |
"generate_time": round(generate_time, 4), | |
"render_time": round(render_time, 4), | |
"ip": get_ip(request), | |
} | |
with open(get_conv_log_filename(), "a") as fout: | |
fout.write(json.dumps(data_0) + "\n") | |
fout.write(json.dumps(data_1) + "\n") | |
append_json_item_on_log_server(data_0, get_conv_log_filename()) | |
append_json_item_on_log_server(data_1, get_conv_log_filename()) | |
# for i, state in enumerate([state_0, state_1]): | |
# src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png' | |
# os.makedirs(os.path.dirname(src_img_file), exist_ok=True) | |
# with open(src_img_file, 'w') as f: | |
# state.source_image.save(f, 'PNG') | |
# output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png' | |
# with open(output_file, 'w') as f: | |
# state.output.save(f, 'PNG') | |
# save_image_file_on_log_server(src_img_file) | |
# save_image_file_on_log_server(output_file) | |
def generate_i2s_multi_annoy(gen_func, render_func, | |
state_0, state_1, | |
image, | |
model_name_0, model_name_1, | |
request: gr.Request): | |
if image is None: | |
raise gr.Warning("Image cannot be empty.") | |
if state_0 is None: | |
state_0 = State(model_name_0, i2s_mode=True, offline=False) | |
if state_1 is None: | |
state_1 = State(model_name_1, i2s_mode=True, offline=False) | |
ip = get_ip(request) | |
i2s_multi_logger.info(f"generate. ip: {ip}") | |
state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
state_0.image, state_1.image = image, image | |
state_0.evaluted_dims, state_1.evaluted_dims = 0, 0 | |
if state_0.offline and state_0.offline_idx and state_1.offline and state_1.offline_idx: | |
start_time = time.time() | |
videos_0, videos_1, model_name_0, model_name_1 = gen_func(image, model_name_0, model_name_1, | |
i2s_model=True, offline=state_0.offline, offline_idx=state_0.offline_idx) | |
# normal_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4") | |
# rgb_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4") | |
# normal_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4") | |
# rgb_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4") | |
state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
state_0.normal_video, state_0.rgb_video, state_0.geo_video = videos_0['normal'], videos_0['rgb'], videos_0['geo'] | |
state_1.normal_video, state_1.rgb_video, state_1.geo_video = videos_1['normal'], videos_1['rgb'], videos_1['geo'] | |
yield state_0, state_1, \ | |
videos_0['geo'], videos_0['normal'], videos_0['rgb'], \ | |
videos_1['geo'], videos_1['normal'], videos_1['rgb'], \ | |
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
data_0 = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name_0, | |
"type": "offline", | |
"gen_params": {}, | |
"state": state_0.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"ip": get_ip(request), | |
} | |
data_1 = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name_1, | |
"type": "offline", | |
"gen_params": {}, | |
"state": state_1.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"ip": get_ip(request), | |
} | |
else: | |
start_time = time.time() | |
shape_0, shape_1 = gen_func(image, model_name_0, model_name_1, i2s_model=True) | |
generate_time = time.time() - start_time | |
videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1) | |
finish_time = time.time() | |
render_time = finish_time - start_time - generate_time | |
state_0.normal_video, state_0.rgb_video, state_0.geo_video = videos_0['normal'], videos_0['rgb'], videos_0['geo'] | |
state_1.normal_video, state_1.rgb_video, state_1.geo_video = videos_1['normal'], videos_1['rgb'], videos_1['geo'] | |
yield state_0, state_1, \ | |
videos_0['geo'], videos_0['normal'], videos_0['rgb'], \ | |
videos_1['geo'], videos_1['normal'], videos_1['rgb'], \ | |
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") | |
finish_tstamp = time.time() | |
# logger.info(f"===output===: {output}") | |
data_0 = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name_0, | |
"type": "online", | |
"gen_params": {}, | |
"state": state_0.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"time": round(finish_time - start_time, 4), | |
"generate_time": round(generate_time, 4), | |
"render_time": round(render_time, 4), | |
"ip": get_ip(request), | |
} | |
data_1 = { | |
"tstamp": round(finish_tstamp, 4), | |
"model": model_name_1, | |
"type": "online", | |
"gen_params": {}, | |
"state": state_1.dict(), | |
"start": round(start_time, 4), | |
"finish": round(finish_tstamp, 4), | |
"time": round(finish_time - start_time, 4), | |
"generate_time": round(generate_time, 4), | |
"render_time": round(render_time, 4), | |
"ip": get_ip(request), | |
} | |
with open(get_conv_log_filename(), "a") as fout: | |
fout.write(json.dumps(data_0) + "\n") | |
fout.write(json.dumps(data_1) + "\n") | |
append_json_item_on_log_server(data_0, get_conv_log_filename()) | |
append_json_item_on_log_server(data_1, get_conv_log_filename()) | |
# for i, state in enumerate([state_0, state_1]): | |
# src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png' | |
# os.makedirs(os.path.dirname(src_img_file), exist_ok=True) | |
# with open(src_img_file, 'w') as f: | |
# state.source_image.save(f, 'PNG') | |
# output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png' | |
# with open(output_file, 'w') as f: | |
# state.output.save(f, 'PNG') | |
# save_image_file_on_log_server(src_img_file) | |
# save_image_file_on_log_server(output_file) |