Spaces:
Sleeping
Sleeping
File size: 1,967 Bytes
ca25718 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import huggingface_hub
import torch
from hpsv2.src.open_clip import create_model, get_tokenizer
from rewards.base_reward import BaseRewardLoss
class HPSLoss(BaseRewardLoss):
"""HPS reward loss function for optimization."""
def __init__(
self,
weighting: float,
dtype: torch.dtype,
device: torch.device,
cache_dir: str,
memsave: bool = False,
):
self.hps_model = create_model(
"ViT-H-14",
"laion2B-s32B-b79K",
precision=dtype,
device=device,
cache_dir=cache_dir,
)
checkpoint_path = huggingface_hub.hf_hub_download(
"xswu/HPSv2", "HPS_v2.1_compressed.pt", cache_dir=cache_dir
)
self.hps_model.load_state_dict(
torch.load(checkpoint_path, map_location=device)["state_dict"]
)
self.hps_tokenizer = get_tokenizer("ViT-H-14")
if memsave:
import memsave_torch.nn
self.hps_model = memsave_torch.nn.convert_to_memory_saving(self.hps_model)
self.hps_model = self.hps_model.to(device, dtype=dtype)
self.hps_model.eval()
self.freeze_parameters(self.hps_model.parameters())
super().__init__("HPS", weighting)
self.hps_model.set_grad_checkpointing(True)
def get_image_features(self, image: torch.Tensor) -> torch.Tensor:
hps_image_features = self.hps_model.encode_image(image)
return hps_image_features
def get_text_features(self, prompt: str) -> torch.Tensor:
hps_text = self.hps_tokenizer(prompt).to("cuda")
hps_text_features = self.hps_model.encode_text(hps_text)
return hps_text_features
def compute_loss(
self, image_features: torch.Tensor, text_features: torch.Tensor
) -> torch.Tensor:
logits_per_image = image_features @ text_features.T
hps_loss = 1 - torch.diagonal(logits_per_image)[0]
return hps_loss
|