ginipick commited on
Commit
5d47f79
·
verified ·
1 Parent(s): 6c32331

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -57
app.py CHANGED
@@ -143,20 +143,38 @@ def optimize_gpu_settings():
143
 
144
  def install_flash_attn():
145
  try:
146
- logging.info("Installing flash-attn...")
147
- subprocess.run(
148
- ["pip", "install", "flash-attn", "--no-build-isolation"],
149
- check=True,
150
- capture_output=True
151
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  logging.info("flash-attn installed successfully!")
153
- except subprocess.CalledProcessError as e:
154
- logging.error(f"Failed to install flash-attn: {e}")
155
- raise
 
156
 
157
  def initialize_system():
158
  optimize_gpu_settings()
159
- install_flash_attn()
160
 
161
  from huggingface_hub import snapshot_download
162
 
@@ -177,42 +195,6 @@ def initialize_system():
177
  logging.error(f"Directory error: {e}")
178
  raise
179
 
180
- @lru_cache(maxsize=100)
181
- def get_cached_file_path(content_hash, prefix):
182
- return create_temp_file(content_hash, prefix)
183
-
184
- def empty_output_folder(output_dir):
185
- try:
186
- shutil.rmtree(output_dir)
187
- os.makedirs(output_dir)
188
- logging.info(f"Output folder cleaned: {output_dir}")
189
- except Exception as e:
190
- logging.error(f"Error cleaning output folder: {e}")
191
- raise
192
-
193
- def create_temp_file(content, prefix, suffix=".txt"):
194
- temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
195
- content = content.strip() + "\n\n"
196
- content = content.replace("\r\n", "\n").replace("\r", "\n")
197
- temp_file.write(content)
198
- temp_file.close()
199
- logging.debug(f"Temporary file created: {temp_file.name}")
200
- return temp_file.name
201
-
202
- def get_last_mp3_file(output_dir):
203
- mp3_files = [f for f in os.listdir(output_dir) if f.endswith('.mp3')]
204
- if not mp3_files:
205
- logging.warning("No MP3 files found")
206
- return None
207
-
208
- mp3_files_with_path = [os.path.join(output_dir, f) for f in mp3_files]
209
- mp3_files_with_path.sort(key=os.path.getmtime, reverse=True)
210
- return mp3_files_with_path[0]
211
-
212
-
213
-
214
-
215
-
216
  def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
217
  try:
218
  # 모델 선택 및 설정
@@ -234,7 +216,7 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
234
  os.makedirs(output_dir, exist_ok=True)
235
  empty_output_folder(output_dir)
236
 
237
- # 명령어 구성
238
  command = [
239
  "python", "infer.py",
240
  "--stage1_model", model_path,
@@ -247,21 +229,31 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
247
  "--cuda_idx", "0",
248
  "--max_new_tokens", str(actual_max_tokens),
249
  "--temperature", str(config['temperature']),
250
- "--disable_offload_model",
251
- "--use_flash_attention_2",
252
- "--bf16",
253
  "--chorus_strength", str(config['chorus_strength'])
254
  ]
255
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  # CUDA 환경 변수 설정
257
  env = os.environ.copy()
258
- env.update({
259
- "CUDA_VISIBLE_DEVICES": "0",
260
- "CUDA_HOME": "/usr/local/cuda",
261
- "PATH": f"/usr/local/cuda/bin:{env.get('PATH', '')}",
262
- "LD_LIBRARY_PATH": f"/usr/local/cuda/lib64:{env.get('LD_LIBRARY_PATH', '')}",
263
- "PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:512"
264
- })
 
265
 
266
  # 명령 실행
267
  process = subprocess.run(command, env=env, check=True, capture_output=True)
@@ -288,6 +280,42 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
288
  except Exception as e:
289
  logging.warning(f"Failed to remove temporary file {file}: {e}")
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  # Gradio 인터페이스
292
  with gr.Blocks() as demo:
293
  with gr.Column():
 
143
 
144
  def install_flash_attn():
145
  try:
146
+ if not torch.cuda.is_available():
147
+ logging.warning("GPU not available, skipping flash-attn installation")
148
+ return False
149
+
150
+ cuda_version = torch.version.cuda
151
+ if cuda_version is None:
152
+ logging.warning("CUDA not available, skipping flash-attn installation")
153
+ return False
154
+
155
+ logging.info(f"Detected CUDA version: {cuda_version}")
156
+
157
+ # CUDA 11.8 specific wheel for Python 3.10
158
+ if cuda_version.startswith("11.8"):
159
+ flash_attn_url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu11torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
160
+ subprocess.run(
161
+ ["pip", "install", flash_attn_url],
162
+ check=True,
163
+ capture_output=True
164
+ )
165
+ else:
166
+ logging.warning(f"Unsupported CUDA version: {cuda_version}, skipping flash-attn installation")
167
+ return False
168
+
169
  logging.info("flash-attn installed successfully!")
170
+ return True
171
+ except Exception as e:
172
+ logging.warning(f"Failed to install flash-attn: {e}")
173
+ return False
174
 
175
  def initialize_system():
176
  optimize_gpu_settings()
177
+ has_flash_attn = install_flash_attn()
178
 
179
  from huggingface_hub import snapshot_download
180
 
 
195
  logging.error(f"Directory error: {e}")
196
  raise
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
199
  try:
200
  # 모델 선택 및 설정
 
216
  os.makedirs(output_dir, exist_ok=True)
217
  empty_output_folder(output_dir)
218
 
219
+ # 기본 명령어 구성
220
  command = [
221
  "python", "infer.py",
222
  "--stage1_model", model_path,
 
229
  "--cuda_idx", "0",
230
  "--max_new_tokens", str(actual_max_tokens),
231
  "--temperature", str(config['temperature']),
 
 
 
232
  "--chorus_strength", str(config['chorus_strength'])
233
  ]
234
 
235
+ # GPU가 있고 flash-attn이 설치된 경우에만 추가 옵션 적용
236
+ if torch.cuda.is_available():
237
+ command.extend([
238
+ "--disable_offload_model",
239
+ "--bf16"
240
+ ])
241
+ try:
242
+ import flash_attn
243
+ command.append("--use_flash_attention_2")
244
+ except ImportError:
245
+ logging.info("flash-attn not available, skipping flash attention option")
246
+
247
  # CUDA 환경 변수 설정
248
  env = os.environ.copy()
249
+ if torch.cuda.is_available():
250
+ env.update({
251
+ "CUDA_VISIBLE_DEVICES": "0",
252
+ "CUDA_HOME": "/usr/local/cuda",
253
+ "PATH": f"/usr/local/cuda/bin:{env.get('PATH', '')}",
254
+ "LD_LIBRARY_PATH": f"/usr/local/cuda/lib64:{env.get('LD_LIBRARY_PATH', '')}",
255
+ "PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:512"
256
+ })
257
 
258
  # 명령 실행
259
  process = subprocess.run(command, env=env, check=True, capture_output=True)
 
280
  except Exception as e:
281
  logging.warning(f"Failed to remove temporary file {file}: {e}")
282
 
283
+ @lru_cache(maxsize=100)
284
+ def get_cached_file_path(content_hash, prefix):
285
+ return create_temp_file(content_hash, prefix)
286
+
287
+ def empty_output_folder(output_dir):
288
+ try:
289
+ shutil.rmtree(output_dir)
290
+ os.makedirs(output_dir)
291
+ logging.info(f"Output folder cleaned: {output_dir}")
292
+ except Exception as e:
293
+ logging.error(f"Error cleaning output folder: {e}")
294
+ raise
295
+
296
+ def create_temp_file(content, prefix, suffix=".txt"):
297
+ temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
298
+ content = content.strip() + "\n\n"
299
+ content = content.replace("\r\n", "\n").replace("\r", "\n")
300
+ temp_file.write(content)
301
+ temp_file.close()
302
+ logging.debug(f"Temporary file created: {temp_file.name}")
303
+ return temp_file.name
304
+
305
+ def get_last_mp3_file(output_dir):
306
+ mp3_files = [f for f in os.listdir(output_dir) if f.endswith('.mp3')]
307
+ if not mp3_files:
308
+ logging.warning("No MP3 files found")
309
+ return None
310
+
311
+ mp3_files_with_path = [os.path.join(output_dir, f) for f in mp3_files]
312
+ mp3_files_with_path.sort(key=os.path.getmtime, reverse=True)
313
+ return mp3_files_with_path[0]
314
+
315
+
316
+
317
+
318
+
319
  # Gradio 인터페이스
320
  with gr.Blocks() as demo:
321
  with gr.Column():