cocktailpeanut commited on
Commit
d82dd30
·
1 Parent(s): da5331f
Files changed (3) hide show
  1. app.py +9 -5
  2. evosdxl_jp_v1.py +3 -1
  3. requirements.txt +2 -2
app.py CHANGED
@@ -6,10 +6,11 @@ import uuid
6
 
7
  import gradio as gr
8
  import numpy as np
9
- import spaces
10
  import torch
11
  from PIL import Image
12
  from evosdxl_jp_v1 import load_evosdxl_jp
 
13
 
14
  DESCRIPTION = """# 🐟 EvoSDXL-JP
15
  🤗 [モデル一覧](https://huggingface.co/SakanaAI) | 📚 [技術レポート](https://arxiv.org/abs/2403.13187) | 📝 [ブログ](https://sakana.ai/evosdxl-jp/) | 🐦 [Twitter](https://twitter.com/SakanaAILabs)
@@ -23,12 +24,14 @@ if not torch.cuda.is_available():
23
  MAX_SEED = np.iinfo(np.int32).max
24
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
25
 
26
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
27
 
28
  NUM_IMAGES_PER_PROMPT = 1
29
  ENABLE_CPU_OFFLOAD = False
30
  USE_TORCH_COMPILE = False
31
- SAFETY_CHECKER = True
 
32
  DEVELOP_MODE = True
33
  if SAFETY_CHECKER:
34
  from safety_checker import StableDiffusionSafetyChecker
@@ -53,7 +56,8 @@ if SAFETY_CHECKER:
53
  return images, has_nsfw_concepts
54
 
55
 
56
- pipe = load_evosdxl_jp("cpu").to("cuda")
 
57
 
58
  def show_warning(warning_text: str) -> gr.Blocks:
59
  with gr.Blocks() as demo:
@@ -154,4 +158,4 @@ with gr.Blocks(css=css) as demo:
154
  Sakana AIは、本モデルの使用によって生じた直接的または間接的な損失に対して、結果に関わらず、一切の責任を負いません。
155
  利用者は、本モデルの使用に伴うリスクを十分に理解し、自身の判断で使用することが必要です。""")
156
 
157
- demo.queue().launch()
 
6
 
7
  import gradio as gr
8
  import numpy as np
9
+ #import spaces
10
  import torch
11
  from PIL import Image
12
  from evosdxl_jp_v1 import load_evosdxl_jp
13
+ import devicetorch
14
 
15
  DESCRIPTION = """# 🐟 EvoSDXL-JP
16
  🤗 [モデル一覧](https://huggingface.co/SakanaAI) | 📚 [技術レポート](https://arxiv.org/abs/2403.13187) | 📝 [ブログ](https://sakana.ai/evosdxl-jp/) | 🐦 [Twitter](https://twitter.com/SakanaAILabs)
 
24
  MAX_SEED = np.iinfo(np.int32).max
25
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
26
 
27
+ #device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ device = devicetorch.get(torch)
29
 
30
  NUM_IMAGES_PER_PROMPT = 1
31
  ENABLE_CPU_OFFLOAD = False
32
  USE_TORCH_COMPILE = False
33
+ #SAFETY_CHECKER = True
34
+ SAFETY_CHECKER = False
35
  DEVELOP_MODE = True
36
  if SAFETY_CHECKER:
37
  from safety_checker import StableDiffusionSafetyChecker
 
56
  return images, has_nsfw_concepts
57
 
58
 
59
+ #pipe = load_evosdxl_jp("cpu").to("cuda")
60
+ pipe = load_evosdxl_jp("cpu").to(device)
61
 
62
  def show_warning(warning_text: str) -> gr.Blocks:
63
  with gr.Blocks() as demo:
 
158
  Sakana AIは、本モデルの使用によって生じた直接的または間接的な損失に対して、結果に関わらず、一切の責任を負いません。
159
  利用者は、本モデルの使用に伴うリスクを十分に理解し、自身の判断で使用することが必要です。""")
160
 
161
+ demo.queue().launch()
evosdxl_jp_v1.py CHANGED
@@ -11,6 +11,7 @@ from diffusers import (
11
  EulerDiscreteScheduler,
12
  )
13
  from diffusers.loaders import LoraLoaderMixin
 
14
 
15
  SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
16
  JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
@@ -137,7 +138,8 @@ def load_evosdxl_jp(device="cuda") -> StableDiffusionXLPipeline:
137
  ],
138
  )
139
  del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights
140
- torch.cuda.empty_cache()
 
141
  unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet")
142
  unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
143
  unet.load_state_dict({**new_conv, **new_attn})
 
11
  EulerDiscreteScheduler,
12
  )
13
  from diffusers.loaders import LoraLoaderMixin
14
+ import devicetorch
15
 
16
  SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
17
  JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
 
138
  ],
139
  )
140
  del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights
141
+ devicetorch.empty_cache(torch)
142
+ #torch.cuda.empty_cache()
143
  unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet")
144
  unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
145
  unet.load_state_dict({**new_conv, **new_attn})
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
- torch
2
  diffusers==0.26.0
3
  transformers
4
  safetensors
5
  accelerate
6
- sentencepiece
 
1
+ #torch
2
  diffusers==0.26.0
3
  transformers
4
  safetensors
5
  accelerate
6
+ sentencepiece