Ubuntu commited on
Commit
de77f70
·
1 Parent(s): dbcdf98
Files changed (2) hide show
  1. .ipynb_checkpoints/app-checkpoint.py +16 -14
  2. app.py +16 -14
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -84,23 +84,25 @@ def video_identity(video,user_name,class_name,trainortest,ready):
84
  train_loader = DataLoader(train_ds, batch_size=2, collate_fn=collator, num_workers=8, shuffle=True)
85
  test_loader = DataLoader(test_ds, batch_size=2, collate_fn=collator, num_workers=7)
86
 
87
-
88
- for name, param in model.named_parameters():
89
- param.requires_grad = False
90
- if name.startswith("classifier"): # choose whatever you like here
91
- param.requires_grad = True
 
 
92
 
93
- pl.seed_everything(42)
94
- classifier = Classifier(model, lr=2e-5)
95
- trainer = pl.Trainer(accelerator='gpu', devices=1, precision=16, max_epochs=3)
96
 
97
- trainer.fit(classifier, train_loader, test_loader)
98
 
99
- for batch_idx, data in enumerate(test_loader):
100
- outputs = model(**data)
101
- img=data['pixel_values'][0][0]
102
- preds=str(outputs.logits.softmax(1).argmax(1))
103
- labels=str(data['labels'])
104
 
105
  return img, preds, labels
106
 
 
84
  train_loader = DataLoader(train_ds, batch_size=2, collate_fn=collator, num_workers=8, shuffle=True)
85
  test_loader = DataLoader(test_ds, batch_size=2, collate_fn=collator, num_workers=7)
86
 
87
+ val_batch = next(iter(test_loader))
88
+ outputs = model(**val_batch)
89
+ preds=outputs.logits.softmax(1).argmax(1)
90
+ # for name, param in model.named_parameters():
91
+ # param.requires_grad = False
92
+ # if name.startswith("classifier"): # choose whatever you like here
93
+ # param.requires_grad = True
94
 
95
+ # pl.seed_everything(42)
96
+ # classifier = Classifier(model, lr=2e-5)
97
+ # trainer = pl.Trainer(accelerator='gpu', devices=1, precision=16, max_epochs=3)
98
 
99
+ # trainer.fit(classifier, train_loader, test_loader)
100
 
101
+ # for batch_idx, data in enumerate(test_loader):
102
+ # outputs = model(**data)
103
+ # img=data['pixel_values'][0][0]
104
+ # preds=str(outputs.logits.softmax(1).argmax(1))
105
+ # labels=str(data['labels'])
106
 
107
  return img, preds, labels
108
 
app.py CHANGED
@@ -84,23 +84,25 @@ def video_identity(video,user_name,class_name,trainortest,ready):
84
  train_loader = DataLoader(train_ds, batch_size=2, collate_fn=collator, num_workers=8, shuffle=True)
85
  test_loader = DataLoader(test_ds, batch_size=2, collate_fn=collator, num_workers=7)
86
 
87
-
88
- for name, param in model.named_parameters():
89
- param.requires_grad = False
90
- if name.startswith("classifier"): # choose whatever you like here
91
- param.requires_grad = True
 
 
92
 
93
- pl.seed_everything(42)
94
- classifier = Classifier(model, lr=2e-5)
95
- trainer = pl.Trainer(accelerator='gpu', devices=1, precision=16, max_epochs=3)
96
 
97
- trainer.fit(classifier, train_loader, test_loader)
98
 
99
- for batch_idx, data in enumerate(test_loader):
100
- outputs = model(**data)
101
- img=data['pixel_values'][0][0]
102
- preds=str(outputs.logits.softmax(1).argmax(1))
103
- labels=str(data['labels'])
104
 
105
  return img, preds, labels
106
 
 
84
  train_loader = DataLoader(train_ds, batch_size=2, collate_fn=collator, num_workers=8, shuffle=True)
85
  test_loader = DataLoader(test_ds, batch_size=2, collate_fn=collator, num_workers=7)
86
 
87
+ val_batch = next(iter(test_loader))
88
+ outputs = model(**val_batch)
89
+ preds=outputs.logits.softmax(1).argmax(1)
90
+ # for name, param in model.named_parameters():
91
+ # param.requires_grad = False
92
+ # if name.startswith("classifier"): # choose whatever you like here
93
+ # param.requires_grad = True
94
 
95
+ # pl.seed_everything(42)
96
+ # classifier = Classifier(model, lr=2e-5)
97
+ # trainer = pl.Trainer(accelerator='gpu', devices=1, precision=16, max_epochs=3)
98
 
99
+ # trainer.fit(classifier, train_loader, test_loader)
100
 
101
+ # for batch_idx, data in enumerate(test_loader):
102
+ # outputs = model(**data)
103
+ # img=data['pixel_values'][0][0]
104
+ # preds=str(outputs.logits.softmax(1).argmax(1))
105
+ # labels=str(data['labels'])
106
 
107
  return img, preds, labels
108