| | from .vss import VSS |
| | from .llm_reranker import LLMReranker |
| | from .multi_vss import MultiVSS |
| | from .bm25 import BM25 |
| | from .colbertv2 import Colbertv2 |
| |
|
| | def get_model(args, skb, **kwargs): |
| | model_name = args.model |
| | if model_name == 'BM25': |
| | return BM25(skb) |
| | if model_name == 'Colbertv2': |
| | try: |
| | return Colbertv2(skb, |
| | dataset_name=args.dataset, |
| | save_dir=args.output_dir, |
| | download_dir=args.download_dir, |
| | human_generated_eval=args.split=='human_generated_eval', |
| | **kwargs |
| | ) |
| | except ImportError: |
| | raise ImportError("Please install the colbert package using `pip install colbert-ai`.") |
| | elif model_name == 'VSS': |
| | return VSS( |
| | skb, |
| | emb_model=args.emb_model, |
| | query_emb_dir=args.query_emb_dir, |
| | candidates_emb_dir=args.node_emb_dir, |
| | device=args.device |
| | ) |
| | if model_name == 'MultiVSS': |
| | return MultiVSS( |
| | skb, |
| | emb_model=args.emb_model, |
| | query_emb_dir=args.query_emb_dir, |
| | candidates_emb_dir=args.node_emb_dir, |
| | chunk_emb_dir=args.chunk_emb_dir, |
| | aggregate=args.aggregate, |
| | chunk_size=args.chunk_size, |
| | max_k=args.multi_vss_topk, |
| | device=args.device |
| | ) |
| | if model_name == 'LLMReranker': |
| | return LLMReranker(skb, |
| | emb_model=args.emb_model, |
| | llm_model=args.llm_model, |
| | query_emb_dir=args.query_emb_dir, |
| | candidates_emb_dir=args.node_emb_dir, |
| | max_cnt = args.max_retry, |
| | max_k=args.llm_topk, |
| | device=args.device |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | raise NotImplementedError(f'{model_name} not implemented') |