liltom-eth commited on
Commit
f00035f
1 Parent(s): f255825

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +93 -0
inference.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ import torch
5
+ from transformers import AutoTokenizer
6
+
7
+ from llava.model import LlavaLlamaForCausalLM
8
+ from llava.utils import disable_torch_init
9
+ from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
10
+
11
+ from llava.conversation import conv_templates, SeparatorStyle
12
+ from llava.constants import (
13
+ IMAGE_TOKEN_INDEX,
14
+ DEFAULT_IMAGE_TOKEN,
15
+ DEFAULT_IM_START_TOKEN,
16
+ DEFAULT_IM_END_TOKEN,
17
+ )
18
+
19
+
20
+ def model_fn(model_dir):
21
+ kwargs = {"device_map": "auto"}
22
+ kwargs["torch_dtype"] = torch.float16
23
+ model = LlavaLlamaForCausalLM.from_pretrained(
24
+ model_dir, low_cpu_mem_usage=True, **kwargs
25
+ )
26
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False)
27
+
28
+ vision_tower = model.get_vision_tower()
29
+ if not vision_tower.is_loaded:
30
+ vision_tower.load_model()
31
+ vision_tower.to(device="cuda", dtype=torch.float16)
32
+ image_processor = vision_tower.image_processor
33
+ return model, tokenizer, image_processor
34
+
35
+
36
+ def predict_fn(data, model_and_tokenizer):
37
+ # unpack model and tokenizer
38
+ model, tokenizer, image_processor = model_and_tokenizer
39
+
40
+ # get prompt & parameters
41
+ image_file = data.pop("image", data)
42
+ raw_prompt = data.pop("question", data)
43
+
44
+ max_new_tokens = data.pop("max_new_tokens", 1024)
45
+ temperature = data.pop("temperature", 0.2)
46
+ conv_mode = data.pop("conv_mode", "llava_v1")
47
+
48
+ # conv_mode = "llava_v1"
49
+ conv = conv_templates[conv_mode].copy()
50
+ roles = conv.roles
51
+ inp = f"{roles[0]}: {raw_prompt}"
52
+ inp = (
53
+ DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp
54
+ )
55
+ conv.append_message(conv.roles[0], inp)
56
+ conv.append_message(conv.roles[1], None)
57
+ prompt = conv.get_prompt()
58
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
59
+
60
+ if image_file.startswith("http") or image_file.startswith("https"):
61
+ response = requests.get(image_file)
62
+ image = Image.open(BytesIO(response.content)).convert("RGB")
63
+ else:
64
+ image = Image.open(image_file).convert("RGB")
65
+
66
+ disable_torch_init()
67
+ image_tensor = (
68
+ image_processor.preprocess(image, return_tensors="pt")["pixel_values"]
69
+ .half()
70
+ .cuda()
71
+ )
72
+
73
+ keywords = [stop_str]
74
+ input_ids = (
75
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
76
+ .unsqueeze(0)
77
+ .cuda()
78
+ )
79
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
80
+ with torch.inference_mode():
81
+ output_ids = model.generate(
82
+ input_ids,
83
+ images=image_tensor,
84
+ do_sample=True,
85
+ temperature=temperature,
86
+ max_new_tokens=max_new_tokens,
87
+ use_cache=True,
88
+ stopping_criteria=[stopping_criteria],
89
+ )
90
+ outputs = tokenizer.decode(
91
+ output_ids[0, input_ids.shape[1] :], skip_special_tokens=True
92
+ ).strip()
93
+ return outputs