luodian commited on
Commit
1d79cb5
1 Parent(s): 9cbec06

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +164 -45
README.md CHANGED
@@ -30,55 +30,174 @@ license: other
30
  Here is an example of multi-modal ICL (in-context learning) with 🦦 Otter. We provide two demo images with corresponding instructions and answers, then we ask the model to generate an answer given our instruct. You may change your instruction and see how the model responds.
31
 
32
  ``` python
 
 
 
 
 
33
  import requests
34
  import torch
35
  import transformers
36
  from PIL import Image
 
 
 
 
37
  from otter.modeling_otter import OtterForConditionalGeneration
38
 
39
- model = OtterForConditionalGeneration.from_pretrained(
40
- "luodian/otter-9b-hf", device_map="auto"
41
- )
42
-
43
- tokenizer = model.text_tokenizer
44
- image_processor = transformers.CLIPImageProcessor()
45
- demo_image_one = Image.open(
46
- requests.get(
47
- "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
48
- ).raw
49
- )
50
- demo_image_two = Image.open(
51
- requests.get(
52
- "http://images.cocodataset.org/test-stuff2017/000000028137.jpg", stream=True
53
- ).raw
54
- )
55
- query_image = Image.open(
56
- requests.get(
57
- "http://images.cocodataset.org/test-stuff2017/000000028352.jpg", stream=True
58
- ).raw
59
- )
60
- vision_x = (
61
- image_processor.preprocess(
62
- [demo_image_one, demo_image_two, query_image], return_tensors="pt"
63
- )["pixel_values"]
64
- .unsqueeze(1)
65
- .unsqueeze(0)
66
- )
67
- model.text_tokenizer.padding_side = "left"
68
- lang_x = model.text_tokenizer(
69
- [
70
- "<image> User: what does the image describe? GPT: <answer> two cats sleeping. <|endofchunk|> <image> User: what does the image describe? GPT: <answer> a bathroom sink. <|endofchunk|> <image> User: what does the image describe? GPT: <answer>"
71
- ],
72
- return_tensors="pt",
73
- )
74
- generated_text = model.generate(
75
- vision_x=vision_x.to(model.device),
76
- lang_x=lang_x["input_ids"].to(model.device),
77
- attention_mask=lang_x["attention_mask"].to(model.device),
78
- max_new_tokens=256,
79
- num_beams=1,
80
- no_repeat_ngram_size=3,
81
- )
82
-
83
- print("Generated text: ", model.text_tokenizer.decode(generated_text[0]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  ```
 
30
  Here is an example of multi-modal ICL (in-context learning) with 🦦 Otter. We provide two demo images with corresponding instructions and answers, then we ask the model to generate an answer given our instruct. You may change your instruction and see how the model responds.
31
 
32
  ``` python
33
+ import mimetypes
34
+ import os
35
+ from io import BytesIO
36
+ from typing import Union
37
+ import cv2
38
  import requests
39
  import torch
40
  import transformers
41
  from PIL import Image
42
+ from torchvision.transforms import Compose, Resize, ToTensor
43
+ from tqdm import tqdm
44
+ import sys
45
+
46
  from otter.modeling_otter import OtterForConditionalGeneration
47
 
48
+
49
+ # Disable warnings
50
+ requests.packages.urllib3.disable_warnings()
51
+
52
+ # ------------------- Utility Functions -------------------
53
+
54
+
55
+ def get_content_type(file_path):
56
+ content_type, _ = mimetypes.guess_type(file_path)
57
+ return content_type
58
+
59
+
60
+ # ------------------- Image and Video Handling Functions -------------------
61
+
62
+
63
+ def extract_frames(video_path, num_frames=16):
64
+ video = cv2.VideoCapture(video_path)
65
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
66
+ frame_step = total_frames // num_frames
67
+ frames = []
68
+
69
+ for i in range(num_frames):
70
+ video.set(cv2.CAP_PROP_POS_FRAMES, i * frame_step)
71
+ ret, frame = video.read()
72
+ if ret:
73
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
74
+ frame = Image.fromarray(frame).convert("RGB")
75
+ frames.append(frame)
76
+
77
+ video.release()
78
+ return frames
79
+
80
+
81
+ def get_image(url: str) -> Union[Image.Image, list]:
82
+ if "://" not in url: # Local file
83
+ content_type = get_content_type(url)
84
+ else: # Remote URL
85
+ content_type = requests.head(url, stream=True, verify=False).headers.get("Content-Type")
86
+
87
+ if "image" in content_type:
88
+ if "://" not in url: # Local file
89
+ return Image.open(url)
90
+ nne
91
+ else: # Remote URL
92
+ return Image.open(requests.get(url, stream=True, verify=False).raw)
93
+ elif "video" in content_type:
94
+ video_path = "temp_video.mp4"
95
+ if "://" not in url: # Local file
96
+ video_path = url
97
+ else: # Remote URL
98
+ with open(video_path, "wb") as f:
99
+ f.write(requests.get(url, stream=True, verify=False).content)
100
+ frames = extract_frames(video_path)
101
+ if "://" in url: # Only remove the temporary video file if it was downloaded
102
+ os.remove(video_path)
103
+ return frames
104
+ else:
105
+ raise ValueError("Invalid content type. Expected image or video.")
106
+
107
+
108
+ # ------------------- OTTER Prompt and Response Functions -------------------
109
+
110
+
111
+ def get_formatted_prompt(prompt: str, in_context_prompts: list = []) -> str:
112
+ in_context_string = ""
113
+ for in_context_prompt, in_context_answer in in_context_prompts:
114
+ in_context_string += f"<image>User: {in_context_prompt} GPT:<answer> {in_context_answer}<|endofchunk|>"
115
+ return f"{in_context_string}<image>User: {prompt} GPT:<answer>"
116
+
117
+
118
+ def get_response(image_list, prompt: str, model=None, image_processor=None, in_context_prompts: list = []) -> str:
119
+ input_data = image_list
120
+
121
+ if isinstance(input_data, Image.Image):
122
+ vision_x = (
123
+ image_processor.preprocess([input_data], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
124
+ )
125
+ elif isinstance(input_data, list): # list of video frames
126
+ vision_x = (
127
+ image_processor.preprocess(input_data, return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
128
+ )
129
+ else:
130
+ raise ValueError("Invalid input data. Expected PIL Image or list of video frames.")
131
+
132
+ lang_x = model.text_tokenizer(
133
+ [
134
+ get_formatted_prompt(prompt, in_context_prompts),
135
+ ],
136
+ return_tensors="pt",
137
+ )
138
+
139
+ generated_text = model.generate(
140
+ vision_x=vision_x.to(model.device),
141
+ lang_x=lang_x["input_ids"].to(model.device),
142
+ attention_mask=lang_x["attention_mask"].to(model.device),
143
+ max_new_tokens=512,
144
+ # num_beams=3,
145
+ # no_repeat_ngram_size=3,
146
+ )
147
+ parsed_output = (
148
+ model.text_tokenizer.decode(generated_text[0])
149
+ .split("<answer>")[-1]
150
+ .lstrip()
151
+ .rstrip()
152
+ .split("<|endofchunk|>")[0]
153
+ .lstrip()
154
+ .rstrip()
155
+ .lstrip('"')
156
+ .rstrip('"')
157
+ )
158
+ return parsed_output
159
+
160
+
161
+ # ------------------- Main Function -------------------
162
+
163
+ if __name__ == "__main__":
164
+ model = OtterForConditionalGeneration.from_pretrained(
165
+ "luodian/otter-9b-hf", device_map="auto"
166
+ )
167
+ model.text_tokenizer.padding_side = "left"
168
+ tokenizer = model.text_tokenizer
169
+ image_processor = transformers.CLIPImageProcessor()
170
+ model.eval()
171
+
172
+ while True:
173
+ urls = [
174
+ "https://images.cocodataset.org/train2017/000000339543.jpg",
175
+ "https://images.cocodataset.org/train2017/000000140285.jpg",
176
+ ]
177
+
178
+ encoded_frames_list = []
179
+ for url in urls:
180
+ frames = get_image(url)
181
+ encoded_frames_list.append(frames)
182
+
183
+ in_context_prompts = []
184
+ in_context_examples = [
185
+ "What does the image describe?::A family is taking picture in front of a snow mountain.",
186
+ ]
187
+ for in_context_input in in_context_examples:
188
+ in_context_prompt, in_context_answer = in_context_input.split("::")
189
+ in_context_prompts.append((in_context_prompt.strip(), in_context_answer.strip()))
190
+
191
+ # prompts_input = input("Enter the prompts separated by commas (or type 'quit' to exit): ")
192
+ prompts_input = "What does the image describe?"
193
+
194
+ prompts = [prompt.strip() for prompt in prompts_input.split(",")]
195
+
196
+ for prompt in prompts:
197
+ print(f"\nPrompt: {prompt}")
198
+ response = get_response(encoded_frames_list, prompt, model, image_processor, in_context_prompts)
199
+ print(f"Response: {response}")
200
+
201
+ if prompts_input.lower() == "quit":
202
+ break
203
  ```