sachin commited on
Commit
35352c6
1 Parent(s): 2e2cb86

Adding training data module

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. src/config.py +40 -0
  3. src/data.py +109 -0
  4. models.py → src/models.py +0 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .vscode/
src/config.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pydantic
2
+
3
+
4
+ class DataConfig(pydantic.BaseModel):
5
+ buffer_size: int = 1000
6
+ data_len: int = 100000
7
+ train_len: int = 90000
8
+ small_dataset: str = "laion/220k-gpt4vision-captions-from-livis"
9
+ large_dataset: str = "laion/laion400m"
10
+ dataset: str = small_dataset
11
+
12
+
13
+ class ModelConfig(pydantic.BaseModel):
14
+ text_model: str = "microsoft/xtremedistil-l6-h256-uncased" # 51 mb
15
+ vision_model: str = "edgenext_small" # 20 mb
16
+ projection_layers: int = 3
17
+ embed_dim: int = 256
18
+ transformer_embed_dim: int = 768
19
+ max_len: int = 77 # maximum length of text in CLIP
20
+ cls_type: bool = True
21
+ freeze_vision_base: bool = False
22
+ freeze_text_base: bool = False
23
+
24
+
25
+ class TrainerConfig(pydantic.BaseModel):
26
+ epochs: int = 20
27
+ batch_size: int = 256
28
+ learning_rate: float = 5e-4
29
+ accumulate_grad_batches: int = 1
30
+ temperature: float = 1.0
31
+ vision_freeze_layers: int = 2
32
+ lambda_1: float = 1.0
33
+ lambda_2: float = 1.0
34
+
35
+ val_check_interval: int = 1000
36
+
37
+ run_openai_clip: bool = False
38
+
39
+ model_config: ModelConfig = ModelConfig()
40
+ data_config: DataConfig = DataConfig()
src/data.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import multiprocessing as mp
3
+ from typing import Optional, Union
4
+
5
+ import datasets
6
+ from PIL import Image
7
+ import requests
8
+ import torch
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from torchvision import transforms
11
+ from transformers import AutoTokenizer
12
+
13
+ from src import config
14
+
15
+
16
+ class Tokenizer:
17
+ def __init__(self, model_name: str, max_len: int) -> None:
18
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ self.max_len = max_len
20
+
21
+ def __call__(self, x: Union[str, list[str]]) -> dict[str, torch.LongTensor]:
22
+ return self.tokenizer(
23
+ x, max_length=self.max_len, truncation=True, padding=True, return_tensors="pt"
24
+ )
25
+
26
+ def decode(self, x: dict[str, torch.LongTensor]) -> list[str]:
27
+ return [
28
+ self.tokenizer.decode(sentence[:sentence_len])
29
+ for sentence, sentence_len in zip(x["input_ids"], x["attention_mask"].sum(axis=-1))
30
+ ]
31
+
32
+
33
+ def _get_image_and_caption(item: dict[str, str]) -> Optional[tuple[Image.Image, str]]:
34
+ image_url = item["url"]
35
+ caption = item["caption"]
36
+ try:
37
+ response = requests.get(image_url, timeout=1)
38
+ response.raise_for_status() # Raise HTTPError for bad responses (4xx and 5xx)
39
+ image = Image.open(io.BytesIO(response.content))
40
+ return image, caption
41
+ except (requests.RequestException, IOError):
42
+ return None
43
+
44
+
45
+ class CollateFn:
46
+ def __init__(self, tokenizer: Tokenizer, transform: transforms.Compose):
47
+ self.tokenizer = tokenizer
48
+ self.transform = transform
49
+
50
+ def __call__(
51
+ self, batch: list[Optional[tuple[str, torch.FloatTensor]]]
52
+ ) -> tuple[dict[str, torch.LongTensor], torch.FloatTensor]:
53
+ filtered_batch = [data for data in map(_get_image_and_caption, batch) if data is not None]
54
+ x, y = zip(*filtered_batch)
55
+ tokenized_text = self.tokenizer(list(x))
56
+ return tokenized_text, torch.stack([self.transform(image) for image in y])
57
+
58
+
59
+ def _get_dataloaders(
60
+ train_ds: Dataset,
61
+ valid_ds: Dataset,
62
+ training_config: config.TrainerConfig,
63
+ collate_fn: CollateFn,
64
+ ) -> tuple[DataLoader, DataLoader]:
65
+ common_params = {
66
+ "batch_size": training_config.batch_size,
67
+ "pin_memory": True,
68
+ "num_workers": mp.cpu_count(),
69
+ "collate_fn": collate_fn,
70
+ }
71
+ train_loader = DataLoader(
72
+ train_ds,
73
+ shuffle=True,
74
+ drop_last=True,
75
+ **common_params,
76
+ )
77
+ valid_loader = DataLoader(
78
+ valid_ds,
79
+ shuffle=False,
80
+ drop_last=False,
81
+ **common_params,
82
+ )
83
+ return train_loader, valid_loader
84
+
85
+
86
+ def get_dataset(
87
+ transform: transforms.Compose,
88
+ tokenizer: Tokenizer,
89
+ hyper_parameters: config.TrainerConfig,
90
+ num_workers: int,
91
+ ) -> tuple[DataLoader, DataLoader]:
92
+ dataset = datasets.load_dataset(
93
+ hyper_parameters.data_config.dataset, split="train", streaming=True
94
+ )
95
+ full_dataset = dataset.shuffle(
96
+ seed=42, buffer_size=hyper_parameters.data_config.buffer_size
97
+ ).take(hyper_parameters.data_config.data_len)
98
+ train_dataset = full_dataset.take(hyper_parameters.data_config.train_len)
99
+ valid_dataset = full_dataset.skip(hyper_parameters.data_config.train_len)
100
+
101
+ collate_fn = CollateFn(tokenizer, transform)
102
+
103
+ return _get_dataloaders(
104
+ train_ds=train_dataset,
105
+ valid_ds=valid_dataset,
106
+ training_config=hyper_parameters,
107
+ collate_fn=collate_fn,
108
+ num_workers=num_workers,
109
+ )
models.py → src/models.py RENAMED
File without changes