ucalyptus's picture
Add application file
880233f
raw
history blame
2.43 kB
import gradio as gr
import torch
from drag_gan import stylegan2, drag_gan
from PIL import Image
device = 'cuda'
g_ema = stylegan2().to(device)
def to_image(tensor):
tensor = tensor.squeeze(0).permute(1, 2, 0)
arr = tensor.detach().cpu().numpy()
arr = (arr - arr.min()) / (arr.max() - arr.min())
arr = arr * 255
return arr.astype('uint8')
def on_click(image, target_point, points, evt: gr.SelectData):
x = evt.index[1]
y = evt.index[0]
if target_point:
image[x:x + 5, y:y + 5, :] = 255
points['target'].append([evt.index[1], evt.index[0]])
return image, str(evt.index)
points['handle'].append([evt.index[1], evt.index[0]])
image[x:x + 5, y:y + 5, :] = 0
return image, str(evt.index)
def on_drag(points, max_iters, state):
max_iters = int(max_iters)
latent = state['latent']
noise = state['noise']
F = state['F']
handle_points = [torch.tensor(p).float() for p in points['handle']]
target_points = [torch.tensor(p).float() for p in points['target']]
mask = torch.zeros((1, 1, 1024, 1024)).to(device)
mask[..., 720:820, 390:600] = 1
for sample2, latent, F in drag_gan(g_ema, latent, noise, F,
handle_points, target_points, mask,
max_iters=max_iters):
points = {'target': [], 'handle': []}
image = to_image(sample2)
state['F'] = F
state['latent'] = latent
yield points, image, state
def main():
torch.cuda.manual_seed(25)
sample_z = torch.randn([1, 512], device=device)
latent, noise = g_ema.prepare([sample_z])
sample, F = g_ema.generate(latent, noise)
with gr.Blocks() as demo:
state = gr.State({
'latent': latent,
'noise': noise,
'F': F,
})
max_iters = gr.Slider(1, 100, 5, label='Max Iterations')
image = gr.Image(to_image(sample)).style(height=512, width=512)
text = gr.Textbox()
btn = gr.Button('Drag it')
points = gr.State({'target': [], 'handle': []})
target_point = gr.Checkbox(label='Target Point')
image.select(on_click, [image, target_point, points], [image, text])
btn.click(on_drag, inputs=[points, max_iters, state], outputs=[points, image, state])
demo.queue(concurrency_count=5, max_size=20).launch()
if __name__ == '__main__':
main()