File size: 2,159 Bytes
0e39422
1feb738
0e39422
 
414dca8
 
73e6ec1
36a7d43
adb4a2a
1feb738
adb4a2a
0e39422
 
 
adb4a2a
0e39422
 
d312027
0e39422
adb4a2a
414dca8
9a435e2
414dca8
 
d312027
 
 
 
2883ce2
5190795
 
 
2883ce2
d312027
 
5190795
d312027
 
 
 
0710d1c
1feb738
d312027
9a435e2
adb4a2a
5190795
9a435e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import streamlit as st
from transformers import CLIPProcessor, CLIPModel
import torch
from PIL import Image
import requests
from io import BytesIO
pip install --upgrade torch


st.title("Text to Image Generation with CLIP")

# Load pretrained models
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")

text = st.text_area("Enter a description:")
if st.button("Generate Image") and text:
    # Process text and get CLIP features for text
    text_features = clip_processor(text, return_tensors="pt", padding=True)

    # Load an example image from the web (replace this with your image loading logic)
    example_image_url = "https://source.unsplash.com/random"
    example_image_response = requests.get(example_image_url)
    example_image = Image.open(BytesIO(example_image_response.content))

    # Process image and get CLIP features for image
    image_features = clip_processor(images=example_image, return_tensors="pt", padding=True)

    # Ensure the dimensions of pixel_values are the same for text and image features
    max_len = max(text_features['pixel_values'].shape[1], image_features['pixel_values'].shape[1])
    text_features['pixel_values'] = torch.nn.functional.pad(text_features['pixel_values'], (0, max_len - text_features['pixel_values'].shape[1]))
    image_features['pixel_values'] = torch.nn.functional.pad(image_features['pixel_values'], (0, max_len - image_features['pixel_values'].shape[1]))

    # Concatenate text and image features
    combined_features = {
        "pixel_values": torch.cat([text_features['pixel_values'], image_features['pixel_values']], dim=1)
    }

    # Forward pass through CLIP
    image_representation = clip_model(**combined_features).last_hidden_state.mean(dim=1)

    # For visualization, you can convert the image representation back to an image
    image_array = image_representation.squeeze().cpu().numpy()
    generated_image = Image.fromarray((image_array * 255).astype('uint8'))

    # Display the generated image
    st.image(generated_image, caption="Generated Image", use_column_width=True)