clement-bonnet commited on
Commit
999b913
1 Parent(s): 389de0a

feat: code wip

Browse files
Files changed (6) hide show
  1. app.py +126 -113
  2. imgs/pattern_1.png +0 -0
  3. imgs/pattern_2.png +0 -0
  4. inference.py +96 -0
  5. requirements.txt +1 -1
  6. utils.py +12 -0
app.py CHANGED
@@ -1,122 +1,135 @@
1
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
  import numpy as np
4
  from PIL import Image
5
 
6
- # Placeholder for your actual model
7
- def generate_image(image_idx: int, x: float, y: float) -> Image.Image:
8
- """
9
- Replace this with your actual model inference
10
- """
11
- # This is just a placeholder - replace with your model
12
- # Creating a simple gradient image as example output
13
- width, height = 256, 256
14
- gradient = np.zeros((height, width, 3), dtype=np.uint8)
15
- gradient[:, :, 0] = np.linspace(0, 255 * x, width)
16
- gradient[:, :, 1] = np.linspace(0, 255 * y, height)[:, np.newaxis]
17
- gradient[:, :, 2] = image_idx * 30 # vary blue channel based on selected image
18
- return Image.fromarray(gradient)
19
-
20
- def process_click(image_idx: int, evt: gr.SelectData) -> Image.Image:
21
- """
22
- Process the click event on the coordinate selector
23
- """
24
- # Extract coordinates from click event
25
- x, y = evt.index[0], evt.index[1]
26
- # Normalize coordinates to [0, 1]
27
- x, y = x/100, y/100
28
- # Generate image using the model
29
- return generate_image(image_idx, x, y)
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  with gr.Blocks() as demo:
32
- gr.Markdown("""
33
- # Interactive Image Generation
34
- Choose a reference image and click on the coordinate selector to generate a new image.
35
- """)
36
-
37
- with gr.Row():
38
- # Left column: Reference images and coordinate selector
39
- with gr.Column(scale=1):
40
- # Radio buttons for image selection
41
- image_idx = gr.Radio(
42
- choices=[i for i in range(4)], # Replace with your actual number of images
43
- value=0,
44
- label="Select Reference Image",
45
- type="index"
46
- )
47
-
48
- # Display reference images
49
- gallery = gr.Gallery(
50
- value=[
51
- "image_0.jpg",
52
- "image_0.jpg",
53
- "image_0.jpg",
54
- "image_0.jpg",
55
- ],
56
- columns=2,
57
- rows=2,
58
- height=300,
59
- label="Reference Images"
60
- )
61
-
62
- # Coordinate selector (displayed as heatmap for click interaction)
63
- coord_selector = gr.Plot(
64
- value=None,
65
- label="Click to select (x, y) coordinates"
66
- )
67
-
68
- # Initialize the coordinate selector
69
- def create_selector():
70
- import plotly.graph_objects as go
71
- fig = go.Figure()
72
-
73
- # Add a square shape
74
- fig.add_trace(go.Scatter(
75
- x=[0, 100, 100, 0, 0],
76
- y=[0, 0, 100, 100, 0],
77
- mode='lines',
78
- line=dict(color='black'),
79
- showlegend=False
80
- ))
81
-
82
- # Update layout
83
- fig.update_layout(
84
- width=300,
85
- height=300,
86
- margin=dict(l=0, r=0, t=0, b=0),
87
- xaxis=dict(
88
- range=[-5, 105],
89
- showgrid=False,
90
- zeroline=False,
91
- visible=False
92
- ),
93
- yaxis=dict(
94
- range=[-5, 105],
95
- showgrid=False,
96
- zeroline=False,
97
- visible=False,
98
- scaleanchor='x'
99
- ),
100
- plot_bgcolor='white'
101
- )
102
- return fig
103
-
104
- # Initialize the coordinate selector
105
- coord_selector.value = create_selector()
106
-
107
- # Right column: Generated image
108
- with gr.Column(scale=1):
109
- output_image = gr.Image(
110
- label="Generated Output",
111
- height=300
112
- )
113
-
114
- # Handle click events
115
- coord_selector.select(
116
- process_click,
117
- inputs=[image_idx],
118
- outputs=output_image
119
  )
120
 
 
 
 
 
 
 
121
  # Launch the app
122
- demo.launch()
 
 
1
+ # import gradio as gr
2
+ # import numpy as np
3
+ # from PIL import Image
4
+
5
+ # from inference import generate_image
6
+
7
+
8
+ # # Create a square image for the coordinate selector
9
+ # def create_selector_image():
10
+ # # Create a white square with black border
11
+ # size = 400
12
+ # border = 2
13
+ # img = np.ones((size, size, 3), dtype=np.uint8) * 255
14
+ # # Add black border
15
+ # img[:border, :] = 0 # top
16
+ # img[-border:, :] = 0 # bottom
17
+ # img[:, :border] = 0 # left
18
+ # img[:, -border:] = 0 # right
19
+ # return Image.fromarray(img)
20
+
21
+
22
+ # def process_click(image_idx: int, x: int, y: int) -> tuple[Image.Image, str]:
23
+ # """
24
+ # Process the click event on the coordinate selector
25
+ # """
26
+ # try:
27
+ # # Normalize coordinates to [0, 1]
28
+ # x_norm, y_norm = x / 400, y / 400 # Divide by image size (400x400)
29
+
30
+ # # Debug message
31
+ # debug_msg = f"Processing: image_idx={image_idx}, coordinates=({x_norm:.3f}, {y_norm:.3f})"
32
+ # print(debug_msg)
33
+
34
+ # # Generate image using the model
35
+ # generated_img = generate_image(image_idx, x_norm, y_norm)
36
+ # return generated_img, debug_msg
37
+ # except Exception as e:
38
+ # error_msg = f"Error: {str(e)}"
39
+ # print(error_msg)
40
+ # return None, error_msg
41
+
42
+
43
+ # with gr.Blocks() as demo:
44
+ # gr.Markdown(
45
+ # """
46
+ # # Interactive Image Generation
47
+ # Choose a reference image and click on the coordinate selector to generate a new image.
48
+ # """
49
+ # )
50
+
51
+ # with gr.Row():
52
+ # # Left column: Reference images and coordinate selector
53
+ # with gr.Column(scale=1):
54
+ # # Radio buttons for image selection
55
+ # image_idx = gr.Radio(
56
+ # choices=list(range(2)), value=0, label="Select Reference Image", type="index"
57
+ # )
58
+
59
+ # # Display reference images
60
+ # gallery = gr.Gallery(
61
+ # value=["imgs/pattern_1.png", "imgs/pattern_2.png"],
62
+ # columns=2,
63
+ # rows=1,
64
+ # height=500,
65
+ # label="Different Tasks",
66
+ # )
67
+
68
+ # # Coordinate selector
69
+ # coord_selector = gr.Image(
70
+ # value=create_selector_image(),
71
+ # label="Click to select (x, y) coordinates",
72
+ # show_label=True,
73
+ # interactive=True,
74
+ # height=400,
75
+ # width=400,
76
+ # )
77
+
78
+ # # Right column: Generated image and debug info
79
+ # with gr.Column(scale=1):
80
+ # output_image = gr.Image(label="Generated Image", height=400)
81
+ # debug_text = gr.Textbox(label="Debug Info", interactive=False)
82
+
83
+ # # Handle click events using click instead of select
84
+ # coord_selector.click(
85
+ # fn=process_click,
86
+ # inputs=[image_idx, coord_selector], # coord_selector will provide x, y coordinates
87
+ # outputs=[output_image, debug_text],
88
+ # )
89
+
90
+ # if __name__ == "__main__":
91
+ # print("Starting Gradio app...")
92
+ # demo.launch(debug=True)
93
+
94
+
95
  import gradio as gr
96
  import numpy as np
97
  from PIL import Image
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ def create_white_square(size=400):
101
+ # Create a white square image
102
+ print("Creating white square")
103
+ return np.full((size, size, 3), 255, dtype=np.uint8)
104
+
105
+
106
+ def get_click_coordinates(evt: gr.SelectData):
107
+ # Get click coordinates
108
+ x, y = evt.index
109
+ print(f"Clicked at coordinates: x={x}, y={y}")
110
+ return f"Clicked at coordinates: x={x}, y={y}"
111
+
112
+
113
+ # Create the interface
114
  with gr.Blocks() as demo:
115
+ gr.Markdown("## Click Coordinate Detector\nClick anywhere on the white square to see coordinates")
116
+
117
+ # Display the white square
118
+ image = gr.Image(
119
+ label="Click on the white square",
120
+ value=create_white_square(),
121
+ interactive=True,
122
+ height=400,
123
+ width=400,
124
+ mirror_webcam=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  )
126
 
127
+ # Display coordinates
128
+ output_text = gr.Textbox(label="Coordinates")
129
+ print("oh yeah")
130
+ # Handle click events
131
+ image.select(get_click_coordinates, inputs=[], outputs=output_text)
132
+
133
  # Launch the app
134
+ if __name__ == "__main__":
135
+ demo.launch()
imgs/pattern_1.png ADDED
imgs/pattern_2.png ADDED
inference.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append("..")
5
+
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ import hydra
9
+ import omegaconf
10
+ import jax
11
+ import jax.numpy as jnp
12
+ import optax
13
+ from flax.training.train_state import TrainState
14
+ from flax.serialization import from_bytes
15
+ from huggingface_hub import snapshot_download
16
+
17
+ # lpn imports
18
+ from src.models.lpn import LPN
19
+ from src.models.transformer import EncoderTransformer, DecoderTransformer
20
+ from src.visualization import display_grid, ax_to_pil
21
+
22
+ from utils import patch_target
23
+
24
+
25
+ checkpoint_name = "quiet-thunder-789--checkpoint:v0"
26
+ BLUE_LOCATION_INPUTS = {1: 13, 2: 9}
27
+
28
+ local_dir = snapshot_download(repo_id="clement-bonnet/lpn-2d", allow_patterns=f"{checkpoint_name}/*")
29
+ with open(f"{local_dir}/{checkpoint_name}/config.yaml", "r") as f:
30
+ cfg = omegaconf.OmegaConf.load(f)
31
+ patch_target(cfg)
32
+
33
+ encoder = EncoderTransformer(hydra.utils.instantiate(cfg.encoder_transformer))
34
+ decoder = DecoderTransformer(hydra.utils.instantiate(cfg.decoder_transformer))
35
+ lpn = LPN(encoder=encoder, decoder=decoder)
36
+
37
+ key = jax.random.PRNGKey(0)
38
+ grids = jax.random.randint(
39
+ key,
40
+ (1, 3, decoder.config.max_rows, decoder.config.max_cols, 2),
41
+ minval=0,
42
+ maxval=decoder.config.vocab_size,
43
+ )
44
+ shapes = jax.random.randint(
45
+ key,
46
+ (1, 3, 2, 2),
47
+ minval=1,
48
+ maxval=min(decoder.config.max_rows, decoder.config.max_cols) + 1,
49
+ )
50
+ variables = lpn.init(
51
+ key, grids, shapes, dropout_eval=False, prior_kl_coeff=0.0, pairwise_kl_coeff=0.0, mode="mean"
52
+ )
53
+ learning_rate, linear_warmup_steps = 0, 0
54
+ linear_warmup_scheduler = optax.warmup_exponential_decay_schedule(
55
+ init_value=learning_rate / (linear_warmup_steps + 1),
56
+ peak_value=learning_rate,
57
+ warmup_steps=linear_warmup_steps,
58
+ transition_steps=1,
59
+ end_value=learning_rate,
60
+ decay_rate=1.0,
61
+ )
62
+ optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(linear_warmup_scheduler))
63
+ optimizer = optax.MultiSteps(optimizer, every_k_schedule=1)
64
+ train_state = TrainState.create(apply_fn=lpn.apply, tx=optimizer, params=variables["params"])
65
+
66
+ with open(os.path.join(local_dir, checkpoint_name, "state.msgpack"), "rb") as data_file:
67
+ byte_data = data_file.read()
68
+ loaded_state = from_bytes(train_state, byte_data)
69
+
70
+ generate_output_from_context = jax.jit(
71
+ lambda context, input, input_grid_shape: lpn.apply(
72
+ {"params": loaded_state.params},
73
+ context=context,
74
+ input=input,
75
+ input_grid_shape=input_grid_shape,
76
+ dropout_eval=True,
77
+ method=lpn._generate_output_from_context,
78
+ )
79
+ )
80
+
81
+
82
+ def generate_image(image_idx: int, x: float, y: float, eps: float = 1e-4) -> Image.Image:
83
+ # Create the input image
84
+ input = jnp.zeros(16, int).at[BLUE_LOCATION_INPUTS[image_idx]].set(1).reshape(4, 4)
85
+ # Ensure x and y are in [eps, 1 - eps]
86
+ x = min(1 - eps, max(eps, x))
87
+ y = min(1 - eps, max(eps, y))
88
+ # Convert x and y to context in R^2
89
+ context = jax.scipy.stats.norm.ppf(jnp.array([x, y]))
90
+ output_grids, _ = generate_output_from_context(
91
+ context=context[None], input=input[None], input_grid_shape=jnp.array([4, 4])[None]
92
+ )
93
+ output_grid = output_grids[0]
94
+ _, ax = plt.subplots(1, 1, figsize=(4, 4))
95
+ display_grid(ax=ax, grid=output_grid, grid_shape=jnp.array([4, 4]))
96
+ return ax_to_pil(ax)
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
  gradio
2
  plotly
3
- git+https://github.com/clement-bonnet/lpn.git@edbe4722340719cc36b5a755fec7213cb8efb9f7
 
1
  gradio
2
  plotly
3
+ git+https://github.com/clement-bonnet/lpn.git@f1bb82598454e897b3d4cb9f313d941943382877
utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import omegaconf
2
+
3
+
4
+ def patch_target(config):
5
+ """Update the _target_ of cfg from src_v2 to src"""
6
+ for key, value in config.items():
7
+ if isinstance(value, omegaconf.DictConfig):
8
+ # Recursive call if the value is another DictConfig
9
+ patch_target(value)
10
+ elif isinstance(value, str) and value.startswith("src_v2"):
11
+ # Update the value if it matches the old_value
12
+ config[key] = value.replace("src_v2", "src")