File size: 5,218 Bytes
5bbf5da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import logging
import traceback

from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerModelCardData,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)

# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
    "google/embeddinggemma-300M",
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="EmbeddingGemma-300M trained on the Medical Instruction and RetrIeval Dataset (MIRIAD)",
    ),
)

# 3. Load a dataset to finetune on
train_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="train").select(range(100_000))
eval_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="eval").select(range(1_000))
test_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="test").select(range(1_000))

# 4. Define a loss function. CachedMultipleNegativesRankingLoss (CMNRL) is a special variant of MNRL (a.k.a. InfoNCE),
# which take question-answer pairs (or triplets, etc.) as input. It will take answers from other questions in the batch
# as wrong answers, reducing the distance between the question and the true answer while increasing the distance to the
# wrong answers, in the embedding space.
# The (C)MNRL losses benefit from larger `per_device_train_batch_size` in the Training Arguments, as they can leverage
# more in-batch negative samples. At the same time, the `mini_batch_size` does not affect training performance, but it
# does limit the memory usage. A good trick is setting a high `per_device_train_batch_size` while keeping
# `mini_batch_size` small.
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=8)

# 5. (Optional) Specify training arguments
run_name = "embeddinggemma-300M-medical-100k"
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # (Cached)MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    prompts={  # Map training column names to model prompts
        "question": model.prompts["query"],
        "passage_text": model.prompts["document"],
    },
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=20,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator using the evaluation queries and 31k answers & evaluate the base model
queries = dict(enumerate(eval_dataset["question"]))
corpus = dict(enumerate(eval_dataset["passage_text"] + train_dataset["passage_text"][:30_000]))
relevant_docs = {idx: [idx] for idx in queries}
dev_evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    name="miriad-eval-1kq-31kd",  # 1k questions, 31k passages
    show_progress_bar=True,
)
dev_evaluator(model)

# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# (Optional) Evaluate the trained model on the evaluation set once more, this will also log the results
# and include them in the model card
dev_evaluator(model)

queries = dict(enumerate(test_dataset["question"]))
corpus = dict(enumerate(test_dataset["passage_text"] + train_dataset["passage_text"][:30_000]))
relevant_docs = {idx: [idx] for idx in queries}
test_evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    name="miriad-test-1kq-31kd",  # 1k questions, 31k passages
    show_progress_bar=True,
)
test_evaluator(model)

# 8. Save the trained model
final_output_dir = f"models/{run_name}/final"
model.save_pretrained(final_output_dir)

# 9. (Optional) Push it to the Hugging Face Hub
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
try:
    model.push_to_hub(run_name)
except Exception:
    logging.error(
        f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
        f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` "
        f"and saving it using `model.push_to_hub('{run_name}')`."
    )