Wauplin HF staff commited on
Commit
c08f6e3
1 Parent(s): 70cd9d2

Create gallery_history.py

Browse files
Files changed (1) hide show
  1. gallery_history.py +128 -0
gallery_history.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ How to use:
3
+ 1. Create a Space with a Persistent Storage attached. Filesystem will be available under `/data`.
4
+ 2. Add `hf_oauth: true` to the Space metadata (README.md). Make sure to have Gradio>=3.41.0 configured.
5
+ 3. Add `HISTORY_FOLDER` as a Space variable (example. `"/data/history"`).
6
+ 4. Add `filelock` as dependency in `requirements.txt`.
7
+ 5. Add history gallery to your Gradio app:
8
+ a. Add imports: `from gallery_history import fetch_gallery_history, show_gallery_history`
9
+ a. Add `history = show_gallery_history()` within `gr.Blocks` context.
10
+ b. Add `.then(fn=fetch_gallery_history, inputs=[prompt, result], outputs=history)` on the generate event.
11
+ """
12
+ import json
13
+ import os
14
+ import numpy as np
15
+ import shutil
16
+ from pathlib import Path
17
+ from PIL import Image
18
+ from typing import Dict, List, Optional, Tuple
19
+ from uuid import uuid4
20
+
21
+ import gradio as gr
22
+ from filelock import FileLock
23
+
24
+ _folder = os.environ.get("HISTORY_FOLDER")
25
+ if _folder is None:
26
+ print(
27
+ "'HISTORY_FOLDER' environment variable not set. User history will be saved "
28
+ "locally and will be lost when the Space instance is restarted."
29
+ )
30
+ _folder = Path(__file__).parent / "history"
31
+ HISTORY_FOLDER_PATH = Path(_folder)
32
+
33
+ IMAGES_FOLDER_PATH = HISTORY_FOLDER_PATH / "images"
34
+ IMAGES_FOLDER_PATH.mkdir(parents=True, exist_ok=True)
35
+
36
+
37
+ def show_gallery_history():
38
+ gr.Markdown(
39
+ "## Your past generations\n\n(Log in to keep a gallery of your previous generations."
40
+ " Your history will be saved and available on your next visit.)"
41
+ )
42
+ with gr.Column():
43
+ with gr.Row():
44
+ gr.LoginButton(min_width=250)
45
+ gr.LogoutButton(min_width=250)
46
+ gallery = gr.Gallery(
47
+ label="Past images",
48
+ show_label=True,
49
+ elem_id="gallery",
50
+ object_fit="contain",
51
+ columns=3,
52
+ height=300,
53
+ preview=False,
54
+ show_share_button=False,
55
+ show_download_button=False,
56
+ )
57
+ gr.Markdown(
58
+ "Make sure to save your images from time to time, this gallery may be deleted in the future."
59
+ )
60
+ gallery.attach_load_event(fetch_gallery_history, every=None)
61
+ return gallery
62
+
63
+
64
+ def fetch_gallery_history(
65
+ prompt: Optional[str] = None,
66
+ result: Optional[np.ndarray] = None,
67
+ user: Optional[gr.OAuthProfile] = None,
68
+ ):
69
+ if user is None:
70
+ return []
71
+ try:
72
+ if prompt is not None and result is not None: # None values means no new images
73
+ new_image = Image.fromarray(result, 'RGB')
74
+ return _update_user_history(user["preferred_username"], new_image, prompt)
75
+ else:
76
+ return _read_user_history(user["preferred_username"])
77
+ except Exception as e:
78
+ raise gr.Error(f"Error while fetching history: {e}") from e
79
+
80
+
81
+ ####################
82
+ # Internal helpers #
83
+ ####################
84
+
85
+
86
+ def _read_user_history(username: str) -> List[Tuple[str, str]]:
87
+ """Return saved history for that user."""
88
+ with _user_lock(username):
89
+ path = _user_history_path(username)
90
+ if path.exists():
91
+ return json.loads(path.read_text())
92
+ return [] # No history yet
93
+
94
+
95
+ def _update_user_history(
96
+ username: str, new_image: Image.Image, prompt: str
97
+ ) -> List[Tuple[str, str]]:
98
+ """Update history for that user and return it."""
99
+ with _user_lock(username):
100
+ # Read existing
101
+ path = _user_history_path(username)
102
+ if path.exists():
103
+ images = json.loads(path.read_text())
104
+ else:
105
+ images = [] # No history yet
106
+
107
+ # Copy image to persistent folder
108
+ images = [(_copy_image(new_image), prompt)] + images
109
+
110
+ # Save and return
111
+ path.write_text(json.dumps(images))
112
+ return images
113
+
114
+
115
+ def _user_history_path(username: str) -> Path:
116
+ return HISTORY_FOLDER_PATH / f"{username}.json"
117
+
118
+
119
+ def _user_lock(username: str) -> FileLock:
120
+ """Ensure history is not corrupted if concurrent calls."""
121
+ return FileLock(f"{_user_history_path(username)}.lock")
122
+
123
+
124
+ def _copy_image(new_image: Image.Image) -> str:
125
+ """Copy image to the persistent storage."""
126
+ dst = str(IMAGES_FOLDER_PATH / f"{uuid4().hex}.png")
127
+ new_image.save(dst)
128
+ return dst