File size: 2,732 Bytes
56330ee 7a99816 56330ee 7a99816 56330ee 7a99816 56330ee 7a99816 56330ee 7a99816 56330ee feaab3e 56330ee 7a99816 56330ee 7a99816 56330ee 7a99816 56330ee 7a99816 56330ee 7a99816 56330ee 7a99816 56330ee 7a99816 56330ee 7a99816 56330ee 7a99816 56330ee 7a99816 56330ee 7a99816 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
import gradio as gr
from src.image_prep import canny_from_pil
from src.pix2pix_turbo import Pix2Pix_Turbo
# Initialize the model
model = Pix2Pix_Turbo("edge_to_image")
# Define the processing function
def process(input_image, prompt, low_threshold, high_threshold):
# Resize to be a multiple of 8
new_width = input_image.width - input_image.width % 8
new_height = input_image.height - input_image.height % 8
input_image = input_image.resize((new_width, new_height))
# Generate canny edge image
canny = canny_from_pil(input_image, low_threshold, high_threshold)
# Convert to tensor and process with model
with torch.no_grad():
c_t = transforms.ToTensor()(canny).unsqueeze(0)
output_image = model(c_t, prompt)
output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
# Visualize canny edges (invert colors)
canny_viz = 1 - (np.array(canny) / 255)
canny_viz = Image.fromarray((canny_viz * 255).astype(np.uint8))
return canny_viz, output_pil
if __name__ == "__main__":
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Pix2pix-Turbo: **Canny Edge -> Image**")
with gr.Row():
with gr.Column():
input_image = gr.Image(source="upload", type="pil")
prompt = gr.Textbox(label="Prompt")
low_threshold = gr.Slider(
label="Canny low threshold",
minimum=1,
maximum=255,
value=100,
step=10
)
high_threshold = gr.Slider(
label="Canny high threshold",
minimum=1,
maximum=255,
value=200,
step=10
)
run_button = gr.Button(value="Run")
with gr.Column():
result_canny = gr.Image(type="pil")
with gr.Column():
result_output = gr.Image(type="pil")
# Set up event handlers
inputs = [input_image, prompt, low_threshold, high_threshold]
outputs = [result_canny, result_output]
prompt.submit(fn=process, inputs=inputs, outputs=outputs)
low_threshold.change(fn=process, inputs=inputs, outputs=outputs)
high_threshold.change(fn=process, inputs=inputs, outputs=outputs)
run_button.click(fn=process, inputs=inputs, outputs=outputs)
# Launch the Gradio interface
demo.queue()
demo.launch(debug=True, share=False)
|