File size: 17,399 Bytes
62bb9d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
from __future__ import annotations

import json
import os
import random
from io import BytesIO
from typing import Type

import av
import numpy as np
import torch
try:
    import torchaudio
    TORCH_AUDIO_AVAILABLE = True
except:
    TORCH_AUDIO_AVAILABLE = False
from PIL import Image as PILImage
from PIL.PngImagePlugin import PngInfo

import folder_paths

# used for image preview
from comfy.cli_args import args
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput


class SavedResult(dict):
    def __init__(self, filename: str, subfolder: str, type: FolderType):
        super().__init__(filename=filename, subfolder=subfolder,type=type.value)

    @property
    def filename(self) -> str:
        return self["filename"]

    @property
    def subfolder(self) -> str:
        return self["subfolder"]

    @property
    def type(self) -> FolderType:
        return FolderType(self["type"])


class SavedImages(_UIOutput):
    """A UI output class to represent one or more saved images, potentially animated."""
    def __init__(self, results: list[SavedResult], is_animated: bool = False):
        super().__init__()
        self.results = results
        self.is_animated = is_animated

    def as_dict(self) -> dict:
        data = {"images": self.results}
        if self.is_animated:
            data["animated"] = (True,)
        return data


class SavedAudios(_UIOutput):
    """UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus)."""
    def __init__(self, results: list[SavedResult]):
        super().__init__()
        self.results = results

    def as_dict(self) -> dict:
        return {"audio": self.results}


def _get_directory_by_folder_type(folder_type: FolderType) -> str:
    if folder_type == FolderType.input:
        return folder_paths.get_input_directory()
    if folder_type == FolderType.output:
        return folder_paths.get_output_directory()
    return folder_paths.get_temp_directory()


class ImageSaveHelper:
    """A helper class with static methods to handle image saving and metadata."""

    @staticmethod
    def _convert_tensor_to_pil(image_tensor: torch.Tensor) -> PILImage.Image:
        """Converts a single torch tensor to a PIL Image."""
        return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))

    @staticmethod
    def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
        """Creates a PngInfo object with prompt and extra_pnginfo."""
        if args.disable_metadata or cls is None or not cls.hidden:
            return None
        metadata = PngInfo()
        if cls.hidden.prompt:
            metadata.add_text("prompt", json.dumps(cls.hidden.prompt))
        if cls.hidden.extra_pnginfo:
            for x in cls.hidden.extra_pnginfo:
                metadata.add_text(x, json.dumps(cls.hidden.extra_pnginfo[x]))
        return metadata

    @staticmethod
    def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
        """Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
        if args.disable_metadata or cls is None or not cls.hidden:
            return None
        metadata = PngInfo()
        if cls.hidden.prompt:
            metadata.add(
                b"comf",
                "prompt".encode("latin-1", "strict")
                + b"\0"
                + json.dumps(cls.hidden.prompt).encode("latin-1", "strict"),
                after_idat=True,
            )
        if cls.hidden.extra_pnginfo:
            for x in cls.hidden.extra_pnginfo:
                metadata.add(
                    b"comf",
                    x.encode("latin-1", "strict")
                    + b"\0"
                    + json.dumps(cls.hidden.extra_pnginfo[x]).encode("latin-1", "strict"),
                    after_idat=True,
                )
        return metadata

    @staticmethod
    def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
        """Creates EXIF metadata bytes for WebP images."""
        exif_data = pil_image.getexif()
        if args.disable_metadata or cls is None or cls.hidden is None:
            return exif_data
        if cls.hidden.prompt is not None:
            exif_data[0x0110] = "prompt:{}".format(json.dumps(cls.hidden.prompt))  # EXIF 0x0110 = Model
        if cls.hidden.extra_pnginfo is not None:
            inital_exif_tag = 0x010F  # EXIF 0x010f = Make
            for key, value in cls.hidden.extra_pnginfo.items():
                exif_data[inital_exif_tag] = "{}:{}".format(key, json.dumps(value))
                inital_exif_tag -= 1
        return exif_data

    @staticmethod
    def save_images(
        images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
    ) -> list[SavedResult]:
        """Saves a batch of images as individual PNG files."""
        full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
            filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
        )
        results = []
        metadata = ImageSaveHelper._create_png_metadata(cls)
        for batch_number, image_tensor in enumerate(images):
            img = ImageSaveHelper._convert_tensor_to_pil(image_tensor)
            filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
            file = f"{filename_with_batch_num}_{counter:05}_.png"
            img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level)
            results.append(SavedResult(file, subfolder, folder_type))
            counter += 1
        return results

    @staticmethod
    def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
        """Saves a batch of images and returns a UI object for the node output."""
        return SavedImages(
                ImageSaveHelper.save_images(
                images,
                filename_prefix=filename_prefix,
                folder_type=FolderType.output,
                cls=cls,
                compress_level=compress_level,
            )
        )

    @staticmethod
    def save_animated_png(
        images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
    ) -> SavedResult:
        """Saves a batch of images as a single animated PNG."""
        full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
            filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
        )
        pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
        metadata = ImageSaveHelper._create_animated_png_metadata(cls)
        file = f"{filename}_{counter:05}_.png"
        save_path = os.path.join(full_output_folder, file)
        pil_images[0].save(
            save_path,
            pnginfo=metadata,
            compress_level=compress_level,
            save_all=True,
            duration=int(1000.0 / fps),
            append_images=pil_images[1:],
        )
        return SavedResult(file, subfolder, folder_type)

    @staticmethod
    def get_save_animated_png_ui(
        images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
    ) -> SavedImages:
        """Saves an animated PNG and returns a UI object for the node output."""
        result = ImageSaveHelper.save_animated_png(
            images,
            filename_prefix=filename_prefix,
            folder_type=FolderType.output,
            cls=cls,
            fps=fps,
            compress_level=compress_level,
        )
        return SavedImages([result], is_animated=len(images) > 1)

    @staticmethod
    def save_animated_webp(
        images,
        filename_prefix: str,
        folder_type: FolderType,
        cls: Type[ComfyNode] | None,
        fps: float,
        lossless: bool,
        quality: int,
        method: int,
    ) -> SavedResult:
        """Saves a batch of images as a single animated WebP."""
        full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
            filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
        )
        pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
        pil_exif = ImageSaveHelper._create_webp_metadata(pil_images[0], cls)
        file = f"{filename}_{counter:05}_.webp"
        pil_images[0].save(
            os.path.join(full_output_folder, file),
            save_all=True,
            duration=int(1000.0 / fps),
            append_images=pil_images[1:],
            exif=pil_exif,
            lossless=lossless,
            quality=quality,
            method=method,
        )
        return SavedResult(file, subfolder, folder_type)

    @staticmethod
    def get_save_animated_webp_ui(
        images,
        filename_prefix: str,
        cls: Type[ComfyNode] | None,
        fps: float,
        lossless: bool,
        quality: int,
        method: int,
    ) -> SavedImages:
        """Saves an animated WebP and returns a UI object for the node output."""
        result = ImageSaveHelper.save_animated_webp(
            images,
            filename_prefix=filename_prefix,
            folder_type=FolderType.output,
            cls=cls,
            fps=fps,
            lossless=lossless,
            quality=quality,
            method=method,
        )
        return SavedImages([result], is_animated=len(images) > 1)


class AudioSaveHelper:
    """A helper class with static methods to handle audio saving and metadata."""
    _OPUS_RATES = [8000, 12000, 16000, 24000, 48000]

    @staticmethod
    def save_audio(
        audio: dict,
        filename_prefix: str,
        folder_type: FolderType,
        cls: Type[ComfyNode] | None,
        format: str = "flac",
        quality: str = "128k",
    ) -> list[SavedResult]:
        full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
            filename_prefix, _get_directory_by_folder_type(folder_type)
        )

        metadata = {}
        if not args.disable_metadata and cls is not None:
            if cls.hidden.prompt is not None:
                metadata["prompt"] = json.dumps(cls.hidden.prompt)
            if cls.hidden.extra_pnginfo is not None:
                for x in cls.hidden.extra_pnginfo:
                    metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])

        results = []
        for batch_number, waveform in enumerate(audio["waveform"].cpu()):
            filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
            file = f"{filename_with_batch_num}_{counter:05}_.{format}"
            output_path = os.path.join(full_output_folder, file)

            # Use original sample rate initially
            sample_rate = audio["sample_rate"]

            # Handle Opus sample rate requirements
            if format == "opus":
                if sample_rate > 48000:
                    sample_rate = 48000
                elif sample_rate not in AudioSaveHelper._OPUS_RATES:
                    # Find the next highest supported rate
                    for rate in sorted(AudioSaveHelper._OPUS_RATES):
                        if rate > sample_rate:
                            sample_rate = rate
                            break
                    if sample_rate not in AudioSaveHelper._OPUS_RATES:  # Fallback if still not supported
                        sample_rate = 48000

                # Resample if necessary
                if sample_rate != audio["sample_rate"]:
                    if not TORCH_AUDIO_AVAILABLE:
                        raise Exception("torchaudio is not available; cannot resample audio.")
                    waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)

            # Create output with specified format
            output_buffer = BytesIO()
            output_container = av.open(output_buffer, mode="w", format=format)

            # Set metadata on the container
            for key, value in metadata.items():
                output_container.metadata[key] = value

            # Set up the output stream with appropriate properties
            if format == "opus":
                out_stream = output_container.add_stream("libopus", rate=sample_rate)
                if quality == "64k":
                    out_stream.bit_rate = 64000
                elif quality == "96k":
                    out_stream.bit_rate = 96000
                elif quality == "128k":
                    out_stream.bit_rate = 128000
                elif quality == "192k":
                    out_stream.bit_rate = 192000
                elif quality == "320k":
                    out_stream.bit_rate = 320000
            elif format == "mp3":
                out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
                if quality == "V0":
                    # TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
                    out_stream.codec_context.qscale = 1
                elif quality == "128k":
                    out_stream.bit_rate = 128000
                elif quality == "320k":
                    out_stream.bit_rate = 320000
            else:  # format == "flac":
                out_stream = output_container.add_stream("flac", rate=sample_rate)

            frame = av.AudioFrame.from_ndarray(
                waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
                format="flt",
                layout="mono" if waveform.shape[0] == 1 else "stereo",
            )
            frame.sample_rate = sample_rate
            frame.pts = 0
            output_container.mux(out_stream.encode(frame))

            # Flush encoder
            output_container.mux(out_stream.encode(None))

            # Close containers
            output_container.close()

            # Write the output to file
            output_buffer.seek(0)
            with open(output_path, "wb") as f:
                f.write(output_buffer.getbuffer())

            results.append(SavedResult(file, subfolder, folder_type))
            counter += 1

        return results

    @staticmethod
    def get_save_audio_ui(
        audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
    ) -> SavedAudios:
        """Save and instantly wrap for UI."""
        return SavedAudios(
            AudioSaveHelper.save_audio(
                audio,
                filename_prefix=filename_prefix,
                folder_type=FolderType.output,
                cls=cls,
                format=format,
                quality=quality,
            )
        )


class PreviewImage(_UIOutput):
    def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
        self.values = ImageSaveHelper.save_images(
            image,
            filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
            folder_type=FolderType.temp,
            cls=cls,
            compress_level=1,
        )
        self.animated = animated

    def as_dict(self):
        return {
            "images": self.values,
            "animated": (self.animated,)
        }


class PreviewMask(PreviewImage):
    def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNode=None, **kwargs):
        preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
        super().__init__(preview, animated, cls, **kwargs)


class PreviewAudio(_UIOutput):
    def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
        self.values = AudioSaveHelper.save_audio(
            audio,
            filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
            folder_type=FolderType.temp,
            cls=cls,
            format="flac",
            quality="128k",
        )

    def as_dict(self) -> dict:
        return {"audio": self.values}


class PreviewVideo(_UIOutput):
    def __init__(self, values: list[SavedResult | dict], **kwargs):
        self.values = values

    def as_dict(self):
        return {"images": self.values, "animated": (True,)}


class PreviewUI3D(_UIOutput):
    def __init__(self, model_file, camera_info, **kwargs):
        self.model_file = model_file
        self.camera_info = camera_info

    def as_dict(self):
        return {"result": [self.model_file, self.camera_info]}


class PreviewText(_UIOutput):
    def __init__(self, value: str, **kwargs):
        self.value = value

    def as_dict(self):
        return {"text": (self.value,)}


class _UI:
    SavedResult = SavedResult
    SavedImages = SavedImages
    SavedAudios = SavedAudios
    ImageSaveHelper = ImageSaveHelper
    AudioSaveHelper = AudioSaveHelper
    PreviewImage = PreviewImage
    PreviewMask = PreviewMask
    PreviewAudio = PreviewAudio
    PreviewVideo = PreviewVideo
    PreviewUI3D = PreviewUI3D
    PreviewText = PreviewText