xu3kev commited on
Commit
9bf1f45
1 Parent(s): 01e8894
Files changed (1) hide show
  1. app.py +432 -4
app.py CHANGED
@@ -1,7 +1,435 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import numpy as np
4
+
5
  import gradio as gr
6
+ import requests
7
+ from openai import OpenAI
8
+ from func_timeout import FunctionTimedOut, func_timeout
9
+ from tqdm import tqdm
10
+
11
+ MOCK = True
12
+ TEST_FOLDER = "c4f5"
13
+
14
+ INPUT_STRUCTION_TEMPLATE = """Here is a gray scale images representing with integer values 0-9.
15
+ {image_str}
16
+ Please write a Python program that generates the image using our own custom turtle module"""
17
+
18
+ PROMPT_TEMPLATE = "### Instruction:\n{input_struction}\n### Response:\n"
19
+
20
+ TEST_IMAGE_STR ="00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000001222222000000000000\n00000000000002000002000000000000\n00000000000002022202000000000000\n00000000000002020202000000000000\n00000000000002020002000000000000\n00000000000002022223000000000000\n00000000000002000000000000000000\n00000000000002000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000"
21
+
22
+ MOCK_RESPONSE = [
23
+ """for i in range(7):
24
+ with fork_state():
25
+ for j in range(4):
26
+ forward(2*i)
27
+ left(90.0)
28
+ """
29
+ ] * 16
30
+
31
+ LOGO_HEADER = """from myturtle import Turtle
32
+ from myturtle import HALF_INF, INF, EPS_DIST, EPS_ANGLE
33
+
34
+ turtle = Turtle()
35
+ def forward(dist):
36
+ turtle.forward(dist)
37
+ def left(angle):
38
+ turtle.left(angle)
39
+ def right(angle):
40
+ turtle.right(angle)
41
+ def teleport(x, y, theta):
42
+ turtle.teleport(x, y, theta)
43
+ def penup():
44
+ turtle.penup()
45
+ def pendown():
46
+ turtle.pendown()
47
+ def position():
48
+ return turtle.x, turtle.y
49
+ def heading():
50
+ return turtle.heading
51
+ def isdown():
52
+ return turtle.is_down
53
+ def fork_state():
54
+ \"\"\"
55
+ Fork the current state of the turtle.
56
+
57
+ Usage:
58
+ with fork_state():
59
+ forward(100)
60
+ left(90)
61
+ forward(100)
62
+ \"\"\"
63
+ return turtle._TurtleState(turtle)"""
64
+
65
+
66
+ def invert_colors(image):
67
+ """
68
+ Inverts the colors of the input image.
69
+ Args:
70
+ - image (dict): Input image dictionary from Sketchpad.
71
+
72
+ Returns:
73
+ - numpy array: Color-inverted image array.
74
+ """
75
+ # Extract image data from the dictionary and convert to NumPy array
76
+ image_data = image['layers'][0]
77
+ image_array = np.array(image_data)
78
+
79
+
80
+ # Invert colors
81
+ inverted_image = 255 - image_array
82
+ return inverted_image
83
+
84
+ def crop_image_to_center(image, target_height=512, target_width=512, detect_cropping_non_white=False):
85
+ # Calculate the center of the original image
86
+ h, w = image.shape
87
+ center_y, center_x = h // 2, w // 2
88
+
89
+ # Calculate the top-left corner of the crop area
90
+ start_x = max(center_x - target_width // 2, 0)
91
+ start_y = max(center_y - target_height // 2, 0)
92
+
93
+ # Ensure the crop area does not exceed the image boundaries
94
+ end_x = min(start_x + target_width, w)
95
+ end_y = min(start_y + target_height, h)
96
+
97
+ # Crop the image
98
+ cropped_image = image[start_y:end_y, start_x:end_x]
99
+ if detect_cropping_non_white:
100
+ cropping_non_white = False
101
+ all_black_pixel_count = np.sum(image < 50)
102
+ cropped_black_pixel_count = np.sum(cropped_image < 50)
103
+ if cropped_black_pixel_count < all_black_pixel_count:
104
+ cropping_non_white = True
105
+
106
+ # If the cropped image is smaller than the target, pad it to the required size
107
+ if cropped_image.shape[0] < target_height or cropped_image.shape[1] < target_width:
108
+ pad_height = target_height - cropped_image.shape[0]
109
+ pad_width = target_width - cropped_image.shape[1]
110
+ cropped_image = cv2.copyMakeBorder(cropped_image, 0, pad_height, 0, pad_width, cv2.BORDER_CONSTANT, value=255) # Using white padding
111
+
112
+ if detect_cropping_non_white:
113
+ if cropping_non_white:
114
+ return None
115
+ else:
116
+ return cropped_image
117
+ else:
118
+ return cropped_image
119
+
120
+ def downscale_image(image, block_size=8, black_threshold=50, gray_level=10, return_level=False):
121
+ # Calculate the size of the output image
122
+ h, w = image.shape
123
+ new_h, new_w = h // block_size, w // block_size
124
+
125
+ # Initialize the output image
126
+ downscaled = np.zeros((new_h, new_w), dtype=np.uint8)
127
+ image_with_level = np.zeros((new_h, new_w), dtype=np.uint8)
128
+ for i in range(0, h, block_size):
129
+ for j in range(0, w, block_size):
130
+ # Extract the block
131
+ block = image[i:i+block_size, j:j+block_size]
132
+
133
+ # Calculate the proportion of black pixels
134
+ black_pixels = np.sum(block < black_threshold)
135
+ total_pixels = block_size * block_size
136
+ proportion_of_black = black_pixels / total_pixels
137
+ discrete_gray_step = 1 / gray_level
138
+ if proportion_of_black >= 0.95:
139
+ proportion_of_black = 0.94
140
+ proportion_of_black = round (proportion_of_black / discrete_gray_step) * discrete_gray_step
141
+ # check that gray level is descretize to 0 ~ gray_level-1
142
+ try:
143
+ assert 0 <= round(proportion_of_black / discrete_gray_step) < gray_level
144
+ except:
145
+ breakpoint()
146
+
147
+ # Assign the new grayscale value (inverse proportion if needed)
148
+ grayscale_value = int(proportion_of_black * 255)
149
+
150
+ # Assign to the downscaled image
151
+ downscaled[i // block_size, j // block_size] = grayscale_value
152
+ image_with_level[i // block_size, j // block_size] = int(proportion_of_black // discrete_gray_step)
153
+ if return_level:
154
+ return downscaled, image_with_level
155
+ else:
156
+ return downscaled
157
+
158
+
159
+ PORT = 8008
160
+ MODEL_NAME="./axolotl/lora-logo_fix_full_deepseek33b_ds33i_epoch3_lr_0.0002_alpha_512_r_512_merged"
161
+ MODEL_NAME="./axolotl/lora-logo_fix_full_deepseek7b_ds33i_lr_0.0002_alpha_512_r_512_merged"
162
+
163
+ def generate_grid_images(folder):
164
+ import matplotlib.patches as patches
165
+ import matplotlib.pyplot as plt
166
+ num_rows, num_cols = 8,8
167
+ fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 12))
168
+ fig.tight_layout(pad=0)
169
+
170
+ # Plot each image with its AST count as a caption
171
+ # load all jpg images in the folder
172
+ import glob
173
+ import os
174
+ print(f"load file path")
175
+ image_files = glob.glob(os.path.join(folder, "*.jpg"))
176
+ print(f"load file path done")
177
+
178
+ images = []
179
+ for idx, image_file in enumerate(image_files):
180
+ img = load_img(image_file)
181
+ images.append(img)
182
+
183
+ print(f"Loaded {len(images)} images")
184
+
185
+ for idx, img in tqdm(enumerate(images)):
186
+ if idx >= num_rows * num_cols:
187
+ break
188
+ row, col = divmod(idx, num_cols)
189
+ ax = axes[row, col]
190
+ if img is None:
191
+ ax.axis('off')
192
+ continue
193
+ try:
194
+ ax.imshow(img, cmap='gray')
195
+ except:
196
+ breakpoint()
197
+ ax.axis('off')
198
+
199
+ # Hide remaining empty subplots
200
+ for idx in range(len(images), num_rows * num_cols):
201
+ row, col = divmod(idx, num_cols)
202
+ axes[row, col].axis('off')
203
+
204
+ # convert fig to numpy return image array
205
+ fig.canvas.draw()
206
+ image_array = np.array(fig.canvas.renderer.buffer_rgba())
207
+ plt.close(fig)
208
+ return image_array
209
+
210
+
211
+ def llm_call(question_prompt, model_name,
212
+ temperature=1, max_tokens=320,
213
+ top_p=1, n_samples=64, stop=None):
214
+
215
+ client = OpenAI(base_url=f"http://localhost:{PORT}/v1", api_key="empty")
216
+
217
+ response = client.completions.create(
218
+ prompt=question_prompt,
219
+ model=model_name,
220
+ temperature=temperature,
221
+ max_tokens=max_tokens,
222
+ top_p=top_p,
223
+ frequency_penalty=0,
224
+ presence_penalty=0,
225
+ n=n_samples,
226
+ stop=stop
227
+ )
228
+
229
+ return response
230
+
231
+
232
+ import cv2
233
+ def load_img(path):
234
+ img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
235
+
236
+ # Threshold the image to create a binary image (white background, black object)
237
+ _, thresh = cv2.threshold(img, 240, 255, cv2.THRESH_BINARY)
238
+
239
+ # Invert the binary image
240
+ thresh_inv = cv2.bitwise_not(thresh)
241
+
242
+ # Find the bounding box of the non-white area
243
+ x, y, w, h = cv2.boundingRect(thresh_inv)
244
+
245
+ # Extract the ROI (region of interest) of the non-white area
246
+ roi = img[y:y+h, x:x+w]
247
+
248
+ # If the ROI is larger than 200x200, resize it
249
+ if w > 256 or h > 256:
250
+ scale = min(256 / w, 256 / h)
251
+ new_w = int(w * scale)
252
+ new_h = int(h * scale)
253
+ roi = cv2.resize(roi, (new_w, new_h), interpolation=cv2.INTER_AREA)
254
+ w, h = new_w, new_h
255
+
256
+ # Create a new 200x200 white image
257
+ centered_img = np.ones((256, 256), dtype=np.uint8) * 255
258
+
259
+ # Calculate the position to center the ROI in the 200x200 image
260
+ start_x = max(0, (256 - w) // 2)
261
+ start_y = max(0, (256 - h) // 2)
262
+
263
+ # Place the ROI in the centered position
264
+ centered_img[start_y:start_y+h, start_x:start_x+w] = roi
265
+
266
+ return centered_img
267
+
268
+
269
+ def run_code(new_folder, counter, code):
270
+ import matplotlib
271
+ fname = f"{new_folder}/logo_{counter}_.jpg"
272
+ counter += 1
273
+ code_with_header_and_save= f"""
274
+ {LOGO_HEADER}
275
+ {code}
276
+ turtle.save('{fname}')
277
+ """
278
+ try:
279
+ func_timeout(3, exec, args=(code_with_header_and_save, {}))
280
+ matplotlib.pyplot.close()
281
+ # exec(code_with_header_and_save, globals())
282
+ except FunctionTimedOut:
283
+ print("Timeout")
284
+ except Exception as e:
285
+ print(e)
286
+
287
+ def run(img_str):
288
+ prompt = PROMPT_TEMPLATE.format(input_struction=INPUT_STRUCTION_TEMPLATE.format(image_str=img_str))
289
+ if not MOCK:
290
+ response = llm_call(prompt, MODEL_NAME)
291
+ print(response)
292
+ codes = []
293
+ for i, choice in enumerate(response.choices):
294
+ print(f"Choice {i}: {choice.text}")
295
+ codes.append(choice.text)
296
+ else:
297
+ codes = MOCK_RESPONSE
298
+
299
+ gradio_test_images_folder = "gradio_test_images"
300
+ import os
301
+ os.makedirs(gradio_test_images_folder, exist_ok=True)
302
+
303
+ counter = 0
304
+ # generate a random hash id
305
+ import hashlib
306
+ import random
307
+ random_id = hashlib.md5(str(random.random()).encode()).hexdigest()[0:4]
308
+ new_folder = os.path.join(gradio_test_images_folder, random_id)
309
+ os.makedirs(new_folder, exist_ok=True)
310
+
311
+
312
+
313
+ for code in tqdm(codes):
314
+ pass
315
+
316
+ from concurrent.futures import ProcessPoolExecutor
317
+ from concurrent.futures import as_completed
318
+ with ProcessPoolExecutor() as executor:
319
+ futures = [executor.submit(run_code, new_folder, i, code) for i, code in enumerate(codes)]
320
+ for future in as_completed(futures):
321
+ try:
322
+ future.result()
323
+ except Exception as exc:
324
+ print(f'Generated an exception: {exc}')
325
+
326
+ # with open("temp.py", 'w') as f:
327
+ # f.write(code_with_header_and_save)
328
+
329
+ # p = subprocess.Popen(["python", "temp.py"], stderr=subprocess.PIPE, stdout=subprocess.PIPE, env=my_env)
330
+ # out, errs = p.communicate()
331
+ # out, errs, = out.decode(), errs.decode()
332
+ # render
333
+ print(random_id)
334
+ folder_path = f"gradio_test_images/{random_id}"
335
+ return folder_path, codes
336
+
337
+
338
+ def test_gen_img_wrapper(_):
339
+ return generate_grid_images(f"gradio_test_images/{TEST_FOLDER}")
340
+
341
+ def int_img_to_str(integer_img):
342
+ lines = []
343
+ for row in integer_img:
344
+ print("".join([str(x) for x in row]))
345
+ lines.append("".join([str(x) for x in row]))
346
+ image_str = "\n".join(lines)
347
+ return image_str
348
+
349
+ def img_to_code_img(sketchpad_img):
350
+ img = sketchpad_img['layers'][0]
351
+ image_array = np.array(img)
352
+ image_array = 255 - image_array[:,:,3]
353
+
354
+ # height, width = image_array.shape
355
+ # output_size = 512
356
+ # block_size = max(height, width) // output_size
357
+
358
+ # # Create new downscaled image array
359
+ # new_image_array = np.zeros((output_size, output_size), dtype=np.uint8)
360
+ # # Process each block
361
+ # for i in range(output_size):
362
+ # for j in range(output_size):
363
+ # # Define the block
364
+ # block = image_array[i*block_size:(i+1)*block_size, j*block_size:(j+1)*block_size]
365
+ # # Calculate the number of pixels set to 255 in the block
366
+ # white_pixels = np.sum(block == 255)
367
+ # # Set the new pixel value
368
+ # if white_pixels >= (block_size * block_size) / 2:
369
+ # new_image_array[i, j] = 255
370
+ new_image_array= image_array
371
+
372
+ _, int_img = downscale_image(new_image_array, block_size=16, return_level=True)
373
+
374
+ if int_img is not None:
375
+ img_str = int_img_to_str(int_img)
376
+ print(img_str)
377
+
378
+ folder_path, codes = run(img_str)
379
+
380
+ generated_grid_img = generate_grid_images(folder_path)
381
+
382
+ return generated_grid_img
383
+
384
+
385
+ def main():
386
+ """
387
+ Sets up and launches the Gradio demo.
388
+ """
389
+ import gradio as gr
390
+ from gradio import Brush
391
+ theme = gr.themes.Default().set(
392
+ )
393
+ with gr.Blocks(theme=theme) as demo:
394
+ gr.Markdown('# Visual Program Synthesis with LLM')
395
+ gr.Markdown("""LOGO/Turtle graphics Programming-by-Example problems aims to synthesize a program that generates the given target image, where the program uses drawing library similar to Python Turtle.""")
396
+ gr.Markdown("""Here we can draw a target image using the sketchpad, and see what kinds of graphics program LLM generates. To allow the LLM to visually perceive the input image, we convert the image to ASCII strings.""")
397
+ gr.Markdown("## Draw logo")
398
+ with gr.Column():
399
+ canvas = gr.Sketchpad(canvas_size=(512,512), brush=Brush(colors=["black"], default_size=3, color_mode='fixed'))
400
+ submit_button = gr.Button("Submit")
401
+ output_image = gr.Image(label="output")
402
+
403
+ submit_button.click(img_to_code_img, inputs=canvas, outputs=output_image)
404
+ demo.load(
405
+ None,
406
+ None,
407
+ js="""
408
+ () => {
409
+ const params = new URLSearchParams(window.location.search);
410
+ if (!params.has('__theme')) {
411
+ params.set('__theme', 'light');
412
+ window.location.search = params.toString();
413
+ }
414
+ }""",
415
+ )
416
+
417
+ demo.launch(share=True)
418
 
419
+ if __name__ == "__main__":
420
+ # parser = argparse.ArgumentParser()
421
+ # parser.add_argument("--host", type=str, default=None)
422
+ # parser.add_argument("--port", type=int, default=8001)
423
+ # parser.add_argument("--model-url",
424
+ # type=str,
425
+ # default="http://localhost:8000/generate")
426
+ # args = parser.parse_args()
427
 
428
+ # main()
429
+ # run()
430
+
431
+ # demo = build_demo()
432
+ # demo.queue().launch(server_name=args.host,
433
+ # server_port=args.port,
434
+ # share=True)
435
+ main()