Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import threading | |
| import subprocess | |
| import time | |
| import re | |
| from huggingface_hub import hf_hub_download | |
| from handwriting_api import InputData, validate_input | |
| from hand import Hand | |
| # Create img directory if it doesn't exist | |
| os.makedirs("img", exist_ok=True) | |
| # Initialize the handwriting model | |
| hand = Hand() | |
| # Create a function to generate handwriting | |
| def generate_handwriting( | |
| text, | |
| style, | |
| bias=0.75, | |
| color="#000000", | |
| stroke_width=2, | |
| multiline=True, | |
| transparent_background=True | |
| ): | |
| """Generate handwritten text using the model""" | |
| try: | |
| # Process the text | |
| if multiline: | |
| lines = text.split('\n') | |
| else: | |
| lines = [text] | |
| # Create arrays for parameters | |
| stroke_colors = [color] * len(lines) | |
| stroke_widths = [stroke_width] * len(lines) | |
| biases = [bias] * len(lines) | |
| styles = [style] * len(lines) | |
| # Process each line to replace slashes with dashes | |
| sanitized_lines = [] | |
| for line_num, line in enumerate(lines): | |
| if len(line) > 75: | |
| return f"Error: Line {line_num+1} is too long (max 75 characters)" | |
| # Replace slashes with dashes | |
| sanitized_line = line.replace('/', '-').replace('\\', '-') | |
| sanitized_lines.append(sanitized_line) | |
| data = InputData( | |
| text='\n'.join(sanitized_lines), | |
| style=style, | |
| bias=bias, | |
| stroke_colors=stroke_colors, | |
| stroke_widths=stroke_widths | |
| ) | |
| try: | |
| validate_input(data) | |
| except ValueError as e: | |
| return f"Error: {str(e)}" | |
| # Generate the handwriting with sanitized lines | |
| hand.write( | |
| filename='img/output.svg', | |
| lines=sanitized_lines, | |
| biases=biases, | |
| styles=styles, | |
| stroke_colors=stroke_colors, | |
| stroke_widths=stroke_widths | |
| ) | |
| # Read the generated SVG | |
| with open("img/output.svg", "r") as f: | |
| svg_content = f.read() | |
| # If transparent background is requested, modify the SVG | |
| if transparent_background: | |
| # Remove the background rectangle or make it transparent | |
| pattern = r'<rect[^>]*?fill="white"[^>]*?>' | |
| if re.search(pattern, svg_content): | |
| svg_content = re.sub(pattern, '', svg_content) | |
| # Write the modified SVG back | |
| with open("img/output.svg", "w") as f: | |
| f.write(svg_content) | |
| return svg_content | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def export_to_png(svg_content): | |
| """Convert SVG to transparent PNG using CairoSVG and Pillow for robust transparency""" | |
| try: | |
| import cairosvg | |
| from PIL import Image | |
| if not svg_content or svg_content.startswith("Error:"): | |
| return None | |
| # Modify the SVG to ensure the background is transparent | |
| # Remove any white background rectangle | |
| pattern = r'<rect[^>]*?fill="white"[^>]*?>' | |
| if re.search(pattern, svg_content): | |
| svg_content = re.sub(pattern, '', svg_content) | |
| # Save the modified SVG to a temporary file | |
| with open("img/temp.svg", "w") as f: | |
| f.write(svg_content) | |
| # Convert SVG to PNG with transparency using CairoSVG | |
| cairosvg.svg2png( | |
| url="img/temp.svg", | |
| write_to="img/output_temp.png", | |
| scale=2.0, | |
| background_color="none" # This ensures transparency | |
| ) | |
| # Additional processing with Pillow to ensure transparency | |
| img = Image.open("img/output_temp.png") | |
| # Convert to RGBA if not already | |
| if img.mode != 'RGBA': | |
| img = img.convert('RGBA') | |
| # Create a transparent canvas | |
| transparent_img = Image.new('RGBA', img.size, (0, 0, 0, 0)) | |
| # Process the image data to ensure white is transparent | |
| datas = img.getdata() | |
| new_data = [] | |
| for item in datas: | |
| # If pixel is white or near-white, make it transparent | |
| if item[0] > 240 and item[1] > 240 and item[2] > 240: | |
| new_data.append((255, 255, 255, 0)) # Transparent | |
| else: | |
| new_data.append(item) # Keep original color | |
| transparent_img.putdata(new_data) | |
| transparent_img.save("img/output.png", "PNG") | |
| # Clean up the temporary file | |
| try: | |
| os.remove("img/output_temp.png") | |
| except: | |
| pass | |
| return "img/output.png" | |
| except Exception as e: | |
| print(f"Error converting to PNG: {str(e)}") | |
| return None | |
| def generate_lyrics_sample(): | |
| """Generate a sample using lyrics""" | |
| from lyrics import all_star | |
| return all_star.split("\n")[0:4] | |
| def generate_handwriting_wrapper( | |
| text, | |
| style, | |
| bias, | |
| color, | |
| stroke_width, | |
| multiline=True | |
| ): | |
| svg = generate_handwriting(text, style, bias, color, stroke_width, multiline) | |
| png_path = export_to_png(svg) | |
| return svg, png_path | |
| css = """ | |
| .container {max-width: 900px; margin: auto;} | |
| .output-container {min-height: 300px;} | |
| .gr-box {border-radius: 8px; box-shadow: 0 4px 6px rgba(0,0,0,0.1);} | |
| .footer {text-align: center; margin-top: 20px; font-size: 0.8em; color: #666;} | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("# 🖋️ Handwriting Synthesis") | |
| gr.Markdown("Generate realistic handwritten text using neural networks.") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| text_input = gr.Textbox( | |
| label="Text Input", | |
| placeholder="Enter text to convert to handwriting...", | |
| lines=5, | |
| max_lines=10, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| style_select = gr.Slider( | |
| minimum=0, | |
| maximum=12, | |
| step=1, | |
| value=9, | |
| label="Handwriting Style" | |
| ) | |
| with gr.Column(scale=1): | |
| bias_slider = gr.Slider( | |
| minimum=0.5, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.75, | |
| label="Neatness (Higher = Neater)" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| color_picker = gr.ColorPicker( | |
| label="Ink Color", | |
| value="#000000" | |
| ) | |
| with gr.Column(scale=1): | |
| stroke_width = gr.Slider( | |
| minimum=1, | |
| maximum=4, | |
| step=0.5, | |
| value=2, | |
| label="Stroke Width" | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate Handwriting", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| with gr.Accordion("Examples", open=False): | |
| sample_btn = gr.Button("Insert Sample Text") | |
| with gr.Column(scale=3): | |
| output_svg = gr.HTML(label="Generated Handwriting (SVG)", elem_classes=["output-container"]) | |
| output_png = gr.Image(type="filepath", label="Generated Handwriting (PNG)", elem_classes=["output-container"]) | |
| with gr.Row(): | |
| download_svg_btn = gr.Button("Download SVG") | |
| download_png_btn = gr.Button("Download PNG") | |
| gr.Markdown(""" | |
| ### Tips: | |
| - Try different styles (0-12) to get various handwriting appearances | |
| - Adjust the neatness slider to make writing more or less tidy | |
| - Each line should be 75 characters or less | |
| - The model works best for English text | |
| - Forward slashes (/) and backslashes (\\) will be replaced with dashes (-) | |
| - PNG output has transparency for easy integration into other documents | |
| """) | |
| gr.Markdown(""" | |
| <div class="footer"> | |
| Created with Gradio • | |
| </div> | |
| """) | |
| # Define interactions | |
| generate_btn.click( | |
| fn=generate_handwriting_wrapper, | |
| inputs=[text_input, style_select, bias_slider, color_picker, stroke_width], | |
| outputs=[output_svg, output_png] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", 9, 0.75, "#000000", 2), | |
| inputs=None, | |
| outputs=[text_input, style_select, bias_slider, color_picker, stroke_width] | |
| ) | |
| sample_btn.click( | |
| fn=lambda: ("\n".join(generate_lyrics_sample())), | |
| inputs=None, | |
| outputs=[text_input] | |
| ) | |
| download_svg_btn.click( | |
| fn=lambda x: x, | |
| inputs=[output_svg], | |
| outputs=[gr.File(label="Download SVG", file_count="single", file_types=[".svg"])] | |
| ) | |
| download_png_btn.click( | |
| fn=lambda x: x, | |
| inputs=[output_png], | |
| outputs=[gr.File(label="Download PNG", file_count="single", file_types=[".png"])] | |
| ) | |
| if __name__ == "__main__": | |
| # Set port based on environment variable or default to 7860 | |
| port = int(os.environ.get("PORT", 7860)) | |
| # Check if required packages are installed | |
| missing_packages = [] | |
| try: | |
| import cairosvg | |
| except ImportError: | |
| missing_packages.append("cairosvg") | |
| try: | |
| from PIL import Image | |
| except ImportError: | |
| missing_packages.append("pillow") | |
| if missing_packages: | |
| print(f"WARNING: The following packages are missing and required for transparent PNG export: {', '.join(missing_packages)}") | |
| print("Please install them using: pip install " + " ".join(missing_packages)) | |
| else: | |
| print("All required packages are installed and ready for transparent PNG export") | |
| demo.launch(server_name="0.0.0.0", server_port=port) |