import typing |
import base64 |
import re |
import shutil |
import time |
import typing |
import asyncio |
import os |
import logging |
from kairyou import KatakanaUtil |
from easytl import EasyTL, Message, SystemTranslationMessage, ModelTranslationMessage |
import backoff |
from handlers.json_handler import JsonHandler |
from modules.common.file_ensurer import FileEnsurer |
from modules.common.toolkit import Toolkit |
from modules.common.exceptions import OpenAIAuthenticationError, MaxBatchDurationExceededException, DeepLAuthorizationException, OpenAIInternalServerError, OpenAIRateLimitError, OpenAIAPITimeoutError, GoogleAuthError, OpenAIAPIStatusError, OpenAIAPIConnectionError, DeepLException, GoogleAPIError |
from modules.common.decorators import permission_error_decorator |
from modules.common.gender_util import GenderUtil |
class Translator: |
""" |
Translator is a class that is used to interact with translation methods and translate text. |
Currently supports OpenAI, Gemini, DeepL, and Google Translate. |
""" |
text_to_translate:typing.List[str] = [] |
translated_text:typing.List[str] = [] |
je_check_text:typing.List[str] = [] |
error_text:typing.List[str] = [] |
openai_translation_batches:typing.List[SystemTranslationMessage | ModelTranslationMessage] = [] |
gemini_translation_batches:typing.List[str] = [] |
deepl_translation_batches:typing.List[str] = [] |
google_translate_translation_batches:typing.List[str] = [] |
num_occurred_malformed_batches = 0 |
_semaphore = asyncio.Semaphore(5) |
TRANSLATION_METHOD:typing.Literal["openai", "gemini", "deepl", "google translate"] = "deepl" |
translation_print_result = "" |
prompt_assembly_mode:int |
number_of_lines_per_batch:int |
sentence_fragmenter_mode:int |
je_check_mode:int |
number_of_malformed_batch_retries:int |
batch_retry_timeout:float |
num_concurrent_batches:int |
gender_context_insertion:bool |
is_cote:bool |
decorator_to_use:typing.Callable |
is_cli = False |
pre_provided_api_key = "" |
@staticmethod |
def get_max_batch_duration() -> float: |
""" |
Returns the max batch duration. |
Structured as a function so that it can be used as a lambda function in the backoff decorator. As decorators call the function when they are defined/runtime, not when they are called. Which I learned the hard way. |
Returns: |
max_batch_duration (float) : the max batch duration. |
""" |
return Translator.max_batch_duration |
@staticmethod |
def log_retry(details) -> None: |
""" |
Logs the retry message. |
Parameters: |
details (dict) : the details of the retry. |
""" |
retry_msg = f"Retrying translation after {details['wait']} seconds after {details['tries']} tries {details['target']} due to {details['exception']}." |
logging.warning(retry_msg) |
@staticmethod |
def log_failure(details) -> None: |
""" |
Logs the translation batch failure message. |
Parameters: |
details (dict) : the details of the failure. |
Raises: |
MaxBatchDurationExceededException : An exception that is raised when the max batch duration is exceeded. |
""" |
error_msg = f"Exceeded allowed duration of {details['wait']} seconds, returning untranslated text after {details['tries']} tries {details['target']}." |
logging.error(error_msg) |
raise MaxBatchDurationExceededException(error_msg) |
@staticmethod |
async def translate() -> None: |
""" |
Translate the text in the file at the path given. |
""" |
time_start = time.time() |
try: |
await Translator.initialize() |
JsonHandler.validate_json() |
if(not Translator.is_cli and Translator.TRANSLATION_METHOD != "google translate"): |
await Translator.check_settings() |
time_start = time.time() |
await Translator.commence_translation() |
except Exception as e: |
Translator.translation_print_result += "An error has occurred, outputting results so far..." |
FileEnsurer.handle_critical_exception(e) |
finally: |
time_end = time.time() |
Translator.assemble_results(time_start, time_end) |
if(Translator.is_cli): |
Toolkit.pause_console() |
@staticmethod |
async def initialize() -> None: |
""" |
Sets the API Key for the respective service and loads the translation settings. |
""" |
translation_methods = { |
"1": ("openai", FileEnsurer.openai_api_key_path), |
"2": ("gemini", FileEnsurer.gemini_api_key_path), |
"3": ("deepl", FileEnsurer.deepl_api_key_path), |
"4": ("google translate", FileEnsurer.google_translate_service_key_json_path) |
} |
if(not Translator.is_cli): |
method = input("What method would you like to use for translation? (1 for OpenAI, 2 for Gemini, 3 for Deepl, 4 for Google Translate), or any other key to exit) : \n") |
if(method not in translation_methods.keys()): |
print("\nThank you for using Kudasai, goodbye.") |
time.sleep(2) |
FileEnsurer.exit_kudasai() |
Toolkit.clear_console() |
else: |
method = Translator.TRANSLATION_METHOD |
Translator.TRANSLATION_METHOD, api_key_path = translation_methods.get(method, ("deepl", FileEnsurer.deepl_api_key_path)) |
if(Translator.pre_provided_api_key != ""): |
if(Translator.TRANSLATION_METHOD == "google translate"): |
encoded_key = base64.b64encode(Translator.pre_provided_api_key.encode('utf-8')).decode('utf-8') |
else: |
encoded_key = Translator.pre_provided_api_key |
Translator.pre_provided_api_key = "" |
with open(api_key_path, 'w+', encoding='utf-8') as file: |
file.write(encoded_key) |
await Translator.init_api_key(Translator.TRANSLATION_METHOD.capitalize(), api_key_path, EasyTL.set_credentials, EasyTL.test_credentials) |
try: |
JsonHandler.load_translation_settings() |
except: |
JsonHandler.reset_translation_settings_to_default() |
JsonHandler.load_translation_settings() |
Toolkit.clear_console() |
@staticmethod |
async def init_api_key(service:str, api_key_path:str, api_key_setter:typing.Callable, api_key_tester:typing.Callable) -> None: |
""" |
Sets up the api key for the respective service. |
Parameters: |
service (string) : the name of the service. |
api_key_path (string) : the path to the api key. |
api_key_setter (callable) : the function that sets the api key. |
api_key_tester (callable) : the function that tests the api key. |
""" |
def get_api_key_from_file(): |
with open(api_key_path, 'r', encoding='utf-8') as file: |
return base64.b64decode((file.read()).encode('utf-8')).decode('utf-8') |
def save_api_key(api_key): |
if(service != "Google translate"): |
encoded_key = base64.b64encode(api_key.encode('utf-8')).decode('utf-8') |
FileEnsurer.standard_overwrite_file(api_key_path, encoded_key, omit=True) |
else: |
FileEnsurer.standard_overwrite_file(api_key_path, api_key, omit=True) |
try: |
if(service != "Google translate"): |
api_key = get_api_key_from_file() |
else: |
api_key = api_key_path |
api_key_setter(service.lower(), api_key) |
is_valid, e = api_key_tester(service.lower()) |
if(not is_valid and e is not None): |
raise e |
logging.info(f"Used saved API key in {api_key_path}") |
time.sleep(2) |
except: |
Toolkit.clear_console() |
input_message = ( |
f"DO NOT DELETE YOUR COPY OF THE API KEY\n\nPlease enter the {service} API key you have : " |
if(service != "Google translate") |
else "DO NOT DELETE YOUR COPY OF THE SERVICE JSON\n\nPlease enter the contents of the service json file (on one line): " |
) |
api_key = input(input_message).strip('"').strip("'").strip() |
if(service == "Google translate"): |
save_api_key(api_key) |
time.sleep(1) |
api_key = api_key_path |
try: |
api_key_setter(service.lower(), api_key) |
is_valid, e = api_key_tester(service.lower()) |
if(not is_valid and e is not None): |
raise e |
save_api_key(api_key) |
except (GoogleAuthError, OpenAIAuthenticationError, DeepLAuthorizationException): |
Toolkit.clear_console() |
logging.error(f"Authorization error while setting up {service}, please double check your API key as it appears to be incorrect.") |
Toolkit.pause_console() |
exit(1) |
except Exception as e: |
Toolkit.clear_console() |
logging.error(f"Unknown error while setting up {service}, The error is as follows " + str(e) + "\nThe exception will now be raised.") |
Toolkit.pause_console() |
raise e |
@staticmethod |
def reset_static_variables() -> None: |
""" |
Resets the static variables. |
Done to prevent issues with the webgui. |
""" |
Translator.text_to_translate = [] |
Translator.translated_text = [] |
Translator.je_check_text = [] |
Translator.error_text = [] |
Translator.openai_translation_batches = [] |
Translator.gemini_translation_batches = [] |
Translator.deepl_translation_batches = [] |
Translator.google_translate_translation_batches = [] |
Translator.num_occurred_malformed_batches = 0 |
Translator.translation_print_result = "" |
Translator.TRANSLATION_METHOD = "deepl" |
Translator.pre_provided_api_key = "" |
Translator.is_cli = False |
@staticmethod |
async def check_settings() -> None: |
""" |
Prompts the user to confirm the settings in the translation settings file. |
""" |
print("Are these settings okay? (1 for yes or 2 for no):") |
method_to_section_dict = { |
"openai": ("openai settings", "OpenAI", FileEnsurer.openai_api_key_path), |
"gemini": ("gemini settings", "Gemini", FileEnsurer.gemini_api_key_path), |
"deepl": ("deepl settings", "DeepL", FileEnsurer.deepl_api_key_path), |
"google translate": (None, None, FileEnsurer.google_translate_service_key_json_path) |
} |
section_to_target, method_name, api_key_path = method_to_section_dict[Translator.TRANSLATION_METHOD] |
try: |
JsonHandler.log_translation_settings(output_to_console=True, specific_section=section_to_target) |
except: |
Toolkit.clear_console() |
if(input("It's likely that you're using an outdated version of the translation settings file, press 1 to reset these to default or 2 to exit and resolve manually : ") == "1"): |
Toolkit.clear_console() |
JsonHandler.reset_translation_settings_to_default() |
JsonHandler.load_translation_settings() |
print("Are these settings okay? (1 for yes or 2 for no) : \n") |
JsonHandler.log_translation_settings(output_to_console=True, specific_section=section_to_target) |
else: |
FileEnsurer.exit_kudasai() |
if(input("\n") != "1"): |
JsonHandler.change_translation_settings() |
Toolkit.clear_console() |
print("Do you want to change your API key? (1 for yes or 2 for no) : ") |
if(input("\n") == "1"): |
if(os.path.exists(api_key_path)): |
os.remove(api_key_path) |
await Translator.init_api_key(method_name, api_key_path, EasyTL.set_credentials, EasyTL.test_credentials) |
Toolkit.clear_console() |
@staticmethod |
async def commence_translation(is_webgui:bool=False) -> None: |
""" |
Uses all the other functions to translate the text provided by Kudasai. |
Parameters: |
is_webgui (bool | optional | default=False) : A bool representing whether the function is being called by the webgui. |
""" |
if(os.path.exists(FileEnsurer.external_translation_genders_path) and not is_webgui): |
logging.info("External genders.json file found, overriding config...") |
shutil.copy2(FileEnsurer.external_translation_genders_path, FileEnsurer.config_translation_genders_path) |
if(not os.path.exists(FileEnsurer.external_translation_settings_path) and not is_webgui): |
logging.info("External translation_settings.json file not found, using config...") |
print("External translation_settings.json file not found, using config...") |
time.sleep(2) |
logging.debug(f"Translator Activated, Translation Method : {Translator.TRANSLATION_METHOD} " |
f"Settings are as follows : ") |
JsonHandler.log_translation_settings() |
Translator.prompt_assembly_mode = int(JsonHandler.current_translation_settings["base translation settings"]["prompt_assembly_mode"]) |
Translator.number_of_lines_per_batch = int(JsonHandler.current_translation_settings["base translation settings"]["number_of_lines_per_batch"]) |
Translator.sentence_fragmenter_mode = int(JsonHandler.current_translation_settings["base translation settings"]["sentence_fragmenter_mode"]) |
Translator.je_check_mode = int(JsonHandler.current_translation_settings["base translation settings"]["je_check_mode"]) |
Translator.num_of_malform_retries = int(JsonHandler.current_translation_settings["base translation settings"]["number_of_malformed_batch_retries"]) |
Translator.max_batch_duration = float(JsonHandler.current_translation_settings["base translation settings"]["batch_retry_timeout"]) |
Translator.num_concurrent_batches = int(JsonHandler.current_translation_settings["base translation settings"]["number_of_concurrent_batches"]) |
Translator.gender_context_insertion = bool(JsonHandler.current_translation_settings["base translation settings"]["gender_context_insertion"]) |
Translator.is_cote = bool(JsonHandler.current_translation_settings["base translation settings"]["is_cote"]) |
GenderUtil.is_cote = Translator.is_cote |
Translator._semaphore = asyncio.Semaphore(Translator.num_concurrent_batches) |
Translator.openai_model = JsonHandler.current_translation_settings["openai settings"]["openai_model"] |
Translator.openai_system_message = JsonHandler.current_translation_settings["openai settings"]["openai_system_message"] |
Translator.openai_temperature = float(JsonHandler.current_translation_settings["openai settings"]["openai_temperature"]) |
Translator.openai_top_p = float(JsonHandler.current_translation_settings["openai settings"]["openai_top_p"]) |
Translator.openai_n = int(JsonHandler.current_translation_settings["openai settings"]["openai_n"]) |
Translator.openai_stream = bool(JsonHandler.current_translation_settings["openai settings"]["openai_stream"]) |
Translator.openai_stop = JsonHandler.current_translation_settings["openai settings"]["openai_stop"] |
Translator.openai_logit_bias = JsonHandler.current_translation_settings["openai settings"]["openai_logit_bias"] |
Translator.openai_max_tokens = JsonHandler.current_translation_settings["openai settings"]["openai_max_tokens"] |
Translator.openai_presence_penalty = float(JsonHandler.current_translation_settings["openai settings"]["openai_presence_penalty"]) |
Translator.openai_frequency_penalty = float(JsonHandler.current_translation_settings["openai settings"]["openai_frequency_penalty"]) |
Translator.gemini_model = JsonHandler.current_translation_settings["gemini settings"]["gemini_model"] |
Translator.gemini_prompt = JsonHandler.current_translation_settings["gemini settings"]["gemini_prompt"] |
Translator.gemini_temperature = float(JsonHandler.current_translation_settings["gemini settings"]["gemini_temperature"]) |
Translator.gemini_top_p = JsonHandler.current_translation_settings["gemini settings"]["gemini_top_p"] |
Translator.gemini_top_k = JsonHandler.current_translation_settings["gemini settings"]["gemini_top_k"] |
Translator.gemini_candidate_count = JsonHandler.current_translation_settings["gemini settings"]["gemini_candidate_count"] |
Translator.gemini_stream = bool(JsonHandler.current_translation_settings["gemini settings"]["gemini_stream"]) |
Translator.gemini_stop_sequences = JsonHandler.current_translation_settings["gemini settings"]["gemini_stop_sequences"] |
Translator.gemini_max_output_tokens = JsonHandler.current_translation_settings["gemini settings"]["gemini_max_output_tokens"] |
Translator.deepl_context = JsonHandler.current_translation_settings["deepl settings"]["deepl_context"] |
Translator.deepl_split_sentences = JsonHandler.current_translation_settings["deepl settings"]["deepl_split_sentences"] |
Translator.deepl_preserve_formatting = JsonHandler.current_translation_settings["deepl settings"]["deepl_preserve_formatting"] |
Translator.deepl_formality = JsonHandler.current_translation_settings["deepl settings"]["deepl_formality"] |
exception_dict = { |
"openai": (OpenAIAuthenticationError, OpenAIInternalServerError, OpenAIRateLimitError, OpenAIAPITimeoutError, OpenAIAPIConnectionError, OpenAIAPIStatusError), |
"gemini": GoogleAPIError, |
"deepl": DeepLException, |
"google translate": GoogleAPIError |
} |
Translator.decorator_to_use = backoff.on_exception( |
backoff.expo, |
max_time=lambda: Translator.get_max_batch_duration(), |
exception=exception_dict.get(Translator.TRANSLATION_METHOD, None), |
on_backoff=lambda details: Translator.log_retry(details), |
on_giveup=lambda details: Translator.log_failure(details), |
raise_on_giveup=False |
) |
Toolkit.clear_console() |
logging.info("Starting Prompt Building...") |
Translator.build_translation_batches() |
translation_methods = { |
"openai": JsonHandler.current_translation_settings["openai settings"]["openai_model"], |
"gemini": JsonHandler.current_translation_settings["gemini settings"]["gemini_model"], |
"deepl": "deepl", |
"google translate": "google translate" |
} |
model = translation_methods[Translator.TRANSLATION_METHOD] |
await Translator.handle_cost_estimate_prompt(model, omit_prompt=is_webgui or Translator.is_cli) |
Toolkit.clear_console() |
logging.info("Starting Translation...") |
async_requests = Translator.build_async_requests(model) |
results = await asyncio.gather(*async_requests) |
logging.info("Redistributing Translated Text...") |
sorted_results = sorted(results, key=lambda x: x[0]) |
for _, translated_prompt, translated_message in sorted_results: |
Translator.redistribute(translated_prompt, translated_message) |
if(Translator.je_check_mode == 2): |
Translator.je_check_text = Translator.fix_je() |
Toolkit.clear_console() |
logging.info("Done!") |
@staticmethod |
def build_async_requests(model:str) -> list[typing.Coroutine]: |
""" |
Builds the asynchronous requests. |
Parameters: |
model (string) : the model used to translate the text. |
Returns: |
async_requests (list - coroutine) : A list of coroutines that represent the asynchronous requests. |
""" |
async_requests = [] |
logging_message = "Built Messages: \n\n" |
translation_batches_methods = { |
method_name: getattr(Translator, f"{method_name}_translation_batches" if method_name != "google translate" else "google_translate_translation_batches") |
for method_name in ["openai", "gemini", "deepl", "google translate"] |
} |
translation_batches = translation_batches_methods[Translator.TRANSLATION_METHOD] |
batch_length = len(translation_batches) |
if(Translator.TRANSLATION_METHOD not in ["deepl", "google translate"]): |
for batch_number, (instructions, prompt) in enumerate(zip(translation_batches[::2], translation_batches[1::2]), start=1): |
assert isinstance(instructions, (SystemTranslationMessage, str)) |
assert isinstance(prompt, (ModelTranslationMessage, str)) |
if(Translator.gender_context_insertion): |
assumption = list(set(GenderUtil.get_gender_assumption_for_system_prompt(prompt if isinstance(prompt, str) else prompt.content))) |
assumption_string = "Additional Information:\nCharacter Genders:\n" + "".join(assumption) if len(assumption) > 0 else "" |
instructions = SystemTranslationMessage(content=f"{instructions.content if isinstance(instructions, Message) else instructions}\n{assumption_string}") |
logging_message += f"\n------------------------\n{instructions.content if isinstance(instructions, Message) else instructions}\n{prompt if isinstance(prompt, str) else prompt.content}" |
async_requests.append(Translator.handle_translation(model, batch_number, batch_length//2, prompt, instructions)) |
else: |
for batch_number, batch in enumerate(translation_batches, start=1): |
assert isinstance(batch, str) |
logging_message += f"\n------------------------\n{batch}" |
async_requests.append(Translator.handle_translation(model, batch_number, batch_length, batch, None)) |
logging.debug(logging_message) |
return async_requests |
@staticmethod |
def generate_text_to_translate_batches(index:int) -> tuple[typing.List[str],int]: |
""" |
Generates prompts for the messages meant for the API. |
Parameters: |
index (int) : An int representing where we currently are in the text file. |
Returns: |
prompt (list - string) : A list of Japanese lines that will be assembled into messages. |
index (int) : An updated int representing where we currently are in the text file. |
""" |
prompt = [] |
non_word_pattern = re.compile(r'^[\W_\s\n-]+$') |
special_chars = ["▼", "△", "◇"] |
quotes = ["「", "」", "『", "』", "【", "】", "\"", "'"] |
part_chars = ["1","2","3","4","5","6","7","8","9", " "] |
while(index < len(Translator.text_to_translate)): |
sentence = Translator.text_to_translate[index].strip() |
lowercase_sentence = sentence.lower() |
has_quotes = any(char in sentence for char in quotes) |
is_part_in_sentence = "part" in lowercase_sentence |
is_special_char = any(char in sentence for char in special_chars) |
is_part_char = all(char in sentence for char in part_chars) |
if(len(prompt) < Translator.number_of_lines_per_batch): |
if(is_special_char or is_part_in_sentence or is_part_char): |
prompt.append(f'{sentence}\n') |
logging.debug(f"Sentence : {sentence} Sentence is a pov change or part marker... adding to prompt.") |
elif(non_word_pattern.match(sentence) or KatakanaUtil.is_punctuation(sentence) and not has_quotes): |
logging.debug(f"Sentence : {sentence} Sentence is punctuation or spacing... skipping.") |
elif(sentence): |
prompt.append(f'{sentence}\n') |
logging.debug(f"Sentence : {sentence} Sentence is a valid sentence... adding to prompt.") |
else: |
return prompt, index |
index += 1 |
return prompt, index |
@staticmethod |
def build_translation_batches() -> None: |
""" |
Builds translations batches dict for the specified service. |
""" |
i = 0 |
while i < len(Translator.text_to_translate): |
batch, i = Translator.generate_text_to_translate_batches(i) |
batch = ''.join(batch) |
if(Translator.TRANSLATION_METHOD == 'openai'): |
if(Translator.prompt_assembly_mode == 1): |
system_msg = SystemTranslationMessage(content=str(Translator.openai_system_message)) |
else: |
system_msg = SystemTranslationMessage(content=str(Translator.openai_system_message)) |
Translator.openai_translation_batches.append(system_msg) |
model_msg = ModelTranslationMessage(content=batch) |
Translator.openai_translation_batches.append(model_msg) |
elif(Translator.TRANSLATION_METHOD == 'gemini'): |
Translator.gemini_translation_batches.append(Translator.gemini_prompt) |
Translator.gemini_translation_batches.append(batch) |
elif(Translator.TRANSLATION_METHOD == 'deepl'): |
Translator.deepl_translation_batches.append(batch) |
elif(Translator.TRANSLATION_METHOD == 'google translate'): |
Translator.google_translate_translation_batches.append(batch) |
@staticmethod |
async def handle_cost_estimate_prompt(model:str, omit_prompt:bool=False) -> str: |
""" |
Handles the cost estimate prompt. |
Parameters: |
model (string) : the model used to translate the text. |
omit_prompt (bool) : whether or not to omit the prompt. |
Returns: |
model (string) : the model used to translate the text. |
""" |
translation_instructions_methods = { |
"openai": Translator.openai_system_message, |
"gemini": Translator.gemini_prompt, |
"deepl": None, |
"google translate": None |
} |
translation_instructions = translation_instructions_methods[Translator.TRANSLATION_METHOD] |
num_entities, min_cost, model = EasyTL.calculate_cost(text=Translator.text_to_translate, service=Translator.TRANSLATION_METHOD, model=model,translation_instructions=translation_instructions) |
print("Note that the cost estimate is not always accurate, and may be higher than the actual cost. However cost calculation now includes output tokens.\n") |
if(Translator.TRANSLATION_METHOD == "gemini"): |
logging.info(f"As of Kudasai {Toolkit.CURRENT_VERSION}, Gemini Pro 1.0 is free to use under 15 requests per minute, Gemini Pro 1.5 is free to use under 2 requests per minute. Requests correspond to number_of_current_batches in the translation settings.") |
entity_word = "tokens" if Translator.TRANSLATION_METHOD in ["openai", "gemini"] else "characters" |
logging.info(f"Estimated number of {entity_word} : " + str(num_entities)) |
logging.info("Estimated minimum cost : " + str(min_cost) + " USD") |
if(not omit_prompt): |
if(input("\nContinue? (1 for yes or 2 for no) : ") == "1"): |
logging.info("User confirmed translation.") |
else: |
logging.info("User cancelled translation.") |
FileEnsurer.exit_kudasai() |
return model |
@staticmethod |
async def handle_translation(model:str, |
batch_number:int, |
length_of_batch:int, |
text_to_translate:typing.Union[str, ModelTranslationMessage], |
translation_instructions:typing.Union[str, SystemTranslationMessage, None]) -> tuple[int, str, str]: |
""" |
Handles the translation requests for the specified service. |
Parameters: |
model (string) : The model of the service used to translate the text. |
batch_number (int) : Which batch we are currently on. |
length_of_batch (int) : How long the batches are. |
text_to_translate (typing.Union[str, ModelTranslationMessage]) : The text to translate. |
translation_instructions (typing.Union[str, SystemTranslationMessage, None]) : The translation instructions. |
Returns: |
batch_number (int) : The batch index. |
text_to_translate (str) : The text to translate. |
translated_text (str) : The translated text |
""" |
async with Translator._semaphore: |
num_tries = 0 |
while True: |
if(FileEnsurer.do_interrupt == True): |
raise Exception("Interrupted by user.") |
logging.info(f"Trying translation for batch {batch_number} of {length_of_batch}...") |
try: |
translation_methods = { |
"openai": EasyTL.openai_translate_async, |
"gemini": EasyTL.gemini_translate_async, |
"deepl": EasyTL.deepl_translate_async, |
"google translate": EasyTL.googletl_translate_async |
} |
translation_params = { |
"openai": { |
"text": text_to_translate, |
"decorator": Translator.decorator_to_use, |
"translation_instructions": translation_instructions, |
"model": model, |
"temperature": Translator.openai_temperature, |
"top_p": Translator.openai_top_p, |
"stop": Translator.openai_stop, |
"max_tokens": Translator.openai_max_tokens, |
"presence_penalty": Translator.openai_presence_penalty, |
"frequency_penalty": Translator.openai_frequency_penalty |
}, |
"gemini": { |
"text": text_to_translate, |
"decorator": Translator.decorator_to_use, |
"model": model, |
"temperature": Translator.gemini_temperature, |
"top_p": Translator.gemini_top_p, |
"top_k": Translator.gemini_top_k, |
"stop_sequences": Translator.gemini_stop_sequences, |
"max_output_tokens": Translator.gemini_max_output_tokens |
}, |
"deepl": { |
"text": text_to_translate, |
"decorator": Translator.decorator_to_use, |
"context": Translator.deepl_context, |
"split_sentences": Translator.deepl_split_sentences, |
"preserve_formatting": Translator.deepl_preserve_formatting, |
"formality": Translator.deepl_formality |
}, |
"google translate": { |
"text": text_to_translate, |
"decorator": Translator.decorator_to_use |
} |
} |
assert isinstance(text_to_translate, ModelTranslationMessage if Translator.TRANSLATION_METHOD == "openai" else str) |
translated_message = await translation_methods[Translator.TRANSLATION_METHOD](**translation_params[Translator.TRANSLATION_METHOD]) |
except MaxBatchDurationExceededException: |
logging.error(f"Batch {batch_number} of {length_of_batch} was not translated due to exceeding the max request duration, returning the untranslated text...") |
break |
if("gpt-4" not in model and Translator.TRANSLATION_METHOD == "openai"): |
break |
if(await Translator.check_if_translation_is_good(translated_message, text_to_translate)): |
break |
if(num_tries >= Translator.num_of_malform_retries): |
logging.warning(f"Batch {batch_number} of {length_of_batch} was malformed but exceeded the max number of retries ({Translator.num_of_malform_retries})") |
break |
else: |
num_tries += 1 |
logging.warning(f"Batch {batch_number} of {length_of_batch} was malformed, retrying...") |
Translator.num_occurred_malformed_batches += 1 |
if(isinstance(text_to_translate, ModelTranslationMessage)): |
text_to_translate = text_to_translate.content |
if(isinstance(translated_message, typing.List)): |
translated_message = ''.join(translated_message) |
logging.info(f"Translation for batch {batch_number} of {length_of_batch} completed.") |
return batch_number, text_to_translate, translated_message |
@staticmethod |
async def check_if_translation_is_good(translated_message:typing.Union[typing.List[str], str], text_to_translate:typing.Union[ModelTranslationMessage, str]) -> bool: |
""" |
Checks if the translation is good, i.e. the number of lines in the prompt and the number of lines in the translated message are the same. |
Parameters: |
translated_message (str) : the translated message. |
text_to_translate (typing.Union[str, Message]) : the translation prompt. |
Returns: |
is_valid (bool) : whether or not the translation is valid. |
""" |
if(not isinstance(text_to_translate, str)): |
prompt = text_to_translate.content |
else: |
prompt = text_to_translate |
if(isinstance(translated_message, list)): |
translated_message = ''.join(translated_message) |
jap = [line for line in prompt.split('\n') if line.strip()] |
eng = [line for line in translated_message.split('\n') if line.strip()] |
return len(jap) == len(eng) |
@staticmethod |
def redistribute(text_to_translate:typing.Union[Message, str], translated_message:str) -> None: |
""" |
Puts translated text back into the text file. |
Parameters: |
text_to_translate (typing.Union[str, Message]) : the translation prompt. |
translated_message (str) : the translated message. |
""" |
if(not isinstance(text_to_translate, str)): |
prompt = text_to_translate.content |
else: |
prompt = text_to_translate |
if(Translator.je_check_mode == 1): |
Translator.je_check_text.append("\n-------------------------\n"+ prompt + "\n\n") |
Translator.je_check_text.append(translated_message + '\n') |
elif(Translator.je_check_mode == 2): |
Translator.je_check_text.append(prompt) |
Translator.je_check_text.append(translated_message) |
if(Translator.sentence_fragmenter_mode == 1): |
sentences = re.findall(r"(.*?(?:(?:\"|\'|-|~|!|\?|%|\(|\)|\.\.\.|\.|---|\[|\])))(?:\s|$)", translated_message) |
patched_sentences = [] |
build_string = None |
for sentence in sentences: |
sentence:str = sentence |
if(sentence.startswith("\"") and not sentence.endswith("\"") and build_string is None): |
build_string = sentence |
continue |
elif(not sentence.startswith("\"") and sentence.endswith("\"") and build_string is not None): |
build_string += f" {sentence}" |
patched_sentences.append(build_string) |
build_string = None |
continue |
elif(build_string is not None): |
build_string += f" {sentence}" |
continue |
Translator.translated_text.append(sentence + '\n') |
for i in range(len(Translator.translated_text)): |
if Translator.translated_text[i] in patched_sentences: |
index = patched_sentences.index(Translator.translated_text[i]) |
Translator.translated_text[i] = patched_sentences[index] |
elif(Translator.sentence_fragmenter_mode == 2): |
Translator.translated_text.append(translated_message) |
@staticmethod |
def fix_je() -> typing.List[str]: |
""" |
Fixes the J->E text to be more j-e checker friendly. |
Note that fix_je() is not always accurate, and may use standard j-e formatting instead of the corrected formatting. |
Returns: |
final_list (list - str) : the 'fixed' J->E text. |
""" |
i = 1 |
final_list = [] |
while(i < len(Translator.je_check_text)): |
jap = Translator.je_check_text[i-1].split('\n') |
eng = Translator.je_check_text[i].split('\n') |
jap = [line for line in jap if(line.strip())] |
eng = [line for line in eng if(line.strip())] |
final_list.append("-------------------------\n") |
if(len(jap) == len(eng)): |
for(jap_line, eng_line) in zip(jap, eng): |
if(jap_line and eng_line): |
final_list.append(jap_line + '\n\n') |
final_list.append(eng_line + '\n\n') |
final_list.append("--------------------------------------------------\n") |
else: |
final_list.append(Translator.je_check_text[i-1] + '\n\n') |
final_list.append(Translator.je_check_text[i] + '\n\n') |
final_list.append("--------------------------------------------------\n") |
i += 2 |
return final_list |
@staticmethod |
def assemble_results(time_start:float, time_end:float) -> None: |
""" |
Generates the Translator translation print result, does not directly output/return, but rather sets Translator.translation_print_result to the output. |
Parameters: |
time_start (float) : When the translation started. |
time_end (float) : When the translation finished. |
""" |
result = ( |
f"Time Elapsed : {Toolkit.get_elapsed_time(time_start, time_end)}\n" |
f"Number of malformed batches : {Translator.num_occurred_malformed_batches}\n\n" |
f"Debug text have been written to : {FileEnsurer.debug_log_path}\n" |
f"J->E text have been written to : {FileEnsurer.je_check_path}\n" |
f"Translated text has been written to : {FileEnsurer.translated_text_path}\n" |
f"Errors have been written to : {FileEnsurer.error_log_path}\n" |
) |
Translator.translation_print_result = result |
@staticmethod |
@permission_error_decorator() |
def write_translator_results() -> None: |
""" |
This function is called to write the results of the Translator module to the output directory. |
""" |
FileEnsurer.standard_create_directory(FileEnsurer.output_dir) |
with open(FileEnsurer.error_log_path, 'a+', encoding='utf-8') as file: |
file.writelines(Translator.error_text) |
with open(FileEnsurer.je_check_path, 'w', encoding='utf-8') as file: |
file.writelines(Translator.je_check_text) |
with open(FileEnsurer.translated_text_path, 'w', encoding='utf-8') as file: |
file.writelines(Translator.translated_text) |
FileEnsurer.standard_create_directory(FileEnsurer.archive_dir) |
timestamp = Toolkit.get_timestamp(is_archival=True) |
list_of_result_tuples = [('kudasai_translated_text', Translator.translated_text), |
('kudasai_je_check_text', Translator.je_check_text), |
('kudasai_error_log', Translator.error_text), |
('debug_log', FileEnsurer.standard_read_file(FileEnsurer.debug_log_path))] |
FileEnsurer.archive_results(list_of_result_tuples, |
module='translator', timestamp=timestamp) |