Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,376 Bytes
588fd03 44dba41 b3a9f3b 2c25495 b3a9f3b da746e1 588fd03 b3a9f3b 588fd03 b3a9f3b 44dba41 b3a9f3b 588fd03 b3a9f3b 588fd03 b3a9f3b 588fd03 b3a9f3b 588fd03 da746e1 b3a9f3b 588fd03 b3a9f3b 588fd03 4ba0cad 588fd03 b3a9f3b 588fd03 efababe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import gc
import spaces
import xml.etree.ElementTree as ET
import re
import os
# Clear GPU memory
torch.cuda.empty_cache()
gc.collect()
# Alpaca prompt template
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{}
### Input:
{}
### Response:
{}"""
# Load model with memory optimizations
model_path = "vinoku89/qwen3-4B-svg-code-gen"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True # Add this if needed for custom models
)
def validate_svg(svg_content):
"""
Validate if SVG content is properly formatted and renderable
"""
try:
# Clean up the SVG content
svg_content = svg_content.strip()
# If it doesn't start with <svg, try to extract SVG content
if not svg_content.startswith('<svg'):
# Look for SVG tags in the content
svg_match = re.search(r'<svg[^>]*>.*?</svg>', svg_content, re.DOTALL | re.IGNORECASE)
if svg_match:
svg_content = svg_match.group(0)
else:
# If no complete SVG found, wrap content in SVG tags
if any(tag in svg_content.lower() for tag in ['<circle', '<rect', '<path', '<line', '<polygon', '<ellipse', '<text']):
svg_content = f'<svg xmlns="http://www.w3.org/2000/svg" width="250" height="250">{svg_content}</svg>'
else:
raise ValueError("No valid SVG elements found")
# Parse XML to validate structure
ET.fromstring(svg_content)
return True, svg_content
except ET.ParseError as e:
return False, f"XML Parse Error: {str(e)}"
except Exception as e:
return False, f"Validation Error: {str(e)}"
@spaces.GPU(duration=60) # Add duration limit
def generate_svg(prompt):
# Clear cache before generation
torch.cuda.empty_cache()
# Format the prompt using Alpaca template
instruction = "Generate SVG code based on the given description."
formatted_prompt = alpaca_prompt.format(
instruction,
prompt,
"" # Empty response - model will fill this
)
inputs = tokenizer(formatted_prompt, return_tensors="pt")
# Move inputs to the same device as model
if hasattr(model, 'device'):
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad(): # Disable gradient computation to save memory
outputs = model.generate(
**inputs,
max_length=1024,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
max_new_tokens=512 # Limit new tokens instead of total length
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the response part (after "### Response:")
response_start = generated_text.find("### Response:")
if response_start != -1:
svg_code = generated_text[response_start + len("### Response:"):].strip()
else:
# Fallback: remove the original formatted prompt
svg_code = generated_text[len(formatted_prompt):].strip()
# Validate SVG
is_valid, result = validate_svg(svg_code)
if is_valid:
# SVG is valid
validated_svg = result
# Ensure the SVG has proper dimensions for display (keep moderate size)
if 'width=' not in validated_svg or 'height=' not in validated_svg:
validated_svg = validated_svg.replace('<svg', '<svg width="250" height="250"', 1)
svg_display = validated_svg
else:
# SVG is invalid, show error message
svg_display = f"""
<div style="width: 250px; height: 200px; border: 2px dashed #ff6b6b;
display: flex; align-items: center; justify-content: center;
background-color: #fff5f5; border-radius: 8px; padding: 15px;
text-align: center; color: #e03131; font-family: Arial, sans-serif;">
<div>
<h4 style="margin: 0 0 8px 0; color: #e03131;">🚫 Preview Not Available</h4>
<p style="margin: 0; font-size: 12px;">Generated SVG contains errors:<br>
<em style="font-size: 11px;">{result}</em></p>
</div>
</div>
"""
# Clear cache after generation
torch.cuda.empty_cache()
return svg_code, svg_display
# Authentication function using HF Space secrets
def authenticate(username, password):
"""
Authentication function for Gradio using HF Space secrets
Returns True if credentials are valid, False otherwise
"""
# Get credentials from HF Space secrets
valid_username = os.getenv("user") # This matches your secret name "user"
valid_password = os.getenv("password") # This matches your secret name "password"
# Fallback credentials if secrets are not available (for local testing)
if valid_username is None:
valid_username = "user"
print("Warning: 'user' secret not found, using fallback")
if valid_password is None:
valid_password = "password"
print("Warning: 'password' secret not found, using fallback")
return username == valid_username and password == valid_password
# Minimal CSS for slightly larger HTML preview only
custom_css = """
div[data-testid="HTML"] {
min-height: 320px !important;
}
"""
gradio_app = gr.Interface(
fn=generate_svg,
inputs=gr.Textbox(
lines=2,
placeholder="Describe the SVG you want (e.g., 'a red circle with blue border')..."
),
outputs=[
gr.Code(label="Generated SVG Code", language="html"),
gr.HTML(label="SVG Preview")
],
title="SVG Code Generator",
description="Generate SVG code from natural language using a fine-tuned LLM.",
css=custom_css
)
if __name__ == "__main__":
gradio_app.launch(auth=(os.getenv("user"), os.getenv("password")), share=True, ssr_mode=False) |