luodian commited on
Commit
f77a3be
1 Parent(s): 2079d37

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +36 -51
README.md CHANGED
@@ -33,7 +33,9 @@ license: mit
33
  ![](https://black.readthedocs.io/en/stable/_static/license.svg)
34
  ![](https://img.shields.io/badge/code%20style-black-000000.svg)
35
 
36
- An example of using this model to run on your video. Please first clone [Otter](https://github.com/Luodian/Otter) to your local disk. Place following script inside the `Otter` folder to make sure it has the access to `otter/modeling_otter.py`.
 
 
37
 
38
  ```python
39
  import mimetypes
@@ -44,7 +46,6 @@ import requests
44
  import torch
45
  import transformers
46
  from PIL import Image
47
-
48
  from otter.modeling_otter import OtterForConditionalGeneration
49
 
50
  # Disable warnings
@@ -61,7 +62,7 @@ def get_content_type(file_path):
61
  # ------------------- Image and Video Handling Functions -------------------
62
 
63
 
64
- def extract_frames(video_path, num_frames=128):
65
  video = cv2.VideoCapture(video_path)
66
  total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
67
  frame_step = total_frames // num_frames
@@ -83,9 +84,7 @@ def get_image(url: str) -> Union[Image.Image, list]:
83
  if "://" not in url: # Local file
84
  content_type = get_content_type(url)
85
  else: # Remote URL
86
- content_type = requests.head(url, stream=True, verify=False).headers.get(
87
- "Content-Type"
88
- )
89
 
90
  if "image" in content_type:
91
  if "://" not in url: # Local file
@@ -114,25 +113,13 @@ def get_formatted_prompt(prompt: str) -> str:
114
  return f"<image>User: {prompt} GPT:<answer>"
115
 
116
 
117
- def get_response(input_data, prompt: str, model=None, image_processor=None) -> str:
118
  if isinstance(input_data, Image.Image):
119
- vision_x = (
120
- image_processor.preprocess([input_data], return_tensors="pt")[
121
- "pixel_values"
122
- ]
123
- .unsqueeze(1)
124
- .unsqueeze(0)
125
- )
126
  elif isinstance(input_data, list): # list of video frames
127
- vision_x = (
128
- image_processor.preprocess(input_data, return_tensors="pt")["pixel_values"]
129
- .unsqueeze(1)
130
- .unsqueeze(0)
131
- )
132
  else:
133
- raise ValueError(
134
- "Invalid input data. Expected PIL Image or list of video frames."
135
- )
136
 
137
  lang_x = model.text_tokenizer(
138
  [
@@ -142,7 +129,7 @@ def get_response(input_data, prompt: str, model=None, image_processor=None) -> s
142
  )
143
 
144
  generated_text = model.generate(
145
- vision_x=vision_x.to(model.device),
146
  lang_x=lang_x["input_ids"].to(model.device),
147
  attention_mask=lang_x["attention_mask"].to(model.device),
148
  max_new_tokens=512,
@@ -162,39 +149,37 @@ def get_response(input_data, prompt: str, model=None, image_processor=None) -> s
162
  )
163
  return parsed_output
164
 
 
 
 
 
 
 
 
 
165
 
166
- if __name__ == "__main__":
167
- # ------------------- Main Function -------------------
168
- load_bit = "fp16"
169
- if load_bit == "fp16":
170
- precision = {"torch_dtype": torch.float16}
171
- elif load_bit == "bf16":
172
- precision = {"torch_dtype": torch.bfloat16}
173
- elif load_bit == "fp32":
174
- precision = {"torch_dtype": torch.float32}
175
 
176
- # This model version is trained on MIMIC-IT DC dataset.
177
- model = OtterForConditionalGeneration.from_pretrained(
178
- "luodian/otter-9b-dc-hf", device_map="auto", **precision
179
- )
180
- model.text_tokenizer.padding_side = "left"
181
- tokenizer = model.text_tokenizer
182
- image_processor = transformers.CLIPImageProcessor()
183
- model.eval()
184
 
185
- while True:
186
- video_url = "demo.mp4" # Replace with the path to your video file
187
 
188
- frames_list = get_image(video_url)
189
 
190
- prompts_input = input("Enter prompts (comma-separated): ")
191
- prompts = [prompt.strip() for prompt in prompts_input.split(",")]
192
 
193
- for prompt in prompts:
194
- print(f"\nPrompt: {prompt}")
195
- response = get_response(frames_list, prompt, model, image_processor)
196
- print(f"Response: {response}")
197
 
198
- if prompts_input.lower() == "quit":
199
- break
200
  ```
 
33
  ![](https://black.readthedocs.io/en/stable/_static/license.svg)
34
  ![](https://img.shields.io/badge/code%20style-black-000000.svg)
35
 
36
+ An example of using this model to run on your video.
37
+ Please first clone [Otter](https://github.com/Luodian/Otter) to your local disk.
38
+ Place following script inside the `Otter` folder to make sure it has the access to `otter/modeling_otter.py`.
39
 
40
  ```python
41
  import mimetypes
 
46
  import torch
47
  import transformers
48
  from PIL import Image
 
49
  from otter.modeling_otter import OtterForConditionalGeneration
50
 
51
  # Disable warnings
 
62
  # ------------------- Image and Video Handling Functions -------------------
63
 
64
 
65
+ def extract_frames(video_path, num_frames=16):
66
  video = cv2.VideoCapture(video_path)
67
  total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
68
  frame_step = total_frames // num_frames
 
84
  if "://" not in url: # Local file
85
  content_type = get_content_type(url)
86
  else: # Remote URL
87
+ content_type = requests.head(url, stream=True, verify=False).headers.get("Content-Type")
 
 
88
 
89
  if "image" in content_type:
90
  if "://" not in url: # Local file
 
113
  return f"<image>User: {prompt} GPT:<answer>"
114
 
115
 
116
+ def get_response(input_data, prompt: str, model=None, image_processor=None, tensor_dtype=None) -> str:
117
  if isinstance(input_data, Image.Image):
118
+ vision_x = image_processor.preprocess([input_data], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
 
 
 
 
 
 
119
  elif isinstance(input_data, list): # list of video frames
120
+ vision_x = image_processor.preprocess(input_data, return_tensors="pt")["pixel_values"].unsqueeze(0).unsqueeze(0)
 
 
 
 
121
  else:
122
+ raise ValueError("Invalid input data. Expected PIL Image or list of video frames.")
 
 
123
 
124
  lang_x = model.text_tokenizer(
125
  [
 
129
  )
130
 
131
  generated_text = model.generate(
132
+ vision_x=vision_x.to(model.device, dtype=tensor_dtype),
133
  lang_x=lang_x["input_ids"].to(model.device),
134
  attention_mask=lang_x["attention_mask"].to(model.device),
135
  max_new_tokens=512,
 
149
  )
150
  return parsed_output
151
 
152
+ # ------------------- Main Function -------------------
153
+ load_bit = "fp16"
154
+ if load_bit == "fp16":
155
+ precision = {"torch_dtype": torch.float16}
156
+ elif load_bit == "bf16":
157
+ precision = {"torch_dtype": torch.bfloat16}
158
+ elif load_bit == "fp32":
159
+ precision = {"torch_dtype": torch.float32}
160
 
161
+ # This model version is trained on MIMIC-IT DC dataset.
162
+ model = OtterForConditionalGeneration.from_pretrained("luodian/OTTER-9B-DenseCaption", device_map="auto", **precision)
163
+ tensor_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}[load_bit]
 
 
 
 
 
 
164
 
165
+ model.text_tokenizer.padding_side = "left"
166
+ tokenizer = model.text_tokenizer
167
+ image_processor = transformers.CLIPImageProcessor()
168
+ model.eval()
 
 
 
 
169
 
170
+ while True:
171
+ video_url = "/path/to/your_video.mp4" # Replace with the path to your video file, could be any common format.
172
 
173
+ frames_list = get_image(video_url)
174
 
175
+ prompts_input = input("Enter prompts (comma-separated): ")
176
+ prompts = [prompt.strip() for prompt in prompts_input.split(",")]
177
 
178
+ for prompt in prompts:
179
+ print(f"\nPrompt: {prompt}")
180
+ response = get_response(frames_list, prompt, model, image_processor, tensor_dtype)
181
+ print(f"Response: {response}")
182
 
183
+ if prompts_input.lower() == "quit":
184
+ break
185
  ```