| """ |
| FloorplanVLM SFT Training - Self-contained local script |
| |
| Trains Qwen2.5-VL-3B with LoRA on CubiCasa5K to output structured JSON |
| (walls, doors, windows, rooms) from floor plan images. |
| |
| Based on: FloorplanVLM (arxiv:2602.06507) + TRL VLM SFT |
| |
| Usage: |
| pip install torch torchvision transformers trl peft datasets accelerate shapely Pillow lxml numpy tqdm huggingface_hub |
| huggingface-cli login |
| python train_floorplan_vlm.py |
| |
| Auto-detects GPU vs CPU. On GPU with flash-attn installed, uses flash_attention_2. |
| Downloads CubiCasa5K (~5GB) automatically on first run. |
| """ |
|
|
| import os |
| import json |
| import math |
| import re |
| import zipfile |
| import subprocess |
| import torch |
| import numpy as np |
| from PIL import Image, ImageDraw |
| from xml.dom import minidom |
| from shapely.geometry import LineString, Polygon, Point |
| from shapely.ops import unary_union |
| from datasets import Dataset |
| from transformers import ( |
| Qwen2_5_VLForConditionalGeneration, |
| AutoProcessor, |
| TrainerCallback, |
| ) |
| from trl import SFTTrainer, SFTConfig |
| from peft import LoraConfig |
|
|
| |
| |
| |
| MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" |
| HUB_MODEL_ID = "manitocross/floorplan-vlm-sft" |
| OUTPUT_DIR = "./floorplan-vlm-sft" |
| DATA_DIR = "./cubicasa_data" |
| ZENODO_URL = "https://zenodo.org/record/2613548/files/cubicasa5k.zip?download=1" |
|
|
| MAX_SAMPLES = None |
| NUM_EPOCHS = 2 |
| BATCH_SIZE = 1 |
| GRAD_ACCUM = 8 |
| LEARNING_RATE = 2e-5 |
| MAX_JSON_CHARS = 10000 |
| PUSH_TO_HUB = True |
| |
|
|
| SYSTEM_PROMPT = ( |
| "You are a floor plan vectorization expert. Extract wall, door, window geometry " |
| "from floor plan images into structured JSON.\n\n" |
| "Output ONLY valid JSON with this schema:\n" |
| '{"walls":[{"id":"wall_N","start":[x,y],"end":[x,y],"thickness":T,"curvature":0,' |
| '"openings":[{"type":"door"|"window","center":D,"width":W}]}],' |
| '"rooms":[{"label":"room_type","walls":["wall_N",...]}]}\n\n' |
| "Coordinates normalized so longer image edge = 1024." |
| ) |
|
|
| USER_PROMPT = "Vectorize this floor plan into structured JSON with all walls, doors, windows, and rooms." |
|
|
| ROOM_MAP = { |
| "Alcove":"room","Attic":"room","Ballroom":"room","Bar":"room","Basement":"room", |
| "Bath":"bathroom","Bedroom":"bedroom","Below150cm":"room","CarPort":"garage", |
| "Church":"room","Closet":"storage","ConferenceRoom":"room","Conservatory":"room", |
| "Counter":"room","Den":"room","Dining":"dining","DraughtLobby":"hallway", |
| "DressingRoom":"storage","EatingArea":"dining","Elevated":"room","Elevator":"room", |
| "Entry":"hallway","ExerciseRoom":"room","Garage":"garage","Garbage":"room", |
| "Hall":"hallway","HallWay":"hallway","HotTub":"room","Kitchen":"kitchen", |
| "Library":"room","LivingRoom":"living_room","Loft":"room","Lounge":"living_room", |
| "MediaRoom":"room","MeetingRoom":"room","Museum":"room","Nook":"room", |
| "Office":"office","OpenToBelow":"room","Outdoor":"outdoor","Pantry":"room", |
| "Reception":"room","RecreationRoom":"room","RetailSpace":"room","Room":"room", |
| "Sanctuary":"room","Sauna":"bathroom","ServiceRoom":"room","ServingArea":"room", |
| "Skylights":"room","Stable":"room","Stage":"room","StairWell":"stairwell", |
| "Storage":"storage","SunRoom":"room","SwimmingPool":"room","TechnicalRoom":"room", |
| "Theatre":"room","Undefined":"room","UserDefined":"room","Utility":"utility", |
| } |
|
|
|
|
| |
|
|
| def download_and_extract(): |
| """Download CubiCasa5K from Zenodo and extract. Skips if already done.""" |
| os.makedirs(DATA_DIR, exist_ok=True) |
|
|
| |
| for d in os.listdir(DATA_DIR): |
| dp = os.path.join(DATA_DIR, d) |
| if os.path.isdir(dp) and d not in ("__MACOSX",): |
| count = 0 |
| for root, dirs, files in os.walk(dp): |
| if 'model.svg' in files: |
| count += 1 |
| if count >= 10: |
| print(f"β Data already extracted at {dp}") |
| return dp |
|
|
| zip_path = os.path.join(DATA_DIR, "cubicasa5k.zip") |
| if not os.path.exists(zip_path): |
| print("Downloading CubiCasa5K from Zenodo (~5GB)...") |
| print("This may take 10-30 minutes depending on your connection.") |
| subprocess.run(["wget", "-q", "--show-progress", ZENODO_URL, "-O", zip_path], check=True) |
|
|
| print("Extracting zip...") |
| with zipfile.ZipFile(zip_path, 'r') as z: |
| z.extractall(DATA_DIR) |
|
|
| for d in os.listdir(DATA_DIR): |
| dp = os.path.join(DATA_DIR, d) |
| if os.path.isdir(dp) and d not in ("__MACOSX",): |
| return dp |
| return DATA_DIR |
|
|
|
|
| |
|
|
| def parse_svg_polygon(element): |
| """Extract polygon coords from an SVG <g> element containing a <polygon>.""" |
| for child in element.childNodes: |
| if child.nodeName == "polygon": |
| pts = child.getAttribute("points").split(' ') |
| X, Y = [], [] |
| for p in pts: |
| p = p.strip() |
| if ',' in p: |
| parts = p.split(',') |
| try: |
| X.append(float(parts[0])) |
| Y.append(float(parts[1])) |
| except ValueError: |
| pass |
| if len(X) >= 3: |
| return np.array(X), np.array(Y) |
| return None, None |
|
|
|
|
| def parse_floorplan(svg_path, img_path): |
| """Parse one CubiCasa5K SVG + image β FloorplanVLM JSON dict.""" |
| img = Image.open(img_path) |
| w, h = img.size |
| scale = 1024.0 / max(w, h) |
|
|
| svg = minidom.parse(svg_path) |
| walls, openings, rooms = [], [], [] |
|
|
| for e in svg.getElementsByTagName('g'): |
| eid = e.getAttribute("id") |
| ecls = e.getAttribute("class") |
|
|
| |
| if eid == "Wall": |
| X, Y = parse_svg_polygon(e) |
| if X is None or len(X) < 4: |
| continue |
| X, Y = X * scale, Y * scale |
| dx, dy = abs(max(X) - min(X)), abs(max(Y) - min(Y)) |
| if dx < 3 and dy < 3: |
| continue |
| if dx > dy: |
| cy = round((min(Y) + max(Y)) / 2) |
| start, end = [round(min(X)), cy], [round(max(X)), cy] |
| thickness = max(round(dy), 1) |
| else: |
| cx = round((min(X) + max(X)) / 2) |
| start, end = [cx, round(min(Y))], [cx, round(max(Y))] |
| thickness = max(round(dx), 1) |
| walls.append({ |
| 'start': start, 'end': end, 'thickness': thickness, |
| 'centerline': LineString([start, end]), |
| }) |
| |
| for child in e.getElementsByTagName('g'): |
| cid = child.getAttribute("id") |
| if cid in ("Door", "Window"): |
| cX, cY = parse_svg_polygon(child) |
| if cX is not None and len(cX) >= 3: |
| cX, cY = cX * scale, cY * scale |
| center = [round(np.mean(cX)), round(np.mean(cY))] |
| ow = max(round(max(abs(max(cX)-min(cX)), abs(max(cY)-min(cY)))), 1) |
| openings.append({'type': cid.lower(), 'center_point': center, 'width': ow}) |
|
|
| |
| elif eid in ("Door", "Window"): |
| parent_id = e.parentNode.getAttribute("id") if e.parentNode else "" |
| if parent_id == "Wall": |
| continue |
| X, Y = parse_svg_polygon(e) |
| if X is None or len(X) < 3: |
| continue |
| X, Y = X * scale, Y * scale |
| center = [round(np.mean(X)), round(np.mean(Y))] |
| ow = max(round(max(abs(max(X)-min(X)), abs(max(Y)-min(Y)))), 1) |
| openings.append({'type': eid.lower(), 'center_point': center, 'width': ow}) |
|
|
| |
| elif "Space " in ecls: |
| name = ecls.replace("Space ", "").split(' ')[0] |
| label = ROOM_MAP.get(name, "room") |
| X, Y = parse_svg_polygon(e) |
| if X is not None and len(X) >= 3: |
| X, Y = X * scale, Y * scale |
| try: |
| poly = Polygon(list(zip(Y, X))) |
| if not poly.is_valid: |
| poly = poly.buffer(0) |
| rooms.append({'label': label, 'polygon': poly}) |
| except Exception: |
| pass |
|
|
| if not walls: |
| return None |
|
|
| |
| for op in openings: |
| oc = Point(op['center_point']) |
| best_i, best_d = None, float('inf') |
| for i, w in enumerate(walls): |
| d = w['centerline'].distance(oc) |
| if d < best_d: |
| best_d = d |
| best_i = i |
| if best_i is not None and best_d < walls[best_i]['thickness'] * 3: |
| op['wall_idx'] = best_i |
| op['center_along'] = round(walls[best_i]['centerline'].project(oc)) |
|
|
| |
| for room in rooms: |
| rp = room['polygon'] |
| room['wall_ids'] = [] |
| for i, w in enumerate(walls): |
| try: |
| if rp.boundary.distance(w['centerline']) < w['thickness'] * 2: |
| room['wall_ids'].append(f"wall_{i+1}") |
| except Exception: |
| pass |
|
|
| |
| result = {"walls": [], "rooms": []} |
| for i, w in enumerate(walls): |
| entry = { |
| "id": f"wall_{i+1}", |
| "start": w['start'], |
| "end": w['end'], |
| "thickness": w['thickness'], |
| "curvature": 0, |
| "openings": [], |
| } |
| for op in openings: |
| if op.get('wall_idx') == i: |
| entry["openings"].append({ |
| "type": op['type'], |
| "center": op['center_along'], |
| "width": op['width'], |
| }) |
| result["walls"].append(entry) |
|
|
| for room in rooms: |
| if room.get('wall_ids'): |
| result["rooms"].append({"label": room['label'], "walls": room['wall_ids']}) |
|
|
| return result |
|
|
|
|
| |
|
|
| def build_dataset_from_cubicasa(data_dir, max_samples=None): |
| """Convert CubiCasa5K to SFT training dataset.""" |
| plans = [] |
| for root, dirs, files in os.walk(data_dir): |
| if 'model.svg' in files and 'F1_scaled.png' in files: |
| plans.append(root) |
|
|
| print(f"Found {len(plans)} floor plans") |
| if max_samples: |
| plans = plans[:max_samples] |
|
|
| records, errors = [], 0 |
| for i, pdir in enumerate(plans): |
| if i % 200 == 0: |
| print(f" Converting {i}/{len(plans)}... ({len(records)} ok, {errors} err)") |
| try: |
| jd = parse_floorplan( |
| os.path.join(pdir, 'model.svg'), |
| os.path.join(pdir, 'F1_scaled.png'), |
| ) |
| if jd and len(jd['walls']) > 0: |
| js = json.dumps(jd, separators=(',', ':')) |
| if len(js) > MAX_JSON_CHARS: |
| continue |
| img = Image.open(os.path.join(pdir, 'F1_scaled.png')).convert("RGB") |
| |
| records.append({ |
| "messages": [ |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, |
| {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": USER_PROMPT}]}, |
| {"role": "assistant", "content": [{"type": "text", "text": js}]}, |
| ], |
| "images": [img], |
| }) |
| else: |
| errors += 1 |
| except Exception as e: |
| errors += 1 |
| if errors <= 3: |
| print(f" Error on {pdir}: {e}") |
|
|
| print(f"β Built {len(records)} training samples ({errors} errors)") |
| return Dataset.from_list(records) |
|
|
|
|
| def create_synthetic_fallback(n=20): |
| """Create synthetic floor plans if real data download fails.""" |
| print(f"Creating {n} synthetic floor plan samples...") |
| records = [] |
| for i in range(n): |
| size = 256 |
| img = Image.new('RGB', (size, size), 'white') |
| draw = ImageDraw.Draw(img) |
| s = 1024.0 / size |
| m = 30 + i * 3 |
| wt = 6 |
| draw.rectangle([m, m, size-m, size-m], outline='black', width=wt) |
| mid = size // 2 + i * 2 |
| draw.line([(m, mid), (size-m, mid)], fill='black', width=wt) |
| dx = size // 3 + i * 8 |
| draw.line([(dx, mid), (dx + 25, mid)], fill='white', width=wt + 2) |
|
|
| jd = { |
| "walls": [ |
| {"id":"wall_1","start":[round(m*s),round(m*s)],"end":[round((size-m)*s),round(m*s)],"thickness":round(wt*s),"curvature":0,"openings":[]}, |
| {"id":"wall_2","start":[round((size-m)*s),round(m*s)],"end":[round((size-m)*s),round((size-m)*s)],"thickness":round(wt*s),"curvature":0,"openings":[]}, |
| {"id":"wall_3","start":[round((size-m)*s),round((size-m)*s)],"end":[round(m*s),round((size-m)*s)],"thickness":round(wt*s),"curvature":0,"openings":[]}, |
| {"id":"wall_4","start":[round(m*s),round((size-m)*s)],"end":[round(m*s),round(m*s)],"thickness":round(wt*s),"curvature":0,"openings":[]}, |
| {"id":"wall_5","start":[round(m*s),round(mid*s)],"end":[round((size-m)*s),round(mid*s)],"thickness":round(wt*s),"curvature":0, |
| "openings":[{"type":"door","center":round(dx*s),"width":round(25*s)}]}, |
| ], |
| "rooms": [ |
| {"label":"bedroom","walls":["wall_1","wall_2","wall_5","wall_4"]}, |
| {"label":"living_room","walls":["wall_5","wall_2","wall_3","wall_4"]}, |
| ], |
| } |
| records.append({ |
| "messages": [ |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, |
| {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": USER_PROMPT}]}, |
| {"role": "assistant", "content": [{"type": "text", "text": json.dumps(jd, separators=(',', ':'))}]}, |
| ], |
| "images": [img], |
| }) |
| return Dataset.from_list(records) |
|
|
|
|
| |
|
|
| class TrainLogger(TrainerCallback): |
| def __init__(self): |
| self.best = float('inf') |
|
|
| def on_log(self, args, state, control, logs=None, **kwargs): |
| if logs and "loss" in logs: |
| loss = logs["loss"] |
| if loss < self.best: |
| self.best = loss |
| lr = logs.get("learning_rate", 0) |
| print(f" step {state.global_step:>5d} | loss {loss:.4f} | best {self.best:.4f} | lr {lr:.2e}") |
|
|
| def on_train_end(self, args, state, control, **kwargs): |
| print(f"\n β
Training complete! steps={state.global_step}, best_loss={self.best:.4f}") |
|
|
|
|
| |
|
|
| def main(): |
| use_gpu = torch.cuda.is_available() |
| print("=" * 64) |
| print(f" FloorplanVLM SFT Training ({'GPU: ' + torch.cuda.get_device_name(0) if use_gpu else 'CPU'})") |
| print(f" Model : {MODEL_ID}") |
| print(f" Output : {HUB_MODEL_ID}") |
| print(f" Epochs : {NUM_EPOCHS}") |
| print(f" Batch : {BATCH_SIZE} Γ {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM} effective") |
| print(f" LR : {LEARNING_RATE}") |
| print(f" Max samples: {MAX_SAMPLES or 'all'}") |
| print("=" * 64) |
|
|
| |
| print("\n[1/7] Getting training data...") |
| try: |
| data_dir = download_and_extract() |
| dataset = build_dataset_from_cubicasa(data_dir, max_samples=MAX_SAMPLES) |
| if len(dataset) < 5: |
| raise ValueError(f"Only {len(dataset)} samples found") |
| except Exception as e: |
| print(f"β Real data unavailable ({e}), using synthetic fallback") |
| dataset = create_synthetic_fallback(20) |
|
|
| print(f" Dataset size: {len(dataset)} samples") |
|
|
| |
| print("\n[2/7] Loading processor...") |
| proc_kwargs = {"min_pixels": 256 * 28 * 28, "max_pixels": 1280 * 28 * 28} if use_gpu else \ |
| {"min_pixels": 64 * 28 * 28, "max_pixels": 256 * 28 * 28} |
| processor = AutoProcessor.from_pretrained(MODEL_ID, **proc_kwargs) |
|
|
| |
| print("\n[3/7] Loading model...") |
| model_kwargs = {"torch_dtype": torch.bfloat16} |
| |
| if use_gpu: |
| try: |
| import flash_attn |
| model_kwargs["attn_implementation"] = "flash_attention_2" |
| print(" Using flash_attention_2") |
| except ImportError: |
| print(" flash-attn not installed, using default attention") |
| else: |
| model_kwargs["torch_dtype"] = torch.float32 |
|
|
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained(MODEL_ID, **model_kwargs) |
| trainable = sum(p.numel() for p in model.parameters()) |
| print(f" Parameters: {trainable:,}") |
|
|
| |
| print("\n[4/7] Configuring LoRA...") |
| if use_gpu: |
| peft_config = LoraConfig( |
| r=16, lora_alpha=32, |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
| lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", |
| ) |
| else: |
| peft_config = LoraConfig( |
| r=8, lora_alpha=16, |
| target_modules=["q_proj", "v_proj"], |
| lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", |
| ) |
|
|
| |
| print("\n[5/7] Configuring training...") |
| sft_config = SFTConfig( |
| output_dir=OUTPUT_DIR, |
| num_train_epochs=NUM_EPOCHS, |
| per_device_train_batch_size=BATCH_SIZE, |
| gradient_accumulation_steps=GRAD_ACCUM, |
| learning_rate=LEARNING_RATE, |
| warmup_steps=20 if use_gpu else 2, |
| lr_scheduler_type="cosine", |
| bf16=use_gpu, |
| fp16=False, |
| gradient_checkpointing=True, |
| logging_steps=5 if use_gpu else 1, |
| logging_first_step=True, |
| logging_strategy="steps", |
| disable_tqdm=True, |
| save_steps=500 if use_gpu else 99999, |
| save_total_limit=2, |
| max_length=4096 if use_gpu else 1024, |
| remove_unused_columns=False, |
| dataset_kwargs={"skip_prepare_dataset": True}, |
| push_to_hub=PUSH_TO_HUB, |
| hub_model_id=HUB_MODEL_ID if PUSH_TO_HUB else None, |
| report_to="none", |
| ) |
|
|
| |
| print("\n[6/7] Starting training...") |
| trainer = SFTTrainer( |
| model=model, |
| args=sft_config, |
| train_dataset=dataset, |
| peft_config=peft_config, |
| processing_class=processor, |
| callbacks=[TrainLogger()], |
| ) |
|
|
| est_steps = max(1, len(dataset) * NUM_EPOCHS // (BATCH_SIZE * GRAD_ACCUM)) |
| print(f" Estimated steps: ~{est_steps}") |
| print("-" * 50) |
| trainer.train() |
| print("-" * 50) |
|
|
| |
| print("\n[7/7] Saving model...") |
| trainer.save_model(OUTPUT_DIR) |
| print(f" Saved locally to {OUTPUT_DIR}/") |
| print(f" Files: {os.listdir(OUTPUT_DIR)}") |
|
|
| if PUSH_TO_HUB: |
| try: |
| trainer.push_to_hub() |
| print(f"\n β
Pushed to https://huggingface.co/{HUB_MODEL_ID}") |
| except Exception as e: |
| print(f" Push via trainer failed ({e}), trying manual upload...") |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| api.create_repo(HUB_MODEL_ID, exist_ok=True) |
| api.upload_folder(folder_path=OUTPUT_DIR, repo_id=HUB_MODEL_ID) |
| print(f" β
Pushed to https://huggingface.co/{HUB_MODEL_ID}") |
| except Exception as e2: |
| print(f" β Push failed: {e2}") |
| print(f" Model saved locally at {OUTPUT_DIR}/") |
|
|
| |
| print("\n[Bonus] Quick inference test...") |
| model.eval() |
| test_img = Image.new('RGB', (200, 200), 'white') |
| d = ImageDraw.Draw(test_img) |
| d.rectangle([20, 20, 180, 180], outline='black', width=5) |
| d.line([(20, 100), (180, 100)], fill='black', width=5) |
|
|
| msgs = [ |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, |
| {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": USER_PROMPT}]}, |
| ] |
| text = processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) |
| inputs = processor(text=[text], images=[test_img], return_tensors="pt", padding=True) |
| if use_gpu: |
| inputs = {k: v.to(model.device) if hasattr(v, 'to') else v for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| out = model.generate(**inputs, max_new_tokens=200, do_sample=False) |
|
|
| gen = processor.batch_decode(out[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0] |
| print(f" Generated: {gen[:400]}") |
|
|
| try: |
| m = re.search(r'\{[\s\S]*\}', gen) |
| if m: |
| parsed = json.loads(m.group()) |
| print(f" β
Valid JSON! Walls: {len(parsed.get('walls', []))}, Rooms: {len(parsed.get('rooms', []))}") |
| except Exception: |
| print(" β JSON parse failed (may need more training)") |
|
|
| print("\n" + "=" * 64) |
| print(f" β
DONE!") |
| if PUSH_TO_HUB: |
| print(f" Model: https://huggingface.co/{HUB_MODEL_ID}") |
| print(f" Local: {os.path.abspath(OUTPUT_DIR)}/") |
| print("=" * 64) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|