File size: 3,604 Bytes
3531f81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aa9942
3531f81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import json
import tempfile
import zipfile
from datetime import datetime
from pathlib import Path
from uuid import uuid4

import gradio as gr
import numpy as np
from PIL import Image

from huggingface_hub import CommitScheduler, InferenceClient


IMAGE_DATASET_DIR = Path("image_dataset_1M") / f"train-{uuid4()}"

IMAGE_DATASET_DIR.mkdir(parents=True, exist_ok=True)
IMAGE_JSONL_PATH = IMAGE_DATASET_DIR / "metadata.jsonl"


class ZipScheduler(CommitScheduler):
    """
    Example of a custom CommitScheduler with overwritten `push_to_hub` to zip images before pushing them to the Hub.

    Workflow:
    1. Read metadata + list PNG files.
    2. Zip png files in a single archive.
    3. Create commit (metadata + archive).
    4. Delete local png files to avoid re-uploading them later.

    Only step 1 requires to activate the lock. Once the metadata is read, the lock is released and the rest of the
    process can be done without blocking the Gradio app.
    """

    def push_to_hub(self):
        # 1. Read metadata + list PNG files
        with self.lock:
            png_files = list(self.folder_path.glob("*.png"))
            if len(png_files) == 0:
                return None  # return early if nothing to commit

            # Read and delete metadata file
            metadata = IMAGE_JSONL_PATH.read_text()
            try:
                IMAGE_JSONL_PATH.unlink()
            except Exception:
                pass

        with tempfile.TemporaryDirectory() as tmpdir:
            # 2. Zip png files + metadata in a single archive
            archive_path = Path(tmpdir) / "train.zip"
            with zipfile.ZipFile(archive_path, "w", zipfile.ZIP_DEFLATED) as zip:
                # PNG files
                for png_file in png_files:
                    zip.write(filename=png_file, arcname=png_file.name)

                # Metadata
                tmp_metadata = Path(tmpdir) / "metadata.jsonl"
                tmp_metadata.write_text(metadata)
                zip.write(filename=tmp_metadata, arcname="metadata.jsonl")

            # 3. Create commit
            self.api.upload_file(
                repo_id=self.repo_id,
                repo_type=self.repo_type,
                revision=self.revision,
                path_in_repo=f"train-{uuid4()}.zip",
                path_or_fileobj=archive_path,
            )

        # 4. Delete local png files to avoid re-uploading them later
        for png_file in png_files:
            try:
                png_file.unlink()
            except Exception:
                pass


scheduler = ZipScheduler(
    repo_id="example-space-to-dataset-image-zip",
    repo_type="dataset",
    folder_path=IMAGE_DATASET_DIR,
)

client = InferenceClient()


def generate_image(prompt: str) -> Image:
    return client.text_to_image(prompt)


def save_image(prompt: str, image_array: np.ndarray) -> None:
    print("Saving: " + prompt)
    image_path = IMAGE_DATASET_DIR / f"{uuid4()}.png"

    with scheduler.lock:
        Image.fromarray(image_array).save(image_path)
        with IMAGE_JSONL_PATH.open("a") as f:
            json.dump({"prompt": prompt, "file_name": image_path.name, "datetime": datetime.now().isoformat()}, f)
            f.write("\n")


def get_demo():
    with gr.Row():
        prompt_value = gr.Textbox(label="Prompt")
        image_value = gr.Image(label="Generated image")
    text_to_image_btn = gr.Button("Generate")
    text_to_image_btn.click(fn=generate_image, inputs=prompt_value, outputs=image_value).success(
        fn=save_image,
        inputs=[prompt_value, image_value],
        outputs=None,
    )