3morrrrr's picture
Update app.py
fdfe94f verified
import gradio as gr
import os
import threading
import subprocess
import time
import re
from huggingface_hub import hf_hub_download
from handwriting_api import InputData, validate_input
from hand import Hand
# Create img directory if it doesn't exist
os.makedirs("img", exist_ok=True)
# Initialize the handwriting model
hand = Hand()
# Create a function to generate handwriting
def generate_handwriting(
text,
style,
bias=0.75,
color="#000000",
stroke_width=2,
multiline=True,
transparent_background=True
):
"""Generate handwritten text using the model"""
try:
# Process the text
if multiline:
lines = text.split('\n')
else:
lines = [text]
# Create arrays for parameters
stroke_colors = [color] * len(lines)
stroke_widths = [stroke_width] * len(lines)
biases = [bias] * len(lines)
styles = [style] * len(lines)
# Process each line to replace slashes with dashes
sanitized_lines = []
for line_num, line in enumerate(lines):
if len(line) > 75:
return f"Error: Line {line_num+1} is too long (max 75 characters)"
# Replace slashes with dashes
sanitized_line = line.replace('/', '-').replace('\\', '-')
sanitized_lines.append(sanitized_line)
data = InputData(
text='\n'.join(sanitized_lines),
style=style,
bias=bias,
stroke_colors=stroke_colors,
stroke_widths=stroke_widths
)
try:
validate_input(data)
except ValueError as e:
return f"Error: {str(e)}"
# Generate the handwriting with sanitized lines
hand.write(
filename='img/output.svg',
lines=sanitized_lines,
biases=biases,
styles=styles,
stroke_colors=stroke_colors,
stroke_widths=stroke_widths
)
# Read the generated SVG
with open("img/output.svg", "r") as f:
svg_content = f.read()
# If transparent background is requested, modify the SVG
if transparent_background:
# Remove the background rectangle or make it transparent
pattern = r'<rect[^>]*?fill="white"[^>]*?>'
if re.search(pattern, svg_content):
svg_content = re.sub(pattern, '', svg_content)
# Write the modified SVG back
with open("img/output.svg", "w") as f:
f.write(svg_content)
return svg_content
except Exception as e:
return f"Error: {str(e)}"
def export_to_png(svg_content):
"""Convert SVG to transparent PNG using CairoSVG and Pillow for robust transparency"""
try:
import cairosvg
from PIL import Image
if not svg_content or svg_content.startswith("Error:"):
return None
# Modify the SVG to ensure the background is transparent
# Remove any white background rectangle
pattern = r'<rect[^>]*?fill="white"[^>]*?>'
if re.search(pattern, svg_content):
svg_content = re.sub(pattern, '', svg_content)
# Save the modified SVG to a temporary file
with open("img/temp.svg", "w") as f:
f.write(svg_content)
# Convert SVG to PNG with transparency using CairoSVG
cairosvg.svg2png(
url="img/temp.svg",
write_to="img/output_temp.png",
scale=2.0,
background_color="none" # This ensures transparency
)
# Additional processing with Pillow to ensure transparency
img = Image.open("img/output_temp.png")
# Convert to RGBA if not already
if img.mode != 'RGBA':
img = img.convert('RGBA')
# Create a transparent canvas
transparent_img = Image.new('RGBA', img.size, (0, 0, 0, 0))
# Process the image data to ensure white is transparent
datas = img.getdata()
new_data = []
for item in datas:
# If pixel is white or near-white, make it transparent
if item[0] > 240 and item[1] > 240 and item[2] > 240:
new_data.append((255, 255, 255, 0)) # Transparent
else:
new_data.append(item) # Keep original color
transparent_img.putdata(new_data)
transparent_img.save("img/output.png", "PNG")
# Clean up the temporary file
try:
os.remove("img/output_temp.png")
except:
pass
return "img/output.png"
except Exception as e:
print(f"Error converting to PNG: {str(e)}")
return None
def generate_lyrics_sample():
"""Generate a sample using lyrics"""
from lyrics import all_star
return all_star.split("\n")[0:4]
def generate_handwriting_wrapper(
text,
style,
bias,
color,
stroke_width,
multiline=True
):
svg = generate_handwriting(text, style, bias, color, stroke_width, multiline)
png_path = export_to_png(svg)
return svg, png_path
css = """
.container {max-width: 900px; margin: auto;}
.output-container {min-height: 300px;}
.gr-box {border-radius: 8px; box-shadow: 0 4px 6px rgba(0,0,0,0.1);}
.footer {text-align: center; margin-top: 20px; font-size: 0.8em; color: #666;}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("# 🖋️ Handwriting Synthesis")
gr.Markdown("Generate realistic handwritten text using neural networks.")
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Text Input",
placeholder="Enter text to convert to handwriting...",
lines=5,
max_lines=10,
)
with gr.Row():
with gr.Column(scale=1):
style_select = gr.Slider(
minimum=0,
maximum=12,
step=1,
value=9,
label="Handwriting Style"
)
with gr.Column(scale=1):
bias_slider = gr.Slider(
minimum=0.5,
maximum=1.0,
step=0.05,
value=0.75,
label="Neatness (Higher = Neater)"
)
with gr.Row():
with gr.Column(scale=1):
color_picker = gr.ColorPicker(
label="Ink Color",
value="#000000"
)
with gr.Column(scale=1):
stroke_width = gr.Slider(
minimum=1,
maximum=4,
step=0.5,
value=2,
label="Stroke Width"
)
with gr.Row():
generate_btn = gr.Button("Generate Handwriting", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Accordion("Examples", open=False):
sample_btn = gr.Button("Insert Sample Text")
with gr.Column(scale=3):
output_svg = gr.HTML(label="Generated Handwriting (SVG)", elem_classes=["output-container"])
output_png = gr.Image(type="filepath", label="Generated Handwriting (PNG)", elem_classes=["output-container"])
with gr.Row():
download_svg_btn = gr.Button("Download SVG")
download_png_btn = gr.Button("Download PNG")
gr.Markdown("""
### Tips:
- Try different styles (0-12) to get various handwriting appearances
- Adjust the neatness slider to make writing more or less tidy
- Each line should be 75 characters or less
- The model works best for English text
- Forward slashes (/) and backslashes (\\) will be replaced with dashes (-)
- PNG output has transparency for easy integration into other documents
""")
gr.Markdown("""
<div class="footer">
Created with Gradio •
</div>
""")
# Define interactions
generate_btn.click(
fn=generate_handwriting_wrapper,
inputs=[text_input, style_select, bias_slider, color_picker, stroke_width],
outputs=[output_svg, output_png]
)
clear_btn.click(
fn=lambda: ("", 9, 0.75, "#000000", 2),
inputs=None,
outputs=[text_input, style_select, bias_slider, color_picker, stroke_width]
)
sample_btn.click(
fn=lambda: ("\n".join(generate_lyrics_sample())),
inputs=None,
outputs=[text_input]
)
download_svg_btn.click(
fn=lambda x: x,
inputs=[output_svg],
outputs=[gr.File(label="Download SVG", file_count="single", file_types=[".svg"])]
)
download_png_btn.click(
fn=lambda x: x,
inputs=[output_png],
outputs=[gr.File(label="Download PNG", file_count="single", file_types=[".png"])]
)
if __name__ == "__main__":
# Set port based on environment variable or default to 7860
port = int(os.environ.get("PORT", 7860))
# Check if required packages are installed
missing_packages = []
try:
import cairosvg
except ImportError:
missing_packages.append("cairosvg")
try:
from PIL import Image
except ImportError:
missing_packages.append("pillow")
if missing_packages:
print(f"WARNING: The following packages are missing and required for transparent PNG export: {', '.join(missing_packages)}")
print("Please install them using: pip install " + " ".join(missing_packages))
else:
print("All required packages are installed and ready for transparent PNG export")
demo.launch(server_name="0.0.0.0", server_port=port)