3morrrrr commited on
Commit
d5c46c4
·
verified ·
1 Parent(s): fec890e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +260 -388
app.py CHANGED
@@ -1,437 +1,309 @@
1
- import os, re, io
2
- import xml.etree.ElementTree as ET
3
  import gradio as gr
4
-
5
- from hand import Hand # your handwriting model wrapper
6
-
7
- # ---------------------------------- Setup ------------------------------------
 
 
 
 
 
 
8
  os.makedirs("img", exist_ok=True)
9
- hand = Hand()
10
-
11
- # ------------------------------ SVG / coords ---------------------------------
12
- def _parse_viewbox(root):
13
- vb = root.get("viewBox")
14
- if vb:
15
- s = re.sub(r"[,\s]+", " ", vb.strip())
16
- parts = [p for p in s.split(" ") if p]
17
- if len(parts) == 4:
18
- try:
19
- x, y, w, h = map(float, parts)
20
- if w <= 0: w = 1200.0
21
- if h <= 0: h = 400.0
22
- return (x, y, w, h)
23
- except ValueError:
24
- pass
25
- def _num(v, d):
26
- if not v: return float(d)
27
- v = v.strip().lower().replace("px", "").replace(",", ".")
28
- try: return float(v)
29
- except: return float(d)
30
- w = _num(root.get("width"), 1200.0)
31
- h = _num(root.get("height"), 400.0)
32
- return (0.0, 0.0, w, h)
33
-
34
- def _px_to_svg_x(x_px, img_w, vb):
35
- vx, vy, vw, vh = vb
36
- if img_w <= 0: return vx
37
- return vx + (x_px / float(img_w)) * vw
38
-
39
- def _px_to_svg_y(y_px, img_h, vb):
40
- vx, vy, vw, vh = vb
41
- if img_h <= 0: return vy
42
- return vy + (y_px / float(img_h)) * vh
43
-
44
- def _extract_paths(elem):
45
- return [e for e in elem.iter() if e.tag.endswith("path")]
46
-
47
- def _translate_group(elem, dx, dy):
48
- prev = elem.get("transform", "")
49
- elem.set("transform", (prev + f" translate({dx},{dy})").strip())
50
-
51
- # ----------------------------- Tokenization ----------------------------------
52
- def _tokenize_line(line):
53
- tokens, i, n = [], 0, len(line)
54
- while i < n:
55
- ch = line[i]
56
- if ch == "_":
57
- j = i
58
- while j < n and line[j] == "_": j += 1
59
- tokens.append(("sep_underscore", line[i:j]))
60
- i = j
61
- elif ch.isspace():
62
- j = i
63
- while j < n and line[j].isspace(): j += 1
64
- tokens.append(("sep_space", line[i:j]))
65
- i = j
66
- else:
67
- j = i
68
- while j < n and (line[j] != "_" and not line[j].isspace()):
69
- j += 1
70
- tokens.append(("text", line[i:j]))
71
- i = j
72
- return tokens
73
-
74
- def _display_text_from_tokens(tokens):
75
- return "".join([v if t == "text" else " " for t, v in tokens])
76
-
77
- # ------------------------ Rasterization & analysis ---------------------------
78
- def _rasterize_svg(svg_str, scale=3.0):
79
- import cairosvg
80
- from PIL import Image
81
- png = cairosvg.svg2png(bytestring=svg_str.encode("utf-8"), scale=scale, background_color="none")
82
- return Image.open(io.BytesIO(png)).convert("RGBA")
83
-
84
- def _find_blobs_and_gaps(alpha_img):
85
- w, h = alpha_img.size
86
- bbox = alpha_img.getbbox()
87
- if not bbox:
88
- return [], [], (0, 0, w, h)
89
- left, top, right, bottom = bbox
90
-
91
- def col_has_ink(x):
92
- for y in range(top, bottom):
93
- if alpha_img.getpixel((x, y)) > 0:
94
- return True
95
- return False
96
 
97
- blobs, gaps = [], []
98
- x = left
99
- in_blob = col_has_ink(x)
100
- start = x
101
- while x < right:
102
- has = col_has_ink(x)
103
- if has != in_blob:
104
- if in_blob: blobs.append((start, x))
105
- else: gaps.append((start, x))
106
- start, in_blob = x, has
107
- x += 1
108
- if in_blob: blobs.append((start, right))
109
- else: gaps.append((start, right))
110
-
111
- core_gaps = [(blobs[i][1], blobs[i + 1][0]) for i in range(len(blobs) - 1)]
112
- return blobs, core_gaps, (left, top, right, bottom)
113
-
114
- def _column_profile(alpha_img, top, bottom):
115
- w, h = alpha_img.size
116
- prof = [0] * w
117
- for x in range(w):
118
- s = 0
119
- for y in range(top, bottom):
120
- if alpha_img.getpixel((x, y)) > 0:
121
- s += 1
122
- prof[x] = s
123
- return prof
124
-
125
- def _synthesize_gap_near(alpha_img, content_bbox, target_x_px, min_w_px=14, search_pct=0.18):
126
- left, top, right, bottom = content_bbox
127
- prof = _column_profile(alpha_img, top, bottom)
128
- half = max(6, int((right - left) * search_pct))
129
- x0 = max(left + 1, int(target_x_px - half))
130
- x1 = min(right - 1, int(target_x_px + half))
131
- best_x, best_v = x0, prof[x0]
132
- for x in range(x0, x1):
133
- v = prof[x]
134
- if v < best_v:
135
- best_x, best_v = x, v
136
- half_w = max(7, min_w_px // 2)
137
- g0 = max(left + 1, best_x - half_w)
138
- g1 = min(right - 1, best_x + half_w)
139
- if g1 - g0 < min_w_px:
140
- pad = (min_w_px - (g1 - g0)) // 2 + 1
141
- g0 = max(left + 1, g0 - pad)
142
- g1 = min(right - 1, g1 + pad)
143
- return (g0, g1)
144
-
145
- def _draw_underscores_in_gap(root, gap_px, baseline_px, img_w, img_h, vb,
146
- color, stroke_width, n, between_px=3.0,
147
- min_len=10.0, max_len=48.0, frac_of_gap=0.60):
148
- gap_w = max(0.0, gap_px[1] - gap_px[0])
149
- if gap_w <= 2.0 or n <= 0:
150
- return
151
- u_len = max(min_len, min(max_len, frac_of_gap * gap_w))
152
- total_needed = n * u_len + (n - 1) * between_px
153
- if total_needed > gap_w:
154
- scale = gap_w / max(1.0, total_needed)
155
- u_len *= scale
156
- block_w = n * u_len + (n - 1) * between_px
157
- x0_px = gap_px[0] + (gap_w - block_w) / 2.0 # center in gap
158
- for i in range(n):
159
- xs = x0_px + i * (u_len + between_px)
160
- xe = xs + u_len
161
- x0 = _px_to_svg_x(xs, img_w, vb)
162
- x1 = _px_to_svg_x(xe, img_w, vb)
163
- y = _px_to_svg_y(baseline_px, img_h, vb)
164
- p = ET.Element("path")
165
- p.set("d", f"M{x0},{y} L{x1},{y}")
166
- p.set("stroke", color)
167
- p.set("stroke-width", str(max(1, stroke_width)))
168
- p.set("fill", "none")
169
- p.set("stroke-linecap", "round")
170
- root.append(p)
171
-
172
- def _draw_margin_underscores(root, edge_px, side, baseline_px, img_w, img_h, vb,
173
- color, stroke_width, n, between_px=4.0, len_px=22.0, pad_px=6.0):
174
- if n <= 0: return
175
- for i in range(n):
176
- if side == "left":
177
- x1_px = edge_px - pad_px - i * (len_px + between_px)
178
- x0_px = x1_px - len_px
179
- else:
180
- x0_px = edge_px + pad_px + i * (len_px + between_px)
181
- x1_px = x0_px + len_px
182
- x0 = _px_to_svg_x(x0_px, img_w, vb)
183
- x1 = _px_to_svg_x(x1_px, img_w, vb)
184
- y = _px_to_svg_y(baseline_px, img_h, vb)
185
- p = ET.Element("path")
186
- p.set("d", f"M{x0},{y} L{x1},{y}")
187
- p.set("stroke", color)
188
- p.set("stroke-width", str(max(1, stroke_width)))
189
- p.set("fill", "none")
190
- p.set("stroke-linecap", "round")
191
- root.append(p)
192
-
193
- # ----------------------- Core: one line w/ underscores -----------------------
194
- def render_line_svg_with_underscores(line, style, bias, color, stroke_width,
195
- force_in_largest_gap=False):
196
- # Tokenize original; render with underscores->spaces
197
- tokens = _tokenize_line(line)
198
- display_line = _display_text_from_tokens(tokens).replace("/", "-").replace("\\", "-")
199
-
200
- # Render the full line once (keeps model spacing intact)
201
- hand.write(
202
- filename="img/line.tmp.svg",
203
- lines=[display_line if display_line.strip() else " "],
204
- biases=[bias], styles=[style],
205
- stroke_colors=[color], stroke_widths=[stroke_width]
206
- )
207
- root = ET.parse("img/line.tmp.svg").getroot()
208
- vb = _parse_viewbox(root)
209
-
210
- # Put model paths into a group we'll augment
211
- g = ET.Element("g")
212
- for p in _extract_paths(root):
213
- g.append(p)
214
-
215
- # Rasterize & analyze
216
- img = _rasterize_svg(ET.tostring(root, encoding="unicode"), scale=3.0)
217
- alpha = img.split()[-1]
218
- blobs, gaps, content_bbox = _find_blobs_and_gaps(alpha)
219
- img_w, img_h = img.size
220
- left, top, right, bottom = content_bbox
221
- line_h_px = max(20, bottom - top)
222
- baseline_px = bottom - int(0.18 * line_h_px)
223
-
224
- # Count leading/trailing underscores
225
- leading_us = 0
226
- i = 0
227
- while i < len(tokens) and tokens[i][0] != "text":
228
- if tokens[i][0] == "sep_underscore":
229
- leading_us += len(tokens[i][1])
230
- i += 1
231
- trailing_us = 0
232
- j = len(tokens) - 1
233
- while j >= 0 and tokens[j][0] != "text":
234
- if tokens[j][0] == "sep_underscore":
235
- trailing_us += len(tokens[j][1])
236
- j -= 1
237
 
238
- # Build clusters (positions between text runs) + underscore counts
239
- clusters = []
240
- i = 0
241
- while i < len(tokens):
242
- t, v = tokens[i]
243
- if t == "text" and len(v) > 0:
244
- j = i + 1
245
- u_count, saw_sep = 0, False
246
- while j < len(tokens) and tokens[j][0].startswith("sep_"):
247
- saw_sep = True
248
- if tokens[j][0] == "sep_underscore":
249
- u_count += len(tokens[j][1])
250
- j += 1
251
- if saw_sep:
252
- clusters.append({"underscores": u_count})
253
- i = j
254
  else:
255
- i += 1
256
-
257
- words = [v for (t, v) in tokens if t == "text" and len(v) > 0]
258
- gaps_sorted_by_width = sorted(gaps, key=lambda ab: (ab[1] - ab[0]), reverse=True)
259
- used_gaps = [False] * len(gaps)
260
-
261
- def _take_gap_for_idx(idx):
262
- # 1) same index gap if available
263
- if 0 <= idx < len(gaps) and not used_gaps[idx]:
264
- used_gaps[idx] = True; return gaps[idx]
265
- # 2) widest unused real gap
266
- for gp in gaps_sorted_by_width:
267
- try:
268
- real_idx = gaps.index(gp)
269
- except ValueError:
270
- continue
271
- if not used_gaps[real_idx]:
272
- used_gaps[real_idx] = True; return gp
273
- # 3) synthesize near expected split
274
- seen_chars = sum(len(w) for w in words[:idx + 1]) if words else 1
275
- total_chars = sum(len(w) for w in words) or 1
276
- ratio = min(0.95, max(0.05, seen_chars / float(total_chars)))
277
- target_x_px = left + ratio * (right - left)
278
- return _synthesize_gap_near(alpha, content_bbox, target_x_px)
279
-
280
- # Draw interior underscores
281
- drew_any = False
282
- for idx, cluster in enumerate(clusters):
283
- n = int(cluster["underscores"])
284
- if n <= 0:
285
- continue
286
- gap_px = _take_gap_for_idx(idx)
287
- _draw_underscores_in_gap(
288
- g, gap_px, baseline_px, img_w, img_h, vb,
289
- color, stroke_width, n,
290
- between_px=3.0, min_len=10.0, max_len=48.0, frac_of_gap=0.60
291
  )
292
- drew_any = True
293
-
294
- # Leading/trailing
295
- if blobs:
296
- first_left = blobs[0][0]
297
- last_right = blobs[-1][1]
298
- if leading_us:
299
- _draw_margin_underscores(
300
- g, first_left, "left", baseline_px, img_w, img_h, vb,
301
- color, stroke_width, leading_us
302
- )
303
- drew_any = True
304
- if trailing_us:
305
- _draw_margin_underscores(
306
- g, last_right, "right", baseline_px, img_w, img_h, vb,
307
- color, stroke_width, trailing_us
308
- )
309
- drew_any = True
310
-
311
- # Optional: force-draw one underscore in widest gap if user asked and none drawn
312
- if force_in_largest_gap and not drew_any and gaps:
313
- widest = max(gaps, key=lambda ab: (ab[1] - ab[0]))
314
- _draw_underscores_in_gap(
315
- g, widest, baseline_px, img_w, img_h, vb,
316
- color, stroke_width, n=1,
317
- between_px=0.0, min_len=12.0, max_len=48.0, frac_of_gap=0.70
318
  )
319
-
320
- width_estimate = right - left if right > left else img_w // 2
321
- return g, width_estimate
322
-
323
- # -------------------------- Multi-line compositor ----------------------------
324
- def generate_handwriting(text, style, bias=0.75, color="#000000", stroke_width=2,
325
- multiline=True, force=False):
326
- try:
327
- lines = text.split("\n") if multiline else [text]
328
- for idx, ln in enumerate(lines):
329
- if len(ln) > 75:
330
- return f"Error: Line {idx+1} is too long (max 75 characters)"
331
-
332
- svg_root = ET.Element("svg", {"xmlns": "http://www.w3.org/2000/svg", "viewBox": "0 0 1200 800"})
333
- y0, line_gap, max_right = 80.0, 110.0, 0.0
334
-
335
- for i, line in enumerate(lines):
336
- g, w = render_line_svg_with_underscores(
337
- line.replace("/", "-").replace("\\", "-"),
338
- style, bias, color, stroke_width,
339
- force_in_largest_gap=force
340
- )
341
- _translate_group(g, dx=40, dy=y0 + i * line_gap)
342
- svg_root.append(g)
343
- max_right = max(max_right, 40 + w)
344
-
345
- height = int(y0 + len(lines) * line_gap + 80)
346
- width = max(300, int(max_right + 40))
347
- svg_root.set("viewBox", f"0 0 {width} {height}")
348
-
349
- svg_content = ET.tostring(svg_root, encoding="unicode")
350
- with open("img/output.svg", "w", encoding="utf-8") as f: f.write(svg_content)
351
  return svg_content
352
  except Exception as e:
353
  return f"Error: {str(e)}"
354
 
355
- # ------------------------------- PNG export ----------------------------------
356
  def export_to_png(svg_content):
 
357
  try:
358
  import cairosvg
359
  from PIL import Image
 
360
  if not svg_content or svg_content.startswith("Error:"):
361
  return None
362
- tmp_svg = "img/temp.svg"
363
- with open(tmp_svg, "w", encoding="utf-8") as f: f.write(svg_content)
364
- cairosvg.svg2png(url=tmp_svg, write_to="img/output_temp.png", scale=2.2, background_color="none")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  img = Image.open("img/output_temp.png")
366
- if img.mode != "RGBA": img = img.convert("RGBA")
367
- data = img.getdata()
368
- img.putdata([(255,255,255,0) if r>240 and g>240 and b>240 else (r,g,b,a) for (r,g,b,a) in data])
369
- out_path = "img/output.png"
370
- img.save(out_path, "PNG")
371
- try: os.remove("img/output_temp.png")
372
- except: pass
373
- return out_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  except Exception as e:
375
  print(f"Error converting to PNG: {str(e)}")
376
  return None
377
 
378
- # --------------------------------- UI ----------------------------------------
379
- def generate_handwriting_wrapper(text, style, bias, color, stroke_width, force):
380
- svg = generate_handwriting(text, style, bias, color, stroke_width, multiline=True, force=force)
381
- png = export_to_png(svg)
382
- return svg, png, "img/output.svg"
 
 
 
 
 
 
 
 
 
 
 
383
 
384
  css = """
385
  .container {max-width: 900px; margin: auto;}
386
  .output-container {min-height: 300px;}
 
 
387
  """
388
 
389
  with gr.Blocks(css=css) as demo:
390
- gr.Markdown("# 🖋️ Handwriting Synthesis — Underscores in Real Gaps")
391
- gr.Markdown("Spacing is the model’s own. We paint `_` *inside* the actual gaps. If needed, toggle force to drop one underscore into the widest gap.")
392
-
393
  with gr.Row():
394
  with gr.Column(scale=2):
395
  text_input = gr.Textbox(
396
  label="Text Input",
397
- placeholder="Try: user_name, zeb_3asba, long__name, __init__",
398
- lines=5
 
399
  )
 
400
  with gr.Row():
401
- style_select = gr.Slider(0, 12, step=1, value=9, label="Handwriting Style")
402
- bias_slider = gr.Slider(0.5, 1.0, step=0.05, value=0.75, label="Neatness (Higher = Neater)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  with gr.Row():
404
- color_picker = gr.ColorPicker(label="Ink Color", value="#000000")
405
- stroke_width = gr.Slider(1, 4, step=0.5, value=2, label="Stroke Width")
406
- force_toggle = gr.Checkbox(label="Force underscore in largest gap if none drawn", value=False)
407
- generate_btn = gr.Button("Generate Handwriting", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  with gr.Column(scale=3):
409
  output_svg = gr.HTML(label="Generated Handwriting (SVG)", elem_classes=["output-container"])
410
  output_png = gr.Image(type="filepath", label="Generated Handwriting (PNG)", elem_classes=["output-container"])
411
- download_svg_file = gr.File(label="Download SVG")
412
- download_png_file = gr.File(label="Download PNG")
413
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  generate_btn.click(
415
  fn=generate_handwriting_wrapper,
416
- inputs=[text_input, style_select, bias_slider, color_picker, stroke_width, force_toggle],
417
- outputs=[output_svg, output_png, download_svg_file]
418
- ).then(
419
- fn=lambda p: p,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  inputs=[output_png],
421
- outputs=[download_png_file]
422
  )
423
 
424
- # -------------------------------- main ---------------------------------------
425
  if __name__ == "__main__":
426
- missing = []
 
 
 
 
427
  try:
428
- import cairosvg # noqa
429
  except ImportError:
430
- missing.append("cairosvg")
 
431
  try:
432
- from PIL import Image # noqa
433
  except ImportError:
434
- missing.append("pillow")
435
- if missing:
436
- print("Install:", " ".join(missing))
437
- demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import threading
4
+ import subprocess
5
+ import time
6
+ import re
7
+ from huggingface_hub import hf_hub_download
8
+ from handwriting_api import InputData, validate_input
9
+ from hand import Hand
10
+
11
+ # Create img directory if it doesn't exist
12
  os.makedirs("img", exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Initialize the handwriting model
15
+ hand = Hand()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # Create a function to generate handwriting
18
+ def generate_handwriting(
19
+ text,
20
+ style,
21
+ bias=0.75,
22
+ color="#000000",
23
+ stroke_width=2,
24
+ multiline=True,
25
+ transparent_background=True
26
+ ):
27
+ """Generate handwritten text using the model"""
28
+ try:
29
+ # Process the text
30
+ if multiline:
31
+ lines = text.split('\n')
 
32
  else:
33
+ lines = [text]
34
+
35
+ # Create arrays for parameters
36
+ stroke_colors = [color] * len(lines)
37
+ stroke_widths = [stroke_width] * len(lines)
38
+ biases = [bias] * len(lines)
39
+ styles = [style] * len(lines)
40
+
41
+ # Process each line to replace slashes with dashes
42
+ sanitized_lines = []
43
+ for line_num, line in enumerate(lines):
44
+ if len(line) > 75:
45
+ return f"Error: Line {line_num+1} is too long (max 75 characters)"
46
+
47
+ # Replace slashes with dashes
48
+ sanitized_line = line.replace('/', '-').replace('\\', '-')
49
+ sanitized_lines.append(sanitized_line)
50
+
51
+ data = InputData(
52
+ text='\n'.join(sanitized_lines),
53
+ style=style,
54
+ bias=bias,
55
+ stroke_colors=stroke_colors,
56
+ stroke_widths=stroke_widths
 
 
 
 
 
 
 
 
 
 
 
 
57
  )
58
+
59
+ try:
60
+ validate_input(data)
61
+ except ValueError as e:
62
+ return f"Error: {str(e)}"
63
+
64
+ # Generate the handwriting with sanitized lines
65
+ hand.write(
66
+ filename='img/output.svg',
67
+ lines=sanitized_lines,
68
+ biases=biases,
69
+ styles=styles,
70
+ stroke_colors=stroke_colors,
71
+ stroke_widths=stroke_widths
 
 
 
 
 
 
 
 
 
 
 
 
72
  )
73
+
74
+ # Read the generated SVG
75
+ with open("img/output.svg", "r") as f:
76
+ svg_content = f.read()
77
+
78
+ # If transparent background is requested, modify the SVG
79
+ if transparent_background:
80
+ # Remove the background rectangle or make it transparent
81
+ pattern = r'<rect[^>]*?fill="white"[^>]*?>'
82
+ if re.search(pattern, svg_content):
83
+ svg_content = re.sub(pattern, '', svg_content)
84
+
85
+ # Write the modified SVG back
86
+ with open("img/output.svg", "w") as f:
87
+ f.write(svg_content)
88
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  return svg_content
90
  except Exception as e:
91
  return f"Error: {str(e)}"
92
 
 
93
  def export_to_png(svg_content):
94
+ """Convert SVG to transparent PNG using CairoSVG and Pillow for robust transparency"""
95
  try:
96
  import cairosvg
97
  from PIL import Image
98
+
99
  if not svg_content or svg_content.startswith("Error:"):
100
  return None
101
+
102
+ # Modify the SVG to ensure the background is transparent
103
+ # Remove any white background rectangle
104
+ pattern = r'<rect[^>]*?fill="white"[^>]*?>'
105
+ if re.search(pattern, svg_content):
106
+ svg_content = re.sub(pattern, '', svg_content)
107
+
108
+ # Save the modified SVG to a temporary file
109
+ with open("img/temp.svg", "w") as f:
110
+ f.write(svg_content)
111
+
112
+ # Convert SVG to PNG with transparency using CairoSVG
113
+ cairosvg.svg2png(
114
+ url="img/temp.svg",
115
+ write_to="img/output_temp.png",
116
+ scale=2.0,
117
+ background_color="none" # This ensures transparency
118
+ )
119
+
120
+ # Additional processing with Pillow to ensure transparency
121
  img = Image.open("img/output_temp.png")
122
+
123
+ # Convert to RGBA if not already
124
+ if img.mode != 'RGBA':
125
+ img = img.convert('RGBA')
126
+
127
+ # Create a transparent canvas
128
+ transparent_img = Image.new('RGBA', img.size, (0, 0, 0, 0))
129
+
130
+ # Process the image data to ensure white is transparent
131
+ datas = img.getdata()
132
+ new_data = []
133
+
134
+ for item in datas:
135
+ # If pixel is white or near-white, make it transparent
136
+ if item[0] > 240 and item[1] > 240 and item[2] > 240:
137
+ new_data.append((255, 255, 255, 0)) # Transparent
138
+ else:
139
+ new_data.append(item) # Keep original color
140
+
141
+ transparent_img.putdata(new_data)
142
+ transparent_img.save("img/output.png", "PNG")
143
+
144
+ # Clean up the temporary file
145
+ try:
146
+ os.remove("img/output_temp.png")
147
+ except:
148
+ pass
149
+
150
+ return "img/output.png"
151
  except Exception as e:
152
  print(f"Error converting to PNG: {str(e)}")
153
  return None
154
 
155
+ def generate_lyrics_sample():
156
+ """Generate a sample using lyrics"""
157
+ from lyrics import all_star
158
+ return all_star.split("\n")[0:4]
159
+
160
+ def generate_handwriting_wrapper(
161
+ text,
162
+ style,
163
+ bias,
164
+ color,
165
+ stroke_width,
166
+ multiline=True
167
+ ):
168
+ svg = generate_handwriting(text, style, bias, color, stroke_width, multiline)
169
+ png_path = export_to_png(svg)
170
+ return svg, png_path
171
 
172
  css = """
173
  .container {max-width: 900px; margin: auto;}
174
  .output-container {min-height: 300px;}
175
+ .gr-box {border-radius: 8px; box-shadow: 0 4px 6px rgba(0,0,0,0.1);}
176
+ .footer {text-align: center; margin-top: 20px; font-size: 0.8em; color: #666;}
177
  """
178
 
179
  with gr.Blocks(css=css) as demo:
180
+ gr.Markdown("# 🖋️ Handwriting Synthesis")
181
+ gr.Markdown("Generate realistic handwritten text using neural networks.")
182
+
183
  with gr.Row():
184
  with gr.Column(scale=2):
185
  text_input = gr.Textbox(
186
  label="Text Input",
187
+ placeholder="Enter text to convert to handwriting...",
188
+ lines=5,
189
+ max_lines=10,
190
  )
191
+
192
  with gr.Row():
193
+ with gr.Column(scale=1):
194
+ style_select = gr.Slider(
195
+ minimum=0,
196
+ maximum=12,
197
+ step=1,
198
+ value=9,
199
+ label="Handwriting Style"
200
+ )
201
+ with gr.Column(scale=1):
202
+ bias_slider = gr.Slider(
203
+ minimum=0.5,
204
+ maximum=1.0,
205
+ step=0.05,
206
+ value=0.75,
207
+ label="Neatness (Higher = Neater)"
208
+ )
209
+
210
  with gr.Row():
211
+ with gr.Column(scale=1):
212
+ color_picker = gr.ColorPicker(
213
+ label="Ink Color",
214
+ value="#000000"
215
+ )
216
+ with gr.Column(scale=1):
217
+ stroke_width = gr.Slider(
218
+ minimum=1,
219
+ maximum=4,
220
+ step=0.5,
221
+ value=2,
222
+ label="Stroke Width"
223
+ )
224
+
225
+ with gr.Row():
226
+ generate_btn = gr.Button("Generate Handwriting", variant="primary")
227
+ clear_btn = gr.Button("Clear")
228
+
229
+ with gr.Accordion("Examples", open=False):
230
+ sample_btn = gr.Button("Insert Sample Text")
231
+
232
  with gr.Column(scale=3):
233
  output_svg = gr.HTML(label="Generated Handwriting (SVG)", elem_classes=["output-container"])
234
  output_png = gr.Image(type="filepath", label="Generated Handwriting (PNG)", elem_classes=["output-container"])
235
+
236
+ with gr.Row():
237
+ download_svg_btn = gr.Button("Download SVG")
238
+ download_png_btn = gr.Button("Download PNG")
239
+
240
+ gr.Markdown("""
241
+ ### Tips:
242
+ - Try different styles (0-12) to get various handwriting appearances
243
+ - Adjust the neatness slider to make writing more or less tidy
244
+ - Each line should be 75 characters or less
245
+ - The model works best for English text
246
+ - Forward slashes (/) and backslashes (\\) will be replaced with dashes (-)
247
+ - PNG output has transparency for easy integration into other documents
248
+ """)
249
+
250
+ gr.Markdown("""
251
+ <div class="footer">
252
+ Created with Gradio •
253
+ </div>
254
+ """)
255
+
256
+ # Define interactions
257
  generate_btn.click(
258
  fn=generate_handwriting_wrapper,
259
+ inputs=[text_input, style_select, bias_slider, color_picker, stroke_width],
260
+ outputs=[output_svg, output_png]
261
+ )
262
+
263
+ clear_btn.click(
264
+ fn=lambda: ("", 9, 0.75, "#000000", 2),
265
+ inputs=None,
266
+ outputs=[text_input, style_select, bias_slider, color_picker, stroke_width]
267
+ )
268
+
269
+ sample_btn.click(
270
+ fn=lambda: ("\n".join(generate_lyrics_sample())),
271
+ inputs=None,
272
+ outputs=[text_input]
273
+ )
274
+
275
+ download_svg_btn.click(
276
+ fn=lambda x: x,
277
+ inputs=[output_svg],
278
+ outputs=[gr.File(label="Download SVG", file_count="single", file_types=[".svg"])]
279
+ )
280
+
281
+ download_png_btn.click(
282
+ fn=lambda x: x,
283
  inputs=[output_png],
284
+ outputs=[gr.File(label="Download PNG", file_count="single", file_types=[".png"])]
285
  )
286
 
 
287
  if __name__ == "__main__":
288
+ # Set port based on environment variable or default to 7860
289
+ port = int(os.environ.get("PORT", 7860))
290
+
291
+ # Check if required packages are installed
292
+ missing_packages = []
293
  try:
294
+ import cairosvg
295
  except ImportError:
296
+ missing_packages.append("cairosvg")
297
+
298
  try:
299
+ from PIL import Image
300
  except ImportError:
301
+ missing_packages.append("pillow")
302
+
303
+ if missing_packages:
304
+ print(f"WARNING: The following packages are missing and required for transparent PNG export: {', '.join(missing_packages)}")
305
+ print("Please install them using: pip install " + " ".join(missing_packages))
306
+ else:
307
+ print("All required packages are installed and ready for transparent PNG export")
308
+
309
+ demo.launch(server_name="0.0.0.0", server_port=port)