KonradSzafer commited on
Commit
c69cba4
1 Parent(s): 2a53bac

initial commit

Browse files
Files changed (49) hide show
  1. .github/workflows/tests-integration.yml +37 -0
  2. .gitignore +68 -0
  3. Dockerfile.api +19 -0
  4. Dockerfile.bot +17 -0
  5. LICENSE +21 -0
  6. api/__init__.py +0 -0
  7. api/__main__.py +39 -0
  8. app.py +33 -0
  9. assets/example.png +0 -0
  10. benchmarker.py +63 -0
  11. config/.env.example +16 -0
  12. config/prompt_templates/llama.txt +7 -0
  13. config/prompt_templates/llama2.txt +6 -0
  14. config/prompt_templates/sythia_v1.3.txt +3 -0
  15. data/benchmark/.gitkeep +0 -0
  16. data/datasets/hf_repositories_urls.json +22 -0
  17. data/datasets/hf_repositories_urls_scraped.json +92 -0
  18. data/get_hf_repositories_urls.py +49 -0
  19. data/hugging_face_docs_dataset.py +190 -0
  20. data/indexer.ipynb +238 -0
  21. data/language-codes.csv +190 -0
  22. data/scrapers/stack_overflow_scraper.py +91 -0
  23. data/stackoverflow_python_dataset.py +55 -0
  24. data/upload_csv_dataset.py +24 -0
  25. discord_bot/__init__.py +0 -0
  26. discord_bot/__main__.py +28 -0
  27. discord_bot/client/__init__.py +1 -0
  28. discord_bot/client/client.py +130 -0
  29. discord_bot/client/utils.py +54 -0
  30. docker-compose.yml +23 -0
  31. models/inference.ipynb +103 -0
  32. qa_engine/__init__.py +9 -0
  33. qa_engine/config.py +64 -0
  34. qa_engine/logger.py +14 -0
  35. qa_engine/mocks.py +41 -0
  36. qa_engine/qa_engine.py +286 -0
  37. qa_engine/response.py +33 -0
  38. questions.txt +9 -0
  39. requirements.txt +28 -0
  40. run_docker.sh +2 -0
  41. run_tests.sh +2 -0
  42. tests/__init__.py +0 -0
  43. tests/discord_bot/__init__.py +0 -0
  44. tests/discord_bot/client/__init__.py +0 -0
  45. tests/discord_bot/client/lorem_ipsum.txt +14 -0
  46. tests/discord_bot/client/test_utils.py +69 -0
  47. tests/index/test_index.py +48 -0
  48. tests/qa_engine/__init__.py +0 -0
  49. tests/qa_engine/test_response.py +30 -0
.github/workflows/tests-integration.yml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CI
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ - feature/**
8
+ - issue/**
9
+
10
+ jobs:
11
+ build:
12
+ strategy:
13
+ matrix:
14
+ python-version: [3.11]
15
+ runs-on: ubuntu-latest
16
+
17
+ steps:
18
+ - name: Checkout
19
+ uses: actions/checkout@v2
20
+ with:
21
+ fetch-depth: 0
22
+
23
+ - name: Switch to Current Branch
24
+ run: git checkout ${{ env.BRANCH }}
25
+
26
+ - name: Set up Python ${{ matrix.python-version }}
27
+ uses: actions/setup-python@v1
28
+ with:
29
+ python-version: ${{ matrix.python-version }}
30
+
31
+ - name: Install dependencies
32
+ run: |
33
+ pip install --no-cache-dir -r requirements.txt
34
+ cp config/.env.example config/.env
35
+ - name: Run unit tests
36
+ run: |
37
+ pytest -o "testpaths=tests" --noconftest
.gitignore ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # Unit test / coverage reports
31
+ htmlcov/
32
+ .tox/
33
+ .nox/
34
+ .coverage
35
+ .coverage.*
36
+ .cache
37
+ nosetests.xml
38
+ coverage.xml
39
+ *.cover
40
+ *.py,cover
41
+ .hypothesis/
42
+ .pytest_cache/
43
+
44
+ # Environments
45
+ .env
46
+ .venv
47
+ env/
48
+ venv/
49
+
50
+ # OS
51
+ .DS_Store
52
+
53
+ # IDE
54
+ .vscode
55
+
56
+ # Project
57
+ wandb/
58
+ indexes/
59
+ *.out
60
+ *.log
61
+
62
+ # Data
63
+ data/datasets/*
64
+ !data/datasets/hf_repositories_urls.json
65
+ !data/datasets/hf_repositories_urls_scraped.json
66
+
67
+ # Local models
68
+ qa_engine/local_models
Dockerfile.api ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ubuntu:latest
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive
4
+
5
+ RUN apt-get -y update && \
6
+ apt-get -y upgrade && \
7
+ apt-get -y install git python3.10 python3-pip
8
+
9
+ COPY requirements.txt .
10
+ RUN pip install --upgrade pip && \
11
+ pip install --no-cache-dir -r requirements.txt
12
+
13
+ WORKDIR /hugging-face-qa-bot
14
+ COPY config/api/ config/api/
15
+ COPY api/ api/
16
+
17
+ EXPOSE 8000
18
+
19
+ ENTRYPOINT [ "python3", "-m", "api" ]
Dockerfile.bot ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ubuntu:latest
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive
4
+
5
+ RUN apt-get -y update && \
6
+ apt-get -y upgrade && \
7
+ apt-get -y install git python3.10 python3-pip
8
+
9
+ COPY requirements.txt .
10
+ RUN pip install --upgrade pip && \
11
+ pip install --no-cache-dir -r requirements.txt
12
+
13
+ WORKDIR /hugging-face-qa-bot
14
+ COPY config/bot/ config/bot/
15
+ COPY bot/ bot/
16
+
17
+ ENTRYPOINT [ "python3", "-m", "bot" ]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
api/__init__.py ADDED
File without changes
api/__main__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+ from fastapi import FastAPI
3
+
4
+ from qa_engine import logger, Config, QAEngine
5
+
6
+
7
+ config = Config()
8
+ app = FastAPI()
9
+ qa_engine = QAEngine(
10
+ llm_model_id=config.question_answering_model_id,
11
+ embedding_model_id=config.embedding_model_id,
12
+ index_repo_id=config.index_repo_id,
13
+ prompt_template=config.prompt_template,
14
+ use_docs_for_context=config.use_docs_for_context,
15
+ num_relevant_docs=config.num_relevant_docs,
16
+ add_sources_to_response=config.add_sources_to_response,
17
+ use_messages_for_context=config.use_messages_in_context,
18
+ debug=config.debug
19
+ )
20
+
21
+
22
+ @app.get('/')
23
+ def get_answer(question: str, messages_context: str = ''):
24
+ logger.info(
25
+ f'Received request with question: {question}' \
26
+ f'and context: {messages_context}'
27
+ )
28
+ response = qa_engine.get_response(
29
+ question=question,
30
+ messages_context=messages_context
31
+ )
32
+ return {
33
+ 'answer': response.get_answer(),
34
+ 'sources': response.get_sources_as_text()
35
+ }
36
+
37
+
38
+ if __name__ == '__main__':
39
+ uvicorn.run(app, host='0.0.0.0', port=8000)
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from qa_engine import logger, Config, QAEngine
4
+
5
+
6
+ config = Config()
7
+ model = QAEngine(
8
+ llm_model_id=config.question_answering_model_id,
9
+ embedding_model_id=config.embedding_model_id,
10
+ index_repo_id=config.index_repo_id,
11
+ prompt_template=config.prompt_template,
12
+ use_docs_for_context=config.use_docs_for_context,
13
+ add_sources_to_response=config.add_sources_to_response,
14
+ use_messages_for_context=config.use_messages_in_context,
15
+ debug=config.debug
16
+ )
17
+
18
+ with gr.Blocks() as demo:
19
+ chatbot = gr.Chatbot()
20
+ msg = gr.Textbox()
21
+ clear = gr.ClearButton([msg, chatbot])
22
+
23
+ def respond(message, chat_history):
24
+ context = ''.join(f'User: {msg} \nBot:{bot_msg}\n' for msg, bot_msg in chat_history)
25
+ logger.info(f'Context: {context}')
26
+ response = model.get_response(message, context)
27
+ bot_message = response.get_answer() + response.get_sources_as_text() + '\n'
28
+ chat_history.append((message, bot_message))
29
+ return '', chat_history
30
+
31
+ msg.submit(respond, [msg, chatbot], [msg, chatbot])
32
+
33
+ demo.launch(share=True)
assets/example.png ADDED
benchmarker.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from dotenv import load_dotenv
3
+ from api.config import Config
4
+ from api.logger import logger
5
+ from api.question_answering import QAModel
6
+ import time
7
+
8
+
9
+ load_dotenv(dotenv_path='config/api/.env')
10
+
11
+ config = Config()
12
+ model = QAModel(
13
+ llm_model_id=config.question_answering_model_id,
14
+ embedding_model_id=config.embedding_model_id,
15
+ index_repo_id=config.index_repo_id,
16
+ prompt_template=config.prompt_template,
17
+ use_docs_for_context=config.use_docs_for_context,
18
+ add_sources_to_response=config.add_sources_to_response,
19
+ use_messages_for_context=config.use_messages_in_context,
20
+ debug=config.debug
21
+ )
22
+
23
+ QUESTIONS_FILENAME = 'data/benchmark/questions.json'
24
+ ANSWERS_FILENAME = 'data/benchmark/answers.json'
25
+
26
+
27
+ def main():
28
+ benchmark_name = \
29
+ f'model: {config.question_answering_model_id}' \
30
+ f'index: {config.index_repo_id}'
31
+
32
+ wandb.init(
33
+ project='HF-Docs-QA',
34
+ name=f'model: {config.question_answering_model_id}',
35
+ mode='run', # run/disabled
36
+ config=config.asdict()
37
+ )
38
+ # log config to wandb
39
+
40
+ with open(QUESTIONS_FILENAME, 'r') as f: # json
41
+ questions = f.readlines()
42
+
43
+ with open(ANSWERS_FILENAME, 'w') as f:
44
+ for q in questions:
45
+ question = q['question']
46
+ messages_contex = q['messages_context']
47
+
48
+ t_start = time.perf_counter()
49
+ response = model.get_response(
50
+ question=question,
51
+ messages_context=messages_context
52
+ )
53
+ t_end = time.perf_counter()
54
+ # write to json
55
+ {
56
+ "answer": response.get_answer(),
57
+ "sources": response.get_sources_as_text(),
58
+ 'time': t_end - t_start
59
+ }
60
+
61
+
62
+ if __name__ == '__main__':
63
+ main()
config/.env.example ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # QA engine settings
2
+ QUESTION_ANSWERING_MODEL_ID=hf-question-answering-model-ID
3
+ EMBEDDING_MODEL_ID=hf-embedding-model-ID
4
+ INDEX_REPO_ID=hf-index-ID
5
+ PROMPT_TEMPLATE_NAME=llama
6
+ USE_DOCS_FOR_CONTEXT=True
7
+ NUM_RELEVANT_DOCS=4
8
+ ADD_SOURCES_TO_RESPONSE=True
9
+ USE_MESSAGES_IN_CONTEXT=True
10
+ DEBUG=True
11
+
12
+ # Discord settings
13
+ DISCORD_TOKEN=your-bot-token
14
+ NUM_LAST_MESSAGES=1
15
+ USE_NAMES_IN_CONTEXT=False
16
+ ENABLE_COMMANDS=True
config/prompt_templates/llama.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ### Instruction:
2
+ Give an answer that contains all the necessary information for the question.
3
+ If the context contains necessary information to answer question, use it to generate an appropriate response.
4
+ {context}
5
+ ### Input:
6
+ {question}
7
+ ### Response:
config/prompt_templates/llama2.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <<SYS>>
2
+ **CONTEXT FOR THE QUESTION**:
3
+ {context}
4
+ <</SYS>>
5
+
6
+ [INST] Respond for the user question as factualy as possible, using given context. [/INST] User: {question}
config/prompt_templates/sythia_v1.3.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ SYSTEM: {context}
2
+ USER: {question}
3
+ ASSISTANT:
data/benchmark/.gitkeep ADDED
File without changes
data/datasets/hf_repositories_urls.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "urls": [
3
+ "https://github.com/huggingface/transformers",
4
+ "https://github.com/huggingface/diffusers",
5
+ "https://github.com/huggingface/datasets",
6
+ "https://github.com/gradio-app/gradio",
7
+ "https://github.com/huggingface/huggingface_hub",
8
+ "https://github.com/huggingface/peft",
9
+ "https://github.com/huggingface/blog",
10
+ "https://github.com/huggingface/optimum",
11
+ "https://github.com/huggingface/tokenizers",
12
+ "https://github.com/huggingface/course",
13
+ "https://github.com/huggingface/deep-rl-class",
14
+ "https://github.com/huggingface/evaluate",
15
+ "https://github.com/huggingface/datasets-server",
16
+ "https://github.com/huggingface/simulate",
17
+ "https://github.com/huggingface/hub-docs",
18
+ "https://github.com/huggingface/pytorch-image-models",
19
+ "https://github.com/huggingface/safetensors",
20
+ "https://github.com/huggingface/hf-endpoints-documentation"
21
+ ]
22
+ }
data/datasets/hf_repositories_urls_scraped.json ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "urls": [
3
+ "https://github.com/huggingface/tokenizers",
4
+ "https://github.com/huggingface/datablations",
5
+ "https://github.com/huggingface/peft",
6
+ "https://github.com/huggingface/tflite-android-transformers",
7
+ "https://github.com/huggingface/simulate",
8
+ "https://github.com/huggingface/transformers",
9
+ "https://github.com/huggingface/deep-rl-class",
10
+ "https://github.com/huggingface/awesome-huggingface",
11
+ "https://github.com/huggingface/datasets-server",
12
+ "https://github.com/huggingface/setfit",
13
+ "https://github.com/huggingface/olm-training",
14
+ "https://github.com/huggingface/huggingface_sb3",
15
+ "https://github.com/huggingface/optimum-neuron",
16
+ "https://github.com/huggingface/blog",
17
+ "https://github.com/huggingface/100-times-faster-nlp",
18
+ "https://github.com/huggingface/bloom-jax-inference",
19
+ "https://github.com/huggingface/speechbox",
20
+ "https://github.com/huggingface/olm-datasets",
21
+ "https://github.com/huggingface/hub-docs",
22
+ "https://github.com/huggingface/torchMoji",
23
+ "https://github.com/huggingface/hffs",
24
+ "https://github.com/huggingface/trl",
25
+ "https://github.com/huggingface/text-generation-inference",
26
+ "https://github.com/huggingface/Mongoku",
27
+ "https://github.com/huggingface/education-toolkit",
28
+ "https://github.com/huggingface/datasets",
29
+ "https://github.com/huggingface/optimum-benchmark",
30
+ "https://github.com/huggingface/course",
31
+ "https://github.com/huggingface/accelerate",
32
+ "https://github.com/huggingface/pytorch-image-models",
33
+ "https://github.com/huggingface/fuego",
34
+ "https://github.com/huggingface/diffusion-models-class",
35
+ "https://github.com/huggingface/disaggregators",
36
+ "https://github.com/huggingface/unity-api",
37
+ "https://github.com/huggingface/workshops",
38
+ "https://github.com/huggingface/llm-ls",
39
+ "https://github.com/huggingface/llm-vscode",
40
+ "https://github.com/huggingface/community-events",
41
+ "https://github.com/huggingface/tune",
42
+ "https://github.com/huggingface/candle",
43
+ "https://github.com/huggingface/paper-style-guide",
44
+ "https://github.com/huggingface/huggingface.js",
45
+ "https://github.com/huggingface/neuralcoref",
46
+ "https://github.com/huggingface/hfapi",
47
+ "https://github.com/huggingface/data-measurements-tool",
48
+ "https://github.com/huggingface/personas",
49
+ "https://github.com/huggingface/instruction-tuned-sd",
50
+ "https://github.com/huggingface/swift-transformers",
51
+ "https://github.com/huggingface/api-inference-community",
52
+ "https://github.com/huggingface/diffusers",
53
+ "https://github.com/huggingface/safetensors",
54
+ "https://github.com/huggingface/optimum-graphcore",
55
+ "https://github.com/huggingface/OBELICS",
56
+ "https://github.com/huggingface/swift-coreml-diffusers",
57
+ "https://github.com/huggingface/naacl_transfer_learning_tutorial",
58
+ "https://github.com/huggingface/nn_pruning",
59
+ "https://github.com/huggingface/awesome-papers",
60
+ "https://github.com/huggingface/optimum-intel",
61
+ "https://github.com/huggingface/autotrain-advanced",
62
+ "https://github.com/huggingface/pytorch-openai-transformer-lm",
63
+ "https://github.com/huggingface/node-question-answering",
64
+ "https://github.com/huggingface/optimum",
65
+ "https://github.com/huggingface/knockknock",
66
+ "https://github.com/huggingface/optimum-habana",
67
+ "https://github.com/huggingface/transfer-learning-conv-ai",
68
+ "https://github.com/huggingface/notebooks",
69
+ "https://github.com/huggingface/hmtl",
70
+ "https://github.com/huggingface/block_movement_pruning",
71
+ "https://github.com/huggingface/huggingface_hub",
72
+ "https://github.com/huggingface/transformers-bloom-inference",
73
+ "https://github.com/huggingface/hf_transfer",
74
+ "https://github.com/huggingface/doc-builder",
75
+ "https://github.com/huggingface/large_language_model_training_playbook",
76
+ "https://github.com/huggingface/that_is_good_data",
77
+ "https://github.com/huggingface/swift-coreml-transformers",
78
+ "https://github.com/huggingface/datasets-viewer",
79
+ "https://github.com/huggingface/open-muse",
80
+ "https://github.com/huggingface/evaluate",
81
+ "https://github.com/huggingface/llm_training_handbook",
82
+ "https://github.com/huggingface/pytorch_block_sparse",
83
+ "https://github.com/huggingface/chat-ui",
84
+ "https://github.com/huggingface/llm.nvim",
85
+ "https://github.com/huggingface/swift-chat",
86
+ "https://github.com/huggingface/pytorch-pretrained-BigGAN",
87
+ "https://github.com/huggingface/exporters",
88
+ "https://github.com/huggingface/audio-transformers-course",
89
+ "https://github.com/huggingface/hf-endpoints-documentation",
90
+ "https://github.com/gradio-app/gradio"
91
+ ]
92
+ }
data/get_hf_repositories_urls.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import requests
4
+ from typing import List
5
+
6
+
7
+ def get_repositories_names(token: str, min_stars: int) -> List[str]:
8
+ repos_per_page = 100
9
+ repo_names = []
10
+ i = 0
11
+ while True:
12
+ url = \
13
+ f'https://api.github.com/orgs/huggingface/repos?' \
14
+ f'per_page={repos_per_page}&page={i}'
15
+ headers = {'Authorization': f'token {token}'}
16
+ response = requests.get(url, headers=headers)
17
+ if response.status_code == 200:
18
+ repos = json.loads(response.content)
19
+ repo_names += [
20
+ repo['full_name'] for repo in repos
21
+ if repo['stargazers_count'] >= min_stars
22
+ ]
23
+ if len(repos) < repos_per_page:
24
+ break
25
+ i += 1
26
+ else:
27
+ return 'Error: '+str(response.status_code)
28
+ return list(set(repo_names))
29
+
30
+
31
+ def save_repositories_urls(repositories_names: List[str], output_filename: str):
32
+ urls = [f'https://github.com/{repo_name}' for repo_name in repositories_names]
33
+ data = {'urls': urls}
34
+ with open(output_filename, 'w') as f:
35
+ json.dump(data, f, indent=4)
36
+
37
+
38
+ if __name__ == '__main__':
39
+ parser = argparse.ArgumentParser()
40
+ parser.add_argument('--token', type=str)
41
+ parser.add_argument('--stars', type=str)
42
+ args = parser.parse_args()
43
+ repositories = get_repositories_names(token=args.token, min_stars=int(args.stars))
44
+ repositories += [
45
+ 'huggingface/hf-endpoints-documentation',
46
+ 'gradio-app/gradio'
47
+ ]
48
+ print(f'Found {len(repositories)} repositories with at least {args.stars} stars')
49
+ save_repositories_urls(repositories, 'datasets/hf_repositories_urls_scraped.json')
data/hugging_face_docs_dataset.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import os
4
+ import re
5
+ import subprocess
6
+ from typing import List
7
+
8
+ import requests
9
+ import pandas as pd
10
+ from bs4 import BeautifulSoup
11
+ from markdown import markdown
12
+ import nbformat
13
+ from nbconvert import MarkdownExporter
14
+ from nbconvert.preprocessors import Preprocessor, ClearOutputPreprocessor
15
+ from tqdm import tqdm
16
+
17
+
18
+ VALIDATE_URLS = False
19
+
20
+
21
+ def download_repositories(repo_urls_file: str, repo_dir: str):
22
+ """
23
+ Downloads the Hugging Face repositories.
24
+ """
25
+ if not os.path.exists(repo_dir):
26
+ os.makedirs(repo_dir)
27
+ with open(repo_urls_file, "r") as f:
28
+ repositories_urls = json.load(f)["urls"]
29
+ print(f'Downloading {len(repositories_urls)} repositories')
30
+ for url in repositories_urls:
31
+ try:
32
+ subprocess.run(["git", "clone", url], cwd=repo_dir)
33
+ except subprocess.CalledProcessError as e:
34
+ print("Command failed with error:", e.stderr)
35
+
36
+
37
+ class EmptyCellPreprocessor(Preprocessor):
38
+ def preprocess_cell(self, cell, resources, index):
39
+ if cell.source.strip() == '':
40
+ cell.source = ''
41
+ cell.cell_type = 'raw'
42
+ return cell, resources
43
+
44
+
45
+ def convert_notebook_to_txt(filename: str):
46
+ """
47
+ Converts a notebook to a markdown file.
48
+ """
49
+ with open(filename) as f:
50
+ notebook = nbformat.read(f, as_version=4)
51
+ # id validation error fix
52
+ for cell in notebook['cells']:
53
+ cell['id'] = str(cell['id'])
54
+
55
+ clear_output = ClearOutputPreprocessor()
56
+ notebook, resources = clear_output.preprocess(notebook, {})
57
+
58
+ exporter = MarkdownExporter()
59
+ exporter.register_preprocessor(EmptyCellPreprocessor, enabled=True)
60
+ output_notebook_text, resources = exporter.from_notebook_node(notebook)
61
+
62
+ new_filename = filename.replace('.ipynb', '_ipynb.md')
63
+ with open(new_filename, 'w') as f:
64
+ f.write(output_notebook_text)
65
+ return new_filename
66
+
67
+
68
+ def extract_files_from_directories(
69
+ repo_urls_file: str,
70
+ repo_dir: str,
71
+ docs_dir: str,
72
+ files_extensions: List[str]
73
+ ) -> None:
74
+
75
+ """
76
+ This function reads markdown and markdownx files from the repositories directory,
77
+ filters out non-English files, and adds the source GitHub URL as the first line of each file.
78
+ The resulting files are saved in the docs_dir.
79
+ """
80
+ languages = pd.read_csv("language-codes.csv").loc[:,"alpha2"].tolist()
81
+ languages.remove("en")
82
+
83
+ files = [
84
+ filename
85
+ for extension in files_extensions
86
+ for filename in glob.glob(repo_dir + f"**/*{extension}", recursive=True)
87
+ ]
88
+ print(f'Used extensions: {", ".join(files_extensions)}')
89
+ print(f'Found {len(files)} files')
90
+
91
+ repo_urls = []
92
+ with open(repo_urls_file, "r") as f:
93
+ repo_urls = json.load(f)["urls"]
94
+
95
+ # filter out the files that are not in english
96
+ filtered_files = []
97
+ for filename in files:
98
+ sep_file = filename.split("/")
99
+ for seq in sep_file:
100
+ if seq in languages:
101
+ break
102
+ else:
103
+ filtered_files.append(filename)
104
+ print(f'Found {len(filtered_files)} files in English')
105
+
106
+ # generate a GitHub URL for a file based on its name and a list of possible repository URLs
107
+ def get_github_url(filename: str, repo_urls: str, repo_dir: str) -> str:
108
+ source = filename.replace(repo_dir, '')
109
+ repo_name, file_path = source.split('/', 1)
110
+ repo_url_prefix = None
111
+ for repo_url in repo_urls:
112
+ if repo_name == repo_url.split('/')[-1]:
113
+ repo_url_prefix = repo_url
114
+ break
115
+ if not repo_url_prefix:
116
+ raise ValueError(f"Repo URL not found for {repo_name}")
117
+ url = f'{repo_url_prefix}/blob/main/{file_path}'
118
+ if VALIDATE_URLS:
119
+ try:
120
+ response = requests.get(url)
121
+ response.raise_for_status()
122
+ except:
123
+ print(f'filename: {filename}')
124
+ print(f'repo: {repo_name}, file: {file_path}')
125
+ print(f'url: {url}')
126
+ raise
127
+ return url
128
+
129
+ # creates a valid filename by replacing certain characters and removing the repo_dir path
130
+ def create_filename_from_path(filename: str, repo_dir: str) -> str:
131
+ filename = filename.replace(repo_dir, '')
132
+ chars_to_replace = ['/', '{', '}', '-', '.']
133
+ filename = ''.join(['_' if c in chars_to_replace else c for c in filename])
134
+ return filename
135
+
136
+ # copy the files with the source added in the first line
137
+ if not os.path.exists(docs_dir):
138
+ os.makedirs(docs_dir)
139
+ copied_files = []
140
+ for filename in tqdm(filtered_files):
141
+ source_url = get_github_url(filename, repo_urls, repo_dir)
142
+ data = f"source: {source_url}\n\n"
143
+ # convert jupyter notebooks to txt files
144
+ try:
145
+ if filename.endswith('.ipynb'):
146
+ filename = convert_notebook_to_txt(filename)
147
+ # rename and copy files
148
+ with open(filename, 'r') as f:
149
+ data += f.read()
150
+ output_filename = docs_dir + create_filename_from_path(filename, repo_dir)
151
+ with open(output_filename, 'w') as f:
152
+ f.write(data)
153
+ if not os.path.isfile(output_filename):
154
+ raise ValueError(f"Failed to create the output file: {output_filename}")
155
+ copied_files.append(output_filename)
156
+ except Exception as ex:
157
+ print(f'Failed to copy file {filename}: {ex}')
158
+
159
+ print(f'Successfully copied {len(set(copied_files))}/{len(filtered_files)} files')
160
+
161
+
162
+ def markdown_cleaner(data: str):
163
+ """
164
+ Clean markdown text.
165
+
166
+ Args:
167
+ data (str): The markdown text to be cleaned.
168
+
169
+ Returns:
170
+ str: The cleaned markdown text.
171
+ """
172
+ soupped = BeautifulSoup(markdown(data), "html.parser")
173
+ raw_text = ''.join(soupped.findAll(string=True))
174
+ clean_text = re.sub(r"<!--.*?-->", "", raw_text, flags=re.DOTALL)
175
+ # remove any special tokens e.g <|endoftext|>
176
+ clean_text = re.sub(r"<\|endoftext\|>", "", clean_text, flags=re.DOTALL)
177
+ # discard non english text
178
+ clean_text = re.sub(r"[^a-zA-Z0-9\s]", "", clean_text, flags=re.DOTALL)
179
+ return "\n".join([t for t in clean_text.split("\n") if t])
180
+
181
+
182
+ if __name__ == '__main__':
183
+ repo_urls_file = "./datasets/hf_repositories_urls.json"
184
+ repo_dir = "./datasets/huggingface_repositories/"
185
+ docs_dir = "./datasets/huggingface_docs/"
186
+ download_repositories(repo_urls_file, repo_dir)
187
+ extract_files_from_directories(
188
+ repo_urls_file, repo_dir, docs_dir,
189
+ files_extensions=['.md', '.mdx', '.ipynb']
190
+ )
data/indexer.ipynb ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import math\n",
10
+ "import numpy as np\n",
11
+ "from pathlib import Path\n",
12
+ "from tqdm import tqdm\n",
13
+ "from typing import List, Any\n",
14
+ "from langchain.chains import RetrievalQA\n",
15
+ "from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings\n",
16
+ "from langchain.document_loaders import TextLoader\n",
17
+ "from langchain.indexes import VectorstoreIndexCreator\n",
18
+ "from langchain.text_splitter import CharacterTextSplitter\n",
19
+ "from langchain.vectorstores import FAISS"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "docs = []\n",
29
+ "metadata = []\n",
30
+ "for p in Path(\"./datasets/huggingface_docs/\").iterdir():\n",
31
+ " if not p.is_dir():\n",
32
+ " with open(p) as f:\n",
33
+ " # the first line is the source of the text\n",
34
+ " source = f.readline().strip().replace('source: ', '')\n",
35
+ " docs.append(f.read())\n",
36
+ " metadata.append({\"source\": source})\n",
37
+ " # break\n",
38
+ "\n",
39
+ "print(f'number of documents: {len(docs)}')"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "chunk_size = 512\n",
49
+ "text_splitter = CharacterTextSplitter(\n",
50
+ " separator=\"\",\n",
51
+ " chunk_size=chunk_size,\n",
52
+ " chunk_overlap=100,\n",
53
+ " length_function=len,\n",
54
+ ")\n",
55
+ "docs = text_splitter.create_documents(docs, metadata)\n",
56
+ "print(f'number of chunks: {len(docs)}')"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "model_name = \"hkunlp/instructor-large\"\n",
66
+ "embed_instruction = \"Represent the Hugging Face library documentation\"\n",
67
+ "query_instruction = \"Query the most relevant piece of information from the Hugging Face documentation\"\n",
68
+ "\n",
69
+ "embedding_model = HuggingFaceInstructEmbeddings(\n",
70
+ " model_name=model_name,\n",
71
+ " embed_instruction=embed_instruction,\n",
72
+ " query_instruction=query_instruction,\n",
73
+ ")"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "class AverageInstructEmbeddings(HuggingFaceInstructEmbeddings):\n",
83
+ " max_length: int = None\n",
84
+ "\n",
85
+ " def __init__(self, max_length: int = 512, **kwargs: Any):\n",
86
+ " super().__init__(**kwargs)\n",
87
+ " self.max_length = max_length\n",
88
+ " if self.max_length < 0:\n",
89
+ " print('max_length is not specified, using model default max_seq_length')\n",
90
+ "\n",
91
+ " def embed_documents(self, texts: List[str]) -> List[List[float]]:\n",
92
+ " all_embeddings = []\n",
93
+ " for text in tqdm(texts, desc=\"Embedding documents\"):\n",
94
+ " if len(text) > self.max_length and self.max_length > -1:\n",
95
+ " n_chunks = math.ceil(len(text)/self.max_length)\n",
96
+ " chunks = [\n",
97
+ " text[i*self.max_length:(i+1)*self.max_length]\n",
98
+ " for i in range(n_chunks)\n",
99
+ " ]\n",
100
+ " instruction_pairs = [[self.embed_instruction, chunk] for chunk in chunks]\n",
101
+ " chunk_embeddings = self.client.encode(instruction_pairs)\n",
102
+ " avg_embedding = np.mean(chunk_embeddings, axis=0)\n",
103
+ " all_embeddings.append(avg_embedding.tolist())\n",
104
+ " else:\n",
105
+ " instruction_pairs = [[self.embed_instruction, text]]\n",
106
+ " embeddings = self.client.encode(instruction_pairs)\n",
107
+ " all_embeddings.append(embeddings[0].tolist())\n",
108
+ "\n",
109
+ " return all_embeddings\n",
110
+ "\n",
111
+ "\n",
112
+ "# max length fed to the model, if longer than max then chunks + averaging\n",
113
+ "max_length = 512\n",
114
+ "embedding_model = AverageInstructEmbeddings( \n",
115
+ " model_name=model_name,\n",
116
+ " embed_instruction=embed_instruction,\n",
117
+ " query_instruction=query_instruction,\n",
118
+ " max_length=max_length,\n",
119
+ ")"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": null,
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "embeddings = embedding_model.embed_documents(texts=[d.page_content for d in docs[:10]])"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": [
137
+ "index = FAISS.from_documents(docs, embedding_model)"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": null,
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "index_name = f'index-{model_name}-{chunk_size}-m{max_length}-notebooks'\n",
147
+ "index_name"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": null,
153
+ "metadata": {},
154
+ "outputs": [],
155
+ "source": [
156
+ "index.save_local(f'../indexes/{index_name}/')"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "index = FAISS.load_local(f'../indexes/{index_name}/', embedding_model)\n",
166
+ "docs = index.similarity_search(query='how to create a pipeline object?', k=5)\n",
167
+ "docs[0].page_content\n",
168
+ "docs[0].metadata"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": [
177
+ "for i, doc in enumerate(docs, start=1):\n",
178
+ " print(f\"\\n{'='*100}\\n\")\n",
179
+ " print(f\"Document {i} of {len(docs)}\")\n",
180
+ " print(\"Page Content:\")\n",
181
+ " print(f\"\\n{'-'*100}\\n\")\n",
182
+ " print(doc.page_content, '\\n')\n",
183
+ " print(doc.metadata)"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": null,
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": [
192
+ "from huggingface_hub import HfApi\n",
193
+ "\n",
194
+ "api = HfApi()\n",
195
+ "api.create_repo(\n",
196
+ " repo_id=f'KonradSzafer/{index_name}',\n",
197
+ " repo_type='dataset',\n",
198
+ " private=False,\n",
199
+ " exist_ok=True\n",
200
+ ")\n",
201
+ "api.upload_folder(\n",
202
+ " folder_path=f'../indexes/{index_name}',\n",
203
+ " repo_id=f'KonradSzafer/{index_name}',\n",
204
+ " repo_type='dataset',\n",
205
+ ")"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": null,
211
+ "metadata": {},
212
+ "outputs": [],
213
+ "source": []
214
+ }
215
+ ],
216
+ "metadata": {
217
+ "kernelspec": {
218
+ "display_name": "hf_qa_bot",
219
+ "language": "python",
220
+ "name": "python3"
221
+ },
222
+ "language_info": {
223
+ "codemirror_mode": {
224
+ "name": "ipython",
225
+ "version": 3
226
+ },
227
+ "file_extension": ".py",
228
+ "mimetype": "text/x-python",
229
+ "name": "python",
230
+ "nbconvert_exporter": "python",
231
+ "pygments_lexer": "ipython3",
232
+ "version": "3.10.12"
233
+ },
234
+ "orig_nbformat": 4
235
+ },
236
+ "nbformat": 4,
237
+ "nbformat_minor": 2
238
+ }
data/language-codes.csv ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ alpha2,English
2
+ aa,Afar
3
+ ab,Abkhazian
4
+ ae,Avestan
5
+ af,Afrikaans
6
+ ak,Akan
7
+ am,Amharic
8
+ an,Aragonese
9
+ ar,Arabic
10
+ as,Assamese
11
+ av,Avaric
12
+ ay,Aymara
13
+ az,Azerbaijani
14
+ ba,Bashkir
15
+ be,Belarusian
16
+ bg,Bulgarian
17
+ bh,Bihari languages
18
+ bi,Bislama
19
+ bm,Bambara
20
+ bn,Bengali
21
+ bo,Tibetan
22
+ br,Breton
23
+ bs,Bosnian
24
+ ca,Catalan; Valencian
25
+ ce,Chechen
26
+ ch,Chamorro
27
+ co,Corsican
28
+ cr,Cree
29
+ cs,Czech
30
+ cu,Church Slavic; Old Slavonic; Church Slavonic; Old Bulgarian; Old Church Slavonic
31
+ cv,Chuvash
32
+ cy,Welsh
33
+ da,Danish
34
+ de,German
35
+ dv,Divehi; Dhivehi; Maldivian
36
+ dz,Dzongkha
37
+ ee,Ewe
38
+ el,"Greek, Modern (1453-)"
39
+ en,English
40
+ eo,Esperanto
41
+ es,Spanish; Castilian
42
+ et,Estonian
43
+ eu,Basque
44
+ fa,Persian
45
+ ff,Fulah
46
+ fi,Finnish
47
+ fj,Fijian
48
+ fo,Faroese
49
+ fr,French
50
+ fy,Western Frisian
51
+ ga,Irish
52
+ gd,Gaelic; Scottish Gaelic
53
+ gj,Gujarati
54
+ gl,Galician
55
+ gn,Guarani
56
+ gu,Gujarati
57
+ gv,Manx
58
+ ha,Hausa
59
+ hd,Hindi
60
+ he,Hebrew
61
+ hi,Hindi
62
+ ho,Hiri Motu
63
+ hr,Croatian
64
+ ht,Haitian; Haitian Creole
65
+ hu,Hungarian
66
+ hy,Armenian
67
+ hz,Herero
68
+ ia,Interlingua (International Auxiliary Language Association)
69
+ id,Indonesian
70
+ ie,Interlingue; Occidental
71
+ ig,Igbo
72
+ ii,Sichuan Yi; Nuosu
73
+ ik,Inupiaq
74
+ io,Ido
75
+ is,Icelandic
76
+ it,Italian
77
+ iu,Inuktitut
78
+ ja,Japanese
79
+ jv,Javanese
80
+ ka,Georgian
81
+ kg,Kongo
82
+ ki,Kikuyu; Gikuyu
83
+ kj,Kuanyama; Kwanyama
84
+ kk,Kazakh
85
+ kl,Kalaallisut; Greenlandic
86
+ km,Central Khmer
87
+ kn,Kannada
88
+ ko,Korean
89
+ kr,Kanuri
90
+ ks,Kashmiri
91
+ ku,Kurdish
92
+ kv,Komi
93
+ kw,Cornish
94
+ ky,Kirghiz; Kyrgyz
95
+ la,Latin
96
+ lb,Luxembourgish; Letzeburgesch
97
+ lg,Ganda
98
+ li,Limburgan; Limburger; Limburgish
99
+ ln,Lingala
100
+ lo,Lao
101
+ lt,Lithuanian
102
+ lu,Luba-Katanga
103
+ lv,Latvian
104
+ mg,Malagasy
105
+ mh,Marshallese
106
+ mi,Maori
107
+ mk,Macedonian
108
+ ml,Malayalam
109
+ mn,Mongolian
110
+ mr,Marathi
111
+ ms,Malay
112
+ mt,Maltese
113
+ my,Burmese
114
+ na,Nauru
115
+ nb,"Bokmål, Norwegian; Norwegian Bokmål"
116
+ nd,"Ndebele, North; North Ndebele"
117
+ ne,Nepali
118
+ ng,Ndonga
119
+ nl,Dutch; Flemish
120
+ nn,"Norwegian Nynorsk; Nynorsk, Norwegian"
121
+ no,Norwegian
122
+ nr,"Ndebele, South; South Ndebele"
123
+ nv,Navajo; Navaho
124
+ ny,Chichewa; Chewa; Nyanja
125
+ oc,Occitan (post 1500)
126
+ oj,Ojibwa
127
+ om,Oromo
128
+ or,Oriya
129
+ os,Ossetian; Ossetic
130
+ pa,Panjabi; Punjabi
131
+ pi,Pali
132
+ pl,Polish
133
+ ps,Pushto; Pashto
134
+ pt,Portuguese
135
+ qu,Quechua
136
+ rm,Romansh
137
+ rn,Rundi
138
+ ro,Romanian; Moldavian; Moldovan
139
+ ru,Russian
140
+ rw,Kinyarwanda
141
+ sa,Sanskrit
142
+ sc,Sardinian
143
+ sd,Sindhi
144
+ se,Northern Sami
145
+ sg,Sango
146
+ si,Sinhala; Sinhalese
147
+ sk,Slovak
148
+ sl,Slovenian
149
+ sm,Samoan
150
+ sn,Shona
151
+ so,Somali
152
+ sq,Albanian
153
+ sr,Serbian
154
+ ss,Swati
155
+ st,"Sotho, Southern"
156
+ su,Sundanese
157
+ sv,Swedish
158
+ sw,Swahili
159
+ ta,Tamil
160
+ te,Telugu
161
+ tg,Tajik
162
+ th,Thai
163
+ ti,Tigrinya
164
+ tk,Turkmen
165
+ tl,Tagalog
166
+ tn,Tswana
167
+ to,Tonga (Tonga Islands)
168
+ tr,Turkish
169
+ ts,Tsonga
170
+ tt,Tatar
171
+ tw,Twi
172
+ ty,Tahitian
173
+ ug,Uighur; Uyghur
174
+ uk,Ukrainian
175
+ ur,Urdu
176
+ uz,Uzbek
177
+ ve,Venda
178
+ vi,Vietnamese
179
+ vo,Volapük
180
+ wa,Walloon
181
+ wo,Wolof
182
+ xh,Xhosa
183
+ yi,Yiddish
184
+ yo,Yoruba
185
+ za,Zhuang; Chuang
186
+ zh,Chinese; General
187
+ zh-CN,Chinese; Simplified
188
+ zh-TW,Chinese; Traditional
189
+ zh-hans,Chinese; Simplified
190
+ zu,Zulu
data/scrapers/stack_overflow_scraper.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import csv
3
+ import time
4
+ import requests
5
+ from typing import List
6
+ import pandas as pd
7
+ from tqdm import tqdm
8
+ from bs4 import BeautifulSoup
9
+
10
+
11
+ def scrape_question_with_answers(question_url: str) -> List[str]:
12
+ url = 'https://stackoverflow.com/' + question_url
13
+ response = requests.get(url)
14
+ soup = BeautifulSoup(response.content, 'html.parser')
15
+
16
+ title = soup.find('title').text.replace(' - Stack Overflow', '')
17
+ question_div = soup.find('div', {'class': 'postcell post-layout--right'})
18
+ question = question_div.find('p').text
19
+ answers_div = soup.find('div', {'class': 'answercell post-layout--right'})
20
+ answer = answers_div.find('div', {'class': 's-prose js-post-body'}).text
21
+ return [title, question, answer, url]
22
+
23
+
24
+ def scrape_questions_page(url: str, min_votes: int, min_answers: int) -> List[List[str]]:
25
+ response = requests.get(url)
26
+ soup = BeautifulSoup(response.content, 'html.parser')
27
+ posts_summaries = soup.find_all('div', {'class':'s-post-summary js-post-summary'})
28
+
29
+ qa_data = []
30
+ for summary in posts_summaries:
31
+ stats_div = summary.find('div', {'class': 's-post-summary--stats'})
32
+ vote_div = stats_div.find('div', {
33
+ 'class': 's-post-summary--stats-item s-post-summary--stats-item__emphasized',
34
+ 'title': re.compile(r'^Score of \d+$')})
35
+ if vote_div:
36
+ vote_number = int(vote_div.find('span', {'class': 's-post-summary--stats-item-number'}).text)
37
+ else:
38
+ vote_number = 0
39
+ answer_div = stats_div.find('div', {
40
+ 'class': 's-post-summary--stats-item',
41
+ 'title': re.compile(r'^\d+ answers$')})
42
+ if answer_div:
43
+ answer_number = int(answer_div.find('span', {'class': 's-post-summary--stats-item-number'}).text)
44
+ else:
45
+ answer_number = 0
46
+
47
+ question_href = summary.find('a', {'class': 's-link'})['href']
48
+ if vote_number >= min_votes and answer_number >= min_answers:
49
+ try:
50
+ qa_data.append(scrape_question_with_answers(question_href))
51
+ except Exception as error:
52
+ print(error)
53
+
54
+ time.sleep(1.5)
55
+ return qa_data
56
+
57
+
58
+ def crawl_and_save_qa(
59
+ filename: str,
60
+ base_url: str,
61
+ start_page: int,
62
+ n_pages: int=10,
63
+ min_votes: int=1,
64
+ min_answers: int=1
65
+ ):
66
+ with open(filename, 'a', newline='') as f:
67
+ writer = csv.writer(f)
68
+ if start_page == 1:
69
+ writer.writerow(['title', 'question', 'answer', 'url'])
70
+ for page_num in tqdm(range(start_page, start_page+n_pages)):
71
+ page_data = scrape_questions_page(
72
+ base_url.format(page_num),
73
+ min_votes,
74
+ min_answers
75
+ )
76
+ if page_data:
77
+ for qa_data in page_data:
78
+ writer.writerow(qa_data)
79
+
80
+
81
+ if __name__ == '__main__':
82
+ filename = '../datasets/stackoverflow_linux.csv'
83
+ url = 'https://stackoverflow.com/questions/tagged/linux?tab=votes&page={}&pagesize=15'
84
+ crawl_and_save_qa(
85
+ filename=filename,
86
+ base_url=url,
87
+ start_page=21,
88
+ n_pages=10,
89
+ min_votes=1,
90
+ min_answers=1
91
+ )
data/stackoverflow_python_dataset.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from datasets import load_dataset
3
+ from bs4 import BeautifulSoup
4
+
5
+
6
+ def preprocess_dataset():
7
+ """
8
+ Preprocesses the 'koutch/stackoverflow_python' dataset.
9
+
10
+ Returns:
11
+ datasets.arrow_dataset.Dataset: The preprocessed dataset.
12
+ """
13
+ dataset = load_dataset('koutch/stackoverflow_python', split='train')
14
+ dataset = dataset.filter(
15
+ lambda example:
16
+ example['question_score'] > 100 and
17
+ example['answer_score'] > 5 and
18
+ datetime.strptime(example['answer_date'], '%Y-%m-%dT%H:%M:%SZ').year > 2010
19
+ )
20
+
21
+ def html2text(example):
22
+ soup = BeautifulSoup(example, 'html.parser')
23
+ return ''.join(soup.findAll(string=True))
24
+
25
+ def transforms(example):
26
+ example['answer'] = html2text(example['answer_body'])
27
+ example['question'] = html2text(example['question_body'])
28
+ return example
29
+
30
+ dataset = dataset.map(lambda example: transforms(example))
31
+ dataset = dataset.remove_columns([
32
+ 'question_score', 'question_date', 'question_id',
33
+ 'answer_date', 'answer_id', 'answer_score', 'tags',
34
+ 'question_body', 'answer_body'
35
+ ])
36
+ return dataset
37
+
38
+
39
+ def show_info(dataset):
40
+ """
41
+ Print information about the dataset.
42
+
43
+ Args:
44
+ dataset (datasets.arrow_dataset.Dataset): The dataset.
45
+ """
46
+ print(dataset.info, '\n')
47
+ print(f'dataset len: {len(dataset)}')
48
+ print(f"example question: {dataset[0]['question']}")
49
+ print(f"example answer: {dataset[0]['answer']}")
50
+
51
+
52
+ if __name__ == '__main__':
53
+ dataset = preprocess_dataset()
54
+ dataset.push_to_hub('KonradSzafer/stackoverflow_python_preprocessed', private=False)
55
+ show_info(dataset)
data/upload_csv_dataset.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import pandas as pd
3
+ from datasets import Dataset, DatasetDict
4
+ from sklearn.model_selection import train_test_split
5
+
6
+
7
+
8
+ def main():
9
+ dataset_name = sys.argv[1]
10
+ test_size = float(sys.argv[2]) if len(sys.argv) > 2 else 0.1
11
+ print(f'dataset: {dataset_name}, test size: {test_size}')
12
+
13
+ filename = f'datasets/{dataset_name}.csv'
14
+ df = pd.read_csv(filename)
15
+ dataset = Dataset.from_pandas(df)
16
+ train_dataset, test_dataset = train_test_split(dataset, test_size=test_size)
17
+ train_dataset = Dataset.from_dict(train_dataset)
18
+ test_dataset = Dataset.from_dict(test_dataset)
19
+ dataset_dict = DatasetDict({'train': train_dataset, 'test': test_dataset})
20
+ dataset_dict.push_to_hub(f'KonradSzafer/{dataset_name}', private=False)
21
+
22
+
23
+ if __name__ == '__main__':
24
+ main()
discord_bot/__init__.py ADDED
File without changes
discord_bot/__main__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from qa_engine import logger, Config, QAEngine
2
+ from discord_bot.client import DiscordClient
3
+
4
+
5
+ config = Config()
6
+ qa_engine = QAEngine(
7
+ llm_model_id=config.question_answering_model_id,
8
+ embedding_model_id=config.embedding_model_id,
9
+ index_repo_id=config.index_repo_id,
10
+ prompt_template=config.prompt_template,
11
+ use_docs_for_context=config.use_docs_for_context,
12
+ num_relevant_docs=config.num_relevant_docs,
13
+ add_sources_to_response=config.add_sources_to_response,
14
+ use_messages_for_context=config.use_messages_in_context,
15
+ debug=config.debug
16
+ )
17
+ client = DiscordClient(
18
+ qa_engine=qa_engine,
19
+ num_last_messages=config.num_last_messages,
20
+ use_names_in_context=config.use_names_in_context,
21
+ enable_commands=config.enable_commands,
22
+ debug=config.debug
23
+ )
24
+
25
+
26
+ if __name__ == '__main__':
27
+ logger.info('Starting Application...')
28
+ client.run(config.discord_token)
discord_bot/client/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .client import DiscordClient
discord_bot/client/client.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import requests
3
+ from urllib.parse import quote
4
+ import discord
5
+ from typing import List
6
+
7
+ from qa_engine import logger, QAEngine
8
+ from discord_bot.client.utils import split_text_into_chunks
9
+
10
+
11
+ class DiscordClient(discord.Client):
12
+ """
13
+ Discord Client class, used for interacting with a Discord server.
14
+
15
+ Args:
16
+ qa_service_url (str): The URL of the question answering service.
17
+ num_last_messages (int, optional): The number of previous messages to use as context for generating answers.
18
+ Defaults to 5.
19
+ use_names_in_context (bool, optional): Whether to include user names in the message context. Defaults to True.
20
+ enable_commands (bool, optional): Whether to enable commands for the bot. Defaults to True.
21
+
22
+ Attributes:
23
+ qa_service_url (str): The URL of the question answering service.
24
+ num_last_messages (int): The number of previous messages to use as context for generating answers.
25
+ use_names_in_context (bool): Whether to include user names in the message context.
26
+ enable_commands (bool): Whether to enable commands for the bot.
27
+ max_message_len (int): The maximum length of a message.
28
+ system_prompt (str): The system prompt to be used.
29
+
30
+ """
31
+ def __init__(
32
+ self,
33
+ qa_engine: QAEngine,
34
+ num_last_messages: int = 5,
35
+ use_names_in_context: bool = True,
36
+ enable_commands: bool = True,
37
+ debug: bool = False
38
+ ):
39
+ logger.info('Initializing Discord client...')
40
+ intents = discord.Intents.all()
41
+ intents.message_content = True
42
+ super().__init__(intents=intents, command_prefix='!')
43
+
44
+ assert num_last_messages >= 1, \
45
+ 'The number of last messages in context should be at least 1'
46
+
47
+ self.qa_engine: QAEngine = qa_engine
48
+ self.num_last_messages: int = num_last_messages
49
+ self.use_names_in_context: bool = use_names_in_context
50
+ self.enable_commands: bool = enable_commands
51
+ self.debug: bool = debug
52
+ self.min_messgae_len: int = 1800
53
+ self.max_message_len: int = 2000
54
+
55
+
56
+ async def on_ready(self):
57
+ """
58
+ Callback function to be called when the client is ready.
59
+ """
60
+ logger.info('Successfully logged in as: {0.user}'.format(self))
61
+ await self.change_presence(activity=discord.Game(name='Chatting...'))
62
+
63
+
64
+ async def get_last_messages(self, message) -> List[str]:
65
+ """
66
+ Method to fetch recent messages from a message's channel.
67
+
68
+ Args:
69
+ message (Message): The discord Message object used to identify the channel.
70
+
71
+ Returns:
72
+ List[str]: Reversed list of recent messages from the channel,
73
+ excluding the input message. Messages may be prefixed with the author's name
74
+ if `self.use_names_in_context` is True.
75
+ """
76
+ last_messages: List[str] = []
77
+ async for msg in message.channel.history(
78
+ limit=self.num_last_messages):
79
+ if self.use_names_in_context:
80
+ last_messages.append(f'{msg.author}: {msg.content}')
81
+ else:
82
+ last_messages.append(msg.content)
83
+ last_messages.reverse()
84
+ last_messages.pop() # remove last message from context
85
+ return last_messages
86
+
87
+
88
+ async def send_message(self, message, answer: str, sources: str):
89
+ chunks = split_text_into_chunks(
90
+ text=answer,
91
+ split_characters=['. ', ', ', '\n'],
92
+ min_size=self.min_messgae_len,
93
+ max_size=self.max_message_len
94
+ )
95
+ for chunk in chunks:
96
+ await message.channel.send(chunk)
97
+ await message.channel.send(sources)
98
+
99
+
100
+ async def on_message(self, message):
101
+ """
102
+ Callback function to be called when a message is received.
103
+
104
+ Args:
105
+ message (discord.Message): The received message.
106
+ """
107
+ if message.author == self.user:
108
+ return
109
+ if self.enable_commands and message.content.startswith('!'):
110
+ if message.content == '!clear':
111
+ await message.channel.purge()
112
+ return
113
+
114
+ last_messages = await self.get_last_messages(message)
115
+ context = '\n'.join(last_messages)
116
+
117
+ logger.info('Received message: {0.content}'.format(message))
118
+ response = self.qa_engine.get_response(
119
+ question=message.content,
120
+ messages_context=context
121
+ )
122
+ logger.info('Sending response: {0}'.format(response))
123
+ try:
124
+ await self.send_message(
125
+ message,
126
+ response.get_answer(),
127
+ response.get_sources_as_text()
128
+ )
129
+ except Exception as e:
130
+ logger.error('Failed to send response: {0}'.format(e))
discord_bot/client/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+
4
+ def find_max_split_index(text: str, char: str) -> int:
5
+ char_idx = text.rfind(char)
6
+ if char_idx > 0:
7
+ # If a character is found, return the index after the splitting character
8
+ split_idx = char_idx + len(char)
9
+ if split_idx >= len(text):
10
+ return len(text)
11
+ else:
12
+ return split_idx
13
+ return -1
14
+
15
+
16
+ def find_max_split_index_from_sequence(text: str, split_characters: List[str]) -> int:
17
+ split_index = max((
18
+ find_max_split_index(text, sequence)
19
+ for sequence in split_characters
20
+ ), default=-1)
21
+ return split_index
22
+
23
+
24
+ def split_text_into_chunks(
25
+ text: str,
26
+ split_characters: List[str] = [],
27
+ min_size: int = 20,
28
+ max_size: int = 250,
29
+ ) -> List[str]:
30
+
31
+ chunks = []
32
+ start_idx = 0
33
+ end_idx = max_size
34
+ text_len = len(text)
35
+ while start_idx < text_len:
36
+ search_chunk = text[start_idx+min_size:end_idx]
37
+ split_idx = find_max_split_index_from_sequence(
38
+ text=search_chunk,
39
+ split_characters=split_characters
40
+ )
41
+ # if no spliting element found, set the maximal size
42
+ if split_idx < 1:
43
+ split_idx = end_idx
44
+ # if found - offset it by the starting idx of the chunk
45
+ else:
46
+ split_idx += start_idx + min_size
47
+
48
+ chunk = text[start_idx:split_idx]
49
+ chunks.append(chunk)
50
+
51
+ start_idx = split_idx
52
+ end_idx = split_idx + max_size
53
+
54
+ return chunks
docker-compose.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3'
2
+ services:
3
+ api:
4
+ build:
5
+ context: .
6
+ dockerfile: Dockerfile.api
7
+ ports:
8
+ - 8000:8000
9
+ networks:
10
+ - mynetwork
11
+ bot:
12
+ build:
13
+ context: .
14
+ dockerfile: Dockerfile.bot
15
+ ports:
16
+ - 80:80
17
+ depends_on:
18
+ - api
19
+ networks:
20
+ - mynetwork
21
+ networks:
22
+ mynetwork:
23
+ driver: bridge
models/inference.ipynb ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import torch\n",
11
+ "import transformers\n",
12
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
13
+ "\n",
14
+ "PROMPT_TEMPLATES_DIR = os.path.dirname(os.path.abspath(os.getcwd()))\n",
15
+ "PROMPT_TEMPLATES_DIR += '/config/api/prompt_templates/'\n",
16
+ "\n",
17
+ "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "prompt_template = 'sythia_v1.3'\n",
27
+ "with open(PROMPT_TEMPLATES_DIR + f'{prompt_template}.txt', 'r') as f:\n",
28
+ " prompt_template = f.read()\n",
29
+ "\n",
30
+ "context = ''\n",
31
+ "question = 'How to fix a bike?'\n",
32
+ "\n",
33
+ "prompt = prompt_template.format(context=context, question=question)\n",
34
+ "print(f'prompt len: {len(prompt)}\\n')\n",
35
+ "print(prompt)"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "model_id = 'migtissera/SynthIA-7B-v1.3'\n",
45
+ "\n",
46
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
47
+ "model = AutoModelForCausalLM.from_pretrained(\n",
48
+ " model_id,\n",
49
+ " torch_dtype=torch.bfloat16,\n",
50
+ " trust_remote_code=True,\n",
51
+ " load_in_8bit=False,\n",
52
+ " device_map='auto',\n",
53
+ " resume_download=True,\n",
54
+ ")\n",
55
+ "\n",
56
+ "pipeline = transformers.pipeline(\n",
57
+ " 'text-generation',\n",
58
+ " model=model,\n",
59
+ " tokenizer=tokenizer,\n",
60
+ " device_map='auto',\n",
61
+ " torch_dtype=torch.bfloat16,\n",
62
+ " eos_token_id=tokenizer.eos_token_id,\n",
63
+ " pad_token_id=tokenizer.eos_token_id,\n",
64
+ " min_new_tokens=64,\n",
65
+ " max_new_tokens=800,\n",
66
+ " temperature=0.5,\n",
67
+ " do_sample=True,\n",
68
+ ")\n",
69
+ "\n",
70
+ "output_text = pipeline(prompt)[0]['generated_text']\n",
71
+ "output_text = output_text.replace(prompt+'\\n', '')\n",
72
+ "print(output_text)"
73
+ ]
74
+ }
75
+ ],
76
+ "metadata": {
77
+ "kernelspec": {
78
+ "display_name": "hf_qa_bot",
79
+ "language": "python",
80
+ "name": "python3"
81
+ },
82
+ "language_info": {
83
+ "codemirror_mode": {
84
+ "name": "ipython",
85
+ "version": 3
86
+ },
87
+ "file_extension": ".py",
88
+ "mimetype": "text/x-python",
89
+ "name": "python",
90
+ "nbconvert_exporter": "python",
91
+ "pygments_lexer": "ipython3",
92
+ "version": "3.11.5"
93
+ },
94
+ "orig_nbformat": 4,
95
+ "vscode": {
96
+ "interpreter": {
97
+ "hash": "e769ac600d1c65682759767682b2a946c0eaa09d353302f712fe4c2e822e15df"
98
+ }
99
+ }
100
+ },
101
+ "nbformat": 4,
102
+ "nbformat_minor": 2
103
+ }
qa_engine/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ from qa_engine.logger import setup_logger
3
+
4
+ setup_logger()
5
+ load_dotenv(dotenv_path='config/.env')
6
+
7
+ from .logger import setup_logger, logger
8
+ from .config import Config
9
+ from .qa_engine import QAEngine
qa_engine/config.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, asdict
3
+ from typing import Any, Union
4
+
5
+ from qa_engine import logger
6
+
7
+
8
+ def get_env(env_name: str, default: Any = None, warn: bool = True) -> str:
9
+ env = os.getenv(env_name)
10
+ if not env:
11
+ if default:
12
+ if warn:
13
+ logger.warning(
14
+ f'Environment variable {env_name} not found.' \
15
+ f'Using the default value: {default}.'
16
+ )
17
+ return default
18
+ else:
19
+ raise ValueError(f'Cannot parse: {env_name}')
20
+ else:
21
+ return env
22
+
23
+
24
+ @dataclass
25
+ class Config:
26
+ # QA Engine config
27
+ question_answering_model_id: str = get_env('QUESTION_ANSWERING_MODEL_ID')
28
+ embedding_model_id: str = get_env('EMBEDDING_MODEL_ID')
29
+ index_repo_id: str = get_env('INDEX_REPO_ID')
30
+ prompt_template_name: str = get_env('PROMPT_TEMPLATE_NAME')
31
+ use_docs_for_context: bool = eval(get_env('USE_DOCS_FOR_CONTEXT', 'True'))
32
+ num_relevant_docs: bool = eval(get_env('NUM_RELEVANT_DOCS', 3))
33
+ add_sources_to_response: bool = eval(get_env('ADD_SOURCES_TO_RESPONSE', 'True'))
34
+ use_messages_in_context: bool = eval(get_env('USE_MESSAGES_IN_CONTEXT', 'True'))
35
+ debug: bool = eval(get_env('DEBUG', 'True'))
36
+
37
+ # Discord bot config - optional
38
+ discord_token: str = get_env('DISCORD_TOKEN', '', warn=False)
39
+ num_last_messages: int = int(get_env('NUM_LAST_MESSAGES', 2, warn=False))
40
+ use_names_in_context: bool = eval(get_env('USE_NAMES_IN_CONTEXT', 'False', warn=False))
41
+ enable_commands: bool = eval(get_env('ENABLE_COMMANDS', 'True', warn=False))
42
+
43
+ def __post_init__(self):
44
+ prompt_template_file = f'config/prompt_templates/{self.prompt_template_name}.txt'
45
+ with open(prompt_template_file, 'r') as f:
46
+ self.prompt_template = f.read()
47
+ # validate config
48
+ if 'context' not in self.prompt_template:
49
+ raise ValueError("Prompt Template does not contain the 'context' field.")
50
+ if 'question' not in self.prompt_template:
51
+ raise ValueError("Prompt Template does not contain the 'question' field.")
52
+ if not self.use_docs_for_context and self.add_sources_to_response:
53
+ raise ValueError('Cannot add sources to response if not using docs in context')
54
+ if self.num_relevant_docs < 1:
55
+ raise ValueError('num_relevant_docs must be greater than 0')
56
+ self.log()
57
+
58
+ def asdict(self) -> dict:
59
+ return asdict(self)
60
+
61
+ def log(self) -> None:
62
+ logger.info('Config:')
63
+ for key, value in self.asdict().items():
64
+ logger.info(f'{key}: {value}')
qa_engine/logger.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+ def setup_logger() -> None:
7
+ """
8
+ Logger setup.
9
+ """
10
+ logger.setLevel(logging.DEBUG)
11
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
12
+ handler = logging.StreamHandler()
13
+ handler.setFormatter(formatter)
14
+ logger.addHandler(handler)
qa_engine/mocks.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Mapping, Optional, Any
3
+
4
+ from langchain.llms.base import LLM
5
+
6
+
7
+ class MockLocalBinaryModel(LLM):
8
+ """
9
+ Mock Local Binary Model class, used for generating the string "a".
10
+
11
+ Args:
12
+ model_id (str): The ID of the model to be mocked.
13
+
14
+ Attributes:
15
+ model_path (str): The path to the model to be mocked.
16
+ llm (str): The string "a".
17
+
18
+ Raises:
19
+ ValueError: If the model_path does not exist.
20
+ """
21
+
22
+ model_path: str = None
23
+ llm: str = 'READY TO MOCK'
24
+
25
+ def __init__(self, model_id: str = None):
26
+ super().__init__()
27
+ self.model_path = f'bot/question_answering/{model_id}'
28
+ if not os.path.exists(self.model_path):
29
+ raise ValueError(f'{self.model_path} does not exist')
30
+
31
+
32
+ def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
33
+ return self.llm
34
+
35
+ @property
36
+ def _identifying_params(self) -> Mapping[str, Any]:
37
+ return {'name_of_model': self.model_path}
38
+
39
+ @property
40
+ def _llm_type(self) -> str:
41
+ return self.model_path
qa_engine/qa_engine.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import requests
4
+ import subprocess
5
+ from typing import Mapping, Optional, Any
6
+
7
+ import torch
8
+ import transformers
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from huggingface_hub import snapshot_download
11
+ from urllib.parse import quote
12
+ from langchain import PromptTemplate, HuggingFaceHub, LLMChain
13
+ from langchain.llms import HuggingFacePipeline
14
+ from langchain.llms.base import LLM
15
+ from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings, HuggingFaceInstructEmbeddings
16
+ from langchain.vectorstores import FAISS
17
+ from sentence_transformers import CrossEncoder
18
+
19
+ from qa_engine import logger
20
+ from qa_engine.response import Response
21
+
22
+
23
+ class LocalBinaryModel(LLM):
24
+ model_id: str = None
25
+ llm: None = None
26
+
27
+ def __init__(self, model_id: str = None):
28
+ super().__init__()
29
+ # pip install llama_cpp_python==0.1.39
30
+ from llama_cpp import Llama
31
+
32
+ model_path = f'qa_engine/{model_id}'
33
+ if not os.path.exists(model_path):
34
+ raise ValueError(f'{model_path} does not exist')
35
+ self.model_id = model_id
36
+ self.llm = Llama(model_path=model_path, n_ctx=4096)
37
+
38
+ def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
39
+ output = self.llm(
40
+ prompt,
41
+ max_tokens=1024,
42
+ stop=['Q:'],
43
+ echo=False
44
+ )
45
+ return output['choices'][0]['text']
46
+
47
+ @property
48
+ def _identifying_params(self) -> Mapping[str, Any]:
49
+ return {'name_of_model': self.model_id}
50
+
51
+ @property
52
+ def _llm_type(self) -> str:
53
+ return self.model_id
54
+
55
+
56
+ class TransformersPipelineModel(LLM):
57
+ model_id: str = None
58
+ pipeline: str = None
59
+
60
+ def __init__(self, model_id: str = None):
61
+ super().__init__()
62
+ self.model_id = model_id
63
+
64
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
65
+ model = AutoModelForCausalLM.from_pretrained(
66
+ model_id,
67
+ torch_dtype=torch.bfloat16,
68
+ trust_remote_code=True,
69
+ load_in_8bit=False,
70
+ device_map='auto',
71
+ resume_download=True,
72
+ )
73
+ self.pipeline = transformers.pipeline(
74
+ 'text-generation',
75
+ model=model,
76
+ tokenizer=tokenizer,
77
+ torch_dtype=torch.bfloat16,
78
+ device_map='auto',
79
+ eos_token_id=tokenizer.eos_token_id,
80
+ pad_token_id=tokenizer.eos_token_id,
81
+ min_new_tokens=64,
82
+ max_new_tokens=800,
83
+ temperature=0.5,
84
+ do_sample=True,
85
+ )
86
+
87
+ def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
88
+ output_text = self.pipeline(prompt)[0]['generated_text']
89
+ output_text = output_text.replace(prompt+'\n', '')
90
+ return output_text
91
+
92
+ @property
93
+ def _identifying_params(self) -> Mapping[str, Any]:
94
+ return {'name_of_model': self.model_id}
95
+
96
+ @property
97
+ def _llm_type(self) -> str:
98
+ return self.model_id
99
+
100
+
101
+ class APIServedModel(LLM):
102
+ model_url: str = None
103
+ debug: bool = None
104
+
105
+ def __init__(self, model_url: str = None, debug: bool = None):
106
+ super().__init__()
107
+ if model_url[-1] == '/':
108
+ raise ValueError('URL should not end with a slash - "/"')
109
+ self.model_url = model_url
110
+ self.debug = debug
111
+
112
+ def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
113
+ prompt_encoded = quote(prompt, safe='')
114
+ url = f'{self.model_url}/?prompt={prompt_encoded}'
115
+ if self.debug:
116
+ logger.info(f'URL: {url}')
117
+ try:
118
+ response = requests.get(url, timeout=1200, verify=False)
119
+ response.raise_for_status()
120
+ return json.loads(response.content)['output_text']
121
+ except Exception as err:
122
+ logger.error(f'Error: {err}')
123
+ return f'Error: {err}'
124
+
125
+ @property
126
+ def _identifying_params(self) -> Mapping[str, Any]:
127
+ return {'name_of_model': f'model url: {self.model_url}'}
128
+
129
+ @property
130
+ def _llm_type(self) -> str:
131
+ return 'api_model'
132
+
133
+
134
+
135
+ class QAEngine():
136
+ """
137
+ QAEngine class, used for generating answers to questions.
138
+
139
+ Args:
140
+ llm_model_id (str): The ID of the LLM model to be used.
141
+ embedding_model_id (str): The ID of the embedding model to be used.
142
+ index_repo_id (str): The ID of the index repository to be used.
143
+ run_locally (bool, optional): Whether to run the models locally or on the Hugging Face hub. Defaults to True.
144
+ use_docs_for_context (bool, optional): Whether to use relevant documents as context for generating answers.
145
+ Defaults to True.
146
+ use_messages_for_context (bool, optional): Whether to use previous messages as context for generating answers.
147
+ Defaults to True.
148
+ debug (bool, optional): Whether to log debug information. Defaults to False.
149
+
150
+ Attributes:
151
+ use_docs_for_context (bool): Whether to use relevant documents as context for generating answers.
152
+ use_messages_for_context (bool): Whether to use previous messages as context for generating answers.
153
+ debug (bool): Whether to log debug information.
154
+ llm_model (Union[LocalBinaryModel, HuggingFacePipeline, HuggingFaceHub]): The LLM model to be used.
155
+ embedding_model (Union[HuggingFaceInstructEmbeddings, HuggingFaceHubEmbeddings]): The embedding model to be used.
156
+ prompt_template (PromptTemplate): The prompt template to be used.
157
+ llm_chain (LLMChain): The LLM chain to be used.
158
+ knowledge_index (FAISS): The FAISS index to be used.
159
+
160
+ """
161
+ def __init__(
162
+ self,
163
+ llm_model_id: str,
164
+ embedding_model_id: str,
165
+ index_repo_id: str,
166
+ prompt_template: str,
167
+ use_docs_for_context: bool = True,
168
+ num_relevant_docs: int = 3,
169
+ add_sources_to_response: bool = True,
170
+ use_messages_for_context: bool = True,
171
+ first_stage_docs: int = 50,
172
+ debug: bool = False
173
+ ):
174
+ super().__init__()
175
+ self.prompt_template = prompt_template
176
+ self.use_docs_for_context = use_docs_for_context
177
+ self.num_relevant_docs = num_relevant_docs
178
+ self.add_sources_to_response = add_sources_to_response
179
+ self.use_messages_for_context = use_messages_for_context
180
+ self.first_stage_docs = first_stage_docs
181
+ self.debug = debug
182
+
183
+ if 'local_models/' in llm_model_id:
184
+ logger.info('using local binary model')
185
+ self.llm_model = LocalBinaryModel(
186
+ model_id=llm_model_id
187
+ )
188
+ elif 'api_models/' in llm_model_id:
189
+ logger.info('using api served model')
190
+ self.llm_model = APIServedModel(
191
+ model_url=llm_model_id.replace('api_models/', ''),
192
+ debug=self.debug
193
+ )
194
+ else:
195
+ logger.info('using transformers pipeline model')
196
+ self.llm_model = TransformersPipelineModel(
197
+ model_id=llm_model_id
198
+ )
199
+
200
+ prompt = PromptTemplate(
201
+ template=prompt_template,
202
+ input_variables=['question', 'context']
203
+ )
204
+ self.llm_chain = LLMChain(prompt=prompt, llm=self.llm_model)
205
+
206
+ if self.use_docs_for_context:
207
+ logger.info(f'Downloading {index_repo_id}')
208
+ snapshot_download(
209
+ repo_id=index_repo_id,
210
+ allow_patterns=['*.faiss', '*.pkl'],
211
+ repo_type='dataset',
212
+ local_dir='indexes/run/'
213
+ )
214
+ logger.info('Loading embedding model')
215
+ embed_instruction = 'Represent the Hugging Face library documentation'
216
+ query_instruction = 'Query the most relevant piece of information from the Hugging Face documentation'
217
+ embedding_model = HuggingFaceInstructEmbeddings(
218
+ model_name=embedding_model_id,
219
+ embed_instruction=embed_instruction,
220
+ query_instruction=query_instruction
221
+ )
222
+ logger.info('Loading index')
223
+ self.knowledge_index = FAISS.load_local('./indexes/run/', embedding_model)
224
+ self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
225
+
226
+
227
+ def get_response(self, question: str, messages_context: str = '') -> Response:
228
+ """
229
+ Generate an answer to the specified question.
230
+
231
+ Args:
232
+ question (str): The question to be answered.
233
+ messages_context (str, optional): The context to be used for generating the answer. Defaults to ''.
234
+
235
+ Returns:
236
+ response (Response): The Response object containing the generated answer and the sources of information
237
+ used to generate the response.
238
+ """
239
+
240
+ response = Response()
241
+ context = ''
242
+ relevant_docs = ''
243
+ if self.use_messages_for_context and messages_context:
244
+ messages_context = f'\nPrevious questions and answers:\n{messages_context}'
245
+ context += messages_context
246
+ if self.use_docs_for_context:
247
+ logger.info('Retriving documents')
248
+ # messages context is used for better retrival
249
+ retrival_query = messages_context + question
250
+ relevant_docs = self.knowledge_index.similarity_search(
251
+ query=retrival_query,
252
+ k=self.first_stage_docs
253
+ )
254
+ cross_encoding_predictions = self.reranker.predict(
255
+ [(retrival_query, doc.page_content) for doc in relevant_docs]
256
+ )
257
+ relevant_docs = [
258
+ doc for _, doc in sorted(
259
+ zip(cross_encoding_predictions, relevant_docs),
260
+ reverse=True, key = lambda x: x[0]
261
+ )
262
+ ]
263
+ relevant_docs = relevant_docs[:self.num_relevant_docs]
264
+ context += '\nExtracted documents:\n'
265
+ context += ''.join([doc.page_content for doc in relevant_docs])
266
+ metadata = [doc.metadata for doc in relevant_docs]
267
+ response.set_sources(sources=[str(m['source']) for m in metadata])
268
+
269
+ logger.info('Running LLM chain')
270
+ answer = self.llm_chain.run(question=question, context=context)
271
+ response.set_answer(answer)
272
+ logger.info('Received answer')
273
+
274
+ if self.debug:
275
+ logger.info('\n' + '=' * 100)
276
+ sep = '\n' + '-' * 100
277
+ logger.info(f'question len: {len(question)} {sep}')
278
+ logger.info(f'question: {question} {sep}')
279
+ logger.info(f'answer len: {len(response.get_answer())} {sep}')
280
+ logger.info(f'answer: {response.get_answer()} {sep}')
281
+ logger.info(f'{response.get_sources_as_text()} {sep}')
282
+ logger.info(f'messages_contex: {messages_context} {sep}')
283
+ logger.info(f'relevant_docs: {relevant_docs} {sep}')
284
+ logger.info(f'context len: {len(context)} {sep}')
285
+ logger.info(f'context: {context} {sep}')
286
+ return response
qa_engine/response.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+
4
+ class Response:
5
+ def __init__(self):
6
+ self.answer = ''
7
+ self.sources = []
8
+
9
+ def set_answer(self, answer: str) -> None:
10
+ self.answer = answer
11
+
12
+ def set_sources(self, sources: List) -> None:
13
+ self.sources = list(set(map(str, sources)))
14
+
15
+ def get_sources(self) -> List[str]:
16
+ return self.sources
17
+
18
+ def get_sources_as_text(self) -> str:
19
+ if not self.sources:
20
+ return ''
21
+ sources_text = '\n\nSources:'
22
+ for i, (source) in enumerate(self.sources):
23
+ sources_text += f'\n [{i+1}] {source}'
24
+ return sources_text
25
+
26
+ def get_answer(self, include_sources: bool = False) -> str:
27
+ answer = self.answer
28
+ if include_sources:
29
+ answer += self.get_sources_as_text()
30
+ return answer
31
+
32
+ def __str__(self):
33
+ return self.get_answer(include_sources=True)
questions.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ How to create audio dataset with Hugging Face?
2
+ I want to check if 2 sentences are similar semantically. How can I do it?
3
+ What are the benefits of Gradio?
4
+ How to deploy a text-to-image model?
5
+ Does Hugging Face offer any distributed training assistance? followup: Can you give me an example setup of it?
6
+ I want to detect cars on video recording. How should I do it and what models do you recommend?
7
+ Is there any tool for evaluating models in Hugging Face? followup: Can you give me an example setup of it?
8
+ What are some advantages of the Hugging Face Hub?
9
+ How would I use a model in 8 bit in transformers?
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ accelerate
5
+ einops
6
+ huggingface_hub
7
+ gradio
8
+ beautifulsoup4==4.12.0
9
+ discord.py==2.2.2
10
+ evaluate==0.4.0
11
+ fastapi==0.98.0
12
+ langchain==0.0.154
13
+ nltk==3.8.1
14
+ nbconvert==7.6.0
15
+ nbformat==5.9.0
16
+ numpy==1.24.2
17
+ markdown==3.4.4
18
+ pandas==1.5.3
19
+ python-dotenv==1.0.0
20
+ Requests==2.29.0
21
+ scikit_learn==1.2.2
22
+ sentence-transformers==2.2.2
23
+ InstructorEmbedding==1.0.0
24
+ faiss_cpu==1.7.3
25
+ tqdm==4.64.1
26
+ uvicorn==0.22.0
27
+ wandb==0.15.0
28
+ pytest==7.3.1
run_docker.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ docker-compose down && docker-compose up --build
run_tests.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ pytest -o "testpaths=tests/" --noconftest -vv
tests/__init__.py ADDED
File without changes
tests/discord_bot/__init__.py ADDED
File without changes
tests/discord_bot/client/__init__.py ADDED
File without changes
tests/discord_bot/client/lorem_ipsum.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Lorem ipsum dolor sit amet, ```consectetur adipiscing elit. Sed vitae arcu``` eros. Sed gravida tellus quis ante luctus, sed scelerisque mi auctor.
2
+ Pellentesque consectetur fringilla turpis, in viverra odio posuere ut. In venenatis dui lectus, eget dignissim lectus lacinia et.
3
+ Nulla commodo, nunc et vulputate vestibulum, mauris lorem semper nisl, a lacinia nulla urna ac lectus. Fusce pulvinar pulvinar augue vitae pulvinar. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; Mauris ut elit nec erat tristique sagittis id in massa.
4
+ Nunc ac libero in mauris consectetur bibendum non id nulla. Suspendisse feugiat ipsum metus, et auctor massa lobortis ac.
5
+ Sed tincidunt dapibus sapien, ut tincidunt dui venenatis vitae. Praesent fermentum mi quis elit posuere, in consequat dui ultrices. Nullam tempus sapien non gravida faucibus.
6
+ Etiam placerat magna risus, in dictum justo commodo eu. Fusce commodo viverra augue, non ullamcorper nunc efficitur et.
7
+ Nam convallis dui vel nunc malesuada, at pharetra dui pulvinar. Cras commodo nibh et neque placerat eleifend. Mauris eget ante pharetra, mattis turpis ac, laoreet enim.
8
+ Sed non sem sit amet dui venenatis posuere. Proin eget eros at nunc fringilla varius eu nec justo.
9
+ Duis tristique consequat metus vitae euismod. Maecenas ullamcorper ullamcorper diam, nec tristique est congue nec. Suspendisse eu elit id massa fermentum pharetra in et mauris.
10
+ In finibus metus nec odio aliquam, eu finibus ligula venenatis. Donec ullamcorper turpis eget sapien blandit volutpat. Sed tincidunt malesuada urna vitae cursus.
11
+ Phasellus in est in purus lobortis dictum. Vivamus tristique turpis at lectus dapibus, ac volutpat arcu volutpat. Integer ac gravida felis.
12
+ Donec laoreet neque non enim fermentum, vitae tincidunt est varius. Phasellus vehicula nisl nec metus dictum malesuada. Vestibulum sit amet erat et tellus euismod congue.
13
+ Nulla facilisi. Mauris venenatis arcu vel elementum consectetur. Cras volutpat scelerisque sollicitudin.
14
+ Fusce eu justo et magna gravida finibus id ut augue. Curabitur malesuada, justo ut semper dapibus, nulla urna maximus tellus, in vestibulum mauris justo id lectus. Vestibulum sit amet mauris feugiat, rhoncus ligula sed, finibus mi.
tests/discord_bot/client/test_utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import os
3
+ from discord_bot.client.utils import ( \
4
+ find_max_split_index, \
5
+ find_max_split_index_from_sequence, \
6
+ split_text_into_chunks
7
+ )
8
+
9
+
10
+ @pytest.fixture(scope='module')
11
+ def test_chunk() -> str:
12
+ return 't. , \n .'
13
+
14
+
15
+ @pytest.fixture(scope='module')
16
+ def test_text() -> str:
17
+ with open('tests/discord_bot/client/lorem_ipsum.txt', 'r') as f:
18
+ text = f.read()
19
+ assert text is not None, 'test text is empty'
20
+ return text
21
+
22
+
23
+ def test_find_max_splitting_index(test_chunk: str):
24
+ index = find_max_split_index(test_chunk, char='\n')
25
+ assert index == 6, 'index should be 6'
26
+ index = find_max_split_index(test_chunk, char='. ')
27
+ assert index == 3, 'index should be 3'
28
+ index = find_max_split_index(test_chunk, char='.')
29
+ assert index == 8, 'index should be 8'
30
+
31
+
32
+ def test_find_max_split_index_from_sequence(test_chunk: str):
33
+ index = find_max_split_index_from_sequence(
34
+ test_chunk,
35
+ split_characters=['\n']
36
+ )
37
+ assert index == 6, 'index should be 6'
38
+ index = find_max_split_index_from_sequence(
39
+ test_chunk,
40
+ split_characters=['.', ', ', '\n']
41
+ )
42
+ assert index == 8, 'index should be 8'
43
+
44
+
45
+ def test_split_text_into_chunks_with_split_characters(test_text: str):
46
+ max_chunk_size = 250
47
+ chunks = split_text_into_chunks(
48
+ test_text,
49
+ split_characters=['. ', ', ', '\n'],
50
+ min_size=20,
51
+ max_size=max_chunk_size
52
+ )
53
+ for chunk in chunks:
54
+ assert len(chunk) > 0, 'Chunk length is zero'
55
+ assert len(chunk) <= max_chunk_size, 'Chunk length exceeds maximum limit'
56
+
57
+
58
+ def test_split_text_into_chunks_without_split_characters():
59
+ test_text = 'a' * 1000
60
+ max_chunk_size = 250
61
+ chunks = split_text_into_chunks(
62
+ test_text,
63
+ split_characters=[],
64
+ min_size=20,
65
+ max_size=max_chunk_size
66
+ )
67
+ for chunk in chunks:
68
+ assert len(chunk) == max_chunk_size, \
69
+ 'Chunk length is too small'
tests/index/test_index.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from typing import Any
3
+ from huggingface_hub import snapshot_download
4
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
5
+ from langchain.vectorstores import FAISS
6
+
7
+
8
+ snapshot_download(
9
+ repo_id='KonradSzafer/index-large-notebooks',
10
+ allow_patterns=['*.faiss', '*.pkl'],
11
+ repo_type='dataset',
12
+ local_dir='indexes/'
13
+ )
14
+
15
+ @pytest.fixture(scope='module')
16
+ def embedding_model() -> HuggingFaceInstructEmbeddings:
17
+ model_name = 'hkunlp/instructor-large'
18
+ embed_instruction = 'Represent the Hugging Face library documentation'
19
+ query_instruction = 'Query the most relevant piece of information from the Hugging Face documentation'
20
+ return HuggingFaceInstructEmbeddings(
21
+ model_name=model_name,
22
+ embed_instruction=embed_instruction,
23
+ query_instruction=query_instruction,
24
+ )
25
+
26
+ @pytest.fixture(scope='module')
27
+ def index_path() -> str:
28
+ return 'indexes/'
29
+
30
+ @pytest.fixture(scope='module')
31
+ def index(embedding_model: HuggingFaceInstructEmbeddings, index_path: str):
32
+ return FAISS.load_local(index_path, embedding_model)
33
+
34
+ @pytest.fixture(scope='module')
35
+ def query() -> str:
36
+ return 'How to use the tokenizer?'
37
+
38
+ def test_load_index(embedding_model: HuggingFaceInstructEmbeddings, index_path: str):
39
+ index = FAISS.load_local(index_path, embedding_model)
40
+ assert index is not None, 'Failed to load index'
41
+
42
+ def test_index_page_content(index, query: str):
43
+ query_docs = index.similarity_search(query=query, k=3)
44
+ assert isinstance(query_docs[0].page_content, str)
45
+
46
+ def test_index_metadata(index, query):
47
+ query_docs = index.similarity_search(query=query, k=3)
48
+ assert isinstance(query_docs[0].metadata['source'], str)
tests/qa_engine/__init__.py ADDED
File without changes
tests/qa_engine/test_response.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+ import importlib
4
+
5
+ from qa_engine.response import Response
6
+
7
+
8
+ def test_set_answer():
9
+ r = Response()
10
+ r.set_answer('Hello, World!')
11
+ assert r.get_answer() == 'Hello, World!'
12
+
13
+
14
+ def test_set_sources():
15
+ r = Response()
16
+ r.set_sources(['source1', 'source1', 'source2'])
17
+ assert len(r.get_sources()) == 2
18
+
19
+
20
+ def test_get_sources_as_text():
21
+ r = Response()
22
+ r.set_sources(['source1', 'source2'])
23
+ assert isinstance(r.get_sources_as_text(), str)
24
+
25
+
26
+ def test_get_response_include_sources():
27
+ r = Response()
28
+ r.set_answer('Hello, World!')
29
+ r.set_sources(['source1', 'source2'])
30
+ assert len(r.get_answer(include_sources=True)) > len('Hello, World!')