cmf-fine-tuned / index.py
Leeps's picture
Upload folder using huggingface_hub
0a1f600 verified
raw
history blame contribute delete
No virus
3.34 kB
import os
import base64
import numpy as np
from PIL import Image
import io
import requests
import gradio as gr
import replicate
from dotenv import load_dotenv, find_dotenv
# Locate the .env file
dotenv_path = find_dotenv()
load_dotenv(dotenv_path)
REPLICATE_API_TOKEN = os.getenv('REPLICATE_API_TOKEN')
def image_classifier(prompt, starter_image, image_strength):
if starter_image is not None:
starter_image_pil = Image.fromarray(starter_image.astype('uint8'))
# Resize the starter image if either dimension is larger than 768 pixels
if starter_image_pil.size[0] > 512 or starter_image_pil.size[1] > 512:
# Calculate the new size while maintaining the aspect ratio
if starter_image_pil.size[0] > starter_image_pil.size[1]:
# Width is larger than height
new_width = 512
new_height = int((512 / starter_image_pil.size[0]) * starter_image_pil.size[1])
else:
# Height is larger than width
new_height = 512
new_width = int((512 / starter_image_pil.size[1]) * starter_image_pil.size[0])
# Resize the image
starter_image_pil = starter_image_pil.resize((new_width, new_height), Image.LANCZOS)
# Save the starter image to a bytes buffer
buffered = io.BytesIO()
starter_image_pil.save(buffered, format="JPEG")
# Encode the starter image to base64
starter_image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
if starter_image is not None:
input = {
"prompt": prompt + " in the style of TOK",
"negative_prompt": "worst quality, low quality, illustration, 2d, painting, cartoons, sketch",
#"refine": "expert_ensemble_refiner",
"apply_watermark": False,
"num_inference_steps": 50,
"num_outputs": 3,
"lora_scale": .96,
"image": "data:image/jpeg;base64," + starter_image_base64,
"prompt_strength": 1-image_strength,
}
else:
input = {
"width": 1024,
"height": 1024,
"prompt": prompt + " in the style of TOK",
"negative_prompt": "worst quality, low quality, illustration, 2d, painting, cartoons, sketch",
#"refine": "expert_ensemble_refiner",
"apply_watermark": False,
"num_inference_steps": 50,
"num_outputs": 3,
"lora_scale": .96,
}
output = replicate.run(
# update to new trained model
"ltejedor/cmf:3af83ef60d86efbf374edb788fa4183a6067416e2fadafe709350dc1efe37d1d",
input=input
)
print(output)
images = []
for i in range(min(len(output), 3)):
image_url = output[i]
response = requests.get(image_url)
images.append(Image.open(io.BytesIO(response.content)))
# Add empty images if fewer than 3 were returned
while len(images) < 3:
images.append(Image.new('RGB', (512, 512), 'gray'))
return images
demo = gr.Interface(fn=image_classifier, inputs=["text", "image", gr.Slider(0, 1, step=0.025, value=0.2, label="Image Strength")], outputs=["image", "image", "image"])
demo.launch(share=False)