Add birefnet
Browse files- utils/birefnet.py +185 -0
utils/birefnet.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import asyncio
|
4 |
+
import requests
|
5 |
+
import fal_client
|
6 |
+
import json
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
# Configure logging to show more detailed information
|
10 |
+
logging.basicConfig(
|
11 |
+
level=logging.INFO,
|
12 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
13 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
14 |
+
)
|
15 |
+
|
16 |
+
async def remove_background_birefnet(image_path: str) -> Optional[str]:
|
17 |
+
"""Remove background using BiRefNet API asynchronously."""
|
18 |
+
logging.info(f"Starting BiRefNet processing for: {image_path}")
|
19 |
+
try:
|
20 |
+
# Submit the request
|
21 |
+
logging.info("Submitting request to BiRefNet API...")
|
22 |
+
handler = await fal_client.submit_async(
|
23 |
+
"fal-ai/birefnet/v2",
|
24 |
+
arguments={
|
25 |
+
"image_url": image_path,
|
26 |
+
"model": "General Use (Heavy)",
|
27 |
+
"operating_resolution": "2048x2048",
|
28 |
+
"output_format": "png",
|
29 |
+
"refine_foreground": True
|
30 |
+
}
|
31 |
+
)
|
32 |
+
request_id = handler.request_id
|
33 |
+
logging.info(f"🔄 Request submitted with ID: {request_id}")
|
34 |
+
|
35 |
+
# Poll for status with logs
|
36 |
+
while True:
|
37 |
+
status = await fal_client.status_async("fal-ai/birefnet/v2", request_id, with_logs=True)
|
38 |
+
|
39 |
+
# Handle logs if available
|
40 |
+
if hasattr(status, 'logs') and status.logs:
|
41 |
+
for log in status.logs:
|
42 |
+
level = log.get('level', 'INFO')
|
43 |
+
message = log.get('message', '')
|
44 |
+
logging.info(f"🔄 BiRefNet {level}: {message}")
|
45 |
+
|
46 |
+
# Check status based on object type
|
47 |
+
if isinstance(status, fal_client.Queued):
|
48 |
+
logging.info(f"⏳ Request in queue")
|
49 |
+
elif isinstance(status, fal_client.InProgress):
|
50 |
+
logging.info("🔄 Request is being processed...")
|
51 |
+
elif isinstance(status, fal_client.Completed):
|
52 |
+
logging.info("✅ Request completed")
|
53 |
+
break
|
54 |
+
elif isinstance(status, fal_client.Failed):
|
55 |
+
logging.error(f"❌ Request failed: {status.error}")
|
56 |
+
return None
|
57 |
+
else:
|
58 |
+
logging.error(f"❌ Unknown status type: {type(status)}")
|
59 |
+
return None
|
60 |
+
|
61 |
+
await asyncio.sleep(1) # Wait before checking again
|
62 |
+
|
63 |
+
# Get the result
|
64 |
+
result = await fal_client.result_async("fal-ai/birefnet/v2", request_id)
|
65 |
+
|
66 |
+
if not result or not isinstance(result, dict):
|
67 |
+
logging.error("❌ Invalid result from BiRefNet")
|
68 |
+
return None
|
69 |
+
|
70 |
+
image_data = result.get('image', {})
|
71 |
+
if not image_data or not isinstance(image_data, dict):
|
72 |
+
logging.error(f"❌ Missing or invalid image data in result: {result}")
|
73 |
+
return None
|
74 |
+
|
75 |
+
image_url = image_data.get('url')
|
76 |
+
if not image_url:
|
77 |
+
logging.error(f"❌ Missing image URL in result: {image_data}")
|
78 |
+
return None
|
79 |
+
|
80 |
+
# Log successful result with image details
|
81 |
+
logging.info(f"✅ Got image: {image_data.get('width')}x{image_data.get('height')} "
|
82 |
+
f"({image_data.get('file_size', 0) / 1024 / 1024:.1f}MB)")
|
83 |
+
return image_url
|
84 |
+
|
85 |
+
except Exception as e:
|
86 |
+
logging.error(f"❌ Unexpected error using BiRefNet API: {str(e)}", exc_info=True)
|
87 |
+
return None
|
88 |
+
|
89 |
+
async def process_single_image(input_path: str, output_path: str) -> bool:
|
90 |
+
"""Process a single image asynchronously."""
|
91 |
+
try:
|
92 |
+
# Upload the file
|
93 |
+
logging.info(f"📤 Uploading to temporary storage...")
|
94 |
+
image_url = await fal_client.upload_file_async(input_path)
|
95 |
+
logging.info(f"✅ Upload successful: {image_url}")
|
96 |
+
|
97 |
+
# Process with BiRefNet
|
98 |
+
result_url = await remove_background_birefnet(image_url)
|
99 |
+
|
100 |
+
if result_url:
|
101 |
+
# Download the result
|
102 |
+
logging.info(f"📥 Downloading result...")
|
103 |
+
response = requests.get(result_url)
|
104 |
+
response.raise_for_status()
|
105 |
+
|
106 |
+
content_type = response.headers.get('content-type', '')
|
107 |
+
if 'image' not in content_type:
|
108 |
+
logging.error(f"❌ Invalid content type: {content_type}")
|
109 |
+
return False
|
110 |
+
|
111 |
+
with open(output_path, 'wb') as f:
|
112 |
+
f.write(response.content)
|
113 |
+
logging.info(f"✅ Successfully saved to {output_path}")
|
114 |
+
return True
|
115 |
+
|
116 |
+
return False
|
117 |
+
except Exception as e:
|
118 |
+
logging.error(f"❌ Error processing image: {str(e)}", exc_info=True)
|
119 |
+
return False
|
120 |
+
|
121 |
+
async def iterate_over_directory(input_dir: str, output_dir: str):
|
122 |
+
"""Process all images in a directory using BiRefNet with async processing."""
|
123 |
+
logging.info(f"🚀 Starting BiRefNet processing for directory: {input_dir}")
|
124 |
+
logging.info(f"📁 Output directory: {output_dir}")
|
125 |
+
|
126 |
+
os.makedirs(output_dir, exist_ok=True)
|
127 |
+
|
128 |
+
# Get list of files to process
|
129 |
+
files = [f for f in os.listdir(input_dir)
|
130 |
+
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
|
131 |
+
total_files = len(files)
|
132 |
+
|
133 |
+
processed = 0
|
134 |
+
skipped = 0
|
135 |
+
failed = 0
|
136 |
+
|
137 |
+
logging.info(f"📊 Found {total_files} images to process")
|
138 |
+
|
139 |
+
# Process files in batches to control concurrency
|
140 |
+
batch_size = 3 # Reduced batch size to avoid overwhelming the API
|
141 |
+
for i in range(0, len(files), batch_size):
|
142 |
+
batch = files[i:i + batch_size]
|
143 |
+
tasks = []
|
144 |
+
|
145 |
+
for filename in batch:
|
146 |
+
input_path = os.path.join(input_dir, filename)
|
147 |
+
output_path = os.path.join(output_dir, filename)
|
148 |
+
|
149 |
+
logging.info(f"\n{'='*50}")
|
150 |
+
logging.info(f"Processing [{i + len(tasks) + 1}/{total_files}]: {filename}")
|
151 |
+
|
152 |
+
if os.path.exists(output_path):
|
153 |
+
logging.info(f"⏭️ Skipping {filename} - already processed")
|
154 |
+
skipped += 1
|
155 |
+
continue
|
156 |
+
|
157 |
+
tasks.append(process_single_image(input_path, output_path))
|
158 |
+
|
159 |
+
if tasks: # Only process if we have tasks
|
160 |
+
# Wait for batch to complete
|
161 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
162 |
+
|
163 |
+
# Process results
|
164 |
+
for filename, result in zip(batch, results):
|
165 |
+
if isinstance(result, Exception):
|
166 |
+
logging.error(f"❌ Failed to process {filename}: {str(result)}")
|
167 |
+
failed += 1
|
168 |
+
elif result:
|
169 |
+
processed += 1
|
170 |
+
else:
|
171 |
+
failed += 1
|
172 |
+
|
173 |
+
# Add a small delay between batches
|
174 |
+
await asyncio.sleep(1)
|
175 |
+
|
176 |
+
logging.info(f"\n{'='*50}")
|
177 |
+
logging.info(f"📊 Processing Summary:")
|
178 |
+
logging.info(f"✅ Successfully processed: {processed}")
|
179 |
+
logging.info(f"⏭️ Skipped (already existed): {skipped}")
|
180 |
+
logging.info(f"❌ Failed: {failed}")
|
181 |
+
logging.info(f"📁 Total files: {total_files}")
|
182 |
+
|
183 |
+
def process_directory(input_dir: str, output_dir: str):
|
184 |
+
"""Synchronous wrapper for iterate_over_directory."""
|
185 |
+
asyncio.run(iterate_over_directory(input_dir, output_dir))
|