File size: 2,248 Bytes
34dc018
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os, getpass

import numpy as np
import replicate


class DalleImageGenerator:
    """Dall-e model using Replicate API"""
    def __init__(self, token=None):
        if "REPLICATE_API_TOKEN" not in os.environ:
            if token is not None:
                os.environ["REPLICATE_API_TOKEN"] = token
            else:
                print(f"Please go to https://replicate.com/docs/api for your Replicate API token.")
                os.environ["REPLICATE_API_TOKEN"] = getpass.getpass(f"Input Replicate API Token:")

        self.dalle = replicate.models.get("kuprel/min-dalle")

    def generate_images(self, text, grid_size, text_adherence=2):
        urls = self.dalle.predict(text=text, grid_size=grid_size, log2_supercondition_factor=text_adherence)
        images = get_image(list(urls)[-1])
        h, w = images.shape[:2]
        h, w = h // grid_size, w // grid_size
        return blockshaped(images, h, w)


def get_image(url):
    """download image from a url"""
    from urllib.request import Request, urlopen
    import io
    import PIL.Image as Image
    hdr = {
        'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.11 (KHTML, like Gecko) Chrome/23.0.1271.64 Safari/537.11',
        'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
        'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.3',
        'Accept-Encoding': 'none',
        'Accept-Language': 'en-US,en;q=0.8',
        'Connection': 'keep-alive'}

    # urllib.request.urlretrieve(url, f"local-filename.jpg")
    req = Request(url, headers=hdr)
    page = urlopen(req)
    return np.array(Image.open(io.BytesIO(page.read())))


def blockshaped(arr, nrows, ncols):
    """
    Return an array of shape (n, nrows, ncols) where
    n * nrows * ncols = arr.size

    If arr is a 2D array, the returned array should look like n subblocks with
    each subblock preserving the "physical" layout of arr.
    """
    h, w, c = arr.shape
    assert h % nrows == 0, f"{h} rows is not evenly divisible by {nrows}"
    assert w % ncols == 0, f"{w} cols is not evenly divisible by {ncols}"
    return (arr.reshape(h//nrows, nrows, w//ncols, ncols, - 1)
               .swapaxes(1,2)
               .reshape(-1, nrows, ncols, 3))