import diffusers import torch import os import time import accelerate import streamlit as st from stqdm import stqdm from diffusers import DiffusionPipeline, UNet2DConditionModel from PIL import Image MODEL_REPO = 'OFA-Sys/small-stable-diffusion-v0' LoRa_DIR = 'weights' DATASET_REPO = 'VESSL/Bored_Ape_NFT_text' SAMPLE_IMAGE = 'weights/Sample.png' # @st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda parameter: parameter.data.numpy()}) def load_pipeline_w_lora() : # Load pipeline pipeline = DiffusionPipeline.from_pretrained( MODEL_REPO, revision=None, torch_dtype=torch.float32, ) # Load LoRa attn layer weights to unet attn layers print('LoRa layers loading...') pipeline.unet.load_attn_procs(LoRa_DIR) print('LoRa layers loaded') pipeline.set_progress_bar_config(disable=True) return pipeline def elapsed_time(fn, *args): start = time.time() output = fn(*args) end = time.time() elapsed = f'{end - start:.2f}' return elapsed, output st.title("BAYC Text to IMAGE generator") st.write(f"Stable diffusion model is fine-tuned by lora using dataset {DATASET_REPO}") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") st.write("Loading models...") elapsed, pipeline = elapsed_time(load_pipeline_w_lora) st.write(f"Model is loaded in {elapsed} seconds!") pipeline = pipeline.to(device) sample = Image.open(SAMPLE_IMAGE) st.image(sample, caption="Example image with prompt ") with st.form(key="information", clear_on_submit=True): prompt = st.text_input( label="Write prompt to generate your unique BAYC image! (e.g. An ape with golden fur)") num_images = st.number_input(label="Number of images to generate", min_value=1, max_value=10) seed = st.number_input(label="Seed for images", min_value=1, max_value=10000) submitted = st.form_submit_button(label="Submit") if submitted : st.write(f"Generating {num_images} BAYC image with prompt <{prompt}>...") generator = torch.Generator(device=device).manual_seed(seed) images = [] for img_idx in stqdm(range(num_images)): generated_image = pipeline(prompt, num_inference_steps=30, generator=generator).images[0] images.append(generated_image) st.write("Done!") st.image(images, width=300, caption=[f"Generated Images with <{prompt}>" for i in range(len(images))])