Andron00e commited on
Commit
3a70b14
1 Parent(s): eec6bf2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +75 -1
README.md CHANGED
@@ -248,4 +248,78 @@ The following hyperparameters were used during training:
248
  - Transformers 4.35.2
249
  - Pytorch 2.1.0+cu118
250
  - Datasets 2.15.0
251
- - Tokenizers 0.15.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  - Transformers 4.35.2
249
  - Pytorch 2.1.0+cu118
250
  - Datasets 2.15.0
251
+ - Tokenizers 0.15.0
252
+
253
+ ### Example of usage
254
+
255
+ ```python
256
+ from datasets import load_dataset
257
+ from transformers import TrainingArguments
258
+ from transformers import CLIPProcessor, AutoModelForImageClassification
259
+
260
+ processor = CLIPProcessor.from_pretrained("Andron00e/CLIPForImageClassification-v1")
261
+ model = AutoModelForImageClassification.from_pretrained("Andron00e/CLIPForImageClassification-v1")
262
+
263
+ dataset = load_dataset("Andron00e/CIFAR100-custom")
264
+ dataset = dataset["train"].train_test_split(test_size=0.2)
265
+ from datasets import DatasetDict
266
+
267
+ val_test = dataset["test"].train_test_split(test_size=0.5)
268
+ dataset = DatasetDict({
269
+ "train": dataset["train"],
270
+ "validation": val_test["train"],
271
+ "test": val_test["test"],
272
+ })
273
+
274
+ def transform(example_batch):
275
+ inputs = processor(text=[classes[x] for x in example_batch['labels']], images=[x for x in example_batch['image']], padding=True, return_tensors='pt')
276
+ inputs['labels'] = example_batch['labels']
277
+ return inputs
278
+
279
+ def collate_fn(batch):
280
+ return {
281
+ 'input_ids': torch.stack([x['input_ids'] for x in batch]),
282
+ 'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
283
+ 'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
284
+ 'labels': torch.tensor([x['labels'] for x in batch])
285
+ }
286
+
287
+ training_args = TrainingArguments(
288
+ output_dir="./outputs",
289
+ per_device_train_batch_size=16,
290
+ evaluation_strategy="steps",
291
+ num_train_epochs=4,
292
+ fp16=False,
293
+ save_steps=100,
294
+ eval_steps=100,
295
+ logging_steps=10,
296
+ learning_rate=2e-4,
297
+ save_total_limit=2,
298
+ remove_unused_columns=False,
299
+ push_to_hub=False,
300
+ report_to='tensorboard',
301
+ load_best_model_at_end=True,
302
+ )
303
+
304
+ from transformers import Trainer
305
+
306
+ trainer = Trainer(
307
+ model=model,
308
+ args=training_args,
309
+ data_collator=collate_fn,
310
+ compute_metrics=compute_metrics,
311
+ train_dataset=dataset.with_transform(transform)["train"],
312
+ eval_dataset=dataset.with_transform(transform)["validation"],
313
+ tokenizer=model.processor,
314
+ )
315
+
316
+ train_results = trainer.train()
317
+ trainer.save_model()
318
+ trainer.log_metrics("train", train_results.metrics)
319
+ trainer.save_metrics("train", train_results.metrics)
320
+ trainer.save_state()
321
+
322
+ metrics = trainer.evaluate(processed_dataset['test'])
323
+ trainer.log_metrics("eval", metrics)
324
+ trainer.save_metrics("eval", metrics)
325
+ ```