jamescalam commited on
Commit
99caaea
1 Parent(s): da210e8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionPipeline
2
+ import torch
3
+ import io
4
+ from PIL import Image
5
+ import os
6
+ from google.cloud import storage
7
+ import pinecone
8
+
9
+ # create Storage Cloud credentials
10
+ G_API = {
11
+ "type": os.environ["type"],
12
+ "project_id": os.environ["project_id"],
13
+ "private_key_id": os.environ["private_key_id"],
14
+ "private_key": os.environ["private_key"],
15
+ "client_email": os.environ["client_email"],
16
+ "client_id": os.environ["client_id"],
17
+ "auth_uri": os.environ["auth_uri"],
18
+ "token_uri": os.environ["token_uri"],
19
+ "auth_provider_x509_cert_url": os.environ["auth_provider_x509_cert_url"],
20
+ "client_x509_cert_url": os.environ["client_x509_cert_url"]
21
+ }
22
+ with open('cloud-storage.json', 'w', encoding='utf-8') as fp:
23
+ fp.write(json.dumps(G_API))
24
+ del G_API
25
+ # connect to Cloud Storage
26
+ os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'cloud-storage.json'
27
+ storage_client = storage.Client()
28
+ bucket = storage_client.get_bucket('hf-diffusion-images')
29
+
30
+ # get api key for pinecone auth
31
+ PINECONE_KEY = os.environ['PINECONE_KEY']
32
+
33
+ index_id = "hf-diffusion"
34
+
35
+ # init connection to pinecone
36
+ pinecone.init(
37
+ api_key=PINECONE_KEY,
38
+ environment="us-west1-gcp"
39
+ )
40
+ if index_id not in pinecone.list_indexes():
41
+ raise ValueError(f"Index '{index_id}' not found")
42
+
43
+ index = pinecone.Index(index_id)
44
+
45
+ device = 'cpu'
46
+
47
+ # init all of the models and move them to a given GPU
48
+ pipe = StableDiffusionPipeline.from_pretrained(
49
+ "CompVis/stable-diffusion-v1-4", use_auth_token=True
50
+ )
51
+ pipe.to(device)
52
+
53
+ def encode_text(text: str):
54
+ text_inputs = pipe.tokenizer(
55
+ text, return_tensors='pt'
56
+ ).to(device)
57
+ text_embeds = pipe.text_encoder(**text_inputs)
58
+ text_embeds = text_embeds.pooler_output.cpu().tolist()[0]
59
+ return text_embeds
60
+
61
+ def prompt_query(text: str):
62
+ embeds = encode_text(text)
63
+ xc = index.query(embeds, top_k=30, include_metadata=True)
64
+ prompts = [
65
+ match['metadata']['prompt'] for match in xc['matches']
66
+ ]
67
+ # deduplicate while preserving order
68
+ prompts = list(dict.fromkeys(prompts))
69
+ return [[x] for x in prompts[:5]]
70
+
71
+ def get_image(url: str):
72
+ blob = bucket.blob(url).download_as_string()
73
+ blob_bytes = io.BytesIO(blob)
74
+ im = Image.open(blob_bytes)
75
+ return im
76
+
77
+ def prompt_image(text: str):
78
+ embeds = encode_text(text)
79
+ xc = index.query(embeds, top_k=9, include_metadata=True)
80
+ image_urls = [
81
+ match['metadata']['image_url'] for match in xc['matches']
82
+ ]
83
+ images = []
84
+ for image_url in image_urls:
85
+ try:
86
+ blob = bucket.blob(image_url).download_as_string()
87
+ blob_bytes = io.BytesIO(blob)
88
+ im = Image.open(blob_bytes)
89
+ images.append(im)
90
+ except ValueError:
91
+ print(f"error for '{image_url}'")
92
+ return images
93
+
94
+ # __APP FUNCTIONS__
95
+
96
+ def set_suggestion(text: str):
97
+ return gr.TextArea.update(value=text[0])
98
+
99
+ def set_images(text: str):
100
+ images = prompt_image(text)
101
+ return gr.Gallery.update(value=images)
102
+
103
+ # __CREATE APP__
104
+ demo = gr.Blocks()
105
+
106
+ with demo:
107
+ gr.Markdown(
108
+ """
109
+ # Dream Cacher
110
+ """
111
+ )
112
+ with gr.Row():
113
+ with gr.Column():
114
+ prompt = gr.TextArea(
115
+ value="A dream about a cat",
116
+ placeholder="Enter a prompt to dream about",
117
+ interactive=True
118
+ )
119
+ search = gr.Button(value="Search!")
120
+ suggestions = gr.Dataset(
121
+ components=[prompt],
122
+ samples=[
123
+ ["Something"],
124
+ ["something else"]
125
+ ]
126
+ )
127
+ # event listener for change in prompt
128
+ prompt.change(prompt_query, prompt, suggestions)
129
+ # event listener for click on suggestion
130
+ suggestions.click(
131
+ set_suggestion,
132
+ suggestions,
133
+ suggestions.components
134
+ )
135
+
136
+
137
+ # results column
138
+ with gr.Column():
139
+ pics = gr.Gallery()
140
+ pics.style(grid=3)
141
+ # search event listening
142
+ search.click(set_images, prompt, pics)
143
+
144
+ demo.launch()