johnpaulbin commited on
Commit
7850b9e
1 Parent(s): b36119b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -116
app.py CHANGED
@@ -1,119 +1,95 @@
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
- # Uncomment to run on cpu
4
- #import os
5
- #os.environ["JAX_PLATFORM_NAME"] = "cpu"
6
  import random
7
- import jax
8
- import flax.linen as nn
9
- from flax.training.common_utils import shard
10
- from flax.jax_utils import replicate, unreplicate
11
- from transformers import BartTokenizer, FlaxBartForConditionalGeneration
12
- from PIL import Image
13
- import numpy as np
14
- import matplotlib.pyplot as plt
15
- from vqgan_jax.modeling_flax_vqgan import VQModel
16
- from dalle_mini.model import CustomFlaxBartForConditionalGeneration
17
- # ## CLIP Scoring
18
- from transformers import CLIPProcessor, FlaxCLIPModel
19
- import gradio as gr
20
- from dalle_mini.helpers import captioned_strip
21
- DALLE_REPO = 'flax-community/dalle-mini'
22
- DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'
23
- VQGAN_REPO = 'flax-community/vqgan_f16_16384'
24
- VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'
25
- tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
26
- model = CustomFlaxBartForConditionalGeneration.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
27
- vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
28
- def custom_to_pil(x):
29
- x = np.clip(x, 0., 1.)
30
- x = (255*x).astype(np.uint8)
31
- x = Image.fromarray(x)
32
- if not x.mode == "RGB":
33
- x = x.convert("RGB")
34
- return x
35
- def generate(input, rng, params):
36
- return model.generate(
37
- **input,
38
- max_length=257,
39
- num_beams=1,
40
- do_sample=True,
41
- prng_key=rng,
42
- eos_token_id=50000,
43
- pad_token_id=50000,
44
- params=params,
45
- )
46
- def get_images(indices, params):
47
- return vqgan.decode_code(indices, params=params)
48
- p_generate = jax.pmap(generate, "batch")
49
- p_get_images = jax.pmap(get_images, "batch")
50
- bart_params = replicate(model.params)
51
- vqgan_params = replicate(vqgan.params)
52
- clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
53
- print("Initialize FlaxCLIPModel")
54
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
55
- print("Initialize CLIPProcessor")
56
- def hallucinate(prompt, num_images=64):
57
- prompt = [prompt] * jax.device_count()
58
- inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
59
- inputs = shard(inputs)
60
- all_images = []
61
- for i in range(num_images // jax.device_count()):
62
- key = random.randint(0, 1e7)
63
- rng = jax.random.PRNGKey(key)
64
- rngs = jax.random.split(rng, jax.local_device_count())
65
- indices = p_generate(inputs, rngs, bart_params).sequences
66
- indices = indices[:, :, 1:]
67
- images = p_get_images(indices, vqgan_params)
68
- images = np.squeeze(np.asarray(images), 1)
69
- for image in images:
70
- all_images.append(custom_to_pil(image))
71
- return all_images
72
- def clip_top_k(prompt, images, k=8):
73
- inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
74
- outputs = clip(**inputs)
75
- logits = outputs.logits_per_text
76
- scores = np.array(logits[0]).argsort()[-k:][::-1]
77
- return [images[score] for score in scores]
78
- def compose_predictions(images, caption=None):
79
- increased_h = 0 if caption is None else 48
80
- w, h = images[0].size[0], images[0].size[1]
81
- img = Image.new("RGB", (len(images)*w, h + increased_h))
82
- for i, img_ in enumerate(images):
83
- img.paste(img_, (i*w, increased_h))
84
- if caption is not None:
85
- draw = ImageDraw.Draw(img)
86
- font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
87
- draw.text((20, 3), caption, (255,255,255), font=font)
88
- return img
89
- def top_k_predictions(prompt, num_candidates=32, k=8):
90
- images = hallucinate(prompt, num_images=num_candidates)
91
- images = clip_top_k(prompt, images, k=k)
92
- return images
93
- def run_inference(prompt, num_images=32, num_preds=8):
94
- images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
95
- predictions = captioned_strip(images)
96
- output_title = f"""
97
- <b>{prompt}</b>
98
- """
99
- return (output_title, predictions)
100
- outputs = [
101
- gr.outputs.HTML(label=""), # To be used as title
102
- gr.outputs.Image(label=''),
103
- ]
104
- description = """
105
- DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
106
- """
107
- gr.Interface(run_inference,
108
- inputs=[gr.inputs.Textbox(label='What do you want to see?')],
109
- outputs=outputs,
110
- title='DALL·E mini',
111
- description=description,
112
- article="<p style='text-align: center'> Created by Boris Dayma et al. 2021 | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a> | <a href='https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA'>Report</a></p>",
113
- layout='vertical',
114
- theme='huggingface',
115
- examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
116
- allow_flagging=False,
117
- live=False,
118
- # server_port=8999
119
- ).launch(share=True)
1
  #!/usr/bin/env python
2
  # coding: utf-8
 
 
 
3
  import random
4
+ from dalle_mini.backend import ServiceError, get_images_from_backend
5
+ import streamlit as st
6
+ # streamlit.session_state is not available in Huggingface spaces.
7
+ # Session state hack https://huggingface.slack.com/archives/C025LJDP962/p1626527367443200?thread_ts=1626525999.440500&cid=C025LJDP962
8
+ from streamlit.report_thread import get_report_ctx
9
+ def query_cache(q_emb=None):
10
+ ctx = get_report_ctx()
11
+ session_id = ctx.session_id
12
+ session = st.server.server.Server.get_current()._get_session_info(session_id).session
13
+ if not hasattr(session, "_query_state"):
14
+ setattr(session, "_query_state", q_emb)
15
+ if q_emb:
16
+ session._query_state = q_emb
17
+ return session._query_state
18
+ def set_run_again(state):
19
+ query_cache(state)
20
+ def should_run_again():
21
+ state = query_cache()
22
+ return state if state is not None else False
23
+ st.sidebar.markdown("""
24
+ <style>
25
+ .aligncenter {
26
+ text-align: center;
27
+ }
28
+ </style>
29
+ <p class="aligncenter">
30
+ <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/img/logo.png"/>
31
+ </p>
32
+ """, unsafe_allow_html=True)
33
+ st.sidebar.markdown("""
34
+ ___
35
+ <p style='text-align: center'>
36
+ DALL·E mini is an AI model that generates images from any prompt you give!
37
+ </p>
38
+ <p style='text-align: center'>
39
+ Created by Boris Dayma et al. 2021
40
+ <br/>
41
+ <a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA" target="_blank">Project Report</a>
42
+ </p>
43
+ """, unsafe_allow_html=True)
44
+ st.header('DALL·E mini')
45
+ st.subheader('Generate images from text')
46
+ prompt = st.text_input("What do you want to see?")
47
+ test = st.empty()
48
+ DEBUG = False
49
+ if prompt != "" or (should_run_again and prompt != ""):
50
+ container = st.empty()
51
+ # The following mimics `streamlit.info()`.
52
+ # I tried to get the secondary background color using `components.streamlit.config.get_options_for_section("theme")["secondaryBackgroundColor"]`
53
+ # but it returns None.
54
+ container.markdown(f"""
55
+ <style> p {{ margin:0 }} div {{ margin:0 }} </style>
56
+ <div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
57
+ <div class="stAlert">
58
+ <div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
59
+ <div class="st-b7">
60
+ <div class="css-whx05o e13vu3m50">
61
+ <div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
62
+ <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/img/loading.gif" width="30"/>
63
+ Generating predictions for: <b>{prompt}</b>
64
+ </div>
65
+ </div>
66
+ </div>
67
+ </div>
68
+ </div>
69
+ </div>
70
+ <small><i>Predictions may take up to 40s under high load. Please stand by.</i></small>
71
+ """, unsafe_allow_html=True)
72
+ try:
73
+ backend_url = st.secrets["BACKEND_SERVER"]
74
+ print(f"Getting selections: {prompt}")
75
+ selected = get_images_from_backend(prompt, backend_url)
76
+ cols = st.beta_columns(4)
77
+ for i, img in enumerate(selected):
78
+ cols[i%4].image(img)
79
+ container.markdown(f"**{prompt}**")
80
+
81
+ set_run_again(st.button('Again!', key='again_button'))
82
+
83
+ except ServiceError as error:
84
+ container.text(f"Service unavailable, status: {error.status_code}")
85
+ except KeyError:
86
+ if DEBUG:
87
+ container.markdown("""
88
+ **Error: BACKEND_SERVER unset**
89
+ Please, create a file called `.streamlit/secrets.toml` inside the app's folder and include a line to configure the server URL:
90
+ ```
91
+ BACKEND_SERVER="<server url>"
92
+ ```
93
+ """)
94
+ else:
95
+ container.markdown('Error -5, please try again or [report it](mailto:pcuenca-dalle@guenever.net).')