sreejith8100 commited on
Commit
03f7143
·
verified ·
1 Parent(s): fabcfe8

Upload endpoint_handler.py

Browse files
Files changed (1) hide show
  1. endpoint_handler.py +88 -0
endpoint_handler.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from transformers import AutoModel, AutoTokenizer
4
+ from io import BytesIO
5
+ import base64
6
+ from huggingface_hub import login
7
+ import os
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, model_dir=None):
11
+ print("[Init] Initializing EndpointHandler...")
12
+ self.load_model()
13
+
14
+ def load_model(self):
15
+ hf_token = os.getenv("HF_TOKEN")
16
+ model_path = "openbmb/MiniCPM-V-4"
17
+
18
+ if hf_token:
19
+ print("[Auth] Logging into Hugging Face Hub with token...")
20
+ login(token=hf_token)
21
+
22
+ print(f"[Model Load] Attempting to load model from: {model_path}")
23
+ try:
24
+ model_path = "/app/models/minicpmv"
25
+ self.model = AutoModel.from_pretrained(
26
+ model_path,
27
+ trust_remote_code=True,
28
+ attn_implementation="sdpa",
29
+ torch_dtype=torch.float16,
30
+ device_map="auto"
31
+ ).eval()
32
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
33
+ except Exception as e:
34
+ print(f"[Model Load Failed]: {e}")
35
+
36
+
37
+ def load_image(self, image_base64):
38
+ try:
39
+ print("[Image Load] Decoding base64 image...")
40
+ image_bytes = base64.b64decode(image_base64)
41
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
42
+ print("[Image Load] Image successfully decoded and converted to RGB.")
43
+ return image
44
+ except Exception as e:
45
+ print(f"[Image Load Error] {e}")
46
+ raise ValueError(f"Failed to open image from base64 string: {e}")
47
+
48
+ def predict(self, request):
49
+ print(f"[Predict] Received request: {request}")
50
+
51
+ image_base64 = request.get("inputs", {}).get("image")
52
+ question = request.get("inputs", {}).get("question")
53
+ stream = request.get("inputs", {}).get("stream", False)
54
+
55
+ if not image_base64 or not question:
56
+ print("[Predict Error] Missing 'image' or 'question' in the request.")
57
+ return {"error": "Missing 'image' or 'question' in inputs."}
58
+
59
+ try:
60
+ image = self.load_image(image_base64)
61
+ msgs = [{"role": "user", "content": [image, question]}]
62
+
63
+ print(f"[Predict] Asking model with question: {question}")
64
+ print("[Predict] Starting chat inference...")
65
+
66
+ res = self.model.chat(
67
+ image=None,
68
+ msgs=msgs,
69
+ tokenizer=self.tokenizer,
70
+ sampling=True,
71
+ stream=stream
72
+ )
73
+
74
+ if stream:
75
+ for new_text in res:
76
+ yield {"output": new_text}
77
+ else:
78
+ generated_text = "".join(res)
79
+ print("[Predict] Inference complete.")
80
+ return {"output": generated_text}
81
+
82
+ except Exception as e:
83
+ print(f"[Predict Error] {e}")
84
+ return {"error": str(e)}
85
+
86
+ def __call__(self, data):
87
+ print("[__call__] Invoked handler with data.")
88
+ return self.predict(data)