clip-variants / convert.py
mlunar's picture
Initial code import
6a3ad5b
raw
history blame
No virus
3.27 kB
from cgitb import text
import os
import clip
import torch.onnx
import torch
from torch import nn
from multiprocessing import Pool
class TextTransformer(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.clip_model = clip_model
def forward(self, x: torch.Tensor):
return self.clip_model.encode_text(x)
def export(model, input, path):
print(f"Exporting to {path}")
torch.onnx.export(
model, # model being run
input, # model input (or a tuple for multiple inputs)
path, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=16, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={
'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}
}
)
def convert(model_name, dashed_name):
visual_path = f"{output_dir}/clip-{dashed_name}-visual.onnx"
textual_path = f"{output_dir}/clip-{dashed_name}-textual.onnx"
visual_exists = os.path.exists(visual_path)
textual_exists = os.path.exists(textual_path)
if visual_exists and textual_exists:
print(f"{visual_path} exists, skipping")
print(f"{textual_path} exists, skipping")
return
print(f"Model: {model_name}")
print(f"Loading CLIP")
model, _ = clip.load(model_name, device=device)
model = model.to(device=device)
if not visual_exists:
input_res = model.visual.input_resolution
export(
model.visual,
torch.rand(1, 3, input_res, input_res),
visual_path,
)
else:
print(f"{visual_path} exists, skipping")
if not textual_exists:
text_transformer = TextTransformer(model)
export(
text_transformer,
clip.tokenize(["hello onnx"]).to(device),
textual_path,
)
else:
print(f"{textual_path} exists, skipping")
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
output_dir = "models"
if __name__ == "__main__":
print(f"Torch device: {device}")
available_models = clip.available_models()
print(f"Available models: {available_models}")
models = [
("RN50", "resnet-50"),
("RN101", "resnet-101"),
("RN50x4", "resnet-50x4"),
("RN50x16", "resnet-50x16"),
("RN50x64", "resnet-50x64"),
("RN50", "resnet-50"),
("RN50", "resnet-50"),
("RN50", "resnet-50"),
("ViT-B/16", "vit-base-patch16"),
("ViT-B/32", "vit-base-patch32"),
("ViT-L/14", "vit-large-patch14"),
("ViT-L/14@336px", "vit-large-patch14-336"),
]
print(f"Converting models: {models}")
for model in models:
convert(*model)
# For converting multiple models at once
# with Pool(1) as p:
# p.starmap(convert, models)
print("done")