Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from drag_gan import stylegan2, drag_gan | |
from PIL import Image | |
device = 'cuda' | |
g_ema = stylegan2().to(device) | |
def to_image(tensor): | |
tensor = tensor.squeeze(0).permute(1, 2, 0) | |
arr = tensor.detach().cpu().numpy() | |
arr = (arr - arr.min()) / (arr.max() - arr.min()) | |
arr = arr * 255 | |
return arr.astype('uint8') | |
def on_click(image, target_point, points, evt: gr.SelectData): | |
x = evt.index[1] | |
y = evt.index[0] | |
if target_point: | |
image[x:x + 5, y:y + 5, :] = 255 | |
points['target'].append([evt.index[1], evt.index[0]]) | |
return image, str(evt.index) | |
points['handle'].append([evt.index[1], evt.index[0]]) | |
image[x:x + 5, y:y + 5, :] = 0 | |
return image, str(evt.index) | |
def on_drag(points, max_iters, state): | |
max_iters = int(max_iters) | |
latent = state['latent'] | |
noise = state['noise'] | |
F = state['F'] | |
handle_points = [torch.tensor(p).float() for p in points['handle']] | |
target_points = [torch.tensor(p).float() for p in points['target']] | |
mask = torch.zeros((1, 1, 1024, 1024)).to(device) | |
mask[..., 720:820, 390:600] = 1 | |
for sample2, latent, F in drag_gan(g_ema, latent, noise, F, | |
handle_points, target_points, mask, | |
max_iters=max_iters): | |
points = {'target': [], 'handle': []} | |
image = to_image(sample2) | |
state['F'] = F | |
state['latent'] = latent | |
yield points, image, state | |
def main(): | |
torch.cuda.manual_seed(25) | |
sample_z = torch.randn([1, 512], device=device) | |
latent, noise = g_ema.prepare([sample_z]) | |
sample, F = g_ema.generate(latent, noise) | |
with gr.Blocks() as demo: | |
state = gr.State({ | |
'latent': latent, | |
'noise': noise, | |
'F': F, | |
}) | |
max_iters = gr.Slider(1, 100, 5, label='Max Iterations') | |
image = gr.Image(to_image(sample)).style(height=512, width=512) | |
text = gr.Textbox() | |
btn = gr.Button('Drag it') | |
points = gr.State({'target': [], 'handle': []}) | |
target_point = gr.Checkbox(label='Target Point') | |
image.select(on_click, [image, target_point, points], [image, text]) | |
btn.click(on_drag, inputs=[points, max_iters, state], outputs=[points, image, state]) | |
demo.queue(concurrency_count=5, max_size=20).launch() | |
if __name__ == '__main__': | |
main() | |