File size: 1,666 Bytes
0a82b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import pandas as pd
from pathlib import Path
import webdataset as wds



def convert_path_to_key(img_path: Path) -> str:
    # 1. Get relative path from root


    relative = img_path.relative_to("data/database")
    
    # 2. Remove suffix (.jpg)
    no_suffix = relative.with_suffix('')
    
    # 3. Convert to POSIX-style string and flatten it
    flat = no_suffix.as_posix().replace('/', '_')
    
    # 4. Replace . with , to match your target format
    key = flat.replace('.', ',')
    
    return key

    
def update_mapping_csv(original_csv, webdataset_dir, new_csv_path):
    df = pd.read_csv(original_csv)


    webdataset_dir = Path(webdataset_dir)
    shards = list(webdataset_dir.glob("*.tar"))
    
    # Create mapping: key -> shard_path
    key_to_shard = {}
    for shard in shards:
        dataset = wds.WebDataset(str(shard),  empty_check=False)

        for sample in dataset:
            key = sample["__key__"]
            key_to_shard[key] = str(shard)
    
  
    df["key"] = df["local_path"].apply(lambda p: convert_path_to_key(Path(p)))
    df["shard_path"] = df["key"].map(key_to_shard)
    # ❗ Raise an error if any shard_path is NaN
    if df["shard_path"].isna().any():
        missing_keys = df[df["shard_path"].isna()]["key"].tolist()
        raise ValueError(f"Missing shard paths for the following keys: {missing_keys[:10]}... (and possibly more)")
    df.to_csv(new_csv_path, index=False)



if __name__ == "__main__":
    update_mapping_csv(
        original_csv="faiss_index/faiss_index_to_local_path.csv",
        webdataset_dir="data/webdataset_shards",
        new_csv_path="faiss_index/faiss_index_webdataset.csv"
    )