File size: 5,632 Bytes
2afa949
 
 
1aa1aec
 
2afa949
 
1aa1aec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2afa949
 
1aa1aec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2afa949
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1aa1aec
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# File name: model.py
import json
import os
import numpy as np
import torch
from starlette.requests import Request
from PIL import Image
import ray
from ray import serve
from clip_retrieval.load_clip import load_clip, get_tokenizer 
# from clip_retrieval.clip_client import ClipClient, Modality

@serve.deployment(num_replicas=6, ray_actor_options={"num_cpus": .2, "num_gpus": 0.1})
class CLIPTransform:
    def __init__(self):
        # os.environ["OMP_NUM_THREADS"] = "20"
        # torch.set_num_threads(20)
        # Load model
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self._clip_model="ViT-L/14"
        self._clip_model_id ="laion5B-L-14"
        self.model, self.preprocess = load_clip(self._clip_model, use_jit=True, device=self.device)
        self.tokenizer = get_tokenizer(self._clip_model)

        print ("using device", self.device)        

    def text_to_embeddings(self, prompt):
        text = self.tokenizer([prompt]).to(self.device)
        with torch.no_grad():
            prompt_embededdings = self.model.encode_text(text)
        prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
        return(prompt_embededdings)
    
    def image_to_embeddings(self, input_im):
        input_im = Image.fromarray(input_im)
        prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
        with torch.no_grad():
            image_embeddings = self.model.encode_image(prepro)
        image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
        return(image_embeddings)

    def preprocessed_image_to_emdeddings(self, prepro):
        with torch.no_grad():
            image_embeddings = self.model.encode_image(prepro)
        image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
        return(image_embeddings)    

    async def __call__(self, http_request: Request) -> str:
        form_data = await http_request.form()

        embeddings = None
        if "text" in form_data:
            prompt = (await form_data["text"].read()).decode()
            print (type(prompt))
            print (str(prompt))
            embeddings = self.text_to_embeddings(prompt)
        elif "image_url" in form_data:
            image_url = (await form_data["image_url"].read()).decode()
            # download image from url
            import requests
            from io import BytesIO
            image_bytes = requests.get(image_url).content
            input_image = Image.open(BytesIO(image_bytes))
            input_image = input_image.convert('RGB')
            input_image = np.array(input_image)
            embeddings = self.image_to_embeddings(input_image)
        elif "preprocessed_image" in form_data:
            tensor_bytes = await form_data["preprocessed_image"].read()
            shape_bytes = await form_data["shape"].read()
            dtype_bytes = await form_data["dtype"].read()

            # Convert bytes back to original form
            dtype_mapping = {
                "torch.float32": torch.float32,
                "torch.float64": torch.float64,
                "torch.float16": torch.float16,
                "torch.uint8": torch.uint8,
                "torch.int8": torch.int8,
                "torch.int16": torch.int16,
                "torch.int32": torch.int32,
                "torch.int64": torch.int64,
                torch.float32: np.float32,
                torch.float64: np.float64,
                torch.float16: np.float16,
                torch.uint8: np.uint8,
                torch.int8: np.int8,
                torch.int16: np.int16,
                torch.int32: np.int32,
                torch.int64: np.int64,
                # add more if needed
            }
            dtype_str = dtype_bytes.decode()
            dtype_torch = dtype_mapping[dtype_str]
            dtype_numpy = dtype_mapping[dtype_torch]
            # shape = np.frombuffer(shape_bytes, dtype=np.int64)
            # TODO: fix shape so it is passed nicely
            shape = tuple([1, 3, 224, 224])

            tensor_numpy = np.frombuffer(tensor_bytes, dtype=dtype_numpy).reshape(shape)
            tensor = torch.from_numpy(tensor_numpy)
            prepro = tensor.to(self.device)
            embeddings = self.preprocessed_image_to_emdeddings(prepro)
        else:
            print ("Invalid request")
            raise Exception("Invalid request")
        return embeddings.cpu().numpy().tolist()

        request = await http_request.json()
        # print(type(request))
        # print(str(request))
        # switch based if we are using text or image
        embeddings = None
        if "text" in request:
            prompt = request["text"]
            embeddings = self.text_to_embeddings(prompt)
        elif "image_url" in request:
            image_url = request["image_url"]
            # download image from url
            import requests
            from io import BytesIO
            image_bytes = requests.get(image_url).content
            input_image = Image.open(BytesIO(image_bytes))
            input_image = input_image.convert('RGB')
            input_image = np.array(input_image)
            embeddings = self.image_to_embeddings(input_image)
        elif "preprocessed_image" in request:
            prepro = request["preprocessed_image"]
            # create torch tensor on the device
            prepro = torch.tensor(prepro).to(self.device)
            embeddings = self.preprocessed_image_to_emdeddings(prepro)
        else:
            raise Exception("Invalid request")
        return embeddings.cpu().numpy().tolist()

deployment_graph = CLIPTransform.bind()