text2speech / app.py
Tharunika1601's picture
Update app.py
73e6ec1 verified
import streamlit as st
from transformers import CLIPProcessor, CLIPModel
import torch
from PIL import Image
import requests
from io import BytesIO
pip install --upgrade torch
st.title("Text to Image Generation with CLIP")
# Load pretrained models
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
text = st.text_area("Enter a description:")
if st.button("Generate Image") and text:
# Process text and get CLIP features for text
text_features = clip_processor(text, return_tensors="pt", padding=True)
# Load an example image from the web (replace this with your image loading logic)
example_image_url = "https://source.unsplash.com/random"
example_image_response = requests.get(example_image_url)
example_image = Image.open(BytesIO(example_image_response.content))
# Process image and get CLIP features for image
image_features = clip_processor(images=example_image, return_tensors="pt", padding=True)
# Ensure the dimensions of pixel_values are the same for text and image features
max_len = max(text_features['pixel_values'].shape[1], image_features['pixel_values'].shape[1])
text_features['pixel_values'] = torch.nn.functional.pad(text_features['pixel_values'], (0, max_len - text_features['pixel_values'].shape[1]))
image_features['pixel_values'] = torch.nn.functional.pad(image_features['pixel_values'], (0, max_len - image_features['pixel_values'].shape[1]))
# Concatenate text and image features
combined_features = {
"pixel_values": torch.cat([text_features['pixel_values'], image_features['pixel_values']], dim=1)
}
# Forward pass through CLIP
image_representation = clip_model(**combined_features).last_hidden_state.mean(dim=1)
# For visualization, you can convert the image representation back to an image
image_array = image_representation.squeeze().cpu().numpy()
generated_image = Image.fromarray((image_array * 255).astype('uint8'))
# Display the generated image
st.image(generated_image, caption="Generated Image", use_column_width=True)