K00B404 commited on
Commit
dc8d200
1 Parent(s): 7eacdc0

Update tk_app.py

Browse files
Files changed (1) hide show
  1. tk_app.py +198 -0
tk_app.py CHANGED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tkinter as tk
2
+ from tkinter import ttk, scrolledtext, filedialog
3
+ from PIL import Image, ImageTk
4
+ import asyncio
5
+ import random
6
+ from threading import RLock
7
+ import requests
8
+ import io
9
+ import base64
10
+ from all_models import models
11
+
12
+ class ImageGeneratorApp:
13
+ def __init__(self, root):
14
+ self.root = root
15
+ self.root.title("Image Generator")
16
+ self.lock = RLock()
17
+
18
+ self.models = models # Use the imported models
19
+ self.num_models = 6 # Number of models to display
20
+
21
+ self.create_widgets()
22
+
23
+ def create_widgets(self):
24
+ self.notebook = ttk.Notebook(self.root)
25
+ self.notebook.pack(fill=tk.BOTH, expand=True)
26
+
27
+ self.create_local_tab()
28
+ self.create_api_tab()
29
+
30
+ def create_local_tab(self):
31
+ local_frame = ttk.Frame(self.notebook)
32
+ self.notebook.add(local_frame, text="Local Generation")
33
+
34
+ # Prompt input
35
+ prompt_label = ttk.Label(local_frame, text="Your prompt:")
36
+ prompt_label.pack(pady=5)
37
+ self.local_prompt_input = scrolledtext.ScrolledText(local_frame, height=4)
38
+ self.local_prompt_input.pack(pady=5, padx=10, fill=tk.X)
39
+
40
+ # Generate button
41
+ generate_button = ttk.Button(local_frame, text="Generate Images", command=self.generate_images)
42
+ generate_button.pack(pady=10)
43
+
44
+ # Image display area
45
+ image_frame = ttk.Frame(local_frame)
46
+ image_frame.pack(pady=10, padx=10)
47
+
48
+ self.image_labels = []
49
+ for i in range(self.num_models):
50
+ label = ttk.Label(image_frame)
51
+ label.grid(row=i//3, column=i%3, padx=5, pady=5)
52
+ self.image_labels.append(label)
53
+
54
+ # Model selection
55
+ model_frame = ttk.LabelFrame(local_frame, text="Model Selection")
56
+ model_frame.pack(pady=10, padx=10, fill=tk.X)
57
+
58
+ self.model_vars = []
59
+ for model in self.models[:self.num_models]:
60
+ var = tk.BooleanVar(value=True)
61
+ cb = ttk.Checkbutton(model_frame, text=model, variable=var)
62
+ cb.pack(anchor=tk.W)
63
+ self.model_vars.append(var)
64
+
65
+ def create_api_tab(self):
66
+ api_frame = ttk.Frame(self.notebook)
67
+ self.notebook.add(api_frame, text="API Generation")
68
+
69
+ # Model selection
70
+ model_label = ttk.Label(api_frame, text="Model:")
71
+ model_label.pack(pady=5)
72
+ self.api_model_var = tk.StringVar(value=self.models[0])
73
+ model_combobox = ttk.Combobox(api_frame, textvariable=self.api_model_var, values=self.models)
74
+ model_combobox.pack(pady=5, padx=10, fill=tk.X)
75
+
76
+ # Prompt input
77
+ prompt_label = ttk.Label(api_frame, text="Your prompt:")
78
+ prompt_label.pack(pady=5)
79
+ self.api_prompt_input = scrolledtext.ScrolledText(api_frame, height=4)
80
+ self.api_prompt_input.pack(pady=5, padx=10, fill=tk.X)
81
+
82
+ # Generate button
83
+ generate_button = ttk.Button(api_frame, text="Generate Image", command=self.generate_api_image)
84
+ generate_button.pack(pady=10)
85
+
86
+ # Image display
87
+ self.api_image_label = ttk.Label(api_frame)
88
+ self.api_image_label.pack(pady=10)
89
+
90
+ # Add to gallery button
91
+ add_gallery_button = ttk.Button(api_frame, text="Add to Gallery", command=self.add_to_gallery)
92
+ add_gallery_button.pack(pady=10)
93
+
94
+ # Gallery display
95
+ self.gallery_frame = ttk.Frame(api_frame)
96
+ self.gallery_frame.pack(pady=10, padx=10, fill=tk.BOTH, expand=True)
97
+
98
+ async def generate_image(self, model, prompt):
99
+ # This is a placeholder for actual model inference
100
+ # In a real application, you would call the actual model here
101
+ await asyncio.sleep(random.uniform(1, 3)) # Random delay to simulate processing time
102
+ return Image.new('RGB', (256, 256), color=random.choice(['red', 'green', 'blue']))
103
+
104
+ async def generate_all_images(self, prompt):
105
+ tasks = []
106
+ for model, var in zip(self.models[:self.num_models], self.model_vars):
107
+ if var.get():
108
+ task = asyncio.create_task(self.generate_image(model, prompt))
109
+ tasks.append(task)
110
+
111
+ results = await asyncio.gather(*tasks)
112
+
113
+ for label, image in zip(self.image_labels, results):
114
+ if image:
115
+ photo = ImageTk.PhotoImage(image)
116
+ label.configure(image=photo)
117
+ label.image = photo
118
+ else:
119
+ label.configure(image='')
120
+
121
+ def generate_images(self):
122
+ prompt = self.local_prompt_input.get("1.0", tk.END).strip()
123
+ if not prompt:
124
+ return
125
+
126
+ for label in self.image_labels:
127
+ label.configure(image='')
128
+
129
+ asyncio.run(self.generate_all_images(prompt))
130
+
131
+ def generate_api_image(self):
132
+ model_str = self.api_model_var.get()
133
+ prompt = self.api_prompt_input.get("1.0", tk.END).strip()
134
+
135
+ if not prompt:
136
+ return
137
+
138
+ # Make API call
139
+ url = "https://k00b404-huggingfacediffusion-custom.hf.space/run/gen_fn_4"
140
+ payload = {
141
+ "data": [
142
+ model_str,
143
+ prompt
144
+ ]
145
+ }
146
+ response = requests.post(url, json=payload)
147
+
148
+ if response.status_code == 200:
149
+ result = response.json()
150
+ image_data = base64.b64decode(result['data'][0].split(',')[1])
151
+ image = Image.open(io.BytesIO(image_data))
152
+ photo = ImageTk.PhotoImage(image)
153
+ self.api_image_label.configure(image=photo)
154
+ self.api_image_label.image = photo
155
+ self.current_api_image = image
156
+ else:
157
+ print(f"Error: {response.status_code}")
158
+
159
+ def add_to_gallery(self):
160
+ if not hasattr(self, 'current_api_image'):
161
+ return
162
+
163
+ # Save the image to a temporary file
164
+ temp_file = 'temp_image.png'
165
+ self.current_api_image.save(temp_file)
166
+
167
+ # Make API call to add to gallery
168
+ url = "https://k00b404-huggingfacediffusion-custom.hf.space/run/add_gallery_4"
169
+ files = {
170
+ 'data': ('image.png', open(temp_file, 'rb'), 'image/png')
171
+ }
172
+ data = {
173
+ 'data': ['', self.api_model_var.get(), '[]']
174
+ }
175
+ response = requests.post(url, files=files, data=data)
176
+
177
+ if response.status_code == 200:
178
+ result = response.json()
179
+ self.update_gallery(result['data'][0])
180
+ else:
181
+ print(f"Error: {response.status_code}")
182
+
183
+ def update_gallery(self, gallery_data):
184
+ for widget in self.gallery_frame.winfo_children():
185
+ widget.destroy()
186
+
187
+ for i, item in enumerate(gallery_data):
188
+ image_data = base64.b64decode(item['image'].split(',')[1])
189
+ image = Image.open(io.BytesIO(image_data))
190
+ photo = ImageTk.PhotoImage(image)
191
+ label = ttk.Label(self.gallery_frame, image=photo)
192
+ label.image = photo
193
+ label.grid(row=i//3, column=i%3, padx=5, pady=5)
194
+
195
+ if __name__ == "__main__":
196
+ root = tk.Tk()
197
+ app = ImageGeneratorApp(root)
198
+ root.mainloop()