Spaces:
Runtime error
Runtime error
Update utils.py
Browse files
utils.py
CHANGED
@@ -66,7 +66,6 @@ def layout(*args):
|
|
66 |
st.markdown(str(foot), unsafe_allow_html=True)
|
67 |
|
68 |
|
69 |
-
|
70 |
def footer():
|
71 |
myargs = [
|
72 |
"Created by ",
|
@@ -96,7 +95,6 @@ def footer():
|
|
96 |
height=600,
|
97 |
)
|
98 |
|
99 |
-
|
100 |
model = False
|
101 |
def generate(prompt,crazy,k):
|
102 |
global model
|
@@ -113,7 +111,11 @@ def generate(prompt,crazy,k):
|
|
113 |
set_seed(np.random.randint(0,10000))
|
114 |
|
115 |
# Sampling
|
116 |
-
|
|
|
|
|
|
|
|
|
117 |
top_k=2048,
|
118 |
top_p=None,
|
119 |
softmax_temperature=crazy,
|
@@ -124,7 +126,7 @@ def generate(prompt,crazy,k):
|
|
124 |
# CLIP Re-ranking
|
125 |
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
|
126 |
model_clip.to(device=device)
|
127 |
-
rank = clip_score(prompt=
|
128 |
images=images,
|
129 |
model_clip=model_clip,
|
130 |
preprocess_clip=preprocess_clip,
|
@@ -143,35 +145,37 @@ def generate(prompt,crazy,k):
|
|
143 |
|
144 |
def drawGrid():
|
145 |
master = {}
|
146 |
-
order = 0
|
147 |
-
|
148 |
-
#print(st.session_state.results)
|
149 |
|
150 |
for r in st.session_state.results[::-1]:
|
151 |
_txt = r['prompt']+" "+str(r['crazy'])+" "+str(r['k'])
|
152 |
-
|
153 |
if(_txt not in master):
|
154 |
master[_txt] = [r]
|
155 |
-
order += 1
|
156 |
else:
|
157 |
master[_txt].append(r)
|
158 |
|
159 |
-
|
160 |
-
for
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
177 |
|
|
|
66 |
st.markdown(str(foot), unsafe_allow_html=True)
|
67 |
|
68 |
|
|
|
69 |
def footer():
|
70 |
myargs = [
|
71 |
"Created by ",
|
|
|
95 |
height=600,
|
96 |
)
|
97 |
|
|
|
98 |
model = False
|
99 |
def generate(prompt,crazy,k):
|
100 |
global model
|
|
|
111 |
set_seed(np.random.randint(0,10000))
|
112 |
|
113 |
# Sampling
|
114 |
+
newPrompt = prompt
|
115 |
+
if("architecture" not in prompt.lower() ):
|
116 |
+
newPrompt += " architecture"
|
117 |
+
|
118 |
+
images = model.sampling(prompt=newPrompt,
|
119 |
top_k=2048,
|
120 |
top_p=None,
|
121 |
softmax_temperature=crazy,
|
|
|
126 |
# CLIP Re-ranking
|
127 |
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
|
128 |
model_clip.to(device=device)
|
129 |
+
rank = clip_score(prompt=newPrompt,
|
130 |
images=images,
|
131 |
model_clip=model_clip,
|
132 |
preprocess_clip=preprocess_clip,
|
|
|
145 |
|
146 |
def drawGrid():
|
147 |
master = {}
|
|
|
|
|
|
|
148 |
|
149 |
for r in st.session_state.results[::-1]:
|
150 |
_txt = r['prompt']+" "+str(r['crazy'])+" "+str(r['k'])
|
|
|
151 |
if(_txt not in master):
|
152 |
master[_txt] = [r]
|
|
|
153 |
else:
|
154 |
master[_txt].append(r)
|
155 |
|
156 |
+
|
157 |
+
for i in st.session_state.images:
|
158 |
+
im = st.empty()
|
159 |
+
|
160 |
+
|
161 |
+
placeholder = st.empty()
|
162 |
+
with placeholder.container():
|
163 |
+
|
164 |
+
for m in master:
|
165 |
|
166 |
+
txt = master[m][0]['prompt']+" (temperature:"+ str(master[m][0]['crazy']) + ", top k:" + str(master[m][0]['k']) + ")"
|
167 |
+
st.subheader(txt)
|
168 |
+
col1, col2, col3 = st.columns(3)
|
169 |
+
|
170 |
+
for ix, item in enumerate(master[m]):
|
171 |
+
if ix % 3 == 0:
|
172 |
+
with col1:
|
173 |
+
st.session_state.images.append(st.image(item["image"]))
|
174 |
+
if ix % 3 == 1:
|
175 |
+
with col2:
|
176 |
+
st.session_state.images.append(st.image(item["image"]))
|
177 |
+
if ix % 3 == 2:
|
178 |
+
with col3:
|
179 |
+
st.session_state.images.append(st.image(item["image"]))
|
180 |
+
|
181 |
|