GenAI-Arena / serve /vote_utils.py
tianleliphoebe's picture
fix a bug
340b8c3
raw
history blame
42.3 kB
import datetime
import time
import json
import uuid
import gradio as gr
import regex as re
from pathlib import Path
from .utils import *
from .log_utils import build_logger
from .constants import IMAGE_DIR, VIDEO_DIR
import imageio
ig_logger = build_logger("gradio_web_server_image_generation", "gr_web_image_generation.log") # ig = image generation, loggers for single model direct chat
igm_logger = build_logger("gradio_web_server_image_generation_multi", "gr_web_image_generation_multi.log") # igm = image generation multi, loggers for side-by-side and battle
ie_logger = build_logger("gradio_web_server_image_editing", "gr_web_image_editing.log") # ie = image editing, loggers for single model direct chat
iem_logger = build_logger("gradio_web_server_image_editing_multi", "gr_web_image_editing_multi.log") # iem = image editing multi, loggers for side-by-side and battle
vg_logger = build_logger("gradio_web_server_video_generation", "gr_web_video_generation.log") # vg = video generation, loggers for single model direct chat
vgm_logger = build_logger("gradio_web_server_video_generation_multi", "gr_web_video_generation_multi.log") # vgm = video generation multi, loggers for side-by-side and battle
def vote_last_response_ig(state, vote_type, model_selector, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": model_selector,
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
save_image_file_on_log_server(output_file)
def vote_last_response_igm(states, vote_type, model_selectors, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"models": [x for x in model_selectors],
"states": [x.dict() for x in states],
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for state in states:
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
save_image_file_on_log_server(output_file)
def vote_last_response_ie(state, vote_type, model_selector, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": model_selector,
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg'
source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg'
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
with open(source_file, 'w') as sf:
state.source_image.save(sf, 'JPEG')
save_image_file_on_log_server(output_file)
save_image_file_on_log_server(source_file)
def vote_last_response_iem(states, vote_type, model_selectors, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"models": [x for x in model_selectors],
"states": [x.dict() for x in states],
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for state in states:
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg'
source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg'
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
with open(source_file, 'w') as sf:
state.source_image.save(sf, 'JPEG')
save_image_file_on_log_server(output_file)
save_image_file_on_log_server(source_file)
def vote_last_response_vg(state, vote_type, model_selector, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": model_selector,
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
if state.model_name.startswith('fal'):
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
else:
print("======== video shape: ========")
print(state.output.shape)
imageio.mimwrite(output_file, state.output, fps=8, quality=9)
save_video_file_on_log_server(output_file)
def vote_last_response_vgm(states, vote_type, model_selectors, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"models": [x for x in model_selectors],
"states": [x.dict() for x in states],
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for state in states:
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
if state.model_name.startswith('fal'):
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
else:
print("======== video shape: ========")
print(state.output.shape)
imageio.mimwrite(output_file, state.output, fps=8, quality=9)
save_video_file_on_log_server(output_file)
## Image Generation (IG) Single Model Direct Chat
def upvote_last_response_ig(state, model_selector, request: gr.Request):
ip = get_ip(request)
ig_logger.info(f"upvote. ip: {ip}")
vote_last_response_ig(state, "upvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def downvote_last_response_ig(state, model_selector, request: gr.Request):
ip = get_ip(request)
ig_logger.info(f"downvote. ip: {ip}")
vote_last_response_ig(state, "downvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def flag_last_response_ig(state, model_selector, request: gr.Request):
ip = get_ip(request)
ig_logger.info(f"flag. ip: {ip}")
vote_last_response_ig(state, "flag", model_selector, request)
return ("",) + (disable_btn,) * 3
## Image Generation Multi (IGM) Side-by-Side and Battle
def leftvote_last_response_igm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
igm_logger.info(f"leftvote (named). ip: {get_ip(request)}")
vote_last_response_igm(
[state0, state1], "leftvote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (gr.Markdown('', visible=True), gr.Markdown('', visible=True))
def rightvote_last_response_igm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
igm_logger.info(f"rightvote (named). ip: {get_ip(request)}")
vote_last_response_igm(
[state0, state1], "rightvote", [model_selector0, model_selector1], request
)
print(model_selector0)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
print("no")
return ("",) + (disable_btn,) * 4 + (gr.Markdown('', visible=True), gr.Markdown('', visible=True))
def tievote_last_response_igm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
igm_logger.info(f"tievote (named). ip: {get_ip(request)}")
vote_last_response_igm(
[state0, state1], "tievote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (gr.Markdown('', visible=True), gr.Markdown('', visible=True))
def bothbad_vote_last_response_igm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
igm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
vote_last_response_igm(
[state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (gr.Markdown('', visible=True), gr.Markdown('', visible=True))
## Image Editing (IE) Single Model Direct Chat
def upvote_last_response_ie(state, model_selector, request: gr.Request):
ip = get_ip(request)
ie_logger.info(f"upvote. ip: {ip}")
vote_last_response_ie(state, "upvote", model_selector, request)
return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
def downvote_last_response_ie(state, model_selector, request: gr.Request):
ip = get_ip(request)
ie_logger.info(f"downvote. ip: {ip}")
vote_last_response_ie(state, "downvote", model_selector, request)
return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
def flag_last_response_ie(state, model_selector, request: gr.Request):
ip = get_ip(request)
ie_logger.info(f"flag. ip: {ip}")
vote_last_response_ie(state, "flag", model_selector, request)
return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
## Image Editing Multi (IEM) Side-by-Side and Battle
def leftvote_last_response_iem(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
iem_logger.info(f"leftvote (anony). ip: {get_ip(request)}")
vote_last_response_iem(
[state0, state1], "leftvote", [model_selector0, model_selector1], request
)
# names = (
# "### Model A: " + state0.model_name,
# "### Model B: " + state1.model_name,
# )
# names = (state0.model_name, state1.model_name)
if model_selector0 == "":
names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
names = (gr.Markdown('', visible=False), gr.Markdown('', visible=False))
return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
def rightvote_last_response_iem(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
iem_logger.info(f"rightvote (anony). ip: {get_ip(request)}")
vote_last_response_iem(
[state0, state1], "rightvote", [model_selector0, model_selector1], request
)
# names = (
# "### Model A: " + state0.model_name,
# "### Model B: " + state1.model_name,
# )
if model_selector0 == "":
names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
names = (gr.Markdown('', visible=False), gr.Markdown('', visible=False))
return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
def tievote_last_response_iem(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
iem_logger.info(f"tievote (anony). ip: {get_ip(request)}")
vote_last_response_iem(
[state0, state1], "tievote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
names = (gr.Markdown('', visible=False), gr.Markdown('', visible=False))
return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
def bothbad_vote_last_response_iem(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
iem_logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
vote_last_response_iem(
[state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
names = (gr.Markdown('', visible=False), gr.Markdown('', visible=False))
return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
## Video Generation (VG) Single Model Direct Chat
def upvote_last_response_vg(state, model_selector, request: gr.Request):
ip = get_ip(request)
vg_logger.info(f"upvote. ip: {ip}")
vote_last_response_vg(state, "upvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def downvote_last_response_vg(state, model_selector, request: gr.Request):
ip = get_ip(request)
vg_logger.info(f"downvote. ip: {ip}")
vote_last_response_vg(state, "downvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def flag_last_response_vg(state, model_selector, request: gr.Request):
ip = get_ip(request)
vg_logger.info(f"flag. ip: {ip}")
vote_last_response_vg(state, "flag", model_selector, request)
return ("",) + (disable_btn,) * 3
## Image Generation Multi (IGM) Side-by-Side and Battle
def leftvote_last_response_vgm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
vgm_logger.info(f"leftvote (named). ip: {get_ip(request)}")
vote_last_response_vgm(
[state0, state1], "leftvote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (
gr.Markdown("", visible=False),
gr.Markdown("", visible=False))
def rightvote_last_response_vgm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
vgm_logger.info(f"rightvote (named). ip: {get_ip(request)}")
vote_last_response_vgm(
[state0, state1], "rightvote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (
gr.Markdown("", visible=False),
gr.Markdown("", visible=False))
def tievote_last_response_vgm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
vgm_logger.info(f"tievote (named). ip: {get_ip(request)}")
vote_last_response_vgm(
[state0, state1], "tievote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (
gr.Markdown("", visible=False),
gr.Markdown("", visible=False))
def bothbad_vote_last_response_vgm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
vgm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
vote_last_response_vgm(
[state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (
gr.Markdown("", visible=False),
gr.Markdown("", visible=False))
share_js = """
function (a, b, c, d) {
const captureElement = document.querySelector('#share-region-named');
html2canvas(captureElement)
.then(canvas => {
canvas.style.display = 'none'
document.body.appendChild(canvas)
return canvas
})
.then(canvas => {
const image = canvas.toDataURL('image/png')
const a = document.createElement('a')
a.setAttribute('download', 'chatbot-arena.png')
a.setAttribute('href', image)
a.click()
canvas.remove()
});
return [a, b, c, d];
}
"""
def share_click_igm(state0, state1, model_selector0, model_selector1, request: gr.Request):
igm_logger.info(f"share (anony). ip: {get_ip(request)}")
if state0 is not None and state1 is not None:
vote_last_response_igm(
[state0, state1], "share", [model_selector0, model_selector1], request
)
def share_click_iem(state0, state1, model_selector0, model_selector1, request: gr.Request):
iem_logger.info(f"share (anony). ip: {get_ip(request)}")
if state0 is not None and state1 is not None:
vote_last_response_iem(
[state0, state1], "share", [model_selector0, model_selector1], request
)
## All Generation Gradio Interface
class ImageStateIG:
def __init__(self, model_name):
self.conv_id = uuid.uuid4().hex
self.model_name = model_name
self.prompt = None
self.output = None
def dict(self):
base = {
"conv_id": self.conv_id,
"model_name": self.model_name,
"prompt": self.prompt
}
return base
class ImageStateIE:
def __init__(self, model_name):
self.conv_id = uuid.uuid4().hex
self.model_name = model_name
self.source_prompt = None
self.target_prompt = None
self.instruct_prompt = None
self.source_image = None
self.output = None
def dict(self):
base = {
"conv_id": self.conv_id,
"model_name": self.model_name,
"source_prompt": self.source_prompt,
"target_prompt": self.target_prompt,
"instruct_prompt": self.instruct_prompt
}
return base
class VideoStateVG:
def __init__(self, model_name):
self.conv_id = uuid.uuid4().hex
self.model_name = model_name
self.prompt = None
self.output = None
def dict(self):
base = {
"conv_id": self.conv_id,
"model_name": self.model_name,
"prompt": self.prompt
}
return base
def generate_ig(gen_func, state, text, model_name, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
if not model_name:
raise gr.Warning("Model name cannot be empty.")
if state is None:
state = ImageStateIG(model_name)
ip = get_ip(request)
ig_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
generated_image = gen_func(text, model_name)
state.prompt = text
state.output = generated_image
state.model_name = model_name
yield state, generated_image
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
save_image_file_on_log_server(output_file)
def generate_igm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
if not model_name0:
raise gr.Warning("Model name A cannot be empty.")
if not model_name1:
raise gr.Warning("Model name B cannot be empty.")
if state0 is None:
state0 = ImageStateIG(model_name0)
if state1 is None:
state1 = ImageStateIG(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
# Remove ### Model (A|B): from model name
model_name0 = re.sub(r"### Model A: ", "", model_name0)
model_name1 = re.sub(r"### Model B: ", "", model_name1)
generated_image0, generated_image1 = gen_func(text, model_name0, model_name1)
state0.prompt = text
state1.prompt = text
state0.output = generated_image0
state1.output = generated_image1
state0.model_name = model_name0
state1.model_name = model_name1
yield state0, state1, generated_image0, generated_image1
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
save_image_file_on_log_server(output_file)
def generate_igm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
if state0 is None:
state0 = ImageStateIG(model_name0)
if state1 is None:
state1 = ImageStateIG(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
model_name0 = re.sub(r"### Model A: ", "", model_name0)
model_name1 = re.sub(r"### Model B: ", "", model_name1)
generated_image0, generated_image1, model_name0, model_name1 = gen_func(text, model_name0, model_name1)
state0.prompt = text
state1.prompt = text
state0.output = generated_image0
state1.output = generated_image1
state0.model_name = model_name0
state1.model_name = model_name1
yield state0, state1, generated_image0, generated_image1, \
gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}")
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
save_image_file_on_log_server(output_file)
def generate_ie(gen_func, state, source_text, target_text, instruct_text, source_image, model_name, request: gr.Request):
if not source_text:
raise gr.Warning("Source prompt cannot be empty.")
if not target_text:
raise gr.Warning("Target prompt cannot be empty.")
if not instruct_text:
raise gr.Warning("Instruction prompt cannot be empty.")
if not source_image:
raise gr.Warning("Source image cannot be empty.")
if not model_name:
raise gr.Warning("Model name cannot be empty.")
if state is None:
state = ImageStateIE(model_name)
ip = get_ip(request)
ig_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
generated_image = gen_func(source_text, target_text, instruct_text, source_image, model_name)
state.source_prompt = source_text
state.target_prompt = target_text
state.instruct_prompt = instruct_text
state.source_image = source_image
state.output = generated_image
state.model_name = model_name
yield state, generated_image
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
with open(src_img_file, 'w') as f:
state.source_image.save(f, 'JPEG')
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
save_image_file_on_log_server(src_img_file)
save_image_file_on_log_server(output_file)
def generate_iem(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request):
if not source_text:
raise gr.Warning("Source prompt cannot be empty.")
if not target_text:
raise gr.Warning("Target prompt cannot be empty.")
if not instruct_text:
raise gr.Warning("Instruction prompt cannot be empty.")
if not source_image:
raise gr.Warning("Source image cannot be empty.")
if not model_name0:
raise gr.Warning("Model name A cannot be empty.")
if not model_name1:
raise gr.Warning("Model name B cannot be empty.")
if state0 is None:
state0 = ImageStateIE(model_name0)
if state1 is None:
state1 = ImageStateIE(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
model_name0 = re.sub(r"### Model A: ", "", model_name0)
model_name1 = re.sub(r"### Model B: ", "", model_name1)
generated_image0, generated_image1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1)
state0.source_prompt = source_text
state0.target_prompt = target_text
state0.instruct_prompt = instruct_text
state0.source_image = source_image
state0.output = generated_image0
state0.model_name = model_name0
state1.source_prompt = source_text
state1.target_prompt = target_text
state1.instruct_prompt = instruct_text
state1.source_image = source_image
state1.output = generated_image1
state1.model_name = model_name1
yield state0, state1, generated_image0, generated_image1
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
with open(src_img_file, 'w') as f:
state.source_image.save(f, 'JPEG')
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
save_image_file_on_log_server(src_img_file)
save_image_file_on_log_server(output_file)
def generate_iem_annoy(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request):
if not source_text:
raise gr.Warning("Source prompt cannot be empty.")
if not target_text:
raise gr.Warning("Target prompt cannot be empty.")
if not instruct_text:
raise gr.Warning("Instruction prompt cannot be empty.")
if not source_image:
raise gr.Warning("Source image cannot be empty.")
if state0 is None:
state0 = ImageStateIE(model_name0)
if state1 is None:
state1 = ImageStateIE(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
model_name0 = re.sub(r"### Model A: ", "", model_name0)
model_name1 = re.sub(r"### Model B: ", "", model_name1)
generated_image0, generated_image1, model_name0, model_name1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1)
state0.source_prompt = source_text
state0.target_prompt = target_text
state0.instruct_prompt = instruct_text
state0.source_image = source_image
state0.output = generated_image0
state0.model_name = model_name0
state1.source_prompt = source_text
state1.target_prompt = target_text
state1.instruct_prompt = instruct_text
state1.source_image = source_image
state1.output = generated_image1
state1.model_name = model_name1
yield state0, state1, generated_image0, generated_image1, \
gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}")
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
with open(src_img_file, 'w') as f:
state.source_image.save(f, 'JPEG')
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
save_image_file_on_log_server(src_img_file)
save_image_file_on_log_server(output_file)
def generate_vg(gen_func, state, text, model_name, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
if not model_name:
raise gr.Warning("Model name cannot be empty.")
if state is None:
state = VideoStateVG(model_name)
ip = get_ip(request)
vg_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
generated_video = gen_func(text, model_name)
state.prompt = text
state.output = generated_video
state.model_name = model_name
# yield state, generated_video
finish_tstamp = time.time()
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
if model_name.startswith('fal'):
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
else:
print("======== video shape: ========")
print(state.output.shape)
imageio.mimwrite(output_file, state.output, fps=8, quality=9)
save_video_file_on_log_server(output_file)
yield state, output_file
def generate_vgm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
if not model_name0:
raise gr.Warning("Model name A cannot be empty.")
if not model_name1:
raise gr.Warning("Model name B cannot be empty.")
if state0 is None:
state0 = VideoStateVG(model_name0)
if state1 is None:
state1 = VideoStateVG(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
# Remove ### Model (A|B): from model name
model_name0 = re.sub(r"### Model A: ", "", model_name0)
model_name1 = re.sub(r"### Model B: ", "", model_name1)
generated_video0, generated_video1 = gen_func(text, model_name0, model_name1)
state0.prompt = text
state1.prompt = text
state0.output = generated_video0
state1.output = generated_video1
state0.model_name = model_name0
state1.model_name = model_name1
# yield state0, state1, generated_video0, generated_video1
print("====== model name =========")
print(state0.model_name)
print(state1.model_name)
finish_tstamp = time.time()
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
print(state.model_name)
if state.model_name.startswith('fal'):
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
else:
print("======== video shape: ========")
print(state.output)
print(state.output.shape)
imageio.mimwrite(output_file, state.output, fps=8, quality=9)
save_video_file_on_log_server(output_file)
yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4'
def generate_vgm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
if state0 is None:
state0 = VideoStateVG(model_name0)
if state1 is None:
state1 = VideoStateVG(model_name1)
ip = get_ip(request)
vgm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
model_name0 = re.sub(r"### Model A: ", "", model_name0)
model_name1 = re.sub(r"### Model B: ", "", model_name1)
generated_video0, generated_video1, model_name0, model_name1 = gen_func(text, model_name0, model_name1)
state0.prompt = text
state1.prompt = text
state0.output = generated_video0
state1.output = generated_video1
state0.model_name = model_name0
state1.model_name = model_name1
# yield state0, state1, generated_video0, generated_video1, \
# gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}")
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
if state.model_name.startswith('fal'):
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
else:
print("======== video shape: ========")
print(state.output.shape)
imageio.mimwrite(output_file, state.output, fps=8, quality=9)
save_video_file_on_log_server(output_file)
yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', \
gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}")