import os import base64 import numpy as np from PIL import Image import io import requests import gradio as gr import replicate from dotenv import load_dotenv, find_dotenv # Locate the .env file dotenv_path = find_dotenv() load_dotenv(dotenv_path) REPLICATE_API_TOKEN = os.getenv('REPLICATE_API_TOKEN') def image_classifier(prompt, starter_image, image_strength): if starter_image is not None: starter_image_pil = Image.fromarray(starter_image.astype('uint8')) # Resize the starter image if either dimension is larger than 768 pixels if starter_image_pil.size[0] > 512 or starter_image_pil.size[1] > 512: # Calculate the new size while maintaining the aspect ratio if starter_image_pil.size[0] > starter_image_pil.size[1]: # Width is larger than height new_width = 512 new_height = int((512 / starter_image_pil.size[0]) * starter_image_pil.size[1]) else: # Height is larger than width new_height = 512 new_width = int((512 / starter_image_pil.size[1]) * starter_image_pil.size[0]) # Resize the image starter_image_pil = starter_image_pil.resize((new_width, new_height), Image.LANCZOS) # Save the starter image to a bytes buffer buffered = io.BytesIO() starter_image_pil.save(buffered, format="JPEG") # Encode the starter image to base64 starter_image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') if starter_image is not None: input = { "prompt": prompt + " in the style of TOK", "negative_prompt": "worst quality, low quality, illustration, 2d, painting, cartoons, sketch", #"refine": "expert_ensemble_refiner", "apply_watermark": False, "num_inference_steps": 50, "num_outputs": 3, "lora_scale": .96, "image": "data:image/jpeg;base64," + starter_image_base64, "prompt_strength": 1-image_strength, } else: input = { "width": 1024, "height": 1024, "prompt": prompt + " in the style of TOK", "negative_prompt": "worst quality, low quality, illustration, 2d, painting, cartoons, sketch", #"refine": "expert_ensemble_refiner", "apply_watermark": False, "num_inference_steps": 50, "num_outputs": 3, "lora_scale": .96, } output = replicate.run( # update to new trained model "ltejedor/cmf:3af83ef60d86efbf374edb788fa4183a6067416e2fadafe709350dc1efe37d1d", input=input ) print(output) images = [] for i in range(min(len(output), 3)): image_url = output[i] response = requests.get(image_url) images.append(Image.open(io.BytesIO(response.content))) # Add empty images if fewer than 3 were returned while len(images) < 3: images.append(Image.new('RGB', (512, 512), 'gray')) return images demo = gr.Interface(fn=image_classifier, inputs=["text", "image", gr.Slider(0, 1, step=0.025, value=0.2, label="Image Strength")], outputs=["image", "image", "image"]) demo.launch(share=False)