chainyo commited on
Commit
af16e92
1 Parent(s): 90c75b6

fix training loop for push to hub

Browse files
Files changed (1) hide show
  1. training_loop.py +7 -5
training_loop.py CHANGED
@@ -74,12 +74,14 @@ def main(
74
 
75
  if push_to_hub:
76
  config = AutoConfig.from_pretrained(f"nvidia/mit-b{model_flavor}")
77
- config.push_to_hub("segformer-sidewalk", repo_url="https://huggingface.co/ChainYo/segformer-sidewalk")
 
 
 
 
78
  checkpoint_path = checkpoint_callback.best_model_filepath
79
- model = SegformerForSemanticSegmentation.from_pretrained(
80
- checkpoint_path, num_labels=num_labels, id2label=id2label, label2id=id2label, config=config,
81
- )
82
- model.push_to_hub("segformer-sidewalk", repo_url="https://huggingface.co/ChainYo/segformer-sidewalk")
83
 
84
 
85
  if __name__ == "__main__":
 
74
 
75
  if push_to_hub:
76
  config = AutoConfig.from_pretrained(f"nvidia/mit-b{model_flavor}")
77
+ config.num_labels = num_labels
78
+ config.id2label = id2label
79
+ config.label2id = {v: k for k, v in id2label_file.items()}
80
+ config.push_to_hub(".", repo_url="https://huggingface.co/ChainYo/segformer-sidewalk")
81
+
82
  checkpoint_path = checkpoint_callback.best_model_filepath
83
+ model = SegformerForSemanticSegmentation.from_pretrained(checkpoint_path, config=config,)
84
+ model.push_to_hub(".", repo_url="https://huggingface.co/ChainYo/segformer-sidewalk")
 
 
85
 
86
 
87
  if __name__ == "__main__":