Spaces:
Sleeping
Sleeping
from relik.retriever.trainer import RetrieverTrainer | |
from relik import GoldenRetriever | |
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex | |
from relik.retriever.data.datasets import AidaInBatchNegativesDataset | |
if __name__ == "__main__": | |
# instantiate retriever | |
document_index = InMemoryDocumentIndex( | |
documents="/root/golden-retriever-v2/data/dpr-like/el/definitions.txt", | |
device="cuda", | |
precision="16", | |
) | |
retriever = GoldenRetriever( | |
question_encoder="intfloat/e5-small-v2", document_index=document_index | |
) | |
train_dataset = AidaInBatchNegativesDataset( | |
name="aida_train", | |
path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/train.jsonl", | |
tokenizer=retriever.question_tokenizer, | |
question_batch_size=64, | |
passage_batch_size=400, | |
max_passage_length=64, | |
use_topics=True, | |
shuffle=True, | |
) | |
val_dataset = AidaInBatchNegativesDataset( | |
name="aida_val", | |
path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/val.jsonl", | |
tokenizer=retriever.question_tokenizer, | |
question_batch_size=64, | |
passage_batch_size=400, | |
max_passage_length=64, | |
use_topics=True, | |
) | |
trainer = RetrieverTrainer( | |
retriever=retriever, | |
train_dataset=train_dataset, | |
val_dataset=val_dataset, | |
max_steps=25_000, | |
wandb_offline_mode=True, | |
) | |
trainer.train() | |