Update to latent-gcode diffusion model
Browse files- README.md +22 -7
- app.py +157 -89
- requirements.txt +3 -0
README.md
CHANGED
|
@@ -4,24 +4,39 @@ emoji: ✏️
|
|
| 4 |
colorFrom: gray
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
-
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
# dcode
|
| 15 |
|
| 16 |
-
Generate polargraph-compatible gcode from text prompts using
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
## Usage
|
| 19 |
|
| 20 |
-
1. Enter a prompt (e.g., "drawing of a cat")
|
| 21 |
-
2. Adjust
|
| 22 |
3. Click Generate
|
| 23 |
-
4. View preview and
|
| 24 |
|
| 25 |
## Model
|
| 26 |
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
colorFrom: gray
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: "4.44.0"
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
+
hardware: t4-small
|
| 12 |
+
short_description: Text to Polargraph Gcode via Latent Diffusion
|
| 13 |
---
|
| 14 |
|
| 15 |
# dcode
|
| 16 |
|
| 17 |
+
Generate polargraph-compatible gcode from text prompts using latent diffusion.
|
| 18 |
+
|
| 19 |
+
## How it works
|
| 20 |
+
|
| 21 |
+
1. **Text → Latent**: Stable Diffusion generates a latent representation from your text prompt
|
| 22 |
+
2. **Latent → Gcode**: Custom transformer decoder converts the latent to gcode commands
|
| 23 |
+
3. **Validation**: Coordinates are clamped to machine bounds
|
| 24 |
|
| 25 |
## Usage
|
| 26 |
|
| 27 |
+
1. Enter a prompt (e.g., "line drawing of a cat")
|
| 28 |
+
2. Adjust diffusion steps and guidance scale
|
| 29 |
3. Click Generate
|
| 30 |
+
4. View preview and copy gcode
|
| 31 |
|
| 32 |
## Model
|
| 33 |
|
| 34 |
+
- Base: Stable Diffusion 2.1
|
| 35 |
+
- Decoder: 6-layer transformer trained on 175k image-gcode pairs
|
| 36 |
+
- Final loss: 0.107
|
| 37 |
+
|
| 38 |
+
## Links
|
| 39 |
+
|
| 40 |
+
- [Model](https://huggingface.co/twarner/dcode-latent-gcode)
|
| 41 |
+
- [Dataset](https://huggingface.co/datasets/twarner/dcode-polargraph-gcode)
|
| 42 |
+
- [GitHub](https://github.com/Twarner491/dcode)
|
app.py
CHANGED
|
@@ -1,40 +1,113 @@
|
|
| 1 |
-
"""dcode Gradio Space - Text to Gcode
|
| 2 |
|
| 3 |
import re
|
| 4 |
import gradio as gr
|
| 5 |
import torch
|
| 6 |
-
from
|
| 7 |
-
|
| 8 |
-
# Available models
|
| 9 |
-
MODELS = {
|
| 10 |
-
"flan-t5-base (best)": "twarner/dcode-flan-t5-base",
|
| 11 |
-
}
|
| 12 |
|
| 13 |
# Machine limits
|
| 14 |
BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5}
|
| 15 |
|
| 16 |
-
#
|
| 17 |
-
|
| 18 |
|
| 19 |
|
| 20 |
-
def
|
| 21 |
-
"""Load and cache
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 26 |
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
return
|
| 38 |
|
| 39 |
|
| 40 |
def validate_gcode(gcode: str) -> str:
|
|
@@ -73,14 +146,11 @@ def gcode_to_svg(gcode: str) -> str:
|
|
| 73 |
x, y = 0.0, 0.0
|
| 74 |
pen_down = False
|
| 75 |
|
| 76 |
-
# Split on newlines first, then also split commands that may be on same line
|
| 77 |
-
# Handle gcode that's all on one line by splitting on G0/G1/M commands
|
| 78 |
lines = []
|
| 79 |
for line in gcode.split("\n"):
|
| 80 |
line = line.strip()
|
| 81 |
if not line:
|
| 82 |
continue
|
| 83 |
-
# Split on gcode commands (G0, G1, G28, M280, etc.)
|
| 84 |
parts = re.split(r'(?=[GM]\d)', line)
|
| 85 |
for part in parts:
|
| 86 |
part = part.strip()
|
|
@@ -88,19 +158,16 @@ def gcode_to_svg(gcode: str) -> str:
|
|
| 88 |
lines.append(part)
|
| 89 |
|
| 90 |
for line in lines:
|
| 91 |
-
|
| 92 |
-
# Pen state from M280 servo commands
|
| 93 |
if "M280" in line.upper():
|
| 94 |
match = re.search(r"S(\d+)", line, re.IGNORECASE)
|
| 95 |
if match:
|
| 96 |
angle = int(match.group(1))
|
| 97 |
was_down = pen_down
|
| 98 |
-
pen_down = angle < 50
|
| 99 |
if was_down and not pen_down and len(current_path) > 1:
|
| 100 |
paths.append(current_path[:])
|
| 101 |
current_path = []
|
| 102 |
|
| 103 |
-
# Position from G0/G1 commands
|
| 104 |
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
|
| 105 |
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
|
| 106 |
|
|
@@ -121,7 +188,6 @@ def gcode_to_svg(gcode: str) -> str:
|
|
| 121 |
if len(current_path) > 1:
|
| 122 |
paths.append(current_path)
|
| 123 |
|
| 124 |
-
# Build SVG - light mode with dark lines
|
| 125 |
w = BOUNDS["right"] - BOUNDS["left"]
|
| 126 |
h = BOUNDS["top"] - BOUNDS["bottom"]
|
| 127 |
padding = 20
|
|
@@ -129,91 +195,92 @@ def gcode_to_svg(gcode: str) -> str:
|
|
| 129 |
svg = f'''<svg xmlns="http://www.w3.org/2000/svg"
|
| 130 |
viewBox="{BOUNDS["left"] - padding} {-BOUNDS["top"] - padding} {w + 2*padding} {h + 2*padding}"
|
| 131 |
style="background: #fafafa; width: 100%; height: 500px; border-radius: 8px; border: 1px solid #e5e5e5;">
|
| 132 |
-
<!-- Work area border -->
|
| 133 |
<rect x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}"
|
| 134 |
fill="#fff" stroke="#ccc" stroke-width="2"/>
|
| 135 |
-
<!-- Center crosshair -->
|
| 136 |
<line x1="0" y1="{-BOUNDS["top"]}" x2="0" y2="{-BOUNDS["bottom"]}" stroke="#ddd" stroke-width="1"/>
|
| 137 |
<line x1="{BOUNDS["left"]}" y1="0" x2="{BOUNDS["right"]}" y2="0" stroke="#ddd" stroke-width="1"/>
|
| 138 |
-
<!-- Grid -->
|
| 139 |
-
<defs>
|
| 140 |
-
<pattern id="grid" width="100" height="100" patternUnits="userSpaceOnUse">
|
| 141 |
-
<path d="M 100 0 L 0 0 0 100" fill="none" stroke="#eee" stroke-width="0.5"/>
|
| 142 |
-
</pattern>
|
| 143 |
-
</defs>
|
| 144 |
-
<rect x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}" fill="url(#grid)"/>
|
| 145 |
'''
|
| 146 |
|
| 147 |
-
# Draw paths - dark lines
|
| 148 |
for path in paths:
|
| 149 |
if len(path) < 2:
|
| 150 |
continue
|
| 151 |
-
# SVG Y is inverted
|
| 152 |
d = " ".join(f"{'M' if i == 0 else 'L'}{p[0]:.1f},{-p[1]:.1f}" for i, p in enumerate(path))
|
| 153 |
svg += f'<path d="{d}" fill="none" stroke="#1a1a1a" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>'
|
| 154 |
|
| 155 |
-
# Stats
|
| 156 |
total_points = sum(len(p) for p in paths)
|
| 157 |
svg += f'''
|
| 158 |
<text x="{BOUNDS["left"] + 10}" y="{-BOUNDS["top"] + 25}" fill="#666" font-family="monospace" font-size="14">
|
| 159 |
Paths: {len(paths)} | Points: {total_points}
|
| 160 |
</text>
|
| 161 |
'''
|
| 162 |
-
|
| 163 |
svg += "</svg>"
|
| 164 |
return svg
|
| 165 |
|
| 166 |
|
| 167 |
-
def generate(prompt: str,
|
| 168 |
-
"""Generate gcode from prompt
|
| 169 |
if not prompt or not prompt.strip():
|
| 170 |
-
|
| 171 |
-
return "Enter a prompt to generate gcode", empty_svg
|
| 172 |
|
| 173 |
try:
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
-
|
| 178 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 179 |
-
|
| 180 |
with torch.no_grad():
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
top_p=0.9,
|
| 187 |
-
pad_token_id=tokenizer.eos_token_id,
|
| 188 |
)
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
gcode = validate_gcode(gcode)
|
| 197 |
line_count = len(gcode.split("\n"))
|
| 198 |
-
|
| 199 |
-
# Generate SVG preview
|
| 200 |
svg = gcode_to_svg(gcode)
|
| 201 |
|
| 202 |
-
gcode_with_header = f"; dcode output - {line_count} lines\n;
|
| 203 |
return gcode_with_header, svg
|
| 204 |
|
| 205 |
except Exception as e:
|
| 206 |
-
|
| 207 |
-
|
|
|
|
| 208 |
|
| 209 |
|
| 210 |
# Custom CSS
|
| 211 |
custom_css = """
|
| 212 |
-
#preview-container {
|
| 213 |
-
background: #0a0a0a;
|
| 214 |
-
border-radius: 8px;
|
| 215 |
-
padding: 0;
|
| 216 |
-
}
|
| 217 |
.gradio-container {
|
| 218 |
max-width: 1200px !important;
|
| 219 |
}
|
|
@@ -222,9 +289,11 @@ custom_css = """
|
|
| 222 |
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="emerald")) as demo:
|
| 223 |
gr.Markdown("""
|
| 224 |
# dcode
|
| 225 |
-
**Text → Polargraph Gcode
|
| 226 |
|
| 227 |
-
|
|
|
|
|
|
|
| 228 |
""")
|
| 229 |
|
| 230 |
with gr.Row():
|
|
@@ -234,24 +303,24 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="emerald")) as d
|
|
| 234 |
placeholder="drawing of a cat, abstract spiral, portrait...",
|
| 235 |
lines=2
|
| 236 |
)
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
label="
|
| 241 |
-
|
| 242 |
with gr.Row():
|
| 243 |
-
|
| 244 |
-
|
| 245 |
|
| 246 |
generate_btn = gr.Button("Generate", variant="primary", size="lg")
|
| 247 |
|
| 248 |
gr.Examples(
|
| 249 |
examples=[
|
| 250 |
-
["drawing of a cat"],
|
| 251 |
["abstract spiral pattern"],
|
| 252 |
["simple house with chimney"],
|
| 253 |
-
["portrait
|
| 254 |
-
["geometric shapes"],
|
| 255 |
],
|
| 256 |
inputs=prompt,
|
| 257 |
)
|
|
@@ -260,7 +329,6 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="emerald")) as d
|
|
| 260 |
preview = gr.HTML(
|
| 261 |
value=gcode_to_svg(""),
|
| 262 |
label="Preview",
|
| 263 |
-
elem_id="preview-container"
|
| 264 |
)
|
| 265 |
|
| 266 |
with gr.Accordion("Gcode Output", open=False):
|
|
@@ -273,12 +341,12 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="emerald")) as d
|
|
| 273 |
|
| 274 |
generate_btn.click(
|
| 275 |
generate,
|
| 276 |
-
[prompt,
|
| 277 |
[gcode_output, preview]
|
| 278 |
)
|
| 279 |
prompt.submit(
|
| 280 |
generate,
|
| 281 |
-
[prompt,
|
| 282 |
[gcode_output, preview]
|
| 283 |
)
|
| 284 |
|
|
|
|
| 1 |
+
"""dcode Gradio Space - Text to Gcode via Latent Diffusion."""
|
| 2 |
|
| 3 |
import re
|
| 4 |
import gradio as gr
|
| 5 |
import torch
|
| 6 |
+
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
# Machine limits
|
| 9 |
BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5}
|
| 10 |
|
| 11 |
+
# Model caches
|
| 12 |
+
_generator = None
|
| 13 |
|
| 14 |
|
| 15 |
+
def get_generator():
|
| 16 |
+
"""Load and cache the latent-gcode generator."""
|
| 17 |
+
global _generator
|
| 18 |
+
if _generator is None:
|
| 19 |
+
from diffusers import StableDiffusionPipeline, AutoencoderKL
|
| 20 |
+
from transformers import AutoTokenizer
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
|
| 23 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 24 |
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 25 |
|
| 26 |
+
print("Loading Stable Diffusion pipeline...")
|
| 27 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
| 28 |
+
"stabilityai/stable-diffusion-2-1-base",
|
| 29 |
+
torch_dtype=dtype,
|
| 30 |
+
safety_checker=None,
|
| 31 |
+
).to(device)
|
| 32 |
+
|
| 33 |
+
print("Loading gcode decoder...")
|
| 34 |
+
from huggingface_hub import hf_hub_download
|
| 35 |
+
|
| 36 |
+
# Download model files
|
| 37 |
+
model_path = hf_hub_download("twarner/dcode-latent-gcode", "pytorch_model.bin")
|
| 38 |
+
config_path = hf_hub_download("twarner/dcode-latent-gcode", "config.json")
|
| 39 |
+
|
| 40 |
+
import json
|
| 41 |
+
with open(config_path) as f:
|
| 42 |
+
config = json.load(f)
|
| 43 |
+
|
| 44 |
+
# Load tokenizer
|
| 45 |
+
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
|
| 46 |
+
|
| 47 |
+
# Build decoder model
|
| 48 |
+
class LatentProjector(nn.Module):
|
| 49 |
+
def __init__(self, latent_dim, hidden_size):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.proj = nn.Sequential(
|
| 52 |
+
nn.Linear(latent_dim, hidden_size * 2),
|
| 53 |
+
nn.GELU(),
|
| 54 |
+
nn.Linear(hidden_size * 2, hidden_size),
|
| 55 |
+
nn.LayerNorm(hidden_size),
|
| 56 |
+
)
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
return self.proj(x)
|
| 59 |
|
| 60 |
+
class GcodeDecoder(nn.Module):
|
| 61 |
+
def __init__(self, hidden_size, vocab_size, num_layers, num_heads, max_seq_len):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.embed = nn.Embedding(vocab_size, hidden_size)
|
| 64 |
+
self.pos_embed = nn.Embedding(max_seq_len, hidden_size)
|
| 65 |
+
layer = nn.TransformerDecoderLayer(hidden_size, num_heads, hidden_size * 4, batch_first=True)
|
| 66 |
+
self.decoder = nn.TransformerDecoder(layer, num_layers)
|
| 67 |
+
self.head = nn.Linear(hidden_size, vocab_size)
|
| 68 |
+
self.max_seq_len = max_seq_len
|
| 69 |
+
|
| 70 |
+
def forward(self, tgt, memory, tgt_mask=None):
|
| 71 |
+
pos = torch.arange(tgt.size(1), device=tgt.device)
|
| 72 |
+
x = self.embed(tgt) + self.pos_embed(pos)
|
| 73 |
+
x = self.decoder(x, memory, tgt_mask=tgt_mask)
|
| 74 |
+
return self.head(x)
|
| 75 |
|
| 76 |
+
# Initialize models
|
| 77 |
+
latent_dim = 4 * 64 * 64
|
| 78 |
+
hidden_size = config.get("hidden_size", 512)
|
| 79 |
+
vocab_size = tokenizer.vocab_size
|
| 80 |
+
num_layers = config.get("num_layers", 6)
|
| 81 |
+
num_heads = config.get("num_heads", 8)
|
| 82 |
+
max_seq_len = config.get("max_seq_len", 1024)
|
| 83 |
+
|
| 84 |
+
projector = LatentProjector(latent_dim, hidden_size).to(device, dtype)
|
| 85 |
+
decoder = GcodeDecoder(hidden_size, vocab_size, num_layers, num_heads, max_seq_len).to(device, dtype)
|
| 86 |
+
|
| 87 |
+
# Load weights
|
| 88 |
+
state_dict = torch.load(model_path, map_location=device)
|
| 89 |
+
|
| 90 |
+
proj_state = {k.replace("projector.", ""): v for k, v in state_dict.items() if k.startswith("projector.")}
|
| 91 |
+
dec_state = {k.replace("decoder.", ""): v for k, v in state_dict.items() if k.startswith("decoder.")}
|
| 92 |
+
|
| 93 |
+
projector.load_state_dict(proj_state)
|
| 94 |
+
decoder.load_state_dict(dec_state)
|
| 95 |
+
|
| 96 |
+
projector.eval()
|
| 97 |
+
decoder.eval()
|
| 98 |
+
|
| 99 |
+
_generator = {
|
| 100 |
+
"pipe": pipe,
|
| 101 |
+
"projector": projector,
|
| 102 |
+
"decoder": decoder,
|
| 103 |
+
"tokenizer": tokenizer,
|
| 104 |
+
"device": device,
|
| 105 |
+
"dtype": dtype,
|
| 106 |
+
"max_seq_len": max_seq_len,
|
| 107 |
+
}
|
| 108 |
+
print("Models loaded!")
|
| 109 |
|
| 110 |
+
return _generator
|
| 111 |
|
| 112 |
|
| 113 |
def validate_gcode(gcode: str) -> str:
|
|
|
|
| 146 |
x, y = 0.0, 0.0
|
| 147 |
pen_down = False
|
| 148 |
|
|
|
|
|
|
|
| 149 |
lines = []
|
| 150 |
for line in gcode.split("\n"):
|
| 151 |
line = line.strip()
|
| 152 |
if not line:
|
| 153 |
continue
|
|
|
|
| 154 |
parts = re.split(r'(?=[GM]\d)', line)
|
| 155 |
for part in parts:
|
| 156 |
part = part.strip()
|
|
|
|
| 158 |
lines.append(part)
|
| 159 |
|
| 160 |
for line in lines:
|
|
|
|
|
|
|
| 161 |
if "M280" in line.upper():
|
| 162 |
match = re.search(r"S(\d+)", line, re.IGNORECASE)
|
| 163 |
if match:
|
| 164 |
angle = int(match.group(1))
|
| 165 |
was_down = pen_down
|
| 166 |
+
pen_down = angle < 50
|
| 167 |
if was_down and not pen_down and len(current_path) > 1:
|
| 168 |
paths.append(current_path[:])
|
| 169 |
current_path = []
|
| 170 |
|
|
|
|
| 171 |
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
|
| 172 |
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
|
| 173 |
|
|
|
|
| 188 |
if len(current_path) > 1:
|
| 189 |
paths.append(current_path)
|
| 190 |
|
|
|
|
| 191 |
w = BOUNDS["right"] - BOUNDS["left"]
|
| 192 |
h = BOUNDS["top"] - BOUNDS["bottom"]
|
| 193 |
padding = 20
|
|
|
|
| 195 |
svg = f'''<svg xmlns="http://www.w3.org/2000/svg"
|
| 196 |
viewBox="{BOUNDS["left"] - padding} {-BOUNDS["top"] - padding} {w + 2*padding} {h + 2*padding}"
|
| 197 |
style="background: #fafafa; width: 100%; height: 500px; border-radius: 8px; border: 1px solid #e5e5e5;">
|
|
|
|
| 198 |
<rect x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}"
|
| 199 |
fill="#fff" stroke="#ccc" stroke-width="2"/>
|
|
|
|
| 200 |
<line x1="0" y1="{-BOUNDS["top"]}" x2="0" y2="{-BOUNDS["bottom"]}" stroke="#ddd" stroke-width="1"/>
|
| 201 |
<line x1="{BOUNDS["left"]}" y1="0" x2="{BOUNDS["right"]}" y2="0" stroke="#ddd" stroke-width="1"/>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
'''
|
| 203 |
|
|
|
|
| 204 |
for path in paths:
|
| 205 |
if len(path) < 2:
|
| 206 |
continue
|
|
|
|
| 207 |
d = " ".join(f"{'M' if i == 0 else 'L'}{p[0]:.1f},{-p[1]:.1f}" for i, p in enumerate(path))
|
| 208 |
svg += f'<path d="{d}" fill="none" stroke="#1a1a1a" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>'
|
| 209 |
|
|
|
|
| 210 |
total_points = sum(len(p) for p in paths)
|
| 211 |
svg += f'''
|
| 212 |
<text x="{BOUNDS["left"] + 10}" y="{-BOUNDS["top"] + 25}" fill="#666" font-family="monospace" font-size="14">
|
| 213 |
Paths: {len(paths)} | Points: {total_points}
|
| 214 |
</text>
|
| 215 |
'''
|
|
|
|
| 216 |
svg += "</svg>"
|
| 217 |
return svg
|
| 218 |
|
| 219 |
|
| 220 |
+
def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, guidance: float):
|
| 221 |
+
"""Generate gcode from text prompt via latent diffusion."""
|
| 222 |
if not prompt or not prompt.strip():
|
| 223 |
+
return "Enter a prompt to generate gcode", gcode_to_svg("")
|
|
|
|
| 224 |
|
| 225 |
try:
|
| 226 |
+
gen = get_generator()
|
| 227 |
+
pipe = gen["pipe"]
|
| 228 |
+
projector = gen["projector"]
|
| 229 |
+
decoder = gen["decoder"]
|
| 230 |
+
tokenizer = gen["tokenizer"]
|
| 231 |
+
device = gen["device"]
|
| 232 |
+
dtype = gen["dtype"]
|
| 233 |
+
max_seq_len = gen["max_seq_len"]
|
| 234 |
|
| 235 |
+
# 1. Text -> Latent via Stable Diffusion
|
|
|
|
|
|
|
| 236 |
with torch.no_grad():
|
| 237 |
+
result = pipe(
|
| 238 |
+
prompt,
|
| 239 |
+
num_inference_steps=num_steps,
|
| 240 |
+
guidance_scale=guidance,
|
| 241 |
+
output_type="latent",
|
|
|
|
|
|
|
| 242 |
)
|
| 243 |
+
latent = result.images # [1, 4, 64, 64]
|
| 244 |
+
|
| 245 |
+
# 2. Latent -> Gcode via decoder
|
| 246 |
+
with torch.no_grad():
|
| 247 |
+
# Flatten and project latent
|
| 248 |
+
latent_flat = latent.view(1, -1).to(dtype) # [1, 4*64*64]
|
| 249 |
+
memory = projector(latent_flat).unsqueeze(1) # [1, 1, hidden]
|
| 250 |
+
|
| 251 |
+
# Autoregressive decoding
|
| 252 |
+
bos_id = tokenizer.bos_token_id or tokenizer.pad_token_id
|
| 253 |
+
eos_id = tokenizer.eos_token_id
|
| 254 |
+
|
| 255 |
+
tokens = torch.tensor([[bos_id]], device=device)
|
| 256 |
+
|
| 257 |
+
for _ in range(min(max_tokens, max_seq_len - 1)):
|
| 258 |
+
logits = decoder(tokens, memory)
|
| 259 |
+
next_logits = logits[:, -1, :] / temperature
|
| 260 |
+
probs = torch.softmax(next_logits, dim=-1)
|
| 261 |
+
next_token = torch.multinomial(probs, 1)
|
| 262 |
+
tokens = torch.cat([tokens, next_token], dim=1)
|
| 263 |
+
|
| 264 |
+
if next_token.item() == eos_id:
|
| 265 |
+
break
|
| 266 |
+
|
| 267 |
+
gcode = tokenizer.decode(tokens[0], skip_special_tokens=True)
|
| 268 |
|
| 269 |
gcode = validate_gcode(gcode)
|
| 270 |
line_count = len(gcode.split("\n"))
|
|
|
|
|
|
|
| 271 |
svg = gcode_to_svg(gcode)
|
| 272 |
|
| 273 |
+
gcode_with_header = f"; dcode output - {line_count} lines\n; Prompt: {prompt}\n; Machine validated\n\n{gcode}"
|
| 274 |
return gcode_with_header, svg
|
| 275 |
|
| 276 |
except Exception as e:
|
| 277 |
+
import traceback
|
| 278 |
+
traceback.print_exc()
|
| 279 |
+
return f"; Error: {e}", gcode_to_svg("")
|
| 280 |
|
| 281 |
|
| 282 |
# Custom CSS
|
| 283 |
custom_css = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
.gradio-container {
|
| 285 |
max-width: 1200px !important;
|
| 286 |
}
|
|
|
|
| 289 |
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="emerald")) as demo:
|
| 290 |
gr.Markdown("""
|
| 291 |
# dcode
|
| 292 |
+
**Text → Polargraph Gcode via Latent Diffusion**
|
| 293 |
|
| 294 |
+
Uses Stable Diffusion to generate latents from text, then decodes to machine gcode.
|
| 295 |
+
|
| 296 |
+
[GitHub](https://github.com/Twarner491/dcode) | [Model](https://huggingface.co/twarner/dcode-latent-gcode) | [Dataset](https://huggingface.co/datasets/twarner/dcode-polargraph-gcode)
|
| 297 |
""")
|
| 298 |
|
| 299 |
with gr.Row():
|
|
|
|
| 303 |
placeholder="drawing of a cat, abstract spiral, portrait...",
|
| 304 |
lines=2
|
| 305 |
)
|
| 306 |
+
|
| 307 |
+
with gr.Row():
|
| 308 |
+
temperature = gr.Slider(0.5, 1.5, value=0.9, label="Temperature")
|
| 309 |
+
max_tokens = gr.Slider(256, 1024, value=512, step=128, label="Max Tokens")
|
| 310 |
+
|
| 311 |
with gr.Row():
|
| 312 |
+
num_steps = gr.Slider(10, 50, value=25, step=5, label="Diffusion Steps")
|
| 313 |
+
guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="Guidance Scale")
|
| 314 |
|
| 315 |
generate_btn = gr.Button("Generate", variant="primary", size="lg")
|
| 316 |
|
| 317 |
gr.Examples(
|
| 318 |
examples=[
|
| 319 |
+
["line drawing of a cat"],
|
| 320 |
["abstract spiral pattern"],
|
| 321 |
["simple house with chimney"],
|
| 322 |
+
["portrait sketch"],
|
| 323 |
+
["geometric shapes and lines"],
|
| 324 |
],
|
| 325 |
inputs=prompt,
|
| 326 |
)
|
|
|
|
| 329 |
preview = gr.HTML(
|
| 330 |
value=gcode_to_svg(""),
|
| 331 |
label="Preview",
|
|
|
|
| 332 |
)
|
| 333 |
|
| 334 |
with gr.Accordion("Gcode Output", open=False):
|
|
|
|
| 341 |
|
| 342 |
generate_btn.click(
|
| 343 |
generate,
|
| 344 |
+
[prompt, temperature, max_tokens, num_steps, guidance],
|
| 345 |
[gcode_output, preview]
|
| 346 |
)
|
| 347 |
prompt.submit(
|
| 348 |
generate,
|
| 349 |
+
[prompt, temperature, max_tokens, num_steps, guidance],
|
| 350 |
[gcode_output, preview]
|
| 351 |
)
|
| 352 |
|
requirements.txt
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
| 1 |
torch>=2.0
|
| 2 |
transformers>=4.36
|
|
|
|
| 3 |
accelerate>=0.25
|
|
|
|
|
|
| 1 |
+
gradio>=4.44.0
|
| 2 |
torch>=2.0
|
| 3 |
transformers>=4.36
|
| 4 |
+
diffusers>=0.25
|
| 5 |
accelerate>=0.25
|
| 6 |
+
huggingface_hub>=0.20
|