ReNO / rewards /pickscore.py
fffiloni's picture
Upload 24 files
ca25718 verified
import torch
from transformers import AutoModel
from rewards.base_reward import BaseRewardLoss
class PickScoreLoss(BaseRewardLoss):
"""PickScore reward loss function for optimization."""
def __init__(
self,
weighting: float,
dtype: torch.dtype,
device: torch.device,
cache_dir: str,
tokenizer,
memsave: bool = False,
):
self.tokenizer = tokenizer
self.pickscore_model = AutoModel.from_pretrained(
"yuvalkirstain/PickScore_v1", cache_dir=cache_dir
).eval()
if memsave:
import memsave_torch.nn
self.pickscore_model = memsave_torch.nn.convert_to_memory_saving(
self.pickscore_model
)
self.pickscore_model = self.pickscore_model.to(device, dtype=dtype)
self.freeze_parameters(self.pickscore_model.parameters())
super().__init__("PickScore", weighting)
self.pickscore_model._set_gradient_checkpointing(True)
def get_image_features(self, image) -> torch.Tensor:
reward_img_features = self.pickscore_model.get_image_features(image)
return reward_img_features
def get_text_features(self, prompt: str) -> torch.Tensor:
prompt_token = self.tokenizer(
prompt, return_tensors="pt", padding=True, max_length=77, truncation=True
).to("cuda")
reward_text_features = self.pickscore_model.get_text_features(**prompt_token)
return reward_text_features
def compute_loss(
self, image_features: torch.Tensor, text_features: torch.Tensor
) -> torch.Tensor:
pickscore_loss = (
30
- (
self.pickscore_model.logit_scale.exp()
* (image_features @ text_features.T)
).mean()
)
return pickscore_loss