ViDove / pipeline.py
Yuhan-Lu
update logic of individual translation
ac6e110
raw
history blame
16.3 kB
import openai
from pytube import YouTube
import argparse
import os
from tqdm import tqdm
from SRT import SRT_script
import stable_whisper
import whisper
from srt2ass import srt2ass
import logging
from datetime import datetime
import torch
import subprocess
import time
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--link", help="youtube video link here", default=None, type=str, required=False)
parser.add_argument("--video_file", help="local video path here", default=None, type=str, required=False)
parser.add_argument("--audio_file", help="local audio path here", default=None, type=str, required=False)
parser.add_argument("--srt_file", help="srt file input path here", default=None, type=str, required=False) # New argument
parser.add_argument("--download", help="download path", default='./downloads', type=str, required=False)
parser.add_argument("--output_dir", help="translate result path", default='./results', type=str, required=False)
parser.add_argument("--video_name", help="video name, if use video link as input, the name will auto-filled by youtube video name", default='placeholder', type=str, required=False)
parser.add_argument("--model_name", help="model name only support gpt-4 and gpt-3.5-turbo", type=str, required=False, default="gpt-4") # default change to gpt-4
parser.add_argument("--log_dir", help="log path", default='./logs', type=str, required=False)
parser.add_argument("-only_srt", help="set script output to only .srt file", action='store_true')
parser.add_argument("-v", help="auto encode script with video", action='store_true')
args = parser.parse_args()
return args
def get_sources(args, download_path, result_path, video_name):
# get source audio
audio_path = None
audio_file = None
video_path = None
if args.link is not None and args.video_file is None:
# Download audio from YouTube
video_link = args.link
video = None
audio = None
try:
yt = YouTube(video_link)
video = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
if video:
video.download(f'{download_path}/video')
print('Video download completed!')
else:
print("Error: Video stream not found")
audio = yt.streams.filter(only_audio=True, file_extension='mp4').first()
if audio:
audio.download(f'{download_path}/audio')
print('Audio download completed!')
else:
print("Error: Audio stream not found")
except Exception as e:
print("Connection Error")
print(e)
exit()
video_path = f'{download_path}/video/{video.default_filename}'
audio_path = '{}/audio/{}'.format(download_path, audio.default_filename)
audio_file = open(audio_path, "rb")
if video_name == 'placeholder':
video_name = audio.default_filename.split('.')[0]
elif args.video_file is not None:
# Read from local
video_path = args.video_file
if args.audio_file is not None:
audio_file= open(args.audio_file, "rb")
audio_path = args.audio_file
else:
output_audio_path = f'{download_path}/audio/{video_name}.mp3'
subprocess.run(['ffmpeg', '-i', video_path, '-f', 'mp3', '-ab', '192000', '-vn', output_audio_path])
audio_file = open(output_audio_path, "rb")
audio_path = output_audio_path
if not os.path.exists(f'{result_path}/{video_name}'):
os.mkdir(f'{result_path}/{video_name}')
if args.audio_file is not None:
audio_file= open(args.audio_file, "rb")
audio_path = args.audio_file
pass
return audio_path, audio_file, video_path, video_name
def get_srt_class(srt_file_en, result_path, video_name, audio_path, audio_file = None, whisper_model = 'large', method = "stable"):
# Instead of using the script_en variable directly, we'll use script_input
if srt_file_en is not None:
srt = SRT_script.parse_from_srt_file(srt_file_en)
else:
# using whisper to perform speech-to-text and save it in <video name>_en.txt under RESULT PATH.
srt_file_en = "{}/{}/{}_en.srt".format(result_path, video_name, video_name)
if not os.path.exists(srt_file_en):
# use OpenAI API for transcribe
if method == "api":
transcript = openai.Audio.transcribe("whisper-1", audio_file)
# use local whisper model
elif method == "basic":
model = whisper.load_model(whisper_model) # using base model in local machine (may use large model on our server)
transcript = model.transcribe(audio_path)
# use stable-whisper
elif method == "stable":
# use cuda if available
devices = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = stable_whisper.load_model(whisper_model, device = devices)
transcript = model.transcribe(audio_path, regroup = False, initial_prompt="Hello, welcome to my lecture. Are you good my friend?")
(
transcript
.split_by_punctuation(['.', '。', '?'])
.merge_by_gap(.15, max_words=3)
.merge_by_punctuation([' '])
.split_by_punctuation(['.', '。', '?'])
)
transcript = transcript.to_dict()
else:
raise ValueError("invalid speech to text method")
srt = SRT_script(transcript['segments']) # read segments to SRT class
else:
srt = SRT_script.parse_from_srt_file(srt_file_en)
return srt_file_en, srt
# Split the video script by sentences and create chunks within the token limit
def script_split(script_in, chunk_size = 1000):
script_split = script_in.split('\n\n')
script_arr = []
range_arr = []
start = 1
end = 0
script = ""
for sentence in script_split:
if len(script) + len(sentence) + 1 <= chunk_size:
script += sentence + '\n\n'
end+=1
else:
range_arr.append((start, end))
start = end+1
end += 1
script_arr.append(script.strip())
script = sentence + '\n\n'
if script.strip():
script_arr.append(script.strip())
range_arr.append((start, len(script_split)-1))
assert len(script_arr) == len(range_arr)
return script_arr, range_arr
# check whether previous translation is done
# zh_file = "{}/{}/{}_zh.srt".format(RESULT_PATH, VIDEO_NAME, VIDEO_NAME)
# segidx = 1
# if os.path.exists(zh_file):
# temp_file = "{}/{}/temp.srt".format(RESULT_PATH, VIDEO_NAME)
# if os.path.exists(temp_file):
# os.remove(temp_file)
# with open(zh_file, "r") as f0:
# for count, _ in enumerate(f0):
# pass
# count += 1
# segidx = int(count/4)+1
# en_file = "{}/{}/{}_en.srt".format(RESULT_PATH, VIDEO_NAME, VIDEO_NAME)
# if args.srt_file is not None:
# en_file = args.srt_file
# with open(en_file, "r") as f1, open(temp_file, "a") as f2:
# x = f1.readlines()
# #print(len(x))
# if count >= len(x):
# print('Work already done! Please delete {}_zh.srt files in result directory first in order to rework'.format(VIDEO_NAME))
# exit()
# for i, line in enumerate(x):
# if i >= count:
# f2.write(line)
# srt = SRT_script.parse_from_srt_file(temp_file)
# print('temp_contents')
# print(srt.get_source_only())
def check_translation(sentence, translation):
"""
check merge sentence issue from openai translation
"""
sentence_count = sentence.count('\n\n') + 1
translation_count = translation.count('\n\n') + 1
if sentence_count != translation_count:
# print("sentence length: ", len(sentence), sentence_count)
# print("translation length: ", len(translation), translation_count)
return False
else:
return True
def get_response(model_name, sentence):
"""
Generates a translated response for a given sentence using a specified OpenAI model.
Args:
model_name (str): The name of the OpenAI model to be used for translation, either "gpt-3.5-turbo" or "gpt-4".
sentence (str): The English sentence related to StarCraft 2 videos that needs to be translated into Chinese.
Returns:
str: The translated Chinese sentence, maintaining the original format, meaning, and number of lines.
"""
if model_name == "gpt-3.5-turbo" or model_name == "gpt-4":
response = openai.ChatCompletion.create(
model=model_name,
messages = [
#{"role": "system", "content": "You are a helpful assistant that translates English to Chinese and have decent background in starcraft2."},
#{"role": "system", "content": "Your translation has to keep the orginal format and be as accurate as possible."},
#{"role": "system", "content": "Your translation needs to be consistent with the number of sentences in the original."},
#{"role": "system", "content": "There is no need for you to add any comments or notes."},
#{"role": "user", "content": 'Translate the following English text to Chinese: "{}"'.format(sentence)}
{"role": "system", "content": "你是一个翻译助理,你的任务是翻译星际争霸视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"},
{"role": "user", "content": sentence}
],
temperature=0.15
)
return response['choices'][0]['message']['content'].strip()
# Translate and save
def translate(srt, script_arr, range_arr, model_name, video_name, video_link):
logging.info("start translating...")
previous_length = 0
for sentence, range in tqdm(zip(script_arr, range_arr)):
# update the range based on previous length
range = (range[0]+previous_length, range[1]+previous_length)
# using chatgpt model
print(f"now translating sentences {range}")
logging.info(f"now translating sentences {range}, time: {datetime.now()}")
flag = True
while flag:
flag = False
try:
translate = get_response(model_name, sentence)
# detect merge sentence issue and try to solve for five times:
attempt_left = 5
while not check_translation(sentence, translate) and attempt_left > 0:
translate = get_response(model_name, sentence)
attempt_left -= 1
# if failure still happen, split into smaller tokens
if attempt_left == 0:
single_sentences = sentence.split("\n\n")
print("merge sentence issue found for range", range)
translate = ""
for i, single_sentence in enumerate(single_sentences):
if i == len(single_sentences) - 1:
translate += get_response(model_name, single_sentence)
else:
translate += get_response(model_name, single_sentence) + "\n\n"
# print(single_sentence, translate.split("\n\n")[-2])
print("solved by individually translation!")
except Exception as e:
logging.debug("An error has occurred during translation:",e)
print("An error has occurred during translation:",e)
print("Retrying... the script will continue after 30 seconds.")
time.sleep(30)
flag = True
srt.set_translation(translate, range, model_name, video_name, video_link)
def main():
args = parse_args()
# input check: input should be either video file or youtube video link.
if args.link is None and args.video_file is None and args.srt_file is None and args.audio_file is None:
print("need video source or srt file")
exit()
# set up
start_time = time.time()
openai.api_key = os.getenv("OPENAI_API_KEY")
DOWNLOAD_PATH = args.download
if not os.path.exists(DOWNLOAD_PATH):
os.mkdir(DOWNLOAD_PATH)
os.mkdir(f'{DOWNLOAD_PATH}/audio')
os.mkdir(f'{DOWNLOAD_PATH}/video')
RESULT_PATH = args.output_dir
if not os.path.exists(RESULT_PATH):
os.mkdir(RESULT_PATH)
# set video name as the input file name if not specified
if args.video_name == 'placeholder' :
# set video name to upload file name
if args.video_file is not None:
VIDEO_NAME = args.video_file.split('/')[-1].split('.')[0]
elif args.audio_file is not None:
VIDEO_NAME = args.audio_file.split('/')[-1].split('.')[0]
elif args.srt_file is not None:
VIDEO_NAME = args.srt_file.split('/')[-1].split('.')[0].split("_")[0]
else:
VIDEO_NAME = args.video_name
else:
VIDEO_NAME = args.video_name
audio_path, audio_file, video_path, VIDEO_NAME = get_sources(args, DOWNLOAD_PATH, RESULT_PATH, VIDEO_NAME)
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)
logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler("{}/{}_{}.log".format(args.log_dir, VIDEO_NAME, datetime.now().strftime("%m%d%Y_%H%M%S")), 'w', encoding='utf-8')])
logging.info("---------------------Video Info---------------------")
logging.info("Video name: {}, translation model: {}, video link: {}".format(VIDEO_NAME, args.model_name, args.link))
srt_file_en, srt = get_srt_class(args.srt_file, RESULT_PATH, VIDEO_NAME, audio_path, audio_file)
# SRT class preprocess
logging.info("---------------------Start Preprocessing SRT class---------------------")
srt.write_srt_file_src(srt_file_en)
srt.form_whole_sentence()
# srt.spell_check_term()
srt.correct_with_force_term()
processed_srt_file_en = srt_file_en.split('.srt')[0] + '_processed.srt'
srt.write_srt_file_src(processed_srt_file_en)
script_input = srt.get_source_only()
# write ass
if not args.only_srt:
logging.info("write English .srt file to .ass")
assSub_en = srt2ass(processed_srt_file_en, "default", "No", "Modest")
logging.info('ASS subtitle saved as: ' + assSub_en)
script_arr, range_arr = script_split(script_input)
logging.info("---------------------Start Translation--------------------")
translate(srt, script_arr, range_arr, args.model_name, VIDEO_NAME, args.link)
# SRT post-processing
logging.info("---------------------Start Post-processing SRT class---------------------")
srt.check_len_and_split()
srt.remove_trans_punctuation()
srt.write_srt_file_translate(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt")
srt.write_srt_file_bilingual(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_bi.srt")
# write ass
if not args.only_srt:
logging.info("write Chinese .srt file to .ass")
assSub_zh = srt2ass(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt", "default", "No", "Modest")
logging.info('ASS subtitle saved as: ' + assSub_zh)
# encode to .mp4 video file
if args.v:
logging.info("encoding video file")
if args.only_srt:
os.system(f'ffmpeg -i {video_path} -vf "subtitles={RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt" {RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}.mp4')
else:
os.system(f'ffmpeg -i {video_path} -vf "subtitles={RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.ass" {RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}.mp4')
end_time = time.time()
logging.info("Pipeline finished, time duration:{}".format(time.strftime("%H:%M:%S", time.gmtime(end_time - start_time))))
if __name__ == "__main__":
main()