floorplan-vlm-training / train_floorplan_vlm.py
manitocross's picture
Upload full SFT training script
3c3f108 verified
"""
FloorplanVLM SFT Training - Self-contained local script
Trains Qwen2.5-VL-3B with LoRA on CubiCasa5K to output structured JSON
(walls, doors, windows, rooms) from floor plan images.
Based on: FloorplanVLM (arxiv:2602.06507) + TRL VLM SFT
Usage:
pip install torch torchvision transformers trl peft datasets accelerate shapely Pillow lxml numpy tqdm huggingface_hub
huggingface-cli login
python train_floorplan_vlm.py
Auto-detects GPU vs CPU. On GPU with flash-attn installed, uses flash_attention_2.
Downloads CubiCasa5K (~5GB) automatically on first run.
"""
import os
import json
import math
import re
import zipfile
import subprocess
import torch
import numpy as np
from PIL import Image, ImageDraw
from xml.dom import minidom
from shapely.geometry import LineString, Polygon, Point
from shapely.ops import unary_union
from datasets import Dataset
from transformers import (
Qwen2_5_VLForConditionalGeneration,
AutoProcessor,
TrainerCallback,
)
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
# ══════════════════════════════════════════════════════════════════════════════
# CONFIGURATION β€” edit these
# ══════════════════════════════════════════════════════════════════════════════
MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
HUB_MODEL_ID = "manitocross/floorplan-vlm-sft" # change to your username
OUTPUT_DIR = "./floorplan-vlm-sft"
DATA_DIR = "./cubicasa_data"
ZENODO_URL = "https://zenodo.org/record/2613548/files/cubicasa5k.zip?download=1"
MAX_SAMPLES = None # None = use all ~5000 samples; set to e.g. 100 for quick test
NUM_EPOCHS = 2
BATCH_SIZE = 1
GRAD_ACCUM = 8 # effective batch size = BATCH_SIZE * GRAD_ACCUM
LEARNING_RATE = 2e-5
MAX_JSON_CHARS = 10000 # skip plans with JSON > this (won't fit in context window)
PUSH_TO_HUB = True # set False to only save locally
# ══════════════════════════════════════════════════════════════════════════════
SYSTEM_PROMPT = (
"You are a floor plan vectorization expert. Extract wall, door, window geometry "
"from floor plan images into structured JSON.\n\n"
"Output ONLY valid JSON with this schema:\n"
'{"walls":[{"id":"wall_N","start":[x,y],"end":[x,y],"thickness":T,"curvature":0,'
'"openings":[{"type":"door"|"window","center":D,"width":W}]}],'
'"rooms":[{"label":"room_type","walls":["wall_N",...]}]}\n\n'
"Coordinates normalized so longer image edge = 1024."
)
USER_PROMPT = "Vectorize this floor plan into structured JSON with all walls, doors, windows, and rooms."
ROOM_MAP = {
"Alcove":"room","Attic":"room","Ballroom":"room","Bar":"room","Basement":"room",
"Bath":"bathroom","Bedroom":"bedroom","Below150cm":"room","CarPort":"garage",
"Church":"room","Closet":"storage","ConferenceRoom":"room","Conservatory":"room",
"Counter":"room","Den":"room","Dining":"dining","DraughtLobby":"hallway",
"DressingRoom":"storage","EatingArea":"dining","Elevated":"room","Elevator":"room",
"Entry":"hallway","ExerciseRoom":"room","Garage":"garage","Garbage":"room",
"Hall":"hallway","HallWay":"hallway","HotTub":"room","Kitchen":"kitchen",
"Library":"room","LivingRoom":"living_room","Loft":"room","Lounge":"living_room",
"MediaRoom":"room","MeetingRoom":"room","Museum":"room","Nook":"room",
"Office":"office","OpenToBelow":"room","Outdoor":"outdoor","Pantry":"room",
"Reception":"room","RecreationRoom":"room","RetailSpace":"room","Room":"room",
"Sanctuary":"room","Sauna":"bathroom","ServiceRoom":"room","ServingArea":"room",
"Skylights":"room","Stable":"room","Stage":"room","StairWell":"stairwell",
"Storage":"storage","SunRoom":"room","SwimmingPool":"room","TechnicalRoom":"room",
"Theatre":"room","Undefined":"room","UserDefined":"room","Utility":"utility",
}
# ── Data Download & Extraction ───────────────────────────────────────────────
def download_and_extract():
"""Download CubiCasa5K from Zenodo and extract. Skips if already done."""
os.makedirs(DATA_DIR, exist_ok=True)
# Check if already extracted
for d in os.listdir(DATA_DIR):
dp = os.path.join(DATA_DIR, d)
if os.path.isdir(dp) and d not in ("__MACOSX",):
count = 0
for root, dirs, files in os.walk(dp):
if 'model.svg' in files:
count += 1
if count >= 10:
print(f"βœ“ Data already extracted at {dp}")
return dp
zip_path = os.path.join(DATA_DIR, "cubicasa5k.zip")
if not os.path.exists(zip_path):
print("Downloading CubiCasa5K from Zenodo (~5GB)...")
print("This may take 10-30 minutes depending on your connection.")
subprocess.run(["wget", "-q", "--show-progress", ZENODO_URL, "-O", zip_path], check=True)
print("Extracting zip...")
with zipfile.ZipFile(zip_path, 'r') as z:
z.extractall(DATA_DIR)
for d in os.listdir(DATA_DIR):
dp = os.path.join(DATA_DIR, d)
if os.path.isdir(dp) and d not in ("__MACOSX",):
return dp
return DATA_DIR
# ── SVG β†’ JSON Conversion ───────────────────────────────────────────────────
def parse_svg_polygon(element):
"""Extract polygon coords from an SVG <g> element containing a <polygon>."""
for child in element.childNodes:
if child.nodeName == "polygon":
pts = child.getAttribute("points").split(' ')
X, Y = [], []
for p in pts:
p = p.strip()
if ',' in p:
parts = p.split(',')
try:
X.append(float(parts[0]))
Y.append(float(parts[1]))
except ValueError:
pass
if len(X) >= 3:
return np.array(X), np.array(Y)
return None, None
def parse_floorplan(svg_path, img_path):
"""Parse one CubiCasa5K SVG + image β†’ FloorplanVLM JSON dict."""
img = Image.open(img_path)
w, h = img.size
scale = 1024.0 / max(w, h)
svg = minidom.parse(svg_path)
walls, openings, rooms = [], [], []
for e in svg.getElementsByTagName('g'):
eid = e.getAttribute("id")
ecls = e.getAttribute("class")
# ── Walls ──
if eid == "Wall":
X, Y = parse_svg_polygon(e)
if X is None or len(X) < 4:
continue
X, Y = X * scale, Y * scale
dx, dy = abs(max(X) - min(X)), abs(max(Y) - min(Y))
if dx < 3 and dy < 3:
continue
if dx > dy:
cy = round((min(Y) + max(Y)) / 2)
start, end = [round(min(X)), cy], [round(max(X)), cy]
thickness = max(round(dy), 1)
else:
cx = round((min(X) + max(X)) / 2)
start, end = [cx, round(min(Y))], [cx, round(max(Y))]
thickness = max(round(dx), 1)
walls.append({
'start': start, 'end': end, 'thickness': thickness,
'centerline': LineString([start, end]),
})
# Nested doors/windows inside this wall element
for child in e.getElementsByTagName('g'):
cid = child.getAttribute("id")
if cid in ("Door", "Window"):
cX, cY = parse_svg_polygon(child)
if cX is not None and len(cX) >= 3:
cX, cY = cX * scale, cY * scale
center = [round(np.mean(cX)), round(np.mean(cY))]
ow = max(round(max(abs(max(cX)-min(cX)), abs(max(cY)-min(cY)))), 1)
openings.append({'type': cid.lower(), 'center_point': center, 'width': ow})
# ── Standalone Doors/Windows ──
elif eid in ("Door", "Window"):
parent_id = e.parentNode.getAttribute("id") if e.parentNode else ""
if parent_id == "Wall":
continue # already handled as nested
X, Y = parse_svg_polygon(e)
if X is None or len(X) < 3:
continue
X, Y = X * scale, Y * scale
center = [round(np.mean(X)), round(np.mean(Y))]
ow = max(round(max(abs(max(X)-min(X)), abs(max(Y)-min(Y)))), 1)
openings.append({'type': eid.lower(), 'center_point': center, 'width': ow})
# ── Rooms ──
elif "Space " in ecls:
name = ecls.replace("Space ", "").split(' ')[0]
label = ROOM_MAP.get(name, "room")
X, Y = parse_svg_polygon(e)
if X is not None and len(X) >= 3:
X, Y = X * scale, Y * scale
try:
poly = Polygon(list(zip(Y, X)))
if not poly.is_valid:
poly = poly.buffer(0)
rooms.append({'label': label, 'polygon': poly})
except Exception:
pass
if not walls:
return None
# Assign openings to their nearest wall
for op in openings:
oc = Point(op['center_point'])
best_i, best_d = None, float('inf')
for i, w in enumerate(walls):
d = w['centerline'].distance(oc)
if d < best_d:
best_d = d
best_i = i
if best_i is not None and best_d < walls[best_i]['thickness'] * 3:
op['wall_idx'] = best_i
op['center_along'] = round(walls[best_i]['centerline'].project(oc))
# Assign rooms to walls based on proximity
for room in rooms:
rp = room['polygon']
room['wall_ids'] = []
for i, w in enumerate(walls):
try:
if rp.boundary.distance(w['centerline']) < w['thickness'] * 2:
room['wall_ids'].append(f"wall_{i+1}")
except Exception:
pass
# Build final JSON
result = {"walls": [], "rooms": []}
for i, w in enumerate(walls):
entry = {
"id": f"wall_{i+1}",
"start": w['start'],
"end": w['end'],
"thickness": w['thickness'],
"curvature": 0,
"openings": [],
}
for op in openings:
if op.get('wall_idx') == i:
entry["openings"].append({
"type": op['type'],
"center": op['center_along'],
"width": op['width'],
})
result["walls"].append(entry)
for room in rooms:
if room.get('wall_ids'):
result["rooms"].append({"label": room['label'], "walls": room['wall_ids']})
return result
# ── Dataset Building ─────────────────────────────────────────────────────────
def build_dataset_from_cubicasa(data_dir, max_samples=None):
"""Convert CubiCasa5K to SFT training dataset."""
plans = []
for root, dirs, files in os.walk(data_dir):
if 'model.svg' in files and 'F1_scaled.png' in files:
plans.append(root)
print(f"Found {len(plans)} floor plans")
if max_samples:
plans = plans[:max_samples]
records, errors = [], 0
for i, pdir in enumerate(plans):
if i % 200 == 0:
print(f" Converting {i}/{len(plans)}... ({len(records)} ok, {errors} err)")
try:
jd = parse_floorplan(
os.path.join(pdir, 'model.svg'),
os.path.join(pdir, 'F1_scaled.png'),
)
if jd and len(jd['walls']) > 0:
js = json.dumps(jd, separators=(',', ':'))
if len(js) > MAX_JSON_CHARS:
continue
img = Image.open(os.path.join(pdir, 'F1_scaled.png')).convert("RGB")
# ALL content fields = list[dict] for Arrow compatibility
records.append({
"messages": [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": USER_PROMPT}]},
{"role": "assistant", "content": [{"type": "text", "text": js}]},
],
"images": [img],
})
else:
errors += 1
except Exception as e:
errors += 1
if errors <= 3:
print(f" Error on {pdir}: {e}")
print(f"βœ“ Built {len(records)} training samples ({errors} errors)")
return Dataset.from_list(records)
def create_synthetic_fallback(n=20):
"""Create synthetic floor plans if real data download fails."""
print(f"Creating {n} synthetic floor plan samples...")
records = []
for i in range(n):
size = 256
img = Image.new('RGB', (size, size), 'white')
draw = ImageDraw.Draw(img)
s = 1024.0 / size
m = 30 + i * 3
wt = 6
draw.rectangle([m, m, size-m, size-m], outline='black', width=wt)
mid = size // 2 + i * 2
draw.line([(m, mid), (size-m, mid)], fill='black', width=wt)
dx = size // 3 + i * 8
draw.line([(dx, mid), (dx + 25, mid)], fill='white', width=wt + 2)
jd = {
"walls": [
{"id":"wall_1","start":[round(m*s),round(m*s)],"end":[round((size-m)*s),round(m*s)],"thickness":round(wt*s),"curvature":0,"openings":[]},
{"id":"wall_2","start":[round((size-m)*s),round(m*s)],"end":[round((size-m)*s),round((size-m)*s)],"thickness":round(wt*s),"curvature":0,"openings":[]},
{"id":"wall_3","start":[round((size-m)*s),round((size-m)*s)],"end":[round(m*s),round((size-m)*s)],"thickness":round(wt*s),"curvature":0,"openings":[]},
{"id":"wall_4","start":[round(m*s),round((size-m)*s)],"end":[round(m*s),round(m*s)],"thickness":round(wt*s),"curvature":0,"openings":[]},
{"id":"wall_5","start":[round(m*s),round(mid*s)],"end":[round((size-m)*s),round(mid*s)],"thickness":round(wt*s),"curvature":0,
"openings":[{"type":"door","center":round(dx*s),"width":round(25*s)}]},
],
"rooms": [
{"label":"bedroom","walls":["wall_1","wall_2","wall_5","wall_4"]},
{"label":"living_room","walls":["wall_5","wall_2","wall_3","wall_4"]},
],
}
records.append({
"messages": [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": USER_PROMPT}]},
{"role": "assistant", "content": [{"type": "text", "text": json.dumps(jd, separators=(',', ':'))}]},
],
"images": [img],
})
return Dataset.from_list(records)
# ── Training Callback ────────────────────────────────────────────────────────
class TrainLogger(TrainerCallback):
def __init__(self):
self.best = float('inf')
def on_log(self, args, state, control, logs=None, **kwargs):
if logs and "loss" in logs:
loss = logs["loss"]
if loss < self.best:
self.best = loss
lr = logs.get("learning_rate", 0)
print(f" step {state.global_step:>5d} | loss {loss:.4f} | best {self.best:.4f} | lr {lr:.2e}")
def on_train_end(self, args, state, control, **kwargs):
print(f"\n βœ… Training complete! steps={state.global_step}, best_loss={self.best:.4f}")
# ── Main ─────────────────────────────────────────────────────────────────────
def main():
use_gpu = torch.cuda.is_available()
print("=" * 64)
print(f" FloorplanVLM SFT Training ({'GPU: ' + torch.cuda.get_device_name(0) if use_gpu else 'CPU'})")
print(f" Model : {MODEL_ID}")
print(f" Output : {HUB_MODEL_ID}")
print(f" Epochs : {NUM_EPOCHS}")
print(f" Batch : {BATCH_SIZE} Γ— {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM} effective")
print(f" LR : {LEARNING_RATE}")
print(f" Max samples: {MAX_SAMPLES or 'all'}")
print("=" * 64)
# ── 1. Data ──
print("\n[1/7] Getting training data...")
try:
data_dir = download_and_extract()
dataset = build_dataset_from_cubicasa(data_dir, max_samples=MAX_SAMPLES)
if len(dataset) < 5:
raise ValueError(f"Only {len(dataset)} samples found")
except Exception as e:
print(f"⚠ Real data unavailable ({e}), using synthetic fallback")
dataset = create_synthetic_fallback(20)
print(f" Dataset size: {len(dataset)} samples")
# ── 2. Processor ──
print("\n[2/7] Loading processor...")
proc_kwargs = {"min_pixels": 256 * 28 * 28, "max_pixels": 1280 * 28 * 28} if use_gpu else \
{"min_pixels": 64 * 28 * 28, "max_pixels": 256 * 28 * 28}
processor = AutoProcessor.from_pretrained(MODEL_ID, **proc_kwargs)
# ── 3. Model ──
print("\n[3/7] Loading model...")
model_kwargs = {"torch_dtype": torch.bfloat16}
# Try flash attention on GPU
if use_gpu:
try:
import flash_attn
model_kwargs["attn_implementation"] = "flash_attention_2"
print(" Using flash_attention_2")
except ImportError:
print(" flash-attn not installed, using default attention")
else:
model_kwargs["torch_dtype"] = torch.float32
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(MODEL_ID, **model_kwargs)
trainable = sum(p.numel() for p in model.parameters())
print(f" Parameters: {trainable:,}")
# ── 4. LoRA ──
print("\n[4/7] Configuring LoRA...")
if use_gpu:
peft_config = LoraConfig(
r=16, lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
)
else:
peft_config = LoraConfig(
r=8, lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
)
# ── 5. Training Config ──
print("\n[5/7] Configuring training...")
sft_config = SFTConfig(
output_dir=OUTPUT_DIR,
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=LEARNING_RATE,
warmup_steps=20 if use_gpu else 2,
lr_scheduler_type="cosine",
bf16=use_gpu,
fp16=False,
gradient_checkpointing=True,
logging_steps=5 if use_gpu else 1,
logging_first_step=True,
logging_strategy="steps",
disable_tqdm=True,
save_steps=500 if use_gpu else 99999,
save_total_limit=2,
max_length=4096 if use_gpu else 1024,
remove_unused_columns=False,
dataset_kwargs={"skip_prepare_dataset": True},
push_to_hub=PUSH_TO_HUB,
hub_model_id=HUB_MODEL_ID if PUSH_TO_HUB else None,
report_to="none",
)
# ── 6. Train ──
print("\n[6/7] Starting training...")
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=dataset,
peft_config=peft_config,
processing_class=processor,
callbacks=[TrainLogger()],
)
est_steps = max(1, len(dataset) * NUM_EPOCHS // (BATCH_SIZE * GRAD_ACCUM))
print(f" Estimated steps: ~{est_steps}")
print("-" * 50)
trainer.train()
print("-" * 50)
# ── 7. Save & Push ──
print("\n[7/7] Saving model...")
trainer.save_model(OUTPUT_DIR)
print(f" Saved locally to {OUTPUT_DIR}/")
print(f" Files: {os.listdir(OUTPUT_DIR)}")
if PUSH_TO_HUB:
try:
trainer.push_to_hub()
print(f"\n βœ… Pushed to https://huggingface.co/{HUB_MODEL_ID}")
except Exception as e:
print(f" Push via trainer failed ({e}), trying manual upload...")
try:
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(HUB_MODEL_ID, exist_ok=True)
api.upload_folder(folder_path=OUTPUT_DIR, repo_id=HUB_MODEL_ID)
print(f" βœ… Pushed to https://huggingface.co/{HUB_MODEL_ID}")
except Exception as e2:
print(f" ❌ Push failed: {e2}")
print(f" Model saved locally at {OUTPUT_DIR}/")
# ── Quick inference test ──
print("\n[Bonus] Quick inference test...")
model.eval()
test_img = Image.new('RGB', (200, 200), 'white')
d = ImageDraw.Draw(test_img)
d.rectangle([20, 20, 180, 180], outline='black', width=5)
d.line([(20, 100), (180, 100)], fill='black', width=5)
msgs = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": USER_PROMPT}]},
]
text = processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=[test_img], return_tensors="pt", padding=True)
if use_gpu:
inputs = {k: v.to(model.device) if hasattr(v, 'to') else v for k, v in inputs.items()}
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=200, do_sample=False)
gen = processor.batch_decode(out[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
print(f" Generated: {gen[:400]}")
try:
m = re.search(r'\{[\s\S]*\}', gen)
if m:
parsed = json.loads(m.group())
print(f" βœ… Valid JSON! Walls: {len(parsed.get('walls', []))}, Rooms: {len(parsed.get('rooms', []))}")
except Exception:
print(" ⚠ JSON parse failed (may need more training)")
print("\n" + "=" * 64)
print(f" βœ… DONE!")
if PUSH_TO_HUB:
print(f" Model: https://huggingface.co/{HUB_MODEL_ID}")
print(f" Local: {os.path.abspath(OUTPUT_DIR)}/")
print("=" * 64)
if __name__ == "__main__":
main()