liltom-eth commited on
Commit
5f909cf
1 Parent(s): 112226e

Upload code/inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/inference.py +30 -3
code/inference.py CHANGED
@@ -6,9 +6,16 @@ from transformers import AutoTokenizer
6
 
7
  from llava.model import LlavaLlamaForCausalLM
8
  from llava.utils import disable_torch_init
9
- from llava.constants import IMAGE_TOKEN_INDEX
10
  from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
11
 
 
 
 
 
 
 
 
 
12
 
13
  def model_fn(model_dir):
14
  kwargs = {"device_map": "auto"}
@@ -32,11 +39,31 @@ def predict_fn(data, model_and_tokenizer):
32
 
33
  # get prompt & parameters
34
  image_file = data.pop("image", data)
35
- prompt = data.pop("question", data)
36
 
37
  max_new_tokens = data.pop("max_new_tokens", 1024)
38
  temperature = data.pop("temperature", 0.2)
39
- stop_str = data.pop("stop_str", "###")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  if image_file.startswith("http") or image_file.startswith("https"):
42
  response = requests.get(image_file)
 
6
 
7
  from llava.model import LlavaLlamaForCausalLM
8
  from llava.utils import disable_torch_init
 
9
  from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
10
 
11
+ from llava.conversation import conv_templates, SeparatorStyle
12
+ from llava.constants import (
13
+ IMAGE_TOKEN_INDEX,
14
+ DEFAULT_IMAGE_TOKEN,
15
+ DEFAULT_IM_START_TOKEN,
16
+ DEFAULT_IM_END_TOKEN,
17
+ )
18
+
19
 
20
  def model_fn(model_dir):
21
  kwargs = {"device_map": "auto"}
 
39
 
40
  # get prompt & parameters
41
  image_file = data.pop("image", data)
42
+ raw_prompt = data.pop("question", data)
43
 
44
  max_new_tokens = data.pop("max_new_tokens", 1024)
45
  temperature = data.pop("temperature", 0.2)
46
+ conv_mode = data.pop("conv_mode", "llava_v1")
47
+
48
+ if conv_mode == "raw":
49
+ # use raw_prompt as prompt
50
+ prompt = raw_prompt
51
+ stop_str = "###"
52
+ else:
53
+ conv = conv_templates[conv_mode].copy()
54
+ roles = conv.roles
55
+ inp = f"{roles[0]}: {raw_prompt}"
56
+ inp = (
57
+ DEFAULT_IM_START_TOKEN
58
+ + DEFAULT_IMAGE_TOKEN
59
+ + DEFAULT_IM_END_TOKEN
60
+ + "\n"
61
+ + inp
62
+ )
63
+ conv.append_message(conv.roles[0], inp)
64
+ conv.append_message(conv.roles[1], None)
65
+ prompt = conv.get_prompt()
66
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
67
 
68
  if image_file.startswith("http") or image_file.startswith("https"):
69
  response = requests.get(image_file)