Spanicin commited on
Commit
218b3a7
1 Parent(s): 4bf86d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -22
app.py CHANGED
@@ -20,6 +20,9 @@ from elevenlabs import set_api_key, generate, play, clone
20
  from flask_cors import CORS, cross_origin
21
  from flask_swagger_ui import get_swaggerui_blueprint
22
  import uuid
 
 
 
23
 
24
  class AnimationConfig:
25
  def __init__(self, driven_audio_path, source_image_path, result_folder,pose_style,expression_scale,enhancer):
@@ -69,23 +72,22 @@ swagger_ui_blueprint = get_swaggerui_blueprint(
69
  )
70
 
71
  app = Flask(__name__)
 
 
72
  CORS(app)
73
  app.register_blueprint(swagger_ui_blueprint, url_prefix=SWAGGER_URL)
74
 
75
  app.config['temp_response'] = None
76
  app.config['generation_thread'] = None
77
  app.config['text_prompt'] = None
 
78
 
79
- TEMP_DIR = tempfile.TemporaryDirectory()
80
 
81
 
82
  def main(args):
83
  pic_path = args.source_image
84
  audio_path = args.driven_audio
85
  save_dir = args.result_dir
86
- # save_dir = os.path.join(args.result_folder, strftime("%Y_%m_%d_%H.%M.%S"))
87
- # os.makedirs(save_dir, exist_ok=True)
88
- print('save_dir',save_dir)
89
  pose_style = args.pose_style
90
  device = args.device
91
  batch_size = args.batch_size
@@ -100,7 +102,6 @@ def main(args):
100
  print('current_root_path ',current_root_path)
101
 
102
  sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess)
103
- print('sadtalker_paths ',sadtalker_paths)
104
 
105
 
106
 
@@ -144,7 +145,6 @@ def main(args):
144
  print('ref_eyeblink_coeff_path',ref_pose_coeff_path)
145
 
146
  batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)
147
- print('batch',batch)
148
  coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
149
 
150
  if args.face3dvis:
@@ -154,19 +154,15 @@ def main(args):
154
  batch_size, input_yaw_list, input_pitch_list, input_roll_list,
155
  expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size)
156
 
157
- print('data ',data)
158
- print('save_dir ', save_dir)
159
- print('pic_path ',pic_path)
160
- print('crop ',crop_info)
161
 
162
- result, base64_video = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
163
  enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size)
164
 
165
 
166
  print('The generated video is named:')
167
  app.config['temp_response'] = base64_video
168
-
169
- return base64_video
170
 
171
  # shutil.move(result, save_dir+'.mp4')
172
 
@@ -174,7 +170,10 @@ def main(args):
174
  if not args.verbose:
175
  shutil.rmtree(save_dir)
176
 
177
- def save_uploaded_file(file, filename):
 
 
 
178
  unique_filename = str(uuid.uuid4()) + "_" + filename
179
  file_path = os.path.join(TEMP_DIR.name, unique_filename)
180
  file.save(file_path)
@@ -197,23 +196,28 @@ def translate_text(text, target_language):
197
 
198
  @app.route("/run", methods=['POST'])
199
  def generate_video():
 
 
200
  if request.method == 'POST':
201
  source_image = request.files['source_image']
202
  text_prompt = request.form['text_prompt']
 
203
  voice_cloning = request.form.get('voice_cloning', 'no')
204
- target_language = request.form.get('target_language', None)
 
205
  pose_style = int(request.form.get('pose_style', 1))
206
  expression_scale = int(request.form.get('expression_scale', 1))
207
  enhancer = request.form.get('enhancer', None)
208
  voice_gender = request.form.get('voice_gender', 'male')
209
 
210
- if target_language is not None:
211
  response = translate_text(text_prompt, target_language)
212
  text_prompt = response.choices[0].message.content.strip()
213
- print('text_prompt',text_prompt)
214
 
215
  app.config['text_prompt'] = text_prompt
216
- source_image_path = save_uploaded_file(source_image, 'source_image.png')
 
 
217
  print(source_image_path)
218
 
219
  if voice_cloning == 'no':
@@ -226,7 +230,7 @@ def generate_video():
226
  voice=voice,
227
  input = text_prompt)
228
 
229
- with tempfile.NamedTemporaryFile(suffix=".wav", prefix="text_to_speech_", delete=False) as temp_file:
230
  driven_audio_path = temp_file.name
231
 
232
  response.write_to_file(driven_audio_path)
@@ -234,7 +238,7 @@ def generate_video():
234
  elif voice_cloning == 'yes':
235
  user_voice = request.files['user_voice']
236
 
237
- with tempfile.NamedTemporaryFile(suffix=".wav", prefix="user_voice_", delete=False) as temp_file:
238
  user_voice_path = temp_file.name
239
  user_voice.save(user_voice_path)
240
  print('user_voice_path',user_voice_path)
@@ -244,11 +248,11 @@ def generate_video():
244
  files = [user_voice_path] )
245
 
246
  audio = generate(text = text_prompt, voice = voice, model = "eleven_multilingual_v2")
247
- with tempfile.NamedTemporaryFile(suffix=".mp3", prefix="cloned_audio_", delete=False) as temp_file:
248
  driven_audio_path = temp_file.name
249
  elevenlabs.save(audio, driven_audio_path)
250
 
251
- save_dir = tempfile.mkdtemp()
252
  result_folder = os.path.join(save_dir, "results")
253
  os.makedirs(result_folder, exist_ok=True)
254
 
@@ -275,6 +279,7 @@ def generate_video():
275
 
276
  @app.route("/status", methods=["GET"])
277
  def check_generation_status():
 
278
  response = {"base64_video": "","text_prompt":"", "status": ""}
279
  process_id = request.args.get('process_id', None)
280
 
@@ -289,6 +294,26 @@ def check_generation_status():
289
  response["base64_video"] = final_response
290
  response["text_prompt"] = app.config.get('text_prompt')
291
  response["status"] = "completed"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  return jsonify(response)
293
  return jsonify({"error":"No process id provided"})
294
 
 
20
  from flask_cors import CORS, cross_origin
21
  from flask_swagger_ui import get_swaggerui_blueprint
22
  import uuid
23
+ import time
24
+
25
+ start_time = time.time()
26
 
27
  class AnimationConfig:
28
  def __init__(self, driven_audio_path, source_image_path, result_folder,pose_style,expression_scale,enhancer):
 
72
  )
73
 
74
  app = Flask(__name__)
75
+
76
+ TEMP_DIR = None
77
  CORS(app)
78
  app.register_blueprint(swagger_ui_blueprint, url_prefix=SWAGGER_URL)
79
 
80
  app.config['temp_response'] = None
81
  app.config['generation_thread'] = None
82
  app.config['text_prompt'] = None
83
+ app.config['final_video_path'] = None
84
 
 
85
 
86
 
87
  def main(args):
88
  pic_path = args.source_image
89
  audio_path = args.driven_audio
90
  save_dir = args.result_dir
 
 
 
91
  pose_style = args.pose_style
92
  device = args.device
93
  batch_size = args.batch_size
 
102
  print('current_root_path ',current_root_path)
103
 
104
  sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess)
 
105
 
106
 
107
 
 
145
  print('ref_eyeblink_coeff_path',ref_pose_coeff_path)
146
 
147
  batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)
 
148
  coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
149
 
150
  if args.face3dvis:
 
154
  batch_size, input_yaw_list, input_pitch_list, input_roll_list,
155
  expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size)
156
 
 
 
 
 
157
 
158
+ result, base64_video,temp_file_path= animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
159
  enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size)
160
 
161
 
162
  print('The generated video is named:')
163
  app.config['temp_response'] = base64_video
164
+ app.config['final_video_path'] = temp_file_path
165
+ return base64_video, temp_file_path
166
 
167
  # shutil.move(result, save_dir+'.mp4')
168
 
 
170
  if not args.verbose:
171
  shutil.rmtree(save_dir)
172
 
173
+ def create_temp_dir():
174
+ return tempfile.TemporaryDirectory()
175
+
176
+ def save_uploaded_file(file, filename,TEMP_DIR):
177
  unique_filename = str(uuid.uuid4()) + "_" + filename
178
  file_path = os.path.join(TEMP_DIR.name, unique_filename)
179
  file.save(file_path)
 
196
 
197
  @app.route("/run", methods=['POST'])
198
  def generate_video():
199
+ global TEMP_DIR
200
+ TEMP_DIR = create_temp_dir()
201
  if request.method == 'POST':
202
  source_image = request.files['source_image']
203
  text_prompt = request.form['text_prompt']
204
+ print('Input text prompt: ',text_prompt)
205
  voice_cloning = request.form.get('voice_cloning', 'no')
206
+ target_language = request.form.get('target_language', 'original_text')
207
+ print('target_language',target_language)
208
  pose_style = int(request.form.get('pose_style', 1))
209
  expression_scale = int(request.form.get('expression_scale', 1))
210
  enhancer = request.form.get('enhancer', None)
211
  voice_gender = request.form.get('voice_gender', 'male')
212
 
213
+ if target_language != 'original_text':
214
  response = translate_text(text_prompt, target_language)
215
  text_prompt = response.choices[0].message.content.strip()
 
216
 
217
  app.config['text_prompt'] = text_prompt
218
+ print('Final text prompt: ',text_prompt)
219
+
220
+ source_image_path = save_uploaded_file(source_image, 'source_image.png',TEMP_DIR)
221
  print(source_image_path)
222
 
223
  if voice_cloning == 'no':
 
230
  voice=voice,
231
  input = text_prompt)
232
 
233
+ with tempfile.NamedTemporaryFile(suffix=".wav", prefix="text_to_speech_",dir=TEMP_DIR.name, delete=False) as temp_file:
234
  driven_audio_path = temp_file.name
235
 
236
  response.write_to_file(driven_audio_path)
 
238
  elif voice_cloning == 'yes':
239
  user_voice = request.files['user_voice']
240
 
241
+ with tempfile.NamedTemporaryFile(suffix=".wav", prefix="user_voice_",dir=TEMP_DIR.name, delete=False) as temp_file:
242
  user_voice_path = temp_file.name
243
  user_voice.save(user_voice_path)
244
  print('user_voice_path',user_voice_path)
 
248
  files = [user_voice_path] )
249
 
250
  audio = generate(text = text_prompt, voice = voice, model = "eleven_multilingual_v2")
251
+ with tempfile.NamedTemporaryFile(suffix=".mp3", prefix="cloned_audio_",dir=TEMP_DIR.name, delete=False) as temp_file:
252
  driven_audio_path = temp_file.name
253
  elevenlabs.save(audio, driven_audio_path)
254
 
255
+ save_dir = tempfile.mkdtemp(dir=TEMP_DIR.name)
256
  result_folder = os.path.join(save_dir, "results")
257
  os.makedirs(result_folder, exist_ok=True)
258
 
 
279
 
280
  @app.route("/status", methods=["GET"])
281
  def check_generation_status():
282
+ global TEMP_DIR
283
  response = {"base64_video": "","text_prompt":"", "status": ""}
284
  process_id = request.args.get('process_id', None)
285
 
 
294
  response["base64_video"] = final_response
295
  response["text_prompt"] = app.config.get('text_prompt')
296
  response["status"] = "completed"
297
+
298
+ final_video_path = app.config['final_video_path']
299
+ print('final_video_path',final_video_path)
300
+
301
+
302
+ if final_video_path and os.path.exists(final_video_path):
303
+ os.remove(final_video_path)
304
+ print("Deleted video file:", final_video_path)
305
+
306
+ TEMP_DIR.cleanup()
307
+ # print("Temporary Directory:", TEMP_DIR.name)
308
+ # if TEMP_DIR:
309
+ # print("Contents of Temporary Directory:")
310
+ # for filename in os.listdir(TEMP_DIR.name):
311
+ # print(filename)
312
+ # else:
313
+ # print("Temporary Directory is None or already cleaned up.")
314
+ end_time = time.time()
315
+ total_time = round(end_time - start_time, 2)
316
+ print("Total time taken for execution:", total_time, " seconds")
317
  return jsonify(response)
318
  return jsonify({"error":"No process id provided"})
319