amongusrickroll68's picture
Rename class_name to _class_name
3c48c5e
import tensorflow as tf
import numpy as np
from PIL import Image
from io import BytesIO
from scipy.stats import truncnorm
from skimage.transform import resize
from transformers import CLIPProcessor, CLIPModel
class TextToImageGenerator:
def __init__(self):
self.clip = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
self.processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
self.generator = tf.keras.models.load_model('path/to/generator/model')
def generate_image(self, prompt):
encoded_prompt = self.processor(prompt, return_tensors="tf").to_dict()
noise = tf.random.normal([1, 256])
text_features = self.clip.get_text_features(encoded_prompt)
image_features = self.generator([text_features, noise], training=False)[0]
image = self._postprocess_image(image_features)
return image
def _postprocess_image(self, image_features):
image_features = (image_features + 1) / 2 # scale from [-1, 1] to [0, 1]
image_features = np.clip(image_features, 0, 1) # clip any values outside of [0, 1]
image = Image.fromarray(np.uint8(image_features * 255))
image = image.resize((256, 256))
image_buffer = BytesIO()
image.save(image_buffer, format='JPEG')
image_data = image_buffer.getvalue()
return image_data