File size: 1,130 Bytes
d902dc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
import torch
from safetensors.torch import load_file as load_safetensor

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load(tokenizer_path = "tokenizer", text_encoder_path = "text_encoder"):
    """ loads the clip model and tokenizer. returns: tuple of clip_model, tokenizer"""
    safetensor_fp16 = f"./{text_encoder_path}/model.fp16.safetensors"  # or use model.safetensors
    config_path = f"./{text_encoder_path}/config.json"
    
    # Load tokenizer
    tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
    
    # Load CLIPTextModelWithProjection from the config file and safetensor
    clip_model = CLIPTextModelWithProjection.from_pretrained(
        text_encoder_path, 
        config=config_path, 
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
    )
    
    # Load safetensor weights
    state_dict = load_safetensor(safetensor_fp16)
    clip_model.load_state_dict(state_dict)
    clip_model = clip_model.to(device)
    
    return clip_model, tokenizer