liltom-eth commited on
Commit
bd82dd7
1 Parent(s): 5a6b6a9

Upload code/inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/inference.py +22 -3
code/inference.py CHANGED
@@ -6,9 +6,16 @@ from transformers import AutoTokenizer
6
 
7
  from llava.model import LlavaLlamaForCausalLM
8
  from llava.utils import disable_torch_init
9
- from llava.constants import IMAGE_TOKEN_INDEX
10
  from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
11
 
 
 
 
 
 
 
 
 
12
 
13
  def model_fn(model_dir):
14
  kwargs = {"device_map": "auto"}
@@ -32,11 +39,23 @@ def predict_fn(data, model_and_tokenizer):
32
 
33
  # get prompt & parameters
34
  image_file = data.pop("image", data)
35
- prompt = data.pop("question", data)
36
 
37
  max_new_tokens = data.pop("max_new_tokens", 1024)
38
  temperature = data.pop("temperature", 0.2)
39
- stop_str = data.pop("stop_str", "###")
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  if image_file.startswith("http") or image_file.startswith("https"):
42
  response = requests.get(image_file)
 
6
 
7
  from llava.model import LlavaLlamaForCausalLM
8
  from llava.utils import disable_torch_init
 
9
  from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
10
 
11
+ from llava.conversation import conv_templates, SeparatorStyle
12
+ from llava.constants import (
13
+ IMAGE_TOKEN_INDEX,
14
+ DEFAULT_IMAGE_TOKEN,
15
+ DEFAULT_IM_START_TOKEN,
16
+ DEFAULT_IM_END_TOKEN,
17
+ )
18
+
19
 
20
  def model_fn(model_dir):
21
  kwargs = {"device_map": "auto"}
 
39
 
40
  # get prompt & parameters
41
  image_file = data.pop("image", data)
42
+ raw_prompt = data.pop("question", data)
43
 
44
  max_new_tokens = data.pop("max_new_tokens", 1024)
45
  temperature = data.pop("temperature", 0.2)
46
+ conv_mode = data.pop("conv_mode", "llava_v1")
47
+
48
+ # conv_mode = "llava_v1"
49
+ conv = conv_templates[conv_mode].copy()
50
+ roles = conv.roles
51
+ inp = f"{roles[0]}: {raw_prompt}"
52
+ inp = (
53
+ DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp
54
+ )
55
+ conv.append_message(conv.roles[0], inp)
56
+ conv.append_message(conv.roles[1], None)
57
+ prompt = conv.get_prompt()
58
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
59
 
60
  if image_file.startswith("http") or image_file.startswith("https"):
61
  response = requests.get(image_file)