valhalla commited on
Commit
0ea4415
1 Parent(s): df7dab5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -1
app.py CHANGED
@@ -1,7 +1,89 @@
 
 
1
 
 
 
 
2
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- st.header("DALL·E mini")
 
 
 
 
 
 
 
 
 
 
 
 
5
  st.subheader("Generate images from text")
6
 
7
  prompt = st.text_input("What do you want to see?")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
 
4
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5
+
6
+ import numpy as np
7
  import streamlit as st
8
+ from PIL import Image
9
+
10
+ import clip
11
+ from dalle.models import Dalle
12
+ from dalle.utils.utils import clip_score
13
+
14
+
15
+ device = "cpu"
16
+ model = Dalle.from_pretrained("minDALL-E/1.3B") # This will automatically download the pretrained model.
17
+ model.to(device=device)
18
+
19
+ model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
20
+ model_clip.to(device=device)
21
+
22
+
23
+ def sample(prompt):
24
+ # Sampling
25
+ images = (
26
+ model.sampling(prompt=prompt, top_k=256, top_p=None, softmax_temperature=1.0, num_candidates=3, device=device)
27
+ .cpu()
28
+ .numpy()
29
+ )
30
+ images = np.transpose(images, (0, 2, 3, 1))
31
+
32
+ # CLIP Re-ranking
33
+ rank = clip_score(
34
+ prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device
35
+ )
36
 
37
+ # Save images
38
+ images = images[rank]
39
+ # print(rank, images.shape)
40
+ pil_images = []
41
+ for i in range(len(images)):
42
+ im = Image.fromarray((images[i] * 255).astype(np.uint8))
43
+ pil_images.append(im)
44
+
45
+ # im = Image.fromarray((images[0] * 255).astype(np.uint8))
46
+ return pil_images
47
+
48
+
49
+ st.header("minDALL-E")
50
  st.subheader("Generate images from text")
51
 
52
  prompt = st.text_input("What do you want to see?")
53
+
54
+ DEBUG = False
55
+ if prompt != "":
56
+ container = st.empty()
57
+ container.markdown(
58
+ f"""
59
+ <style> p {{ margin:0 }} div {{ margin:0 }} </style>
60
+ <div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
61
+ <div class="stAlert">
62
+ <div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
63
+ <div class="st-b7">
64
+ <div class="css-whx05o e13vu3m50">
65
+ <div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
66
+ <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
67
+ Generating predictions for: <b>{prompt}</b>
68
+ </div>
69
+ </div>
70
+ </div>
71
+ </div>
72
+ </div>
73
+ </div>
74
+ <small><i>Predictions may take up to 40s under high load. Please stand by.</i></small>
75
+ """,
76
+ unsafe_allow_html=True,
77
+ )
78
+
79
+ print(f"Getting selections: {prompt}")
80
+ selected = sample(prompt)
81
+
82
+ margin = 0.1 # for better position of zoom in arrow
83
+ n_columns = 3
84
+ cols = st.columns([1] + [margin, 1] * (n_columns - 1))
85
+ for i, img in enumerate(selected):
86
+ cols[(i % n_columns) * 2].image(img)
87
+ container.markdown(f"**{prompt}**")
88
+
89
+ st.button("Again!", key="again_button")