Update app.py
Browse files
app.py
CHANGED
@@ -22,25 +22,25 @@ from flask_swagger_ui import get_swaggerui_blueprint
|
|
22 |
import uuid
|
23 |
|
24 |
class AnimationConfig:
|
25 |
-
def __init__(self, driven_audio_path, source_image_path, result_folder):
|
26 |
self.driven_audio = driven_audio_path
|
27 |
self.source_image = source_image_path
|
28 |
self.ref_eyeblink = None
|
29 |
self.ref_pose = None
|
30 |
self.checkpoint_dir = './checkpoints'
|
31 |
self.result_dir = result_folder
|
32 |
-
self.pose_style =
|
33 |
-
self.batch_size =
|
34 |
self.size = 256
|
35 |
-
self.expression_scale =
|
36 |
self.input_yaw = None
|
37 |
self.input_pitch = None
|
38 |
self.input_roll = None
|
39 |
-
self.enhancer =
|
40 |
self.background_enhancer = None
|
41 |
self.cpu = False
|
42 |
self.face3dvis = False
|
43 |
-
self.still = False
|
44 |
self.preprocess = 'crop'
|
45 |
self.verbose = False
|
46 |
self.old_version = False
|
@@ -74,6 +74,7 @@ app.register_blueprint(swagger_ui_blueprint, url_prefix=SWAGGER_URL)
|
|
74 |
|
75 |
app.config['temp_response'] = None
|
76 |
app.config['generation_thread'] = None
|
|
|
77 |
|
78 |
TEMP_DIR = tempfile.TemporaryDirectory()
|
79 |
|
@@ -186,32 +187,43 @@ def translate_text(text, target_language):
|
|
186 |
model="gpt-4-0125-preview",
|
187 |
messages=[
|
188 |
{"role": "system", "content": "You are a helpful assistant."},
|
189 |
-
{"role": "user", "content": f"Translate the following text into {target_language}
|
190 |
],
|
191 |
max_tokens=len(text),
|
192 |
temperature=0.3,
|
193 |
)
|
194 |
return response
|
195 |
|
|
|
196 |
@app.route("/run", methods=['POST'])
|
197 |
def generate_video():
|
198 |
if request.method == 'POST':
|
199 |
source_image = request.files['source_image']
|
200 |
text_prompt = request.form['text_prompt']
|
201 |
voice_cloning = request.form.get('voice_cloning', 'no')
|
202 |
-
target_language = request.form.get('target_language',
|
|
|
|
|
|
|
|
|
203 |
|
204 |
-
if target_language
|
205 |
response = translate_text(text_prompt, target_language)
|
206 |
text_prompt = response.choices[0].message.content.strip()
|
207 |
print('text_prompt',text_prompt)
|
208 |
|
|
|
209 |
source_image_path = save_uploaded_file(source_image, 'source_image.png')
|
210 |
print(source_image_path)
|
211 |
|
212 |
if voice_cloning == 'no':
|
|
|
|
|
|
|
|
|
|
|
213 |
response = client.audio.speech.create(model="tts-1-hd",
|
214 |
-
voice=
|
215 |
input = text_prompt)
|
216 |
|
217 |
with tempfile.NamedTemporaryFile(suffix=".wav", prefix="text_to_speech_", delete=False) as temp_file:
|
@@ -241,7 +253,7 @@ def generate_video():
|
|
241 |
os.makedirs(result_folder, exist_ok=True)
|
242 |
|
243 |
# Example of using the class with some hypothetical paths
|
244 |
-
args = AnimationConfig(driven_audio_path=driven_audio_path, source_image_path=source_image_path, result_folder=result_folder)
|
245 |
|
246 |
if torch.cuda.is_available() and not args.cpu:
|
247 |
args.device = "cuda"
|
@@ -263,7 +275,7 @@ def generate_video():
|
|
263 |
|
264 |
@app.route("/status", methods=["GET"])
|
265 |
def check_generation_status():
|
266 |
-
response = {"base64_video": "", "status": ""}
|
267 |
process_id = request.args.get('process_id', None)
|
268 |
|
269 |
# process_id is required to check the status for that specific process
|
@@ -275,6 +287,7 @@ def check_generation_status():
|
|
275 |
# app.config['temp_response']['status'] = 'completed'
|
276 |
final_response = app.config['temp_response']
|
277 |
response["base64_video"] = final_response
|
|
|
278 |
response["status"] = "completed"
|
279 |
return jsonify(response)
|
280 |
return jsonify({"error":"No process id provided"})
|
|
|
22 |
import uuid
|
23 |
|
24 |
class AnimationConfig:
|
25 |
+
def __init__(self, driven_audio_path, source_image_path, result_folder,pose_style,expression_scale,enhancer):
|
26 |
self.driven_audio = driven_audio_path
|
27 |
self.source_image = source_image_path
|
28 |
self.ref_eyeblink = None
|
29 |
self.ref_pose = None
|
30 |
self.checkpoint_dir = './checkpoints'
|
31 |
self.result_dir = result_folder
|
32 |
+
self.pose_style = pose_style
|
33 |
+
self.batch_size = 2
|
34 |
self.size = 256
|
35 |
+
self.expression_scale = expression_scale
|
36 |
self.input_yaw = None
|
37 |
self.input_pitch = None
|
38 |
self.input_roll = None
|
39 |
+
self.enhancer = enhancer
|
40 |
self.background_enhancer = None
|
41 |
self.cpu = False
|
42 |
self.face3dvis = False
|
43 |
+
self.still = False
|
44 |
self.preprocess = 'crop'
|
45 |
self.verbose = False
|
46 |
self.old_version = False
|
|
|
74 |
|
75 |
app.config['temp_response'] = None
|
76 |
app.config['generation_thread'] = None
|
77 |
+
app.config['text_prompt'] = None
|
78 |
|
79 |
TEMP_DIR = tempfile.TemporaryDirectory()
|
80 |
|
|
|
187 |
model="gpt-4-0125-preview",
|
188 |
messages=[
|
189 |
{"role": "system", "content": "You are a helpful assistant."},
|
190 |
+
{"role": "user", "content": f"Translate and give just the following text into {target_language} as response: {text}\n"},
|
191 |
],
|
192 |
max_tokens=len(text),
|
193 |
temperature=0.3,
|
194 |
)
|
195 |
return response
|
196 |
|
197 |
+
|
198 |
@app.route("/run", methods=['POST'])
|
199 |
def generate_video():
|
200 |
if request.method == 'POST':
|
201 |
source_image = request.files['source_image']
|
202 |
text_prompt = request.form['text_prompt']
|
203 |
voice_cloning = request.form.get('voice_cloning', 'no')
|
204 |
+
target_language = request.form.get('target_language', None)
|
205 |
+
pose_style = int(request.form.get('pose_style', 1))
|
206 |
+
expression_scale = int(request.form.get('expression_scale', 1))
|
207 |
+
enhancer = request.form.get('enhancer', None)
|
208 |
+
voice_gender = request.form.get('voice_gender', 'male')
|
209 |
|
210 |
+
if target_language is not None:
|
211 |
response = translate_text(text_prompt, target_language)
|
212 |
text_prompt = response.choices[0].message.content.strip()
|
213 |
print('text_prompt',text_prompt)
|
214 |
|
215 |
+
app.config['text_prompt'] = text_prompt
|
216 |
source_image_path = save_uploaded_file(source_image, 'source_image.png')
|
217 |
print(source_image_path)
|
218 |
|
219 |
if voice_cloning == 'no':
|
220 |
+
if voice_gender == 'male':
|
221 |
+
voice = 'onyx'
|
222 |
+
else:
|
223 |
+
voice = 'nova'
|
224 |
+
|
225 |
response = client.audio.speech.create(model="tts-1-hd",
|
226 |
+
voice=voice,
|
227 |
input = text_prompt)
|
228 |
|
229 |
with tempfile.NamedTemporaryFile(suffix=".wav", prefix="text_to_speech_", delete=False) as temp_file:
|
|
|
253 |
os.makedirs(result_folder, exist_ok=True)
|
254 |
|
255 |
# Example of using the class with some hypothetical paths
|
256 |
+
args = AnimationConfig(driven_audio_path=driven_audio_path, source_image_path=source_image_path, result_folder=result_folder, pose_style=pose_style, expression_scale=expression_scale, enhancer=enhancer)
|
257 |
|
258 |
if torch.cuda.is_available() and not args.cpu:
|
259 |
args.device = "cuda"
|
|
|
275 |
|
276 |
@app.route("/status", methods=["GET"])
|
277 |
def check_generation_status():
|
278 |
+
response = {"base64_video": "","text_prompt":"", "status": ""}
|
279 |
process_id = request.args.get('process_id', None)
|
280 |
|
281 |
# process_id is required to check the status for that specific process
|
|
|
287 |
# app.config['temp_response']['status'] = 'completed'
|
288 |
final_response = app.config['temp_response']
|
289 |
response["base64_video"] = final_response
|
290 |
+
response["text_prompt"] = app.config.get('text_prompt')
|
291 |
response["status"] = "completed"
|
292 |
return jsonify(response)
|
293 |
return jsonify({"error":"No process id provided"})
|