File size: 10,224 Bytes
fc0a183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import os
from pathlib import Path
import argparse
import glob
import time
import gc
from tqdm import tqdm
import torch
from transformers import AutoTokenizer
import pandas as pd
from vllm import LLM, SamplingParams
from torch.utils.data import DataLoader
import json
import random
from utils import result_writer

SYSTEM_PROMPT_I2V = """
You are an expert in video captioning. You are given a structured video caption and you need to compose it to be more natural and fluent in English.

## Structured Input
{structured_input}

## Notes
1. If there has an empty field, just ignore it and do not mention it in the output.
2. Do not make any semantic changes to the original fields. Please be sure to follow the original meaning.
3. If the action field is not empty, eliminate the irrelevant information in the action field that is not related to the timing action(such as wearings, background and environment information) to make a pure action field.

## Output Principles and Orders
1. First, eliminate the static information in the action field that is not related to the timing action, such as background or environment information.
2. Second, describe each subject with its pure action and expression if these fields exist.

## Output
Please directly output the final composed caption without any additional information.
"""

SYSTEM_PROMPT_T2V = """
You are an expert in video captioning. You are given a structured video caption and you need to compose it to be more natural and fluent in English.

## Structured Input
{structured_input}

## Notes
1. According to the action field information, change its name field to the subject pronoun in the action.
2. If there has an empty field, just ignore it and do not mention it in the output.
3. Do not make any semantic changes to the original fields. Please be sure to follow the original meaning.

## Output Principles and Orders
1. First, declare the shot_type, then declare the shot_angle and the shot_position fields.
2. Second, eliminate information in the action field that is not related to the timing action, such as background or environment information if action is not empty.
3. Third, describe each subject with its pure action, appearance, expression, position if these fields exist.
4. Finally, declare the environment and lighting if the environment and lighting fields are not empty.

## Output
Please directly output the final composed caption without any additional information.
"""

SHOT_TYPE_LIST = [
    'close-up shot',
    'extreme close-up shot',
    'medium shot',
    'long shot',
    'full shot',
]


class StructuralCaptionDataset(torch.utils.data.Dataset):
    def __init__(self, input_csv, model_path):
        self.meta = pd.read_csv(input_csv)
        self.task = args.task
        self.system_prompt = SYSTEM_PROMPT_T2V if self.task == 't2v' else SYSTEM_PROMPT_I2V
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        
    
    def __len__(self):
        return len(self.meta)
    
    def __getitem__(self, index):
        row = self.meta.iloc[index]
        real_index = self.meta.index[index]

        struct_caption = json.loads(row["structural_caption"])

        camera_movement = struct_caption.get('camera_motion', '')
        if camera_movement != '':
            camera_movement += '.'
        camera_movement = camera_movement.capitalize()
        
        fusion_by_llm = False
        cleaned_struct_caption = self.clean_struct_caption(struct_caption, self.task)
        if cleaned_struct_caption.get('num_subjects', 0) > 0:
            new_struct_caption = json.dumps(cleaned_struct_caption, indent=4, ensure_ascii=False)
            conversation = [
                {
                    "role": "system",
                    "content": self.system_prompt.format(structured_input=new_struct_caption),
                },
            ]
            text = self.tokenizer.apply_chat_template(
                conversation,
                tokenize=False,
                add_generation_prompt=True
            )
            fusion_by_llm = True
        else:
            text = '-'
        return real_index, fusion_by_llm, text, '-', camera_movement
    
    def clean_struct_caption(self, struct_caption, task):
        raw_subjects = struct_caption.get('subjects', [])
        subjects = []
        for subject in raw_subjects:
            subject_type = subject.get("TYPES", {}).get('type', '')
            subject_sub_type = subject.get("TYPES", {}).get('sub_type', '')
            if subject_type not in ["Human", "Animal"]:
                subject['expression'] = ''
            if subject_type == 'Human' and subject_sub_type == 'Accessory':
                subject['expression'] = ''
            if subject_sub_type != '':
                subject['name'] = subject_sub_type
            if 'TYPES' in subject:
                del subject['TYPES']
            if 'is_main_subject' in subject:
                del subject['is_main_subject']
            subjects.append(subject)

        to_del_subject_ids = []
        for idx, subject in enumerate(subjects):
            action = subject.get('action', '').strip()
            subject['action'] = action
            if random.random() > 0.9 and 'appearance' in subject:
                del subject['appearance']
            if random.random() > 0.9 and 'position' in subject:
                del subject['position']
            if task == 'i2v':
                # just keep name and action, expression in subjects
                dropped_keys = ['appearance', 'position']
                for key in dropped_keys:
                    if key in subject:
                        del subject[key]
                if subject['action'] == '' and ('expression' not in subject or subject['expression'] == ''):
                    to_del_subject_ids.append(idx)
        
        # delete the subjects according to the to_del_subject_ids
        for idx in sorted(to_del_subject_ids, reverse=True):
            del subjects[idx]


        shot_type = struct_caption.get('shot_type', '').replace('_', ' ')
        if shot_type not in SHOT_TYPE_LIST:
            struct_caption['shot_type'] = ''
        
        new_struct_caption = {
            'num_subjects': len(subjects),
            'subjects': subjects,
            'shot_type': struct_caption.get('shot_type', ''),
            'shot_angle': struct_caption.get('shot_angle', ''),
            'shot_position': struct_caption.get('shot_position', ''),
            'environment': struct_caption.get('environment', ''),
            'lighting': struct_caption.get('lighting', ''),
        }

        if task == 't2v' and random.random() > 0.9:
            del new_struct_caption['lighting']

        if task == 'i2v':
            drop_keys = ['environment', 'lighting', 'shot_type', 'shot_angle', 'shot_position']
            for drop_key in drop_keys:
                del new_struct_caption[drop_key]
        return new_struct_caption

def custom_collate_fn(batch):
    real_indices, fusion_by_llm, texts, original_texts, camera_movements = zip(*batch)
    return list(real_indices), list(fusion_by_llm), list(texts), list(original_texts), list(camera_movements)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Caption Fusion by LLM")
    parser.add_argument("--input_csv", default="./examples/test_result.csv")
    parser.add_argument("--out_csv", default="./examples/test_result_caption.csv")
    parser.add_argument("--bs", type=int, default=4)
    parser.add_argument("--tp", type=int, default=1)
    parser.add_argument("--model_path", required=True, type=str, help="LLM model path")
    parser.add_argument("--task", default='t2v', help="t2v or i2v")
    
    args = parser.parse_args()

    sampling_params = SamplingParams(
        temperature=0.1,
        max_tokens=512,
        stop=['\n\n']
    )
    # model_path = "/maindata/data/shared/public/Common-Models/Qwen2.5-32B-Instruct/"

   
    llm = LLM(
        model=args.model_path,
        gpu_memory_utilization=0.9,
        max_model_len=4096,
        tensor_parallel_size = args.tp
    )
    

    dataset = StructuralCaptionDataset(input_csv=args.input_csv, model_path=args.model_path)
    
    dataloader = DataLoader(
        dataset,
        batch_size=args.bs,
        num_workers=8,
        collate_fn=custom_collate_fn,
        shuffle=False,
        drop_last=False,
    )

    indices_list = []
    result_list = []
    for indices, fusion_by_llms, texts, original_texts, camera_movements in tqdm(dataloader):
        llm_indices, llm_texts, llm_original_texts, llm_camera_movements = [], [], [], []
        for idx, fusion_by_llm, text, original_text, camera_movement in zip(indices, fusion_by_llms, texts, original_texts, camera_movements):
            if fusion_by_llm:
                llm_indices.append(idx)
                llm_texts.append(text)
                llm_original_texts.append(original_text)
                llm_camera_movements.append(camera_movement)    
            else:
                indices_list.append(idx)
                caption = original_text + " " + camera_movement
                result_list.append(caption)
        if len(llm_texts) > 0:
            try:
                outputs = llm.generate(llm_texts, sampling_params, use_tqdm=False)
                results = []
                for output in outputs:
                    result = output.outputs[0].text.strip()
                    results.append(result)
                indices_list.extend(llm_indices)
            except Exception as e:
                print(f"Error at {llm_indices}: {str(e)}")
                indices_list.extend(llm_indices)
                results = llm_original_texts
            
            for result, camera_movement in zip(results, llm_camera_movements):
                # concat camera movement to fusion_caption
                llm_caption = result + " " + camera_movement
                result_list.append(llm_caption)
    torch.cuda.empty_cache()
    gc.collect()
    gathered_list = [indices_list, result_list]
    meta_new = result_writer(indices_list, result_list, dataset.meta, column=[f"{args.task}_fusion_caption"])
    meta_new.to_csv(args.out_csv, index=False)