|
''' |
|
@File : utils.py |
|
@Time : 2023/04/05 19:18:00 |
|
@Auther : Jiazheng Xu |
|
@Contact : xjz22@mails.tsinghua.edu.cn |
|
* Based on CLIP code base |
|
* https://github.com/openai/CLIP |
|
* Checkpoint of CLIP/BLIP/Aesthetic are from: |
|
* https://github.com/openai/CLIP |
|
* https://github.com/salesforce/BLIP |
|
* https://github.com/christophschuhmann/improved-aesthetic-predictor |
|
''' |
|
|
|
import os |
|
import urllib |
|
from typing import Union, List |
|
import pathlib |
|
|
|
import torch |
|
from tqdm import tqdm |
|
from huggingface_hub import hf_hub_download |
|
|
|
from .ImageReward import ImageReward |
|
from .models.CLIPScore import CLIPScore |
|
from .models.BLIPScore import BLIPScore |
|
from .models.AestheticScore import AestheticScore |
|
|
|
_MODELS = { |
|
"ImageReward-v1.0": "https://huggingface.co/THUDM/ImageReward/blob/main/ImageReward.pt", |
|
} |
|
|
|
|
|
def available_models() -> List[str]: |
|
"""Returns the names of available ImageReward models""" |
|
return list(_MODELS.keys()) |
|
|
|
|
|
def ImageReward_download(url: str, root: str): |
|
os.makedirs(root, exist_ok=True) |
|
filename = os.path.basename(url) |
|
download_target = os.path.join(root, filename) |
|
hf_hub_download(repo_id="THUDM/ImageReward", filename=filename, local_dir=root) |
|
return download_target |
|
|
|
|
|
def load(name: str = "ImageReward-v1.0", |
|
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", |
|
download_root: str = None, |
|
med_config_path: str = None): |
|
"""Load a ImageReward model |
|
|
|
Parameters |
|
---------- |
|
name: str |
|
A model name listed by `ImageReward.available_models()`, or the path to a model checkpoint containing the state_dict |
|
device: Union[str, torch.device] |
|
The device to put the loaded model |
|
download_root: str |
|
path to download the model files; by default, it uses "~/.cache/ImageReward" |
|
med_config_path: str |
|
|
|
Returns |
|
------- |
|
model : torch.nn.Module |
|
The ImageReward model |
|
""" |
|
if name in _MODELS: |
|
download_root = download_root or "~/.cache/ImageReward" |
|
download_root = pathlib.Path(download_root) |
|
model_path = pathlib.Path(download_root) / 'ImageReward.pt' |
|
|
|
if not model_path.exists(): |
|
model_path = ImageReward_download(_MODELS[name], root=download_root.as_posix()) |
|
elif os.path.isfile(name): |
|
model_path = name |
|
else: |
|
raise RuntimeError(f"Model {name} not found; available models = {available_models()}") |
|
|
|
print('-> load ImageReward model from %s' % model_path) |
|
state_dict = torch.load(model_path, map_location='cpu') |
|
|
|
|
|
if med_config_path is None: |
|
med_config_root = download_root or "~/.cache/ImageReward" |
|
med_config_root = pathlib.Path(med_config_root) |
|
med_config_path = med_config_root / 'med_config.json' |
|
|
|
if not med_config_path.exists(): |
|
med_config_path = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json", |
|
root=med_config_root.as_posix()) |
|
print('-> load ImageReward med_config from %s' % med_config_path) |
|
|
|
model = ImageReward(device=device, med_config=med_config_path).to(device) |
|
msg = model.load_state_dict(state_dict, strict=False) |
|
model.eval() |
|
|
|
return model |
|
|
|
|
|
_SCORES = { |
|
"CLIP": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", |
|
"BLIP": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth", |
|
"Aesthetic": "https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac%2Blogos%2Bava1-l14-linearMSE.pth", |
|
} |
|
|
|
|
|
def available_scores() -> List[str]: |
|
"""Returns the names of available ImageReward scores""" |
|
return list(_SCORES.keys()) |
|
|
|
|
|
def _download(url: str, root: str): |
|
os.makedirs(root, exist_ok=True) |
|
filename = os.path.basename(url) |
|
|
|
download_target = os.path.join(root, filename) |
|
|
|
if os.path.exists(download_target) and not os.path.isfile(download_target): |
|
raise RuntimeError(f"{download_target} exists and is not a regular file") |
|
|
|
if os.path.isfile(download_target): |
|
return download_target |
|
|
|
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: |
|
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, |
|
unit_divisor=1024) as loop: |
|
while True: |
|
buffer = source.read(8192) |
|
if not buffer: |
|
break |
|
|
|
output.write(buffer) |
|
loop.update(len(buffer)) |
|
|
|
return download_target |
|
|
|
|
|
def load_score(name: str = "CLIP", device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", |
|
download_root: str = None): |
|
"""Load a ImageReward model |
|
|
|
Parameters |
|
---------- |
|
name : str |
|
A model name listed by `ImageReward.available_models()` |
|
|
|
device : Union[str, torch.device] |
|
The device to put the loaded model |
|
|
|
download_root: str |
|
path to download the model files; by default, it uses "~/.cache/ImageReward" |
|
|
|
Returns |
|
------- |
|
model : torch.nn.Module |
|
The ImageReward model |
|
""" |
|
model_download_root = download_root or os.path.expanduser("~/.cache/ImageReward") |
|
|
|
if name in _SCORES: |
|
model_path = _download(_SCORES[name], model_download_root) |
|
else: |
|
raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}") |
|
|
|
print('load checkpoint from %s' % model_path) |
|
if name == "BLIP": |
|
state_dict = torch.load(model_path, map_location='cpu') |
|
med_config = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json", |
|
model_download_root) |
|
model = BLIPScore(med_config=med_config, device=device).to(device) |
|
model.blip.load_state_dict(state_dict['model'], strict=False) |
|
elif name == "CLIP": |
|
model = CLIPScore(download_root=model_download_root, device=device).to(device) |
|
elif name == "Aesthetic": |
|
state_dict = torch.load(model_path, map_location='cpu') |
|
model = AestheticScore(download_root=model_download_root, device=device).to(device) |
|
model.mlp.load_state_dict(state_dict, strict=False) |
|
else: |
|
raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}") |
|
|
|
print("checkpoint loaded") |
|
model.eval() |
|
|
|
return model |
|
|