File size: 2,327 Bytes
dbf1d4a
 
 
 
 
 
 
b91759c
dbf1d4a
 
 
 
 
 
 
a56ba50
dbf1d4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b91759c
dbf1d4a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import gradio as gr
import requests
from utils import resize_image, pil_to_b64, b64_to_pil, process_images_and_inpaint

USE_FASTAPI = False
FAST_API_ENDPOINT = 'http://127.0.0.1:5000/inpaint'

def run_inpainting(img_1, img_2, img_3, img_4, alpha_gradient_width, init_image_height):
    images = []
    for img in [img_1, img_2, img_3, img_4]:
        if img is not None: 
            images.append(pil_to_b64(resize_image(img, init_image_height)))
    if USE_FASTAPI:
        return call_inpainting_api(img_1, img_2, img_3, img_4, alpha_gradient_width, init_image_height)
    else: 
        return b64_to_pil(process_images_and_inpaint(images, int(alpha_gradient_width), int(init_image_height)))

def call_inpainting_api(img_1, img_2, img_3, img_4, alpha_gradient_width, init_image_height):
    images = []
    for img in [img_1, img_2, img_3, img_4]:
        if img is not None: 
            images.append(pil_to_b64(resize_image(img, init_image_height)))
    response = requests.post(FAST_API_ENDPOINT, json={
        "images": images, 
        "alpha_gradient_width": alpha_gradient_width,
        "init_image_height": init_image_height
    })
    if response.status_code == 200:
        return b64_to_pil(response.json()["inpainted_image"])
    else:
        return "Error calling inpainting API"

TITLE = """<h2 align="center"> 🎞️ Memory Carousel </h2>"""

# Define the Gradio interface 
with gr.Blocks() as demo:
    gr.HTML(TITLE)
    with gr.Column():
        with gr.Row():
            input_image_1 = gr.Image(type='pil', label="First image")
            input_image_2 = gr.Image(type='pil', label="Second image")
        with gr.Row():
            input_image_3 = gr.Image(type='pil', label="Third image(optional)")
            input_image_4 = gr.Image(type='pil', label="Fourth image(optional)")
        with gr.Row():
            alpha_gradient_width = gr.Number(value=100, label="Alpha Gradient Width")
            init_image_height = gr.Number(value=768, label="Init Image Height")
        generate_button = gr.Button("Generate")
        output = gr.Image(type='pil')

        generate_button.click(
            fn=run_inpainting, 
            inputs=[input_image_1, input_image_2, input_image_3, input_image_4, alpha_gradient_width, init_image_height], 
            outputs=[output]
        )
demo.launch()