Update README.md
Browse files
README.md
CHANGED
@@ -47,6 +47,7 @@ import torch
|
|
47 |
import transformers
|
48 |
from PIL import Image
|
49 |
import sys
|
|
|
50 |
sys.path.append("..")
|
51 |
from otter.modeling_otter import OtterForConditionalGeneration
|
52 |
|
@@ -130,6 +131,7 @@ def get_response(input_data, prompt: str, model=None, image_processor=None, tens
|
|
130 |
return_tensors="pt",
|
131 |
)
|
132 |
|
|
|
133 |
generated_text = model.generate(
|
134 |
vision_x=vision_x.to(model.device, dtype=tensor_dtype),
|
135 |
lang_x=lang_x["input_ids"].to(model.device),
|
@@ -137,6 +139,7 @@ def get_response(input_data, prompt: str, model=None, image_processor=None, tens
|
|
137 |
max_new_tokens=512,
|
138 |
num_beams=3,
|
139 |
no_repeat_ngram_size=3,
|
|
|
140 |
)
|
141 |
parsed_output = (
|
142 |
model.text_tokenizer.decode(generated_text[0])
|
@@ -151,6 +154,7 @@ def get_response(input_data, prompt: str, model=None, image_processor=None, tens
|
|
151 |
)
|
152 |
return parsed_output
|
153 |
|
|
|
154 |
# ------------------- Main Function -------------------
|
155 |
load_bit = "fp16"
|
156 |
if load_bit == "fp16":
|
@@ -184,4 +188,5 @@ while True:
|
|
184 |
|
185 |
if prompts_input.lower() == "quit":
|
186 |
break
|
|
|
187 |
```
|
|
|
47 |
import transformers
|
48 |
from PIL import Image
|
49 |
import sys
|
50 |
+
|
51 |
sys.path.append("..")
|
52 |
from otter.modeling_otter import OtterForConditionalGeneration
|
53 |
|
|
|
131 |
return_tensors="pt",
|
132 |
)
|
133 |
|
134 |
+
bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
|
135 |
generated_text = model.generate(
|
136 |
vision_x=vision_x.to(model.device, dtype=tensor_dtype),
|
137 |
lang_x=lang_x["input_ids"].to(model.device),
|
|
|
139 |
max_new_tokens=512,
|
140 |
num_beams=3,
|
141 |
no_repeat_ngram_size=3,
|
142 |
+
bad_words_ids=bad_words_id,
|
143 |
)
|
144 |
parsed_output = (
|
145 |
model.text_tokenizer.decode(generated_text[0])
|
|
|
154 |
)
|
155 |
return parsed_output
|
156 |
|
157 |
+
|
158 |
# ------------------- Main Function -------------------
|
159 |
load_bit = "fp16"
|
160 |
if load_bit == "fp16":
|
|
|
188 |
|
189 |
if prompts_input.lower() == "quit":
|
190 |
break
|
191 |
+
|
192 |
```
|