Tharunika1601 commited on
Commit
5190795
1 Parent(s): d7c38d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -25,13 +25,13 @@ if st.button("Generate Image") and text:
25
  image_features = clip_processor(images=example_image, return_tensors="pt", padding=True)
26
 
27
  # Ensure the dimensions of pixel_values are the same for text and image features
28
- max_len = max(text_features.pixel_values.shape[1], image_features.pixel_values.shape[1])
29
- text_features.pixel_values = torch.nn.functional.pad(text_features.pixel_values, (0, max_len - text_features.pixel_values.shape[1]))
30
- image_features.pixel_values = torch.nn.functional.pad(image_features.pixel_values, (0, max_len - image_features.pixel_values.shape[1]))
31
 
32
  # Concatenate text and image features
33
  combined_features = {
34
- "pixel_values": torch.cat([text_features.pixel_values, image_features.pixel_values], dim=1)
35
  }
36
 
37
  # Forward pass through CLIP
@@ -41,5 +41,5 @@ if st.button("Generate Image") and text:
41
  image_array = image_representation.squeeze().cpu().numpy()
42
  generated_image = Image.fromarray((image_array * 255).astype('uint8'))
43
 
44
-
45
  st.image(generated_image, caption="Generated Image", use_column_width=True)
 
25
  image_features = clip_processor(images=example_image, return_tensors="pt", padding=True)
26
 
27
  # Ensure the dimensions of pixel_values are the same for text and image features
28
+ max_len = max(text_features['pixel_values'].shape[1], image_features['pixel_values'].shape[1])
29
+ text_features['pixel_values'] = torch.nn.functional.pad(text_features['pixel_values'], (0, max_len - text_features['pixel_values'].shape[1]))
30
+ image_features['pixel_values'] = torch.nn.functional.pad(image_features['pixel_values'], (0, max_len - image_features['pixel_values'].shape[1]))
31
 
32
  # Concatenate text and image features
33
  combined_features = {
34
+ "pixel_values": torch.cat([text_features['pixel_values'], image_features['pixel_values']], dim=1)
35
  }
36
 
37
  # Forward pass through CLIP
 
41
  image_array = image_representation.squeeze().cpu().numpy()
42
  generated_image = Image.fromarray((image_array * 255).astype('uint8'))
43
 
44
+ # Display the generated image
45
  st.image(generated_image, caption="Generated Image", use_column_width=True)