awacke1 commited on
Commit
02659a8
1 Parent(s): e6083e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -2
app.py CHANGED
@@ -9,11 +9,10 @@ import re
9
  import random
10
  import torch
11
  import time
12
- import time
13
 
14
  from PIL import Image
15
  from io import BytesIO
16
- from PIL import Image
17
  from diffusers import DiffusionPipeline, LCMScheduler, AutoencoderTiny
18
 
19
  try:
@@ -33,6 +32,77 @@ device = torch.device(
33
  torch_device = device
34
  torch_dtype = torch.float16
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
37
  print(f"TORCH_COMPILE: {TORCH_COMPILE}")
38
  print(f"device: {device}")
 
9
  import random
10
  import torch
11
  import time
12
+ import shutil # Added for zip functionality
13
 
14
  from PIL import Image
15
  from io import BytesIO
 
16
  from diffusers import DiffusionPipeline, LCMScheduler, AutoencoderTiny
17
 
18
  try:
 
32
  torch_device = device
33
  torch_dtype = torch.float16
34
 
35
+
36
+ # add file save and download and clear:
37
+ # Function to create a zip file from a list of files
38
+ def create_zip(files):
39
+ timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
40
+ zip_filename = f"images_{timestamp}.zip"
41
+ with zipfile.ZipFile(zip_filename, 'w') as zipf:
42
+ for file in files:
43
+ zipf.write(file, os.path.basename(file))
44
+ return zip_filename
45
+
46
+ # Function to encode a file to base64
47
+ def encode_file_to_base64(file_path):
48
+ with open(file_path, "rb") as file:
49
+ encoded = base64.b64encode(file.read()).decode()
50
+ return encoded
51
+
52
+
53
+ # Function to save all images as a zip file and provide a base64 download link
54
+ def save_all_images(images):
55
+ if len(images) == 0:
56
+ return None, None
57
+
58
+ zip_filename = create_zip(images) # Create a zip file from the list of image files
59
+ zip_base64 = encode_file_to_base64(zip_filename) # Encode the zip file to base64
60
+ download_link = f'<a href="data:application/zip;base64,{zip_base64}" download="{zip_filename}">Download All</a>'
61
+
62
+ return zip_filename, download_link
63
+
64
+
65
+ # Function to clear all image files
66
+ def clear_all_images():
67
+ base_dir = os.getcwd() # Get the current base directory
68
+ img_files = [file for file in os.listdir(base_dir) if file.lower().endswith((".png", ".jpg", ".jpeg"))] # List all files ending with ".jpg" or ".jpeg"
69
+
70
+ # Remove all image files
71
+ for file in img_files:
72
+ os.remove(file)
73
+
74
+
75
+ # Add "Save All" button with emoji
76
+ save_all_button = gr.Button("💾 Save All", scale=1)
77
+ # Add "Clear All" button with emoji
78
+ clear_all_button = gr.Button("🗑️ Clear All", scale=1)
79
+
80
+ # Function to handle "Save All" button click
81
+ def save_all_button_click():
82
+ images = [file for file in os.listdir() if file.lower().endswith((".png", ".jpg", ".jpeg"))]
83
+ zip_filename, download_link = save_all_images(images)
84
+
85
+ if download_link:
86
+ gr.write(download_link)
87
+
88
+ # Function to handle "Clear All" button click
89
+ def clear_all_button_click():
90
+ clear_all_images()
91
+
92
+ # Attach click event handlers to the buttons
93
+ save_all_button.click(save_all_button_click)
94
+ clear_all_button.click(clear_all_button_click)
95
+
96
+ # Add buttons to the Streamlit app
97
+ gr.button(save_all_button)
98
+ gr.button(clear_all_button)
99
+
100
+
101
+
102
+
103
+
104
+
105
+
106
  print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
107
  print(f"TORCH_COMPILE: {TORCH_COMPILE}")
108
  print(f"device: {device}")