File size: 2,154 Bytes
12f2e48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import List
import random
import argparse
import json
import os

from datasets import Dataset

from multi_token.constants import ROLE_ASSISTANT, ROLE_USER


TYPES = ["audio", "image", "text"]

REPLACEMENTS = {
    "image": ["audio", "image", "document"],
    "picture": ["audio file", "picture", "text snippet"],
    "photo": ["sound", "photo", "text"],
    "visual": ["audio", "visual", "textual"],
    "see": ["hear", "see", "read"],
    "look": ["sound", "look", "read"],
    "visible": ["audible", "visible", "readable"],
}

TEMP_TOKEN = "<<<TEMP-TOKEN>>>"


def _convert_convo(convo) -> List:
    type_idx = TYPES.index(random.choice(TYPES))
    msgs = []
    for m in convo:
        content = m["value"].replace("<image>", TEMP_TOKEN)
        for k, v in REPLACEMENTS.items():
            content = content.replace(k, v[type_idx])
        content = content.replace(TEMP_TOKEN, "<imagebind>")
        msgs.append(
            {
                "role": {"gpt": ROLE_ASSISTANT, "human": ROLE_USER}[m["from"]],
                "content": content,
            }
        )
    return msgs


def main(args):
    rows = []
    for json_fn in args.llava_json:
        with open(json_fn) as f:
            rows.extend(json.load(f))

    def gen(rows):
        for row in rows:
            img_path = row["image"]
            fn = os.path.join(args.image_folder, img_path)
            if not os.path.exists(fn):
                print("Skipping", fn)
                continue
            yield {
                "id": str(row["id"]),
                "imagebinds": [fn],
                "messages": _convert_convo(row["conversations"]),
            }

    ds = Dataset.from_generator(gen, gen_kwargs={"rows": rows}, num_proc=args.num_proc)
    ds.save_to_disk(args.output_folder)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--llava_json", type=str, action="append")
    parser.add_argument("-f", "--image_folder", type=str)
    parser.add_argument("-o", "--output_folder", type=str)
    parser.add_argument("-n", "--num_proc", type=int, default=1)
    args = parser.parse_args()
    main(args)