Wauplin HF staff commited on
Commit
235adef
1 Parent(s): 02afadf

User history v2

Browse files
.gitignore CHANGED
@@ -137,4 +137,8 @@ dmypy.json
137
  .ruff_cache
138
 
139
  # Spell checker config
140
- cspell.json
 
 
 
 
 
137
  .ruff_cache
138
 
139
  # Spell checker config
140
+ cspell.json
141
+
142
+ mock
143
+ user_history
144
+ _history_snapshots
app.py CHANGED
@@ -7,26 +7,31 @@ import tempfile
7
  import gradio as gr
8
  from gradio_client import Client
9
 
 
 
10
 
11
  client = Client("runwayml/stable-diffusion-v1-5")
12
 
13
 
14
- def generate(prompt: str) -> tuple[str, list[str]]:
15
- negative_prompt = ""
16
- guidance_scale = 9.0
17
  out_dir = client.predict(prompt, fn_index=1)
18
 
19
- config = {
20
  "prompt": prompt,
21
- "negative_prompt": negative_prompt,
22
- "guidance_scale": guidance_scale,
23
  }
24
- with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as config_file:
25
- json.dump(config, config_file)
26
 
27
  with (pathlib.Path(out_dir) / "captions.json").open() as f:
28
  paths = list(json.load(f).keys())
29
- return paths
 
 
 
 
 
30
 
31
 
32
  with gr.Blocks(css="style.css") as demo:
@@ -39,15 +44,13 @@ with gr.Blocks(css="style.css") as demo:
39
  height="600px",
40
  object_fit="scale-down",
41
  )
 
42
 
43
- prompt.submit(
44
- fn=generate,
45
- inputs=prompt,
46
- outputs=gallery,
47
- )
48
-
49
  with gr.Tab("Past generations"):
50
- gr.Markdown("building...")
51
 
52
  if __name__ == "__main__":
53
- demo.launch()
 
7
  import gradio as gr
8
  from gradio_client import Client
9
 
10
+ import user_history
11
+
12
 
13
  client = Client("runwayml/stable-diffusion-v1-5")
14
 
15
 
16
+ def generate(prompt: str, profile: gr.OAuthProfile | None) -> tuple[str, list[str]]:
 
 
17
  out_dir = client.predict(prompt, fn_index=1)
18
 
19
+ metadata = {
20
  "prompt": prompt,
21
+ "negative_prompt": "",
22
+ "guidance_scale": 0.9,
23
  }
24
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as metadata_file:
25
+ json.dump(metadata, metadata_file)
26
 
27
  with (pathlib.Path(out_dir) / "captions.json").open() as f:
28
  paths = list(json.load(f).keys())
29
+
30
+ # Saving user history
31
+ for path in paths:
32
+ user_history.save_image(label=prompt, image=path, profile=profile, metadata=metadata)
33
+
34
+ return paths # type: ignore
35
 
36
 
37
  with gr.Blocks(css="style.css") as demo:
 
44
  height="600px",
45
  object_fit="scale-down",
46
  )
47
+ prompt.submit(fn=generate, inputs=prompt, outputs=gallery)
48
 
49
+ with gr.Blocks() as demo_with_history:
50
+ with gr.Tab("App"):
51
+ demo.render()
 
 
 
52
  with gr.Tab("Past generations"):
53
+ user_history.render()
54
 
55
  if __name__ == "__main__":
56
+ demo_with_history.queue().launch()
assets/icon_delete.png ADDED
assets/icon_download.png ADDED
assets/icon_refresh.png ADDED
gallery_history.py DELETED
@@ -1,122 +0,0 @@
1
- """
2
- How to use:
3
- 1. Create a Space with a Persistent Storage attached. Filesystem will be available under `/data`.
4
- 2. Add `hf_oauth: true` to the Space metadata (README.md). Make sure to have Gradio>=3.41.0 configured.
5
- 3. Add `HISTORY_FOLDER` as a Space variable (example. `"/data/history"`).
6
- 4. Add `filelock` as dependency in `requirements.txt`.
7
- 5. Add history gallery to your Gradio app:
8
- a. Add imports: `from gallery_history import fetch_gallery_history, show_gallery_history`
9
- a. Add `history = show_gallery_history()` within `gr.Blocks` context.
10
- b. Add `.then(fn=fetch_gallery_history, inputs=[prompt, result], outputs=history)` on the generate event.
11
- """
12
- import json
13
- import os
14
- import shutil
15
- from pathlib import Path
16
- from typing import Dict, List, Optional, Tuple
17
- from uuid import uuid4
18
-
19
- import gradio as gr
20
- from filelock import FileLock
21
-
22
-
23
- _folder = os.environ.get("HISTORY_FOLDER")
24
- if _folder is None:
25
- print(
26
- "'HISTORY_FOLDER' environment variable not set. User history will be saved "
27
- "locally and will be lost when the Space instance is restarted."
28
- )
29
- _folder = Path(__file__).parent / "history"
30
- HISTORY_FOLDER_PATH = Path(_folder)
31
-
32
- IMAGES_FOLDER_PATH = HISTORY_FOLDER_PATH / "images"
33
- IMAGES_FOLDER_PATH.mkdir(parents=True, exist_ok=True)
34
-
35
-
36
- def show_gallery_history():
37
- gr.Markdown(
38
- "## Your past generations\n\n(Log in to keep a gallery of your previous generations."
39
- " Your history will be saved and available on your next visit.)"
40
- )
41
- with gr.Column():
42
- with gr.Row():
43
- gr.LoginButton(min_width=250)
44
- gr.LogoutButton(min_width=250)
45
- gallery = gr.Gallery(
46
- label="Past images",
47
- show_label=True,
48
- elem_id="gallery",
49
- object_fit="contain",
50
- columns=3,
51
- height=300,
52
- preview=False,
53
- show_share_button=False,
54
- show_download_button=False,
55
- )
56
- gr.Markdown("Make sure to save your images from time to time, this gallery may be deleted in the future.")
57
- gallery.attach_load_event(fetch_gallery_history, every=None)
58
- return gallery
59
-
60
-
61
- def fetch_gallery_history(
62
- prompt: Optional[str] = None,
63
- result: Optional[Dict] = None,
64
- user: Optional[gr.OAuthProfile] = None,
65
- ):
66
- if user is None:
67
- return []
68
- try:
69
- if prompt is not None and result is not None: # None values means no new images
70
- return _update_user_history(user["preferred_username"], [(item["name"], prompt) for item in result])
71
- else:
72
- return _read_user_history(user["preferred_username"])
73
- except Exception as e:
74
- raise gr.Error(f"Error while fetching history: {e}") from e
75
-
76
-
77
- ####################
78
- # Internal helpers #
79
- ####################
80
-
81
-
82
- def _read_user_history(username: str) -> List[Tuple[str, str]]:
83
- """Return saved history for that user."""
84
- with _user_lock(username):
85
- path = _user_history_path(username)
86
- if path.exists():
87
- return json.loads(path.read_text())
88
- return [] # No history yet
89
-
90
-
91
- def _update_user_history(username: str, new_images: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
92
- """Update history for that user and return it."""
93
- with _user_lock(username):
94
- # Read existing
95
- path = _user_history_path(username)
96
- if path.exists():
97
- images = json.loads(path.read_text())
98
- else:
99
- images = [] # No history yet
100
-
101
- # Copy images to persistent folder
102
- images = [(_copy_image(src_path), prompt) for src_path, prompt in new_images] + images
103
-
104
- # Save and return
105
- path.write_text(json.dumps(images))
106
- return images
107
-
108
-
109
- def _user_history_path(username: str) -> Path:
110
- return HISTORY_FOLDER_PATH / f"{username}.json"
111
-
112
-
113
- def _user_lock(username: str) -> FileLock:
114
- """Ensure history is not corrupted if concurrent calls."""
115
- return FileLock(f"{_user_history_path(username)}.lock")
116
-
117
-
118
- def _copy_image(src: str) -> str:
119
- """Copy image to the persistent storage."""
120
- dst = IMAGES_FOLDER_PATH / f"{uuid4().hex}_{Path(src).name}" # keep file ext
121
- shutil.copyfile(src, dst)
122
- return str(dst)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pyproject.toml CHANGED
@@ -1,6 +1,6 @@
1
  [tool.black]
2
  line-length = 119
3
- target_version = ['py37', 'py38', 'py39', 'py310']
4
  preview = true
5
 
6
  [tool.mypy]
 
1
  [tool.black]
2
  line-length = 119
3
+ target_version = ['py38', 'py39', 'py310']
4
  preview = true
5
 
6
  [tool.mypy]
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
- gradio>=3.44
2
  huggingface_hub>=0.17
 
3
 
4
  # dev-deps
5
  ruff
 
1
+ gradio[oauth]>=3.44
2
  huggingface_hub>=0.17
3
+ Pillow
4
 
5
  # dev-deps
6
  ruff
user_history.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import shutil
4
+ import warnings
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ from typing import Dict, List, Tuple
8
+ from uuid import uuid4
9
+
10
+ import gradio as gr
11
+ import numpy as np
12
+ from filelock import FileLock
13
+ from PIL.Image import Image
14
+
15
+
16
+ def setup(
17
+ folder_path: str | Path | None = None,
18
+ delete_button: bool = True,
19
+ export_button: bool = True,
20
+ ) -> None:
21
+ user_history = _UserHistory()
22
+ user_history.folder_path = _resolve_folder_path(folder_path)
23
+ user_history.delete_button = delete_button
24
+ user_history.export_button = export_button
25
+ user_history.initialized = True
26
+
27
+
28
+ def render() -> None:
29
+ user_history = _UserHistory()
30
+
31
+ # initialize with default config
32
+ if not user_history.initialized:
33
+ print("Initializing user history with default config. Use `user_history.setup(...)` to customize.")
34
+ setup()
35
+
36
+ # deactivate if no persistent storage
37
+ if user_history.folder_path is None:
38
+ gr.Markdown(
39
+ "User history is deactivated as no Persistent Storage volume has been found. Please contact the Space"
40
+ " owner to either assign a [Persistent Storage](https://huggingface.co/docs/hub/spaces-storage) or set"
41
+ " `folder_path` to a temporary folder."
42
+ )
43
+ return
44
+
45
+ # Render user history tab
46
+ gr.Markdown(
47
+ "## Your past generations\n\n(Log in to keep a gallery of your previous generations."
48
+ " Your history will be saved and available on your next visit.)"
49
+ )
50
+ with gr.Row():
51
+ gr.LoginButton(min_width=250)
52
+ gr.LogoutButton(min_width=250)
53
+ refresh_button = gr.Button("Refresh", icon="./assets/icon_refresh.png")
54
+ export_button = gr.Button("Export", icon="./assets/icon_download.png")
55
+ delete_button = gr.Button("Delete history", icon="./assets/icon_delete.png")
56
+
57
+ # "Export zip" row (hidden by default)
58
+ with gr.Row():
59
+ export_file = gr.File(file_count="single", file_types=[".zip"], label="Exported history", visible=False)
60
+
61
+ # "Config deletion" row (hidden by default)
62
+ with gr.Row():
63
+ confirm_button = gr.Button("Confirm delete all history", variant="stop", visible=False)
64
+ cancel_button = gr.Button("Cancel", visible=False)
65
+
66
+ # Gallery
67
+ gallery = gr.Gallery(
68
+ label="Past images",
69
+ show_label=True,
70
+ elem_id="gallery",
71
+ object_fit="contain",
72
+ columns=5,
73
+ height=600,
74
+ preview=False,
75
+ show_share_button=False,
76
+ show_download_button=False,
77
+ )
78
+ gr.Markdown("Make sure to save your images from time to time, this gallery may be deleted in the future.")
79
+ gallery.attach_load_event(_fetch_user_history, every=None)
80
+
81
+ # Interactions
82
+ refresh_button.click(fn=_fetch_user_history, inputs=[], outputs=[gallery], queue=False)
83
+ export_button.click(fn=_export_user_history, inputs=[], outputs=[export_file], queue=False)
84
+
85
+ # Taken from https://github.com/gradio-app/gradio/issues/3324#issuecomment-1446382045
86
+ delete_button.click(
87
+ lambda: [gr.update(visible=True), gr.update(visible=True)],
88
+ outputs=[confirm_button, cancel_button],
89
+ queue=False,
90
+ )
91
+ cancel_button.click(
92
+ lambda: [gr.update(visible=False), gr.update(visible=False)],
93
+ outputs=[confirm_button, cancel_button],
94
+ queue=False,
95
+ )
96
+ confirm_button.click(_delete_user_history).then(
97
+ lambda: [gr.update(visible=False), gr.update(visible=False)],
98
+ outputs=[confirm_button, cancel_button],
99
+ queue=False,
100
+ )
101
+
102
+
103
+ def save_image(
104
+ profile: gr.OAuthProfile | None,
105
+ image: Image | np.ndarray | str | Path,
106
+ label: str | None = None,
107
+ metadata: Dict | None = None,
108
+ ):
109
+ # Ignore images from logged out users
110
+ if profile is None:
111
+ return
112
+ username = profile["preferred_username"]
113
+
114
+ # Ignore images if user history not used
115
+ user_history = _UserHistory()
116
+ if not user_history.initialized:
117
+ warnings.warn(
118
+ "User history is not set in Gradio demo. Saving image is ignored. You must use `user_history.render(...)`"
119
+ " first."
120
+ )
121
+ return
122
+
123
+ # Copy image to storage
124
+ image_path = _copy_image(image, dst_folder=user_history._user_images_path(username))
125
+
126
+ # Save new image + metadata
127
+ if metadata is None:
128
+ metadata = {}
129
+ if "datetime" not in metadata:
130
+ metadata["datetime"] = str(datetime.now())
131
+ data = {"path": str(image_path), "label": label, "metadata": metadata}
132
+ with user_history._user_lock(username):
133
+ with user_history._user_jsonl_path(username).open("a") as f:
134
+ f.write(json.dumps(data) + "\n")
135
+
136
+
137
+ #############
138
+ # Internals #
139
+ #############
140
+
141
+
142
+ class _UserHistory(object):
143
+ _instance = None
144
+ initialized: bool = False
145
+
146
+ folder_path: Path | None
147
+ delete_button: bool
148
+ export_button: bool
149
+
150
+ def __new__(cls):
151
+ # Using singleton pattern => we don't want to expose an object (more complex to use) but still want to keep
152
+ # state between `render` and `save_image` calls.
153
+ if cls._instance is None:
154
+ cls._instance = super(_UserHistory, cls).__new__(cls)
155
+ return cls._instance
156
+
157
+ def _user_path(self, username: str) -> Path:
158
+ if self.folder_path is None:
159
+ raise Exception("User history is deactivated.")
160
+ path = self.folder_path / username
161
+ path.mkdir(parents=True, exist_ok=True)
162
+ return path
163
+
164
+ def _user_lock(self, username: str) -> FileLock:
165
+ """Ensure history is not corrupted if concurrent calls."""
166
+ if self.folder_path is None:
167
+ raise Exception("User history is deactivated.")
168
+ return FileLock(self.folder_path / f"{username}.lock") # lock outside of folder => better when exporting ZIP
169
+
170
+ def _user_jsonl_path(self, username: str) -> Path:
171
+ return self._user_path(username) / "history.jsonl"
172
+
173
+ def _user_images_path(self, username: str) -> Path:
174
+ path = self._user_path(username) / "images"
175
+ path.mkdir(parents=True, exist_ok=True)
176
+ return path
177
+
178
+
179
+ def _fetch_user_history(profile: gr.OAuthProfile | None) -> List[Tuple[str, str]]:
180
+ """Return saved history for that user, if it exists."""
181
+ # Cannot load history for logged out users
182
+ if profile is None:
183
+ return []
184
+ username = profile["preferred_username"]
185
+
186
+ user_history = _UserHistory()
187
+ if not user_history.initialized:
188
+ warnings.warn("User history is not set in Gradio demo. You must use `user_history.render(...)` first.")
189
+ return []
190
+
191
+ with user_history._user_lock(username):
192
+ # No file => no history saved yet
193
+ jsonl_path = user_history._user_jsonl_path(username)
194
+ if not jsonl_path.is_file():
195
+ return []
196
+
197
+ # Read history
198
+ images = []
199
+ for line in jsonl_path.read_text().splitlines():
200
+ data = json.loads(line)
201
+ images.append((data["path"], data["label"] or ""))
202
+ return list(reversed(images))
203
+
204
+
205
+ def _export_user_history(profile: gr.OAuthProfile | None) -> Dict | None:
206
+ """Zip all history for that user, if it exists and return it as a downloadable file."""
207
+ # Cannot load history for logged out users
208
+ if profile is None:
209
+ return None
210
+ username = profile["preferred_username"]
211
+
212
+ user_history = _UserHistory()
213
+ if not user_history.initialized:
214
+ warnings.warn("User history is not set in Gradio demo. You must use `user_history.render(...)` first.")
215
+ return None
216
+
217
+ # Zip history
218
+ with user_history._user_lock(username):
219
+ path = shutil.make_archive(
220
+ str(_archives_path() / f"history_{username}"), "zip", user_history._user_path(username)
221
+ )
222
+
223
+ return gr.update(visible=True, value=path)
224
+
225
+
226
+ def _delete_user_history(profile: gr.OAuthProfile | None) -> None:
227
+ """Delete all history for that user."""
228
+ # Cannot load history for logged out users
229
+ if profile is None:
230
+ return
231
+ username = profile["preferred_username"]
232
+
233
+ user_history = _UserHistory()
234
+ if not user_history.initialized:
235
+ warnings.warn("User history is not set in Gradio demo. You must use `user_history.render(...)` first.")
236
+ return
237
+
238
+ with user_history._user_lock(username):
239
+ shutil.rmtree(user_history._user_path(username))
240
+
241
+
242
+ ####################
243
+ # Internal helpers #
244
+ ####################
245
+
246
+
247
+ def _copy_image(image: Image | np.ndarray | str | Path, dst_folder: Path) -> Path:
248
+ """Copy image to the images folder."""
249
+ # Already a path => copy it
250
+ if isinstance(image, str):
251
+ image = Path(image)
252
+ if isinstance(image, Path):
253
+ dst = dst_folder / f"{uuid4().hex}_{Path(image).name}" # keep file ext
254
+ shutil.copyfile(image, dst)
255
+ return dst
256
+
257
+ # Still a Python object => serialize it
258
+ if isinstance(image, np.ndarray):
259
+ image = Image.fromarray(image)
260
+ if isinstance(image, Image):
261
+ dst = dst_folder / f"{uuid4().hex}.png"
262
+ image.save(dst)
263
+ return dst
264
+
265
+ raise ValueError(f"Unsupported image type: {type(image)}")
266
+
267
+
268
+ def _resolve_folder_path(folder_path: str | Path | None) -> Path | None:
269
+ if folder_path is not None:
270
+ return Path(folder_path).expanduser().resolve()
271
+
272
+ if os.getenv("SYSTEM") == "spaces":
273
+ if os.path.exists("/data"): # Persistent storage is enabled!
274
+ return Path("/data") / "user_history"
275
+ else:
276
+ return None # No persistent storage => no user history
277
+
278
+ # Not in a Space => local folder
279
+ return Path(__file__).parent / "user_history"
280
+
281
+
282
+ def _archives_path() -> Path:
283
+ # Doesn't have to be on persistent storage as it's only used for download
284
+ path = Path(__file__).parent / "_history_snapshots"
285
+ path.mkdir(parents=True, exist_ok=True)
286
+ return path