marutitecblic commited on
Commit
0fe8b25
·
verified ·
1 Parent(s): c419b31

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +3 -126
handler.py CHANGED
@@ -1,127 +1,4 @@
1
- import torch
2
- from PIL import Image
3
- from transformers import AutoModelForCausalLM, AutoProcessor
4
- from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
5
- from transformers.image_transforms import resize, to_channel_dimension_format
6
- import os
7
- from typing import Dict, List, Any
8
 
9
- # Constants
10
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
-
12
- # HF_TASK = os.getenv('HF_TASK')
13
-
14
- # API_TOKEN = os.getenv('API_TOKEN') # Ensure you replace this with your actual API token
15
-
16
- # # Load processor and model
17
- # PROCESSOR = AutoProcessor.from_pretrained(
18
- # "marutitecblic/HtmlTocode",
19
- # trust_remote_code=True,
20
- # # token=API_TOKEN,
21
- # )
22
- # MODEL = AutoModelForCausalLM.from_pretrained(
23
- # "marutitecblic/HtmlTocode",
24
- # # token=API_TOKEN,
25
- # trust_remote_code=True,
26
- # torch_dtype=torch.bfloat16,
27
- # ).to(DEVICE)
28
-
29
- # image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
30
- # BOS_TOKEN = PROCESSOR.tokenizer.bos_token
31
- # BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
32
-
33
-
34
-
35
- # def preprocess(event):
36
- # image = Image.open(event["file"]).convert("RGB")
37
- # inputs = PROCESSOR.tokenizer(
38
- # f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
39
- # return_tensors="pt",
40
- # add_special_tokens=False,
41
- # )
42
- # inputs["pixel_values"] = PROCESSOR.image_processor([image], transform=custom_transform)
43
- # inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
44
- # return inputs
45
-
46
- # def inference(model_inputs):
47
- # inputs = preprocess(model_inputs)
48
- # generated_ids = MODEL.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_length=4096)
49
- # generated_text = PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
50
- # return {"generated_text": generated_text}
51
-
52
- # def postprocess(model_outputs):
53
- # return model_outputs
54
-
55
- # def handle(event, context):
56
- # model_inputs = event
57
- # model_outputs = inference(model_inputs)
58
- # response = postprocess(model_outputs)
59
- # return response
60
-
61
- class EndpointHandler:
62
- def __init__(self,model_path:str):
63
- # Load processor and model
64
- self.PROCESSOR = AutoProcessor.from_pretrained(
65
- model_path,
66
- trust_remote_code=True,
67
- # token=API_TOKEN,
68
- )
69
- self.MODEL = AutoModelForCausalLM.from_pretrained(
70
- model_path,
71
- # token=API_TOKEN,
72
- trust_remote_code=True,
73
- torch_dtype=torch.bfloat16,
74
- ).to(DEVICE)
75
- self.image_seq_len = self.MODEL.config.perceiver_config.resampler_n_latents
76
- self.BOS_TOKEN = self.PROCESSOR.tokenizer.bos_token
77
- self.BAD_WORDS_IDS = self.PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
78
-
79
-
80
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
81
- # image = data.pop("inputs", data)
82
-
83
- # # process image
84
- # pixel_values = self.processor(images=image, return_tensors="pt").pixel_values
85
-
86
- # # run prediction
87
- # generated_ids = self.model.generate(pixel_values)
88
-
89
- # # decode output
90
- # prediction = generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
91
- image = Image.open(data["file"]).convert("RGB")
92
- inputs = self.PROCESSOR.tokenizer(
93
- f"{self.BOS_TOKEN}<fake_token_around_image>{'<image>' * self.image_seq_len}<fake_token_around_image>",
94
- return_tensors="pt",
95
- add_special_tokens=False,
96
- )
97
- inputs["pixel_values"] = self.PROCESSOR.image_processor([image], transform=self.custom_transform)
98
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
99
- # inputs = preprocess(model_inputs)
100
- generated_ids = self.MODEL.generate(**inputs, bad_words_ids=self.BAD_WORDS_IDS, max_length=4096)
101
- generated_text = self.PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
102
- return {"text": generated_text}
103
- # return {"text":prediction[0]}
104
-
105
- # @classmethod
106
- def convert_to_rgb(self, image):
107
- if image.mode == "RGB":
108
- return image
109
- image_rgba = image.convert("RGBA")
110
- background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
111
- alpha_composite = Image.alpha_composite(background, image_rgba)
112
- alpha_composite = alpha_composite.convert("RGB")
113
- return alpha_composite
114
- # @classmethod
115
- def custom_transform(self, x):
116
- x = self.convert_to_rgb(x)
117
- x = to_numpy_array(x)
118
- x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
119
- x = self.PROCESSOR.image_processor.rescale(x, scale=1 / 255)
120
- x = self.PROCESSOR.image_processor.normalize(
121
- x,
122
- mean=self.PROCESSOR.image_processor.image_mean,
123
- std=self.PROCESSOR.image_processor.image_std
124
- )
125
- x = to_channel_dimension_format(x, ChannelDimension.FIRST)
126
- x = torch.tensor(x)
127
- return x
 
1
+ from custom_image_to_text_pipeline import ImageToTextPipeline
 
 
 
 
 
 
2
 
3
+ def get_inference_handler(model_dir):
4
+ return ImageToTextPipeline(model_dir)