Tharunika1601 commited on
Commit
1feb738
1 Parent(s): 83c0cc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -1,22 +1,25 @@
1
  import streamlit as st
2
- from transformers import CLIPProcessor, CLIPModel, DiffusionModel
3
  import torch
4
  from PIL import Image
5
 
6
- st.title("Text to Image Generation")
7
 
8
  # Load pretrained models
9
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
10
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
11
- diffusion_model = DiffusionModel.from_pretrained("openai/guided-diffusion-clipped-coco")
12
 
13
  text = st.text_area("Enter a description:")
14
  if st.button("Generate Image") and text:
15
  # Process text and get CLIP features
16
  text_features = clip_processor(text, return_tensors="pt", padding=True)
17
 
18
- # Generate image from text using Guided Diffusion
19
- image = diffusion_model.generate_text_to_image(text_features["pixel_values"])
 
 
 
 
20
 
21
  # Display the generated image
22
  st.image(image, caption="Generated Image", use_column_width=True)
 
1
  import streamlit as st
2
+ from transformers import CLIPProcessor, CLIPModel
3
  import torch
4
  from PIL import Image
5
 
6
+ st.title("Text to Image Generation with CLIP")
7
 
8
  # Load pretrained models
9
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
10
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
 
11
 
12
  text = st.text_area("Enter a description:")
13
  if st.button("Generate Image") and text:
14
  # Process text and get CLIP features
15
  text_features = clip_processor(text, return_tensors="pt", padding=True)
16
 
17
+ # Use CLIP's image projection to generate an image representation
18
+ image_representation = clip_model.get_image_features(pixel_values=text_features.pixel_values)
19
+
20
+ # For visualization, you can convert the image representation back to an image
21
+ image_array = image_representation.squeeze().permute(1, 2, 0).cpu().numpy()
22
+ image = Image.fromarray((image_array * 255).astype('uint8'))
23
 
24
  # Display the generated image
25
  st.image(image, caption="Generated Image", use_column_width=True)