Evan Davis
Add files from GH
da180b6
raw
history blame
6.95 kB
import time
import torch
import streamlit as st
from PIL import Image, ImageDraw
from streamlit_image_coordinates import streamlit_image_coordinates
import draggan
import utils
## Default to CPU if no GPU is available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
### Streamlit setup ###
st.set_page_config(
page_title="DragGAN Demo",
page_icon="πŸ‰",
layout="wide",
)
message_container = st.empty()
col1, col2 = st.columns([1, 2])
def reset():
st.session_state.clear()
def reset_rerun():
reset()
st.experimental_rerun()
### Run/Reset buttons in right col ###
with col2:
st.markdown("")
but_col1, but_col2 = st.columns([1,7])
run_button = but_col1.button("▢️ Run")
reset_button = but_col2.button("πŸ” Reset")
### Settings panel in left col ###
with col1:
st.header("πŸ‰ DragGAN")
settings_col1, settings_col2 = st.columns([1,1])
# Models from Self-Distilled SG https://github.com/self-distilled-stylegan/self-distilled-internet-photos
model_options = {
"Lions": "https://storage.googleapis.com/self-distilled-stylegan/lions_512_pytorch.pkl",
"Faces (FFHQ)": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl",
"Elephants": "https://storage.googleapis.com/self-distilled-stylegan/elephants_512_pytorch.pkl",
"Parrots": "https://storage.googleapis.com/self-distilled-stylegan/parrots_512_pytorch.pkl",
"Horses": "https://storage.googleapis.com/self-distilled-stylegan/horses_256_pytorch.pkl",
"Bicycles": "https://storage.googleapis.com/self-distilled-stylegan/bicycles_256_pytorch.pkl",
"Giraffes": "https://storage.googleapis.com/self-distilled-stylegan/giraffes_512_pytorch.pkl",
"Dogs (1)": "https://storage.googleapis.com/self-distilled-stylegan/dogs_1024_pytorch.pkl",
"Dogs (2)": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqdog.pkl",
"Cats": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqcat.pkl",
"Wildlife": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqwild.pkl",
"MetFaces": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl",
}
model_name = str(settings_col1.selectbox("Model", list(model_options.keys()), on_change=reset, help="StyleGAN2 model to use, downloaded and cached on first run"))
model_url = model_options[model_name]
seed = settings_col2.number_input("Seed", value=22, step=1, min_value=0, on_change=reset, help="Random seed for generating W+ latent")
target_resolution = int(settings_col1.selectbox("Resolution", [256, 512, 1024], index=1, on_change=reset, help="Resize generated image to this resolution (may be different than native model resolution)"))
n_iter = int(settings_col1.number_input("Iterations", value=200, step=5, help="Number of iterations to run optimization", on_change=reset))
step_size = settings_col2.number_input("Step Size", value=1e-3, step=1e-4, min_value=1e-4, format="%.4f", help="Step size (Learning Rate) for gradient descent")
multiplier = settings_col1.number_input("Speed", value=1.0, step=0.05, min_value=0.05, help="Multiplier for target patch movement")
tolerance = settings_col2.number_input("Tolerance", value=2, step=1, min_value=1, help="Number of pixels away from target to stop")
display_every = settings_col2.number_input("Display Every", value=25, step=1, min_value=1, help="Display image during optimization every n iterations")
truncation_psi = settings_col1.number_input("Truncation", value=0.8, step=0.1, min_value=0.0, on_change=reset, help="Truncation trick value to control diversity (higher = more diverse)")
truncation_cutoff = settings_col2.number_input(
"Truncation Cutoff", value=8, step=1, min_value=-1, max_value=18, on_change=reset, help="Number of layers to apply truncation to (-1 = all layers)"
)
if reset_button:
reset_rerun()
if "points" not in st.session_state:
st.session_state["points"] = []
st.session_state["points_types"] = []
# State variable to track whether the next click should be a 'handle' or 'target'
st.session_state["next_click"] = "handle"
s = time.perf_counter()
G = draggan.load_model(model_url, device=device)
if "W" not in st.session_state:
W = draggan.generate_W(
G,
seed=int(seed),
truncation_psi=truncation_psi,
truncation_cutoff=int(truncation_cutoff),
network_pkl=model_url,
device=device,
)
else:
W = st.session_state["W"]
img, F0 = draggan.generate_image(W, G, network_pkl=model_url, device=device)
if img.size[0] != target_resolution:
img = img.resize((target_resolution, target_resolution))
print(f"Generated image in {(time.perf_counter() - s)*1000:.0f}ms")
# Draw an ellipse at each coordinate in points
if "points" in st.session_state and "points_types" in st.session_state:
handles, targets = [], []
for point, point_type in zip(
st.session_state["points"], st.session_state["points_types"]
):
if point_type == "handle":
handles.append(point)
else:
targets.append(point)
if len(handles) > 0:
utils.draw_handle_target_points(img, handles, targets)
### Right column image container ###
with col2:
empty = st.empty()
with empty.container():
value = streamlit_image_coordinates(img, key="pil")
# New point is clicked
if value is not None:
point = value["x"], value["y"]
if point not in st.session_state["points"]:
# st.session_state["points"].append(point)
st.session_state["points"].append(point)
st.session_state["points_types"].append(st.session_state["next_click"])
st.session_state["next_click"] = (
"target" if st.session_state["next_click"] == "handle" else "handle"
)
st.experimental_rerun()
## Optimization loop
if run_button:
if len(handles) > 0 and len(targets) > 0 and len(handles) == len(targets) and all(targets):
W = draggan.optimize(
W,
G,
handle_points=handles,
target_points=targets,
r1=3,
r2=12,
tolerance=tolerance,
max_iter=n_iter,
lr=step_size,
multiplier=multiplier,
empty=empty,
display_every=display_every,
target_resolution=target_resolution,
device=device,
)
# st.write(handles)
# st.write(targets)
st.session_state.clear()
st.session_state["W"] = W
st.experimental_rerun()
else:
message_container.warning("Please add at least one handle and one target point.")