update arxiv
Browse files
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
---
|
2 |
title: Diffusion Cocktail
|
3 |
emoji: 🍸
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
python: 3.
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: Diffusion Cocktail
|
3 |
emoji: 🍸
|
4 |
+
colorFrom: orange
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.7.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
python: 3.9.17
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -6,6 +6,7 @@ import torch
|
|
6 |
import torchvision.transforms as T
|
7 |
|
8 |
from clip_interrogator import Config, Interrogator
|
|
|
9 |
|
10 |
from ditail import DitailDemo, seed_everything
|
11 |
|
@@ -74,6 +75,9 @@ class WebApp():
|
|
74 |
gtag('config', '{self.gtag}');
|
75 |
}}
|
76 |
"""
|
|
|
|
|
|
|
77 |
|
78 |
self.debug_mode = debug_mode # turn off clip interrogator when debugging for faster building speed
|
79 |
if not self.debug_mode:
|
@@ -81,7 +85,6 @@ class WebApp():
|
|
81 |
|
82 |
|
83 |
def init_interrogator(self):
|
84 |
-
# init clip interrogator
|
85 |
config = Config()
|
86 |
config.clip_model_name = self.args_base['clip_model_name']
|
87 |
config.caption_model_name = self.args_base['caption_model_name']
|
@@ -89,16 +92,25 @@ class WebApp():
|
|
89 |
self.ci.config.chunk_size = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024
|
90 |
self.ci.config.flavor_intermediate_count = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
def title(self):
|
93 |
gr.HTML(
|
94 |
"""
|
95 |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
96 |
<div>
|
97 |
<h1 >Diffusion Cocktail 🍸: Fused Generation from Diffusion Models</h1>
|
98 |
-
<div style="display: flex; justify-content: center; align-items: center; text-align: center; margin: 20px; gap: 10px
|
99 |
-
<a class="flex-item" href="https://arxiv.org/abs/
|
100 |
<img src="https://img.shields.io/badge/arXiv-paper-darkred.svg" alt="arXiv Paper">
|
101 |
-
</a>
|
102 |
<a class="flex-item" href="https://MAPS-research.github.io/Ditail" target="_blank">
|
103 |
<img src="https://img.shields.io/badge/Project_Page-Diffusion_Cocktail-yellow.svg" alt="Project Page">
|
104 |
</a>
|
|
|
6 |
import torchvision.transforms as T
|
7 |
|
8 |
from clip_interrogator import Config, Interrogator
|
9 |
+
from diffusers import StableDiffusionPipeline
|
10 |
|
11 |
from ditail import DitailDemo, seed_everything
|
12 |
|
|
|
75 |
gtag('config', '{self.gtag}');
|
76 |
}}
|
77 |
"""
|
78 |
+
|
79 |
+
# pre-download base model for better user experience
|
80 |
+
self._preload_pipeline()
|
81 |
|
82 |
self.debug_mode = debug_mode # turn off clip interrogator when debugging for faster building speed
|
83 |
if not self.debug_mode:
|
|
|
85 |
|
86 |
|
87 |
def init_interrogator(self):
|
|
|
88 |
config = Config()
|
89 |
config.clip_model_name = self.args_base['clip_model_name']
|
90 |
config.caption_model_name = self.args_base['caption_model_name']
|
|
|
92 |
self.ci.config.chunk_size = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024
|
93 |
self.ci.config.flavor_intermediate_count = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024
|
94 |
|
95 |
+
|
96 |
+
def _preload_pipeline(self):
|
97 |
+
for model in BASE_MODEL.values():
|
98 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
99 |
+
model, torch_dtype=torch.float16
|
100 |
+
).to(self.args_base['device'])
|
101 |
+
pipe = None
|
102 |
+
|
103 |
+
|
104 |
def title(self):
|
105 |
gr.HTML(
|
106 |
"""
|
107 |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
108 |
<div>
|
109 |
<h1 >Diffusion Cocktail 🍸: Fused Generation from Diffusion Models</h1>
|
110 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center; margin: 20px; gap: 10px;">
|
111 |
+
<a class="flex-item" href="https://arxiv.org/abs/2312.08873" target="_blank">
|
112 |
<img src="https://img.shields.io/badge/arXiv-paper-darkred.svg" alt="arXiv Paper">
|
113 |
+
</a>
|
114 |
<a class="flex-item" href="https://MAPS-research.github.io/Ditail" target="_blank">
|
115 |
<img src="https://img.shields.io/badge/Project_Page-Diffusion_Cocktail-yellow.svg" alt="Project Page">
|
116 |
</a>
|