leiwx52 commited on
Commit
5a444be
1 Parent(s): 57debeb

VLog hf gradio demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +159 -0
  2. examples/C8lMW0MODFs.log +3 -0
  3. examples/C8lMW0MODFs.mp4 +3 -0
  4. examples/XZVHmRvfDHM.log +3 -0
  5. examples/XZVHmRvfDHM.mp4 +3 -0
  6. examples/basketball_vlog.log +3 -0
  7. examples/basketball_vlog.mp4 +3 -0
  8. examples/buy_watermelon.log +3 -0
  9. examples/buy_watermelon.mp4 +3 -0
  10. examples/covid.log +3 -0
  11. examples/covid.mp4 +3 -0
  12. examples/huaqiang.log +3 -0
  13. examples/huaqiang.mp4 +3 -0
  14. examples/news.log +3 -0
  15. examples/news.mp4 +3 -0
  16. examples/outcGtbnMuQ.log +3 -0
  17. examples/outcGtbnMuQ.mp4 +3 -0
  18. examples/travel_in_roman.log +3 -0
  19. examples/travel_in_roman.mp4 +3 -0
  20. examples/travel_in_roman_full.log +3 -0
  21. examples/travel_in_roman_full.mp4 +3 -0
  22. examples/vlog.jpg +0 -0
  23. models/__init__.py +3 -0
  24. models/__pycache__/__init__.cpython-38.pyc +0 -0
  25. models/__pycache__/blip2_model.cpython-38.pyc +0 -0
  26. models/__pycache__/clip_model.cpython-38.pyc +0 -0
  27. models/__pycache__/gpt_model.cpython-38.pyc +0 -0
  28. models/__pycache__/grit_model.cpython-38.pyc +0 -0
  29. models/__pycache__/kts_model.cpython-38.pyc +0 -0
  30. models/__pycache__/vlog.cpython-38.pyc +0 -0
  31. models/__pycache__/whisper_model.cpython-38.pyc +0 -0
  32. models/blip2_model.py +47 -0
  33. models/clip_model.py +54 -0
  34. models/gpt_model.py +102 -0
  35. models/grit_model.py +21 -0
  36. models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc +0 -0
  37. models/grit_src/configs/Base.yaml +77 -0
  38. models/grit_src/configs/GRiT_B_DenseCap.yaml +20 -0
  39. models/grit_src/configs/GRiT_B_DenseCap_ObjectDet.yaml +23 -0
  40. models/grit_src/configs/GRiT_B_ObjectDet.yaml +20 -0
  41. models/grit_src/configs/GRiT_H_ObjectDet.yaml +21 -0
  42. models/grit_src/configs/GRiT_L_ObjectDet.yaml +20 -0
  43. models/grit_src/grit/__init__.py +7 -0
  44. models/grit_src/grit/__pycache__/__init__.cpython-38.pyc +0 -0
  45. models/grit_src/grit/__pycache__/config.cpython-38.pyc +0 -0
  46. models/grit_src/grit/__pycache__/predictor.cpython-38.pyc +0 -0
  47. models/grit_src/grit/config.py +50 -0
  48. models/grit_src/grit/custom_solver.py +88 -0
  49. models/grit_src/grit/data/__pycache__/custom_build_augmentation.cpython-38.pyc +0 -0
  50. models/grit_src/grit/data/__pycache__/custom_dataset_mapper.cpython-38.pyc +0 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import openai
4
+ import requests
5
+ import csv
6
+ import argparse
7
+ from models.vlog import Vlogger
8
+
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--video_path', default='examples/huaqiang.mp4')
11
+ parser.add_argument('--alpha', default=10, type=int, help='Determine the maximum segment number for KTS algorithm, the larger the value, the fewer segments.')
12
+ parser.add_argument('--beta', default=1, type=int, help='The smallest time gap between successive clips, in seconds.')
13
+ parser.add_argument('--data_dir', default='./examples', type=str, help='Directory for saving videos and logs.')
14
+ parser.add_argument('--tmp_dir', default='./tmp', type=str, help='Directory for saving intermediate files.')
15
+
16
+ # * Models settings *
17
+ parser.add_argument('--openai_api_key', default='xxx', type=str, help='OpenAI API key')
18
+ parser.add_argument('--image_caption', action='store_true', dest='image_caption', default=True, help='Set this flag to True if you want to use BLIP Image Caption')
19
+ parser.add_argument('--dense_caption', action='store_true', dest='dense_caption', default=True, help='Set this flag to True if you want to use Dense Caption')
20
+ parser.add_argument('--feature_extractor', default='openai/clip-vit-base-patch32', help='Select the feature extractor model for video segmentation')
21
+ parser.add_argument('--feature_extractor_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu')
22
+ parser.add_argument('--image_captioner', choices=['blip', 'blip2'], dest='captioner_base_model', default='blip2', help='blip2 requires 15G GPU memory, blip requires 6G GPU memory')
23
+ parser.add_argument('--image_captioner_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended')
24
+ parser.add_argument('--dense_captioner_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu, < 6G GPU is not recommended>')
25
+ parser.add_argument('--audio_translator', default='large')
26
+ parser.add_argument('--audio_translator_device', choices=['cuda', 'cpu'], default='cuda')
27
+ parser.add_argument('--gpt_version', choices=['gpt-3.5-turbo'], default='gpt-3.5-turbo')
28
+
29
+ args = parser.parse_args()
30
+
31
+
32
+ def get_empty_state():
33
+ return {"total_tokens": 0, "messages": []}
34
+
35
+
36
+ def submit_api_key_fn(api_key, vlogger):
37
+ try:
38
+ vlogger.init_llm_with_api_key(api_key)
39
+ return gr.update(value = "OpenAI key submitted successful 🎉"), True, vlogger
40
+
41
+ except Exception as e:
42
+ return gr.update(value = f"Error {e}"), False, vlogger
43
+
44
+
45
+ def submit_message(prompt, state, vlogger, api_key_submitted, vlog_loaded):
46
+ if not api_key_submitted:
47
+ return gr.update(value=''), [("👀", "Please enter your OpenAI API key 😊"),], state, vlogger
48
+
49
+ if not vlog_loaded:
50
+ return gr.update(value=''), [("👀", "Please follow the instruction to select a video and generate the document for chatting 😊"),], state, vlogger
51
+
52
+ history = state['messages']
53
+
54
+ if not prompt:
55
+ return gr.update(value=''), [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)], state, vlogger
56
+
57
+ prompt_msg = { "role": "user", "content": prompt }
58
+
59
+ try:
60
+ history.append(prompt_msg)
61
+ answer = vlogger.chat2video(prompt)
62
+ history.append({"role": "system", "content": answer})
63
+
64
+ except Exception as e:
65
+ history.append(prompt_msg)
66
+ history.append({
67
+ "role": "system",
68
+ "content": f"Error: {e}"
69
+ })
70
+
71
+ chat_messages = [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)]
72
+ return '', chat_messages, state, vlogger
73
+
74
+ def clear_conversation(vlogger):
75
+ vlogger.clean_history()
76
+
77
+ # return input_message, video_inp, chatbot, vlog_outp, state, vlogger, vlog_loaded
78
+ return gr.update(value=None, visible=True), gr.update(value=None, interactive=False), None, gr.update(value=None, visible=True), get_empty_state(), vlogger, False
79
+
80
+ def vlog_fn(vid_path, vlogger, api_key_submitted):
81
+ if not api_key_submitted:
82
+ log_text = "====== Please enter your OpenAI API key first 😊 ====="
83
+ return gr.update(value=log_text, visible=True), False, vlogger
84
+
85
+ print(vid_path)
86
+ if vid_path is None:
87
+ log_text = "====== Please select an video from examples first 🤔 ====="
88
+ vloaded_flag = False
89
+ else:
90
+ log_list = vlogger.video2log(vid_path)
91
+ log_text = "\n".join(log_list)
92
+ vloaded_flag = True
93
+ return gr.update(value=log_text, visible=True), vloaded_flag, vlogger
94
+
95
+ css = """
96
+ #col-container {max-width: 90%; margin-left: auto; margin-right: auto;}
97
+ #video_inp {min-height: 300px}
98
+ #chatbox {min-height: 100px;}
99
+ #header {text-align: center;
100
+ #hint {font-size: 0.9em; padding: 0.5em; margin: 0;}
101
+ .message { font-size: 1.2em; }
102
+ """
103
+
104
+ with gr.Blocks(css=css) as demo:
105
+
106
+ state = gr.State(get_empty_state())
107
+ vlogger = gr.State(Vlogger(args))
108
+ vlog_loaded = gr.State(False)
109
+ api_key_submitted = gr.State(False)
110
+
111
+
112
+ with gr.Column(elem_id="col-container"):
113
+ gr.Markdown("""## 🎞️ VLog Demo
114
+ Powered by BLIP2, GRIT, Whisper, ChatGPT and LangChain
115
+ Github: [https://github.com/showlab/VLog](https://github.com/showlab/VLog)""",
116
+ elem_id="header")
117
+ gr.Markdown("*Instruction*: For the current demo, please enter OpenAI api key, select an example video, click the button to generate a document and try chatting over the video 😊", elem_id="hint")
118
+ with gr.Row():
119
+ with gr.Column(scale=6):
120
+ video_inp = gr.Video(label="video_input", interactive=False)
121
+ chatbot = gr.Chatbot(elem_id="chatbox")
122
+ input_message = gr.Textbox(show_label=False, placeholder="Enter text and press enter", visible=True).style(container=False)
123
+ btn_submit = gr.Button("Submit")
124
+ btn_clear_conversation = gr.Button("🔃 Start New Conversation")
125
+
126
+ with gr.Column(scale=6):
127
+ vlog_btn = gr.Button("Generate Video Document")
128
+ vlog_outp = gr.Textbox(label="Document output", lines=30)
129
+
130
+ with gr.Column(scale=1):
131
+ openai_api_key = gr.Textbox(
132
+ placeholder="Input OpenAI API key and press Enter",
133
+ show_label=False,
134
+ label = "OpenAI API Key",
135
+ lines=1,
136
+ type="password"
137
+ )
138
+ examples = gr.Examples(
139
+ examples=[
140
+ ["examples/basketball_vlog.mp4"],
141
+ ["examples/travel_in_roman.mp4"],
142
+ ["examples/C8lMW0MODFs.mp4"],
143
+ ["examples/outcGtbnMuQ.mp4"],
144
+ ["examples/huaqiang.mp4"],
145
+ ],
146
+ inputs=[video_inp],
147
+ )
148
+
149
+ gr.HTML('''<br><br><br><center>You can duplicate this Space to skip the queue:<a href="https://huggingface.co/spaces/TencentARC/VLog?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a><br></center>''')
150
+
151
+ btn_submit.click(submit_message, [input_message, state, vlogger, api_key_submitted, vlog_loaded], [input_message, chatbot, state, vlogger])
152
+ input_message.submit(submit_message, [input_message, state, vlogger, api_key_submitted, vlog_loaded], [input_message, chatbot, state, vlogger])
153
+ btn_clear_conversation.click(clear_conversation, [vlogger], [input_message, video_inp, chatbot, vlog_outp, state, vlogger, vlog_loaded])
154
+ vlog_btn.click(vlog_fn, [video_inp, vlogger, api_key_submitted], [vlog_outp, vlog_loaded, vlogger])
155
+ openai_api_key.submit(submit_api_key_fn, [openai_api_key, vlogger], [vlog_outp, api_key_submitted, vlogger])
156
+ demo.load(queur=False)
157
+
158
+ demo.queue(concurrency_count=10)
159
+ demo.launch(height='800px', server_port=8749, debug=True, share=True)
examples/C8lMW0MODFs.log ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b044e554f8dc7a790b02aa1ebc391165b84d93cce9579fa6b2fe0418cd4d1122
3
+ size 9075
examples/C8lMW0MODFs.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d094489e459ae952880f4cbd8fdbcc790df1a69ccf9fb4f6c5fca998b6871133
3
+ size 10537029
examples/XZVHmRvfDHM.log ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e37046ae268e20f3d44df7410954c3cf5ffd73116e6f5e3f9ef73a690f001d51
3
+ size 7262
examples/XZVHmRvfDHM.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2da0eae7e0b18c04ad4f2b8124a09fbbde407eeedb0a532dbf40701c8c744b5
3
+ size 1961212
examples/basketball_vlog.log ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b2d18c6d7d7c5061ae41b9cd2b8cc0828d2aee2b02b40b4286fdd26905b0ac0
3
+ size 23527
examples/basketball_vlog.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5d6034c324f3e9de35278783ed68a85081ef74a252c9394e273b339f7d1b6c3
3
+ size 32376805
examples/buy_watermelon.log ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cd0e7bfca9fba4b71428235d41b446083ffe8d7496ef43249f7438017def067
3
+ size 3922
examples/buy_watermelon.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:926ee7ec1ca4d3e0674a647bf84887bdf077961c3972148ae23fb569c22e0e4e
3
+ size 6209789
examples/covid.log ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50281df7c21815c662e2f03e461b02dd5b2f8253a3f92bcd1dfca4229d89e3ce
3
+ size 9782
examples/covid.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53c35480ff6ac15f2f8747aa9ba9dc36086d5f4e342ac79eac5e43e5bd248817
3
+ size 16090827
examples/huaqiang.log ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cd0e7bfca9fba4b71428235d41b446083ffe8d7496ef43249f7438017def067
3
+ size 3922
examples/huaqiang.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:926ee7ec1ca4d3e0674a647bf84887bdf077961c3972148ae23fb569c22e0e4e
3
+ size 6209789
examples/news.log ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42068b9573daee32bf33d5aa4049f937bfa2cb3c6472d40c42332f8d2173a929
3
+ size 8968
examples/news.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:905e453db16213c962d01371357877b8a168da50508676b81cf474d431d3d2ca
3
+ size 23599849
examples/outcGtbnMuQ.log ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45a3911acfe78745ed9cfc9502deebef1ab6912dc89566735fcbdf7acda00b44
3
+ size 63033
examples/outcGtbnMuQ.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47f4ddd4debd3c5955cb7c0a1f5e2ffa9c0d6a171931898ee085c5eab521f33d
3
+ size 98609326
examples/travel_in_roman.log ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f90d4b4c46322b6f15984b64aaedf35232b2fb21ddac518f1c5784fe25944e3c
3
+ size 9166
examples/travel_in_roman.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b02522bb72215bcb1a69657af9d08cad0141e1b3e30553024609cb0927471e04
3
+ size 34442658
examples/travel_in_roman_full.log ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a163943c7676168b51adf08e5305dd78fba7c54e6cd00330c06541eb23d0d23
3
+ size 45295
examples/travel_in_roman_full.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1cb54427e21a1ccbba23bfe5314e4ae2d45658d0b4b654f815abf1861c1ca3c
3
+ size 92642344
examples/vlog.jpg ADDED
models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ from .kts_src import *
2
+ from .clip_model import *
3
+ from .grit_model import *
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (215 Bytes). View file
models/__pycache__/blip2_model.cpython-38.pyc ADDED
Binary file (2.02 kB). View file
models/__pycache__/clip_model.cpython-38.pyc ADDED
Binary file (1.91 kB). View file
models/__pycache__/gpt_model.cpython-38.pyc ADDED
Binary file (3.43 kB). View file
models/__pycache__/grit_model.cpython-38.pyc ADDED
Binary file (1.21 kB). View file
models/__pycache__/kts_model.cpython-38.pyc ADDED
Binary file (1.34 kB). View file
models/__pycache__/vlog.cpython-38.pyc ADDED
Binary file (4.34 kB). View file
models/__pycache__/whisper_model.cpython-38.pyc ADDED
Binary file (1.24 kB). View file
models/blip2_model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration, BlipProcessor, BlipForConditionalGeneration
4
+
5
+ class ImageCaptioner:
6
+ def __init__(self, model_name="blip2-opt", device="cpu"):
7
+ self.model_name = model_name
8
+ self.device = device
9
+ self.processor, self.model = self.initialize_model()
10
+
11
+ def initialize_model(self):
12
+ if self.device == 'cpu':
13
+ self.data_type = torch.float32
14
+ else:
15
+ self.data_type = torch.float16
16
+ processor, model = None, None
17
+ if self.model_name == "blip2-opt":
18
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b-coco")
19
+ model = Blip2ForConditionalGeneration.from_pretrained(
20
+ "Salesforce/blip2-opt-2.7b-coco", torch_dtype=self.data_type, low_cpu_mem_usage=True)
21
+
22
+ elif self.model_name == "blip2-flan-t5":
23
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
24
+ model = Blip2ForConditionalGeneration.from_pretrained(
25
+ "Salesforce/blip2-flan-t5-xl", torch_dtype=self.data_type, low_cpu_mem_usage=True)
26
+
27
+ # for gpu with small memory
28
+ elif self.model_name == "blip":
29
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
30
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
31
+
32
+ else:
33
+ raise NotImplementedError(f"{self.model_name} not implemented.")
34
+ model.to(self.device)
35
+
36
+ if self.device != 'cpu':
37
+ model.half()
38
+ return processor, model
39
+
40
+ def image_caption(self, image):
41
+ inputs = self.processor(images=image, return_tensors="pt").to(self.device, self.data_type)
42
+ generated_ids = self.model.generate(**inputs)
43
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
44
+ return generated_text
45
+
46
+ def image_caption_debug(self, image_src):
47
+ return "A dish with salmon, broccoli, and something yellow."
models/clip_model.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import pdb
4
+ import torch
5
+ import numpy as np
6
+ from PIL import Image
7
+ from transformers import CLIPProcessor, CLIPVisionModelWithProjection
8
+ from transformers import logging
9
+ logging.set_verbosity_error()
10
+
11
+ class FeatureExtractor():
12
+ def __init__(self, args):
13
+ self.device = args.feature_extractor_device
14
+ self.beta = args.beta
15
+ self.processor = CLIPProcessor.from_pretrained(args.feature_extractor)
16
+ self.model = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor).to(self.device)
17
+ self.data_dir = args.data_dir
18
+ self.tmp_dir = args.tmp_dir
19
+
20
+
21
+ def __call__(self, video_path, video_id):
22
+ cap = cv2.VideoCapture(video_path)
23
+ fps = cap.get(cv2.CAP_PROP_FPS)
24
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
25
+ video_length = frame_count / fps
26
+ sample_rate = int(fps) * self.beta
27
+
28
+ save_path = os.path.join(self.tmp_dir, video_id + '.npz')
29
+ if os.path.exists(save_path):
30
+ data = np.load(save_path)
31
+ clip_features = data['features']
32
+ return clip_features, video_length
33
+
34
+ clip_features = []
35
+ print("Extract the clip feature.")
36
+ while True:
37
+ ret, frame = cap.read()
38
+ if not ret:
39
+ break
40
+
41
+ if cap.get(cv2.CAP_PROP_POS_FRAMES) % sample_rate == 0:
42
+ image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
43
+ inputs = self.processor(images=image, return_tensors="pt").pixel_values
44
+ inputs = inputs.to(self.device)
45
+
46
+ with torch.no_grad():
47
+ feat = self.model(inputs)['image_embeds']
48
+ clip_features.append(feat.cpu().numpy())
49
+ print("Finished.")
50
+
51
+ clip_features = np.concatenate(clip_features, axis=0)
52
+ np.savez_compressed(save_path, features=clip_features)
53
+
54
+ return clip_features, video_length
models/gpt_model.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import pickle
4
+ from langchain.llms import OpenAI
5
+ from langchain.vectorstores.faiss import FAISS
6
+ from langchain.chains import ChatVectorDBChain
7
+ from langchain.prompts.prompt import PromptTemplate
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain.document_loaders import UnstructuredFileLoader
10
+ from langchain.embeddings import OpenAIEmbeddings
11
+
12
+ _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
13
+ You can assume the discussion is about the video content.
14
+ Chat History:
15
+ {chat_history}
16
+ Follow Up Input: {question}
17
+ Standalone question:"""
18
+ CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
19
+
20
+ qa_template = """You are an AI assistant designed for answering questions about a video.
21
+ You are given a document and a question, the document records what people see and hear from this video.
22
+ Try to connet these information and provide a conversational answer.
23
+ Question: {question}
24
+ =========
25
+ {context}
26
+ =========
27
+ """
28
+ QA_PROMPT = PromptTemplate(template=qa_template, input_variables=["question", "context"])
29
+
30
+
31
+ class LlmReasoner():
32
+ def __init__(self, args):
33
+ self.history = []
34
+ self.gpt_version = args.gpt_version
35
+ self.data_dir = args.data_dir
36
+ self.tmp_dir = args.tmp_dir
37
+ self.qa_chain = None
38
+ self.vectorstore = None
39
+ self.top_k = 3
40
+ self.llm = OpenAI(temperature=0, model_name=self.gpt_version)
41
+
42
+ def exist_vectorstore(self, video_id):
43
+ pkl_path = os.path.join(self.tmp_dir, f"{video_id}.pkl")
44
+ log_path = os.path.join(self.data_dir, f"{video_id}.log")
45
+ if os.path.exists(pkl_path) and os.path.exists(log_path):
46
+ with open(pkl_path, 'rb') as file:
47
+ self.vectorstore = pickle.load(file)
48
+
49
+ self.qa_chain = ChatVectorDBChain.from_llm(
50
+ self.llm,
51
+ self.vectorstore,
52
+ qa_prompt=QA_PROMPT,
53
+ condense_question_prompt=CONDENSE_QUESTION_PROMPT,
54
+ )
55
+ self.qa_chain.top_k_docs_for_context = self.top_k
56
+ return True
57
+ return False
58
+
59
+ def create_vectorstore(self, video_id):
60
+ pkl_path = os.path.join(self.tmp_dir, f"{video_id}.pkl")
61
+
62
+ if not os.path.exists(pkl_path):
63
+ loader = UnstructuredFileLoader(os.path.join(self.data_dir, f"{video_id}.log"))
64
+ raw_documents = loader.load()
65
+
66
+ # Split text
67
+ text_splitter = RecursiveCharacterTextSplitter()
68
+ documents = text_splitter.split_documents(raw_documents)
69
+
70
+ # Load Data to vectorstore
71
+ embeddings = OpenAIEmbeddings()
72
+ vectorstore = FAISS.from_documents(documents, embeddings)
73
+
74
+ # Save vectorstore
75
+ with open(pkl_path, "wb") as f:
76
+ pickle.dump(vectorstore, f)
77
+
78
+
79
+ with open(pkl_path, 'rb') as file:
80
+ self.vectorstore = pickle.load(file)
81
+
82
+ self.qa_chain = ChatVectorDBChain.from_llm(
83
+ self.llm,
84
+ self.vectorstore,
85
+ qa_prompt=QA_PROMPT,
86
+ condense_question_prompt=CONDENSE_QUESTION_PROMPT,
87
+ )
88
+ self.qa_chain.top_k_docs_for_context = self.top_k
89
+
90
+ return
91
+
92
+ def __call__(self, question):
93
+ print(f"Question: {question}")
94
+ response = self.qa_chain({"question": question, "chat_history": self.history})["answer"]
95
+ self.history.append((question, response))
96
+
97
+ print(f"Assistant: {response}")
98
+ print("\n")
99
+ return response
100
+
101
+ def clean_history(self):
102
+ self.history = []
models/grit_model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from models.grit_src.image_dense_captions import image_caption_api
3
+
4
+ class DenseCaptioner():
5
+ def __init__(self, device):
6
+ self.device = device
7
+
8
+ def initialize_model(self):
9
+ pass
10
+
11
+ def image_dense_caption_debug(self, image_src):
12
+ dense_caption = """
13
+ 1. the broccoli is green, [0, 0, 333, 325];
14
+ 2. a piece of broccoli, [0, 147, 143, 324];
15
+ 3. silver fork on plate, [4, 547, 252, 612];
16
+ """
17
+ return dense_caption
18
+
19
+ def image_dense_caption(self, image_src):
20
+ dense_caption = image_caption_api(image_src, self.device)
21
+ return dense_caption
models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc ADDED
Binary file (2.33 kB). View file
models/grit_src/configs/Base.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ META_ARCHITECTURE: "GRiT"
3
+ MASK_ON: True
4
+ PROPOSAL_GENERATOR:
5
+ NAME: "CenterNet"
6
+ FPN:
7
+ IN_FEATURES: ["layer3", "layer4", "layer5"]
8
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
9
+ PIXEL_STD: [58.395, 57.12, 57.375]
10
+ ROI_HEADS:
11
+ NAME: GRiTROIHeadsAndTextDecoder
12
+ IN_FEATURES: ["p3", "p4", "p5"]
13
+ IOU_THRESHOLDS: [0.6]
14
+ NUM_CLASSES: 1
15
+ SCORE_THRESH_TEST: 0.02
16
+ NMS_THRESH_TEST: 0.5
17
+ OBJECT_FEAT_POOLER_RES: 14
18
+ ROI_BOX_CASCADE_HEAD:
19
+ IOUS: [0.6, 0.7, 0.8]
20
+ ROI_BOX_HEAD:
21
+ NAME: "FastRCNNConvFCHead"
22
+ NUM_FC: 2
23
+ POOLER_RESOLUTION: 7
24
+ CLS_AGNOSTIC_BBOX_REG: True
25
+ MULT_PROPOSAL_SCORE: True
26
+ ROI_MASK_HEAD:
27
+ NAME: "MaskRCNNConvUpsampleHead"
28
+ NUM_CONV: 4
29
+ POOLER_RESOLUTION: 14
30
+ CLS_AGNOSTIC_MASK: True
31
+ CENTERNET:
32
+ NUM_CLASSES: 1
33
+ REG_WEIGHT: 1.
34
+ NOT_NORM_REG: True
35
+ ONLY_PROPOSAL: True
36
+ WITH_AGN_HM: True
37
+ INFERENCE_TH: 0.0001
38
+ PRE_NMS_TOPK_TRAIN: 4000
39
+ POST_NMS_TOPK_TRAIN: 2000
40
+ PRE_NMS_TOPK_TEST: 1000
41
+ POST_NMS_TOPK_TEST: 256
42
+ NMS_TH_TRAIN: 0.9
43
+ NMS_TH_TEST: 0.9
44
+ POS_WEIGHT: 0.5
45
+ NEG_WEIGHT: 0.5
46
+ IGNORE_HIGH_FP: 0.85
47
+ DATASETS:
48
+ TRAIN: ("coco_2017_train",)
49
+ TEST: ("coco_2017_val",)
50
+ DATALOADER:
51
+ SAMPLER_TRAIN: "MultiDatasetSampler"
52
+ DATASET_RATIO: [1]
53
+ DATASET_INPUT_SIZE: [1024]
54
+ DATASET_INPUT_SCALE: [[0.1, 2.0]]
55
+ FILTER_EMPTY_ANNOTATIONS: False
56
+ NUM_WORKERS: 8
57
+ TEST:
58
+ DETECTIONS_PER_IMAGE: 256
59
+ SOLVER:
60
+ LR_SCHEDULER_NAME: "WarmupCosineLR"
61
+ CHECKPOINT_PERIOD: 10000
62
+ WARMUP_ITERS: 1000
63
+ WARMUP_FACTOR: 0.001
64
+ USE_CUSTOM_SOLVER: True
65
+ OPTIMIZER: "ADAMW"
66
+ MAX_ITER: 180000
67
+ IMS_PER_BATCH: 64
68
+ BASE_LR: 0.00008
69
+ VIT_LAYER_DECAY: True
70
+ CLIP_GRADIENTS:
71
+ ENABLED: True
72
+ INPUT:
73
+ FORMAT: RGB
74
+ CUSTOM_AUG: EfficientDetResizeCrop
75
+ TRAIN_SIZE: 640
76
+ USE_ACT_CHECKPOINT: True
77
+ VERSION: 2
models/grit_src/configs/GRiT_B_DenseCap.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base.yaml"
2
+ MODEL:
3
+ TRAIN_TASK: ["DenseCap"]
4
+ TEST_TASK: "DenseCap"
5
+ MASK_ON: False
6
+ ROI_HEADS:
7
+ SOFT_NMS_ENABLED: False
8
+ BEAM_SIZE: 1
9
+ WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth"
10
+ BACKBONE:
11
+ NAME: build_vit_fpn_backbone
12
+ VIT_LAYERS: 12
13
+ SOLVER:
14
+ VIT_LAYER_DECAY_RATE: 0.7
15
+ DATASETS:
16
+ TRAIN: ("vg_train",)
17
+ TEST: ("vg_test",)
18
+ DATALOADER:
19
+ DATASET_BS: 2
20
+ OUTPUT_DIR: "./output/GRiT_B_DenseCap"
models/grit_src/configs/GRiT_B_DenseCap_ObjectDet.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base.yaml"
2
+ MODEL:
3
+ TRAIN_TASK: ["ObjectDet", "DenseCap"]
4
+ TEST_TASK: "DenseCap" # DenseCap or ObjectDet: Choose one for testing
5
+ MASK_ON: True
6
+ ROI_HEADS:
7
+ SOFT_NMS_ENABLED: False
8
+ BEAM_SIZE: 1
9
+ WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth"
10
+ BACKBONE:
11
+ NAME: build_vit_fpn_backbone
12
+ VIT_LAYERS: 12
13
+ SOLVER:
14
+ VIT_LAYER_DECAY_RATE: 0.7
15
+ DATASETS:
16
+ TRAIN: ("GRiT_coco2017_train", "vg_train")
17
+ TEST: ("coco_2017_test-dev",)
18
+ DATALOADER:
19
+ DATASET_RATIO: [1, 1]
20
+ DATASET_BS: 2
21
+ DATASET_INPUT_SIZE: [1024, 1024]
22
+ DATASET_INPUT_SCALE: [[0.1, 2.0], [0.1, 2.0]]
23
+ OUTPUT_DIR: "./output/GRiT_B_DenseCap_ObjectDet"
models/grit_src/configs/GRiT_B_ObjectDet.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base.yaml"
2
+ MODEL:
3
+ TRAIN_TASK: ["ObjectDet"]
4
+ TEST_TASK: "ObjectDet"
5
+ MASK_ON: True
6
+ ROI_HEADS:
7
+ SOFT_NMS_ENABLED: True
8
+ BEAM_SIZE: 3
9
+ WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth"
10
+ BACKBONE:
11
+ NAME: build_vit_fpn_backbone
12
+ VIT_LAYERS: 12
13
+ SOLVER:
14
+ VIT_LAYER_DECAY_RATE: 0.7
15
+ DATASETS:
16
+ TRAIN: ("GRiT_coco2017_train",)
17
+ TEST: ("coco_2017_val",)
18
+ DATALOADER:
19
+ DATASET_BS: 2
20
+ OUTPUT_DIR: "./output/GRiT_B_ObjectDet"
models/grit_src/configs/GRiT_H_ObjectDet.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base.yaml"
2
+ MODEL:
3
+ TRAIN_TASK: ["ObjectDet"]
4
+ TEST_TASK: "ObjectDet"
5
+ MASK_ON: True
6
+ ROI_HEADS:
7
+ SOFT_NMS_ENABLED: True
8
+ BEAM_SIZE: 3
9
+ WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_huge_p14to16.pth"
10
+ BACKBONE:
11
+ NAME: build_vit_fpn_backbone_huge
12
+ VIT_LAYERS: 32
13
+ SOLVER:
14
+ MAX_ITER: 135000
15
+ VIT_LAYER_DECAY_RATE: 0.9
16
+ DATASETS:
17
+ TRAIN: ("GRiT_coco2017_train",)
18
+ TEST: ("coco_2017_val",)
19
+ DATALOADER:
20
+ DATASET_BS: 1
21
+ OUTPUT_DIR: "./output/GRiT_H_ObjectDet"
models/grit_src/configs/GRiT_L_ObjectDet.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base.yaml"
2
+ MODEL:
3
+ TRAIN_TASK: ["ObjectDet"]
4
+ TEST_TASK: "ObjectDet"
5
+ MASK_ON: True
6
+ ROI_HEADS:
7
+ SOFT_NMS_ENABLED: True
8
+ BEAM_SIZE: 3
9
+ WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_large.pth"
10
+ BACKBONE:
11
+ NAME: build_vit_fpn_backbone_large
12
+ VIT_LAYERS: 24
13
+ SOLVER:
14
+ VIT_LAYER_DECAY_RATE: 0.8
15
+ DATASETS:
16
+ TRAIN: ("GRiT_coco2017_train",)
17
+ TEST: ("coco_2017_val",)
18
+ DATALOADER:
19
+ DATASET_BS: 1
20
+ OUTPUT_DIR: "./output/GRiT_L_ObjectDet"
models/grit_src/grit/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ from .modeling.meta_arch import grit
2
+ from .modeling.roi_heads import grit_roi_heads
3
+ from .modeling.backbone import vit
4
+
5
+ from .data.datasets import object365
6
+ from .data.datasets import vg
7
+ from .data.datasets import grit_coco
models/grit_src/grit/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (414 Bytes). View file
models/grit_src/grit/__pycache__/config.cpython-38.pyc ADDED
Binary file (1.41 kB). View file
models/grit_src/grit/__pycache__/predictor.cpython-38.pyc ADDED
Binary file (2.65 kB). View file
models/grit_src/grit/config.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from detectron2.config import CfgNode as CN
2
+
3
+
4
+ def add_grit_config(cfg):
5
+ _C = cfg
6
+
7
+ _C.MODEL.BEAM_SIZE = 1
8
+ _C.MODEL.TRAIN_TASK = ["ObjectDet", "DenseCap"]
9
+ _C.MODEL.TEST_TASK = "DenseCap" # This can be varied if the model is jointly trained on multiple tasks
10
+
11
+ _C.MODEL.ROI_BOX_HEAD.USE_BIAS = 0.0 # >= 0: not use
12
+ _C.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE = False
13
+
14
+ _C.MODEL.ROI_HEADS.MASK_WEIGHT = 1.0
15
+ _C.MODEL.ROI_HEADS.OBJECT_FEAT_POOLER_RES = 14
16
+ _C.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False
17
+
18
+ # Backbones
19
+ _C.MODEL.VIT_LAYERS = 12
20
+
21
+ # Text Decoder
22
+ _C.TEXT_DECODER = CN()
23
+ _C.TEXT_DECODER.VOCAB_SIZE = 30522
24
+ _C.TEXT_DECODER.HIDDEN_SIZE = 768
25
+ _C.TEXT_DECODER.NUM_LAYERS = 6
26
+ _C.TEXT_DECODER.ATTENTION_HEADS = 12
27
+ _C.TEXT_DECODER.FEEDFORWARD_SIZE = 768 * 4
28
+
29
+ # Multi-dataset dataloader
30
+ _C.DATALOADER.DATASET_RATIO = [1, 1] # sample ratio
31
+ _C.DATALOADER.DATASET_BS = 1
32
+ _C.DATALOADER.DATASET_INPUT_SIZE = [1024, 1024]
33
+ _C.DATALOADER.DATASET_INPUT_SCALE = [(0.1, 2.0), (0.1, 2.0)]
34
+ _C.DATALOADER.DATASET_MIN_SIZES = [(640, 800), (640, 800)]
35
+ _C.DATALOADER.DATASET_MAX_SIZES = [1333, 1333]
36
+
37
+ _C.SOLVER.USE_CUSTOM_SOLVER = True
38
+ _C.SOLVER.OPTIMIZER = 'ADAMW'
39
+ _C.SOLVER.VIT_LAYER_DECAY = True
40
+ _C.SOLVER.VIT_LAYER_DECAY_RATE = 0.7
41
+
42
+ _C.INPUT.CUSTOM_AUG = 'EfficientDetResizeCrop'
43
+ _C.INPUT.TRAIN_SIZE = 1024
44
+ _C.INPUT.TEST_SIZE = 1024
45
+ _C.INPUT.SCALE_RANGE = (0.1, 2.)
46
+ # 'default' for fixed short / long edge
47
+ _C.INPUT.TEST_INPUT_TYPE = 'default'
48
+
49
+ _C.FIND_UNUSED_PARAM = True
50
+ _C.USE_ACT_CHECKPOINT = True
models/grit_src/grit/custom_solver.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ # Modified by Jialian Wu from https://github.com/facebookresearch/Detic/blob/main/detic/custom_solver.py
3
+ import itertools
4
+ from typing import Any, Callable, Dict, Iterable, List, Set, Type, Union
5
+ import torch
6
+
7
+ from detectron2.config import CfgNode
8
+
9
+ from detectron2.solver.build import maybe_add_gradient_clipping
10
+
11
+
12
+ def build_custom_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
13
+ params: List[Dict[str, Any]] = []
14
+ memo: Set[torch.nn.parameter.Parameter] = set()
15
+ optimizer_type = cfg.SOLVER.OPTIMIZER
16
+
17
+ for key, value in model.named_parameters(recurse=True):
18
+ if not value.requires_grad:
19
+ continue
20
+ # Avoid duplicating parameters
21
+ if value in memo:
22
+ continue
23
+ memo.add(value)
24
+ lr = cfg.SOLVER.BASE_LR
25
+ weight_decay = cfg.SOLVER.WEIGHT_DECAY
26
+
27
+ if cfg.SOLVER.VIT_LAYER_DECAY:
28
+ lr = lr * get_vit_lr_decay_rate(key, cfg.SOLVER.VIT_LAYER_DECAY_RATE, cfg.MODEL.VIT_LAYERS)
29
+
30
+ param = {"params": [value], "lr": lr}
31
+ if optimizer_type != 'ADAMW':
32
+ param['weight_decay'] = weight_decay
33
+ params += [param]
34
+
35
+ def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
36
+ # detectron2 doesn't have full model gradient clipping now
37
+ clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
38
+ enable = (
39
+ cfg.SOLVER.CLIP_GRADIENTS.ENABLED
40
+ and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
41
+ and clip_norm_val > 0.0
42
+ )
43
+
44
+ class FullModelGradientClippingOptimizer(optim):
45
+ def step(self, closure=None):
46
+ all_params = itertools.chain(*[x["params"] for x in self.param_groups])
47
+ torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
48
+ super().step(closure=closure)
49
+
50
+ return FullModelGradientClippingOptimizer if enable else optim
51
+
52
+
53
+ if optimizer_type == 'SGD':
54
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
55
+ params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM,
56
+ nesterov=cfg.SOLVER.NESTEROV
57
+ )
58
+ elif optimizer_type == 'ADAMW':
59
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
60
+ params, cfg.SOLVER.BASE_LR,
61
+ weight_decay=cfg.SOLVER.WEIGHT_DECAY
62
+ )
63
+ else:
64
+ raise NotImplementedError(f"no optimizer type {optimizer_type}")
65
+ if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
66
+ optimizer = maybe_add_gradient_clipping(cfg, optimizer)
67
+ return optimizer
68
+
69
+
70
+ def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
71
+ """
72
+ Calculate lr decay rate for different ViT blocks.
73
+ Args:
74
+ name (string): parameter name.
75
+ lr_decay_rate (float): base lr decay rate.
76
+ num_layers (int): number of ViT blocks.
77
+
78
+ Returns:
79
+ lr decay rate for the given parameter.
80
+ """
81
+ layer_id = num_layers + 1
82
+ if name.startswith("backbone"):
83
+ if ".pos_embed" in name or ".patch_embed" in name:
84
+ layer_id = 0
85
+ elif ".blocks." in name and ".residual." not in name:
86
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
87
+
88
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
models/grit_src/grit/data/__pycache__/custom_build_augmentation.cpython-38.pyc ADDED
Binary file (1.22 kB). View file
models/grit_src/grit/data/__pycache__/custom_dataset_mapper.cpython-38.pyc ADDED
Binary file (5.69 kB). View file