donut-base-ascii / speed_test.py
nbroad's picture
nbroad HF staff
Upload 2 files
cbde782
import argparse
import torch
from datasets import load_dataset
from transformers import AutoProcessor, VisionEncoderDecoderModel
def speedometer(
model: torch.nn.Module,
pixel_values: torch.Tensor,
decoder_input_ids: torch.Tensor,
processor: AutoProcessor,
bad_words_ids: list,
warmup_iters: int = 100,
timing_iters: int = 100,
num_tokens: int = 10,
) -> None:
"""Measure average run time for a PyTorch module
Performs forward passes.
"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
# Warmup runs
torch.cuda.synchronize()
for _ in range(warmup_iters):
outputs = model.generate(
pixel_values.to(model.device),
decoder_input_ids=decoder_input_ids.to(model.device),
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=bad_words_ids,
return_dict_in_generate=True,
min_length=num_tokens,
max_length=num_tokens,
)
# Timing runs
start.record()
for _ in range(timing_iters):
outputs = model.generate(
pixel_values.to(model.device),
decoder_input_ids=decoder_input_ids.to(model.device),
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=bad_words_ids,
return_dict_in_generate=True,
min_length=num_tokens,
max_length=num_tokens,
)
end.record()
torch.cuda.synchronize()
mean = start.elapsed_time(end) / timing_iters
print(f"Mean time: {mean} ms")
return mean
def get_ja_list_of_lists(processor):
def is_japanese(s):
"Made by GPT-4: https://chat.openai.com/share/a795b15c-8534-40b9-9699-c8c1319f5f25"
for char in s:
code_point = ord(char)
if (
0x3040 <= code_point <= 0x309F
or 0x30A0 <= code_point <= 0x30FF
or 0x4E00 <= code_point <= 0x9FFF
or 0x3400 <= code_point <= 0x4DBF
or 0x20000 <= code_point <= 0x2A6DF
or 0x31F0 <= code_point <= 0x31FF
or 0xFF00 <= code_point <= 0xFFEF
or 0x3000 <= code_point <= 0x303F
or 0x3200 <= code_point <= 0x32FF
):
continue
else:
return False
return True
ja_tokens, ja_ids = [], []
for token, id in processor.tokenizer.vocab.items():
if is_japanese(token.lstrip("▁")):
ja_tokens.append(token)
ja_ids.append(id)
return [[x] for x in ja_ids]
def main():
parser = argparse.ArgumentParser(description='Description of your program')
parser.add_argument('--model_path', help='Description for foo argument', required=True)
parser.add_argument('--ja_bad_words', help='Use ja bad_words_ids', action="store_true", default=False)
args = parser.parse_args()
print("Running speed test on model: ", args.model_path, "with ja_bad_words: ", args.ja_bad_words)
processor = AutoProcessor.from_pretrained(args.model_path)
model = VisionEncoderDecoderModel.from_pretrained(args.model_path)
device = 0 if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
image = dataset[1]["image"]
task_prompt = "<s_synthdog>"
decoder_input_ids = processor.tokenizer(
task_prompt, add_special_tokens=False, return_tensors="pt"
).input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values
bad_words_ids = [[processor.tokenizer.unk_token_id]]
if args.ja_bad_words:
bad_words_ids += get_ja_list_of_lists(processor)
print("Length of bad_words_ids: ", len(bad_words_ids))
results = speedometer(
model,
pixel_values,
decoder_input_ids,
processor,
bad_words_ids=bad_words_ids,
warmup_iters=100,
timing_iters=100,
num_tokens=10,
)
if __name__ == "__main__":
main()