File size: 1,666 Bytes
cf57696
 
 
 
 
 
 
 
 
 
 
 
988981a
cf57696
 
 
 
 
 
 
 
 
 
 
 
 
16cc722
cf57696
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import time
import json

import wandb
import gradio as gr

from qa_engine import logger, Config, QAEngine


QUESTIONS_FILENAME = 'benchmark/questions.json'

config = Config()
qa_engine = QAEngine(config=config)


def main():
    filtered_config = config.asdict()
    disallowed_config_keys = [
        "DISCORD_TOKEN", "NUM_LAST_MESSAGES", "USE_NAMES_IN_CONTEXT",
        "ENABLE_COMMANDS", "APP_MODE", "DEBUG"
    ]
    for key in disallowed_config_keys:
        filtered_config.pop(key, None)

    wandb.init(
        project='HF-Docs-QA',
        entity='hf-qa-bot',
        name=f'{config.question_answering_model_id} - {config.embedding_model_id} - {config.index_repo_id}',
        mode='run', # run/disabled
        config=filtered_config
    )

    with open(QUESTIONS_FILENAME, 'r') as f:
        questions = json.load(f)

    table = wandb.Table(
        columns=[
            "id", "question", "messages_context", "answer", "sources", "time"
        ]
    )
    for i, q in enumerate(questions):
        logger.info(f"Question {i+1}/{len(questions)}")

        question = q['question']
        messages_context = q['messages_context']

        time_start = time.perf_counter()
        response = qa_engine.get_response(
            question=question, 
            messages_context=messages_context
        )
        time_end = time.perf_counter()

        table.add_data(
            i,
            question,
            messages_context,
            response.get_answer(),
            response.get_sources_as_text(),
            time_end - time_start
        )

    wandb.log({"answers": table})
    wandb.finish()


if __name__ == '__main__':
    main()