File size: 3,753 Bytes
e1017ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#
#  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

import argparse
from collections import defaultdict
from api.db import FileType, TaskStatus, ParserType, LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.settings import retrievaler
from api.utils import get_uuid
from rag.nlp import tokenize, search
from rag.utils.es_conn import ELASTICSEARCH
from ranx import evaluate


class benchmark_ndcg10:
    def __init__(self, kb_id):
        e, kb = KnowledgebaseService.get_by_id(kb_id)
        self.similarity_threshold = kb.similarity_threshold
        self.vector_similarity_weight = kb.vector_similarity_weight
        self.embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)

    def _get_benchmarks(self, query, count=16):
        req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold}
        sres = retrievaler.search(req, search.index_name("benchmark"), self.embd_mdl)
        return sres

    def _get_retrieval(self, qrels):
        run = defaultdict(dict)
        query_list = list(qrels.keys())
        for query in query_list:
            sres = self._get_benchmarks(query)
            sim, _, _ = retrievaler.rerank(sres, query, 1 - self.vector_similarity_weight,
                                           self.vector_similarity_weight)
            for index, id in enumerate(sres.ids):
                run[query][id] = sim[index]
        return run

    def embedding(self, docs, batch_size=16):
        vects = []
        cnts = [d["content_with_weight"] for d in docs]
        for i in range(0, len(cnts), batch_size):
            vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
            vects.extend(vts.tolist())
        assert len(docs) == len(vects)
        for i, d in enumerate(docs):
            v = vects[i]
            d["q_%d_vec" % len(v)] = v
        return docs

    def __call__(self, file_path):
        qrels = defaultdict(dict)

        docs = []
        with open(file_path) as f:
            for line in f:
                query, text, rel = line.strip('\n').split()
                d = {
                    "id": get_uuid()
                }
                tokenize(d, text)
                docs.append(d)
                if len(docs) >= 32:
                    ELASTICSEARCH.bulk(docs, search.index_name("benchmark"))
                    docs = []
                qrels[query][d["id"]] = float(rel)
            docs = self.embedding(docs)
            ELASTICSEARCH.bulk(docs, search.index_name("benchmark"))

        run = self._get_retrieval(qrels)
        return evaluate(qrels, run, "ndcg@10")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--filepath', default='', help="file path", action='store', required=True)
    parser.add_argument('-k', '--kb_id', default='', help="kb_id", action='store', required=True)
    args = parser.parse_args()

    ex = benchmark_ndcg10(args.kb_id)
    print(ex(args.filepath))