text2speech / app.py
Tharunika1601's picture
Update app.py
73e6ec1 verified
raw history blame
No virus
2.16 kB
import streamlit as st
from transformers import CLIPProcessor, CLIPModel
import torch
from PIL import Image
import requests
from io import BytesIO
pip install --upgrade torch
st.title("Text to Image Generation with CLIP")
# Load pretrained models
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
text = st.text_area("Enter a description:")
if st.button("Generate Image") and text:
# Process text and get CLIP features for text
text_features = clip_processor(text, return_tensors="pt", padding=True)
# Load an example image from the web (replace this with your image loading logic)
example_image_url = "https://source.unsplash.com/random"
example_image_response = requests.get(example_image_url)
example_image = Image.open(BytesIO(example_image_response.content))
# Process image and get CLIP features for image
image_features = clip_processor(images=example_image, return_tensors="pt", padding=True)
# Ensure the dimensions of pixel_values are the same for text and image features
max_len = max(text_features['pixel_values'].shape[1], image_features['pixel_values'].shape[1])
text_features['pixel_values'] = torch.nn.functional.pad(text_features['pixel_values'], (0, max_len - text_features['pixel_values'].shape[1]))
image_features['pixel_values'] = torch.nn.functional.pad(image_features['pixel_values'], (0, max_len - image_features['pixel_values'].shape[1]))
# Concatenate text and image features
combined_features = {
"pixel_values": torch.cat([text_features['pixel_values'], image_features['pixel_values']], dim=1)
}
# Forward pass through CLIP
image_representation = clip_model(**combined_features).last_hidden_state.mean(dim=1)
# For visualization, you can convert the image representation back to an image
image_array = image_representation.squeeze().cpu().numpy()
generated_image = Image.fromarray((image_array * 255).astype('uint8'))
# Display the generated image
st.image(generated_image, caption="Generated Image", use_column_width=True)