AnsenH's picture
Update frontend/app.py
03344ff verified
import gradio as gr
import requests
from utils import resize_image, pil_to_b64, b64_to_pil, process_images_and_inpaint
USE_FASTAPI = False
FAST_API_ENDPOINT = 'http://127.0.0.1:5000/inpaint'
def run_inpainting(img_1, img_2, img_3, img_4, alpha_gradient_width, init_image_height):
images = []
for img in [img_1, img_2, img_3, img_4]:
if img is not None:
images.append(pil_to_b64(resize_image(img, init_image_height)))
if USE_FASTAPI:
return call_inpainting_api(img_1, img_2, img_3, img_4, alpha_gradient_width, init_image_height)
else:
return b64_to_pil(process_images_and_inpaint(images, int(alpha_gradient_width), int(init_image_height)))
def call_inpainting_api(img_1, img_2, img_3, img_4, alpha_gradient_width, init_image_height):
images = []
for img in [img_1, img_2, img_3, img_4]:
if img is not None:
images.append(pil_to_b64(resize_image(img, init_image_height)))
response = requests.post(FAST_API_ENDPOINT, json={
"images": images,
"alpha_gradient_width": alpha_gradient_width,
"init_image_height": init_image_height
})
if response.status_code == 200:
return b64_to_pil(response.json()["inpainted_image"])
else:
return "Error calling inpainting API"
TITLE = """<h2 align="center"> ๐ŸŽž๏ธ Memory Carousel </h2>"""
# Define the Gradio interface
with gr.Blocks() as demo:
gr.HTML(TITLE)
with gr.Column():
with gr.Row():
input_image_1 = gr.Image(type='pil', label="First image")
input_image_2 = gr.Image(type='pil', label="Second image")
with gr.Row():
input_image_3 = gr.Image(type='pil', label="Third image(optional)")
input_image_4 = gr.Image(type='pil', label="Fourth image(optional)")
with gr.Row():
alpha_gradient_width = gr.Number(value=100, label="Alpha Gradient Width")
init_image_height = gr.Number(value=768, label="Init Image Height")
generate_button = gr.Button("Generate")
output = gr.Image(type='pil')
generate_button.click(
fn=run_inpainting,
inputs=[input_image_1, input_image_2, input_image_3, input_image_4, alpha_gradient_width, init_image_height],
outputs=[output]
)
demo.launch()