Spaces:
Running
Running
clement-bonnet
commited on
Commit
•
999b913
1
Parent(s):
389de0a
feat: code wip
Browse files- app.py +126 -113
- imgs/pattern_1.png +0 -0
- imgs/pattern_2.png +0 -0
- inference.py +96 -0
- requirements.txt +1 -1
- utils.py +12 -0
app.py
CHANGED
@@ -1,122 +1,135 @@
|
|
1 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
|
6 |
-
# Placeholder for your actual model
|
7 |
-
def generate_image(image_idx: int, x: float, y: float) -> Image.Image:
|
8 |
-
"""
|
9 |
-
Replace this with your actual model inference
|
10 |
-
"""
|
11 |
-
# This is just a placeholder - replace with your model
|
12 |
-
# Creating a simple gradient image as example output
|
13 |
-
width, height = 256, 256
|
14 |
-
gradient = np.zeros((height, width, 3), dtype=np.uint8)
|
15 |
-
gradient[:, :, 0] = np.linspace(0, 255 * x, width)
|
16 |
-
gradient[:, :, 1] = np.linspace(0, 255 * y, height)[:, np.newaxis]
|
17 |
-
gradient[:, :, 2] = image_idx * 30 # vary blue channel based on selected image
|
18 |
-
return Image.fromarray(gradient)
|
19 |
-
|
20 |
-
def process_click(image_idx: int, evt: gr.SelectData) -> Image.Image:
|
21 |
-
"""
|
22 |
-
Process the click event on the coordinate selector
|
23 |
-
"""
|
24 |
-
# Extract coordinates from click event
|
25 |
-
x, y = evt.index[0], evt.index[1]
|
26 |
-
# Normalize coordinates to [0, 1]
|
27 |
-
x, y = x/100, y/100
|
28 |
-
# Generate image using the model
|
29 |
-
return generate_image(image_idx, x, y)
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
with gr.Blocks() as demo:
|
32 |
-
gr.Markdown(""
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
choices=[i for i in range(4)], # Replace with your actual number of images
|
43 |
-
value=0,
|
44 |
-
label="Select Reference Image",
|
45 |
-
type="index"
|
46 |
-
)
|
47 |
-
|
48 |
-
# Display reference images
|
49 |
-
gallery = gr.Gallery(
|
50 |
-
value=[
|
51 |
-
"image_0.jpg",
|
52 |
-
"image_0.jpg",
|
53 |
-
"image_0.jpg",
|
54 |
-
"image_0.jpg",
|
55 |
-
],
|
56 |
-
columns=2,
|
57 |
-
rows=2,
|
58 |
-
height=300,
|
59 |
-
label="Reference Images"
|
60 |
-
)
|
61 |
-
|
62 |
-
# Coordinate selector (displayed as heatmap for click interaction)
|
63 |
-
coord_selector = gr.Plot(
|
64 |
-
value=None,
|
65 |
-
label="Click to select (x, y) coordinates"
|
66 |
-
)
|
67 |
-
|
68 |
-
# Initialize the coordinate selector
|
69 |
-
def create_selector():
|
70 |
-
import plotly.graph_objects as go
|
71 |
-
fig = go.Figure()
|
72 |
-
|
73 |
-
# Add a square shape
|
74 |
-
fig.add_trace(go.Scatter(
|
75 |
-
x=[0, 100, 100, 0, 0],
|
76 |
-
y=[0, 0, 100, 100, 0],
|
77 |
-
mode='lines',
|
78 |
-
line=dict(color='black'),
|
79 |
-
showlegend=False
|
80 |
-
))
|
81 |
-
|
82 |
-
# Update layout
|
83 |
-
fig.update_layout(
|
84 |
-
width=300,
|
85 |
-
height=300,
|
86 |
-
margin=dict(l=0, r=0, t=0, b=0),
|
87 |
-
xaxis=dict(
|
88 |
-
range=[-5, 105],
|
89 |
-
showgrid=False,
|
90 |
-
zeroline=False,
|
91 |
-
visible=False
|
92 |
-
),
|
93 |
-
yaxis=dict(
|
94 |
-
range=[-5, 105],
|
95 |
-
showgrid=False,
|
96 |
-
zeroline=False,
|
97 |
-
visible=False,
|
98 |
-
scaleanchor='x'
|
99 |
-
),
|
100 |
-
plot_bgcolor='white'
|
101 |
-
)
|
102 |
-
return fig
|
103 |
-
|
104 |
-
# Initialize the coordinate selector
|
105 |
-
coord_selector.value = create_selector()
|
106 |
-
|
107 |
-
# Right column: Generated image
|
108 |
-
with gr.Column(scale=1):
|
109 |
-
output_image = gr.Image(
|
110 |
-
label="Generated Output",
|
111 |
-
height=300
|
112 |
-
)
|
113 |
-
|
114 |
-
# Handle click events
|
115 |
-
coord_selector.select(
|
116 |
-
process_click,
|
117 |
-
inputs=[image_idx],
|
118 |
-
outputs=output_image
|
119 |
)
|
120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
# Launch the app
|
122 |
-
|
|
|
|
1 |
+
# import gradio as gr
|
2 |
+
# import numpy as np
|
3 |
+
# from PIL import Image
|
4 |
+
|
5 |
+
# from inference import generate_image
|
6 |
+
|
7 |
+
|
8 |
+
# # Create a square image for the coordinate selector
|
9 |
+
# def create_selector_image():
|
10 |
+
# # Create a white square with black border
|
11 |
+
# size = 400
|
12 |
+
# border = 2
|
13 |
+
# img = np.ones((size, size, 3), dtype=np.uint8) * 255
|
14 |
+
# # Add black border
|
15 |
+
# img[:border, :] = 0 # top
|
16 |
+
# img[-border:, :] = 0 # bottom
|
17 |
+
# img[:, :border] = 0 # left
|
18 |
+
# img[:, -border:] = 0 # right
|
19 |
+
# return Image.fromarray(img)
|
20 |
+
|
21 |
+
|
22 |
+
# def process_click(image_idx: int, x: int, y: int) -> tuple[Image.Image, str]:
|
23 |
+
# """
|
24 |
+
# Process the click event on the coordinate selector
|
25 |
+
# """
|
26 |
+
# try:
|
27 |
+
# # Normalize coordinates to [0, 1]
|
28 |
+
# x_norm, y_norm = x / 400, y / 400 # Divide by image size (400x400)
|
29 |
+
|
30 |
+
# # Debug message
|
31 |
+
# debug_msg = f"Processing: image_idx={image_idx}, coordinates=({x_norm:.3f}, {y_norm:.3f})"
|
32 |
+
# print(debug_msg)
|
33 |
+
|
34 |
+
# # Generate image using the model
|
35 |
+
# generated_img = generate_image(image_idx, x_norm, y_norm)
|
36 |
+
# return generated_img, debug_msg
|
37 |
+
# except Exception as e:
|
38 |
+
# error_msg = f"Error: {str(e)}"
|
39 |
+
# print(error_msg)
|
40 |
+
# return None, error_msg
|
41 |
+
|
42 |
+
|
43 |
+
# with gr.Blocks() as demo:
|
44 |
+
# gr.Markdown(
|
45 |
+
# """
|
46 |
+
# # Interactive Image Generation
|
47 |
+
# Choose a reference image and click on the coordinate selector to generate a new image.
|
48 |
+
# """
|
49 |
+
# )
|
50 |
+
|
51 |
+
# with gr.Row():
|
52 |
+
# # Left column: Reference images and coordinate selector
|
53 |
+
# with gr.Column(scale=1):
|
54 |
+
# # Radio buttons for image selection
|
55 |
+
# image_idx = gr.Radio(
|
56 |
+
# choices=list(range(2)), value=0, label="Select Reference Image", type="index"
|
57 |
+
# )
|
58 |
+
|
59 |
+
# # Display reference images
|
60 |
+
# gallery = gr.Gallery(
|
61 |
+
# value=["imgs/pattern_1.png", "imgs/pattern_2.png"],
|
62 |
+
# columns=2,
|
63 |
+
# rows=1,
|
64 |
+
# height=500,
|
65 |
+
# label="Different Tasks",
|
66 |
+
# )
|
67 |
+
|
68 |
+
# # Coordinate selector
|
69 |
+
# coord_selector = gr.Image(
|
70 |
+
# value=create_selector_image(),
|
71 |
+
# label="Click to select (x, y) coordinates",
|
72 |
+
# show_label=True,
|
73 |
+
# interactive=True,
|
74 |
+
# height=400,
|
75 |
+
# width=400,
|
76 |
+
# )
|
77 |
+
|
78 |
+
# # Right column: Generated image and debug info
|
79 |
+
# with gr.Column(scale=1):
|
80 |
+
# output_image = gr.Image(label="Generated Image", height=400)
|
81 |
+
# debug_text = gr.Textbox(label="Debug Info", interactive=False)
|
82 |
+
|
83 |
+
# # Handle click events using click instead of select
|
84 |
+
# coord_selector.click(
|
85 |
+
# fn=process_click,
|
86 |
+
# inputs=[image_idx, coord_selector], # coord_selector will provide x, y coordinates
|
87 |
+
# outputs=[output_image, debug_text],
|
88 |
+
# )
|
89 |
+
|
90 |
+
# if __name__ == "__main__":
|
91 |
+
# print("Starting Gradio app...")
|
92 |
+
# demo.launch(debug=True)
|
93 |
+
|
94 |
+
|
95 |
import gradio as gr
|
96 |
import numpy as np
|
97 |
from PIL import Image
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
+
def create_white_square(size=400):
|
101 |
+
# Create a white square image
|
102 |
+
print("Creating white square")
|
103 |
+
return np.full((size, size, 3), 255, dtype=np.uint8)
|
104 |
+
|
105 |
+
|
106 |
+
def get_click_coordinates(evt: gr.SelectData):
|
107 |
+
# Get click coordinates
|
108 |
+
x, y = evt.index
|
109 |
+
print(f"Clicked at coordinates: x={x}, y={y}")
|
110 |
+
return f"Clicked at coordinates: x={x}, y={y}"
|
111 |
+
|
112 |
+
|
113 |
+
# Create the interface
|
114 |
with gr.Blocks() as demo:
|
115 |
+
gr.Markdown("## Click Coordinate Detector\nClick anywhere on the white square to see coordinates")
|
116 |
+
|
117 |
+
# Display the white square
|
118 |
+
image = gr.Image(
|
119 |
+
label="Click on the white square",
|
120 |
+
value=create_white_square(),
|
121 |
+
interactive=True,
|
122 |
+
height=400,
|
123 |
+
width=400,
|
124 |
+
mirror_webcam=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
)
|
126 |
|
127 |
+
# Display coordinates
|
128 |
+
output_text = gr.Textbox(label="Coordinates")
|
129 |
+
print("oh yeah")
|
130 |
+
# Handle click events
|
131 |
+
image.select(get_click_coordinates, inputs=[], outputs=output_text)
|
132 |
+
|
133 |
# Launch the app
|
134 |
+
if __name__ == "__main__":
|
135 |
+
demo.launch()
|
imgs/pattern_1.png
ADDED
imgs/pattern_2.png
ADDED
inference.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append("..")
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import hydra
|
9 |
+
import omegaconf
|
10 |
+
import jax
|
11 |
+
import jax.numpy as jnp
|
12 |
+
import optax
|
13 |
+
from flax.training.train_state import TrainState
|
14 |
+
from flax.serialization import from_bytes
|
15 |
+
from huggingface_hub import snapshot_download
|
16 |
+
|
17 |
+
# lpn imports
|
18 |
+
from src.models.lpn import LPN
|
19 |
+
from src.models.transformer import EncoderTransformer, DecoderTransformer
|
20 |
+
from src.visualization import display_grid, ax_to_pil
|
21 |
+
|
22 |
+
from utils import patch_target
|
23 |
+
|
24 |
+
|
25 |
+
checkpoint_name = "quiet-thunder-789--checkpoint:v0"
|
26 |
+
BLUE_LOCATION_INPUTS = {1: 13, 2: 9}
|
27 |
+
|
28 |
+
local_dir = snapshot_download(repo_id="clement-bonnet/lpn-2d", allow_patterns=f"{checkpoint_name}/*")
|
29 |
+
with open(f"{local_dir}/{checkpoint_name}/config.yaml", "r") as f:
|
30 |
+
cfg = omegaconf.OmegaConf.load(f)
|
31 |
+
patch_target(cfg)
|
32 |
+
|
33 |
+
encoder = EncoderTransformer(hydra.utils.instantiate(cfg.encoder_transformer))
|
34 |
+
decoder = DecoderTransformer(hydra.utils.instantiate(cfg.decoder_transformer))
|
35 |
+
lpn = LPN(encoder=encoder, decoder=decoder)
|
36 |
+
|
37 |
+
key = jax.random.PRNGKey(0)
|
38 |
+
grids = jax.random.randint(
|
39 |
+
key,
|
40 |
+
(1, 3, decoder.config.max_rows, decoder.config.max_cols, 2),
|
41 |
+
minval=0,
|
42 |
+
maxval=decoder.config.vocab_size,
|
43 |
+
)
|
44 |
+
shapes = jax.random.randint(
|
45 |
+
key,
|
46 |
+
(1, 3, 2, 2),
|
47 |
+
minval=1,
|
48 |
+
maxval=min(decoder.config.max_rows, decoder.config.max_cols) + 1,
|
49 |
+
)
|
50 |
+
variables = lpn.init(
|
51 |
+
key, grids, shapes, dropout_eval=False, prior_kl_coeff=0.0, pairwise_kl_coeff=0.0, mode="mean"
|
52 |
+
)
|
53 |
+
learning_rate, linear_warmup_steps = 0, 0
|
54 |
+
linear_warmup_scheduler = optax.warmup_exponential_decay_schedule(
|
55 |
+
init_value=learning_rate / (linear_warmup_steps + 1),
|
56 |
+
peak_value=learning_rate,
|
57 |
+
warmup_steps=linear_warmup_steps,
|
58 |
+
transition_steps=1,
|
59 |
+
end_value=learning_rate,
|
60 |
+
decay_rate=1.0,
|
61 |
+
)
|
62 |
+
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(linear_warmup_scheduler))
|
63 |
+
optimizer = optax.MultiSteps(optimizer, every_k_schedule=1)
|
64 |
+
train_state = TrainState.create(apply_fn=lpn.apply, tx=optimizer, params=variables["params"])
|
65 |
+
|
66 |
+
with open(os.path.join(local_dir, checkpoint_name, "state.msgpack"), "rb") as data_file:
|
67 |
+
byte_data = data_file.read()
|
68 |
+
loaded_state = from_bytes(train_state, byte_data)
|
69 |
+
|
70 |
+
generate_output_from_context = jax.jit(
|
71 |
+
lambda context, input, input_grid_shape: lpn.apply(
|
72 |
+
{"params": loaded_state.params},
|
73 |
+
context=context,
|
74 |
+
input=input,
|
75 |
+
input_grid_shape=input_grid_shape,
|
76 |
+
dropout_eval=True,
|
77 |
+
method=lpn._generate_output_from_context,
|
78 |
+
)
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
def generate_image(image_idx: int, x: float, y: float, eps: float = 1e-4) -> Image.Image:
|
83 |
+
# Create the input image
|
84 |
+
input = jnp.zeros(16, int).at[BLUE_LOCATION_INPUTS[image_idx]].set(1).reshape(4, 4)
|
85 |
+
# Ensure x and y are in [eps, 1 - eps]
|
86 |
+
x = min(1 - eps, max(eps, x))
|
87 |
+
y = min(1 - eps, max(eps, y))
|
88 |
+
# Convert x and y to context in R^2
|
89 |
+
context = jax.scipy.stats.norm.ppf(jnp.array([x, y]))
|
90 |
+
output_grids, _ = generate_output_from_context(
|
91 |
+
context=context[None], input=input[None], input_grid_shape=jnp.array([4, 4])[None]
|
92 |
+
)
|
93 |
+
output_grid = output_grids[0]
|
94 |
+
_, ax = plt.subplots(1, 1, figsize=(4, 4))
|
95 |
+
display_grid(ax=ax, grid=output_grid, grid_shape=jnp.array([4, 4]))
|
96 |
+
return ax_to_pil(ax)
|
requirements.txt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
gradio
|
2 |
plotly
|
3 |
-
git+https://github.com/clement-bonnet/lpn.git@
|
|
|
1 |
gradio
|
2 |
plotly
|
3 |
+
git+https://github.com/clement-bonnet/lpn.git@f1bb82598454e897b3d4cb9f313d941943382877
|
utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import omegaconf
|
2 |
+
|
3 |
+
|
4 |
+
def patch_target(config):
|
5 |
+
"""Update the _target_ of cfg from src_v2 to src"""
|
6 |
+
for key, value in config.items():
|
7 |
+
if isinstance(value, omegaconf.DictConfig):
|
8 |
+
# Recursive call if the value is another DictConfig
|
9 |
+
patch_target(value)
|
10 |
+
elif isinstance(value, str) and value.startswith("src_v2"):
|
11 |
+
# Update the value if it matches the old_value
|
12 |
+
config[key] = value.replace("src_v2", "src")
|