awacke1 commited on
Commit
ea5b567
1 Parent(s): 2e8bac6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -0
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import random
5
+ import uuid
6
+ import base64
7
+ import gradio as gr
8
+ import numpy as np
9
+ from PIL import Image
10
+ import spaces
11
+ import torch
12
+ import glob
13
+ from datetime import datetime
14
+
15
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
16
+
17
+ DESCRIPTION = """# DALL•E 3 XL v2 High Fi"""
18
+
19
+ def create_download_link(filename):
20
+ with open(filename, "rb") as file:
21
+ encoded_string = base64.b64encode(file.read()).decode('utf-8')
22
+ download_link = f'<a href="data:image/png;base64,{encoded_string}" download="{filename}">Download Image</a>'
23
+ return download_link
24
+
25
+ def save_image(img, prompt):
26
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
27
+ filename = f"{timestamp}_{prompt[:50]}.png" # Limit filename length
28
+ img.save(filename)
29
+ return filename
30
+
31
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
32
+ if randomize_seed:
33
+ seed = random.randint(0, MAX_SEED)
34
+ return seed
35
+
36
+ def get_image_gallery():
37
+ image_files = glob.glob("*.png")
38
+ image_files.sort(key=os.path.getmtime, reverse=True)
39
+ return image_files
40
+
41
+ MAX_SEED = np.iinfo(np.int32).max
42
+
43
+ if not torch.cuda.is_available():
44
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
45
+
46
+ USE_TORCH_COMPILE = 0
47
+ ENABLE_CPU_OFFLOAD = 0
48
+
49
+ if torch.cuda.is_available():
50
+ pipe = StableDiffusionXLPipeline.from_pretrained(
51
+ "fluently/Fluently-XL-v4",
52
+ torch_dtype=torch.float16,
53
+ use_safetensors=True,
54
+ )
55
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
56
+
57
+ pipe.load_lora_weights("ehristoforu/dalle-3-xl-v2", weight_name="dalle-3-xl-lora-v2.safetensors", adapter_name="dalle")
58
+ pipe.set_adapters("dalle")
59
+
60
+ pipe.to("cuda")
61
+
62
+ @spaces.GPU(enable_queue=True)
63
+ def generate(
64
+ prompt: str,
65
+ negative_prompt: str = "",
66
+ use_negative_prompt: bool = False,
67
+ seed: int = 0,
68
+ width: int = 1024,
69
+ height: int = 1024,
70
+ guidance_scale: float = 3,
71
+ randomize_seed: bool = False,
72
+ progress=gr.Progress(track_tqdm=True),
73
+ ):
74
+ seed = int(randomize_seed_fn(seed, randomize_seed))
75
+
76
+ if not use_negative_prompt:
77
+ negative_prompt = ""
78
+
79
+ images = pipe(
80
+ prompt=prompt,
81
+ negative_prompt=negative_prompt,
82
+ width=width,
83
+ height=height,
84
+ guidance_scale=guidance_scale,
85
+ num_inference_steps=20,
86
+ num_images_per_prompt=1,
87
+ cross_attention_kwargs={"scale": 0.65},
88
+ output_type="pil",
89
+ ).images
90
+ image_paths = [save_image(img, prompt) for img in images]
91
+ download_links = [create_download_link(path) for path in image_paths]
92
+
93
+ return image_paths, seed, download_links, get_image_gallery()
94
+
95
+ examples = [
96
+ "An elderly man engages in a virtual reality physical therapy session, guided by a compassionate AI therapist that adapts the exercises to his abilities and provides encouragement, all from the comfort of his own home.",
97
+ "In a bright, welcoming dental office, a young patient watches in awe as a dental robot efficiently and painlessly repairs a cavity using a laser system, while the dentist explains the procedure using interactive 3D images.",
98
+ "A team of biomedical engineers collaborate in a state-of-the-art research facility, designing and testing advanced prosthetic limbs that seamlessly integrate with the patient's nervous system for natural, intuitive control.",
99
+ "A pregnant woman undergoes a routine check-up, as a gentle robotic ultrasound system captures high-resolution 3D images of her developing baby, while the obstetrician provides reassurance and guidance via a holographic display.",
100
+ "In a cutting-edge cancer treatment center, a patient undergoes a precision radiation therapy session, where an AI-guided system delivers highly targeted doses to destroy cancer cells while preserving healthy tissue.",
101
+ "A group of medical students attend a virtual reality lecture, where they can interact with detailed, 3D anatomical models and simulate complex surgical procedures under the guidance of renowned experts from around the world.",
102
+ "In a remote village, a local healthcare worker uses a portable, AI-powered diagnostic device to quickly and accurately assess a patient's symptoms, while seamlessly connecting with specialists in distant cities for expert advice and treatment planning.",
103
+ "At an advanced fertility clinic, a couple watches in wonder as an AI-assisted system carefully selects the most viable embryos for implantation, while providing personalized guidance and emotional support throughout the process."
104
+ ]
105
+
106
+ css = '''
107
+ .gradio-container{max-width: 1024px !important}
108
+ h1{text-align:center}
109
+ footer {
110
+ visibility: hidden
111
+ }
112
+ '''
113
+
114
+ with gr.Blocks(css=css, theme="pseudolab/huggingface-korea-theme") as demo:
115
+ gr.Markdown(DESCRIPTION)
116
+
117
+ with gr.Group():
118
+ with gr.Row():
119
+ prompt = gr.Text(
120
+ label="Prompt",
121
+ show_label=False,
122
+ max_lines=1,
123
+ placeholder="Enter your prompt",
124
+ container=False,
125
+ )
126
+ run_button = gr.Button("Run", scale=0)
127
+ result = gr.Gallery(label="Result", columns=1, preview=True, show_label=False)
128
+ with gr.Accordion("Advanced options", open=False):
129
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
130
+ negative_prompt = gr.Text(
131
+ label="Negative prompt",
132
+ lines=4,
133
+ max_lines=6,
134
+ value="""(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, (NSFW:1.25)""",
135
+ placeholder="Enter a negative prompt",
136
+ visible=True,
137
+ )
138
+ seed = gr.Slider(
139
+ label="Seed",
140
+ minimum=0,
141
+ maximum=MAX_SEED,
142
+ step=1,
143
+ value=0,
144
+ visible=True
145
+ )
146
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
147
+ with gr.Row(visible=True):
148
+ width = gr.Slider(
149
+ label="Width",
150
+ minimum=512,
151
+ maximum=2048,
152
+ step=8,
153
+ value=1920,
154
+ )
155
+ height = gr.Slider(
156
+ label="Height",
157
+ minimum=512,
158
+ maximum=2048,
159
+ step=8,
160
+ value=1080,
161
+ )
162
+ with gr.Row():
163
+ guidance_scale = gr.Slider(
164
+ label="Guidance Scale",
165
+ minimum=0.1,
166
+ maximum=20.0,
167
+ step=0.1,
168
+ value=20.0,
169
+ )
170
+
171
+ image_gallery = gr.Gallery(label="Generated Images", show_label=True, columns=4, height="auto")
172
+
173
+ gr.Examples(
174
+ examples=examples,
175
+ inputs=prompt,
176
+ outputs=[result, seed],
177
+ fn=generate,
178
+ cache_examples=False,
179
+ )
180
+
181
+ use_negative_prompt.change(
182
+ fn=lambda x: gr.update(visible=x),
183
+ inputs=use_negative_prompt,
184
+ outputs=negative_prompt,
185
+ api_name=False,
186
+ )
187
+
188
+ def update_gallery():
189
+ return gr.update(value=get_image_gallery())
190
+
191
+ gr.on(
192
+ triggers=[
193
+ prompt.submit,
194
+ negative_prompt.submit,
195
+ run_button.click,
196
+ ],
197
+ fn=generate,
198
+ inputs=[
199
+ prompt,
200
+ negative_prompt,
201
+ use_negative_prompt,
202
+ seed,
203
+ width,
204
+ height,
205
+ guidance_scale,
206
+ randomize_seed,
207
+ ],
208
+ outputs=[result, seed, gr.HTML(visible=False), image_gallery],
209
+ api_name="run",
210
+ )
211
+
212
+ demo.load(fn=update_gallery, outputs=image_gallery)
213
+
214
+ if __name__ == "__main__":
215
+ demo.queue(max_size=20).launch(show_api=False, debug=False)