File size: 2,531 Bytes
e83adaa
 
 
 
dc6157e
e83adaa
 
d1ca92a
e83adaa
 
 
 
 
032f281
e83adaa
be419c1
a10c3fd
e83adaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9e5ea9
 
 
 
 
 
1e6b2bf
e83adaa
 
 
 
 
 
 
 
 
 
 
 
 
 
81ab75e
 
e83adaa
81ab75e
e83adaa
81ab75e
 
 
8124b54
81ab75e
8124b54
81ab75e
 
e83adaa
81ab75e
 
 
e83adaa
81ab75e
e83adaa
81ab75e
 
 
 
 
 
e83adaa
81ab75e
 
 
 
 
e83adaa
81ab75e
e83adaa
81ab75e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import diffusers
import torch
import os
import time
import accelerate
import streamlit as st

from stqdm import stqdm
from diffusers import DiffusionPipeline, UNet2DConditionModel
from PIL import Image


MODEL_REPO = 'OFA-Sys/small-stable-diffusion-v0'
LoRa_DIR = 'weights'
DATASET_REPO = 'VESSL/Bored_Ape_NFT_text'
SAMPLE_IMAGE = 'weights/Sample.png'

def load_pipeline_w_lora() :

    # Load pretrained unet from huggingface
    unet = UNet2DConditionModel.from_pretrained(
        MODEL_REPO,
        subfolder="unet",
        revision=None
    )

    # Load pipeline
    pipeline = DiffusionPipeline.from_pretrained(
        MODEL_REPO,
        unet=unet,
        revision=None,
        torch_dtype=torch.float32,
    )
    
    # Load LoRa attn layer weights to unet attn layers
    print('LoRa layers loading...')
    pipeline.unet.load_attn_procs(LoRa_DIR)
    print('LoRa layers loaded')
    
    pipeline.set_progress_bar_config(disable=True)

    return pipeline


def elapsed_time(fn, *args):
    start = time.time()
    output = fn(*args)
    end = time.time()

    elapsed = f'{end - start:.2f}'

    return elapsed, output


st.title("BAYC Text to IMAGE generator")
st.write(f"Stable diffusion model is fine-tuned by lora using dataset {DATASET_REPO}")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

st.write("Loading models...")
elapsed, pipeline = elapsed_time(load_pipeline_w_lora)
st.write(f"Model is loaded in {elapsed} seconds!")

pipeline = pipeline.to(device)

sample = Image.open(SAMPLE_IMAGE)
st.image(sample, caption="Example image with prompt <An ape with solid gold fur and beanie>")

with st.form(key="information", clear_on_submit=True): 
    prompt = st.text_input(
        label="Write prompt to generate your unique BAYC image! (e.g. An ape with golden fur)")

    num_images = st.number_input(label="Number of images to generate", min_value=1, max_value=10)

    seed = st.number_input(label="Seed for images", min_value=1, max_value=10000)

    submitted = st.form_submit_button(label="Submit")
    
if submitted :
    st.write(f"Generating {num_images} BAYC image with prompt <{prompt}>...")

    generator = torch.Generator(device=device).manual_seed(seed)
    images = []
    for img_idx in stqdm(range(num_images)):
        generated_image = pipeline(prompt, num_inference_steps=30, generator=generator).images[0]
        images.append(generated_image)

    st.write("Done!")

    st.image(images, width=150, caption=f"Generated Images with <{prompt}>")