| |
|
| | import os |
| | import torch |
| | from torch.utils.data import Dataset, DataLoader |
| | from torch.optim import AdamW |
| | from transformers import get_scheduler |
| | from PIL import Image |
| | import json |
| | from pathlib import Path |
| | from tqdm import tqdm |
| | import requests |
| | from io import BytesIO |
| |
|
| | |
| | import sys |
| | sys.path.insert(0, str(Path(__file__).parent)) |
| | from oculus_unified_model import OculusForConditionalGeneration |
| |
|
| | class InstructionDataset(Dataset): |
| | """ |
| | Dataset for Visual Instruction Tuning. |
| | Loads from a JSON file with format: |
| | [{'image': 'path/to/img', 'conversations': [{'from': 'human', 'value': '...'}, {'from': 'gpt', 'value': '...'}]}] |
| | """ |
| | def __init__(self, processor, data_dir="data/coco", max_samples=None): |
| | self.processor = processor |
| | self.samples = [] |
| | |
| | |
| | ann_file = Path(data_dir) / "annotations" / "captions_train2017.json" |
| | if not ann_file.exists(): |
| | print(f"⚠️ COCO Captions not found at {ann_file}. Using synthetic fallback.") |
| | |
| | self.samples = [ |
| | {"image_path": "data/coco/images/000000071345.jpg", "q": "Describe this.", "a": "A car parked on the street."} |
| | ] * 100 |
| | else: |
| | print(f"Loading real instruction data from {ann_file}...") |
| | with open(ann_file) as f: |
| | coco = json.load(f) |
| | |
| | |
| | img_map = {img['id']: img['file_name'] for img in coco['images']} |
| | |
| | |
| | prompts = [ |
| | "Describe this image.", |
| | "What is going on here?", |
| | "Write a caption for this photo.", |
| | "What do you see?", |
| | "Provide a detailed description.", |
| | "Explain the scene." |
| | ] |
| | import random |
| | |
| | |
| | for ann in coco['annotations']: |
| | img_id = ann['image_id'] |
| | caption = ann['caption'] |
| | filename = img_map.get(img_id) |
| | |
| | if filename: |
| | img_path = Path(data_dir) / "images" / filename |
| | |
| | if img_path.exists(): |
| | self.samples.append({ |
| | "image_path": str(img_path), |
| | "question": random.choice(prompts), |
| | "answer": caption |
| | }) |
| | |
| | if max_samples and len(self.samples) >= max_samples: |
| | break |
| | |
| | print(f"✅ Loaded {len(self.samples)} instruction samples from COCO") |
| |
|
| | def __len__(self): |
| | return len(self.samples) |
| |
|
| | def __getitem__(self, idx): |
| | item = self.samples[idx] |
| | |
| | |
| | try: |
| | image = Image.open(item['image_path']).convert('RGB') |
| | except: |
| | image = Image.new('RGB', (224, 224)) |
| |
|
| | question = item['question'] |
| | answer = item['answer'] |
| | |
| | |
| | encoding = self.processor( |
| | images=image, |
| | text=question, |
| | padding="max_length", |
| | truncation=True, |
| | max_length=32, |
| | return_tensors="pt" |
| | ) |
| | |
| | labels = self.processor(text=answer, padding="max_length", truncation=True, max_length=32, return_tensors="pt").input_ids |
| | |
| | return { |
| | "pixel_values": encoding.pixel_values.squeeze(0), |
| | "input_ids": encoding.input_ids.squeeze(0), |
| | "attention_mask": encoding.attention_mask.squeeze(0), |
| | "labels": labels.squeeze(0) |
| | } |
| |
|
| | def train(): |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | if torch.backends.mps.is_available(): |
| | device = "mps" |
| | print(f"Using device: {device}") |
| | |
| | |
| | model_path = "checkpoints/oculus_detection_v2/final" |
| | print(f"Loading Oculus from {model_path}...") |
| | oculus = OculusForConditionalGeneration.from_pretrained(model_path) |
| | |
| | |
| | oculus.load_language_model(device=device) |
| | |
| | |
| | vqa_model = oculus.lm_vqa_model |
| | vqa_model.train() |
| | |
| | optimizer = AdamW(vqa_model.parameters(), lr=2e-5) |
| | |
| | |
| | dataset = InstructionDataset(oculus.lm_vqa_processor, max_samples=5000) |
| | dataloader = DataLoader(dataset, batch_size=4, shuffle=True) |
| | |
| | print("\n🚀 Starting Instruction Tuning (Reasoning Module)...") |
| | epochs = 4 |
| | |
| | for epoch in range(epochs): |
| | total_loss = 0 |
| | pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") |
| | |
| | for batch in pbar: |
| | batch = {k: v.to(device) for k, v in batch.items()} |
| | |
| | |
| | outputs = vqa_model(**batch) |
| | loss = outputs.loss |
| | |
| | |
| | loss.backward() |
| | optimizer.step() |
| | optimizer.zero_grad() |
| | |
| | total_loss += loss.item() |
| | pbar.set_postfix(loss=loss.item()) |
| | |
| | avg_loss = total_loss / len(dataloader) |
| | print(f"Epoch {epoch+1} Avg Loss: {avg_loss:.4f}") |
| | |
| | |
| | output_dir = Path("checkpoints/oculus_instruct_v1") |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | print(f"\n💾 Saving tuned VQA model to {output_dir}") |
| | vqa_model.save_pretrained(output_dir / "vqa_model") |
| | oculus.lm_vqa_processor.save_pretrained(output_dir / "vqa_model") |
| | |
| | print("✅ Instruction Tuning Complete!") |
| |
|
| | if __name__ == "__main__": |
| | train() |
| |
|