Spaces:
Runtime error
Runtime error
File size: 2,531 Bytes
e83adaa dc6157e e83adaa d1ca92a e83adaa 032f281 e83adaa be419c1 a10c3fd e83adaa d9e5ea9 1e6b2bf e83adaa 81ab75e e83adaa 81ab75e e83adaa 81ab75e 8124b54 81ab75e 8124b54 81ab75e e83adaa 81ab75e e83adaa 81ab75e e83adaa 81ab75e e83adaa 81ab75e e83adaa 81ab75e e83adaa 81ab75e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import diffusers
import torch
import os
import time
import accelerate
import streamlit as st
from stqdm import stqdm
from diffusers import DiffusionPipeline, UNet2DConditionModel
from PIL import Image
MODEL_REPO = 'OFA-Sys/small-stable-diffusion-v0'
LoRa_DIR = 'weights'
DATASET_REPO = 'VESSL/Bored_Ape_NFT_text'
SAMPLE_IMAGE = 'weights/Sample.png'
def load_pipeline_w_lora() :
# Load pretrained unet from huggingface
unet = UNet2DConditionModel.from_pretrained(
MODEL_REPO,
subfolder="unet",
revision=None
)
# Load pipeline
pipeline = DiffusionPipeline.from_pretrained(
MODEL_REPO,
unet=unet,
revision=None,
torch_dtype=torch.float32,
)
# Load LoRa attn layer weights to unet attn layers
print('LoRa layers loading...')
pipeline.unet.load_attn_procs(LoRa_DIR)
print('LoRa layers loaded')
pipeline.set_progress_bar_config(disable=True)
return pipeline
def elapsed_time(fn, *args):
start = time.time()
output = fn(*args)
end = time.time()
elapsed = f'{end - start:.2f}'
return elapsed, output
st.title("BAYC Text to IMAGE generator")
st.write(f"Stable diffusion model is fine-tuned by lora using dataset {DATASET_REPO}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
st.write("Loading models...")
elapsed, pipeline = elapsed_time(load_pipeline_w_lora)
st.write(f"Model is loaded in {elapsed} seconds!")
pipeline = pipeline.to(device)
sample = Image.open(SAMPLE_IMAGE)
st.image(sample, caption="Example image with prompt <An ape with solid gold fur and beanie>")
with st.form(key="information", clear_on_submit=True):
prompt = st.text_input(
label="Write prompt to generate your unique BAYC image! (e.g. An ape with golden fur)")
num_images = st.number_input(label="Number of images to generate", min_value=1, max_value=10)
seed = st.number_input(label="Seed for images", min_value=1, max_value=10000)
submitted = st.form_submit_button(label="Submit")
if submitted :
st.write(f"Generating {num_images} BAYC image with prompt <{prompt}>...")
generator = torch.Generator(device=device).manual_seed(seed)
images = []
for img_idx in stqdm(range(num_images)):
generated_image = pipeline(prompt, num_inference_steps=30, generator=generator).images[0]
images.append(generated_image)
st.write("Done!")
st.image(images, width=150, caption=f"Generated Images with <{prompt}>")
|