Spaces:
Running
Running
import gradio as gr | |
import os | |
import threading | |
import subprocess | |
import time | |
import re | |
from huggingface_hub import hf_hub_download | |
from handwriting_api import InputData, validate_input | |
from hand import Hand | |
# Create img directory if it doesn't exist | |
os.makedirs("img", exist_ok=True) | |
# Initialize the handwriting model | |
hand = Hand() | |
# Create a function to generate handwriting | |
def generate_handwriting( | |
text, | |
style, | |
bias=0.75, | |
color="#000000", | |
stroke_width=2, | |
multiline=True, | |
transparent_background=True | |
): | |
"""Generate handwritten text using the model""" | |
try: | |
# Process the text | |
if multiline: | |
lines = text.split('\n') | |
else: | |
lines = [text] | |
# Create arrays for parameters | |
stroke_colors = [color] * len(lines) | |
stroke_widths = [stroke_width] * len(lines) | |
biases = [bias] * len(lines) | |
styles = [style] * len(lines) | |
# Process each line to replace slashes with dashes | |
sanitized_lines = [] | |
for line_num, line in enumerate(lines): | |
if len(line) > 75: | |
return f"Error: Line {line_num+1} is too long (max 75 characters)" | |
# Replace slashes with dashes | |
sanitized_line = line.replace('/', '-').replace('\\', '-') | |
sanitized_lines.append(sanitized_line) | |
data = InputData( | |
text='\n'.join(sanitized_lines), | |
style=style, | |
bias=bias, | |
stroke_colors=stroke_colors, | |
stroke_widths=stroke_widths | |
) | |
try: | |
validate_input(data) | |
except ValueError as e: | |
return f"Error: {str(e)}" | |
# Generate the handwriting with sanitized lines | |
hand.write( | |
filename='img/output.svg', | |
lines=sanitized_lines, | |
biases=biases, | |
styles=styles, | |
stroke_colors=stroke_colors, | |
stroke_widths=stroke_widths | |
) | |
# Read the generated SVG | |
with open("img/output.svg", "r") as f: | |
svg_content = f.read() | |
# If transparent background is requested, modify the SVG | |
if transparent_background: | |
# Remove the background rectangle or make it transparent | |
pattern = r'<rect[^>]*?fill="white"[^>]*?>' | |
if re.search(pattern, svg_content): | |
svg_content = re.sub(pattern, '', svg_content) | |
# Write the modified SVG back | |
with open("img/output.svg", "w") as f: | |
f.write(svg_content) | |
return svg_content | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def export_to_png(svg_content): | |
"""Convert SVG to transparent PNG using CairoSVG and Pillow for robust transparency""" | |
try: | |
import cairosvg | |
from PIL import Image | |
if not svg_content or svg_content.startswith("Error:"): | |
return None | |
# Modify the SVG to ensure the background is transparent | |
# Remove any white background rectangle | |
pattern = r'<rect[^>]*?fill="white"[^>]*?>' | |
if re.search(pattern, svg_content): | |
svg_content = re.sub(pattern, '', svg_content) | |
# Save the modified SVG to a temporary file | |
with open("img/temp.svg", "w") as f: | |
f.write(svg_content) | |
# Convert SVG to PNG with transparency using CairoSVG | |
cairosvg.svg2png( | |
url="img/temp.svg", | |
write_to="img/output_temp.png", | |
scale=2.0, | |
background_color="none" # This ensures transparency | |
) | |
# Additional processing with Pillow to ensure transparency | |
img = Image.open("img/output_temp.png") | |
# Convert to RGBA if not already | |
if img.mode != 'RGBA': | |
img = img.convert('RGBA') | |
# Create a transparent canvas | |
transparent_img = Image.new('RGBA', img.size, (0, 0, 0, 0)) | |
# Process the image data to ensure white is transparent | |
datas = img.getdata() | |
new_data = [] | |
for item in datas: | |
# If pixel is white or near-white, make it transparent | |
if item[0] > 240 and item[1] > 240 and item[2] > 240: | |
new_data.append((255, 255, 255, 0)) # Transparent | |
else: | |
new_data.append(item) # Keep original color | |
transparent_img.putdata(new_data) | |
transparent_img.save("img/output.png", "PNG") | |
# Clean up the temporary file | |
try: | |
os.remove("img/output_temp.png") | |
except: | |
pass | |
return "img/output.png" | |
except Exception as e: | |
print(f"Error converting to PNG: {str(e)}") | |
return None | |
def generate_lyrics_sample(): | |
"""Generate a sample using lyrics""" | |
from lyrics import all_star | |
return all_star.split("\n")[0:4] | |
def generate_handwriting_wrapper( | |
text, | |
style, | |
bias, | |
color, | |
stroke_width, | |
multiline=True | |
): | |
svg = generate_handwriting(text, style, bias, color, stroke_width, multiline) | |
png_path = export_to_png(svg) | |
return svg, png_path | |
css = """ | |
.container {max-width: 900px; margin: auto;} | |
.output-container {min-height: 300px;} | |
.gr-box {border-radius: 8px; box-shadow: 0 4px 6px rgba(0,0,0,0.1);} | |
.footer {text-align: center; margin-top: 20px; font-size: 0.8em; color: #666;} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("# 🖋️ Handwriting Synthesis") | |
gr.Markdown("Generate realistic handwritten text using neural networks.") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
text_input = gr.Textbox( | |
label="Text Input", | |
placeholder="Enter text to convert to handwriting...", | |
lines=5, | |
max_lines=10, | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
style_select = gr.Slider( | |
minimum=0, | |
maximum=12, | |
step=1, | |
value=9, | |
label="Handwriting Style" | |
) | |
with gr.Column(scale=1): | |
bias_slider = gr.Slider( | |
minimum=0.5, | |
maximum=1.0, | |
step=0.05, | |
value=0.75, | |
label="Neatness (Higher = Neater)" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
color_picker = gr.ColorPicker( | |
label="Ink Color", | |
value="#000000" | |
) | |
with gr.Column(scale=1): | |
stroke_width = gr.Slider( | |
minimum=1, | |
maximum=4, | |
step=0.5, | |
value=2, | |
label="Stroke Width" | |
) | |
with gr.Row(): | |
generate_btn = gr.Button("Generate Handwriting", variant="primary") | |
clear_btn = gr.Button("Clear") | |
with gr.Accordion("Examples", open=False): | |
sample_btn = gr.Button("Insert Sample Text") | |
with gr.Column(scale=3): | |
output_svg = gr.HTML(label="Generated Handwriting (SVG)", elem_classes=["output-container"]) | |
output_png = gr.Image(type="filepath", label="Generated Handwriting (PNG)", elem_classes=["output-container"]) | |
with gr.Row(): | |
download_svg_btn = gr.Button("Download SVG") | |
download_png_btn = gr.Button("Download PNG") | |
gr.Markdown(""" | |
### Tips: | |
- Try different styles (0-12) to get various handwriting appearances | |
- Adjust the neatness slider to make writing more or less tidy | |
- Each line should be 75 characters or less | |
- The model works best for English text | |
- Forward slashes (/) and backslashes (\\) will be replaced with dashes (-) | |
- PNG output has transparency for easy integration into other documents | |
""") | |
gr.Markdown(""" | |
<div class="footer"> | |
Created with Gradio • | |
</div> | |
""") | |
# Define interactions | |
generate_btn.click( | |
fn=generate_handwriting_wrapper, | |
inputs=[text_input, style_select, bias_slider, color_picker, stroke_width], | |
outputs=[output_svg, output_png] | |
) | |
clear_btn.click( | |
fn=lambda: ("", 9, 0.75, "#000000", 2), | |
inputs=None, | |
outputs=[text_input, style_select, bias_slider, color_picker, stroke_width] | |
) | |
sample_btn.click( | |
fn=lambda: ("\n".join(generate_lyrics_sample())), | |
inputs=None, | |
outputs=[text_input] | |
) | |
download_svg_btn.click( | |
fn=lambda x: x, | |
inputs=[output_svg], | |
outputs=[gr.File(label="Download SVG", file_count="single", file_types=[".svg"])] | |
) | |
download_png_btn.click( | |
fn=lambda x: x, | |
inputs=[output_png], | |
outputs=[gr.File(label="Download PNG", file_count="single", file_types=[".png"])] | |
) | |
if __name__ == "__main__": | |
# Set port based on environment variable or default to 7860 | |
port = int(os.environ.get("PORT", 7860)) | |
# Check if required packages are installed | |
missing_packages = [] | |
try: | |
import cairosvg | |
except ImportError: | |
missing_packages.append("cairosvg") | |
try: | |
from PIL import Image | |
except ImportError: | |
missing_packages.append("pillow") | |
if missing_packages: | |
print(f"WARNING: The following packages are missing and required for transparent PNG export: {', '.join(missing_packages)}") | |
print("Please install them using: pip install " + " ".join(missing_packages)) | |
else: | |
print("All required packages are installed and ready for transparent PNG export") | |
demo.launch(server_name="0.0.0.0", server_port=port) |