Instructions to use gbyuvd/RougeBERT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use gbyuvd/RougeBERT with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("fill-mask", model="gbyuvd/RougeBERT")# Load model directly from transformers import AutoModelForMaskedLM model = AutoModelForMaskedLM.from_pretrained("gbyuvd/RougeBERT", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # benchmark_roberta.py | |
| # For benchmarking against RoBERTa on 0.15 MASK PROB MLM | |
| # Consider adjusting the data, tokenizer, and anything else | |
| # according to your use case(s). | |
| # It uses train_eval_utils.py for train, eval, logging helpers | |
| # - gbyuvd | |
| from RougeBERTHF import RougeBERTForMaskedLM, RougeBERTConfig | |
| from transformers import RobertaForMaskedLM, RobertaConfig | |
| from FastChemTokenizer import FastChemTokenizer | |
| from train_eval_utils import train_and_eval, prepare_train_val_test_split | |
| # ---------------------------- | |
| # Tokenizer | |
| # ---------------------------- | |
| tokenizer = FastChemTokenizer.from_pretrained("../smitok") | |
| # ---------------------------- | |
| # Training Hyperparams | |
| # ---------------------------- | |
| BATCH_SIZE = 16 | |
| GRAD_ACCUM = 4 | |
| NUM_EPOCHS = 10 | |
| MAX_SEQ_LEN = 512 | |
| LEARNING_RATE = 1e-5 | |
| MASK_PROB = 0.15 | |
| FULL_CSV = "../data/sample_1k_smi_42.csv" | |
| TRAIN_CSV = "../data/train.csv" | |
| VAL_CSV = "../data/val.csv" | |
| TEST_CSV = "../data/test.csv" | |
| SAVE_DIR = "./pretrained_roguebert" | |
| # ---------------------------- | |
| # Helper: pick correct max_seq length | |
| # ---------------------------- | |
| def get_seq_len(config, default=512): | |
| if hasattr(config, "max_seq"): | |
| return config.max_seq | |
| elif hasattr(config, "max_position_embeddings"): | |
| return config.max_position_embeddings | |
| return default | |
| # ---------------------------- | |
| # RougeBERT Hybrid | |
| # ---------------------------- | |
| rouge_config = RougeBERTConfig(vocab_size=len(tokenizer), max_seq=MAX_SEQ_LEN) | |
| rouge_model = RougeBERTForMaskedLM(rouge_config) # 9022400 params | |
| rouge_results = train_and_eval( | |
| rouge_model, | |
| tokenizer, | |
| TRAIN_CSV, | |
| VAL_CSV, | |
| TEST_CSV, | |
| rouge_config, | |
| run_name="rougebert", | |
| batch_size=BATCH_SIZE, | |
| grad_accum=GRAD_ACCUM, | |
| num_epochs=NUM_EPOCHS, | |
| learning_rate=LEARNING_RATE, | |
| mask_prob=MASK_PROB, | |
| max_seq_len=get_seq_len(rouge_config, MAX_SEQ_LEN), | |
| ) | |
| # ---------------------------- | |
| # Vanilla BERT | |
| # ---------------------------- | |
| roberta_config = RobertaConfig( | |
| vocab_size=len(tokenizer), | |
| hidden_size=282, # | |
| intermediate_size=1300, # FFN size | |
| num_attention_heads=6, | |
| num_hidden_layers=8, | |
| max_position_embeddings=MAX_SEQ_LEN + 2, # RoBERTa usually adds 2 for special offsets | |
| type_vocab_size=1, # RoBERTa doesn't use token_type_ids | |
| pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0, | |
| bos_token_id=tokenizer.bos_token_id if hasattr(tokenizer, 'bos_token_id') else 0, | |
| eos_token_id=tokenizer.eos_token_id if hasattr(tokenizer, 'eos_token_id') else 2, | |
| # RoBERTa-specific | |
| attention_probs_dropout_prob=0.1, | |
| hidden_dropout_prob=0.1, | |
| initializer_range=0.02, | |
| layer_norm_eps=1e-5, | |
| ) | |
| roberta_model = RobertaForMaskedLM(roberta_config) # 9017590 params | |
| bert_results = train_and_eval( | |
| roberta_model, | |
| tokenizer, | |
| TRAIN_CSV, | |
| VAL_CSV, | |
| TEST_CSV, | |
| roberta_config, | |
| run_name="roberta_baseline", | |
| batch_size=BATCH_SIZE, | |
| grad_accum=GRAD_ACCUM, | |
| num_epochs=NUM_EPOCHS, | |
| learning_rate=LEARNING_RATE, | |
| mask_prob=MASK_PROB, | |
| max_seq_len=get_seq_len(roberta_config, MAX_SEQ_LEN), | |
| ) | |
| # ---------------------------- | |
| # Print comparison | |
| # ---------------------------- | |
| print("RougeBERT:", rouge_results) | |
| print("Vanilla BERT:", bert_results) | |