diff --git a/.gitattributes b/.gitattributes index c7d9f3332a950355d5a77d85000f05e6f45435ea..11ed3b03004e803a21bc9f30ff5c913f7c909fa7 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.jar filter=lfs diff=lfs merge=lfs -text +*.mp4 filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..fa960a207e89d4c8e33c6f28096d1acc55910b26 --- /dev/null +++ b/.gitignore @@ -0,0 +1,71 @@ +results/ +output_*/ +icl_inference_output/ +.vscode/ +tmp/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class +*.ipynb + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..b4f7ee2451c3a2e39e946a33ddaaa04d660180bd --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Yiqin Wang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 276c9176d6040e1ec61d752991e092f20206ca4f..e5baae608a8729271f37c3b132e23e02abb6af39 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ --- title: ChatVID -emoji: 🐨 -colorFrom: gray -colorTo: blue +emoji: 🎥 +colorFrom: green +colorTo: red sdk: gradio -sdk_version: 3.34.0 +sdk_version: 3.30.0 app_file: app.py pinned: false license: mit diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..fe4fe6b41b3c7782941e50814d3f93929bdd1561 --- /dev/null +++ b/app.py @@ -0,0 +1,97 @@ +import argparse +import time + +import gradio as gr + +from config.config_utils import get_config +from model import Captioner, VicunaHandler + + +def set_example_video(example: list) -> dict: + return gr.Video.update(value=example[0]) + + +def upload_file(files): + file_paths = [file.name for file in files] + return file_paths + + +def upload_video(video): + print(video) + return video + + +def respond(input, chat_history): + bot_response = handler.gr_chat(input) + chat_history.append((input, bot_response)) + time.sleep(0.1) + return "", chat_history + + +def clear_chat(chat_history): + handler.chatbot.clear_conv_() + + return "", [] + + + +config = get_config('config/infer.yaml') + +captioner = Captioner(config) # global + +global handler +handler = VicunaHandler(config['vicuna']) + +with gr.Blocks(theme=gr.themes.Soft()) as demo: + gr.Markdown("##

ChatVID

") + gr.Markdown(""" + ChatVID is a video chatbot that can chat about any video. + """) + with gr.Row(): + with gr.Column(): + video_path = gr.Video(label="Video") + + with gr.Column(): + upload_button = gr.Button( + "Upload & Watch. (Click once and wait 3min )") + chat_button = gr.Button("Let's Chat!", interactive=False) + num_frames = gr.Slider( + minimum=5, + value=12, + maximum=12, + step=1, + label="Number of frames (no more than 12)") + + with gr.Column(): + chatbot = gr.Chatbot() + captions = gr.State("") + with gr.Row(visible=False) as input: + with gr.Column(scale=0.7): + txt = gr.Textbox( + show_label=False, + placeholder="Enter text and press enter").style( + container=False) + with gr.Column(scale=0.15, min_width=0): + run_button = gr.Button("RUN!") + with gr.Column(scale=0.15, min_width=0): + clear_button = gr.Button("CLEAR") + + upload_button.click( + lambda: gr.update(interactive=False), None, chat_button).then( + lambda: gr.update(visible=False), None, + input).then(lambda: [], None, chatbot).then( + captioner.caption_video, [video_path, num_frames], + [captions]).then(lambda: gr.update(interactive=True), None, + chat_button) + + chat_button.click(handler.gr_chatbot_init, [captions], + None).then(lambda: gr.update(visible=True), None, + input) + + txt.submit(respond, inputs=[txt, chatbot], outputs=[txt, chatbot]) + run_button.click( + respond, inputs=[txt, chatbot], outputs=[txt, chatbot]) + clear_button.click( + clear_chat, inputs=[chatbot], outputs=[txt, chatbot]) + +demo.launch(share=True) diff --git a/config/config_utils.py b/config/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5f0d4d9ff58522738e8afcf425454ace3a1886fb --- /dev/null +++ b/config/config_utils.py @@ -0,0 +1,14 @@ +def get_config( + config_path: str +): + import yaml + f = open(config_path, "r") + config = yaml.load(f.read(), yaml.Loader) + f.close() + return config + +def save_config( + config: dict, + file_path: str, +): + pass \ No newline at end of file diff --git a/config/debug.yaml b/config/debug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0b06e564d8562e4f420a3ef231ba0e2a33dc475f --- /dev/null +++ b/config/debug.yaml @@ -0,0 +1,23 @@ +device: 'cuda' +video_path: '/mnt/petrelfs/wangyiqin/vid_cap/examples/videos/' +video_name: 'cook_720p.mp4' +fps: 120 + +vicuna: + model_path: '/mnt/petrelfs/wangyiqin/vid_cap/vicuna-7b' + device: 'cuda' + num_gpus: 1 + max_gpu_memory: '40Gib' + load_8bit: True + conv_template: + temperature: 1.0 + max_new_tokens: 512 + debug: False + output_path: '/mnt/petrelfs/wangyiqin/vid_cap/VideoChatDuplicate/examples/test.json' + +vid2seq: + enable: True + clip_path: '/mnt/petrelfs/wangyiqin/vid_cap/examples/ViT-L-14.pt' + output_path: '/mnt/petrelfs/wangyiqin/vid_cap/examples/' + work_dir: 'vid2seq_workdir' + config_path: '/mnt/petrelfs/wangyiqin/vid_cap/scenic/scenic/projects/vid2seq/configs/youcook2.py' diff --git a/config/infer.yaml b/config/infer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..893a35a82f582e4fc34cf823a558dbca9cc0d70e --- /dev/null +++ b/config/infer.yaml @@ -0,0 +1,16 @@ +device: 'cuda' + +vicuna: + model_path: '/home/user/app/vicuna-7b' + device: 'cuda' + num_gpus: 'auto' + max_gpu_memory: '24Gib' + load_8bit: True + conv_template: + temperature: 1.0 + max_new_tokens: 512 + debug: False + output_path: '/home/user/app/vicuna_out.json' + +vid2seq: + enable: False diff --git a/config/local_infer.yaml b/config/local_infer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..31c478eecae2285a54e21a4cb09cc7d4be1f9b3e --- /dev/null +++ b/config/local_infer.yaml @@ -0,0 +1,21 @@ +device: 'cuda' + +vicuna: + model_path: '/mnt/petrelfs/wangyiqin/vid_cap/ChatVID/vicuna-7b' + device: 'cuda' + num_gpus: 1 + max_gpu_memory: '24Gib' + load_8bit: True + conv_template: + temperature: 1.0 + max_new_tokens: 512 + debug: False + output_path: '/mnt/petrelfs/wangyiqin/vid_cap/ChatVID/examples/vicuna_out.json' + +vid2seq: + enable: True + clip_path: '/mnt/petrelfs/wangyiqin/vid_cap/ChatVID/clip_ckpt/ViT-L-14.pt' + output_path: '/mnt/petrelfs/wangyiqin/vid_cap/ChatVID/examples/' + work_dir: 'vid2seq_workdir' + config_path: 'config/vid2seq_config.py' + checkpoint_path: '/mnt/petrelfs/wangyiqin/vid_cap/ChatVID/vid2seq_ckpt' #only folder name diff --git a/config/vid2seq_config.py b/config/vid2seq_config.py new file mode 100644 index 0000000000000000000000000000000000000000..885f8ad2b08d3d78422e96e1ea2b9ac5d3baa57f --- /dev/null +++ b/config/vid2seq_config.py @@ -0,0 +1,183 @@ +import ml_collections + +YOUCOOK_TRAIN_SIZE = 1333 # Number of videos + + +def get_config(runlocal=''): + """Returns the base experiment configuration.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.token_loss_coef = 1. + config.runlocal = runlocal + config.experiment_name = 'youcook' + + config.count_flops = False # if runlocal else ml_collections.ConfigDict({'count_flops': True}) + + # dataset + config.dataset_name = 'dense_video_captioning' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.corrupt = 0. + config.dataset_configs.span_len = 3. + config.dataset_configs.preserve = True + config.dataset_configs.corrupt_coef = 0. + config.dataset_configs.proba_corrupt = 0. + notime = ml_collections.config_dict.FieldReference(False) + config.dataset_configs.notime = notime + config.dataset_configs.abs_time_token = False + config.dataset_configs.random_temporal_crop_proba = 0.5 + config.dataset_configs.time_format = 'se' + tmp_only = ml_collections.config_dict.FieldReference(False) + config.dataset_configs.tmp_only = tmp_only + config.dataset_configs.split = False + order = ml_collections.config_dict.FieldReference('ld') + config.dataset_configs.order = order + config.dataset_configs.from_xm = None + + config.data_dtype_str = 'float32' + + config.dataset_configs.base_dir = '/mnt/petrelfs/wangyiqin/vid_cap/examples' + config.dataset_configs.tables = { + 'train': 'train.tfrecord.sst@64', + 'validation': 'test@1', + } + config.dataset_configs.examples_per_subset = { + 'train': 0, + 'validation': 1, + } + + # List of modalities to load, supports `features` only for now. + # Note that it only specifies which modalities to load, not which to use, + # which is controlled by config.model.modality_fusion + config.dataset_configs.modalities = ('features', 'text') + config.dataset_configs.features_dim = 768 + config.dataset_configs.return_as_dict = True + num_frames = ml_collections.config_dict.FieldReference( + 100) # need to change back to 100 in the future -- Yiqin + config.dataset_configs.num_frames = num_frames + num_bins = ml_collections.config_dict.FieldReference(100) + config.dataset_configs.num_bins = num_bins + config.dataset_configs.one_hot_labels = True + config.dataset_configs.zero_centering = True + config.dataset_configs.val_on_test = False + config.dataset_configs.num_eval_clips = 1 + config.dataset_configs.prefetch_to_device = 2 + + # Text params + config.dataset_configs.max_num_output_words = 256 + config.dataset_configs.max_num_input_words = 1000 + config.dataset_configs.tokenizer = ml_collections.ConfigDict() + config.dataset_configs.tokenizer.tokenizer_type = 'sentence_piece' + config.dataset_configs.caption_string = 'caption/string' + config.dataset_configs.train_caption_string = 'caption/string' + config.dataset_configs.input_timestamp_name = 'video/timestamps' + config.dataset_configs.input_duration_name = 'video/duration' + config.dataset_configs.output_raw_timestamp_name = 'timestamp' + config.dataset_configs.output_raw_duration_name = 'duration' + config.dataset_configs.input_feature_name = 'image/clip_embeddings' + config.dataset_configs.output_raw_feature_name = 'features' + config.dataset_configs.vocabulary_size = 32128 + config.dataset_configs.max_events = 20 + config.dataset_configs.asr_notime = False + config.datasets = {'youcook': config.dataset_configs} + + # Decoding + config.decoding = ml_collections.ConfigDict() + config.decoding.decoding_method = 'beamsearch' + # config.decoding.decoding_method = 'temperature_sample' + config.decoding.num_decodes = 4 + config.decoding.alpha = 1 + config.decoding.temperature = 1. + + # Model + config.model_name = 'vid2seq' + config.model = ml_collections.ConfigDict() + config.model.from_xm = None + + # Encoder configs + config.model.encoder = ml_collections.ConfigDict() + config.model.encoder.share_encoder = True + config.model.encoder.encoder_type = 'cat_encoder' + config.model.encoder.cat_encoder = ml_collections.ConfigDict() + config.model.encoder.cat_encoder.dim = 2048 + config.model.encoder.cat_encoder.layers = 12 + config.model.encoder.cat_encoder.heads = 12 + config.model.encoder.cat_encoder.pos_embed = 'learned_1d' + config.model.encoder.cat_encoder.dropout_rate = 0. + config.model.encoder.cat_encoder.t5_dropout_rate = 0.1 + config.model.encoder.cat_encoder.stochastic_depth = 0. + config.model.encoder.cat_encoder.pretrained_config = 't5_1_1_base' + config.model.encoder.from_xm = None + + # Decoder configs + config.model.decoder_type = 't5_decoder' + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.order = order + config.model.decoder.t5_decoder = ml_collections.ConfigDict() + config.model.decoder.t5_decoder.logits_via_embedding = False + config.model.decoder.t5_decoder.dropout_rate = 0.1 + config.model.decoder.t5_decoder.num_frames = num_frames + config.model.decoder.notime = notime + config.model.decoder.num_bins = num_bins + config.model.decoder.tmp_only = tmp_only + config.model.decoder.t5_decoder.pretrained_config = 't5_1_1_base' + + # Initalisation configs + config.init_from = ml_collections.ConfigDict() + # Replace with your checkpoint pretrained on YT-temporal-1bn, assuming it has + # been trained for 200K iterations + config.init_from.checkpoint_path = '/mnt/petrelfs/wangyiqin/vid_cap/vid2seq_model' + # config.init_from.model_config = '/mnt/petrelfs/wangyiqin/vid_cap/scenic/scenic/projects/vid2seq/configs/yttemporal.py' + config.init_from.step = 200001 # ytt 200000, anet 200001 + + config.init_from.encoder = ml_collections.ConfigDict() + config.init_from.encoder.checkpoint_path = None + config.init_from.encoder.init_from_vit = False + config.init_from.encoder = ml_collections.ConfigDict() + config.init_from.encoder.load_pretrained_weights = True + + config.init_from.decoder = ml_collections.ConfigDict() + config.init_from.decoder.load_pretrained_weights = True + + config.init_from.t5 = ml_collections.ConfigDict() + config.init_from.t5.load_pretrained_weights = True + + # Training + config.trainer_name = 'densevidcap_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.weight_decay = 0. + config.l2_decay_factor = 0. + config.max_grad_norm = 1. + config.label_smoothing = 0.1 + epochs = ml_collections.config_dict.FieldReference(0) ### add + config.num_training_epochs = 0 + batch_size = ml_collections.config_dict.FieldReference(1) + config.batch_size = 1 #if runlocal else batch_size # 128 # Minimum is num_devices = 32 + config.eval_batch_size = 1 #if runlocal else 32 # Needs to be num_local_devices + config.rng_seed = 0 + + # Learning schedule. + steps_per_epoch = 3 if runlocal else YOUCOOK_TRAIN_SIZE // batch_size + total_steps = epochs * steps_per_epoch + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = total_steps // 10 + config.lr_configs.steps_per_cycle = total_steps + config.lr_configs.total_steps = total_steps + config.lr_configs.base_learning_rate = 3e-4 + + config.eval_metrics = ['cider', 'meteor', 'soda'] + + # Logging + config.log_eval_steps = steps_per_epoch # write TB and/or XM summary + config.log_summary_steps = steps_per_epoch # write TB and/or XM summary + config.write_summary = True # write TB and/or XM summary + config.write_xm_measurements = True # write XM measurements + config.xprof = True # Profile using xprof + config.checkpoint = True # do checkpointing + config.debug_train = False # debug mode during training + config.debug_eval = True # debug mode during eval + return config diff --git a/config/yttemporal.py b/config/yttemporal.py new file mode 100644 index 0000000000000000000000000000000000000000..1e291c18ab8c1b3a6bd3adcbe1a92013ea871783 --- /dev/null +++ b/config/yttemporal.py @@ -0,0 +1,184 @@ + +import ml_collections + + +def get_config(runlocal=''): + """Returns the base experiment configuration.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.token_loss_coef = 1. + config.runlocal = runlocal + config.experiment_name = 'ytt' + + config.count_flops = False if runlocal else ml_collections.ConfigDict( + {'count_flops': True}) + + # dataset + config.dataset_name = 'dense_video_captioning' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.corrupt = 0.25 + config.dataset_configs.span_len = 5. + config.dataset_configs.proba_corrupt = 1. + config.dataset_configs.corrupt_coef = 1. + config.dataset_configs.preserve = False + notime = ml_collections.config_dict.FieldReference(False) + config.dataset_configs.notime = notime + config.dataset_configs.abs_time_token = False + config.dataset_configs.random_temporal_crop_proba = 1. + config.dataset_configs.time_format = 'se' + tmp_only = ml_collections.config_dict.FieldReference(False) + config.dataset_configs.tmp_only = tmp_only + config.dataset_configs.split = not runlocal + order = ml_collections.config_dict.FieldReference('ld') + config.dataset_configs.order = order + config.dataset_configs.from_xm = None + + config.data_dtype_str = 'float32' + + config.dataset_configs.base_dir = '/' + config.dataset_configs.base_dir = '/path/to/yttemporal' + config.dataset_configs.tables = { + 'train': 'train.tfrecord.sst@1024', + } + config.dataset_configs.examples_per_subset = { + 'train': 14780275, + } + + # List of modalities to load, supports `features` only for now. + # Note that it only specifies which modalities to load, not which to use, + # which is controlled by config.model.modality_fusion + config.dataset_configs.modalities = ('features', 'text') + config.dataset_configs.features_dim = 768 + config.dataset_configs.return_as_dict = True + num_frames = ml_collections.config_dict.FieldReference(100) + config.dataset_configs.num_frames = num_frames + num_bins = ml_collections.config_dict.FieldReference(100) + config.dataset_configs.num_bins = num_bins + config.dataset_configs.one_hot_labels = True + config.dataset_configs.zero_centering = True + config.dataset_configs.val_on_test = False + config.dataset_configs.num_eval_clips = 1 + config.dataset_configs.prefetch_to_device = 2 + + # Text params + config.dataset_configs.max_num_output_words = 1000 + config.dataset_configs.max_num_input_words = 1000 + config.dataset_configs.tokenizer = ml_collections.ConfigDict() + config.dataset_configs.tokenizer.tokenizer_type = 'sentence_piece' + config.dataset_configs.caption_string = 'ASR/segment/label/string' + config.dataset_configs.train_caption_string = 'ASR/segment/label/string' + config.dataset_configs.input_timestamp_start_name = 'ASR/segment/start/timestamp' + config.dataset_configs.input_timestamp_end_name = 'ASR/segment/end/timestamp' + config.dataset_configs.input_duration_name = 'video/duration' + config.dataset_configs.output_raw_timestamp_name = 'timestamp' + config.dataset_configs.output_raw_duration_name = 'duration' + config.dataset_configs.input_feature_name = 'image/clip_embeddings' + config.dataset_configs.output_raw_feature_name = 'features' + config.dataset_configs.vocabulary_size = 32128 + config.dataset_configs.max_events = 1100 + config.dataset_configs.max_segments = 0 + config.datasets = {'ytt': config.dataset_configs} + + # Decoding + config.decoding = ml_collections.ConfigDict() + config.decoding.decoding_method = 'beamsearch' + config.decoding.num_decodes = 4 + config.decoding.alpha = 0.6 + config.decoding.temperature = 1. + + # Model + config.model_name = 'vid2seq' + config.model = ml_collections.ConfigDict() + config.model.from_xm = None + + # Encoder configs + config.model.encoder = ml_collections.ConfigDict() + config.model.encoder.share_encoder = True + config.model.encoder.encoder_type = 'cat_encoder' + config.model.encoder.cat_encoder = ml_collections.ConfigDict() + config.model.encoder.cat_encoder.dim = 2048 + config.model.encoder.cat_encoder.layers = 12 + config.model.encoder.cat_encoder.heads = 12 + config.model.encoder.cat_encoder.pos_embed = 'learned_1d' + config.model.encoder.cat_encoder.dropout_rate = 0.1 + config.model.encoder.cat_encoder.t5_dropout_rate = 0.1 + config.model.encoder.cat_encoder.stochastic_depth = 0. + config.model.encoder.cat_encoder.pretrained_config = 't5_1_1_base' + config.model.encoder.from_xm = None + + # Decoder configs + config.model.decoder_type = 't5_decoder' + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.order = order + config.model.decoder.t5_decoder = ml_collections.ConfigDict() + config.model.decoder.t5_decoder.logits_via_embedding = False + config.model.decoder.t5_decoder.dropout_rate = 0.1 + config.model.decoder.t5_decoder.num_frames = num_frames + config.model.decoder.notime = notime + config.model.decoder.num_bins = num_bins + config.model.decoder.tmp_only = tmp_only + # Obtained from scenic/projects/t5/model.py. + config.model.decoder.t5_decoder.pretrained_config = 't5_1_1_base' + + config.model.tmp_decoder_type = 't5_decoder' + config.model.tmp_decoder = ml_collections.ConfigDict() + config.model.tmp_decoder.t5_decoder = ml_collections.ConfigDict() + config.model.tmp_decoder.t5_decoder.logits_via_embedding = False + config.model.tmp_decoder.t5_decoder.dropout_rate = 0. + config.model.tmp_decoder.t5_decoder.pretrained_config = 't5_1_1_base' + config.model.decoder.t5_decoder.local = 5 + + # Initalisation configs + config.init_from = ml_collections.ConfigDict() + config.init_from.step = None + config.init_from.xm = None + + config.init_from.encoder = ml_collections.ConfigDict() + config.init_from.encoder.checkpoint_path = None + config.init_from.encoder.init_from_vit = False + config.init_from.encoder = ml_collections.ConfigDict() + config.init_from.encoder.load_pretrained_weights = True + + config.init_from.decoder = ml_collections.ConfigDict() + config.init_from.decoder.load_pretrained_weights = True + + config.init_from.t5 = ml_collections.ConfigDict() + config.init_from.t5.load_pretrained_weights = True + + # Training + config.trainer_name = 'densevidcap_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.weight_decay = 0. + config.l2_decay_factor = 0. + config.max_grad_norm = 0.1 + config.label_smoothing = 0.1 + epochs = ml_collections.config_dict.FieldReference(10) + config.num_training_epochs = 0 + batch_size = ml_collections.config_dict.FieldReference(512) + config.batch_size = 1 if runlocal else batch_size # 128 # Minimum is num_devices = 32 + config.eval_batch_size = 1 if runlocal else 128 # Needs to be num_local_devices + config.rng_seed = 0 + + # Learning schedule. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * linear_warmup' + config.lr_configs.warmup_steps = 1000 + config.lr_configs.base_learning_rate = 1e-4 + + config.eval_metrics = ['cider', 'meteor', 'soda'] + + # Logging + config.log_summary_steps = 500 # write TB and/or XM summary + config.checkpoint_steps = 5000 + config.log_eval_steps = 5000 + config.write_summary = True # write TB and/or XM summary + config.write_xm_measurements = True # write XM measurements + config.xprof = True # Profile using xprof + config.checkpoint = True # do checkpointing + config.debug_train = False # debug mode during training + config.debug_eval = False # debug mode during eval + return config diff --git a/model/Captioner.py b/model/Captioner.py new file mode 100644 index 0000000000000000000000000000000000000000..e443612b9eaf64ffbd0354722c93d94d609c2b2b --- /dev/null +++ b/model/Captioner.py @@ -0,0 +1,72 @@ +from mmaction.datasets.transforms import (DecordInit, SampleFrames, Resize, + FormatShape, DecordDecode) +from model.audio import SpeechRecognizer +from model.vision import DenseCaptioner, ImageCaptioner + + +class Captioner: + """ Captioner class for video captioning + """ + + def __init__(self, config): + """ Initialize the captioner + Args: + config: configuration file + """ + self.config = config + self.image_captioner = ImageCaptioner(device=config['device']) + self.dense_captioner = DenseCaptioner(device=config['device']) + self.speech_recognizer = SpeechRecognizer(device=config['device']) + # if self.config['vid2seq']['enable']: + # self.vid2seq_captioner = Vid2SeqCaptioner(config=config['vid2seq']) + + self.src_dir = '' + + def debug_vid2seq(self, video_path, num_frames=8): + return self.vid2seq_captioner(video_path=video_path) + + def caption_video(self, video_path, num_frames=8): + print("Watching video ...") + + video_info = {'filename': video_path, 'start_index': 0} + + video_processors = [ + DecordInit(), + SampleFrames(clip_len=1, frame_interval=1, num_clips=num_frames), + DecordDecode(), + Resize(scale=(-1, 720)), + FormatShape(input_format='NCHW'), + ] + for processor in video_processors: + video_info = processor.transform(video_info) + + timestamp_list = [ + round(i / video_info['avg_fps'], 1) + for i in video_info['frame_inds'] + ] + + image_captions = self.image_captioner(imgs=video_info['imgs']) + dense_captions = self.dense_captioner(imgs=video_info['imgs']) + # if self.config['vid2seq']['enable']: + # vid2seq_captions = self.vid2seq_captioner(video_path=video_path) + # else: + vid2seq_captions = [] + try: + speech = self.speech_recognizer(video_path) + except RuntimeError: + speech = "" + + overall_captions = "" + for i in range(num_frames): + overall_captions += "[" + str(timestamp_list[i]) + "s]: " + overall_captions += "You see " + image_captions[i] + overall_captions += "You find " + dense_captions[i] + "\n" + + if speech != "": + overall_captions += "You hear \"" + speech + "\"\n" + + for i in range(len(vid2seq_captions)): + overall_captions += "You notice " + vid2seq_captions[i] + "\n" + print("Captions generated") + + return overall_captions diff --git a/model/Vicuna.py b/model/Vicuna.py new file mode 100644 index 0000000000000000000000000000000000000000..54e0786fe26679b09b71a90aa64cc20a32a519ec --- /dev/null +++ b/model/Vicuna.py @@ -0,0 +1,214 @@ +from model.fastchat.conversation import (Conversation, SeparatorStyle, + compute_skip_echo_len, + get_default_conv_template) +from model.fastchat.serve.inference import (ChatIO, chat_loop, generate_stream, + load_model) + + +class SimpleChatIO(ChatIO): + + def prompt_for_input(self, role) -> str: + return input(f"{role}: ") + + def prompt_for_output(self, role: str): + print(f"{role}: ", end="", flush=True) + + def stream_output(self, output_stream, skip_echo_len: int): + pre = 0 + for outputs in output_stream: + outputs = outputs[skip_echo_len:].strip() + outputs = outputs.split(" ") + now = len(outputs) - 1 + if now > pre: + print(" ".join(outputs[pre:now]), end=" ", flush=True) + pre = now + print(" ".join(outputs[pre:]), flush=True) + return " ".join(outputs) + + +class VicunaChatBot: + + def __init__( + self, + model_path: str, + device: str, + num_gpus: str, + max_gpu_memory: str, + load_8bit: bool, + conv_template, + ChatIO: ChatIO, + debug: bool, + ): + self.model_path = model_path + self.device = device + self.chatio = ChatIO + self.debug = debug + + self.model, self.tokenizer = load_model(self.model_path, device, + num_gpus, max_gpu_memory, + load_8bit, debug) + + if conv_template: + self.conv = conv_template.copy() + else: + self.conv = get_default_conv_template(model_path).copy() + + self.conv_template = self.conv.copy() + + def chat(self, inp: str, temperature: float, max_new_tokens: int): + """ Vicuna as a chatbot. """ + self.conv.append_message(self.conv.roles[0], inp) + self.conv.append_message(self.conv.roles[1], None) + + generate_stream_func = generate_stream + prompt = self.conv.get_prompt() + + skip_echo_len = compute_skip_echo_len(self.model_path, self.conv, + prompt) + stop_str = ( + self.conv.sep if self.conv.sep_style + in [SeparatorStyle.SINGLE, SeparatorStyle.BAIZE] else None) + params = { + "model": self.model_path, + "prompt": prompt, + "temperature": temperature, + "max_new_tokens": max_new_tokens, + "stop": stop_str, + } + print(prompt) + self.chatio.prompt_for_output(self.conv.roles[1]) + output_stream = generate_stream_func(self.model, self.tokenizer, + params, self.device) + outputs = self.chatio.stream_output(output_stream, skip_echo_len) + # NOTE: strip is important to align with the training data. + self.conv.messages[-1][-1] = outputs.strip() + return outputs + + def summarise(self, caption: dict, temperature: float, + max_new_tokens: int): + """ Vicuna as a summariser. """ + questions = caption + captions = {} + for id, question in questions.items(): + # Reset the conversation for each iteration + self.conv = get_default_conv_template(self.model_path).copy() + self.conv.append_message(self.conv.roles[0], question) + self.conv.append_message(self.conv.roles[1], None) + + generate_stream_func = generate_stream + prompt = self.conv.get_prompt() + + skip_echo_len = compute_skip_echo_len(self.model_path, self.conv, + prompt) + stop_str = ( + self.conv.sep if self.conv.sep_style + in [SeparatorStyle.SINGLE, SeparatorStyle.BAIZE] else None) + + params = { + "model": self.model_path, + "prompt": prompt, + "temperature": temperature, + "max_new_tokens": max_new_tokens, + "stop": stop_str, + } + + self.chatio.prompt_for_output(self.conv.roles[1]) + output_stream = generate_stream_func(self.model, self.tokenizer, + params, self.device) + outputs = self.chatio.stream_output(output_stream, skip_echo_len) + captions[id] = outputs + + if self.debug: + print("\n", {"prompt": prompt, "outputs": outputs}, "\n") + + print(captions) + return captions + + def clear_conv_(self): + """ Clear the conversation. """ + self.conv = self.conv_template.copy() + + def change_conv_template_(self, conv_template): + self.conv_template = conv_template.copy() + self.conv = conv_template.copy() + + def change_conv_(self, conv_template): + """ Change the conversation. """ + self.conv = conv_template.copy() + + +class VicunaHandler: + """ VicunaHandler is a class that handles the communication between the + frontend and the backend. """ + + def __init__(self, config): + self.config = config + self.chat_io = SimpleChatIO() + self.chatbot = VicunaChatBot( + self.config['model_path'], + self.config['device'], + self.config['num_gpus'], + self.config['max_gpu_memory'], + self.config['load_8bit'], + None, + self.chat_io, + self.config['debug'], + ) + + def chat(self): + """ Chat with the Vicuna. """ + template = self._construct_conversation("") + chat_loop( + self.config['model_path'], + self.config['device'], + self.config['num_gpus'], + self.config['max_gpu_memory'], + self.config['load_8bit'], + template, + self.config['temperature'], + self.config['max_new_tokens'], + self.chat_io, + self.config['debug'], + ) + + def gr_chatbot_init(self, caption: str): + """ Initialise the chatbot for gradio. """ + + template = self._construct_conversation(caption) + self.chatbot.change_conv_template_(template) + print("Chatbot initialised.") + + def gr_chat(self, inp): + """ Chat using gradio as the frontend. """ + return self.chatbot.chat(inp, self.config['temperature'], + self.config['max_new_tokens']) + + def _construct_conversation(self, prompt): + """ Construct a conversation template. + Args: + prompt: the prompt for the conversation. + """ + + user_message = "The following text described what you have " +\ + "seen, found, heard and notice from a consecutive video." +\ + " Some of the texts may not be accurate. " +\ + "Try to conclude what happens in the video, " +\ + "then answer my question based on your conclusion.\n" +\ + "