tdurbor commited on
Commit
761739f
·
1 Parent(s): 8436088

Add birefnet

Browse files
Files changed (1) hide show
  1. utils/birefnet.py +185 -0
utils/birefnet.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import asyncio
4
+ import requests
5
+ import fal_client
6
+ import json
7
+ from typing import Optional
8
+
9
+ # Configure logging to show more detailed information
10
+ logging.basicConfig(
11
+ level=logging.INFO,
12
+ format='%(asctime)s - %(levelname)s - %(message)s',
13
+ datefmt='%Y-%m-%d %H:%M:%S'
14
+ )
15
+
16
+ async def remove_background_birefnet(image_path: str) -> Optional[str]:
17
+ """Remove background using BiRefNet API asynchronously."""
18
+ logging.info(f"Starting BiRefNet processing for: {image_path}")
19
+ try:
20
+ # Submit the request
21
+ logging.info("Submitting request to BiRefNet API...")
22
+ handler = await fal_client.submit_async(
23
+ "fal-ai/birefnet/v2",
24
+ arguments={
25
+ "image_url": image_path,
26
+ "model": "General Use (Heavy)",
27
+ "operating_resolution": "2048x2048",
28
+ "output_format": "png",
29
+ "refine_foreground": True
30
+ }
31
+ )
32
+ request_id = handler.request_id
33
+ logging.info(f"🔄 Request submitted with ID: {request_id}")
34
+
35
+ # Poll for status with logs
36
+ while True:
37
+ status = await fal_client.status_async("fal-ai/birefnet/v2", request_id, with_logs=True)
38
+
39
+ # Handle logs if available
40
+ if hasattr(status, 'logs') and status.logs:
41
+ for log in status.logs:
42
+ level = log.get('level', 'INFO')
43
+ message = log.get('message', '')
44
+ logging.info(f"🔄 BiRefNet {level}: {message}")
45
+
46
+ # Check status based on object type
47
+ if isinstance(status, fal_client.Queued):
48
+ logging.info(f"⏳ Request in queue")
49
+ elif isinstance(status, fal_client.InProgress):
50
+ logging.info("🔄 Request is being processed...")
51
+ elif isinstance(status, fal_client.Completed):
52
+ logging.info("✅ Request completed")
53
+ break
54
+ elif isinstance(status, fal_client.Failed):
55
+ logging.error(f"❌ Request failed: {status.error}")
56
+ return None
57
+ else:
58
+ logging.error(f"❌ Unknown status type: {type(status)}")
59
+ return None
60
+
61
+ await asyncio.sleep(1) # Wait before checking again
62
+
63
+ # Get the result
64
+ result = await fal_client.result_async("fal-ai/birefnet/v2", request_id)
65
+
66
+ if not result or not isinstance(result, dict):
67
+ logging.error("❌ Invalid result from BiRefNet")
68
+ return None
69
+
70
+ image_data = result.get('image', {})
71
+ if not image_data or not isinstance(image_data, dict):
72
+ logging.error(f"❌ Missing or invalid image data in result: {result}")
73
+ return None
74
+
75
+ image_url = image_data.get('url')
76
+ if not image_url:
77
+ logging.error(f"❌ Missing image URL in result: {image_data}")
78
+ return None
79
+
80
+ # Log successful result with image details
81
+ logging.info(f"✅ Got image: {image_data.get('width')}x{image_data.get('height')} "
82
+ f"({image_data.get('file_size', 0) / 1024 / 1024:.1f}MB)")
83
+ return image_url
84
+
85
+ except Exception as e:
86
+ logging.error(f"❌ Unexpected error using BiRefNet API: {str(e)}", exc_info=True)
87
+ return None
88
+
89
+ async def process_single_image(input_path: str, output_path: str) -> bool:
90
+ """Process a single image asynchronously."""
91
+ try:
92
+ # Upload the file
93
+ logging.info(f"📤 Uploading to temporary storage...")
94
+ image_url = await fal_client.upload_file_async(input_path)
95
+ logging.info(f"✅ Upload successful: {image_url}")
96
+
97
+ # Process with BiRefNet
98
+ result_url = await remove_background_birefnet(image_url)
99
+
100
+ if result_url:
101
+ # Download the result
102
+ logging.info(f"📥 Downloading result...")
103
+ response = requests.get(result_url)
104
+ response.raise_for_status()
105
+
106
+ content_type = response.headers.get('content-type', '')
107
+ if 'image' not in content_type:
108
+ logging.error(f"❌ Invalid content type: {content_type}")
109
+ return False
110
+
111
+ with open(output_path, 'wb') as f:
112
+ f.write(response.content)
113
+ logging.info(f"✅ Successfully saved to {output_path}")
114
+ return True
115
+
116
+ return False
117
+ except Exception as e:
118
+ logging.error(f"❌ Error processing image: {str(e)}", exc_info=True)
119
+ return False
120
+
121
+ async def iterate_over_directory(input_dir: str, output_dir: str):
122
+ """Process all images in a directory using BiRefNet with async processing."""
123
+ logging.info(f"🚀 Starting BiRefNet processing for directory: {input_dir}")
124
+ logging.info(f"📁 Output directory: {output_dir}")
125
+
126
+ os.makedirs(output_dir, exist_ok=True)
127
+
128
+ # Get list of files to process
129
+ files = [f for f in os.listdir(input_dir)
130
+ if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
131
+ total_files = len(files)
132
+
133
+ processed = 0
134
+ skipped = 0
135
+ failed = 0
136
+
137
+ logging.info(f"📊 Found {total_files} images to process")
138
+
139
+ # Process files in batches to control concurrency
140
+ batch_size = 3 # Reduced batch size to avoid overwhelming the API
141
+ for i in range(0, len(files), batch_size):
142
+ batch = files[i:i + batch_size]
143
+ tasks = []
144
+
145
+ for filename in batch:
146
+ input_path = os.path.join(input_dir, filename)
147
+ output_path = os.path.join(output_dir, filename)
148
+
149
+ logging.info(f"\n{'='*50}")
150
+ logging.info(f"Processing [{i + len(tasks) + 1}/{total_files}]: {filename}")
151
+
152
+ if os.path.exists(output_path):
153
+ logging.info(f"⏭️ Skipping {filename} - already processed")
154
+ skipped += 1
155
+ continue
156
+
157
+ tasks.append(process_single_image(input_path, output_path))
158
+
159
+ if tasks: # Only process if we have tasks
160
+ # Wait for batch to complete
161
+ results = await asyncio.gather(*tasks, return_exceptions=True)
162
+
163
+ # Process results
164
+ for filename, result in zip(batch, results):
165
+ if isinstance(result, Exception):
166
+ logging.error(f"❌ Failed to process {filename}: {str(result)}")
167
+ failed += 1
168
+ elif result:
169
+ processed += 1
170
+ else:
171
+ failed += 1
172
+
173
+ # Add a small delay between batches
174
+ await asyncio.sleep(1)
175
+
176
+ logging.info(f"\n{'='*50}")
177
+ logging.info(f"📊 Processing Summary:")
178
+ logging.info(f"✅ Successfully processed: {processed}")
179
+ logging.info(f"⏭️ Skipped (already existed): {skipped}")
180
+ logging.info(f"❌ Failed: {failed}")
181
+ logging.info(f"📁 Total files: {total_files}")
182
+
183
+ def process_directory(input_dir: str, output_dir: str):
184
+ """Synchronous wrapper for iterate_over_directory."""
185
+ asyncio.run(iterate_over_directory(input_dir, output_dir))