欧卫
'add_app_files'
58627fa
raw
history blame
No virus
1.26 kB
from colbert.infra.run import Run
from colbert.infra.launcher import Launcher
from colbert.infra.config import ColBERTConfig, RunConfig
from colbert.training.training import train
class Trainer:
def __init__(self, triples, queries, collection, config=None):
self.config = ColBERTConfig.from_existing(config, Run().config)
self.triples = triples
self.queries = queries
self.collection = collection
def configure(self, **kw_args):
self.config.configure(**kw_args)
def train(self, checkpoint='bert-base-uncased'):
"""
Note that config.checkpoint is ignored. Only the supplied checkpoint here is used.
"""
# Resources don't come from the config object. They come from the input parameters.
# TODO: After the API stabilizes, make this "self.config.assign()" to emphasize this distinction.
self.configure(triples=self.triples, queries=self.queries, collection=self.collection)
self.configure(checkpoint=checkpoint)
launcher = Launcher(train)
self._best_checkpoint_path = launcher.launch(self.config, self.triples, self.queries, self.collection)
def best_checkpoint_path(self):
return self._best_checkpoint_path