clip-variants / convert.py
mlunar's picture
Add shape inference for textual models
b98d24d
from cgitb import text
import os
import clip
import torch.onnx
import torch
from torch import nn
from multiprocessing import Pool
class TextTransformer(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.clip_model = clip_model
def forward(self, x: torch.Tensor):
return self.clip_model.encode_text(x)
def export(model, input, path):
print(f"Exporting to {path}")
torch.onnx.export(
model, # model being run
input, # model input (or a tuple for multiple inputs)
path, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=16, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={
'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}
}
)
def convert(model_name, dashed_name):
visual_path = f"{output_dir}/clip-{dashed_name}-visual.onnx"
textual_path = f"{output_dir}/clip-{dashed_name}-textual.onnx"
visual_exists = os.path.exists(visual_path)
textual_exists = os.path.exists(textual_path)
if visual_exists and textual_exists:
print(f"{visual_path} exists, skipping")
print(f"{textual_path} exists, skipping")
return
print(f"Model: {model_name}")
print(f"Loading CLIP")
model, _ = clip.load(model_name, device=device)
model = model.to(device=device)
if not visual_exists:
input_res = model.visual.input_resolution
export(
model.visual,
torch.rand(1, 3, input_res, input_res),
visual_path,
)
else:
print(f"{visual_path} exists, skipping")
if not textual_exists:
text_transformer = TextTransformer(model)
export(
text_transformer,
clip.tokenize(["hello onnx"]).to(device),
textual_path,
)
else:
print(f"{textual_path} exists, skipping")
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
output_dir = "converted"
if __name__ == "__main__":
print(f"Torch device: {device}")
available_models = clip.available_models()
print(f"Available models: {available_models}")
models = [
("RN50", "resnet-50"),
("RN101", "resnet-101"),
("RN50x4", "resnet-50x4"),
("RN50x16", "resnet-50x16"),
("RN50x64", "resnet-50x64"),
("RN50", "resnet-50"),
("RN50", "resnet-50"),
("RN50", "resnet-50"),
("ViT-B/16", "vit-base-patch16"),
("ViT-B/32", "vit-base-patch32"),
("ViT-L/14", "vit-large-patch14"),
("ViT-L/14@336px", "vit-large-patch14-336"),
]
print(f"Converting models: {models}")
for model in models:
convert(*model)
# For converting multiple models at once
# with Pool(1) as p:
# p.starmap(convert, models)
print("done")