import os import math import argparse import glob import gradio import torch from torch.utils.data import DataLoader from tqdm import tqdm from transformers import ( PreTrainedTokenizerBase, DataCollatorForSeq2Seq, ) from model import load_model_for_inference from dataset import DatasetReader, count_lines from accelerate import Accelerator, DistributedType, find_executable_batch_size from typing import Optional def encode_string(text): return text.replace("\r", r"\r").replace("\n", r"\n").replace("\t", r"\t") def get_dataloader( accelerator: Accelerator, filename: str, tokenizer: PreTrainedTokenizerBase, batch_size: int, max_length: int, prompt: str, ) -> DataLoader: dataset = DatasetReader( filename=filename, tokenizer=tokenizer, max_length=max_length, prompt=prompt, ) if accelerator.distributed_type == DistributedType.TPU: data_collator = DataCollatorForSeq2Seq( tokenizer, padding="max_length", max_length=max_length, label_pad_token_id=tokenizer.pad_token_id, return_tensors="pt", ) else: data_collator = DataCollatorForSeq2Seq( tokenizer, padding=True, label_pad_token_id=tokenizer.pad_token_id, # max_length=max_length, No need to set max_length here, we already truncate in the preprocess function pad_to_multiple_of=8, return_tensors="pt", ) return DataLoader( dataset, batch_size=batch_size, collate_fn=data_collator, num_workers=0, # Disable multiprocessing ) def main( source_lang: Optional[str], target_lang: Optional[str], starting_batch_size: int = 8, model_name: str = "facebook/m2m100_1.2B", lora_weights_name_or_path: str = None, force_auto_device_map: bool = False, precision: str = None, max_length: int = 256, num_beams: int = 4, num_return_sequences: int = 1, do_sample: bool = False, temperature: float = 1.0, top_k: int = 50, top_p: float = 1.0, keep_special_tokens: bool = False, keep_tokenization_spaces: bool = False, repetition_penalty: float = None, prompt: str = None, trust_remote_code: bool = False, ): accelerator = Accelerator() sentences_path = "sample_text/en.small.txt" output_path = "sample_text/en2es.translation.m2m100_12B.txt" if force_auto_device_map and starting_batch_size >= 64: print( f"WARNING: You are using a very large batch size ({starting_batch_size}) and the auto_device_map flag. " f"auto_device_map will offload model parameters to the CPU when they don't fit on the GPU VRAM. " f"If you use a very large batch size, it will offload a lot of parameters to the CPU and slow down the " f"inference. You should consider using a smaller batch size, i.e '--starting_batch_size 8'" ) if precision is None: quantization = None dtype = None elif precision == "8" or precision == "4": quantization = int(precision) dtype = None elif precision == "fp16": quantization = None dtype = "float16" elif precision == "bf16": quantization = None dtype = "bfloat16" elif precision == "32": quantization = None dtype = "float32" else: raise ValueError( f"Precision {precision} not supported. Please choose between 8, 4, fp16, bf16, 32 or None." ) model, tokenizer = load_model_for_inference( weights_path=model_name, quantization=quantization, lora_weights_name_or_path=lora_weights_name_or_path, torch_dtype=dtype, force_auto_device_map=force_auto_device_map, trust_remote_code=trust_remote_code, ) is_translation_model = hasattr(tokenizer, "lang_code_to_id") lang_code_to_idx = None if ( is_translation_model and (source_lang is None or target_lang is None) and "small100" not in model_name ): raise ValueError( f"The model you are using requires a source and target language. " f"Please specify them with --source-lang and --target-lang. " f"The supported languages are: {tokenizer.lang_code_to_id.keys()}" ) if not is_translation_model and ( source_lang is not None or target_lang is not None ): if prompt is None: print( "WARNING: You are using a model that does not support source and target languages parameters " "but you specified them. You probably want to use m2m100/nllb200 for translation or " "set --prompt to define the task for you model. " ) else: print( "WARNING: You are using a model that does not support source and target languages parameters " "but you specified them." ) if prompt is not None and "%%SENTENCE%%" not in prompt: raise ValueError( f"The prompt must contain the %%SENTENCE%% token to indicate where the sentence should be inserted. " f"Your prompt: {prompt}" ) if is_translation_model: try: _ = tokenizer.lang_code_to_id[source_lang] except KeyError: raise KeyError( f"Language {source_lang} not found in tokenizer. Available languages: {tokenizer.lang_code_to_id.keys()}" ) tokenizer.src_lang = source_lang try: lang_code_to_idx = tokenizer.lang_code_to_id[target_lang] except KeyError: raise KeyError( f"Language {target_lang} not found in tokenizer. Available languages: {tokenizer.lang_code_to_id.keys()}" ) if "small100" in model_name: tokenizer.tgt_lang = target_lang # We don't need to force the BOS token, so we set is_translation_model to False is_translation_model = False if model.config.model_type == "seamless_m4t": # Loading a seamless_m4t model, we need to set a few things to ensure compatibility supported_langs = tokenizer.additional_special_tokens supported_langs = [lang.replace("__", "") for lang in supported_langs] if source_lang is None or target_lang is None: raise ValueError( f"The model you are using requires a source and target language. " f"Please specify them with --source-lang and --target-lang. " f"The supported languages are: {supported_langs}" ) if source_lang not in supported_langs: raise ValueError( f"Language {source_lang} not found in tokenizer. Available languages: {supported_langs}" ) if target_lang not in supported_langs: raise ValueError( f"Language {target_lang} not found in tokenizer. Available languages: {supported_langs}" ) tokenizer.src_lang = source_lang gen_kwargs = { "max_new_tokens": max_length, "num_beams": num_beams, "num_return_sequences": num_return_sequences, "do_sample": do_sample, "temperature": temperature, "top_k": top_k, "top_p": top_p, } if repetition_penalty is not None: gen_kwargs["repetition_penalty"] = repetition_penalty if is_translation_model: gen_kwargs["forced_bos_token_id"] = lang_code_to_idx if model.config.model_type == "seamless_m4t": gen_kwargs["tgt_lang"] = target_lang if accelerator.is_main_process: print( f"** Translation **\n" f"Input file: {sentences_path}\n" f"Output file: {output_path}\n" f"Source language: {source_lang}\n" f"Target language: {target_lang}\n" f"Force target lang as BOS token: {is_translation_model}\n" f"Prompt: {prompt}\n" f"Starting batch size: {starting_batch_size}\n" f"Device: {str(accelerator.device).split(':')[0]}\n" f"Num. Devices: {accelerator.num_processes}\n" f"Distributed_type: {accelerator.distributed_type}\n" f"Max length: {max_length}\n" f"Quantization: {quantization}\n" f"Precision: {dtype}\n" f"Model: {model_name}\n" f"LoRA weights: {lora_weights_name_or_path}\n" f"Force auto device map: {force_auto_device_map}\n" f"Keep special tokens: {keep_special_tokens}\n" f"Keep tokenization spaces: {keep_tokenization_spaces}\n" ) print("** Generation parameters **") print("\n".join(f"{k}: {v}" for k, v in gen_kwargs.items())) print("\n") @find_executable_batch_size(starting_batch_size=starting_batch_size) def inference(batch_size, sentences_path, output_path): nonlocal model, tokenizer, max_length, gen_kwargs, precision, prompt, is_translation_model print(f"Translating {sentences_path} with batch size {batch_size}") total_lines: int = count_lines(sentences_path) data_loader = get_dataloader( accelerator=accelerator, filename=sentences_path, tokenizer=tokenizer, batch_size=batch_size, max_length=max_length, prompt=prompt, ) model, data_loader = accelerator.prepare(model, data_loader) samples_seen: int = 0 with tqdm( total=total_lines, desc="Dataset translation", leave=True, ascii=True, disable=(not accelerator.is_main_process), ) as pbar, open(output_path, "w", encoding="utf-8") as output_file: with torch.no_grad(): for step, batch in enumerate(data_loader): batch["input_ids"] = batch["input_ids"] batch["attention_mask"] = batch["attention_mask"] generated_tokens = accelerator.unwrap_model(model).generate( **batch, **gen_kwargs, ) generated_tokens = accelerator.pad_across_processes( generated_tokens, dim=1, pad_index=tokenizer.pad_token_id ) generated_tokens = ( accelerator.gather(generated_tokens).cpu().numpy() ) tgt_text = tokenizer.batch_decode( generated_tokens, skip_special_tokens=not keep_special_tokens, clean_up_tokenization_spaces=not keep_tokenization_spaces, ) if accelerator.is_main_process: if ( step == math.ceil( math.ceil(total_lines / batch_size) / accelerator.num_processes ) - 1 ): tgt_text = tgt_text[ : (total_lines * num_return_sequences) - samples_seen ] else: samples_seen += len(tgt_text) print( "\n".join( [encode_string(sentence) for sentence in tgt_text] ), file=output_file, ) pbar.update(len(tgt_text) // gen_kwargs["num_return_sequences"]) print(f"Translation done. Output written to {output_path}\n") if sentences_path is not None: os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True) inference(sentences_path=sentences_path, output_path=output_path) print(f"Translation done.\n") with open(output_path, "r", encoding="utf-8") as f: return f.read() # if __name__ == "__main__": # parser = argparse.ArgumentParser(description="Run the translation experiments") # input_group = parser.add_mutually_exclusive_group(required=True) # input_group.add_argument( # "--sentences_path", # default=None, # type=str, # help="Path to a txt file containing the sentences to translate. One sentence per line.", # ) # input_group.add_argument( # "--sentences_dir", # type=str, # default=None, # help="Path to a directory containing the sentences to translate. " # "Sentences must be in .txt files containing containing one sentence per line.", # ) # parser.add_argument( # "--files_extension", # type=str, # default="txt", # help="If sentences_dir is specified, extension of the files to translate. Defaults to txt. " # "If set to an empty string, we will translate all files in the directory.", # ) # parser.add_argument( # "--output_path", # type=str, # required=True, # help="Path to a txt file where the translated sentences will be written. If the input is a directory, " # "the output will be a directory with the same structure.", # ) # parser.add_argument( # "--source_lang", # type=str, # default=None, # required=False, # help="Source language id. See: supported_languages.md. Required for m2m100 and nllb200", # ) # parser.add_argument( # "--target_lang", # type=str, # default=None, # required=False, # help="Source language id. See: supported_languages.md. Required for m2m100 and nllb200", # ) # parser.add_argument( # "--starting_batch_size", # type=int, # default=128, # help="Starting batch size, we will automatically reduce it if we find an OOM error." # "If you use multiple devices, we will divide this number by the number of devices.", # ) # parser.add_argument( # "--model_name", # type=str, # default="facebook/m2m100_1.2B", # help="Path to the model to use. See: https://huggingface.co/models", # ) # parser.add_argument( # "--lora_weights_name_or_path", # type=str, # default=None, # help="If the model uses LoRA weights, path to those weights. See: https://github.com/huggingface/peft", # ) # parser.add_argument( # "--force_auto_device_map", # action="store_true", # help=" Whether to force the use of the auto device map. If set to True, " # "the model will be split across GPUs and CPU to fit the model in memory. " # "If set to False, a full copy of the model will be loaded into each GPU. Defaults to False.", # ) # parser.add_argument( # "--max_length", # type=int, # default=256, # help="Maximum number of tokens in the source sentence and generated sentence. " # "Increase this value to translate longer sentences, at the cost of increasing memory usage.", # ) # parser.add_argument( # "--num_beams", # type=int, # default=5, # help="Number of beams for beam search, m2m10 author recommends 5, but it might use too much memory", # ) # parser.add_argument( # "--num_return_sequences", # type=int, # default=1, # help="Number of possible translation to return for each sentence (num_return_sequences<=num_beams).", # ) # parser.add_argument( # "--precision", # type=str, # default=None, # choices=["bf16", "fp16", "32", "4", "8"], # help="Precision of the model. bf16, fp16 or 32, 8 , 4 " # "(4bits/8bits quantification, requires bitsandbytes library: https://github.com/TimDettmers/bitsandbytes). " # "If None, we will use the torch.dtype of the model weights.", # ) # parser.add_argument( # "--do_sample", # action="store_true", # help="Use sampling instead of beam search.", # ) # parser.add_argument( # "--temperature", # type=float, # default=0.8, # help="Temperature for sampling, value used only if do_sample is True.", # ) # parser.add_argument( # "--top_k", # type=int, # default=100, # help="If do_sample is True, will sample from the top k most likely tokens.", # ) # parser.add_argument( # "--top_p", # type=float, # default=0.75, # help="If do_sample is True, will sample from the top k most likely tokens.", # ) # parser.add_argument( # "--keep_special_tokens", # action="store_true", # help="Keep special tokens in the decoded text.", # ) # parser.add_argument( # "--keep_tokenization_spaces", # action="store_true", # help="Do not clean spaces in the decoded text.", # ) # parser.add_argument( # "--repetition_penalty", # type=float, # default=None, # help="Repetition penalty.", # ) # parser.add_argument( # "--prompt", # type=str, # default=None, # help="Prompt to use for generation. " # "It must include the special token %%SENTENCE%% which will be replaced by the sentence to translate.", # ) # parser.add_argument( # "--trust_remote_code", # action="store_true", # help="If set we will trust remote code in HuggingFace models. This is required for some models.", # ) # args = parser.parse_args() # main( # sentences_path=args.sentences_path, # sentences_dir=args.sentences_dir, # files_extension=args.files_extension, # output_path=args.output_path, # source_lang=args.source_lang, # target_lang=args.target_lang, # starting_batch_size=args.starting_batch_size, # model_name=args.model_name, # max_length=args.max_length, # num_beams=args.num_beams, # num_return_sequences=args.num_return_sequences, # precision=args.precision, # do_sample=args.do_sample, # temperature=args.temperature, # top_k=args.top_k, # top_p=args.top_p, # keep_special_tokens=args.keep_special_tokens, # keep_tokenization_spaces=args.keep_tokenization_spaces, # repetition_penalty=args.repetition_penalty, # prompt=args.prompt, # trust_remote_code=args.trust_remote_code, # ) demo = gradio.Interface(fn=main, inputs=["textbox", "textbox"], outputs="textbox") demo.launch(share=True)