from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from argparse import ArgumentParser import time import os, re import fnmatch import glob import shutil import zipfile from tqdm import tqdm def find_all_htmls(root_dir): html_files = [] for foldername, subfolders, filenames in os.walk(root_dir): for extension in ['*.html', '*.xhtml', '*.htm']: for filename in fnmatch.filter(filenames, extension): file_path = os.path.join(foldername, filename) html_files.append(file_path) return html_files def get_html_text_list(epub_path, text_length): data_list = [] def clean_text(text): text=re.sub(r']*?>.*?', '', text) text=re.sub(r'<[^>]*>|\n', '', text) return text with open(epub_path, 'r', encoding='utf-8') as f: file_text = f.read() matches = re.finditer(r'<(h[1-6]|p).*?>(.+?)', file_text, flags=re.DOTALL) if not matches: print("perhaps this file is a struct file") return data_list, file_text groups = [] text = '' pre_end = 0 for match in matches: if len(text + match.group(2)) <= text_length: new_text = clean_text(match.group(2)) if new_text: groups.append(match) text += '\n' + new_text else: data_list.append((text, groups, pre_end)) pre_end = groups[-1].end() new_text = clean_text(match.group(2)) if new_text: groups = [match] text = clean_text(match.group(2)) else: groups = [] text = '' if text: data_list.append((text, groups, pre_end)) # TEST: # for d in data_list: # print(f"{len(d[0])}", end=" ") return data_list, file_text def get_prompt(input, model_version): if model_version == '0.5' or model_version == '0.8': prompt = "将下面的日文文本翻译成中文:" + input + "" return prompt if model_version == '0.7': prompt = f"<|im_start|>user\n将下面的日文文本翻译成中文:{input}<|im_end|>\n<|im_start|>assistant\n" return prompt if model_version == '0.1': prompt = "Human: \n将下面的日文文本翻译成中文:" + input + "\n\nAssistant: \n" return prompt if model_version == '0.4': prompt = "User: 将下面的日文文本翻译成中文:" + input + "\nAssistant: " return prompt raise ValueError(f"Wrong model version{model_version}, please view https://huggingface.co/sakuraumi/Sakura-13B-Galgame") def split_response(response, model_version): response = response.replace("", "") if model_version == '0.5' or model_version == '0.8': output = response.split("")[1] return output if model_version == '0.7': output = response.split("<|im_start|>assistant\n")[1] return output if model_version == '0.1': output = response.split("\n\nAssistant: \n")[1] return output if model_version == '0.4': output = response.split("\nAssistant: ")[1] return output raise ValueError(f"Wrong model version{model_version}, please view https://huggingface.co/sakuraumi/Sakura-13B-Galgame") def detect_degeneration(generation: list, model_version): if model_version != "0.8": return False i = generation.index(196) generation = generation[i+1:] if len(generation) >= 1023: print("model degeneration detected, retrying...") return True else: return False def get_model_response(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, prompt: str, model_version: str, generation_config: GenerationConfig, text_length: int): backup_generation_config_stage2 = GenerationConfig( temperature=1, top_p=0.6, top_k=40, num_beams=1, bos_token_id=1, eos_token_id=2, pad_token_id=0, max_new_tokens=1024, min_new_tokens=1, do_sample=True ) backup_generation_config_stage3 = GenerationConfig( top_k=5, num_beams=1, bos_token_id=1, eos_token_id=2, pad_token_id=0, max_new_tokens=1024, min_new_tokens=1, penalty_alpha=0.3 ) backup_generation_config = [backup_generation_config_stage2, backup_generation_config_stage3] generation = model.generate(**tokenizer(prompt, return_tensors="pt").to(model.device), generation_config=generation_config)[0] if len(generation) > 2 * text_length: stage = 0 while detect_degeneration(list(generation), model_version): stage += 1 if stage > 2: print("model degeneration cannot be avoided.") break generation = model.generate(**tokenizer(prompt, return_tensors="pt").to(model.device), generation_config=backup_generation_config[stage-1])[0] response = tokenizer.decode(generation) output = split_response(response, model_version) return output def main(): parser = ArgumentParser() parser.add_argument("--model_name_or_path", type=str, default="SakuraLLM/Sakura-13B-LNovel-v0.8", help="model huggingface id or local path.") parser.add_argument("--use_gptq_model", action="store_true", help="whether your model is gptq quantized.") parser.add_argument("--model_version", type=str, default="0.8", help="model version written on huggingface readme, now we have ['0.1', '0.4', '0.5', '0.7', '0.8']") parser.add_argument("--data_path", type=str, default="", help="file path of the epub you want to translate.") parser.add_argument("--data_folder", type=str, default="", help="folder path of the epubs you want to translate.") parser.add_argument("--output_folder", type=str, default="", help="save folder path of the epubs model translated.") parser.add_argument("--text_length", type=int, default=512, help="input max length in each inference.") args = parser.parse_args() if args.use_gptq_model: from auto_gptq import AutoGPTQForCausalLM generation_config = GenerationConfig( temperature=0.1, top_p=0.3, top_k=40, num_beams=1, bos_token_id=1, eos_token_id=2, pad_token_id=0, max_new_tokens=1024, min_new_tokens=1, do_sample=True ) print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False, trust_remote_code=True) if args.use_gptq_model: model = AutoGPTQForCausalLM.from_quantized(args.model_name_or_path, device="cuda:0", trust_remote_code=True) else: model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, device_map="auto", trust_remote_code=True) print("Start translating...") start = time.time() epub_list = [] save_list = [] if args.data_path: epub_list.append(args.data_path) save_list.append(os.path.join(args.output_folder, os.path.basename(args.data_path))) if args.data_folder: os.makedirs(args.output_folder, exist_ok=True) for f in os.listdir(args.data_folder): if f.endswith(".epub"): epub_list.append(os.path.join(args.data_folder, f)) save_list.append(os.path.join(args.output_folder, f)) for epub_path, save_path in zip(epub_list, save_list): print(f"translating {epub_path}...") start_epub = time.time() if os.path.exists('./temp'): shutil.rmtree('./temp') with zipfile.ZipFile(epub_path, 'r') as f: f.extractall('./temp') for html_path in find_all_htmls('./temp'): print(f"\ttranslating {html_path}...") start_html = time.time() translated = '' data_list, file_text = get_html_text_list(html_path, args.text_length) if len(data_list) == 0: continue for text, groups, pre_end in tqdm(data_list): prompt = get_prompt(text, args.model_version) output = get_model_response(model, tokenizer, prompt, args.model_version, generation_config, args.text_length) texts = output.strip().split('\n') if len(texts) < len(groups): texts += [''] * (len(groups) - len(texts)) else: texts = texts[:len(groups)-1] + ['
'.join(texts[len(groups)-1:])] for t, match in zip(texts, groups): t = match.group(0).replace(match.group(2), t) translated += file_text[pre_end:match.start()] + t pre_end = match.end() translated += file_text[data_list[-1][1][-1].end():] with open(html_path, 'w', encoding='utf-8') as f: f.write(translated) end_html = time.time() print(f"\t{html_path} translated, used time: ", end_html-start_html) with zipfile.ZipFile(save_path, 'w', zipfile.ZIP_DEFLATED) as f: for file_path in glob.glob(f'./temp/**', recursive=True): if not os.path.isdir(file_path): relative_path = os.path.relpath(file_path, './temp') f.write(file_path, relative_path) shutil.rmtree('./temp') end_epub = time.time() print(f"{epub_path} translated, used time: ", end_epub-start_epub) end = time.time() print("translation completed, used time: ", end-start) if __name__ == "__main__": main()