| |
| import logging |
| import os |
| from typing import Union, List |
|
|
| import cn_clip.clip as clip |
| import torch |
| from PIL import Image |
| from cn_clip.clip import load_from_name |
|
|
| from config import MODELS_PATH |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| MODEL_NAME_CN = os.environ.get('MODEL_NAME_CN', 'ViT-B-16') |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| model = None |
| preprocess = None |
|
|
| def init_clip_model(): |
| """初始化CLIP模型""" |
| global model, preprocess |
| try: |
| model, preprocess = load_from_name(MODEL_NAME_CN, device=device, download_root=MODELS_PATH) |
| model.eval() |
| logger.info(f"CLIP model initialized successfully, dimension: {model.visual.output_dim}") |
| return True |
| except Exception as e: |
| logger.error(f"CLIP model initialization failed: {e}") |
| return False |
|
|
| def is_clip_available(): |
| """检查CLIP模型是否可用""" |
| return model is not None and preprocess is not None |
|
|
| def encode_image(image_path: str) -> torch.Tensor: |
| """编码图片为向量""" |
| if not is_clip_available(): |
| raise RuntimeError("CLIP模型未初始化") |
|
|
| image = Image.open(image_path).convert("RGB") |
| image_tensor = preprocess(image).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| features = model.encode_image(image_tensor) |
| features = features / features.norm(p=2, dim=-1, keepdim=True) |
| return features.cpu() |
|
|
| def encode_text(text: Union[str, List[str]]) -> torch.Tensor: |
| """编码文本为向量""" |
| if not is_clip_available(): |
| raise RuntimeError("CLIP模型未初始化") |
|
|
| texts = [text] if isinstance(text, str) else text |
| text_tokens = clip.tokenize(texts).to(device) |
| with torch.no_grad(): |
| features = model.encode_text(text_tokens) |
| features = features / features.norm(p=2, dim=-1, keepdim=True) |
| return features.cpu() |
|
|