Spaces:
Runtime error
Runtime error
import argparse | |
import torch | |
import open_clip | |
import pandas as pd | |
from fvcore.nn import FlopCountAnalysis, flop_count_str, ActivationCountAnalysis | |
parser = argparse.ArgumentParser(description='OpenCLIP Profiler') | |
# benchmark specific args | |
parser.add_argument('--model', metavar='NAME', default='', | |
help='model(s) to profile') | |
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', | |
help='Output csv file for results') | |
def profile_fvcore( | |
model, | |
image_input_size=(3, 224, 224), | |
text_input_size=(77,), | |
batch_size=1, | |
detailed=False, | |
force_cpu=False | |
): | |
if force_cpu: | |
model = model.to('cpu') | |
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype | |
example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) | |
example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) | |
fca = FlopCountAnalysis(model, (example_image_input, example_text_input)) | |
aca = ActivationCountAnalysis(model, (example_image_input, example_text_input)) | |
if detailed: | |
fcs = flop_count_str(fca) | |
print(fcs) | |
return fca.total(), aca.total() | |
def profile_fvcore_text( | |
model, | |
text_input_size=(77,), | |
batch_size=1, | |
detailed=False, | |
force_cpu=False | |
): | |
if force_cpu: | |
model = model.to('cpu') | |
device = next(model.parameters()).device | |
example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) | |
fca = FlopCountAnalysis(model, example_input) | |
aca = ActivationCountAnalysis(model, example_input) | |
if detailed: | |
fcs = flop_count_str(fca) | |
print(fcs) | |
return fca.total(), aca.total() | |
def profile_fvcore_image( | |
model, | |
image_input_size=(3, 224, 224), | |
batch_size=1, | |
detailed=False, | |
force_cpu=False | |
): | |
if force_cpu: | |
model = model.to('cpu') | |
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype | |
example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) | |
fca = FlopCountAnalysis(model, example_input) | |
aca = ActivationCountAnalysis(model, example_input) | |
if detailed: | |
fcs = flop_count_str(fca) | |
print(fcs) | |
return fca.total(), aca.total() | |
def count_params(model): | |
return sum([m.numel() for m in model.parameters()]) | |
def profile_model(model_name): | |
model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False) | |
model.eval() | |
if torch.cuda.is_available(): | |
model = model.cuda() | |
if isinstance(model.visual.image_size, (tuple, list)): | |
image_input_size = (3,) + tuple(model.visual.image_size[-2:]) | |
else: | |
image_input_size = (3, model.visual.image_size, model.visual.image_size) | |
text_input_size = (77,) | |
results = {} | |
results['model'] = model_name | |
results['image_size'] = image_input_size[1] | |
model_cfg = open_clip.get_model_config(model_name) | |
if model_cfg: | |
vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg']) | |
text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg']) | |
results['image_width'] = int(vision_cfg.width) | |
results['text_width'] = int(text_cfg.width) | |
results['embed_dim'] = int(model_cfg['embed_dim']) | |
else: | |
results['image_width'] = 0 | |
results['text_width'] = 0 | |
results['embed_dim'] = 0 | |
retries = 2 | |
while retries: | |
retries -= 1 | |
try: | |
macs, acts = profile_fvcore( | |
model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries) | |
image_macs, image_acts = profile_fvcore_image( | |
model.visual, image_input_size=image_input_size, force_cpu=not retries) | |
text_macs, text_acts = profile_fvcore_text( | |
model.text, text_input_size=text_input_size, force_cpu=not retries) | |
results['gmacs'] = round(macs / 1e9, 2) | |
results['macts'] = round(acts / 1e6, 2) | |
results['mparams'] = round(count_params(model) / 1e6, 2) | |
results['image_gmacs'] = round(image_macs / 1e9, 2) | |
results['image_macts'] = round(image_acts / 1e6, 2) | |
results['image_mparams'] = round(count_params(model.visual) / 1e6, 2) | |
results['text_gmacs'] = round(text_macs / 1e9, 2) | |
results['text_macts'] = round(text_acts / 1e6, 2) | |
results['text_mparams'] = round(count_params(model.text) / 1e6, 2) | |
except RuntimeError as e: | |
pass | |
return results | |
def main(): | |
args = parser.parse_args() | |
# FIXME accept a text file name to allow lists of models in txt/csv | |
if args.model == 'all': | |
parsed_model = open_clip.list_models() | |
else: | |
parsed_model = args.model.split(',') | |
results = [] | |
for m in parsed_model: | |
row = profile_model(m) | |
results.append(row) | |
df = pd.DataFrame(results, columns=results[0].keys()) | |
df = df.sort_values('gmacs') | |
print(df) | |
if args.results_file: | |
df.to_csv(args.results_file, index=False) | |
if __name__ == '__main__': | |
main() | |