Spaces:
Runtime error
Runtime error
from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts | |
from htbuilder.units import percent, px | |
from htbuilder.funcs import rgba, rgb | |
import streamlit as st | |
import os | |
import sys | |
import argparse | |
import clip | |
import numpy as np | |
from PIL import Image | |
from dalle.models import Dalle | |
from dalle.utils.utils import set_seed, clip_score | |
import streamlit.components.v1 as components | |
def link(link, text, **style): | |
return a(_href=link, _target="_blank", style=styles(**style))(text) | |
def layout(*args): | |
style = """ | |
<style> | |
# MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
.stApp { bottom: 105px; } | |
</style> | |
""" | |
style_div = styles( | |
position="fixed", | |
left=0, | |
bottom=0, | |
margin=px(0, 0, 0, 0), | |
width=percent(100), | |
color="black", | |
text_align="center", | |
height="auto", | |
opacity=1 | |
) | |
style_hr = styles( | |
display="block", | |
margin=px(8, 8, "auto", "auto"), | |
border_style="inset", | |
border_width=px(2) | |
) | |
body = p() | |
foot = div( | |
style=style_div | |
)( | |
hr( | |
style=style_hr | |
), | |
body | |
) | |
st.markdown(style, unsafe_allow_html=True) | |
for arg in args: | |
if isinstance(arg, str): | |
body(arg) | |
elif isinstance(arg, HtmlElement): | |
body(arg) | |
st.markdown(str(foot), unsafe_allow_html=True) | |
def footer(): | |
myargs = [ | |
"Created by ", | |
link("https://jonathanmalott.com", "Jonathan Malott"), | |
br(), | |
link("https://bridgingbarriers.utexas.edu/good-systems", "Good Systems Grand Challenge"), | |
", The University of Texas at Austin.", | |
" Advised by Dr. Junfeng Jiao.", | |
br(), | |
br(), | |
] | |
layout(*myargs) | |
components.html( | |
""" | |
<!-- Global site tag (gtag.js) - Google Analytics --> | |
<script async src="https://www.googletagmanager.com/gtag/js?id=G-SB6NJ9DQS7"></script> | |
<script> | |
window.dataLayer = window.dataLayer || []; | |
function gtag(){dataLayer.push(arguments);} | |
gtag('js', new Date()); | |
gtag('config', 'G-SB6NJ9DQS7'); | |
</script> | |
""", | |
height=600, | |
) | |
model = False | |
def generate(prompt,crazy,k): | |
global model | |
device = 'cpu' | |
if(model == False): | |
model = Dalle.from_pretrained('minDALL-E/1.3B') # This will automatically download the pretrained model. | |
model.to(device=device) | |
num_candidates = 1 | |
images = [] | |
set_seed(np.random.randint(0,10000)) | |
# Sampling | |
newPrompt = prompt | |
if("architecture" not in prompt.lower() ): | |
newPrompt += " architecture" | |
images = model.sampling(prompt=newPrompt, | |
top_k=2048, | |
top_p=None, | |
softmax_temperature=crazy, | |
num_candidates=num_candidates, | |
device=device).cpu().numpy() | |
images = np.transpose(images, (0, 2, 3, 1)) | |
# CLIP Re-ranking | |
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device) | |
model_clip.to(device=device) | |
rank = clip_score(prompt=newPrompt, | |
images=images, | |
model_clip=model_clip, | |
preprocess_clip=preprocess_clip, | |
device=device) | |
result = images[rank] | |
item = {} | |
item['prompt'] = prompt | |
item['crazy'] = crazy | |
item['k'] = k | |
item['image'] = Image.fromarray((result*255).astype(np.uint8)) | |
st.session_state.results.append(item) | |
def drawGrid(): | |
master = {} | |
for r in st.session_state.results[::-1]: | |
_txt = r['prompt']+" "+str(r['crazy'])+" "+str(r['k']) | |
if(_txt not in master): | |
master[_txt] = [r] | |
else: | |
master[_txt].append(r) | |
for i in st.session_state.images: | |
im = st.empty() | |
placeholder = st.empty() | |
with placeholder.container(): | |
for m in master: | |
txt = master[m][0]['prompt']+" (temperature:"+ str(master[m][0]['crazy']) + ", top k:" + str(master[m][0]['k']) + ")" | |
st.subheader(txt) | |
col1, col2, col3 = st.columns(3) | |
for ix, item in enumerate(master[m]): | |
if ix % 3 == 0: | |
with col1: | |
st.session_state.images.append(st.image(item["image"])) | |
if ix % 3 == 1: | |
with col2: | |
st.session_state.images.append(st.image(item["image"])) | |
if ix % 3 == 2: | |
with col3: | |
st.session_state.images.append(st.image(item["image"])) | |