nostalgebraist's picture
typo
53273e5
raw history blame
No virus
7.45 kB
import streamlit as st
import time, uuid
from datetime import datetime
import numpy as np
from PIL import Image
if 'session_id' not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
st.session_state.n_gen = 0
# constants
HF_REPO_NAME_DIFFUSION = 'nostalgebraist/nostalgebraist-autoresponder-diffusion'
model_path_diffusion = 'nostalgebraist-autoresponder-diffusion'
timestep_respacing_sres1 = '20' # '90,60,60,20,20'
timestep_respacing_sres2 = '20' # '250'
DIFFUSION_DEFAULTS = dict(
batch_size=1,
n_samples=1,
clf_free_guidance=True,
clf_free_guidance_sres=False,
guidance_scale=1,
guidance_scale_sres=0,
yield_intermediates=True
)
@st.experimental_singleton
def setup():
import os, subprocess, sys
if not os.path.exists('improved_diffusion'):
os.system("git clone https://github.com/nostalgebraist/improved-diffusion.git")
os.system("cd improved-diffusion && git fetch origin nbar-space && git checkout nbar-space && pip install -e .")
os.system("pip install tokenizers x-transformers==0.22.0 axial-positional-embedding")
os.system("pip install einops==0.3.2")
sys.path.append("improved-diffusion")
import improved_diffusion.pipeline
from transformer_utils.util.tfm_utils import get_local_path_from_huggingface_cdn
if not os.path.exists(model_path_diffusion):
model_tar_name = 'model.tar'
model_tar_path = get_local_path_from_huggingface_cdn(
HF_REPO_NAME_DIFFUSION, model_tar_name
)
subprocess.run(f"tar -xf {model_tar_path} && rm {model_tar_path}", shell=True)
checkpoint_path_sres1 = os.path.join(model_path_diffusion, "sres1.pt")
config_path_sres1 = os.path.join(model_path_diffusion, "config_sres1.json")
checkpoint_path_sres2 = os.path.join(model_path_diffusion, "sres2.pt")
config_path_sres2 = os.path.join(model_path_diffusion, "config_sres2.json")
# load
sampling_model_sres1 = improved_diffusion.pipeline.SamplingModel.from_config(
checkpoint_path=checkpoint_path_sres1,
config_path=config_path_sres1,
timestep_respacing=timestep_respacing_sres1
)
sampling_model_sres2 = improved_diffusion.pipeline.SamplingModel.from_config(
checkpoint_path=checkpoint_path_sres2,
config_path=config_path_sres2,
timestep_respacing=timestep_respacing_sres2
)
pipeline = improved_diffusion.pipeline.SamplingPipeline(sampling_model_sres1, sampling_model_sres2)
return pipeline
def now_str():
return datetime.utcnow().strftime('%Y-%m-%d %H-%M-%S')
def log(msg, st_state):
session_id = st_state.session_id if 'session_id' in st_state else None
n_gen = st.session_state.n_gen if 'n_gen' in st_state else None
print(f"{now_str()} {session_id} ({n_gen}th gen):\n\t{msg}\n")
def handler(text, ts1, ts2, gs1, st_state):
pipeline = setup()
data = {'text': text[:380], 'guidance_scale': gs1}
args = {k: v for k, v in DIFFUSION_DEFAULTS.items()}
args.update(data)
log_data = {'ts1': ts2, 'ts2': ts2}
log_data.update(args)
log(repr(log_data), st_state)
pipeline.base_model.set_timestep_respacing(str(ts1))
pipeline.super_res_model.set_timestep_respacing(str(ts2))
return pipeline.sample(**args)
FRESH = True
st.title('nostalgebraist-autoresponder image generation demo')
st.write("#### For a **much faster experience**, try the [Colab notebook](https://colab.research.google.com/drive/17BOTYmLv4fdurr8y5dcaGKy8JVY_A62a?usp=sharing) instead!")
st.write("A demo of the image models used in the tumblr bot [nostalgebraist-autoresponder](https://nostalgebraist-autoresponder.tumblr.com/).\n\nBy [nostalgebraist](https://nostalgebraist.tumblr.com/)")
st.write('##### What is this thing? How does it work?')
st.write("See [this post](https://nostalgebraist.tumblr.com/post/672300992964050944/franks-image-generation-model-explained) for an explanation.")
st.header('Prompt')
button_dril = st.button('Fill @dril tweet example text')
if FRESH and button_dril:
st.session_state.fill_value = 'wint\nFollowing\n@dril\nthe wise man bowed his head solemnly and\nspoke: "theres actually zero difference\nbetween good & bad things. you imbecile.\nyou fucking moron'
if 'fill_value' in st.session_state:
fill_value = st.session_state.fill_value
else:
fill_value = ""
st.session_state.fill_value = fill_value
text = st.text_area('Enter your text here (or leave blank for a textless image)', max_chars=380, height=230,
value=fill_value)
st.header('Settings')
st.write("The bot uses 250 base steps and 250 upsampling steps, with custom spacing (not available here) for the base part.\n\nSince this demo doesn't have a GPU, you'll probably want to use fewer than 250 steps unles you have a lot of patience.")
help_ts1 = "How long to run the base model. Larger values make the image more realistic / better. Smaller values are faster."
help_ts2 = "How long to run the upsampling model. Larger values sometimes make the big image crisper and more detailed. Smaller values are faster."
help_gs1 = "Guidance scale. Larger values make the image more likely to contain the text you wrote. If this is zero, the first part will be faster."
ts1 = st.slider('Steps (base)', min_value=5, max_value=500, value=50, step=5, help=help_ts1)
ts2 = st.slider('Steps (upsampling)', min_value=5, max_value=500, value=50, step=5, help=help_ts2)
gs1 = st.select_slider('Guidance scale (base)', [0.5*i for i in range(9)], value=1.0, help=help_gs1)
button_go = st.button('Generate')
button_stop = st.button('Stop')
st.write("During generation, the two images show different ways of looking at the same process.\n- The left image starts with 100% noise and gradually turns into 100 signal.\n- The right image shows the model's current 'guess' about what the left image will look like when all the noise has been removed.")
generating_marker = st.empty()
low_res = st.empty()
high_res = st.empty()
if button_go:
st.session_state.n_gen = st.session_state.n_gen + 1
with generating_marker.container():
st.write('**Generating...**')
st.write('**Prompt:**')
st.write(repr(text))
count_low_res, count_high_res = 0, 0
times_low, times_high = [], []
t = time.time()
for s, xs in handler(text, ts1, ts2, gs1, st.session_state):
s = Image.fromarray(s[0])
xs = Image.fromarray(xs[0])
t2 = time.time()
delta = t2 - t
t = t2
is_high_res = s.size[0] == 256
if is_high_res:
target = high_res
count_high_res += 1
count = count_high_res
total = ts2
times_high.append(delta)
times = times_high
prefix = "Part 2 of 2 (upsampling)"
else:
target = low_res
count_low_res += 1
count = count_low_res
total = ts1
times_low.append(delta)
times = times_low
prefix = "Part 1 of 2 (base model)"
rate = sum(times)/len(times)
with target.container():
st.image([s, xs])
st.write(f'{prefix} | {count:02d} / {total} frames | {rate:.2f} seconds/frame')
if button_stop:
log('gen stopped', st.session_state)
break
with generating_marker.container():
log('gen complete', st.session_state)
st.write('')