File size: 6,376 Bytes
588fd03
 
44dba41
b3a9f3b
2c25495
b3a9f3b
 
da746e1
588fd03
b3a9f3b
 
 
588fd03
b3a9f3b
 
44dba41
b3a9f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588fd03
b3a9f3b
 
 
 
 
 
 
 
 
588fd03
 
b3a9f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588fd03
b3a9f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588fd03
 
 
da746e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3a9f3b
 
 
 
 
 
 
588fd03
 
b3a9f3b
 
 
 
588fd03
4ba0cad
588fd03
 
 
b3a9f3b
 
588fd03
 
 
efababe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import gc
import spaces
import xml.etree.ElementTree as ET
import re
import os

# Clear GPU memory
torch.cuda.empty_cache()
gc.collect()

# Alpaca prompt template
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

# Load model with memory optimizations
model_path = "vinoku89/qwen3-4B-svg-code-gen"

tokenizer = AutoTokenizer.from_pretrained(model_path)

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
    device_map="auto",
    low_cpu_mem_usage=True,
    trust_remote_code=True  # Add this if needed for custom models
)

def validate_svg(svg_content):
    """
    Validate if SVG content is properly formatted and renderable
    """
    try:
        # Clean up the SVG content
        svg_content = svg_content.strip()
        
        # If it doesn't start with <svg, try to extract SVG content
        if not svg_content.startswith('<svg'):
            # Look for SVG tags in the content
            svg_match = re.search(r'<svg[^>]*>.*?</svg>', svg_content, re.DOTALL | re.IGNORECASE)
            if svg_match:
                svg_content = svg_match.group(0)
            else:
                # If no complete SVG found, wrap content in SVG tags
                if any(tag in svg_content.lower() for tag in ['<circle', '<rect', '<path', '<line', '<polygon', '<ellipse', '<text']):
                    svg_content = f'<svg xmlns="http://www.w3.org/2000/svg" width="250" height="250">{svg_content}</svg>'
                else:
                    raise ValueError("No valid SVG elements found")
        
        # Parse XML to validate structure
        ET.fromstring(svg_content)
        
        return True, svg_content
        
    except ET.ParseError as e:
        return False, f"XML Parse Error: {str(e)}"
    except Exception as e:
        return False, f"Validation Error: {str(e)}"

@spaces.GPU(duration=60)  # Add duration limit
def generate_svg(prompt):
    # Clear cache before generation
    torch.cuda.empty_cache()
    
    # Format the prompt using Alpaca template
    instruction = "Generate SVG code based on the given description."
    formatted_prompt = alpaca_prompt.format(
        instruction,
        prompt,
        ""  # Empty response - model will fill this
    )
    
    inputs = tokenizer(formatted_prompt, return_tensors="pt")
    
    # Move inputs to the same device as model
    if hasattr(model, 'device'):
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    with torch.no_grad():  # Disable gradient computation to save memory
        outputs = model.generate(
            **inputs,
            max_length=1024,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            max_new_tokens=512  # Limit new tokens instead of total length
        )
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract only the response part (after "### Response:")
    response_start = generated_text.find("### Response:")
    if response_start != -1:
        svg_code = generated_text[response_start + len("### Response:"):].strip()
    else:
        # Fallback: remove the original formatted prompt
        svg_code = generated_text[len(formatted_prompt):].strip()
    
    # Validate SVG
    is_valid, result = validate_svg(svg_code)
    
    if is_valid:
        # SVG is valid
        validated_svg = result
        # Ensure the SVG has proper dimensions for display (keep moderate size)
        if 'width=' not in validated_svg or 'height=' not in validated_svg:
            validated_svg = validated_svg.replace('<svg', '<svg width="250" height="250"', 1)
        svg_display = validated_svg
    else:
        # SVG is invalid, show error message
        svg_display = f"""
        <div style="width: 250px; height: 200px; border: 2px dashed #ff6b6b; 
                    display: flex; align-items: center; justify-content: center; 
                    background-color: #fff5f5; border-radius: 8px; padding: 15px; 
                    text-align: center; color: #e03131; font-family: Arial, sans-serif;">
            <div>
                <h4 style="margin: 0 0 8px 0; color: #e03131;">🚫 Preview Not Available</h4>
                <p style="margin: 0; font-size: 12px;">Generated SVG contains errors:<br>
                <em style="font-size: 11px;">{result}</em></p>
            </div>
        </div>
        """
    
    # Clear cache after generation
    torch.cuda.empty_cache()
    
    return svg_code, svg_display

# Authentication function using HF Space secrets
def authenticate(username, password):
    """
    Authentication function for Gradio using HF Space secrets
    Returns True if credentials are valid, False otherwise
    """
    # Get credentials from HF Space secrets
    valid_username = os.getenv("user")  # This matches your secret name "user"
    valid_password = os.getenv("password")  # This matches your secret name "password"
    
    # Fallback credentials if secrets are not available (for local testing)
    if valid_username is None:
        valid_username = "user"
        print("Warning: 'user' secret not found, using fallback")
    
    if valid_password is None:
        valid_password = "password"
        print("Warning: 'password' secret not found, using fallback")
    
    return username == valid_username and password == valid_password

# Minimal CSS for slightly larger HTML preview only
custom_css = """
div[data-testid="HTML"] {
    min-height: 320px !important;
}
"""

gradio_app = gr.Interface(
    fn=generate_svg,
    inputs=gr.Textbox(
        lines=2, 
        placeholder="Describe the SVG you want (e.g., 'a red circle with blue border')..."
    ),
    outputs=[
        gr.Code(label="Generated SVG Code", language="html"),
        gr.HTML(label="SVG Preview")
    ],
    title="SVG Code Generator",
    description="Generate SVG code from natural language using a fine-tuned LLM.",
    css=custom_css
)

if __name__ == "__main__":
    gradio_app.launch(auth=(os.getenv("user"), os.getenv("password")), share=True, ssr_mode=False)