Gopalag commited on
Commit
97a01d0
·
verified ·
1 Parent(s): 06d4d54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +281 -166
app.py CHANGED
@@ -1,132 +1,154 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
 
4
  import torch
5
- from diffusers import StableDiffusionPipeline
6
  from PIL import Image
7
- import os
8
 
9
- MAX_SEED = 10000
10
- MAX_IMAGE_SIZE = 1024
11
 
12
- def get_edge_color(image):
13
- """Get a random color from the edge of the image"""
14
- img_array = np.array(image)
15
- top_edge = img_array[0, :, :]
16
- bottom_edge = img_array[-1, :, :]
17
- left_edge = img_array[:, 0, :]
18
- right_edge = img_array[:, -1, :]
19
- edge_pixels = np.concatenate([top_edge, bottom_edge, left_edge, right_edge])
20
- random_edge_color = tuple(edge_pixels[random.randint(0, len(edge_pixels)-1)])
21
- return random_edge_color
22
 
23
- def color_match_tshirt(tshirt_image, target_color, threshold=30):
24
- """Change white/near-white areas of the t-shirt to the target color"""
25
- img_array = np.array(tshirt_image)
26
- white_mask = np.all(np.abs(img_array - [255, 255, 255]) < threshold, axis=2)
27
- img_array[white_mask] = target_color
28
- return Image.fromarray(img_array)
29
 
30
- def add_watermark(image, logo_path, position='bottom-right', size_percentage=10):
31
- """Add a watermark to an image"""
32
- try:
33
- if not os.path.exists(logo_path):
34
- return image
35
-
36
- logo = Image.open(logo_path).convert('RGBA')
37
- main_width, main_height = image.size
38
- logo_width = int(main_width * size_percentage / 100)
39
- logo_height = int(logo.size[1] * (logo_width / logo.size[0]))
40
- logo = logo.resize((logo_width, logo_height), Image.Resampling.LANCZOS)
41
-
42
- if image.mode != 'RGBA':
43
- image = image.convert('RGBA')
44
-
45
- watermarked = Image.new('RGBA', image.size, (0, 0, 0, 0))
46
- watermarked.paste(image, (0, 0))
47
-
48
- if position == 'bottom-right':
49
- pos = (main_width - logo_width - 10, main_height - logo_height - 10)
50
- elif position == 'bottom-left':
51
- pos = (10, main_height - logo_height - 10)
52
- elif position == 'top-right':
53
- pos = (main_width - logo_width - 10, 10)
54
- else: # top-left
55
- pos = (10, 10)
56
-
57
- watermarked.paste(logo, pos, logo)
58
- return watermarked.convert('RGB')
59
- except Exception as e:
60
- print(f"Failed to add watermark: {str(e)}")
61
- return image
62
 
63
- def create_tshirt_preview(design_image, tshirt_template_path, tshirt_color="white"):
64
- """Create a preview of the design on a t-shirt"""
65
- try:
66
- tshirt = Image.open(tshirt_template_path)
67
- tshirt_width, tshirt_height = tshirt.size
68
-
69
- edge_color = get_edge_color(design_image)
70
- tshirt = color_match_tshirt(tshirt, edge_color)
71
-
72
- design_width = int(tshirt_width * 0.35)
73
- design_height = int(design_width * design_image.size[1] / design_image.size[0])
74
- design_image = design_image.resize((design_width, design_height), Image.Resampling.LANCZOS)
75
-
76
- x = (tshirt_width - design_width) // 2
77
- y = int(tshirt_height * 0.2)
78
-
79
- if design_image.mode == 'RGBA':
80
- mask = design_image.split()[3]
81
- else:
82
- mask = None
83
-
84
- tshirt.paste(design_image, (x, y), mask)
85
- return tshirt
86
- except Exception as e:
87
- print(f"Failed to create t-shirt preview: {str(e)}")
88
- return design_image
 
 
 
89
 
90
- def enhance_prompt(prompt, style):
91
- """Enhance the prompt based on selected style"""
92
- if not style:
93
- return prompt
94
-
95
- style_prompts = {
96
- "minimal": "minimalist design, clean lines, simple shapes",
97
- "vintage": "vintage style, retro, distressed texture",
98
- "artistic": "artistic, creative, hand-drawn style",
99
- "geometric": "geometric patterns, abstract shapes",
100
- "typography": "modern typography, creative lettering",
101
- "realistic": "photorealistic, detailed illustration"
102
- }
103
 
104
- return f"{prompt}, {style_prompts.get(style, '')}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- def initialize_pipeline():
107
- """Initialize the Stable Diffusion pipeline"""
108
- model_id = "stabilityai/stable-diffusion-2-1"
109
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
110
- if torch.cuda.is_available():
111
- pipe = pipe.to("cuda")
112
- return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- def generate_design(prompt, style, seed, width, height, num_inference_steps, pipe):
115
- """Generate the design using Stable Diffusion"""
116
- enhanced_prompt = enhance_prompt(prompt, style)
 
 
 
 
 
117
  generator = torch.Generator().manual_seed(seed)
118
 
119
- image = pipe(
 
120
  prompt=enhanced_prompt,
121
  width=width,
122
  height=height,
123
  num_inference_steps=num_inference_steps,
124
  generator=generator,
 
125
  ).images[0]
126
 
127
- return image
 
 
 
128
 
129
- # Constants
130
  TSHIRT_COLORS = {
131
  "White": "#FFFFFF",
132
  "Black": "#000000",
@@ -134,7 +156,14 @@ TSHIRT_COLORS = {
134
  "Gray": "#808080"
135
  }
136
 
137
- STYLES = [
 
 
 
 
 
 
 
138
  "minimal",
139
  "vintage",
140
  "artistic",
@@ -143,67 +172,153 @@ STYLES = [
143
  "realistic"
144
  ]
145
 
146
- EXAMPLES = [
147
- ["Cool geometric mountain landscape", "minimal", "White"],
148
- ["Vintage motorcycle with flames", "vintage", "Black"],
149
- ["Flamingo in scenic forest", "realistic", "White"],
150
- ["Adventure Starts typography", "typography", "White"]
151
- ]
152
-
153
- # Gradio Interface
154
- def create_interface():
155
- pipe = initialize_pipeline()
156
-
157
- def infer(prompt, style, tshirt_color, seed, randomize_seed, width, height, num_inference_steps):
158
- if randomize_seed:
159
- seed = random.randint(0, MAX_SEED)
160
-
161
- try:
162
- design_image = generate_design(prompt, style, seed, width, height, num_inference_steps, pipe)
163
- tshirt_preview = create_tshirt_preview(design_image, "tshirt_template.png", tshirt_color)
164
- return design_image, tshirt_preview, seed
165
- except Exception as e:
166
- print(f"Error during inference: {str(e)}")
167
- return None, None, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
170
- with gr.Column():
171
- gr.Markdown("# 👕 Deradh's T-Shirt Design Generator")
172
-
173
- with gr.Row():
174
- prompt = gr.Textbox(label="Design Description", placeholder="Describe your t-shirt design idea")
175
- style = gr.Dropdown(choices=[""] + STYLES, value="", label="Style")
176
- tshirt_color = gr.Dropdown(choices=list(TSHIRT_COLORS.keys()), value="White", label="T-Shirt Color")
177
-
178
- run_button = gr.Button("✨ Generate")
179
-
180
- with gr.Row():
181
- result = gr.Image(label="Generated Design")
182
- preview = gr.Image(label="T-Shirt Preview")
183
-
184
- with gr.Accordion("Advanced Settings", open=False):
185
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
186
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
187
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
188
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
189
- num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=25)
190
-
191
- gr.Examples(
192
- examples=EXAMPLES,
193
- inputs=[prompt, style, tshirt_color],
194
- outputs=[result, preview, seed],
195
- fn=lambda p, s, c: infer(p, s, c, 0, True, 512, 512, 25),
196
- cache_examples=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  )
198
-
199
- run_button.click(
200
- fn=infer,
201
- inputs=[prompt, style, tshirt_color, seed, randomize_seed, width, height, num_inference_steps],
202
- outputs=[result, preview, seed]
203
  )
204
-
205
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
- if __name__ == "__main__":
208
- demo = create_interface()
209
- demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ import spaces
5
  import torch
6
+ from diffusers import DiffusionPipeline
7
  from PIL import Image
8
+ import io
9
 
10
+ dtype = torch.bfloat16
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ pipe = DiffusionPipeline.from_pretrained(
14
+ "black-forest-labs/FLUX.1-schnell",
15
+ torch_dtype=dtype
16
+ ).to(device)
 
 
 
 
 
 
17
 
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+ MAX_IMAGE_SIZE = 2048
 
 
 
 
20
 
21
+ import numpy as np
22
+ from collections import Counter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ def get_prominent_colors(image, num_colors=5):
25
+ """
26
+ Get the most prominent colors from an image, focusing on edges
27
+ """
28
+ # Convert to numpy array
29
+ img_array = np.array(image)
30
+
31
+ # Create a simple edge mask using gradient magnitude
32
+ gradient_x = np.gradient(img_array.mean(axis=2))[1]
33
+ gradient_y = np.gradient(img_array.mean(axis=2))[0]
34
+ gradient_magnitude = np.sqrt(gradient_x**2 + gradient_y**2)
35
+
36
+ # Threshold to get edge pixels
37
+ edge_threshold = np.percentile(gradient_magnitude, 90) # Adjust percentile as needed
38
+ edge_mask = gradient_magnitude > edge_threshold
39
+
40
+ # Get colors from edge pixels
41
+ edge_colors = img_array[edge_mask]
42
+
43
+ # Convert colors to tuples for counting
44
+ colors = [tuple(color) for color in edge_colors]
45
+
46
+ # Count occurrences of each color
47
+ color_counts = Counter(colors)
48
+
49
+ # Get most common colors
50
+ prominent_colors = color_counts.most_common(num_colors)
51
+
52
+ return prominent_colors
53
 
54
+ def create_tshirt_preview(design_image, tshirt_color="white"):
55
+ """
56
+ Overlay the design onto the existing t-shirt template and color match
57
+ """
58
+ # Load the template t-shirt image
59
+ tshirt = Image.open('image.jpeg')
60
+ tshirt_width, tshirt_height = tshirt.size
61
+
62
+ # Convert design to PIL Image if it's not already
63
+ if not isinstance(design_image, Image.Image):
64
+ design_image = Image.fromarray(design_image)
 
 
65
 
66
+ # Get prominent colors from the design
67
+ prominent_colors = get_prominent_colors(design_image)
68
+ if prominent_colors:
69
+ # Use the most prominent color for the t-shirt
70
+ main_color = prominent_colors[0][0] # RGB tuple of most common color
71
+ else:
72
+ # Fallback to white if no colors found
73
+ main_color = (255, 255, 255)
74
+
75
+ # Convert design to PIL Image if it's not already
76
+ if not isinstance(design_image, Image.Image):
77
+ design_image = Image.fromarray(design_image)
78
+
79
+ # Resize design to fit nicely on shirt (40% of shirt width)
80
+ design_width = int(tshirt_width * 0.35) # Adjust this percentage as needed
81
+ design_height = int(design_width * design_image.size[1] / design_image.size[0])
82
+ design_image = design_image.resize((design_width, design_height), Image.Resampling.LANCZOS)
83
+
84
+ # Calculate position to center design on shirt
85
+ x = (tshirt_width - design_width) // 2
86
+ y = int(tshirt_height * 0.2) # Adjust this value based on your template
87
+
88
+ # If design has transparency (RGBA), create mask
89
+ if design_image.mode == 'RGBA':
90
+ mask = design_image.split()[3]
91
+ else:
92
+ mask = None
93
+
94
+ # Paste design onto shirt
95
+ tshirt.paste(design_image, (x, y), mask)
96
+
97
+ return tshirt
98
 
99
+ def enhance_prompt_for_tshirt(prompt, style=None):
100
+ """Add specific terms to ensure good t-shirt designs."""
101
+ style_terms = {
102
+ "minimal": ["simple geometric shapes", "clean lines", "minimalist illustration"],
103
+ "vintage": ["distressed effect", "retro typography", "vintage illustration"],
104
+ "artistic": ["hand-drawn style", "watercolor effect", "artistic illustration"],
105
+ "geometric": ["abstract shapes", "geometric patterns", "modern design"],
106
+ "typography": ["bold typography", "creative lettering", "text-based design"],
107
+ "realistic": ["realistic", "cinematic", "photograph"]
108
+ }
109
+
110
+ base_terms = [
111
+ "create t-shirt design",
112
+ "with centered composition",
113
+ "high quality",
114
+ "professional design",
115
+ "clear background"
116
+ ]
117
+
118
+ enhanced_prompt = f"{prompt}, {', '.join(base_terms)}"
119
+
120
+ if style and style in style_terms:
121
+ style_specific_terms = style_terms[style]
122
+ enhanced_prompt = f"{enhanced_prompt}, {', '.join(style_specific_terms)}"
123
+
124
+ return enhanced_prompt
125
 
126
+ @spaces.GPU()
127
+ def infer(prompt, style=None, tshirt_color="white", seed=42, randomize_seed=False,
128
+ width=1024, height=1024, num_inference_steps=4,
129
+ progress=gr.Progress(track_tqdm=True)):
130
+ if randomize_seed:
131
+ seed = random.randint(0, MAX_SEED)
132
+
133
+ enhanced_prompt = enhance_prompt_for_tshirt(prompt, style)
134
  generator = torch.Generator().manual_seed(seed)
135
 
136
+ # Generate the design
137
+ design_image = pipe(
138
  prompt=enhanced_prompt,
139
  width=width,
140
  height=height,
141
  num_inference_steps=num_inference_steps,
142
  generator=generator,
143
+ guidance_scale=0.0
144
  ).images[0]
145
 
146
+ # Create t-shirt preview
147
+ tshirt_preview = create_tshirt_preview(design_image, tshirt_color)
148
+
149
+ return design_image, tshirt_preview, seed
150
 
151
+ # Available t-shirt colors
152
  TSHIRT_COLORS = {
153
  "White": "#FFFFFF",
154
  "Black": "#000000",
 
156
  "Gray": "#808080"
157
  }
158
 
159
+ examples = [
160
+ ["Cool geometric mountain landscape", "minimal", "White"],
161
+ ["Vintage motorcycle with flames", "vintage", "Black"],
162
+ ["flamingo in scenic forset", "realistic", "White"],
163
+ ["Adventure Starts typography", "typography", "White"]
164
+ ]
165
+
166
+ styles = [
167
  "minimal",
168
  "vintage",
169
  "artistic",
 
172
  "realistic"
173
  ]
174
 
175
+ css = """
176
+ #col-container {
177
+ margin: 0 auto;
178
+ max-width: 1200px !important;
179
+ padding: 20px;
180
+ }
181
+ .main-title {
182
+ text-align: center;
183
+ color: #2d3748;
184
+ margin-bottom: 1rem;
185
+ font-family: 'Poppins', sans-serif;
186
+ }
187
+ .subtitle {
188
+ text-align: center;
189
+ color: #4a5568;
190
+ margin-bottom: 2rem;
191
+ font-family: 'Inter', sans-serif;
192
+ font-size: 0.95rem;
193
+ line-height: 1.5;
194
+ }
195
+ .design-input {
196
+ border: 2px solid #e2e8f0;
197
+ border-radius: 10px;
198
+ padding: 12px !important;
199
+ margin-bottom: 1rem !important;
200
+ font-size: 1rem;
201
+ transition: all 0.3s ease;
202
+ }
203
+ .results-row {
204
+ display: grid;
205
+ grid-template-columns: 1fr 1fr;
206
+ gap: 20px;
207
+ margin-top: 20px;
208
+ }
209
+ """
210
 
211
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
212
+ with gr.Column(elem_id="col-container"):
213
+ gr.Markdown(
214
+ """
215
+ # 👕Deradh's T-Shirt Design Generator
216
+ """,
217
+ elem_classes=["main-title"]
218
+ )
219
+
220
+ gr.Markdown(
221
+ """
222
+ Create unique t-shirt designs using Deradh's AI.
223
+ Describe your design idea and select a style to generate professional-quality artwork
224
+ perfect for custom t-shirts.
225
+ """,
226
+ elem_classes=["subtitle"]
227
+ )
228
+
229
+ with gr.Row():
230
+ with gr.Column(scale=2):
231
+ prompt = gr.Text(
232
+ label="Design Description",
233
+ show_label=False,
234
+ max_lines=1,
235
+ placeholder="Describe your t-shirt design idea",
236
+ container=False,
237
+ elem_classes=["design-input"]
238
+ )
239
+ with gr.Column(scale=1):
240
+ style = gr.Dropdown(
241
+ choices=[""] + styles,
242
+ value="",
243
+ label="Style",
244
+ container=False
245
+ )
246
+ with gr.Column(scale=1):
247
+ tshirt_color = gr.Dropdown(
248
+ choices=list(TSHIRT_COLORS.keys()),
249
+ value="White",
250
+ label="T-Shirt Color",
251
+ container=False
252
+ )
253
+ run_button = gr.Button(
254
+ "✨ Generate",
255
+ scale=0,
256
+ elem_classes=["generate-button"]
257
+ )
258
+
259
+ with gr.Row(elem_classes=["results-row"]):
260
+ result = gr.Image(
261
+ label="Generated Design",
262
+ show_label=True,
263
+ elem_classes=["result-image"]
264
  )
265
+ preview = gr.Image(
266
+ label="T-Shirt Preview",
267
+ show_label=True,
268
+ elem_classes=["preview-image"]
 
269
  )
270
+
271
+ with gr.Accordion("🔧 Advanced Settings", open=False):
272
+ with gr.Group():
273
+ seed = gr.Slider(
274
+ label="Design Seed",
275
+ minimum=0,
276
+ maximum=MAX_SEED,
277
+ step=1,
278
+ value=0,
279
+ )
280
+ randomize_seed = gr.Checkbox(
281
+ label="Randomize Design",
282
+ value=True
283
+ )
284
+
285
+ with gr.Row():
286
+ width = gr.Slider(
287
+ label="Width",
288
+ minimum=256,
289
+ maximum=MAX_IMAGE_SIZE,
290
+ step=32,
291
+ value=1024,
292
+ )
293
+ height = gr.Slider(
294
+ label="Height",
295
+ minimum=256,
296
+ maximum=MAX_IMAGE_SIZE,
297
+ step=32,
298
+ value=1024,
299
+ )
300
+
301
+ num_inference_steps = gr.Slider(
302
+ label="Generation Quality (Steps)",
303
+ minimum=1,
304
+ maximum=50,
305
+ step=1,
306
+ value=4,
307
+ )
308
+
309
+ gr.Examples(
310
+ examples=examples,
311
+ fn=infer,
312
+ inputs=[prompt, style, tshirt_color],
313
+ outputs=[result, preview, seed],
314
+ cache_examples=True
315
+ )
316
+
317
+ gr.on(
318
+ triggers=[run_button.click, prompt.submit],
319
+ fn=infer,
320
+ inputs=[prompt, style, tshirt_color, seed, randomize_seed, width, height, num_inference_steps],
321
+ outputs=[result, preview, seed]
322
+ )
323
 
324
+ demo.launch()