valhalla commited on
Commit
4d72a29
1 Parent(s): 47c5fb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -45
app.py CHANGED
@@ -1,67 +1,57 @@
 
1
  import os
2
- import sys
 
 
3
 
4
- sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5
-
6
- import numpy as np
7
  import streamlit as st
8
  from PIL import Image
9
 
10
- import clip
11
- from dalle.models import Dalle
12
- from dalle.utils.utils import clip_score, download
13
 
14
- url = "https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz"
15
- root = os.path.expanduser("~/.cache/minDALLE")
16
- filename = os.path.basename(url)
17
- pathname = filename[:-len('.tar.gz')]
18
 
19
- expected_md5 = url.split("/")[-2]
20
- download_target = os.path.join(root, filename)
21
- result_path = os.path.join(root, pathname)
22
 
23
- if not os.path.exists(result_path):
24
- result_path = download(url, root)
25
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
- device = "cpu"
29
- model = Dalle.from_pretrained(result_path) # This will automatically download the pretrained model.
30
- model.to(device=device)
31
 
32
- model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
33
- model_clip.to(device=device)
34
 
35
 
36
- def sample(prompt):
37
- # Sampling
38
- images = (
39
- model.sampling(prompt=prompt, top_k=256, top_p=None, softmax_temperature=1.0, num_candidates=3, device=device)
40
- .cpu()
41
- .numpy()
42
- )
43
- images = np.transpose(images, (0, 2, 3, 1))
44
 
45
- # CLIP Re-ranking
46
- rank = clip_score(
47
- prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device
48
- )
49
 
50
- # Save images
51
- images = images[rank]
52
- # print(rank, images.shape)
53
- pil_images = []
54
- for i in range(len(images)):
55
- im = Image.fromarray((images[i] * 255).astype(np.uint8))
56
- pil_images.append(im)
57
-
58
- # im = Image.fromarray((images[0] * 255).astype(np.uint8))
59
- return pil_images
60
 
61
 
62
  st.header("minDALL-E")
63
  st.subheader("Generate images from text")
64
 
 
 
 
65
  prompt = st.text_input("What do you want to see?")
66
 
67
  DEBUG = False
@@ -90,9 +80,9 @@ if prompt != "":
90
  )
91
 
92
  print(f"Getting selections: {prompt}")
93
- selected = sample(prompt)
94
 
95
- margin = 0.1 #for better position of zoom in arrow
96
  n_columns = 3
97
  cols = st.columns([1] + [margin, 1] * (n_columns - 1))
98
  for i, img in enumerate(selected):
 
1
+ import base64
2
  import os
3
+ import time
4
+ from io import BytesIO
5
+ from multiprocessing import Process
6
 
 
 
 
7
  import streamlit as st
8
  from PIL import Image
9
 
10
+ import requests
 
 
11
 
 
 
 
 
12
 
13
+ def start_server():
14
+ os.system("uvicorn server:app --port 8080 --host 0.0.0.0 --workers 2")
 
15
 
 
 
16
 
17
+ def load_models():
18
+ if not is_port_in_use(8080):
19
+ with st.spinner(text="Loading models, please wait..."):
20
+ proc = Process(target=start_server, args=(), daemon=True)
21
+ proc.start()
22
+ while not is_port_in_use(8080):
23
+ time.sleep(1)
24
+ st.success("Model server started.")
25
+ else:
26
+ st.success("Model server already running...")
27
+ st.session_state["models_loaded"] = True
28
 
29
 
30
+ def is_port_in_use(port):
31
+ import socket
 
32
 
33
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
34
+ return s.connect_ex(("0.0.0.0", port)) == 0
35
 
36
 
37
+ def generate(prompt):
38
+ correct_request = f"http://0.0.0.0:8080/correct?prompt={prompt}"
39
+ response = requests.get(correct_request)
40
+ images = response.json()["images"]
41
+ images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
42
+ return images
 
 
43
 
 
 
 
 
44
 
45
+ if "models_loaded" not in st.session_state:
46
+ st.session_state["models_loaded"] = False
 
 
 
 
 
 
 
 
47
 
48
 
49
  st.header("minDALL-E")
50
  st.subheader("Generate images from text")
51
 
52
+ if not st.session_state["models_loaded"]:
53
+ load_models()
54
+
55
  prompt = st.text_input("What do you want to see?")
56
 
57
  DEBUG = False
 
80
  )
81
 
82
  print(f"Getting selections: {prompt}")
83
+ selected = generate(prompt)
84
 
85
+ margin = 0.1 # for better position of zoom in arrow
86
  n_columns = 3
87
  cols = st.columns([1] + [margin, 1] * (n_columns - 1))
88
  for i, img in enumerate(selected):