File size: 4,143 Bytes
99caaea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from diffusers import StableDiffusionPipeline
import torch
import io
from PIL import Image
import os
from google.cloud import storage
import pinecone

# create Storage Cloud credentials
G_API = {
    "type": os.environ["type"],
    "project_id": os.environ["project_id"],
    "private_key_id": os.environ["private_key_id"],
    "private_key": os.environ["private_key"],
    "client_email": os.environ["client_email"],
    "client_id": os.environ["client_id"],
    "auth_uri": os.environ["auth_uri"],
    "token_uri": os.environ["token_uri"],
    "auth_provider_x509_cert_url": os.environ["auth_provider_x509_cert_url"],
    "client_x509_cert_url": os.environ["client_x509_cert_url"]
}
with open('cloud-storage.json', 'w', encoding='utf-8') as fp:
    fp.write(json.dumps(G_API))
del G_API
# connect to Cloud Storage
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'cloud-storage.json'
storage_client = storage.Client()
bucket = storage_client.get_bucket('hf-diffusion-images')
    
# get api key for pinecone auth
PINECONE_KEY = os.environ['PINECONE_KEY']

index_id = "hf-diffusion"

# init connection to pinecone
pinecone.init(
    api_key=PINECONE_KEY,
    environment="us-west1-gcp"
)
if index_id not in pinecone.list_indexes():
    raise ValueError(f"Index '{index_id}' not found")

index = pinecone.Index(index_id)

device = 'cpu'

# init all of the models and move them to a given GPU
pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", use_auth_token=True
)
pipe.to(device)

def encode_text(text: str):
    text_inputs = pipe.tokenizer(
        text, return_tensors='pt'
    ).to(device)
    text_embeds = pipe.text_encoder(**text_inputs)
    text_embeds = text_embeds.pooler_output.cpu().tolist()[0]
    return text_embeds

def prompt_query(text: str):
    embeds = encode_text(text)
    xc = index.query(embeds, top_k=30, include_metadata=True)
    prompts = [
        match['metadata']['prompt'] for match in xc['matches']
    ]
    # deduplicate while preserving order
    prompts = list(dict.fromkeys(prompts))
    return [[x] for x in prompts[:5]]

def get_image(url: str):
    blob = bucket.blob(url).download_as_string()
    blob_bytes = io.BytesIO(blob)
    im = Image.open(blob_bytes)
    return im

def prompt_image(text: str):
    embeds = encode_text(text)
    xc = index.query(embeds, top_k=9, include_metadata=True)
    image_urls = [
        match['metadata']['image_url'] for match in xc['matches']
    ]
    images = []
    for image_url in image_urls:
        try:
            blob = bucket.blob(image_url).download_as_string()
            blob_bytes = io.BytesIO(blob)
            im = Image.open(blob_bytes)
            images.append(im)
        except ValueError:
            print(f"error for '{image_url}'")
    return images

# __APP FUNCTIONS__

def set_suggestion(text: str):
    return gr.TextArea.update(value=text[0])

def set_images(text: str):
    images = prompt_image(text)
    return gr.Gallery.update(value=images)

# __CREATE APP__
demo = gr.Blocks()

with demo:
    gr.Markdown(
        """
        # Dream Cacher
        """
    )
    with gr.Row():
        with gr.Column():
            prompt = gr.TextArea(
                value="A dream about a cat",
                placeholder="Enter a prompt to dream about",
                interactive=True
            )
            search = gr.Button(value="Search!")
            suggestions = gr.Dataset(
                components=[prompt],
                samples=[
                    ["Something"],
                    ["something else"]
                ]
            )
            # event listener for change in prompt
            prompt.change(prompt_query, prompt, suggestions)
            # event listener for click on suggestion
            suggestions.click(
                set_suggestion,
                suggestions,
                suggestions.components
            )
            
        
        # results column
        with gr.Column():
            pics = gr.Gallery()
            pics.style(grid=3)
            # search event listening
            search.click(set_images, prompt, pics)

demo.launch()