Update README.md
Browse files
README.md
CHANGED
@@ -33,7 +33,9 @@ license: mit
|
|
33 |
![](https://black.readthedocs.io/en/stable/_static/license.svg)
|
34 |
![](https://img.shields.io/badge/code%20style-black-000000.svg)
|
35 |
|
36 |
-
An example of using this model to run on your video.
|
|
|
|
|
37 |
|
38 |
```python
|
39 |
import mimetypes
|
@@ -44,7 +46,6 @@ import requests
|
|
44 |
import torch
|
45 |
import transformers
|
46 |
from PIL import Image
|
47 |
-
|
48 |
from otter.modeling_otter import OtterForConditionalGeneration
|
49 |
|
50 |
# Disable warnings
|
@@ -61,7 +62,7 @@ def get_content_type(file_path):
|
|
61 |
# ------------------- Image and Video Handling Functions -------------------
|
62 |
|
63 |
|
64 |
-
def extract_frames(video_path, num_frames=
|
65 |
video = cv2.VideoCapture(video_path)
|
66 |
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
67 |
frame_step = total_frames // num_frames
|
@@ -83,9 +84,7 @@ def get_image(url: str) -> Union[Image.Image, list]:
|
|
83 |
if "://" not in url: # Local file
|
84 |
content_type = get_content_type(url)
|
85 |
else: # Remote URL
|
86 |
-
content_type = requests.head(url, stream=True, verify=False).headers.get(
|
87 |
-
"Content-Type"
|
88 |
-
)
|
89 |
|
90 |
if "image" in content_type:
|
91 |
if "://" not in url: # Local file
|
@@ -114,25 +113,13 @@ def get_formatted_prompt(prompt: str) -> str:
|
|
114 |
return f"<image>User: {prompt} GPT:<answer>"
|
115 |
|
116 |
|
117 |
-
def get_response(input_data, prompt: str, model=None, image_processor=None) -> str:
|
118 |
if isinstance(input_data, Image.Image):
|
119 |
-
vision_x = (
|
120 |
-
image_processor.preprocess([input_data], return_tensors="pt")[
|
121 |
-
"pixel_values"
|
122 |
-
]
|
123 |
-
.unsqueeze(1)
|
124 |
-
.unsqueeze(0)
|
125 |
-
)
|
126 |
elif isinstance(input_data, list): # list of video frames
|
127 |
-
vision_x = (
|
128 |
-
image_processor.preprocess(input_data, return_tensors="pt")["pixel_values"]
|
129 |
-
.unsqueeze(1)
|
130 |
-
.unsqueeze(0)
|
131 |
-
)
|
132 |
else:
|
133 |
-
raise ValueError(
|
134 |
-
"Invalid input data. Expected PIL Image or list of video frames."
|
135 |
-
)
|
136 |
|
137 |
lang_x = model.text_tokenizer(
|
138 |
[
|
@@ -142,7 +129,7 @@ def get_response(input_data, prompt: str, model=None, image_processor=None) -> s
|
|
142 |
)
|
143 |
|
144 |
generated_text = model.generate(
|
145 |
-
vision_x=vision_x.to(model.device),
|
146 |
lang_x=lang_x["input_ids"].to(model.device),
|
147 |
attention_mask=lang_x["attention_mask"].to(model.device),
|
148 |
max_new_tokens=512,
|
@@ -162,39 +149,37 @@ def get_response(input_data, prompt: str, model=None, image_processor=None) -> s
|
|
162 |
)
|
163 |
return parsed_output
|
164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
if load_bit == "fp16":
|
170 |
-
precision = {"torch_dtype": torch.float16}
|
171 |
-
elif load_bit == "bf16":
|
172 |
-
precision = {"torch_dtype": torch.bfloat16}
|
173 |
-
elif load_bit == "fp32":
|
174 |
-
precision = {"torch_dtype": torch.float32}
|
175 |
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
model.text_tokenizer.padding_side = "left"
|
181 |
-
tokenizer = model.text_tokenizer
|
182 |
-
image_processor = transformers.CLIPImageProcessor()
|
183 |
-
model.eval()
|
184 |
|
185 |
-
|
186 |
-
|
187 |
|
188 |
-
|
189 |
|
190 |
-
|
191 |
-
|
192 |
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
|
198 |
-
|
199 |
-
|
200 |
```
|
|
|
33 |
![](https://black.readthedocs.io/en/stable/_static/license.svg)
|
34 |
![](https://img.shields.io/badge/code%20style-black-000000.svg)
|
35 |
|
36 |
+
An example of using this model to run on your video.
|
37 |
+
Please first clone [Otter](https://github.com/Luodian/Otter) to your local disk.
|
38 |
+
Place following script inside the `Otter` folder to make sure it has the access to `otter/modeling_otter.py`.
|
39 |
|
40 |
```python
|
41 |
import mimetypes
|
|
|
46 |
import torch
|
47 |
import transformers
|
48 |
from PIL import Image
|
|
|
49 |
from otter.modeling_otter import OtterForConditionalGeneration
|
50 |
|
51 |
# Disable warnings
|
|
|
62 |
# ------------------- Image and Video Handling Functions -------------------
|
63 |
|
64 |
|
65 |
+
def extract_frames(video_path, num_frames=16):
|
66 |
video = cv2.VideoCapture(video_path)
|
67 |
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
68 |
frame_step = total_frames // num_frames
|
|
|
84 |
if "://" not in url: # Local file
|
85 |
content_type = get_content_type(url)
|
86 |
else: # Remote URL
|
87 |
+
content_type = requests.head(url, stream=True, verify=False).headers.get("Content-Type")
|
|
|
|
|
88 |
|
89 |
if "image" in content_type:
|
90 |
if "://" not in url: # Local file
|
|
|
113 |
return f"<image>User: {prompt} GPT:<answer>"
|
114 |
|
115 |
|
116 |
+
def get_response(input_data, prompt: str, model=None, image_processor=None, tensor_dtype=None) -> str:
|
117 |
if isinstance(input_data, Image.Image):
|
118 |
+
vision_x = image_processor.preprocess([input_data], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
elif isinstance(input_data, list): # list of video frames
|
120 |
+
vision_x = image_processor.preprocess(input_data, return_tensors="pt")["pixel_values"].unsqueeze(0).unsqueeze(0)
|
|
|
|
|
|
|
|
|
121 |
else:
|
122 |
+
raise ValueError("Invalid input data. Expected PIL Image or list of video frames.")
|
|
|
|
|
123 |
|
124 |
lang_x = model.text_tokenizer(
|
125 |
[
|
|
|
129 |
)
|
130 |
|
131 |
generated_text = model.generate(
|
132 |
+
vision_x=vision_x.to(model.device, dtype=tensor_dtype),
|
133 |
lang_x=lang_x["input_ids"].to(model.device),
|
134 |
attention_mask=lang_x["attention_mask"].to(model.device),
|
135 |
max_new_tokens=512,
|
|
|
149 |
)
|
150 |
return parsed_output
|
151 |
|
152 |
+
# ------------------- Main Function -------------------
|
153 |
+
load_bit = "fp16"
|
154 |
+
if load_bit == "fp16":
|
155 |
+
precision = {"torch_dtype": torch.float16}
|
156 |
+
elif load_bit == "bf16":
|
157 |
+
precision = {"torch_dtype": torch.bfloat16}
|
158 |
+
elif load_bit == "fp32":
|
159 |
+
precision = {"torch_dtype": torch.float32}
|
160 |
|
161 |
+
# This model version is trained on MIMIC-IT DC dataset.
|
162 |
+
model = OtterForConditionalGeneration.from_pretrained("luodian/OTTER-9B-DenseCaption", device_map="auto", **precision)
|
163 |
+
tensor_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}[load_bit]
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
+
model.text_tokenizer.padding_side = "left"
|
166 |
+
tokenizer = model.text_tokenizer
|
167 |
+
image_processor = transformers.CLIPImageProcessor()
|
168 |
+
model.eval()
|
|
|
|
|
|
|
|
|
169 |
|
170 |
+
while True:
|
171 |
+
video_url = "/path/to/your_video.mp4" # Replace with the path to your video file, could be any common format.
|
172 |
|
173 |
+
frames_list = get_image(video_url)
|
174 |
|
175 |
+
prompts_input = input("Enter prompts (comma-separated): ")
|
176 |
+
prompts = [prompt.strip() for prompt in prompts_input.split(",")]
|
177 |
|
178 |
+
for prompt in prompts:
|
179 |
+
print(f"\nPrompt: {prompt}")
|
180 |
+
response = get_response(frames_list, prompt, model, image_processor, tensor_dtype)
|
181 |
+
print(f"Response: {response}")
|
182 |
|
183 |
+
if prompts_input.lower() == "quit":
|
184 |
+
break
|
185 |
```
|