File size: 12,336 Bytes
b613c3c
 
 
 
 
91fb4ef
 
 
76eb17f
91fb4ef
 
 
b613c3c
91fb4ef
b613c3c
91fb4ef
b613c3c
 
91fb4ef
 
 
b613c3c
 
 
91fb4ef
76eb17f
91fb4ef
 
76eb17f
91fb4ef
 
 
 
b613c3c
 
 
 
91fb4ef
 
 
 
b613c3c
91fb4ef
 
 
 
 
 
76eb17f
 
91fb4ef
 
 
 
 
b613c3c
91fb4ef
 
 
b613c3c
91fb4ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b613c3c
 
91fb4ef
 
 
 
b613c3c
91fb4ef
 
 
 
 
 
 
 
 
 
 
 
 
 
b613c3c
91fb4ef
 
 
76eb17f
91fb4ef
 
 
 
 
 
 
 
 
 
76eb17f
91fb4ef
b613c3c
 
91fb4ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76eb17f
 
b613c3c
76eb17f
 
 
 
 
 
 
 
91fb4ef
 
 
 
 
 
 
b613c3c
91fb4ef
 
 
 
 
 
 
 
 
 
b613c3c
91fb4ef
 
 
 
76eb17f
91fb4ef
 
b613c3c
91fb4ef
76eb17f
 
 
b613c3c
91fb4ef
 
b613c3c
91fb4ef
 
 
 
76eb17f
 
91fb4ef
76eb17f
91fb4ef
76eb17f
91fb4ef
 
b613c3c
91fb4ef
 
76eb17f
b613c3c
91fb4ef
 
 
 
b613c3c
91fb4ef
 
76eb17f
 
 
 
 
 
 
 
 
 
b613c3c
76eb17f
 
 
 
 
 
 
 
 
 
 
 
b613c3c
76eb17f
 
 
b613c3c
76eb17f
 
 
 
b613c3c
76eb17f
 
91fb4ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b613c3c
 
91fb4ef
 
 
b613c3c
91fb4ef
 
 
 
b613c3c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
"""
File upload handler for Video Model Studio.
Processes uploaded files including videos, images, ZIPs, and WebDataset archives.
"""

import os
import shutil
import zipfile
import tarfile
import tempfile
import gradio as gr
from pathlib import Path
from typing import List, Dict, Optional, Tuple, Any, Union
import logging
import traceback

from vms.config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
from vms.utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption, webdataset_handler

logger = logging.getLogger(__name__)

class FileUploadHandler:
    """Handles processing of uploaded files"""
    
    def process_uploaded_files(self, file_paths: List[str]) -> str:
        """Process uploaded file (ZIP, TAR, MP4, or image)
        
        Args:
            file_paths: File paths to the uploaded files from Gradio
                
        Returns:
            Status message string
        """
        if not file_paths or len(file_paths) == 0:
            logger.warning("No files provided to process_uploaded_files")
            return "No files provided"
            
        for file_path in file_paths:
            file_path = Path(file_path)
            try:
                original_name = file_path.name
                logger.info(f"Processing uploaded file: {original_name}")

                # Determine file type from name
                file_ext = file_path.suffix.lower()

                if file_ext == '.zip':
                    return self.process_zip_file(file_path)
                elif file_ext == '.tar':
                    return self.process_tar_file(file_path)
                elif file_ext == '.mp4' or file_ext == '.webm':
                    return self.process_mp4_file(file_path, original_name)
                elif is_image_file(file_path):
                    return self.process_image_file(file_path, original_name)
                else:
                    logger.error(f"Unsupported file type: {file_ext}")
                    raise gr.Error(f"Unsupported file type: {file_ext}")

            except Exception as e:
                logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True)
                raise gr.Error(f"Error processing file: {str(e)}")

    def process_image_file(self, file_path: Path, original_name: str) -> str:
        """Process a single image file
        
        Args:
            file_path: Path to the image
            original_name: Original filename
            
        Returns:
            Status message string
        """
        try:
            # Create a unique filename with configured extension
            stem = Path(original_name).stem
            target_path = STAGING_PATH / f"{stem}.{NORMALIZE_IMAGES_TO}"
            
            # If file already exists, add number suffix
            counter = 1
            while target_path.exists():
                target_path = STAGING_PATH / f"{stem}___{counter}.{NORMALIZE_IMAGES_TO}"
                counter += 1

            logger.info(f"Processing image file: {original_name} -> {target_path}")
            
            # Convert to normalized format and remove black bars
            success = normalize_image(file_path, target_path)
            
            if not success:
                logger.error(f"Failed to process image: {original_name}")
                raise gr.Error(f"Failed to process image: {original_name}")

            # Handle caption
            src_caption_path = file_path.with_suffix('.txt')
            if src_caption_path.exists():
                caption = src_caption_path.read_text()
                caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
                target_path.with_suffix('.txt').write_text(caption)

            logger.info(f"Successfully stored image: {target_path.name}")
            gr.Info(f"Successfully stored image: {target_path.name}")
            return f"Successfully stored image: {target_path.name}"

        except Exception as e:
            logger.error(f"Error processing image file: {str(e)}", exc_info=True)
            raise gr.Error(f"Error processing image file: {str(e)}")

    def process_zip_file(self, file_path: Path) -> str:
        """Process uploaded ZIP file containing media files or WebDataset tar files
        
        Args:
            file_path: Path to the uploaded ZIP file
                
        Returns:
            Status message string
        """
        try:
            video_count = 0
            image_count = 0
            tar_count = 0
            
            logger.info(f"Processing ZIP file: {file_path}")
            
            # Create temporary directory
            with tempfile.TemporaryDirectory() as temp_dir:
                # Extract ZIP
                extract_dir = Path(temp_dir) / "extracted"
                extract_dir.mkdir()
                with zipfile.ZipFile(file_path, 'r') as zip_ref:
                    zip_ref.extractall(extract_dir)
                
                # Process each file
                for root, _, files in os.walk(extract_dir):
                    for file in files:
                        if file.startswith('._'):  # Skip Mac metadata
                            continue
                            
                        file_path = Path(root) / file
                        
                        try:
                            # Check if it's a WebDataset tar file
                            if file.lower().endswith('.tar'):
                                logger.info(f"Processing WebDataset archive from ZIP: {file}")
                                # Process WebDataset shard
                                vid_count, img_count = webdataset_handler.process_webdataset_shard(
                                    file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
                                )
                                video_count += vid_count
                                image_count += img_count
                                tar_count += 1
                            elif is_video_file(file_path):
                                # Copy video to videos_to_split
                                target_path = VIDEOS_TO_SPLIT_PATH / file_path.name
                                counter = 1
                                while target_path.exists():
                                    target_path = VIDEOS_TO_SPLIT_PATH / f"{file_path.stem}___{counter}{file_path.suffix}"
                                    counter += 1
                                shutil.copy2(file_path, target_path)
                                logger.info(f"Extracted video from ZIP: {file} -> {target_path.name}")
                                video_count += 1
                                
                            elif is_image_file(file_path):
                                # Convert image and save to staging
                                target_path = STAGING_PATH / f"{file_path.stem}.{NORMALIZE_IMAGES_TO}"
                                counter = 1
                                while target_path.exists():
                                    target_path = STAGING_PATH / f"{file_path.stem}___{counter}.{NORMALIZE_IMAGES_TO}"
                                    counter += 1
                                if normalize_image(file_path, target_path):
                                    logger.info(f"Extracted image from ZIP: {file} -> {target_path.name}")
                                    image_count += 1
                                
                            # Copy associated caption file if it exists
                            txt_path = file_path.with_suffix('.txt')
                            if txt_path.exists() and not file.lower().endswith('.tar'):
                                if is_video_file(file_path):
                                    shutil.copy2(txt_path, target_path.with_suffix('.txt'))
                                    logger.info(f"Copied caption file for {file}")
                                elif is_image_file(file_path):
                                    caption = txt_path.read_text()
                                    caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
                                    target_path.with_suffix('.txt').write_text(caption)
                                    logger.info(f"Processed caption for {file}")
                                    
                        except Exception as e:
                            logger.error(f"Error processing {file_path.name} from ZIP: {str(e)}", exc_info=True)
                            continue

            # Generate status message
            parts = []
            if tar_count > 0:
                parts.append(f"{tar_count} WebDataset shard{'s' if tar_count != 1 else ''}")
            if video_count > 0:
                parts.append(f"{video_count} video{'s' if video_count != 1 else ''}")
            if image_count > 0:
                parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
                
            if not parts:
                logger.warning("No supported media files found in ZIP")
                return "No supported media files found in ZIP"
                
            status = f"Successfully stored {', '.join(parts)}"
            logger.info(status)
            gr.Info(status)
            return status
            
        except Exception as e:
            logger.error(f"Error processing ZIP: {str(e)}", exc_info=True)
            raise gr.Error(f"Error processing ZIP: {str(e)}")

    def process_tar_file(self, file_path: Path) -> str:
        """Process a WebDataset tar file
        
        Args:
            file_path: Path to the uploaded tar file
                
        Returns:
            Status message string
        """
        try:
            logger.info(f"Processing WebDataset TAR file: {file_path}")
            video_count, image_count = webdataset_handler.process_webdataset_shard(
                file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
            )
            
            # Generate status message
            parts = []
            if video_count > 0:
                parts.append(f"{video_count} video{'s' if video_count != 1 else ''}")
            if image_count > 0:
                parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
                
            if not parts:
                logger.warning("No supported media files found in WebDataset")
                return "No supported media files found in WebDataset"
                
            status = f"Successfully extracted {' and '.join(parts)} from WebDataset"
            logger.info(status)
            gr.Info(status)
            return status
            
        except Exception as e:
            logger.error(f"Error processing WebDataset tar file: {str(e)}", exc_info=True)
            raise gr.Error(f"Error processing WebDataset tar file: {str(e)}")

    def process_mp4_file(self, file_path: Path, original_name: str) -> str:
        """Process a single video file
        
        Args:
            file_path: Path to the file
            original_name: Original filename
            
        Returns:
            Status message string
        """
        try:
            # Create a unique filename
            target_path = VIDEOS_TO_SPLIT_PATH / original_name
            
            # If file already exists, add number suffix
            counter = 1
            while target_path.exists():
                stem = Path(original_name).stem
                target_path = VIDEOS_TO_SPLIT_PATH / f"{stem}___{counter}.mp4"
                counter += 1

            logger.info(f"Processing video file: {original_name} -> {target_path}")
            
            # Copy the file to the target location
            shutil.copy2(file_path, target_path)

            logger.info(f"Successfully stored video: {target_path.name}")
            gr.Info(f"Successfully stored video: {target_path.name}")
            return f"Successfully stored video: {target_path.name}"

        except Exception as e:
            logger.error(f"Error processing video file: {str(e)}", exc_info=True)
            raise gr.Error(f"Error processing video file: {str(e)}")