Spaces:
Sleeping
Sleeping
File size: 2,159 Bytes
0e39422 1feb738 0e39422 414dca8 73e6ec1 36a7d43 adb4a2a 1feb738 adb4a2a 0e39422 adb4a2a 0e39422 d312027 0e39422 adb4a2a 414dca8 9a435e2 414dca8 d312027 2883ce2 5190795 2883ce2 d312027 5190795 d312027 0710d1c 1feb738 d312027 9a435e2 adb4a2a 5190795 9a435e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import streamlit as st
from transformers import CLIPProcessor, CLIPModel
import torch
from PIL import Image
import requests
from io import BytesIO
pip install --upgrade torch
st.title("Text to Image Generation with CLIP")
# Load pretrained models
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
text = st.text_area("Enter a description:")
if st.button("Generate Image") and text:
# Process text and get CLIP features for text
text_features = clip_processor(text, return_tensors="pt", padding=True)
# Load an example image from the web (replace this with your image loading logic)
example_image_url = "https://source.unsplash.com/random"
example_image_response = requests.get(example_image_url)
example_image = Image.open(BytesIO(example_image_response.content))
# Process image and get CLIP features for image
image_features = clip_processor(images=example_image, return_tensors="pt", padding=True)
# Ensure the dimensions of pixel_values are the same for text and image features
max_len = max(text_features['pixel_values'].shape[1], image_features['pixel_values'].shape[1])
text_features['pixel_values'] = torch.nn.functional.pad(text_features['pixel_values'], (0, max_len - text_features['pixel_values'].shape[1]))
image_features['pixel_values'] = torch.nn.functional.pad(image_features['pixel_values'], (0, max_len - image_features['pixel_values'].shape[1]))
# Concatenate text and image features
combined_features = {
"pixel_values": torch.cat([text_features['pixel_values'], image_features['pixel_values']], dim=1)
}
# Forward pass through CLIP
image_representation = clip_model(**combined_features).last_hidden_state.mean(dim=1)
# For visualization, you can convert the image representation back to an image
image_array = image_representation.squeeze().cpu().numpy()
generated_image = Image.fromarray((image_array * 255).astype('uint8'))
# Display the generated image
st.image(generated_image, caption="Generated Image", use_column_width=True)
|