frankleeeee commited on
Commit
a097e62
1 Parent(s): 348ea80
Files changed (1) hide show
  1. app.py +28 -7
app.py CHANGED
@@ -14,6 +14,7 @@ import sys
14
  import spaces
15
  import gradio as gr
16
  import torch
 
17
 
18
 
19
 
@@ -29,7 +30,7 @@ HF_STDIT_MAP = {
29
  "v1-HQ-16x512x512": "hpcai-tech/OpenSora-STDiT-v1-HQ-16x512x512",
30
  }
31
 
32
- def install_dependencies():
33
  """
34
  Install the required dependencies for the demo if they are not already installed.
35
  """
@@ -41,7 +42,9 @@ def install_dependencies():
41
  except (ImportError, ModuleNotFoundError):
42
  return False
43
 
44
- # install flash attention
 
 
45
  if not _is_package_available("flash_attn"):
46
  subprocess.run(
47
  f"{sys.executable} -m pip install flash-attn --no-build-isolation",
@@ -49,6 +52,25 @@ def install_dependencies():
49
  shell=True,
50
  )
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def read_config(config_path):
53
  """
54
  Read the configuration file.
@@ -114,6 +136,7 @@ def parse_args():
114
  parser.add_argument("--port", default=None, type=int, help="The port to run the Gradio App on.")
115
  parser.add_argument("--host", default=None, type=str, help="The host to run the Gradio App on.")
116
  parser.add_argument("--share", action="store_true", help="Whether to share this gradio demo.")
 
117
  return parser.parse_args()
118
 
119
 
@@ -128,11 +151,11 @@ config = read_config(CONFIG_MAP[args.model_type])
128
  os.makedirs(args.output, exist_ok=True)
129
 
130
  # disable torch jit as it can cause failure in gradio SDK
131
- # since gradio sdk uses torch with cuda 11.3
132
  torch.jit._state.disable()
133
 
134
  # set up
135
- install_dependencies()
136
 
137
  # build model
138
  vae, text_encoder, stdit, scheduler = build_models(args.model_type, config)
@@ -141,7 +164,6 @@ vae, text_encoder, stdit, scheduler = build_models(args.model_type, config)
141
  def run_inference(prompt_text):
142
  latent_size = get_latent_size(config, vae)
143
 
144
- from opensora.datasets import save_sample
145
  samples = scheduler.sample(
146
  stdit,
147
  text_encoder,
@@ -204,6 +226,5 @@ with gr.Blocks() as demo:
204
  )
205
 
206
  # launch
207
- # demo.launch(server_port=args.port, server_name=args.host, share=args.share)
208
- demo.launch()
209
 
 
14
  import spaces
15
  import gradio as gr
16
  import torch
17
+ from opensora.datasets import save_sample
18
 
19
 
20
 
 
30
  "v1-HQ-16x512x512": "hpcai-tech/OpenSora-STDiT-v1-HQ-16x512x512",
31
  }
32
 
33
+ def install_dependencies(enable_optimization=False):
34
  """
35
  Install the required dependencies for the demo if they are not already installed.
36
  """
 
42
  except (ImportError, ModuleNotFoundError):
43
  return False
44
 
45
+ # flash attention is needed no matter optimization is enabled or not
46
+ # because Hugging Face transformers detects flash_attn is a dependency in STDiT
47
+ # thus, we need to install it no matter what
48
  if not _is_package_available("flash_attn"):
49
  subprocess.run(
50
  f"{sys.executable} -m pip install flash-attn --no-build-isolation",
 
52
  shell=True,
53
  )
54
 
55
+ if enable_optimization:
56
+ # install ape
57
+ if not _is_package_available("apex"):
58
+ subprocess.run(
59
+ f'{sys.executable} -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" git+https://github.com/NVIDIA/apex.git',
60
+ shell=True,
61
+ )
62
+
63
+ # install ninja
64
+ if not _is_package_available("ninja"):
65
+ subprocess.run(f"{sys.executable} -m pip install ninja", shell=True)
66
+
67
+ # install xformers
68
+ if not _is_package_available("xformers"):
69
+ subprocess.run(
70
+ f"{sys.executable} -m pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers",
71
+ shell=True,
72
+ )
73
+
74
  def read_config(config_path):
75
  """
76
  Read the configuration file.
 
136
  parser.add_argument("--port", default=None, type=int, help="The port to run the Gradio App on.")
137
  parser.add_argument("--host", default=None, type=str, help="The host to run the Gradio App on.")
138
  parser.add_argument("--share", action="store_true", help="Whether to share this gradio demo.")
139
+ parser.add_argument("--enable-optimization", action="store_true", help="Whether to enable optimization such as flash attention and fused layernorm")
140
  return parser.parse_args()
141
 
142
 
 
151
  os.makedirs(args.output, exist_ok=True)
152
 
153
  # disable torch jit as it can cause failure in gradio SDK
154
+ # gradio sdk uses torch with cuda 11.3
155
  torch.jit._state.disable()
156
 
157
  # set up
158
+ install_dependencies(enable_optimization=args.enable_optimization)
159
 
160
  # build model
161
  vae, text_encoder, stdit, scheduler = build_models(args.model_type, config)
 
164
  def run_inference(prompt_text):
165
  latent_size = get_latent_size(config, vae)
166
 
 
167
  samples = scheduler.sample(
168
  stdit,
169
  text_encoder,
 
226
  )
227
 
228
  # launch
229
+ demo.launch(server_port=args.port, server_name=args.host, share=args.share)
 
230