File size: 3,344 Bytes
367eecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4ed719
367eecb
 
0a1f600
367eecb
 
 
 
 
 
 
0a1f600
 
367eecb
d4ed719
367eecb
 
0a1f600
367eecb
 
 
 
 
 
 
 
 
 
 
 
866d4b9
 
 
 
 
 
 
 
 
367eecb
866d4b9
367eecb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import base64
import numpy as np
from PIL import Image
import io
import requests
import gradio as gr

import replicate

from dotenv import load_dotenv, find_dotenv

# Locate the .env file
dotenv_path = find_dotenv()

load_dotenv(dotenv_path)

REPLICATE_API_TOKEN = os.getenv('REPLICATE_API_TOKEN')


def image_classifier(prompt, starter_image, image_strength):
        
    if starter_image is not None:
        starter_image_pil = Image.fromarray(starter_image.astype('uint8'))

        # Resize the starter image if either dimension is larger than 768 pixels
        if starter_image_pil.size[0] > 512 or starter_image_pil.size[1] > 512:
            # Calculate the new size while maintaining the aspect ratio
            if starter_image_pil.size[0] > starter_image_pil.size[1]:
                # Width is larger than height
                new_width = 512
                new_height = int((512 / starter_image_pil.size[0]) * starter_image_pil.size[1])
            else:
                # Height is larger than width
                new_height = 512
                new_width = int((512 / starter_image_pil.size[1]) * starter_image_pil.size[0])
            
            # Resize the image
            starter_image_pil = starter_image_pil.resize((new_width, new_height), Image.LANCZOS)

                # Save the starter image to a bytes buffer
            buffered = io.BytesIO()
            starter_image_pil.save(buffered, format="JPEG")
            
            # Encode the starter image to base64
            starter_image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')

    if starter_image is not None:
        input = {
            "prompt": prompt + " in the style of TOK",
            "negative_prompt": "worst quality, low quality, illustration, 2d, painting, cartoons, sketch",
            #"refine": "expert_ensemble_refiner",
            "apply_watermark": False,
            "num_inference_steps": 50,
            "num_outputs": 3,
            "lora_scale": .96,
            "image": "data:image/jpeg;base64," + starter_image_base64, 
            "prompt_strength": 1-image_strength,
        }
    else:
        input = {
            "width": 1024,
            "height": 1024,
            "prompt": prompt + " in the style of TOK",
            "negative_prompt": "worst quality, low quality, illustration, 2d, painting, cartoons, sketch",
            #"refine": "expert_ensemble_refiner",
            "apply_watermark": False,
            "num_inference_steps": 50,
            "num_outputs": 3,
            "lora_scale": .96,
        }
    
    output = replicate.run(
        # update to new trained model
        "ltejedor/cmf:3af83ef60d86efbf374edb788fa4183a6067416e2fadafe709350dc1efe37d1d",
        input=input
    )

    print(output)
    
    images = []
    for i in range(min(len(output), 3)):
        image_url = output[i]
        response = requests.get(image_url)
        images.append(Image.open(io.BytesIO(response.content)))
    
    # Add empty images if fewer than 3 were returned
    while len(images) < 3:
        images.append(Image.new('RGB', (512, 512), 'gray'))

    return images

demo = gr.Interface(fn=image_classifier, inputs=["text", "image", gr.Slider(0, 1, step=0.025, value=0.2, label="Image Strength")], outputs=["image", "image", "image"])
demo.launch(share=False)