merve HF staff commited on
Commit
0029ec4
1 Parent(s): 8e9be15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -13
app.py CHANGED
@@ -33,29 +33,45 @@ def model_inference(
33
  if isinstance(images, Image.Image):
34
  images = [images]
35
 
36
- if isinstance(text, str):
37
- text = "<image>" + text
38
- text = [text]
39
 
40
- inputs = processor(text=text, images=images, padding=True, return_tensors="pt").to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  assert decoding_strategy in [
43
  "Greedy",
44
  "Top P Sampling",
45
  ]
46
  if decoding_strategy == "Greedy":
47
- do_sample = False
48
  elif decoding_strategy == "Top P Sampling":
49
- do_sample = True
 
 
 
 
 
50
 
51
  # Generate
52
-
53
- generated_ids = model.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_new_tokens=max_new_tokens,
54
- temperature=temperature, do_sample=do_sample, repetition_penalty=repetition_penalty,
55
- top_p=top_p),
56
- #generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
57
- generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True)
58
- print("INPUT:", text, "|OUTPUT:", generated_texts)
59
  return generated_texts[0]
60
 
61
 
 
33
  if isinstance(images, Image.Image):
34
  images = [images]
35
 
 
 
 
36
 
37
+ resulting_messages = [
38
+ {
39
+ "role": "user",
40
+ "content": [{"type": "image"}] + [
41
+ {"type": "text", "text": text}
42
+ ]
43
+ }
44
+ ]
45
+
46
+
47
+ prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
48
+ inputs = processor(text=prompt, images=[images], return_tensors="pt")
49
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
50
+
51
+ generation_args = {
52
+ "max_new_tokens": max_new_tokens,
53
+ "repetition_penalty": repetition_penalty,
54
+
55
+ }
56
 
57
  assert decoding_strategy in [
58
  "Greedy",
59
  "Top P Sampling",
60
  ]
61
  if decoding_strategy == "Greedy":
62
+ generation_args["do_sample"] = False
63
  elif decoding_strategy == "Top P Sampling":
64
+ generation_args["temperature"] = temperature
65
+ generation_args["do_sample"] = True
66
+ generation_args["top_p"] = top_p
67
+
68
+
69
+ generation_args.update(inputs)
70
 
71
  # Generate
72
+ generated_ids = model.generate(**generation_args)
73
+
74
+ generated_texts = processor.batch_decode(generated_ids[:, generation_args["input_ids"].size(1):], skip_special_tokens=True)
 
 
 
 
75
  return generated_texts[0]
76
 
77