Spaces:
Sleeping
Sleeping
File size: 1,495 Bytes
626eca0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from relik.retriever.trainer import RetrieverTrainer
from relik import GoldenRetriever
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
from relik.retriever.data.datasets import AidaInBatchNegativesDataset
if __name__ == "__main__":
# instantiate retriever
document_index = InMemoryDocumentIndex(
documents="/root/golden-retriever-v2/data/dpr-like/el/definitions.txt",
device="cuda",
precision="16",
)
retriever = GoldenRetriever(
question_encoder="intfloat/e5-small-v2", document_index=document_index
)
train_dataset = AidaInBatchNegativesDataset(
name="aida_train",
path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/train.jsonl",
tokenizer=retriever.question_tokenizer,
question_batch_size=64,
passage_batch_size=400,
max_passage_length=64,
use_topics=True,
shuffle=True,
)
val_dataset = AidaInBatchNegativesDataset(
name="aida_val",
path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/val.jsonl",
tokenizer=retriever.question_tokenizer,
question_batch_size=64,
passage_batch_size=400,
max_passage_length=64,
use_topics=True,
)
trainer = RetrieverTrainer(
retriever=retriever,
train_dataset=train_dataset,
val_dataset=val_dataset,
max_steps=25_000,
wandb_offline_mode=True,
)
trainer.train()
|