File size: 7,217 Bytes
761739f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fca8a9
761739f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os
import logging
import asyncio
import requests
import fal_client
import json
from typing import Optional

# Configure logging to show more detailed information
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

async def remove_background_birefnet(image_path: str) -> Optional[str]:
    """Remove background using BiRefNet API asynchronously."""
    logging.info(f"Starting BiRefNet processing for: {image_path}")
    try:
        # Submit the request
        logging.info("Submitting request to BiRefNet API...")
        handler = await fal_client.submit_async(
            "fal-ai/birefnet/v2",
            arguments={
                "image_url": image_path,
                "model": "General Use (Heavy)",
                "operating_resolution": "1024x1024",
                "output_format": "png",
                "refine_foreground": True
            }
        )
        request_id = handler.request_id
        logging.info(f"🔄 Request submitted with ID: {request_id}")

        # Poll for status with logs
        while True:
            status = await fal_client.status_async("fal-ai/birefnet/v2", request_id, with_logs=True)
            
            # Handle logs if available
            if hasattr(status, 'logs') and status.logs:
                for log in status.logs:
                    level = log.get('level', 'INFO')
                    message = log.get('message', '')
                    logging.info(f"🔄 BiRefNet {level}: {message}")
            
            # Check status based on object type
            if isinstance(status, fal_client.Queued):
                logging.info(f"⏳ Request in queue")
            elif isinstance(status, fal_client.InProgress):
                logging.info("🔄 Request is being processed...")
            elif isinstance(status, fal_client.Completed):
                logging.info("✅ Request completed")
                break
            elif isinstance(status, fal_client.Failed):
                logging.error(f"❌ Request failed: {status.error}")
                return None
            else:
                logging.error(f"❌ Unknown status type: {type(status)}")
                return None
            
            await asyncio.sleep(1)  # Wait before checking again

        # Get the result
        result = await fal_client.result_async("fal-ai/birefnet/v2", request_id)
        
        if not result or not isinstance(result, dict):
            logging.error("❌ Invalid result from BiRefNet")
            return None

        image_data = result.get('image', {})
        if not image_data or not isinstance(image_data, dict):
            logging.error(f"❌ Missing or invalid image data in result: {result}")
            return None

        image_url = image_data.get('url')
        if not image_url:
            logging.error(f"❌ Missing image URL in result: {image_data}")
            return None

        # Log successful result with image details
        logging.info(f"✅ Got image: {image_data.get('width')}x{image_data.get('height')} "
                    f"({image_data.get('file_size', 0) / 1024 / 1024:.1f}MB)")
        return image_url

    except Exception as e:
        logging.error(f"❌ Unexpected error using BiRefNet API: {str(e)}", exc_info=True)
        return None

async def process_single_image(input_path: str, output_path: str) -> bool:
    """Process a single image asynchronously."""
    try:
        # Upload the file
        logging.info(f"📤 Uploading to temporary storage...")
        image_url = await fal_client.upload_file_async(input_path)
        logging.info(f"✅ Upload successful: {image_url}")
        
        # Process with BiRefNet
        result_url = await remove_background_birefnet(image_url)
        
        if result_url:
            # Download the result
            logging.info(f"📥 Downloading result...")
            response = requests.get(result_url)
            response.raise_for_status()
            
            content_type = response.headers.get('content-type', '')
            if 'image' not in content_type:
                logging.error(f"❌ Invalid content type: {content_type}")
                return False
                
            with open(output_path, 'wb') as f:
                f.write(response.content)
            logging.info(f"✅ Successfully saved to {output_path}")
            return True
            
        return False
    except Exception as e:
        logging.error(f"❌ Error processing image: {str(e)}", exc_info=True)
        return False

async def iterate_over_directory(input_dir: str, output_dir: str):
    """Process all images in a directory using BiRefNet with async processing."""
    logging.info(f"🚀 Starting BiRefNet processing for directory: {input_dir}")
    logging.info(f"📁 Output directory: {output_dir}")
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Get list of files to process
    files = [f for f in os.listdir(input_dir) 
             if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
    total_files = len(files)
    
    processed = 0
    skipped = 0
    failed = 0
    
    logging.info(f"📊 Found {total_files} images to process")

    # Process files in batches to control concurrency
    batch_size = 3  # Reduced batch size to avoid overwhelming the API
    for i in range(0, len(files), batch_size):
        batch = files[i:i + batch_size]
        tasks = []
        
        for filename in batch:
            input_path = os.path.join(input_dir, filename)
            output_path = os.path.join(output_dir, filename)
            
            logging.info(f"\n{'='*50}")
            logging.info(f"Processing [{i + len(tasks) + 1}/{total_files}]: {filename}")
            
            if os.path.exists(output_path):
                logging.info(f"⏭️ Skipping {filename} - already processed")
                skipped += 1
                continue
            
            tasks.append(process_single_image(input_path, output_path))
        
        if tasks:  # Only process if we have tasks
            # Wait for batch to complete
            results = await asyncio.gather(*tasks, return_exceptions=True)
            
            # Process results
            for filename, result in zip(batch, results):
                if isinstance(result, Exception):
                    logging.error(f"❌ Failed to process {filename}: {str(result)}")
                    failed += 1
                elif result:
                    processed += 1
                else:
                    failed += 1
            
            # Add a small delay between batches
            await asyncio.sleep(1)

    logging.info(f"\n{'='*50}")
    logging.info(f"📊 Processing Summary:")
    logging.info(f"✅ Successfully processed: {processed}")
    logging.info(f"⏭️ Skipped (already existed): {skipped}")
    logging.info(f"❌ Failed: {failed}")
    logging.info(f"📁 Total files: {total_files}")

def process_directory(input_dir: str, output_dir: str):
    """Synchronous wrapper for iterate_over_directory."""
    asyncio.run(iterate_over_directory(input_dir, output_dir))