Jordan Legg commited on
Commit
3ae9c83
Β·
1 Parent(s): e514cac

fix trying to fix image preprocessing

Browse files
Files changed (2) hide show
  1. app.py +75 -44
  2. requirements.txt +2 -1
app.py CHANGED
@@ -8,18 +8,32 @@ from torchvision import transforms
8
  from diffusers import DiffusionPipeline
9
 
10
  # Define constants
11
- dtype = torch.bfloat16
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
  MAX_SEED = np.iinfo(np.int32).max
14
  MAX_IMAGE_SIZE = 2048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Load the diffusion pipeline
17
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
18
 
19
- def preprocess_image(image):
20
  # Preprocess the image for the VAE
21
  preprocess = transforms.Compose([
22
- transforms.Resize((512, 512)), # Adjust the size as needed
23
  transforms.ToTensor(),
24
  transforms.Normalize([0.5], [0.5])
25
  ])
@@ -32,44 +46,60 @@ def encode_image(image, vae):
32
  latents = vae.encode(image).latent_dist.sample() * 0.18215
33
  return latents
34
 
 
 
 
 
 
 
 
 
 
 
35
  @spaces.GPU()
36
- def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
37
- if randomize_seed:
38
- seed = random.randint(0, MAX_SEED)
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- if init_image is not None:
42
- # Process img2img
43
- init_image = init_image.convert("RGB")
44
- init_image = preprocess_image(init_image)
45
- latents = encode_image(init_image, pipe.vae)
46
- image = pipe(
47
- prompt=prompt,
48
- height=height,
49
- width=width,
50
- num_inference_steps=num_inference_steps,
51
- generator=generator,
52
- guidance_scale=0.0,
53
- latents=latents
54
- ).images[0]
55
- else:
56
- # Process text2img
57
- image = pipe(
58
- prompt=prompt,
59
- height=height,
60
- width=width,
61
- num_inference_steps=num_inference_steps,
62
- generator=generator,
63
- guidance_scale=0.0
64
- ).images[0]
65
-
66
- return image, seed
 
 
 
 
67
 
68
  # Define example prompts
69
  examples = [
70
  "a tiny astronaut hatching from an egg on the moon",
71
  "a cat holding a sign that says hello world",
72
  "an anime illustration of a wiener schnitzel",
 
 
73
  ]
74
 
75
  # CSS styling for the Japanese-inspired interface
@@ -122,7 +152,7 @@ with gr.Blocks(css=css) as demo:
122
  label="Prompt",
123
  show_label=False,
124
  max_lines=1,
125
- placeholder="Enter your prompt",
126
  container=False,
127
  )
128
  run_button = gr.Button("Run", scale=0)
@@ -144,17 +174,17 @@ with gr.Blocks(css=css) as demo:
144
  with gr.Row():
145
  width = gr.Slider(
146
  label="Width",
147
- minimum=256,
148
  maximum=MAX_IMAGE_SIZE,
149
- step=32,
150
- value=1024,
151
  )
152
  height = gr.Slider(
153
  label="Height",
154
- minimum=256,
155
  maximum=MAX_IMAGE_SIZE,
156
- step=32,
157
- value=1024,
158
  )
159
 
160
  with gr.Row():
@@ -181,7 +211,8 @@ with gr.Blocks(css=css) as demo:
181
  outputs=[result, seed]
182
  )
183
 
184
- demo.launch()
 
185
 
186
 
187
 
 
8
  from diffusers import DiffusionPipeline
9
 
10
  # Define constants
 
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
  MAX_IMAGE_SIZE = 2048
13
+ MIN_IMAGE_SIZE = 256
14
+ DEFAULT_IMAGE_SIZE = 1024
15
+ MAX_PROMPT_LENGTH = 500
16
+
17
+ # Check for GPU availability
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ if device == "cpu":
20
+ print("Warning: Running on CPU. This may be very slow.")
21
+
22
+ dtype = torch.float16 if device == "cuda" else torch.float32
23
+
24
+ def load_model():
25
+ try:
26
+ return DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
27
+ except Exception as e:
28
+ raise RuntimeError(f"Failed to load the model: {str(e)}")
29
 
30
  # Load the diffusion pipeline
31
+ pipe = load_model()
32
 
33
+ def preprocess_image(image, target_size=(512, 512)):
34
  # Preprocess the image for the VAE
35
  preprocess = transforms.Compose([
36
+ transforms.Resize(target_size, interpolation=transforms.InterpolationMode.LANCZOS),
37
  transforms.ToTensor(),
38
  transforms.Normalize([0.5], [0.5])
39
  ])
 
46
  latents = vae.encode(image).latent_dist.sample() * 0.18215
47
  return latents
48
 
49
+ def validate_inputs(prompt, width, height, num_inference_steps):
50
+ if not prompt or len(prompt) > MAX_PROMPT_LENGTH:
51
+ raise ValueError(f"Prompt must be between 1 and {MAX_PROMPT_LENGTH} characters.")
52
+ if width % 8 != 0 or height % 8 != 0:
53
+ raise ValueError("Width and height must be divisible by 8.")
54
+ if width < MIN_IMAGE_SIZE or width > MAX_IMAGE_SIZE or height < MIN_IMAGE_SIZE or height > MAX_IMAGE_SIZE:
55
+ raise ValueError(f"Image dimensions must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}.")
56
+ if num_inference_steps < 1 or num_inference_steps > 50:
57
+ raise ValueError("Number of inference steps must be between 1 and 50.")
58
+
59
  @spaces.GPU()
60
+ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=DEFAULT_IMAGE_SIZE, height=DEFAULT_IMAGE_SIZE, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
61
+ try:
62
+ validate_inputs(prompt, width, height, num_inference_steps)
63
+
64
+ if randomize_seed:
65
+ seed = random.randint(0, MAX_SEED)
66
+ generator = torch.Generator(device=device).manual_seed(seed)
67
+
68
+ if init_image is not None:
69
+ init_image = init_image.convert("RGB")
70
+ init_image = preprocess_image(init_image, (height, width))
71
+ latents = encode_image(init_image, pipe.vae)
72
+ latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
73
+ image = pipe(
74
+ prompt=prompt,
75
+ height=height,
76
+ width=width,
77
+ num_inference_steps=num_inference_steps,
78
+ generator=generator,
79
+ guidance_scale=0.0,
80
+ latents=latents
81
+ ).images[0]
82
+ else:
83
+ image = pipe(
84
+ prompt=prompt,
85
+ height=height,
86
+ width=width,
87
+ num_inference_steps=num_inference_steps,
88
+ generator=generator,
89
+ guidance_scale=0.0
90
+ ).images[0]
91
+
92
+ return image, seed
93
+ except Exception as e:
94
+ raise gr.Error(str(e))
95
 
96
  # Define example prompts
97
  examples = [
98
  "a tiny astronaut hatching from an egg on the moon",
99
  "a cat holding a sign that says hello world",
100
  "an anime illustration of a wiener schnitzel",
101
+ "a surreal landscape with floating islands and waterfalls",
102
+ "a steampunk-inspired cityscape at sunset"
103
  ]
104
 
105
  # CSS styling for the Japanese-inspired interface
 
152
  label="Prompt",
153
  show_label=False,
154
  max_lines=1,
155
+ placeholder=f"Enter your prompt (max {MAX_PROMPT_LENGTH} characters)",
156
  container=False,
157
  )
158
  run_button = gr.Button("Run", scale=0)
 
174
  with gr.Row():
175
  width = gr.Slider(
176
  label="Width",
177
+ minimum=MIN_IMAGE_SIZE,
178
  maximum=MAX_IMAGE_SIZE,
179
+ step=8,
180
+ value=DEFAULT_IMAGE_SIZE,
181
  )
182
  height = gr.Slider(
183
  label="Height",
184
+ minimum=MIN_IMAGE_SIZE,
185
  maximum=MAX_IMAGE_SIZE,
186
+ step=8,
187
+ value=DEFAULT_IMAGE_SIZE,
188
  )
189
 
190
  with gr.Row():
 
211
  outputs=[result, seed]
212
  )
213
 
214
+ if __name__ == "__main__":
215
+ demo.launch()
216
 
217
 
218
 
requirements.txt CHANGED
@@ -6,4 +6,5 @@ transformers==4.42.4
6
  xformers
7
  sentencepiece
8
  gradio==4.29.0
9
- torchvision
 
 
6
  xformers
7
  sentencepiece
8
  gradio==4.29.0
9
+ torchvision
10
+ pillow