VatsalPatel18 commited on
Commit
8381e8e
·
verified ·
1 Parent(s): 70884da

Upload 8 files

Browse files
scripts/.ipynb_checkpoints/PlipDataProcess-checkpoint.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import torch
4
+ from PIL import Image
5
+ from concurrent.futures import ThreadPoolExecutor
6
+
7
+ class PlipDataProcess(torch.utils.data.Dataset):
8
+ def __init__(self, root_dir, files, df, img_processor=None, num_tiles_per_patient=128, max_workers=64, save_dir='processed_tile_data'):
9
+ self.root_dir = root_dir
10
+ self.files = files
11
+ self.df = df
12
+ self.img_processor = img_processor
13
+ self.num_tiles_per_patient = num_tiles_per_patient
14
+ self.max_workers = max_workers
15
+ self.save_dir = save_dir
16
+ if not os.path.exists(self.save_dir):
17
+ os.makedirs(self.save_dir)
18
+
19
+ def __len__(self):
20
+ return len(self.files)
21
+
22
+ def load_and_process_image(self, tile_path):
23
+ image = Image.open(tile_path)
24
+ return self.img_processor.preprocess(image)['pixel_values']
25
+
26
+ def save_individual_tile_data(self, tile_data, file_data, file_name, tile_name):
27
+ save_path = os.path.join(self.save_dir, file_name, f"{tile_name}.pt")
28
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
29
+ torch.save({'tile_data': tile_data, 'file_data': file_data}, save_path)
30
+
31
+ def __getitem__(self, idx):
32
+ file = self.files[idx]
33
+ tiles_path = os.path.join(self.root_dir, file,)
34
+ tiles = [tile for tile in os.listdir(tiles_path) if tile != '.ipynb_checkpoints']
35
+ selected_tiles = random.sample(tiles, min(self.num_tiles_per_patient, len(tiles)))
36
+
37
+ #file_data = torch.tensor(self.df.loc[f'{file}-01'].values, dtype=torch.float32)
38
+
39
+ try:
40
+ file_data = torch.tensor(self.df.loc[f'{file}-01'].values, dtype=torch.float32)
41
+ except KeyError:
42
+ # If the file is not found in the dataframe, create a tensor of zeros
43
+ # Shape is inferred from the other rows in the dataframe
44
+ num_features = self.df.shape[1]
45
+ file_data = torch.zeros(num_features, dtype=torch.float32)
46
+
47
+ with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
48
+ for tile_name in selected_tiles:
49
+ tile_path = os.path.join(tiles_path, tile_name)
50
+ executor.submit(self.process_and_save_tile, tile_path, file_data, file, tile_name)
51
+
52
+ return idx
53
+
54
+ def process_and_save_tile(self, tile_path, file_data, file_name, tile_name):
55
+ tile_data = self.load_and_process_image(tile_path)
56
+ self.save_individual_tile_data(tile_data, file_data, file_name, tile_name)
scripts/PlipDataProcess.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import torch
4
+ from PIL import Image
5
+ from concurrent.futures import ThreadPoolExecutor
6
+
7
+ class PlipDataProcess(torch.utils.data.Dataset):
8
+ def __init__(self, root_dir, files, df, img_processor=None, num_tiles_per_patient=128, max_workers=64, save_dir='processed_tile_data'):
9
+ self.root_dir = root_dir
10
+ self.files = files
11
+ self.df = df
12
+ self.img_processor = img_processor
13
+ self.num_tiles_per_patient = num_tiles_per_patient
14
+ self.max_workers = max_workers
15
+ self.save_dir = save_dir
16
+ if not os.path.exists(self.save_dir):
17
+ os.makedirs(self.save_dir)
18
+
19
+ def __len__(self):
20
+ return len(self.files)
21
+
22
+ def load_and_process_image(self, tile_path):
23
+ image = Image.open(tile_path)
24
+ return self.img_processor.preprocess(image)['pixel_values']
25
+
26
+ def save_individual_tile_data(self, tile_data, file_data, file_name, tile_name):
27
+ save_path = os.path.join(self.save_dir, file_name, f"{tile_name}.pt")
28
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
29
+ torch.save({'tile_data': tile_data, 'file_data': file_data}, save_path)
30
+
31
+ def __getitem__(self, idx):
32
+ file = self.files[idx]
33
+ tiles_path = os.path.join(self.root_dir, file,)
34
+ tiles = [tile for tile in os.listdir(tiles_path) if tile != '.ipynb_checkpoints']
35
+ selected_tiles = random.sample(tiles, min(self.num_tiles_per_patient, len(tiles)))
36
+
37
+ #file_data = torch.tensor(self.df.loc[f'{file}-01'].values, dtype=torch.float32)
38
+
39
+ try:
40
+ file_data = torch.tensor(self.df.loc[f'{file}-01'].values, dtype=torch.float32)
41
+ except KeyError:
42
+ # If the file is not found in the dataframe, create a tensor of zeros
43
+ # Shape is inferred from the other rows in the dataframe
44
+ num_features = self.df.shape[1]
45
+ file_data = torch.zeros(num_features, dtype=torch.float32)
46
+
47
+ with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
48
+ for tile_name in selected_tiles:
49
+ tile_path = os.path.join(tiles_path, tile_name)
50
+ executor.submit(self.process_and_save_tile, tile_path, file_data, file, tile_name)
51
+
52
+ return idx
53
+
54
+ def process_and_save_tile(self, tile_path, file_data, file_name, tile_name):
55
+ tile_data = self.load_and_process_image(tile_path)
56
+ self.save_individual_tile_data(tile_data, file_data, file_name, tile_name)
scripts/__pycache__/slide_processor_parallel.cpython-310.pyc ADDED
Binary file (6.34 kB). View file
 
scripts/genomic_plip_model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import CLIPVisionModel
3
+
4
+ class GenomicPLIPModel(torch.nn.Module):
5
+ def __init__(self, original_model):
6
+ super(GenomicPLIPModel, self).__init__()
7
+ self.vision_model = original_model.vision_model
8
+ self.vision_projection = torch.nn.Linear(768, 512)
9
+ self.fc_layer = torch.nn.Linear(4, 512) # Fully connected layer for the 4D vector
10
+
11
+ def forward(self, pixel_values, score_vector):
12
+ vision_output = self.vision_model(pixel_values)
13
+ pooled_output = vision_output.pooler_output
14
+ vision_features = self.vision_projection(pooled_output)
15
+ score_features = self.fc_layer(score_vector)
16
+
17
+ return vision_features, score_features
scripts/slide_processor.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import os
6
+ import openslide
7
+ from PIL import Image
8
+ from openslide import OpenSlideError
9
+ from openslide.deepzoom import DeepZoomGenerator
10
+ import math
11
+ import random
12
+ from pyspark.ml.linalg import Vectors
13
+ import pyspark.sql.functions as F
14
+ from scipy.ndimage.morphology import binary_fill_holes
15
+ from skimage.color import rgb2gray
16
+ from skimage.feature import canny
17
+ from skimage.morphology import binary_closing, binary_dilation, disk
18
+ from concurrent.futures import ProcessPoolExecutor
19
+ import tqdm
20
+
21
+ class SlideProcessor:
22
+ def __init__(self, tile_size=1024, overlap=0, tissue_threshold=0.65, max_workers=30):
23
+ self.tile_size = tile_size
24
+ self.overlap = overlap
25
+ self.tissue_threshold = tissue_threshold
26
+ self.max_workers = max_workers
27
+
28
+ def optical_density(self, tile):
29
+ tile = tile.astype(np.float64)
30
+ od = -np.log((tile+1)/240)
31
+ return od
32
+
33
+ def keep_tile(self, tile, tissue_threshold=None):
34
+ if tissue_threshold is None:
35
+ tissue_threshold = self.tissue_threshold
36
+
37
+ if tile.shape[0:2] == (self.tile_size, self.tile_size):
38
+ tile_orig = tile
39
+ tile = rgb2gray(tile)
40
+ tile = 1 - tile
41
+ tile = canny(tile)
42
+ tile = binary_closing(tile, disk(10))
43
+ tile = binary_dilation(tile, disk(10))
44
+ tile = binary_fill_holes(tile)
45
+ percentage = tile.mean()
46
+
47
+ check1 = percentage >= tissue_threshold
48
+
49
+ tile = self.optical_density(tile_orig)
50
+ beta = 0.15
51
+ tile = np.min(tile, axis=2) >= beta
52
+ tile = binary_closing(tile, disk(2))
53
+ tile = binary_dilation(tile, disk(2))
54
+ tile = binary_fill_holes(tile)
55
+ percentage = tile.mean()
56
+
57
+ check2 = percentage >= tissue_threshold
58
+
59
+ return check1 and check2
60
+ else:
61
+ return False
62
+
63
+ def filter_tiles(self, tile_indices, generator):
64
+ filtered_tiles = []
65
+ for i in range(len(tile_indices)):
66
+ tile_size, overlap, zoom_level, col, row = tile_indices[i]
67
+ tile = np.asarray(generator.get_tile(zoom_level, (col, row)))
68
+ if self.keep_tile(tile, self.tissue_threshold):
69
+ filtered_tiles.append((col, row))
70
+ return filtered_tiles
71
+
72
+
73
+ def get_tiles(self, samples, tile_indices, generator):
74
+ tiles = []
75
+ for i in samples:
76
+ tile_size, overlap, zoom_level, col, row = tile_indices[i]
77
+ tile = np.asarray(generator.get_tile(zoom_level, (col, row)))
78
+ tiles.append((i, tile))
79
+ return tiles
80
+
81
+ def save_tiles(self, sample_tiles, slide_num, loc='pDataset/rest'):
82
+ for sample in sample_tiles:
83
+ i, tile = sample
84
+ im = Image.fromarray(tile)
85
+ fname = f"{slide_num}_{i}"
86
+ file_path = os.path.join(loc, f"{fname}.jpeg")
87
+ im.save(file_path)
88
+
89
+ def get_save_tiles(self, samples, tile_indices, slide_num, generator, file, loc=None):
90
+ if loc is None:
91
+ loc = f'/home/gp7/ml_pni/Dataset/tiles_1024/{file}'
92
+
93
+ for i, cord in enumerate(samples):
94
+ x, y = cord
95
+ tile_size, overlap, zoom_level, col, row = tile_indices[i]
96
+ tile = np.asarray(generator.get_tile(zoom_level, (x, y)))
97
+ im = Image.fromarray(tile)
98
+ fname = f"{slide_num}_{x}_{y}"
99
+ file_path = os.path.join(loc, f"{fname}.jpeg")
100
+ im.save(file_path)
101
+
102
+ def process_one_slide(self, file, base_dir='HNSC_DS', output_dir='/home/gp7/ml_pni/Dataset/tiles_1024'):
103
+ f2p = os.path.join(base_dir, f'{file}.svs')
104
+
105
+ if not os.path.exists(output_dir):
106
+ os.makedirs(output_dir)
107
+
108
+ img1 = openslide.open_slide(f2p)
109
+ generator = DeepZoomGenerator(img1, tile_size=self.tile_size, overlap=self.overlap, limit_bounds=True)
110
+ highest_zoom_level = generator.level_count - 1
111
+
112
+ try:
113
+ mag = int(img1.properties[openslide.PROPERTY_NAME_OBJECTIVE_POWER])
114
+ offset = math.floor((mag / 20) / 2)
115
+ level = highest_zoom_level - offset
116
+ except (ValueError, KeyError):
117
+ level = highest_zoom_level
118
+
119
+ zoom_level = level
120
+ cols, rows = generator.level_tiles[zoom_level]
121
+ tile_indices = [(self.tile_size, self.overlap, zoom_level, col, row) for col in range(cols) for row in range(rows)]
122
+
123
+ filter_sname = os.path.join(output_dir, f'{file}_info.npy')
124
+
125
+ if os.path.exists(filter_sname):
126
+ try:
127
+ filtered_tiles = np.load(filter_sname)
128
+ print(f"Found existing filtered tiles for {file}, skipping tile filtering.")
129
+ except:
130
+ print(f"Error reading {filter_sname}, re-filtering tiles.")
131
+ filtered_tiles = self.filter_tiles(tile_indices, generator)
132
+ np.save(filter_sname, filtered_tiles)
133
+ else:
134
+ print(f"Didn't find existing filtered tiles for {file}, filtering tiles.")
135
+ filtered_tiles = self.filter_tiles(tile_indices, generator)
136
+ np.save(filter_sname, filtered_tiles)
137
+
138
+ directory = os.path.join(output_dir, file)
139
+ if not os.path.exists(directory):
140
+ os.makedirs(directory)
141
+
142
+ existing_files_count = len([f for f in os.listdir(directory) if f.endswith('.jpeg')])
143
+
144
+ filtered_tiles_count = len(filtered_tiles)
145
+ threshold = 5
146
+ if abs(existing_files_count - filtered_tiles_count) <= threshold:
147
+ print(f"Found approximately the same number of files as filtered tiles for {file}, skipping tile saving.")
148
+ else:
149
+ print('Now going to save tiles')
150
+ self.get_save_tiles(filtered_tiles, tile_indices, file, generator, directory)
151
+
152
+ return file
153
+
154
+ def parallel_process(self, files, base_dir='HNSC_DS', output_dir='/home/gp7/ml_pni/Dataset/tiles_1024'):
155
+ with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
156
+ results = list(tqdm.tqdm(executor.map(self.process_one_slide, files, [base_dir]*len(files), [output_dir]*len(files)), total=len(files)))
157
+ return results
scripts/slide_processor_parallel.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import os
6
+ import openslide
7
+ from PIL import Image
8
+ from openslide import OpenSlideError
9
+ from openslide.deepzoom import DeepZoomGenerator
10
+ import math
11
+ import random
12
+ from scipy.ndimage.morphology import binary_fill_holes
13
+ from skimage.color import rgb2gray
14
+ from skimage.feature import canny
15
+ from skimage.morphology import binary_closing, binary_dilation, disk
16
+ from concurrent.futures import ProcessPoolExecutor
17
+ import tqdm
18
+
19
+ class SlideProcessor:
20
+ def __init__(self, tile_size=1024, overlap=0, tissue_threshold=0.65, max_workers=30):
21
+ self.tile_size = tile_size
22
+ self.overlap = overlap
23
+ self.tissue_threshold = tissue_threshold
24
+ self.max_workers = max_workers
25
+
26
+ def optical_density(self, tile):
27
+ tile = tile.astype(np.float64)
28
+ od = -np.log((tile+1)/240)
29
+ return od
30
+
31
+ def keep_tile(self, tile, tissue_threshold=None):
32
+ if tissue_threshold is None:
33
+ tissue_threshold = self.tissue_threshold
34
+
35
+ if tile.shape[0:2] == (self.tile_size, self.tile_size):
36
+ tile_orig = tile
37
+ tile = rgb2gray(tile)
38
+ tile = 1 - tile
39
+ tile = canny(tile)
40
+ tile = binary_closing(tile, disk(10))
41
+ tile = binary_dilation(tile, disk(10))
42
+ tile = binary_fill_holes(tile)
43
+ percentage = tile.mean()
44
+
45
+ check1 = percentage >= tissue_threshold
46
+
47
+ tile = self.optical_density(tile_orig)
48
+ beta = 0.15
49
+ tile = np.min(tile, axis=2) >= beta
50
+ tile = binary_closing(tile, disk(2))
51
+ tile = binary_dilation(tile, disk(2))
52
+ tile = binary_fill_holes(tile)
53
+ percentage = tile.mean()
54
+
55
+ check2 = percentage >= tissue_threshold
56
+
57
+ return check1 and check2
58
+ else:
59
+ return False
60
+
61
+ def filter_tiles(self, tile_indices, generator):
62
+ def process_tile(tile_index):
63
+ tile_size, overlap, zoom_level, col, row = tile_index
64
+ tile = np.asarray(generator.get_tile(zoom_level, (col, row)))
65
+ if self.keep_tile(tile, self.tissue_threshold):
66
+ return col, row
67
+ return None
68
+
69
+ with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
70
+ results = executor.map(process_tile, tile_indices)
71
+
72
+ # Filter out None results and return the list of tiles to keep
73
+ return [result for result in results if result is not None]
74
+
75
+
76
+ def get_tiles(self, samples, tile_indices, generator):
77
+ tiles = []
78
+ for i in samples:
79
+ tile_size, overlap, zoom_level, col, row = tile_indices[i]
80
+ tile = np.asarray(generator.get_tile(zoom_level, (col, row)))
81
+ tiles.append((i, tile))
82
+ return tiles
83
+
84
+ def save_tiles(self, sample_tiles, slide_num, loc='pDataset/rest'):
85
+ for sample in sample_tiles:
86
+ i, tile = sample
87
+ im = Image.fromarray(tile)
88
+ fname = f"{slide_num}_{i}"
89
+ file_path = os.path.join(loc, f"{fname}.jpeg")
90
+ im.save(file_path)
91
+
92
+ def get_save_tiles(self, samples, tile_indices, slide_num, generator, file, loc):
93
+
94
+ def save_tile(cord):
95
+ x, y = cord
96
+ tile_index = next((ti for ti in tile_indices if ti[3] == x and ti[4] == y), None)
97
+ if tile_index:
98
+ tile_size, overlap, zoom_level, col, row = tile_index
99
+ tile = np.asarray(generator.get_tile(zoom_level, (x, y)))
100
+ im = Image.fromarray(tile)
101
+ fname = f"{slide_num}_{x}_{y}"
102
+ file_path = os.path.join(loc, f"{fname}.jpeg")
103
+ im.save(file_path)
104
+
105
+ with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
106
+ executor.map(save_tile, samples)
107
+
108
+ def process_one_slide(self, file_loc, output_dir=None):
109
+ f2p = file_loc
110
+
111
+ if not os.path.exists(output_dir):
112
+ os.makedirs(output_dir)
113
+
114
+ img1 = openslide.open_slide(f2p)
115
+ generator = DeepZoomGenerator(img1, tile_size=self.tile_size, overlap=self.overlap, limit_bounds=True)
116
+ highest_zoom_level = generator.level_count - 1
117
+
118
+ try:
119
+ mag = int(img1.properties[openslide.PROPERTY_NAME_OBJECTIVE_POWER])
120
+ offset = math.floor((mag / 20) / 2)
121
+ level = highest_zoom_level - offset
122
+ except (ValueError, KeyError):
123
+ level = highest_zoom_level
124
+
125
+ zoom_level = level
126
+ cols, rows = generator.level_tiles[zoom_level]
127
+ tile_indices = [(self.tile_size, self.overlap, zoom_level, col, row) for col in range(cols) for row in range(rows)]
128
+
129
+ filtered_tiles = self.filter_tiles(tile_indices, generator)
130
+ #np.save(filter_sname, filtered_tiles)
131
+ if file_loc.endswith('.svs'):
132
+ file = file_loc[-16:-4]
133
+ print(file)
134
+
135
+ directory = os.path.join(output_dir, file)
136
+ if not os.path.exists(directory):
137
+ os.makedirs(directory)
138
+
139
+ existing_files_count = len([f for f in os.listdir(directory) if f.endswith('.jpeg')])
140
+
141
+ filtered_tiles_count = len(filtered_tiles)
142
+ threshold = 5
143
+ if abs(existing_files_count - filtered_tiles_count) <= threshold:
144
+ print(f"Found approximately the same number of files as filtered tiles for {file}, skipping tile saving.")
145
+ else:
146
+ print('Now going to save tiles')
147
+ self.get_save_tiles(filtered_tiles, tile_indices, file, generator,file, directory)
148
+ #np.save(directory, filtered_tiles)
149
+
150
+ return file
151
+
152
+ def parallel_process(self, base_dir='HNSC_DS', output_dir=None):
153
+ # List all .svs files in the base directory
154
+ files = [os.path.join(base_dir, f) for f in os.listdir(base_dir) if f.endswith('.svs')]
155
+
156
+ with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
157
+ # Use executor.map to process each file. No need to repeat base_dir and output_dir as they are now constant for all files
158
+ results = list(tqdm.tqdm(executor.map(self.process_one_slide, files, [output_dir]*len(files)), total=len(files)))
159
+
160
+ return results
scripts/tile_classifier.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import Dataset
4
+
5
+ class SimpleNN(nn.Module):
6
+ def __init__(self):
7
+ super(SimpleNN, self).__init__()
8
+ self.fc1 = nn.Linear(512, 512)
9
+ self.fc2 = nn.Linear(512, 256)
10
+ self.fc3 = nn.Linear(256, 1)
11
+
12
+ def forward(self, x):
13
+ x = torch.relu(self.fc1(x))
14
+ x = torch.relu(self.fc2(x))
15
+ x = torch.sigmoid(self.fc3(x))
16
+ return x
17
+
18
+ class CustomDataset(Dataset):
19
+ <<<<<<< HEAD
20
+ def __init__(self, X, Y):
21
+ =======
22
+ def __init__(self,X,Y):
23
+ >>>>>>> docker
24
+ self.X = torch.tensor(X, dtype=torch.float32)
25
+ self.Y = torch.tensor(Y, dtype=torch.float32)
26
+
27
+ def __len__(self):
28
+ return len(self.X)
29
+
30
+ def __getitem__(self, index):
31
+ return self.X[index], self.Y[index]
scripts/tile_file_dataloader.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+
5
+ class FlatTileDataset(Dataset):
6
+ def __init__(self, data_dir):
7
+ super().__init__()
8
+ self.data_dir = data_dir
9
+ # List all files in the data_dir that are files (not directories)
10
+ self.files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, f))]
11
+
12
+ def __len__(self):
13
+ # Return the total number of files
14
+ return len(self.files)
15
+
16
+ def __getitem__(self, idx):
17
+ # Get the file path for the given index
18
+ file_path = self.files[idx]
19
+ # Load the data from the file
20
+ data = torch.load(file_path)
21
+ # Assuming the data file is a dictionary with 'tile_data' and 'file_data' keys
22
+ tile_data = torch.from_numpy(data['tile_data'][0])
23
+ file_data = data['file_data']
24
+ # Return the tile data and file data
25
+ return tile_data, file_data