Spaces:
Runtime error
Runtime error
File size: 1,257 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
from colbert.infra.run import Run
from colbert.infra.launcher import Launcher
from colbert.infra.config import ColBERTConfig, RunConfig
from colbert.training.training import train
class Trainer:
def __init__(self, triples, queries, collection, config=None):
self.config = ColBERTConfig.from_existing(config, Run().config)
self.triples = triples
self.queries = queries
self.collection = collection
def configure(self, **kw_args):
self.config.configure(**kw_args)
def train(self, checkpoint='bert-base-uncased'):
"""
Note that config.checkpoint is ignored. Only the supplied checkpoint here is used.
"""
# Resources don't come from the config object. They come from the input parameters.
# TODO: After the API stabilizes, make this "self.config.assign()" to emphasize this distinction.
self.configure(triples=self.triples, queries=self.queries, collection=self.collection)
self.configure(checkpoint=checkpoint)
launcher = Launcher(train)
self._best_checkpoint_path = launcher.launch(self.config, self.triples, self.queries, self.collection)
def best_checkpoint_path(self):
return self._best_checkpoint_path
|