File size: 6,814 Bytes
b72fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
#!/usr/bin/env python3
"""
Stage 1: Data Loading and Image Downloading
Downloads and preprocesses top 2000 images from parquet file
"""

import os
import json
import requests
import pandas as pd
from PIL import Image
from io import BytesIO
import concurrent.futures
from pathlib import Path
import time
import logging
import numpy as np
from typing import Tuple

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def setup_environment():
    """Setup data directory"""
    os.makedirs('./data', exist_ok=True)
    os.makedirs('./data/images', exist_ok=True)
    os.makedirs('./data/metadata', exist_ok=True)
    return True

def load_and_sample_data(parquet_path: str, n_samples: int = 2000) -> pd.DataFrame:
    """Load parquet file and sample top N rows"""
    logger.info(f"Loading data from {parquet_path}")
    df = pd.read_parquet(parquet_path)
    logger.info(f"Loaded {len(df)} rows, sampling top {n_samples}")
    return df.head(n_samples)

def has_white_edges(img: Image.Image, threshold: int = 240) -> bool:
    """Check if image has 3 or more white edges (mean RGB > threshold)"""
    try:
        img_array = np.array(img)
        height, width = img_array.shape[:2]
        
        # Define edge thickness (check 5 pixels from each edge)
        edge_thickness = 5
        
        # Get edges
        top_edge = img_array[:edge_thickness, :].mean(axis=(0, 1))
        bottom_edge = img_array[-edge_thickness:, :].mean(axis=(0, 1))
        left_edge = img_array[:, :edge_thickness].mean(axis=(0, 1))
        right_edge = img_array[:, -edge_thickness:].mean(axis=(0, 1))
        
        # Check if edge is white (all RGB channels > threshold)
        edges = [top_edge, bottom_edge, left_edge, right_edge]
        white_edges = sum(1 for edge in edges if np.all(edge > threshold))
        
        return white_edges >= 3
    except Exception as e:
        logger.debug(f"Error checking white edges: {e}")
        return False

def download_and_process_image(url: str, target_size: int = 256) -> Image.Image:
    """Download image and resize with center crop, skip if has white edges"""
    try:
        response = requests.get(url, timeout=10, headers={'User-Agent': 'Mozilla/5.0'})
        response.raise_for_status()
        
        
        img = Image.open(BytesIO(response.content)).convert('RGB')
        
        # Check for white edges before processing
        if has_white_edges(img):
            logger.debug(f"Skipping image with white edges: {url}")
            return None
        
        # Resize and center crop to target_size x target_size
        width, height = img.size
        min_side = min(width, height)
        scale = target_size / min_side
        
        new_width = int(width * scale)
        new_height = int(height * scale)
        
        img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
        
        # Center crop
        left = (new_width - target_size) // 2
        top = (new_height - target_size) // 2
        right = left + target_size
        bottom = top + target_size
        
        img = img.crop((left, top, right, bottom))
        
        # Double-check after processing
        if has_white_edges(img):
            logger.debug(f"Skipping processed image with white edges: {url}")
            return None
            
        return img
    except Exception as e:
        logger.error(f"Error downloading {url}: {e}")
        return None

def process_single_image(args: Tuple[int, str, str, str]) -> bool:
    """Download and save a single image"""
    idx, url, hash_val, caption = args
    
    try:
        # Download and process image
        image = download_and_process_image(url)
        if image is None:
            logger.debug(f"Skipped image {idx} (white edges or download error)")
            return False
        
        # Save image
        image_path = f'./data/images/img_{idx}.png'
        image.save(image_path)
        
        # Save metadata for next stage
        metadata = {
            "idx": idx,
            "caption": caption,
            "url": url,
            "hash": hash_val,
            "image_path": image_path
        }
        
        metadata_path = f'./data/metadata/meta_{idx}.json'
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        logger.info(f"Downloaded and saved image {idx}")
        return True
        
    except Exception as e:
        logger.error(f"Error processing image {idx}: {e}")
        return False

def download_images(df: pd.DataFrame, max_workers: int = 20):
    """Download all images with parallel processing"""
    logger.info(f"Starting image download with {max_workers} workers...")
    
    args_list = [(i, row['url'], row['hash'], row['caption']) 
                 for i, (_, row) in enumerate(df.iterrows())]
    
    successful = 0
    skipped_white_edges = 0
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(process_single_image, args) for args in args_list]
        
        for i, future in enumerate(concurrent.futures.as_completed(futures)):
            if future.result():
                successful += 1
            else:
                skipped_white_edges += 1
            
            # Progress logging every 100 images
            if (i + 1) % 100 == 0:
                logger.info(f"Processed {i + 1}/{len(args_list)} images (successful: {successful}, skipped: {skipped_white_edges})")
            
            # Minimal rate limiting for high concurrency
            time.sleep(0.01)
    
    logger.info(f"Download complete: {successful}/{len(args_list)} images downloaded, {skipped_white_edges} skipped (white edges)")
    
    # Save summary
    summary = {
        "total_images": len(args_list),
        "successful_downloads": successful,
        "skipped_white_edges": skipped_white_edges,
        "download_rate": f"{successful/len(args_list)*100:.1f}%",
        "stage": "download_complete"
    }
    
    with open('./data/stage1_summary.json', 'w') as f:
        json.dump(summary, f, indent=2)

def main():
    """Main execution for Stage 1"""
    logger.info("Starting Stage 1: Data Loading and Image Downloading...")
    
    # Setup
    setup_environment()
    
    # Load data
    parquet_path = '/home/fal/partiprompt_clip/curated_part_00000.parquet'
    df = load_and_sample_data(parquet_path, n_samples=5000)
    
    # Save the dataframe for other stages
    df.to_pickle('./data/sampled_data.pkl')
    
    # Download images with optimized settings
    download_images(df, max_workers=30)
    
    logger.info("Stage 1 completed successfully!")

if __name__ == "__main__":
    main()