R-Kentaren commited on
Commit
23e3b7a
·
verified ·
1 Parent(s): 05f35b6

Create smolvlm_inference.py

Browse files
Files changed (1) hide show
  1. smolvlm_inference.py +23 -0
smolvlm_inference.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForImageTextToText, AutoProcessor
3
+
4
+
5
+ class TransformersModel:
6
+ def __init__(self, model_id: str, to_device: str = "cuda"):
7
+ self.model_id = model_id
8
+ self.processor = AutoProcessor.from_pretrained(model_id)
9
+ self.processor.image_processor.size = {"longest_edge": 3 * 384}
10
+ self.model = AutoModelForImageTextToText.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(to_device)
11
+
12
+ def generate(self, messages: list[dict], **kwargs):
13
+ inputs = self.processor.apply_chat_template(
14
+ messages,
15
+ add_generation_prompt=True,
16
+ tokenize=True,
17
+ return_dict=True,
18
+ return_tensors="pt",
19
+ ).to(self.model.device, dtype=torch.bfloat16)
20
+ generated_ids = self.model.generate(**inputs, **kwargs)
21
+ return self.processor.batch_decode(
22
+ generated_ids[:, len(inputs["input_ids"][0]) :], skip_special_tokens=True
23
+ )[0]