Spaces:
Runtime error
Runtime error
from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts | |
from htbuilder.units import percent, px | |
from htbuilder.funcs import rgba, rgb | |
import streamlit as st | |
import os | |
import sys | |
import argparse | |
import clip | |
import numpy as np | |
from PIL import Image | |
from dalle.models import Dalle | |
from dalle.utils.utils import set_seed, clip_score | |
def link(link, text, **style): | |
return a(_href=link, _target="_blank", style=styles(**style))(text) | |
def layout(*args): | |
style = """ | |
<style> | |
# MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
.stApp { bottom: 105px; } | |
</style> | |
""" | |
style_div = styles( | |
position="fixed", | |
left=0, | |
bottom=0, | |
margin=px(0, 0, 0, 0), | |
width=percent(100), | |
color="black", | |
text_align="center", | |
height="auto", | |
opacity=1 | |
) | |
style_hr = styles( | |
display="block", | |
margin=px(8, 8, "auto", "auto"), | |
border_style="inset", | |
border_width=px(2) | |
) | |
body = p() | |
foot = div( | |
style=style_div | |
)( | |
hr( | |
style=style_hr | |
), | |
body | |
) | |
st.markdown(style, unsafe_allow_html=True) | |
for arg in args: | |
if isinstance(arg, str): | |
body(arg) | |
elif isinstance(arg, HtmlElement): | |
body(arg) | |
st.markdown(str(foot), unsafe_allow_html=True) | |
def footer(): | |
myargs = [ | |
"Created by ", | |
link("https://jonathanmalott.com", "Jonathan Malott"), | |
br(), | |
link("https://bridgingbarriers.utexas.edu/good-systems", "Good Systems"), | |
" Grand Challenge", | |
", The University of Texas at Austin.", | |
" Advised by Dr. Junfeng Jiao.", | |
br(), | |
br(), | |
] | |
layout(*myargs) | |
#footer() | |
def generate(prompt,crazy,k): | |
device = 'cpu' | |
print("-2-") | |
model = Dalle.from_pretrained('minDALL-E/1.3B') # This will automatically download the pretrained model. | |
print("-3-") | |
model.to(device=device) | |
num_candidates = 1 | |
images = [] | |
set_seed(np.random.randint(0,10000)) | |
# Sampling | |
images = model.sampling(prompt=prompt, | |
top_k=2048, | |
top_p=None, | |
softmax_temperature=crazy, | |
num_candidates=num_candidates, | |
device=device).cpu().numpy() | |
images = np.transpose(images, (0, 2, 3, 1)) | |
# CLIP Re-ranking | |
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device) | |
model_clip.to(device=device) | |
rank = clip_score(prompt=prompt, | |
images=images, | |
model_clip=model_clip, | |
preprocess_clip=preprocess_clip, | |
device=device) | |
result = images[rank] | |
item = {} | |
item['prompt'] = prompt | |
item['crazy'] = crazy | |
item['k'] = k | |
item['image'] = Image.fromarray((result*255).astype(np.uint8)) | |
st.session_state.results.append(item) | |
def drawGrid(): | |
master = {} | |
order = 0 | |
#print(st.session_state.results) | |
for r in st.session_state.results[::-1]: | |
_txt = r['prompt']+" "+str(r['crazy'])+" "+str(r['k']) | |
if(_txt not in master): | |
master[_txt] = [r] | |
order += 1 | |
else: | |
master[_txt].append(r) | |
for m in master: | |
#with placeholder.container(): | |
txt = master[m][0]['prompt']+" (temperature:"+ str(master[m][0]['crazy']) + ", top k:" + str(master[m][0]['k']) + ")" | |
st.subheader(txt) | |
col1, col2, col3 = st.columns(3) | |
for ix, item in enumerate(master[m]): | |
if ix % 3 == 0: | |
with col1: | |
st.image(item["image"]) | |
if ix % 3 == 1: | |
with col2: | |
st.image(item["image"]) | |
if ix % 3 == 2: | |
with col3: | |
st.image(item["image"]) | |