File size: 1,718 Bytes
244fae2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts
from htbuilder.units import percent, px
from htbuilder.funcs import rgba, rgb
import streamlit as st
import os
import sys
import argparse
import clip
import numpy as np
from PIL import Image
from dalle.models import Dalle
from dalle.utils.utils import set_seed, clip_score
import cv2
import subprocess
import signal

def signal_handler(sig, frame):
    print('You pressed Ctrl+C!')
    sys.exit(0)


def generate(prompt,crazy):

    print("-------------------")

    signal.signal(signal.SIGINT, signal_handler)

    device = 'cpu'
    model = Dalle.from_pretrained('minDALL-E/1.3B')  # This will automatically download the pretrained model.
    model.to(device=device)
    num_candidates = 3

    images = []
    
    set_seed(np.random.randint(0,10000))

    # Sampling
    images = model.sampling(prompt=prompt,
                            top_k=2048,
                            top_p=None,
                            softmax_temperature=crazy,
                            num_candidates=num_candidates,
                            device=device).cpu().numpy()
    images = np.transpose(images, (0, 2, 3, 1))

    # CLIP Re-ranking
    model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
    model_clip.to(device=device)
    rank = clip_score(prompt=prompt,
                      images=images,
                      model_clip=model_clip,
                      preprocess_clip=preprocess_clip,
                      device=device)

    # Save images
    #return images[rank]
    for image in images:
        cv2.imwrite('temp/'+str(np.random.randint(0,10000))+'.jpeg', image)


generate("a pink house",0.75)