Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import CLIPProcessor, CLIPModel | |
| import torch | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| pip install --upgrade torch | |
| st.title("Text to Image Generation with CLIP") | |
| # Load pretrained models | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16") | |
| text = st.text_area("Enter a description:") | |
| if st.button("Generate Image") and text: | |
| # Process text and get CLIP features for text | |
| text_features = clip_processor(text, return_tensors="pt", padding=True) | |
| # Load an example image from the web (replace this with your image loading logic) | |
| example_image_url = "https://source.unsplash.com/random" | |
| example_image_response = requests.get(example_image_url) | |
| example_image = Image.open(BytesIO(example_image_response.content)) | |
| # Process image and get CLIP features for image | |
| image_features = clip_processor(images=example_image, return_tensors="pt", padding=True) | |
| # Ensure the dimensions of pixel_values are the same for text and image features | |
| max_len = max(text_features['pixel_values'].shape[1], image_features['pixel_values'].shape[1]) | |
| text_features['pixel_values'] = torch.nn.functional.pad(text_features['pixel_values'], (0, max_len - text_features['pixel_values'].shape[1])) | |
| image_features['pixel_values'] = torch.nn.functional.pad(image_features['pixel_values'], (0, max_len - image_features['pixel_values'].shape[1])) | |
| # Concatenate text and image features | |
| combined_features = { | |
| "pixel_values": torch.cat([text_features['pixel_values'], image_features['pixel_values']], dim=1) | |
| } | |
| # Forward pass through CLIP | |
| image_representation = clip_model(**combined_features).last_hidden_state.mean(dim=1) | |
| # For visualization, you can convert the image representation back to an image | |
| image_array = image_representation.squeeze().cpu().numpy() | |
| generated_image = Image.fromarray((image_array * 255).astype('uint8')) | |
| # Display the generated image | |
| st.image(generated_image, caption="Generated Image", use_column_width=True) | |