from cgitb import text import os import clip import torch.onnx import torch from torch import nn from multiprocessing import Pool class TextTransformer(nn.Module): def __init__(self, clip_model): super().__init__() self.clip_model = clip_model def forward(self, x: torch.Tensor): return self.clip_model.encode_text(x) def export(model, input, path): print(f"Exporting to {path}") torch.onnx.export( model, # model being run input, # model input (or a tuple for multiple inputs) path, # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=16, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding for optimization input_names = ['input'], # the model's input names output_names = ['output'], # the model's output names dynamic_axes={ 'input' : {0 : 'batch_size'}, # variable length axes 'output' : {0 : 'batch_size'} } ) def convert(model_name, dashed_name): visual_path = f"{output_dir}/clip-{dashed_name}-visual.onnx" textual_path = f"{output_dir}/clip-{dashed_name}-textual.onnx" visual_exists = os.path.exists(visual_path) textual_exists = os.path.exists(textual_path) if visual_exists and textual_exists: print(f"{visual_path} exists, skipping") print(f"{textual_path} exists, skipping") return print(f"Model: {model_name}") print(f"Loading CLIP") model, _ = clip.load(model_name, device=device) model = model.to(device=device) if not visual_exists: input_res = model.visual.input_resolution export( model.visual, torch.rand(1, 3, input_res, input_res), visual_path, ) else: print(f"{visual_path} exists, skipping") if not textual_exists: text_transformer = TextTransformer(model) export( text_transformer, clip.tokenize(["hello onnx"]).to(device), textual_path, ) else: print(f"{textual_path} exists, skipping") device = "cuda" if torch.cuda.is_available() else "cpu" device = "cpu" output_dir = "converted" if __name__ == "__main__": print(f"Torch device: {device}") available_models = clip.available_models() print(f"Available models: {available_models}") models = [ ("RN50", "resnet-50"), ("RN101", "resnet-101"), ("RN50x4", "resnet-50x4"), ("RN50x16", "resnet-50x16"), ("RN50x64", "resnet-50x64"), ("RN50", "resnet-50"), ("RN50", "resnet-50"), ("RN50", "resnet-50"), ("ViT-B/16", "vit-base-patch16"), ("ViT-B/32", "vit-base-patch32"), ("ViT-L/14", "vit-large-patch14"), ("ViT-L/14@336px", "vit-large-patch14-336"), ] print(f"Converting models: {models}") for model in models: convert(*model) # For converting multiple models at once # with Pool(1) as p: # p.starmap(convert, models) print("done")