File size: 3,914 Bytes
824b515
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Adapt from Cheng An Hsieh, et. al.: https://github.com/RewardMultiverse/reward-multiverse
from PIL import Image
import io
import numpy as np
import torch.nn as nn
import torch
import torchvision
import albumentations as A
from transformers import CLIPModel, CLIPProcessor
# import ipdb
# st = ipdb.set_trace

def jpeg_compressibility(device):
    def _fn(images):
        '''
        args:
            images: shape NCHW
        '''
        org_type = images.dtype
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC

        transform_images_tensor = torch.Tensor(np.array(images)).to(device, dtype=org_type)
        transform_images_tensor = (transform_images_tensor.permute(0,3,1,2) / 255).clamp(0,1)   # NHWC -> NCHW
        transform_images_pil = [Image.fromarray(image) for image in images]
        buffers = [io.BytesIO() for _ in transform_images_pil]

        for image, buffer in zip(transform_images_pil, buffers):
            image.save(buffer, format="JPEG", quality=95)

        sizes = [buffer.tell() / 1000 for buffer in buffers]

        return np.array(sizes), transform_images_tensor

    return _fn


class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(32, 1),
        )

    def forward(self, embed):
        return self.layers(embed)

def jpegcompression_loss_fn(target=None,
                     grad_scale=0,
                     device=None,
                     accelerator=None,
                     torch_dtype=None,
                     reward_model_resume_from=None):
    scorer = JpegCompressionScorer(dtype=torch_dtype, model_path=reward_model_resume_from).to(device, dtype=torch_dtype)
    scorer.requires_grad_(False)
    scorer.eval()
    def loss_fn(im_pix_un): 
        if accelerator.mixed_precision == "fp16":
            with accelerator.autocast():
                rewards = scorer(im_pix_un)
        else:
            rewards = scorer(im_pix_un)
        
        if target is None:
            loss = rewards
        else:
            loss = abs(rewards - target)
        return loss * grad_scale, rewards
    return loss_fn

class JpegCompressionScorer(nn.Module):
    def __init__(self, dtype=None, model_path=None):
        super().__init__()
        self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
        self.clip.requires_grad_(False)
        self.score_generator = MLP()

        if model_path:
            state_dict = torch.load(model_path)
            self.score_generator.load_state_dict(state_dict)
        if dtype:
            self.dtype = dtype
        self.target_size = (224,224)
        self.normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                    std=[0.26862954, 0.26130258, 0.27577711])
       

    def set_device(self, device, inference_type):
        # self.clip.to(device, dtype = inference_type)
        self.score_generator.to(device) # , dtype = inference_type

    def __call__(self, images):
        device = next(self.parameters()).device
        im_pix = torchvision.transforms.Resize(self.target_size)(images)
        im_pix = self.normalize(im_pix).to(images.dtype)
        embed = self.clip.get_image_features(pixel_values=im_pix)
        embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
        return self.score_generator(embed).squeeze(1)