Spaces:
Runtime error
Runtime error
KonradSzafer
commited on
Commit
•
c69cba4
1
Parent(s):
2a53bac
initial commit
Browse files- .github/workflows/tests-integration.yml +37 -0
- .gitignore +68 -0
- Dockerfile.api +19 -0
- Dockerfile.bot +17 -0
- LICENSE +21 -0
- api/__init__.py +0 -0
- api/__main__.py +39 -0
- app.py +33 -0
- assets/example.png +0 -0
- benchmarker.py +63 -0
- config/.env.example +16 -0
- config/prompt_templates/llama.txt +7 -0
- config/prompt_templates/llama2.txt +6 -0
- config/prompt_templates/sythia_v1.3.txt +3 -0
- data/benchmark/.gitkeep +0 -0
- data/datasets/hf_repositories_urls.json +22 -0
- data/datasets/hf_repositories_urls_scraped.json +92 -0
- data/get_hf_repositories_urls.py +49 -0
- data/hugging_face_docs_dataset.py +190 -0
- data/indexer.ipynb +238 -0
- data/language-codes.csv +190 -0
- data/scrapers/stack_overflow_scraper.py +91 -0
- data/stackoverflow_python_dataset.py +55 -0
- data/upload_csv_dataset.py +24 -0
- discord_bot/__init__.py +0 -0
- discord_bot/__main__.py +28 -0
- discord_bot/client/__init__.py +1 -0
- discord_bot/client/client.py +130 -0
- discord_bot/client/utils.py +54 -0
- docker-compose.yml +23 -0
- models/inference.ipynb +103 -0
- qa_engine/__init__.py +9 -0
- qa_engine/config.py +64 -0
- qa_engine/logger.py +14 -0
- qa_engine/mocks.py +41 -0
- qa_engine/qa_engine.py +286 -0
- qa_engine/response.py +33 -0
- questions.txt +9 -0
- requirements.txt +28 -0
- run_docker.sh +2 -0
- run_tests.sh +2 -0
- tests/__init__.py +0 -0
- tests/discord_bot/__init__.py +0 -0
- tests/discord_bot/client/__init__.py +0 -0
- tests/discord_bot/client/lorem_ipsum.txt +14 -0
- tests/discord_bot/client/test_utils.py +69 -0
- tests/index/test_index.py +48 -0
- tests/qa_engine/__init__.py +0 -0
- tests/qa_engine/test_response.py +30 -0
.github/workflows/tests-integration.yml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: CI
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
- feature/**
|
8 |
+
- issue/**
|
9 |
+
|
10 |
+
jobs:
|
11 |
+
build:
|
12 |
+
strategy:
|
13 |
+
matrix:
|
14 |
+
python-version: [3.11]
|
15 |
+
runs-on: ubuntu-latest
|
16 |
+
|
17 |
+
steps:
|
18 |
+
- name: Checkout
|
19 |
+
uses: actions/checkout@v2
|
20 |
+
with:
|
21 |
+
fetch-depth: 0
|
22 |
+
|
23 |
+
- name: Switch to Current Branch
|
24 |
+
run: git checkout ${{ env.BRANCH }}
|
25 |
+
|
26 |
+
- name: Set up Python ${{ matrix.python-version }}
|
27 |
+
uses: actions/setup-python@v1
|
28 |
+
with:
|
29 |
+
python-version: ${{ matrix.python-version }}
|
30 |
+
|
31 |
+
- name: Install dependencies
|
32 |
+
run: |
|
33 |
+
pip install --no-cache-dir -r requirements.txt
|
34 |
+
cp config/.env.example config/.env
|
35 |
+
- name: Run unit tests
|
36 |
+
run: |
|
37 |
+
pytest -o "testpaths=tests" --noconftest
|
.gitignore
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# Unit test / coverage reports
|
31 |
+
htmlcov/
|
32 |
+
.tox/
|
33 |
+
.nox/
|
34 |
+
.coverage
|
35 |
+
.coverage.*
|
36 |
+
.cache
|
37 |
+
nosetests.xml
|
38 |
+
coverage.xml
|
39 |
+
*.cover
|
40 |
+
*.py,cover
|
41 |
+
.hypothesis/
|
42 |
+
.pytest_cache/
|
43 |
+
|
44 |
+
# Environments
|
45 |
+
.env
|
46 |
+
.venv
|
47 |
+
env/
|
48 |
+
venv/
|
49 |
+
|
50 |
+
# OS
|
51 |
+
.DS_Store
|
52 |
+
|
53 |
+
# IDE
|
54 |
+
.vscode
|
55 |
+
|
56 |
+
# Project
|
57 |
+
wandb/
|
58 |
+
indexes/
|
59 |
+
*.out
|
60 |
+
*.log
|
61 |
+
|
62 |
+
# Data
|
63 |
+
data/datasets/*
|
64 |
+
!data/datasets/hf_repositories_urls.json
|
65 |
+
!data/datasets/hf_repositories_urls_scraped.json
|
66 |
+
|
67 |
+
# Local models
|
68 |
+
qa_engine/local_models
|
Dockerfile.api
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM ubuntu:latest
|
2 |
+
|
3 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
4 |
+
|
5 |
+
RUN apt-get -y update && \
|
6 |
+
apt-get -y upgrade && \
|
7 |
+
apt-get -y install git python3.10 python3-pip
|
8 |
+
|
9 |
+
COPY requirements.txt .
|
10 |
+
RUN pip install --upgrade pip && \
|
11 |
+
pip install --no-cache-dir -r requirements.txt
|
12 |
+
|
13 |
+
WORKDIR /hugging-face-qa-bot
|
14 |
+
COPY config/api/ config/api/
|
15 |
+
COPY api/ api/
|
16 |
+
|
17 |
+
EXPOSE 8000
|
18 |
+
|
19 |
+
ENTRYPOINT [ "python3", "-m", "api" ]
|
Dockerfile.bot
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM ubuntu:latest
|
2 |
+
|
3 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
4 |
+
|
5 |
+
RUN apt-get -y update && \
|
6 |
+
apt-get -y upgrade && \
|
7 |
+
apt-get -y install git python3.10 python3-pip
|
8 |
+
|
9 |
+
COPY requirements.txt .
|
10 |
+
RUN pip install --upgrade pip && \
|
11 |
+
pip install --no-cache-dir -r requirements.txt
|
12 |
+
|
13 |
+
WORKDIR /hugging-face-qa-bot
|
14 |
+
COPY config/bot/ config/bot/
|
15 |
+
COPY bot/ bot/
|
16 |
+
|
17 |
+
ENTRYPOINT [ "python3", "-m", "bot" ]
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
api/__init__.py
ADDED
File without changes
|
api/__main__.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uvicorn
|
2 |
+
from fastapi import FastAPI
|
3 |
+
|
4 |
+
from qa_engine import logger, Config, QAEngine
|
5 |
+
|
6 |
+
|
7 |
+
config = Config()
|
8 |
+
app = FastAPI()
|
9 |
+
qa_engine = QAEngine(
|
10 |
+
llm_model_id=config.question_answering_model_id,
|
11 |
+
embedding_model_id=config.embedding_model_id,
|
12 |
+
index_repo_id=config.index_repo_id,
|
13 |
+
prompt_template=config.prompt_template,
|
14 |
+
use_docs_for_context=config.use_docs_for_context,
|
15 |
+
num_relevant_docs=config.num_relevant_docs,
|
16 |
+
add_sources_to_response=config.add_sources_to_response,
|
17 |
+
use_messages_for_context=config.use_messages_in_context,
|
18 |
+
debug=config.debug
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
@app.get('/')
|
23 |
+
def get_answer(question: str, messages_context: str = ''):
|
24 |
+
logger.info(
|
25 |
+
f'Received request with question: {question}' \
|
26 |
+
f'and context: {messages_context}'
|
27 |
+
)
|
28 |
+
response = qa_engine.get_response(
|
29 |
+
question=question,
|
30 |
+
messages_context=messages_context
|
31 |
+
)
|
32 |
+
return {
|
33 |
+
'answer': response.get_answer(),
|
34 |
+
'sources': response.get_sources_as_text()
|
35 |
+
}
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
uvicorn.run(app, host='0.0.0.0', port=8000)
|
app.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from qa_engine import logger, Config, QAEngine
|
4 |
+
|
5 |
+
|
6 |
+
config = Config()
|
7 |
+
model = QAEngine(
|
8 |
+
llm_model_id=config.question_answering_model_id,
|
9 |
+
embedding_model_id=config.embedding_model_id,
|
10 |
+
index_repo_id=config.index_repo_id,
|
11 |
+
prompt_template=config.prompt_template,
|
12 |
+
use_docs_for_context=config.use_docs_for_context,
|
13 |
+
add_sources_to_response=config.add_sources_to_response,
|
14 |
+
use_messages_for_context=config.use_messages_in_context,
|
15 |
+
debug=config.debug
|
16 |
+
)
|
17 |
+
|
18 |
+
with gr.Blocks() as demo:
|
19 |
+
chatbot = gr.Chatbot()
|
20 |
+
msg = gr.Textbox()
|
21 |
+
clear = gr.ClearButton([msg, chatbot])
|
22 |
+
|
23 |
+
def respond(message, chat_history):
|
24 |
+
context = ''.join(f'User: {msg} \nBot:{bot_msg}\n' for msg, bot_msg in chat_history)
|
25 |
+
logger.info(f'Context: {context}')
|
26 |
+
response = model.get_response(message, context)
|
27 |
+
bot_message = response.get_answer() + response.get_sources_as_text() + '\n'
|
28 |
+
chat_history.append((message, bot_message))
|
29 |
+
return '', chat_history
|
30 |
+
|
31 |
+
msg.submit(respond, [msg, chatbot], [msg, chatbot])
|
32 |
+
|
33 |
+
demo.launch(share=True)
|
assets/example.png
ADDED
benchmarker.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
from api.config import Config
|
4 |
+
from api.logger import logger
|
5 |
+
from api.question_answering import QAModel
|
6 |
+
import time
|
7 |
+
|
8 |
+
|
9 |
+
load_dotenv(dotenv_path='config/api/.env')
|
10 |
+
|
11 |
+
config = Config()
|
12 |
+
model = QAModel(
|
13 |
+
llm_model_id=config.question_answering_model_id,
|
14 |
+
embedding_model_id=config.embedding_model_id,
|
15 |
+
index_repo_id=config.index_repo_id,
|
16 |
+
prompt_template=config.prompt_template,
|
17 |
+
use_docs_for_context=config.use_docs_for_context,
|
18 |
+
add_sources_to_response=config.add_sources_to_response,
|
19 |
+
use_messages_for_context=config.use_messages_in_context,
|
20 |
+
debug=config.debug
|
21 |
+
)
|
22 |
+
|
23 |
+
QUESTIONS_FILENAME = 'data/benchmark/questions.json'
|
24 |
+
ANSWERS_FILENAME = 'data/benchmark/answers.json'
|
25 |
+
|
26 |
+
|
27 |
+
def main():
|
28 |
+
benchmark_name = \
|
29 |
+
f'model: {config.question_answering_model_id}' \
|
30 |
+
f'index: {config.index_repo_id}'
|
31 |
+
|
32 |
+
wandb.init(
|
33 |
+
project='HF-Docs-QA',
|
34 |
+
name=f'model: {config.question_answering_model_id}',
|
35 |
+
mode='run', # run/disabled
|
36 |
+
config=config.asdict()
|
37 |
+
)
|
38 |
+
# log config to wandb
|
39 |
+
|
40 |
+
with open(QUESTIONS_FILENAME, 'r') as f: # json
|
41 |
+
questions = f.readlines()
|
42 |
+
|
43 |
+
with open(ANSWERS_FILENAME, 'w') as f:
|
44 |
+
for q in questions:
|
45 |
+
question = q['question']
|
46 |
+
messages_contex = q['messages_context']
|
47 |
+
|
48 |
+
t_start = time.perf_counter()
|
49 |
+
response = model.get_response(
|
50 |
+
question=question,
|
51 |
+
messages_context=messages_context
|
52 |
+
)
|
53 |
+
t_end = time.perf_counter()
|
54 |
+
# write to json
|
55 |
+
{
|
56 |
+
"answer": response.get_answer(),
|
57 |
+
"sources": response.get_sources_as_text(),
|
58 |
+
'time': t_end - t_start
|
59 |
+
}
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == '__main__':
|
63 |
+
main()
|
config/.env.example
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# QA engine settings
|
2 |
+
QUESTION_ANSWERING_MODEL_ID=hf-question-answering-model-ID
|
3 |
+
EMBEDDING_MODEL_ID=hf-embedding-model-ID
|
4 |
+
INDEX_REPO_ID=hf-index-ID
|
5 |
+
PROMPT_TEMPLATE_NAME=llama
|
6 |
+
USE_DOCS_FOR_CONTEXT=True
|
7 |
+
NUM_RELEVANT_DOCS=4
|
8 |
+
ADD_SOURCES_TO_RESPONSE=True
|
9 |
+
USE_MESSAGES_IN_CONTEXT=True
|
10 |
+
DEBUG=True
|
11 |
+
|
12 |
+
# Discord settings
|
13 |
+
DISCORD_TOKEN=your-bot-token
|
14 |
+
NUM_LAST_MESSAGES=1
|
15 |
+
USE_NAMES_IN_CONTEXT=False
|
16 |
+
ENABLE_COMMANDS=True
|
config/prompt_templates/llama.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Instruction:
|
2 |
+
Give an answer that contains all the necessary information for the question.
|
3 |
+
If the context contains necessary information to answer question, use it to generate an appropriate response.
|
4 |
+
{context}
|
5 |
+
### Input:
|
6 |
+
{question}
|
7 |
+
### Response:
|
config/prompt_templates/llama2.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<<SYS>>
|
2 |
+
**CONTEXT FOR THE QUESTION**:
|
3 |
+
{context}
|
4 |
+
<</SYS>>
|
5 |
+
|
6 |
+
[INST] Respond for the user question as factualy as possible, using given context. [/INST] User: {question}
|
config/prompt_templates/sythia_v1.3.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
SYSTEM: {context}
|
2 |
+
USER: {question}
|
3 |
+
ASSISTANT:
|
data/benchmark/.gitkeep
ADDED
File without changes
|
data/datasets/hf_repositories_urls.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"urls": [
|
3 |
+
"https://github.com/huggingface/transformers",
|
4 |
+
"https://github.com/huggingface/diffusers",
|
5 |
+
"https://github.com/huggingface/datasets",
|
6 |
+
"https://github.com/gradio-app/gradio",
|
7 |
+
"https://github.com/huggingface/huggingface_hub",
|
8 |
+
"https://github.com/huggingface/peft",
|
9 |
+
"https://github.com/huggingface/blog",
|
10 |
+
"https://github.com/huggingface/optimum",
|
11 |
+
"https://github.com/huggingface/tokenizers",
|
12 |
+
"https://github.com/huggingface/course",
|
13 |
+
"https://github.com/huggingface/deep-rl-class",
|
14 |
+
"https://github.com/huggingface/evaluate",
|
15 |
+
"https://github.com/huggingface/datasets-server",
|
16 |
+
"https://github.com/huggingface/simulate",
|
17 |
+
"https://github.com/huggingface/hub-docs",
|
18 |
+
"https://github.com/huggingface/pytorch-image-models",
|
19 |
+
"https://github.com/huggingface/safetensors",
|
20 |
+
"https://github.com/huggingface/hf-endpoints-documentation"
|
21 |
+
]
|
22 |
+
}
|
data/datasets/hf_repositories_urls_scraped.json
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"urls": [
|
3 |
+
"https://github.com/huggingface/tokenizers",
|
4 |
+
"https://github.com/huggingface/datablations",
|
5 |
+
"https://github.com/huggingface/peft",
|
6 |
+
"https://github.com/huggingface/tflite-android-transformers",
|
7 |
+
"https://github.com/huggingface/simulate",
|
8 |
+
"https://github.com/huggingface/transformers",
|
9 |
+
"https://github.com/huggingface/deep-rl-class",
|
10 |
+
"https://github.com/huggingface/awesome-huggingface",
|
11 |
+
"https://github.com/huggingface/datasets-server",
|
12 |
+
"https://github.com/huggingface/setfit",
|
13 |
+
"https://github.com/huggingface/olm-training",
|
14 |
+
"https://github.com/huggingface/huggingface_sb3",
|
15 |
+
"https://github.com/huggingface/optimum-neuron",
|
16 |
+
"https://github.com/huggingface/blog",
|
17 |
+
"https://github.com/huggingface/100-times-faster-nlp",
|
18 |
+
"https://github.com/huggingface/bloom-jax-inference",
|
19 |
+
"https://github.com/huggingface/speechbox",
|
20 |
+
"https://github.com/huggingface/olm-datasets",
|
21 |
+
"https://github.com/huggingface/hub-docs",
|
22 |
+
"https://github.com/huggingface/torchMoji",
|
23 |
+
"https://github.com/huggingface/hffs",
|
24 |
+
"https://github.com/huggingface/trl",
|
25 |
+
"https://github.com/huggingface/text-generation-inference",
|
26 |
+
"https://github.com/huggingface/Mongoku",
|
27 |
+
"https://github.com/huggingface/education-toolkit",
|
28 |
+
"https://github.com/huggingface/datasets",
|
29 |
+
"https://github.com/huggingface/optimum-benchmark",
|
30 |
+
"https://github.com/huggingface/course",
|
31 |
+
"https://github.com/huggingface/accelerate",
|
32 |
+
"https://github.com/huggingface/pytorch-image-models",
|
33 |
+
"https://github.com/huggingface/fuego",
|
34 |
+
"https://github.com/huggingface/diffusion-models-class",
|
35 |
+
"https://github.com/huggingface/disaggregators",
|
36 |
+
"https://github.com/huggingface/unity-api",
|
37 |
+
"https://github.com/huggingface/workshops",
|
38 |
+
"https://github.com/huggingface/llm-ls",
|
39 |
+
"https://github.com/huggingface/llm-vscode",
|
40 |
+
"https://github.com/huggingface/community-events",
|
41 |
+
"https://github.com/huggingface/tune",
|
42 |
+
"https://github.com/huggingface/candle",
|
43 |
+
"https://github.com/huggingface/paper-style-guide",
|
44 |
+
"https://github.com/huggingface/huggingface.js",
|
45 |
+
"https://github.com/huggingface/neuralcoref",
|
46 |
+
"https://github.com/huggingface/hfapi",
|
47 |
+
"https://github.com/huggingface/data-measurements-tool",
|
48 |
+
"https://github.com/huggingface/personas",
|
49 |
+
"https://github.com/huggingface/instruction-tuned-sd",
|
50 |
+
"https://github.com/huggingface/swift-transformers",
|
51 |
+
"https://github.com/huggingface/api-inference-community",
|
52 |
+
"https://github.com/huggingface/diffusers",
|
53 |
+
"https://github.com/huggingface/safetensors",
|
54 |
+
"https://github.com/huggingface/optimum-graphcore",
|
55 |
+
"https://github.com/huggingface/OBELICS",
|
56 |
+
"https://github.com/huggingface/swift-coreml-diffusers",
|
57 |
+
"https://github.com/huggingface/naacl_transfer_learning_tutorial",
|
58 |
+
"https://github.com/huggingface/nn_pruning",
|
59 |
+
"https://github.com/huggingface/awesome-papers",
|
60 |
+
"https://github.com/huggingface/optimum-intel",
|
61 |
+
"https://github.com/huggingface/autotrain-advanced",
|
62 |
+
"https://github.com/huggingface/pytorch-openai-transformer-lm",
|
63 |
+
"https://github.com/huggingface/node-question-answering",
|
64 |
+
"https://github.com/huggingface/optimum",
|
65 |
+
"https://github.com/huggingface/knockknock",
|
66 |
+
"https://github.com/huggingface/optimum-habana",
|
67 |
+
"https://github.com/huggingface/transfer-learning-conv-ai",
|
68 |
+
"https://github.com/huggingface/notebooks",
|
69 |
+
"https://github.com/huggingface/hmtl",
|
70 |
+
"https://github.com/huggingface/block_movement_pruning",
|
71 |
+
"https://github.com/huggingface/huggingface_hub",
|
72 |
+
"https://github.com/huggingface/transformers-bloom-inference",
|
73 |
+
"https://github.com/huggingface/hf_transfer",
|
74 |
+
"https://github.com/huggingface/doc-builder",
|
75 |
+
"https://github.com/huggingface/large_language_model_training_playbook",
|
76 |
+
"https://github.com/huggingface/that_is_good_data",
|
77 |
+
"https://github.com/huggingface/swift-coreml-transformers",
|
78 |
+
"https://github.com/huggingface/datasets-viewer",
|
79 |
+
"https://github.com/huggingface/open-muse",
|
80 |
+
"https://github.com/huggingface/evaluate",
|
81 |
+
"https://github.com/huggingface/llm_training_handbook",
|
82 |
+
"https://github.com/huggingface/pytorch_block_sparse",
|
83 |
+
"https://github.com/huggingface/chat-ui",
|
84 |
+
"https://github.com/huggingface/llm.nvim",
|
85 |
+
"https://github.com/huggingface/swift-chat",
|
86 |
+
"https://github.com/huggingface/pytorch-pretrained-BigGAN",
|
87 |
+
"https://github.com/huggingface/exporters",
|
88 |
+
"https://github.com/huggingface/audio-transformers-course",
|
89 |
+
"https://github.com/huggingface/hf-endpoints-documentation",
|
90 |
+
"https://github.com/gradio-app/gradio"
|
91 |
+
]
|
92 |
+
}
|
data/get_hf_repositories_urls.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import argparse
|
3 |
+
import requests
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
|
7 |
+
def get_repositories_names(token: str, min_stars: int) -> List[str]:
|
8 |
+
repos_per_page = 100
|
9 |
+
repo_names = []
|
10 |
+
i = 0
|
11 |
+
while True:
|
12 |
+
url = \
|
13 |
+
f'https://api.github.com/orgs/huggingface/repos?' \
|
14 |
+
f'per_page={repos_per_page}&page={i}'
|
15 |
+
headers = {'Authorization': f'token {token}'}
|
16 |
+
response = requests.get(url, headers=headers)
|
17 |
+
if response.status_code == 200:
|
18 |
+
repos = json.loads(response.content)
|
19 |
+
repo_names += [
|
20 |
+
repo['full_name'] for repo in repos
|
21 |
+
if repo['stargazers_count'] >= min_stars
|
22 |
+
]
|
23 |
+
if len(repos) < repos_per_page:
|
24 |
+
break
|
25 |
+
i += 1
|
26 |
+
else:
|
27 |
+
return 'Error: '+str(response.status_code)
|
28 |
+
return list(set(repo_names))
|
29 |
+
|
30 |
+
|
31 |
+
def save_repositories_urls(repositories_names: List[str], output_filename: str):
|
32 |
+
urls = [f'https://github.com/{repo_name}' for repo_name in repositories_names]
|
33 |
+
data = {'urls': urls}
|
34 |
+
with open(output_filename, 'w') as f:
|
35 |
+
json.dump(data, f, indent=4)
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
parser = argparse.ArgumentParser()
|
40 |
+
parser.add_argument('--token', type=str)
|
41 |
+
parser.add_argument('--stars', type=str)
|
42 |
+
args = parser.parse_args()
|
43 |
+
repositories = get_repositories_names(token=args.token, min_stars=int(args.stars))
|
44 |
+
repositories += [
|
45 |
+
'huggingface/hf-endpoints-documentation',
|
46 |
+
'gradio-app/gradio'
|
47 |
+
]
|
48 |
+
print(f'Found {len(repositories)} repositories with at least {args.stars} stars')
|
49 |
+
save_repositories_urls(repositories, 'datasets/hf_repositories_urls_scraped.json')
|
data/hugging_face_docs_dataset.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import subprocess
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
import requests
|
9 |
+
import pandas as pd
|
10 |
+
from bs4 import BeautifulSoup
|
11 |
+
from markdown import markdown
|
12 |
+
import nbformat
|
13 |
+
from nbconvert import MarkdownExporter
|
14 |
+
from nbconvert.preprocessors import Preprocessor, ClearOutputPreprocessor
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
|
18 |
+
VALIDATE_URLS = False
|
19 |
+
|
20 |
+
|
21 |
+
def download_repositories(repo_urls_file: str, repo_dir: str):
|
22 |
+
"""
|
23 |
+
Downloads the Hugging Face repositories.
|
24 |
+
"""
|
25 |
+
if not os.path.exists(repo_dir):
|
26 |
+
os.makedirs(repo_dir)
|
27 |
+
with open(repo_urls_file, "r") as f:
|
28 |
+
repositories_urls = json.load(f)["urls"]
|
29 |
+
print(f'Downloading {len(repositories_urls)} repositories')
|
30 |
+
for url in repositories_urls:
|
31 |
+
try:
|
32 |
+
subprocess.run(["git", "clone", url], cwd=repo_dir)
|
33 |
+
except subprocess.CalledProcessError as e:
|
34 |
+
print("Command failed with error:", e.stderr)
|
35 |
+
|
36 |
+
|
37 |
+
class EmptyCellPreprocessor(Preprocessor):
|
38 |
+
def preprocess_cell(self, cell, resources, index):
|
39 |
+
if cell.source.strip() == '':
|
40 |
+
cell.source = ''
|
41 |
+
cell.cell_type = 'raw'
|
42 |
+
return cell, resources
|
43 |
+
|
44 |
+
|
45 |
+
def convert_notebook_to_txt(filename: str):
|
46 |
+
"""
|
47 |
+
Converts a notebook to a markdown file.
|
48 |
+
"""
|
49 |
+
with open(filename) as f:
|
50 |
+
notebook = nbformat.read(f, as_version=4)
|
51 |
+
# id validation error fix
|
52 |
+
for cell in notebook['cells']:
|
53 |
+
cell['id'] = str(cell['id'])
|
54 |
+
|
55 |
+
clear_output = ClearOutputPreprocessor()
|
56 |
+
notebook, resources = clear_output.preprocess(notebook, {})
|
57 |
+
|
58 |
+
exporter = MarkdownExporter()
|
59 |
+
exporter.register_preprocessor(EmptyCellPreprocessor, enabled=True)
|
60 |
+
output_notebook_text, resources = exporter.from_notebook_node(notebook)
|
61 |
+
|
62 |
+
new_filename = filename.replace('.ipynb', '_ipynb.md')
|
63 |
+
with open(new_filename, 'w') as f:
|
64 |
+
f.write(output_notebook_text)
|
65 |
+
return new_filename
|
66 |
+
|
67 |
+
|
68 |
+
def extract_files_from_directories(
|
69 |
+
repo_urls_file: str,
|
70 |
+
repo_dir: str,
|
71 |
+
docs_dir: str,
|
72 |
+
files_extensions: List[str]
|
73 |
+
) -> None:
|
74 |
+
|
75 |
+
"""
|
76 |
+
This function reads markdown and markdownx files from the repositories directory,
|
77 |
+
filters out non-English files, and adds the source GitHub URL as the first line of each file.
|
78 |
+
The resulting files are saved in the docs_dir.
|
79 |
+
"""
|
80 |
+
languages = pd.read_csv("language-codes.csv").loc[:,"alpha2"].tolist()
|
81 |
+
languages.remove("en")
|
82 |
+
|
83 |
+
files = [
|
84 |
+
filename
|
85 |
+
for extension in files_extensions
|
86 |
+
for filename in glob.glob(repo_dir + f"**/*{extension}", recursive=True)
|
87 |
+
]
|
88 |
+
print(f'Used extensions: {", ".join(files_extensions)}')
|
89 |
+
print(f'Found {len(files)} files')
|
90 |
+
|
91 |
+
repo_urls = []
|
92 |
+
with open(repo_urls_file, "r") as f:
|
93 |
+
repo_urls = json.load(f)["urls"]
|
94 |
+
|
95 |
+
# filter out the files that are not in english
|
96 |
+
filtered_files = []
|
97 |
+
for filename in files:
|
98 |
+
sep_file = filename.split("/")
|
99 |
+
for seq in sep_file:
|
100 |
+
if seq in languages:
|
101 |
+
break
|
102 |
+
else:
|
103 |
+
filtered_files.append(filename)
|
104 |
+
print(f'Found {len(filtered_files)} files in English')
|
105 |
+
|
106 |
+
# generate a GitHub URL for a file based on its name and a list of possible repository URLs
|
107 |
+
def get_github_url(filename: str, repo_urls: str, repo_dir: str) -> str:
|
108 |
+
source = filename.replace(repo_dir, '')
|
109 |
+
repo_name, file_path = source.split('/', 1)
|
110 |
+
repo_url_prefix = None
|
111 |
+
for repo_url in repo_urls:
|
112 |
+
if repo_name == repo_url.split('/')[-1]:
|
113 |
+
repo_url_prefix = repo_url
|
114 |
+
break
|
115 |
+
if not repo_url_prefix:
|
116 |
+
raise ValueError(f"Repo URL not found for {repo_name}")
|
117 |
+
url = f'{repo_url_prefix}/blob/main/{file_path}'
|
118 |
+
if VALIDATE_URLS:
|
119 |
+
try:
|
120 |
+
response = requests.get(url)
|
121 |
+
response.raise_for_status()
|
122 |
+
except:
|
123 |
+
print(f'filename: {filename}')
|
124 |
+
print(f'repo: {repo_name}, file: {file_path}')
|
125 |
+
print(f'url: {url}')
|
126 |
+
raise
|
127 |
+
return url
|
128 |
+
|
129 |
+
# creates a valid filename by replacing certain characters and removing the repo_dir path
|
130 |
+
def create_filename_from_path(filename: str, repo_dir: str) -> str:
|
131 |
+
filename = filename.replace(repo_dir, '')
|
132 |
+
chars_to_replace = ['/', '{', '}', '-', '.']
|
133 |
+
filename = ''.join(['_' if c in chars_to_replace else c for c in filename])
|
134 |
+
return filename
|
135 |
+
|
136 |
+
# copy the files with the source added in the first line
|
137 |
+
if not os.path.exists(docs_dir):
|
138 |
+
os.makedirs(docs_dir)
|
139 |
+
copied_files = []
|
140 |
+
for filename in tqdm(filtered_files):
|
141 |
+
source_url = get_github_url(filename, repo_urls, repo_dir)
|
142 |
+
data = f"source: {source_url}\n\n"
|
143 |
+
# convert jupyter notebooks to txt files
|
144 |
+
try:
|
145 |
+
if filename.endswith('.ipynb'):
|
146 |
+
filename = convert_notebook_to_txt(filename)
|
147 |
+
# rename and copy files
|
148 |
+
with open(filename, 'r') as f:
|
149 |
+
data += f.read()
|
150 |
+
output_filename = docs_dir + create_filename_from_path(filename, repo_dir)
|
151 |
+
with open(output_filename, 'w') as f:
|
152 |
+
f.write(data)
|
153 |
+
if not os.path.isfile(output_filename):
|
154 |
+
raise ValueError(f"Failed to create the output file: {output_filename}")
|
155 |
+
copied_files.append(output_filename)
|
156 |
+
except Exception as ex:
|
157 |
+
print(f'Failed to copy file {filename}: {ex}')
|
158 |
+
|
159 |
+
print(f'Successfully copied {len(set(copied_files))}/{len(filtered_files)} files')
|
160 |
+
|
161 |
+
|
162 |
+
def markdown_cleaner(data: str):
|
163 |
+
"""
|
164 |
+
Clean markdown text.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
data (str): The markdown text to be cleaned.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
str: The cleaned markdown text.
|
171 |
+
"""
|
172 |
+
soupped = BeautifulSoup(markdown(data), "html.parser")
|
173 |
+
raw_text = ''.join(soupped.findAll(string=True))
|
174 |
+
clean_text = re.sub(r"<!--.*?-->", "", raw_text, flags=re.DOTALL)
|
175 |
+
# remove any special tokens e.g <|endoftext|>
|
176 |
+
clean_text = re.sub(r"<\|endoftext\|>", "", clean_text, flags=re.DOTALL)
|
177 |
+
# discard non english text
|
178 |
+
clean_text = re.sub(r"[^a-zA-Z0-9\s]", "", clean_text, flags=re.DOTALL)
|
179 |
+
return "\n".join([t for t in clean_text.split("\n") if t])
|
180 |
+
|
181 |
+
|
182 |
+
if __name__ == '__main__':
|
183 |
+
repo_urls_file = "./datasets/hf_repositories_urls.json"
|
184 |
+
repo_dir = "./datasets/huggingface_repositories/"
|
185 |
+
docs_dir = "./datasets/huggingface_docs/"
|
186 |
+
download_repositories(repo_urls_file, repo_dir)
|
187 |
+
extract_files_from_directories(
|
188 |
+
repo_urls_file, repo_dir, docs_dir,
|
189 |
+
files_extensions=['.md', '.mdx', '.ipynb']
|
190 |
+
)
|
data/indexer.ipynb
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import math\n",
|
10 |
+
"import numpy as np\n",
|
11 |
+
"from pathlib import Path\n",
|
12 |
+
"from tqdm import tqdm\n",
|
13 |
+
"from typing import List, Any\n",
|
14 |
+
"from langchain.chains import RetrievalQA\n",
|
15 |
+
"from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings\n",
|
16 |
+
"from langchain.document_loaders import TextLoader\n",
|
17 |
+
"from langchain.indexes import VectorstoreIndexCreator\n",
|
18 |
+
"from langchain.text_splitter import CharacterTextSplitter\n",
|
19 |
+
"from langchain.vectorstores import FAISS"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": null,
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"docs = []\n",
|
29 |
+
"metadata = []\n",
|
30 |
+
"for p in Path(\"./datasets/huggingface_docs/\").iterdir():\n",
|
31 |
+
" if not p.is_dir():\n",
|
32 |
+
" with open(p) as f:\n",
|
33 |
+
" # the first line is the source of the text\n",
|
34 |
+
" source = f.readline().strip().replace('source: ', '')\n",
|
35 |
+
" docs.append(f.read())\n",
|
36 |
+
" metadata.append({\"source\": source})\n",
|
37 |
+
" # break\n",
|
38 |
+
"\n",
|
39 |
+
"print(f'number of documents: {len(docs)}')"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "code",
|
44 |
+
"execution_count": null,
|
45 |
+
"metadata": {},
|
46 |
+
"outputs": [],
|
47 |
+
"source": [
|
48 |
+
"chunk_size = 512\n",
|
49 |
+
"text_splitter = CharacterTextSplitter(\n",
|
50 |
+
" separator=\"\",\n",
|
51 |
+
" chunk_size=chunk_size,\n",
|
52 |
+
" chunk_overlap=100,\n",
|
53 |
+
" length_function=len,\n",
|
54 |
+
")\n",
|
55 |
+
"docs = text_splitter.create_documents(docs, metadata)\n",
|
56 |
+
"print(f'number of chunks: {len(docs)}')"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": null,
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"model_name = \"hkunlp/instructor-large\"\n",
|
66 |
+
"embed_instruction = \"Represent the Hugging Face library documentation\"\n",
|
67 |
+
"query_instruction = \"Query the most relevant piece of information from the Hugging Face documentation\"\n",
|
68 |
+
"\n",
|
69 |
+
"embedding_model = HuggingFaceInstructEmbeddings(\n",
|
70 |
+
" model_name=model_name,\n",
|
71 |
+
" embed_instruction=embed_instruction,\n",
|
72 |
+
" query_instruction=query_instruction,\n",
|
73 |
+
")"
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "code",
|
78 |
+
"execution_count": null,
|
79 |
+
"metadata": {},
|
80 |
+
"outputs": [],
|
81 |
+
"source": [
|
82 |
+
"class AverageInstructEmbeddings(HuggingFaceInstructEmbeddings):\n",
|
83 |
+
" max_length: int = None\n",
|
84 |
+
"\n",
|
85 |
+
" def __init__(self, max_length: int = 512, **kwargs: Any):\n",
|
86 |
+
" super().__init__(**kwargs)\n",
|
87 |
+
" self.max_length = max_length\n",
|
88 |
+
" if self.max_length < 0:\n",
|
89 |
+
" print('max_length is not specified, using model default max_seq_length')\n",
|
90 |
+
"\n",
|
91 |
+
" def embed_documents(self, texts: List[str]) -> List[List[float]]:\n",
|
92 |
+
" all_embeddings = []\n",
|
93 |
+
" for text in tqdm(texts, desc=\"Embedding documents\"):\n",
|
94 |
+
" if len(text) > self.max_length and self.max_length > -1:\n",
|
95 |
+
" n_chunks = math.ceil(len(text)/self.max_length)\n",
|
96 |
+
" chunks = [\n",
|
97 |
+
" text[i*self.max_length:(i+1)*self.max_length]\n",
|
98 |
+
" for i in range(n_chunks)\n",
|
99 |
+
" ]\n",
|
100 |
+
" instruction_pairs = [[self.embed_instruction, chunk] for chunk in chunks]\n",
|
101 |
+
" chunk_embeddings = self.client.encode(instruction_pairs)\n",
|
102 |
+
" avg_embedding = np.mean(chunk_embeddings, axis=0)\n",
|
103 |
+
" all_embeddings.append(avg_embedding.tolist())\n",
|
104 |
+
" else:\n",
|
105 |
+
" instruction_pairs = [[self.embed_instruction, text]]\n",
|
106 |
+
" embeddings = self.client.encode(instruction_pairs)\n",
|
107 |
+
" all_embeddings.append(embeddings[0].tolist())\n",
|
108 |
+
"\n",
|
109 |
+
" return all_embeddings\n",
|
110 |
+
"\n",
|
111 |
+
"\n",
|
112 |
+
"# max length fed to the model, if longer than max then chunks + averaging\n",
|
113 |
+
"max_length = 512\n",
|
114 |
+
"embedding_model = AverageInstructEmbeddings( \n",
|
115 |
+
" model_name=model_name,\n",
|
116 |
+
" embed_instruction=embed_instruction,\n",
|
117 |
+
" query_instruction=query_instruction,\n",
|
118 |
+
" max_length=max_length,\n",
|
119 |
+
")"
|
120 |
+
]
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"cell_type": "code",
|
124 |
+
"execution_count": null,
|
125 |
+
"metadata": {},
|
126 |
+
"outputs": [],
|
127 |
+
"source": [
|
128 |
+
"embeddings = embedding_model.embed_documents(texts=[d.page_content for d in docs[:10]])"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "code",
|
133 |
+
"execution_count": null,
|
134 |
+
"metadata": {},
|
135 |
+
"outputs": [],
|
136 |
+
"source": [
|
137 |
+
"index = FAISS.from_documents(docs, embedding_model)"
|
138 |
+
]
|
139 |
+
},
|
140 |
+
{
|
141 |
+
"cell_type": "code",
|
142 |
+
"execution_count": null,
|
143 |
+
"metadata": {},
|
144 |
+
"outputs": [],
|
145 |
+
"source": [
|
146 |
+
"index_name = f'index-{model_name}-{chunk_size}-m{max_length}-notebooks'\n",
|
147 |
+
"index_name"
|
148 |
+
]
|
149 |
+
},
|
150 |
+
{
|
151 |
+
"cell_type": "code",
|
152 |
+
"execution_count": null,
|
153 |
+
"metadata": {},
|
154 |
+
"outputs": [],
|
155 |
+
"source": [
|
156 |
+
"index.save_local(f'../indexes/{index_name}/')"
|
157 |
+
]
|
158 |
+
},
|
159 |
+
{
|
160 |
+
"cell_type": "code",
|
161 |
+
"execution_count": null,
|
162 |
+
"metadata": {},
|
163 |
+
"outputs": [],
|
164 |
+
"source": [
|
165 |
+
"index = FAISS.load_local(f'../indexes/{index_name}/', embedding_model)\n",
|
166 |
+
"docs = index.similarity_search(query='how to create a pipeline object?', k=5)\n",
|
167 |
+
"docs[0].page_content\n",
|
168 |
+
"docs[0].metadata"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"execution_count": null,
|
174 |
+
"metadata": {},
|
175 |
+
"outputs": [],
|
176 |
+
"source": [
|
177 |
+
"for i, doc in enumerate(docs, start=1):\n",
|
178 |
+
" print(f\"\\n{'='*100}\\n\")\n",
|
179 |
+
" print(f\"Document {i} of {len(docs)}\")\n",
|
180 |
+
" print(\"Page Content:\")\n",
|
181 |
+
" print(f\"\\n{'-'*100}\\n\")\n",
|
182 |
+
" print(doc.page_content, '\\n')\n",
|
183 |
+
" print(doc.metadata)"
|
184 |
+
]
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"cell_type": "code",
|
188 |
+
"execution_count": null,
|
189 |
+
"metadata": {},
|
190 |
+
"outputs": [],
|
191 |
+
"source": [
|
192 |
+
"from huggingface_hub import HfApi\n",
|
193 |
+
"\n",
|
194 |
+
"api = HfApi()\n",
|
195 |
+
"api.create_repo(\n",
|
196 |
+
" repo_id=f'KonradSzafer/{index_name}',\n",
|
197 |
+
" repo_type='dataset',\n",
|
198 |
+
" private=False,\n",
|
199 |
+
" exist_ok=True\n",
|
200 |
+
")\n",
|
201 |
+
"api.upload_folder(\n",
|
202 |
+
" folder_path=f'../indexes/{index_name}',\n",
|
203 |
+
" repo_id=f'KonradSzafer/{index_name}',\n",
|
204 |
+
" repo_type='dataset',\n",
|
205 |
+
")"
|
206 |
+
]
|
207 |
+
},
|
208 |
+
{
|
209 |
+
"cell_type": "code",
|
210 |
+
"execution_count": null,
|
211 |
+
"metadata": {},
|
212 |
+
"outputs": [],
|
213 |
+
"source": []
|
214 |
+
}
|
215 |
+
],
|
216 |
+
"metadata": {
|
217 |
+
"kernelspec": {
|
218 |
+
"display_name": "hf_qa_bot",
|
219 |
+
"language": "python",
|
220 |
+
"name": "python3"
|
221 |
+
},
|
222 |
+
"language_info": {
|
223 |
+
"codemirror_mode": {
|
224 |
+
"name": "ipython",
|
225 |
+
"version": 3
|
226 |
+
},
|
227 |
+
"file_extension": ".py",
|
228 |
+
"mimetype": "text/x-python",
|
229 |
+
"name": "python",
|
230 |
+
"nbconvert_exporter": "python",
|
231 |
+
"pygments_lexer": "ipython3",
|
232 |
+
"version": "3.10.12"
|
233 |
+
},
|
234 |
+
"orig_nbformat": 4
|
235 |
+
},
|
236 |
+
"nbformat": 4,
|
237 |
+
"nbformat_minor": 2
|
238 |
+
}
|
data/language-codes.csv
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
alpha2,English
|
2 |
+
aa,Afar
|
3 |
+
ab,Abkhazian
|
4 |
+
ae,Avestan
|
5 |
+
af,Afrikaans
|
6 |
+
ak,Akan
|
7 |
+
am,Amharic
|
8 |
+
an,Aragonese
|
9 |
+
ar,Arabic
|
10 |
+
as,Assamese
|
11 |
+
av,Avaric
|
12 |
+
ay,Aymara
|
13 |
+
az,Azerbaijani
|
14 |
+
ba,Bashkir
|
15 |
+
be,Belarusian
|
16 |
+
bg,Bulgarian
|
17 |
+
bh,Bihari languages
|
18 |
+
bi,Bislama
|
19 |
+
bm,Bambara
|
20 |
+
bn,Bengali
|
21 |
+
bo,Tibetan
|
22 |
+
br,Breton
|
23 |
+
bs,Bosnian
|
24 |
+
ca,Catalan; Valencian
|
25 |
+
ce,Chechen
|
26 |
+
ch,Chamorro
|
27 |
+
co,Corsican
|
28 |
+
cr,Cree
|
29 |
+
cs,Czech
|
30 |
+
cu,Church Slavic; Old Slavonic; Church Slavonic; Old Bulgarian; Old Church Slavonic
|
31 |
+
cv,Chuvash
|
32 |
+
cy,Welsh
|
33 |
+
da,Danish
|
34 |
+
de,German
|
35 |
+
dv,Divehi; Dhivehi; Maldivian
|
36 |
+
dz,Dzongkha
|
37 |
+
ee,Ewe
|
38 |
+
el,"Greek, Modern (1453-)"
|
39 |
+
en,English
|
40 |
+
eo,Esperanto
|
41 |
+
es,Spanish; Castilian
|
42 |
+
et,Estonian
|
43 |
+
eu,Basque
|
44 |
+
fa,Persian
|
45 |
+
ff,Fulah
|
46 |
+
fi,Finnish
|
47 |
+
fj,Fijian
|
48 |
+
fo,Faroese
|
49 |
+
fr,French
|
50 |
+
fy,Western Frisian
|
51 |
+
ga,Irish
|
52 |
+
gd,Gaelic; Scottish Gaelic
|
53 |
+
gj,Gujarati
|
54 |
+
gl,Galician
|
55 |
+
gn,Guarani
|
56 |
+
gu,Gujarati
|
57 |
+
gv,Manx
|
58 |
+
ha,Hausa
|
59 |
+
hd,Hindi
|
60 |
+
he,Hebrew
|
61 |
+
hi,Hindi
|
62 |
+
ho,Hiri Motu
|
63 |
+
hr,Croatian
|
64 |
+
ht,Haitian; Haitian Creole
|
65 |
+
hu,Hungarian
|
66 |
+
hy,Armenian
|
67 |
+
hz,Herero
|
68 |
+
ia,Interlingua (International Auxiliary Language Association)
|
69 |
+
id,Indonesian
|
70 |
+
ie,Interlingue; Occidental
|
71 |
+
ig,Igbo
|
72 |
+
ii,Sichuan Yi; Nuosu
|
73 |
+
ik,Inupiaq
|
74 |
+
io,Ido
|
75 |
+
is,Icelandic
|
76 |
+
it,Italian
|
77 |
+
iu,Inuktitut
|
78 |
+
ja,Japanese
|
79 |
+
jv,Javanese
|
80 |
+
ka,Georgian
|
81 |
+
kg,Kongo
|
82 |
+
ki,Kikuyu; Gikuyu
|
83 |
+
kj,Kuanyama; Kwanyama
|
84 |
+
kk,Kazakh
|
85 |
+
kl,Kalaallisut; Greenlandic
|
86 |
+
km,Central Khmer
|
87 |
+
kn,Kannada
|
88 |
+
ko,Korean
|
89 |
+
kr,Kanuri
|
90 |
+
ks,Kashmiri
|
91 |
+
ku,Kurdish
|
92 |
+
kv,Komi
|
93 |
+
kw,Cornish
|
94 |
+
ky,Kirghiz; Kyrgyz
|
95 |
+
la,Latin
|
96 |
+
lb,Luxembourgish; Letzeburgesch
|
97 |
+
lg,Ganda
|
98 |
+
li,Limburgan; Limburger; Limburgish
|
99 |
+
ln,Lingala
|
100 |
+
lo,Lao
|
101 |
+
lt,Lithuanian
|
102 |
+
lu,Luba-Katanga
|
103 |
+
lv,Latvian
|
104 |
+
mg,Malagasy
|
105 |
+
mh,Marshallese
|
106 |
+
mi,Maori
|
107 |
+
mk,Macedonian
|
108 |
+
ml,Malayalam
|
109 |
+
mn,Mongolian
|
110 |
+
mr,Marathi
|
111 |
+
ms,Malay
|
112 |
+
mt,Maltese
|
113 |
+
my,Burmese
|
114 |
+
na,Nauru
|
115 |
+
nb,"Bokmål, Norwegian; Norwegian Bokmål"
|
116 |
+
nd,"Ndebele, North; North Ndebele"
|
117 |
+
ne,Nepali
|
118 |
+
ng,Ndonga
|
119 |
+
nl,Dutch; Flemish
|
120 |
+
nn,"Norwegian Nynorsk; Nynorsk, Norwegian"
|
121 |
+
no,Norwegian
|
122 |
+
nr,"Ndebele, South; South Ndebele"
|
123 |
+
nv,Navajo; Navaho
|
124 |
+
ny,Chichewa; Chewa; Nyanja
|
125 |
+
oc,Occitan (post 1500)
|
126 |
+
oj,Ojibwa
|
127 |
+
om,Oromo
|
128 |
+
or,Oriya
|
129 |
+
os,Ossetian; Ossetic
|
130 |
+
pa,Panjabi; Punjabi
|
131 |
+
pi,Pali
|
132 |
+
pl,Polish
|
133 |
+
ps,Pushto; Pashto
|
134 |
+
pt,Portuguese
|
135 |
+
qu,Quechua
|
136 |
+
rm,Romansh
|
137 |
+
rn,Rundi
|
138 |
+
ro,Romanian; Moldavian; Moldovan
|
139 |
+
ru,Russian
|
140 |
+
rw,Kinyarwanda
|
141 |
+
sa,Sanskrit
|
142 |
+
sc,Sardinian
|
143 |
+
sd,Sindhi
|
144 |
+
se,Northern Sami
|
145 |
+
sg,Sango
|
146 |
+
si,Sinhala; Sinhalese
|
147 |
+
sk,Slovak
|
148 |
+
sl,Slovenian
|
149 |
+
sm,Samoan
|
150 |
+
sn,Shona
|
151 |
+
so,Somali
|
152 |
+
sq,Albanian
|
153 |
+
sr,Serbian
|
154 |
+
ss,Swati
|
155 |
+
st,"Sotho, Southern"
|
156 |
+
su,Sundanese
|
157 |
+
sv,Swedish
|
158 |
+
sw,Swahili
|
159 |
+
ta,Tamil
|
160 |
+
te,Telugu
|
161 |
+
tg,Tajik
|
162 |
+
th,Thai
|
163 |
+
ti,Tigrinya
|
164 |
+
tk,Turkmen
|
165 |
+
tl,Tagalog
|
166 |
+
tn,Tswana
|
167 |
+
to,Tonga (Tonga Islands)
|
168 |
+
tr,Turkish
|
169 |
+
ts,Tsonga
|
170 |
+
tt,Tatar
|
171 |
+
tw,Twi
|
172 |
+
ty,Tahitian
|
173 |
+
ug,Uighur; Uyghur
|
174 |
+
uk,Ukrainian
|
175 |
+
ur,Urdu
|
176 |
+
uz,Uzbek
|
177 |
+
ve,Venda
|
178 |
+
vi,Vietnamese
|
179 |
+
vo,Volapük
|
180 |
+
wa,Walloon
|
181 |
+
wo,Wolof
|
182 |
+
xh,Xhosa
|
183 |
+
yi,Yiddish
|
184 |
+
yo,Yoruba
|
185 |
+
za,Zhuang; Chuang
|
186 |
+
zh,Chinese; General
|
187 |
+
zh-CN,Chinese; Simplified
|
188 |
+
zh-TW,Chinese; Traditional
|
189 |
+
zh-hans,Chinese; Simplified
|
190 |
+
zu,Zulu
|
data/scrapers/stack_overflow_scraper.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import csv
|
3 |
+
import time
|
4 |
+
import requests
|
5 |
+
from typing import List
|
6 |
+
import pandas as pd
|
7 |
+
from tqdm import tqdm
|
8 |
+
from bs4 import BeautifulSoup
|
9 |
+
|
10 |
+
|
11 |
+
def scrape_question_with_answers(question_url: str) -> List[str]:
|
12 |
+
url = 'https://stackoverflow.com/' + question_url
|
13 |
+
response = requests.get(url)
|
14 |
+
soup = BeautifulSoup(response.content, 'html.parser')
|
15 |
+
|
16 |
+
title = soup.find('title').text.replace(' - Stack Overflow', '')
|
17 |
+
question_div = soup.find('div', {'class': 'postcell post-layout--right'})
|
18 |
+
question = question_div.find('p').text
|
19 |
+
answers_div = soup.find('div', {'class': 'answercell post-layout--right'})
|
20 |
+
answer = answers_div.find('div', {'class': 's-prose js-post-body'}).text
|
21 |
+
return [title, question, answer, url]
|
22 |
+
|
23 |
+
|
24 |
+
def scrape_questions_page(url: str, min_votes: int, min_answers: int) -> List[List[str]]:
|
25 |
+
response = requests.get(url)
|
26 |
+
soup = BeautifulSoup(response.content, 'html.parser')
|
27 |
+
posts_summaries = soup.find_all('div', {'class':'s-post-summary js-post-summary'})
|
28 |
+
|
29 |
+
qa_data = []
|
30 |
+
for summary in posts_summaries:
|
31 |
+
stats_div = summary.find('div', {'class': 's-post-summary--stats'})
|
32 |
+
vote_div = stats_div.find('div', {
|
33 |
+
'class': 's-post-summary--stats-item s-post-summary--stats-item__emphasized',
|
34 |
+
'title': re.compile(r'^Score of \d+$')})
|
35 |
+
if vote_div:
|
36 |
+
vote_number = int(vote_div.find('span', {'class': 's-post-summary--stats-item-number'}).text)
|
37 |
+
else:
|
38 |
+
vote_number = 0
|
39 |
+
answer_div = stats_div.find('div', {
|
40 |
+
'class': 's-post-summary--stats-item',
|
41 |
+
'title': re.compile(r'^\d+ answers$')})
|
42 |
+
if answer_div:
|
43 |
+
answer_number = int(answer_div.find('span', {'class': 's-post-summary--stats-item-number'}).text)
|
44 |
+
else:
|
45 |
+
answer_number = 0
|
46 |
+
|
47 |
+
question_href = summary.find('a', {'class': 's-link'})['href']
|
48 |
+
if vote_number >= min_votes and answer_number >= min_answers:
|
49 |
+
try:
|
50 |
+
qa_data.append(scrape_question_with_answers(question_href))
|
51 |
+
except Exception as error:
|
52 |
+
print(error)
|
53 |
+
|
54 |
+
time.sleep(1.5)
|
55 |
+
return qa_data
|
56 |
+
|
57 |
+
|
58 |
+
def crawl_and_save_qa(
|
59 |
+
filename: str,
|
60 |
+
base_url: str,
|
61 |
+
start_page: int,
|
62 |
+
n_pages: int=10,
|
63 |
+
min_votes: int=1,
|
64 |
+
min_answers: int=1
|
65 |
+
):
|
66 |
+
with open(filename, 'a', newline='') as f:
|
67 |
+
writer = csv.writer(f)
|
68 |
+
if start_page == 1:
|
69 |
+
writer.writerow(['title', 'question', 'answer', 'url'])
|
70 |
+
for page_num in tqdm(range(start_page, start_page+n_pages)):
|
71 |
+
page_data = scrape_questions_page(
|
72 |
+
base_url.format(page_num),
|
73 |
+
min_votes,
|
74 |
+
min_answers
|
75 |
+
)
|
76 |
+
if page_data:
|
77 |
+
for qa_data in page_data:
|
78 |
+
writer.writerow(qa_data)
|
79 |
+
|
80 |
+
|
81 |
+
if __name__ == '__main__':
|
82 |
+
filename = '../datasets/stackoverflow_linux.csv'
|
83 |
+
url = 'https://stackoverflow.com/questions/tagged/linux?tab=votes&page={}&pagesize=15'
|
84 |
+
crawl_and_save_qa(
|
85 |
+
filename=filename,
|
86 |
+
base_url=url,
|
87 |
+
start_page=21,
|
88 |
+
n_pages=10,
|
89 |
+
min_votes=1,
|
90 |
+
min_answers=1
|
91 |
+
)
|
data/stackoverflow_python_dataset.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from datasets import load_dataset
|
3 |
+
from bs4 import BeautifulSoup
|
4 |
+
|
5 |
+
|
6 |
+
def preprocess_dataset():
|
7 |
+
"""
|
8 |
+
Preprocesses the 'koutch/stackoverflow_python' dataset.
|
9 |
+
|
10 |
+
Returns:
|
11 |
+
datasets.arrow_dataset.Dataset: The preprocessed dataset.
|
12 |
+
"""
|
13 |
+
dataset = load_dataset('koutch/stackoverflow_python', split='train')
|
14 |
+
dataset = dataset.filter(
|
15 |
+
lambda example:
|
16 |
+
example['question_score'] > 100 and
|
17 |
+
example['answer_score'] > 5 and
|
18 |
+
datetime.strptime(example['answer_date'], '%Y-%m-%dT%H:%M:%SZ').year > 2010
|
19 |
+
)
|
20 |
+
|
21 |
+
def html2text(example):
|
22 |
+
soup = BeautifulSoup(example, 'html.parser')
|
23 |
+
return ''.join(soup.findAll(string=True))
|
24 |
+
|
25 |
+
def transforms(example):
|
26 |
+
example['answer'] = html2text(example['answer_body'])
|
27 |
+
example['question'] = html2text(example['question_body'])
|
28 |
+
return example
|
29 |
+
|
30 |
+
dataset = dataset.map(lambda example: transforms(example))
|
31 |
+
dataset = dataset.remove_columns([
|
32 |
+
'question_score', 'question_date', 'question_id',
|
33 |
+
'answer_date', 'answer_id', 'answer_score', 'tags',
|
34 |
+
'question_body', 'answer_body'
|
35 |
+
])
|
36 |
+
return dataset
|
37 |
+
|
38 |
+
|
39 |
+
def show_info(dataset):
|
40 |
+
"""
|
41 |
+
Print information about the dataset.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
dataset (datasets.arrow_dataset.Dataset): The dataset.
|
45 |
+
"""
|
46 |
+
print(dataset.info, '\n')
|
47 |
+
print(f'dataset len: {len(dataset)}')
|
48 |
+
print(f"example question: {dataset[0]['question']}")
|
49 |
+
print(f"example answer: {dataset[0]['answer']}")
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
dataset = preprocess_dataset()
|
54 |
+
dataset.push_to_hub('KonradSzafer/stackoverflow_python_preprocessed', private=False)
|
55 |
+
show_info(dataset)
|
data/upload_csv_dataset.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import pandas as pd
|
3 |
+
from datasets import Dataset, DatasetDict
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
def main():
|
9 |
+
dataset_name = sys.argv[1]
|
10 |
+
test_size = float(sys.argv[2]) if len(sys.argv) > 2 else 0.1
|
11 |
+
print(f'dataset: {dataset_name}, test size: {test_size}')
|
12 |
+
|
13 |
+
filename = f'datasets/{dataset_name}.csv'
|
14 |
+
df = pd.read_csv(filename)
|
15 |
+
dataset = Dataset.from_pandas(df)
|
16 |
+
train_dataset, test_dataset = train_test_split(dataset, test_size=test_size)
|
17 |
+
train_dataset = Dataset.from_dict(train_dataset)
|
18 |
+
test_dataset = Dataset.from_dict(test_dataset)
|
19 |
+
dataset_dict = DatasetDict({'train': train_dataset, 'test': test_dataset})
|
20 |
+
dataset_dict.push_to_hub(f'KonradSzafer/{dataset_name}', private=False)
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == '__main__':
|
24 |
+
main()
|
discord_bot/__init__.py
ADDED
File without changes
|
discord_bot/__main__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from qa_engine import logger, Config, QAEngine
|
2 |
+
from discord_bot.client import DiscordClient
|
3 |
+
|
4 |
+
|
5 |
+
config = Config()
|
6 |
+
qa_engine = QAEngine(
|
7 |
+
llm_model_id=config.question_answering_model_id,
|
8 |
+
embedding_model_id=config.embedding_model_id,
|
9 |
+
index_repo_id=config.index_repo_id,
|
10 |
+
prompt_template=config.prompt_template,
|
11 |
+
use_docs_for_context=config.use_docs_for_context,
|
12 |
+
num_relevant_docs=config.num_relevant_docs,
|
13 |
+
add_sources_to_response=config.add_sources_to_response,
|
14 |
+
use_messages_for_context=config.use_messages_in_context,
|
15 |
+
debug=config.debug
|
16 |
+
)
|
17 |
+
client = DiscordClient(
|
18 |
+
qa_engine=qa_engine,
|
19 |
+
num_last_messages=config.num_last_messages,
|
20 |
+
use_names_in_context=config.use_names_in_context,
|
21 |
+
enable_commands=config.enable_commands,
|
22 |
+
debug=config.debug
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
if __name__ == '__main__':
|
27 |
+
logger.info('Starting Application...')
|
28 |
+
client.run(config.discord_token)
|
discord_bot/client/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .client import DiscordClient
|
discord_bot/client/client.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import requests
|
3 |
+
from urllib.parse import quote
|
4 |
+
import discord
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
from qa_engine import logger, QAEngine
|
8 |
+
from discord_bot.client.utils import split_text_into_chunks
|
9 |
+
|
10 |
+
|
11 |
+
class DiscordClient(discord.Client):
|
12 |
+
"""
|
13 |
+
Discord Client class, used for interacting with a Discord server.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
qa_service_url (str): The URL of the question answering service.
|
17 |
+
num_last_messages (int, optional): The number of previous messages to use as context for generating answers.
|
18 |
+
Defaults to 5.
|
19 |
+
use_names_in_context (bool, optional): Whether to include user names in the message context. Defaults to True.
|
20 |
+
enable_commands (bool, optional): Whether to enable commands for the bot. Defaults to True.
|
21 |
+
|
22 |
+
Attributes:
|
23 |
+
qa_service_url (str): The URL of the question answering service.
|
24 |
+
num_last_messages (int): The number of previous messages to use as context for generating answers.
|
25 |
+
use_names_in_context (bool): Whether to include user names in the message context.
|
26 |
+
enable_commands (bool): Whether to enable commands for the bot.
|
27 |
+
max_message_len (int): The maximum length of a message.
|
28 |
+
system_prompt (str): The system prompt to be used.
|
29 |
+
|
30 |
+
"""
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
qa_engine: QAEngine,
|
34 |
+
num_last_messages: int = 5,
|
35 |
+
use_names_in_context: bool = True,
|
36 |
+
enable_commands: bool = True,
|
37 |
+
debug: bool = False
|
38 |
+
):
|
39 |
+
logger.info('Initializing Discord client...')
|
40 |
+
intents = discord.Intents.all()
|
41 |
+
intents.message_content = True
|
42 |
+
super().__init__(intents=intents, command_prefix='!')
|
43 |
+
|
44 |
+
assert num_last_messages >= 1, \
|
45 |
+
'The number of last messages in context should be at least 1'
|
46 |
+
|
47 |
+
self.qa_engine: QAEngine = qa_engine
|
48 |
+
self.num_last_messages: int = num_last_messages
|
49 |
+
self.use_names_in_context: bool = use_names_in_context
|
50 |
+
self.enable_commands: bool = enable_commands
|
51 |
+
self.debug: bool = debug
|
52 |
+
self.min_messgae_len: int = 1800
|
53 |
+
self.max_message_len: int = 2000
|
54 |
+
|
55 |
+
|
56 |
+
async def on_ready(self):
|
57 |
+
"""
|
58 |
+
Callback function to be called when the client is ready.
|
59 |
+
"""
|
60 |
+
logger.info('Successfully logged in as: {0.user}'.format(self))
|
61 |
+
await self.change_presence(activity=discord.Game(name='Chatting...'))
|
62 |
+
|
63 |
+
|
64 |
+
async def get_last_messages(self, message) -> List[str]:
|
65 |
+
"""
|
66 |
+
Method to fetch recent messages from a message's channel.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
message (Message): The discord Message object used to identify the channel.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
List[str]: Reversed list of recent messages from the channel,
|
73 |
+
excluding the input message. Messages may be prefixed with the author's name
|
74 |
+
if `self.use_names_in_context` is True.
|
75 |
+
"""
|
76 |
+
last_messages: List[str] = []
|
77 |
+
async for msg in message.channel.history(
|
78 |
+
limit=self.num_last_messages):
|
79 |
+
if self.use_names_in_context:
|
80 |
+
last_messages.append(f'{msg.author}: {msg.content}')
|
81 |
+
else:
|
82 |
+
last_messages.append(msg.content)
|
83 |
+
last_messages.reverse()
|
84 |
+
last_messages.pop() # remove last message from context
|
85 |
+
return last_messages
|
86 |
+
|
87 |
+
|
88 |
+
async def send_message(self, message, answer: str, sources: str):
|
89 |
+
chunks = split_text_into_chunks(
|
90 |
+
text=answer,
|
91 |
+
split_characters=['. ', ', ', '\n'],
|
92 |
+
min_size=self.min_messgae_len,
|
93 |
+
max_size=self.max_message_len
|
94 |
+
)
|
95 |
+
for chunk in chunks:
|
96 |
+
await message.channel.send(chunk)
|
97 |
+
await message.channel.send(sources)
|
98 |
+
|
99 |
+
|
100 |
+
async def on_message(self, message):
|
101 |
+
"""
|
102 |
+
Callback function to be called when a message is received.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
message (discord.Message): The received message.
|
106 |
+
"""
|
107 |
+
if message.author == self.user:
|
108 |
+
return
|
109 |
+
if self.enable_commands and message.content.startswith('!'):
|
110 |
+
if message.content == '!clear':
|
111 |
+
await message.channel.purge()
|
112 |
+
return
|
113 |
+
|
114 |
+
last_messages = await self.get_last_messages(message)
|
115 |
+
context = '\n'.join(last_messages)
|
116 |
+
|
117 |
+
logger.info('Received message: {0.content}'.format(message))
|
118 |
+
response = self.qa_engine.get_response(
|
119 |
+
question=message.content,
|
120 |
+
messages_context=context
|
121 |
+
)
|
122 |
+
logger.info('Sending response: {0}'.format(response))
|
123 |
+
try:
|
124 |
+
await self.send_message(
|
125 |
+
message,
|
126 |
+
response.get_answer(),
|
127 |
+
response.get_sources_as_text()
|
128 |
+
)
|
129 |
+
except Exception as e:
|
130 |
+
logger.error('Failed to send response: {0}'.format(e))
|
discord_bot/client/utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
|
4 |
+
def find_max_split_index(text: str, char: str) -> int:
|
5 |
+
char_idx = text.rfind(char)
|
6 |
+
if char_idx > 0:
|
7 |
+
# If a character is found, return the index after the splitting character
|
8 |
+
split_idx = char_idx + len(char)
|
9 |
+
if split_idx >= len(text):
|
10 |
+
return len(text)
|
11 |
+
else:
|
12 |
+
return split_idx
|
13 |
+
return -1
|
14 |
+
|
15 |
+
|
16 |
+
def find_max_split_index_from_sequence(text: str, split_characters: List[str]) -> int:
|
17 |
+
split_index = max((
|
18 |
+
find_max_split_index(text, sequence)
|
19 |
+
for sequence in split_characters
|
20 |
+
), default=-1)
|
21 |
+
return split_index
|
22 |
+
|
23 |
+
|
24 |
+
def split_text_into_chunks(
|
25 |
+
text: str,
|
26 |
+
split_characters: List[str] = [],
|
27 |
+
min_size: int = 20,
|
28 |
+
max_size: int = 250,
|
29 |
+
) -> List[str]:
|
30 |
+
|
31 |
+
chunks = []
|
32 |
+
start_idx = 0
|
33 |
+
end_idx = max_size
|
34 |
+
text_len = len(text)
|
35 |
+
while start_idx < text_len:
|
36 |
+
search_chunk = text[start_idx+min_size:end_idx]
|
37 |
+
split_idx = find_max_split_index_from_sequence(
|
38 |
+
text=search_chunk,
|
39 |
+
split_characters=split_characters
|
40 |
+
)
|
41 |
+
# if no spliting element found, set the maximal size
|
42 |
+
if split_idx < 1:
|
43 |
+
split_idx = end_idx
|
44 |
+
# if found - offset it by the starting idx of the chunk
|
45 |
+
else:
|
46 |
+
split_idx += start_idx + min_size
|
47 |
+
|
48 |
+
chunk = text[start_idx:split_idx]
|
49 |
+
chunks.append(chunk)
|
50 |
+
|
51 |
+
start_idx = split_idx
|
52 |
+
end_idx = split_idx + max_size
|
53 |
+
|
54 |
+
return chunks
|
docker-compose.yml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: '3'
|
2 |
+
services:
|
3 |
+
api:
|
4 |
+
build:
|
5 |
+
context: .
|
6 |
+
dockerfile: Dockerfile.api
|
7 |
+
ports:
|
8 |
+
- 8000:8000
|
9 |
+
networks:
|
10 |
+
- mynetwork
|
11 |
+
bot:
|
12 |
+
build:
|
13 |
+
context: .
|
14 |
+
dockerfile: Dockerfile.bot
|
15 |
+
ports:
|
16 |
+
- 80:80
|
17 |
+
depends_on:
|
18 |
+
- api
|
19 |
+
networks:
|
20 |
+
- mynetwork
|
21 |
+
networks:
|
22 |
+
mynetwork:
|
23 |
+
driver: bridge
|
models/inference.ipynb
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os\n",
|
10 |
+
"import torch\n",
|
11 |
+
"import transformers\n",
|
12 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
13 |
+
"\n",
|
14 |
+
"PROMPT_TEMPLATES_DIR = os.path.dirname(os.path.abspath(os.getcwd()))\n",
|
15 |
+
"PROMPT_TEMPLATES_DIR += '/config/api/prompt_templates/'\n",
|
16 |
+
"\n",
|
17 |
+
"os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
|
18 |
+
]
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"cell_type": "code",
|
22 |
+
"execution_count": null,
|
23 |
+
"metadata": {},
|
24 |
+
"outputs": [],
|
25 |
+
"source": [
|
26 |
+
"prompt_template = 'sythia_v1.3'\n",
|
27 |
+
"with open(PROMPT_TEMPLATES_DIR + f'{prompt_template}.txt', 'r') as f:\n",
|
28 |
+
" prompt_template = f.read()\n",
|
29 |
+
"\n",
|
30 |
+
"context = ''\n",
|
31 |
+
"question = 'How to fix a bike?'\n",
|
32 |
+
"\n",
|
33 |
+
"prompt = prompt_template.format(context=context, question=question)\n",
|
34 |
+
"print(f'prompt len: {len(prompt)}\\n')\n",
|
35 |
+
"print(prompt)"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "code",
|
40 |
+
"execution_count": null,
|
41 |
+
"metadata": {},
|
42 |
+
"outputs": [],
|
43 |
+
"source": [
|
44 |
+
"model_id = 'migtissera/SynthIA-7B-v1.3'\n",
|
45 |
+
"\n",
|
46 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
|
47 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
48 |
+
" model_id,\n",
|
49 |
+
" torch_dtype=torch.bfloat16,\n",
|
50 |
+
" trust_remote_code=True,\n",
|
51 |
+
" load_in_8bit=False,\n",
|
52 |
+
" device_map='auto',\n",
|
53 |
+
" resume_download=True,\n",
|
54 |
+
")\n",
|
55 |
+
"\n",
|
56 |
+
"pipeline = transformers.pipeline(\n",
|
57 |
+
" 'text-generation',\n",
|
58 |
+
" model=model,\n",
|
59 |
+
" tokenizer=tokenizer,\n",
|
60 |
+
" device_map='auto',\n",
|
61 |
+
" torch_dtype=torch.bfloat16,\n",
|
62 |
+
" eos_token_id=tokenizer.eos_token_id,\n",
|
63 |
+
" pad_token_id=tokenizer.eos_token_id,\n",
|
64 |
+
" min_new_tokens=64,\n",
|
65 |
+
" max_new_tokens=800,\n",
|
66 |
+
" temperature=0.5,\n",
|
67 |
+
" do_sample=True,\n",
|
68 |
+
")\n",
|
69 |
+
"\n",
|
70 |
+
"output_text = pipeline(prompt)[0]['generated_text']\n",
|
71 |
+
"output_text = output_text.replace(prompt+'\\n', '')\n",
|
72 |
+
"print(output_text)"
|
73 |
+
]
|
74 |
+
}
|
75 |
+
],
|
76 |
+
"metadata": {
|
77 |
+
"kernelspec": {
|
78 |
+
"display_name": "hf_qa_bot",
|
79 |
+
"language": "python",
|
80 |
+
"name": "python3"
|
81 |
+
},
|
82 |
+
"language_info": {
|
83 |
+
"codemirror_mode": {
|
84 |
+
"name": "ipython",
|
85 |
+
"version": 3
|
86 |
+
},
|
87 |
+
"file_extension": ".py",
|
88 |
+
"mimetype": "text/x-python",
|
89 |
+
"name": "python",
|
90 |
+
"nbconvert_exporter": "python",
|
91 |
+
"pygments_lexer": "ipython3",
|
92 |
+
"version": "3.11.5"
|
93 |
+
},
|
94 |
+
"orig_nbformat": 4,
|
95 |
+
"vscode": {
|
96 |
+
"interpreter": {
|
97 |
+
"hash": "e769ac600d1c65682759767682b2a946c0eaa09d353302f712fe4c2e822e15df"
|
98 |
+
}
|
99 |
+
}
|
100 |
+
},
|
101 |
+
"nbformat": 4,
|
102 |
+
"nbformat_minor": 2
|
103 |
+
}
|
qa_engine/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
from qa_engine.logger import setup_logger
|
3 |
+
|
4 |
+
setup_logger()
|
5 |
+
load_dotenv(dotenv_path='config/.env')
|
6 |
+
|
7 |
+
from .logger import setup_logger, logger
|
8 |
+
from .config import Config
|
9 |
+
from .qa_engine import QAEngine
|
qa_engine/config.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass, asdict
|
3 |
+
from typing import Any, Union
|
4 |
+
|
5 |
+
from qa_engine import logger
|
6 |
+
|
7 |
+
|
8 |
+
def get_env(env_name: str, default: Any = None, warn: bool = True) -> str:
|
9 |
+
env = os.getenv(env_name)
|
10 |
+
if not env:
|
11 |
+
if default:
|
12 |
+
if warn:
|
13 |
+
logger.warning(
|
14 |
+
f'Environment variable {env_name} not found.' \
|
15 |
+
f'Using the default value: {default}.'
|
16 |
+
)
|
17 |
+
return default
|
18 |
+
else:
|
19 |
+
raise ValueError(f'Cannot parse: {env_name}')
|
20 |
+
else:
|
21 |
+
return env
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class Config:
|
26 |
+
# QA Engine config
|
27 |
+
question_answering_model_id: str = get_env('QUESTION_ANSWERING_MODEL_ID')
|
28 |
+
embedding_model_id: str = get_env('EMBEDDING_MODEL_ID')
|
29 |
+
index_repo_id: str = get_env('INDEX_REPO_ID')
|
30 |
+
prompt_template_name: str = get_env('PROMPT_TEMPLATE_NAME')
|
31 |
+
use_docs_for_context: bool = eval(get_env('USE_DOCS_FOR_CONTEXT', 'True'))
|
32 |
+
num_relevant_docs: bool = eval(get_env('NUM_RELEVANT_DOCS', 3))
|
33 |
+
add_sources_to_response: bool = eval(get_env('ADD_SOURCES_TO_RESPONSE', 'True'))
|
34 |
+
use_messages_in_context: bool = eval(get_env('USE_MESSAGES_IN_CONTEXT', 'True'))
|
35 |
+
debug: bool = eval(get_env('DEBUG', 'True'))
|
36 |
+
|
37 |
+
# Discord bot config - optional
|
38 |
+
discord_token: str = get_env('DISCORD_TOKEN', '', warn=False)
|
39 |
+
num_last_messages: int = int(get_env('NUM_LAST_MESSAGES', 2, warn=False))
|
40 |
+
use_names_in_context: bool = eval(get_env('USE_NAMES_IN_CONTEXT', 'False', warn=False))
|
41 |
+
enable_commands: bool = eval(get_env('ENABLE_COMMANDS', 'True', warn=False))
|
42 |
+
|
43 |
+
def __post_init__(self):
|
44 |
+
prompt_template_file = f'config/prompt_templates/{self.prompt_template_name}.txt'
|
45 |
+
with open(prompt_template_file, 'r') as f:
|
46 |
+
self.prompt_template = f.read()
|
47 |
+
# validate config
|
48 |
+
if 'context' not in self.prompt_template:
|
49 |
+
raise ValueError("Prompt Template does not contain the 'context' field.")
|
50 |
+
if 'question' not in self.prompt_template:
|
51 |
+
raise ValueError("Prompt Template does not contain the 'question' field.")
|
52 |
+
if not self.use_docs_for_context and self.add_sources_to_response:
|
53 |
+
raise ValueError('Cannot add sources to response if not using docs in context')
|
54 |
+
if self.num_relevant_docs < 1:
|
55 |
+
raise ValueError('num_relevant_docs must be greater than 0')
|
56 |
+
self.log()
|
57 |
+
|
58 |
+
def asdict(self) -> dict:
|
59 |
+
return asdict(self)
|
60 |
+
|
61 |
+
def log(self) -> None:
|
62 |
+
logger.info('Config:')
|
63 |
+
for key, value in self.asdict().items():
|
64 |
+
logger.info(f'{key}: {value}')
|
qa_engine/logger.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
|
4 |
+
logger = logging.getLogger(__name__)
|
5 |
+
|
6 |
+
def setup_logger() -> None:
|
7 |
+
"""
|
8 |
+
Logger setup.
|
9 |
+
"""
|
10 |
+
logger.setLevel(logging.DEBUG)
|
11 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
12 |
+
handler = logging.StreamHandler()
|
13 |
+
handler.setFormatter(formatter)
|
14 |
+
logger.addHandler(handler)
|
qa_engine/mocks.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Mapping, Optional, Any
|
3 |
+
|
4 |
+
from langchain.llms.base import LLM
|
5 |
+
|
6 |
+
|
7 |
+
class MockLocalBinaryModel(LLM):
|
8 |
+
"""
|
9 |
+
Mock Local Binary Model class, used for generating the string "a".
|
10 |
+
|
11 |
+
Args:
|
12 |
+
model_id (str): The ID of the model to be mocked.
|
13 |
+
|
14 |
+
Attributes:
|
15 |
+
model_path (str): The path to the model to be mocked.
|
16 |
+
llm (str): The string "a".
|
17 |
+
|
18 |
+
Raises:
|
19 |
+
ValueError: If the model_path does not exist.
|
20 |
+
"""
|
21 |
+
|
22 |
+
model_path: str = None
|
23 |
+
llm: str = 'READY TO MOCK'
|
24 |
+
|
25 |
+
def __init__(self, model_id: str = None):
|
26 |
+
super().__init__()
|
27 |
+
self.model_path = f'bot/question_answering/{model_id}'
|
28 |
+
if not os.path.exists(self.model_path):
|
29 |
+
raise ValueError(f'{self.model_path} does not exist')
|
30 |
+
|
31 |
+
|
32 |
+
def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
|
33 |
+
return self.llm
|
34 |
+
|
35 |
+
@property
|
36 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
37 |
+
return {'name_of_model': self.model_path}
|
38 |
+
|
39 |
+
@property
|
40 |
+
def _llm_type(self) -> str:
|
41 |
+
return self.model_path
|
qa_engine/qa_engine.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import requests
|
4 |
+
import subprocess
|
5 |
+
from typing import Mapping, Optional, Any
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import transformers
|
9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
10 |
+
from huggingface_hub import snapshot_download
|
11 |
+
from urllib.parse import quote
|
12 |
+
from langchain import PromptTemplate, HuggingFaceHub, LLMChain
|
13 |
+
from langchain.llms import HuggingFacePipeline
|
14 |
+
from langchain.llms.base import LLM
|
15 |
+
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings, HuggingFaceInstructEmbeddings
|
16 |
+
from langchain.vectorstores import FAISS
|
17 |
+
from sentence_transformers import CrossEncoder
|
18 |
+
|
19 |
+
from qa_engine import logger
|
20 |
+
from qa_engine.response import Response
|
21 |
+
|
22 |
+
|
23 |
+
class LocalBinaryModel(LLM):
|
24 |
+
model_id: str = None
|
25 |
+
llm: None = None
|
26 |
+
|
27 |
+
def __init__(self, model_id: str = None):
|
28 |
+
super().__init__()
|
29 |
+
# pip install llama_cpp_python==0.1.39
|
30 |
+
from llama_cpp import Llama
|
31 |
+
|
32 |
+
model_path = f'qa_engine/{model_id}'
|
33 |
+
if not os.path.exists(model_path):
|
34 |
+
raise ValueError(f'{model_path} does not exist')
|
35 |
+
self.model_id = model_id
|
36 |
+
self.llm = Llama(model_path=model_path, n_ctx=4096)
|
37 |
+
|
38 |
+
def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
|
39 |
+
output = self.llm(
|
40 |
+
prompt,
|
41 |
+
max_tokens=1024,
|
42 |
+
stop=['Q:'],
|
43 |
+
echo=False
|
44 |
+
)
|
45 |
+
return output['choices'][0]['text']
|
46 |
+
|
47 |
+
@property
|
48 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
49 |
+
return {'name_of_model': self.model_id}
|
50 |
+
|
51 |
+
@property
|
52 |
+
def _llm_type(self) -> str:
|
53 |
+
return self.model_id
|
54 |
+
|
55 |
+
|
56 |
+
class TransformersPipelineModel(LLM):
|
57 |
+
model_id: str = None
|
58 |
+
pipeline: str = None
|
59 |
+
|
60 |
+
def __init__(self, model_id: str = None):
|
61 |
+
super().__init__()
|
62 |
+
self.model_id = model_id
|
63 |
+
|
64 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
65 |
+
model = AutoModelForCausalLM.from_pretrained(
|
66 |
+
model_id,
|
67 |
+
torch_dtype=torch.bfloat16,
|
68 |
+
trust_remote_code=True,
|
69 |
+
load_in_8bit=False,
|
70 |
+
device_map='auto',
|
71 |
+
resume_download=True,
|
72 |
+
)
|
73 |
+
self.pipeline = transformers.pipeline(
|
74 |
+
'text-generation',
|
75 |
+
model=model,
|
76 |
+
tokenizer=tokenizer,
|
77 |
+
torch_dtype=torch.bfloat16,
|
78 |
+
device_map='auto',
|
79 |
+
eos_token_id=tokenizer.eos_token_id,
|
80 |
+
pad_token_id=tokenizer.eos_token_id,
|
81 |
+
min_new_tokens=64,
|
82 |
+
max_new_tokens=800,
|
83 |
+
temperature=0.5,
|
84 |
+
do_sample=True,
|
85 |
+
)
|
86 |
+
|
87 |
+
def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
|
88 |
+
output_text = self.pipeline(prompt)[0]['generated_text']
|
89 |
+
output_text = output_text.replace(prompt+'\n', '')
|
90 |
+
return output_text
|
91 |
+
|
92 |
+
@property
|
93 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
94 |
+
return {'name_of_model': self.model_id}
|
95 |
+
|
96 |
+
@property
|
97 |
+
def _llm_type(self) -> str:
|
98 |
+
return self.model_id
|
99 |
+
|
100 |
+
|
101 |
+
class APIServedModel(LLM):
|
102 |
+
model_url: str = None
|
103 |
+
debug: bool = None
|
104 |
+
|
105 |
+
def __init__(self, model_url: str = None, debug: bool = None):
|
106 |
+
super().__init__()
|
107 |
+
if model_url[-1] == '/':
|
108 |
+
raise ValueError('URL should not end with a slash - "/"')
|
109 |
+
self.model_url = model_url
|
110 |
+
self.debug = debug
|
111 |
+
|
112 |
+
def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
|
113 |
+
prompt_encoded = quote(prompt, safe='')
|
114 |
+
url = f'{self.model_url}/?prompt={prompt_encoded}'
|
115 |
+
if self.debug:
|
116 |
+
logger.info(f'URL: {url}')
|
117 |
+
try:
|
118 |
+
response = requests.get(url, timeout=1200, verify=False)
|
119 |
+
response.raise_for_status()
|
120 |
+
return json.loads(response.content)['output_text']
|
121 |
+
except Exception as err:
|
122 |
+
logger.error(f'Error: {err}')
|
123 |
+
return f'Error: {err}'
|
124 |
+
|
125 |
+
@property
|
126 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
127 |
+
return {'name_of_model': f'model url: {self.model_url}'}
|
128 |
+
|
129 |
+
@property
|
130 |
+
def _llm_type(self) -> str:
|
131 |
+
return 'api_model'
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
class QAEngine():
|
136 |
+
"""
|
137 |
+
QAEngine class, used for generating answers to questions.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
llm_model_id (str): The ID of the LLM model to be used.
|
141 |
+
embedding_model_id (str): The ID of the embedding model to be used.
|
142 |
+
index_repo_id (str): The ID of the index repository to be used.
|
143 |
+
run_locally (bool, optional): Whether to run the models locally or on the Hugging Face hub. Defaults to True.
|
144 |
+
use_docs_for_context (bool, optional): Whether to use relevant documents as context for generating answers.
|
145 |
+
Defaults to True.
|
146 |
+
use_messages_for_context (bool, optional): Whether to use previous messages as context for generating answers.
|
147 |
+
Defaults to True.
|
148 |
+
debug (bool, optional): Whether to log debug information. Defaults to False.
|
149 |
+
|
150 |
+
Attributes:
|
151 |
+
use_docs_for_context (bool): Whether to use relevant documents as context for generating answers.
|
152 |
+
use_messages_for_context (bool): Whether to use previous messages as context for generating answers.
|
153 |
+
debug (bool): Whether to log debug information.
|
154 |
+
llm_model (Union[LocalBinaryModel, HuggingFacePipeline, HuggingFaceHub]): The LLM model to be used.
|
155 |
+
embedding_model (Union[HuggingFaceInstructEmbeddings, HuggingFaceHubEmbeddings]): The embedding model to be used.
|
156 |
+
prompt_template (PromptTemplate): The prompt template to be used.
|
157 |
+
llm_chain (LLMChain): The LLM chain to be used.
|
158 |
+
knowledge_index (FAISS): The FAISS index to be used.
|
159 |
+
|
160 |
+
"""
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
llm_model_id: str,
|
164 |
+
embedding_model_id: str,
|
165 |
+
index_repo_id: str,
|
166 |
+
prompt_template: str,
|
167 |
+
use_docs_for_context: bool = True,
|
168 |
+
num_relevant_docs: int = 3,
|
169 |
+
add_sources_to_response: bool = True,
|
170 |
+
use_messages_for_context: bool = True,
|
171 |
+
first_stage_docs: int = 50,
|
172 |
+
debug: bool = False
|
173 |
+
):
|
174 |
+
super().__init__()
|
175 |
+
self.prompt_template = prompt_template
|
176 |
+
self.use_docs_for_context = use_docs_for_context
|
177 |
+
self.num_relevant_docs = num_relevant_docs
|
178 |
+
self.add_sources_to_response = add_sources_to_response
|
179 |
+
self.use_messages_for_context = use_messages_for_context
|
180 |
+
self.first_stage_docs = first_stage_docs
|
181 |
+
self.debug = debug
|
182 |
+
|
183 |
+
if 'local_models/' in llm_model_id:
|
184 |
+
logger.info('using local binary model')
|
185 |
+
self.llm_model = LocalBinaryModel(
|
186 |
+
model_id=llm_model_id
|
187 |
+
)
|
188 |
+
elif 'api_models/' in llm_model_id:
|
189 |
+
logger.info('using api served model')
|
190 |
+
self.llm_model = APIServedModel(
|
191 |
+
model_url=llm_model_id.replace('api_models/', ''),
|
192 |
+
debug=self.debug
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
logger.info('using transformers pipeline model')
|
196 |
+
self.llm_model = TransformersPipelineModel(
|
197 |
+
model_id=llm_model_id
|
198 |
+
)
|
199 |
+
|
200 |
+
prompt = PromptTemplate(
|
201 |
+
template=prompt_template,
|
202 |
+
input_variables=['question', 'context']
|
203 |
+
)
|
204 |
+
self.llm_chain = LLMChain(prompt=prompt, llm=self.llm_model)
|
205 |
+
|
206 |
+
if self.use_docs_for_context:
|
207 |
+
logger.info(f'Downloading {index_repo_id}')
|
208 |
+
snapshot_download(
|
209 |
+
repo_id=index_repo_id,
|
210 |
+
allow_patterns=['*.faiss', '*.pkl'],
|
211 |
+
repo_type='dataset',
|
212 |
+
local_dir='indexes/run/'
|
213 |
+
)
|
214 |
+
logger.info('Loading embedding model')
|
215 |
+
embed_instruction = 'Represent the Hugging Face library documentation'
|
216 |
+
query_instruction = 'Query the most relevant piece of information from the Hugging Face documentation'
|
217 |
+
embedding_model = HuggingFaceInstructEmbeddings(
|
218 |
+
model_name=embedding_model_id,
|
219 |
+
embed_instruction=embed_instruction,
|
220 |
+
query_instruction=query_instruction
|
221 |
+
)
|
222 |
+
logger.info('Loading index')
|
223 |
+
self.knowledge_index = FAISS.load_local('./indexes/run/', embedding_model)
|
224 |
+
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
225 |
+
|
226 |
+
|
227 |
+
def get_response(self, question: str, messages_context: str = '') -> Response:
|
228 |
+
"""
|
229 |
+
Generate an answer to the specified question.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
question (str): The question to be answered.
|
233 |
+
messages_context (str, optional): The context to be used for generating the answer. Defaults to ''.
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
response (Response): The Response object containing the generated answer and the sources of information
|
237 |
+
used to generate the response.
|
238 |
+
"""
|
239 |
+
|
240 |
+
response = Response()
|
241 |
+
context = ''
|
242 |
+
relevant_docs = ''
|
243 |
+
if self.use_messages_for_context and messages_context:
|
244 |
+
messages_context = f'\nPrevious questions and answers:\n{messages_context}'
|
245 |
+
context += messages_context
|
246 |
+
if self.use_docs_for_context:
|
247 |
+
logger.info('Retriving documents')
|
248 |
+
# messages context is used for better retrival
|
249 |
+
retrival_query = messages_context + question
|
250 |
+
relevant_docs = self.knowledge_index.similarity_search(
|
251 |
+
query=retrival_query,
|
252 |
+
k=self.first_stage_docs
|
253 |
+
)
|
254 |
+
cross_encoding_predictions = self.reranker.predict(
|
255 |
+
[(retrival_query, doc.page_content) for doc in relevant_docs]
|
256 |
+
)
|
257 |
+
relevant_docs = [
|
258 |
+
doc for _, doc in sorted(
|
259 |
+
zip(cross_encoding_predictions, relevant_docs),
|
260 |
+
reverse=True, key = lambda x: x[0]
|
261 |
+
)
|
262 |
+
]
|
263 |
+
relevant_docs = relevant_docs[:self.num_relevant_docs]
|
264 |
+
context += '\nExtracted documents:\n'
|
265 |
+
context += ''.join([doc.page_content for doc in relevant_docs])
|
266 |
+
metadata = [doc.metadata for doc in relevant_docs]
|
267 |
+
response.set_sources(sources=[str(m['source']) for m in metadata])
|
268 |
+
|
269 |
+
logger.info('Running LLM chain')
|
270 |
+
answer = self.llm_chain.run(question=question, context=context)
|
271 |
+
response.set_answer(answer)
|
272 |
+
logger.info('Received answer')
|
273 |
+
|
274 |
+
if self.debug:
|
275 |
+
logger.info('\n' + '=' * 100)
|
276 |
+
sep = '\n' + '-' * 100
|
277 |
+
logger.info(f'question len: {len(question)} {sep}')
|
278 |
+
logger.info(f'question: {question} {sep}')
|
279 |
+
logger.info(f'answer len: {len(response.get_answer())} {sep}')
|
280 |
+
logger.info(f'answer: {response.get_answer()} {sep}')
|
281 |
+
logger.info(f'{response.get_sources_as_text()} {sep}')
|
282 |
+
logger.info(f'messages_contex: {messages_context} {sep}')
|
283 |
+
logger.info(f'relevant_docs: {relevant_docs} {sep}')
|
284 |
+
logger.info(f'context len: {len(context)} {sep}')
|
285 |
+
logger.info(f'context: {context} {sep}')
|
286 |
+
return response
|
qa_engine/response.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
|
4 |
+
class Response:
|
5 |
+
def __init__(self):
|
6 |
+
self.answer = ''
|
7 |
+
self.sources = []
|
8 |
+
|
9 |
+
def set_answer(self, answer: str) -> None:
|
10 |
+
self.answer = answer
|
11 |
+
|
12 |
+
def set_sources(self, sources: List) -> None:
|
13 |
+
self.sources = list(set(map(str, sources)))
|
14 |
+
|
15 |
+
def get_sources(self) -> List[str]:
|
16 |
+
return self.sources
|
17 |
+
|
18 |
+
def get_sources_as_text(self) -> str:
|
19 |
+
if not self.sources:
|
20 |
+
return ''
|
21 |
+
sources_text = '\n\nSources:'
|
22 |
+
for i, (source) in enumerate(self.sources):
|
23 |
+
sources_text += f'\n [{i+1}] {source}'
|
24 |
+
return sources_text
|
25 |
+
|
26 |
+
def get_answer(self, include_sources: bool = False) -> str:
|
27 |
+
answer = self.answer
|
28 |
+
if include_sources:
|
29 |
+
answer += self.get_sources_as_text()
|
30 |
+
return answer
|
31 |
+
|
32 |
+
def __str__(self):
|
33 |
+
return self.get_answer(include_sources=True)
|
questions.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
How to create audio dataset with Hugging Face?
|
2 |
+
I want to check if 2 sentences are similar semantically. How can I do it?
|
3 |
+
What are the benefits of Gradio?
|
4 |
+
How to deploy a text-to-image model?
|
5 |
+
Does Hugging Face offer any distributed training assistance? followup: Can you give me an example setup of it?
|
6 |
+
I want to detect cars on video recording. How should I do it and what models do you recommend?
|
7 |
+
Is there any tool for evaluating models in Hugging Face? followup: Can you give me an example setup of it?
|
8 |
+
What are some advantages of the Hugging Face Hub?
|
9 |
+
How would I use a model in 8 bit in transformers?
|
requirements.txt
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
transformers
|
4 |
+
accelerate
|
5 |
+
einops
|
6 |
+
huggingface_hub
|
7 |
+
gradio
|
8 |
+
beautifulsoup4==4.12.0
|
9 |
+
discord.py==2.2.2
|
10 |
+
evaluate==0.4.0
|
11 |
+
fastapi==0.98.0
|
12 |
+
langchain==0.0.154
|
13 |
+
nltk==3.8.1
|
14 |
+
nbconvert==7.6.0
|
15 |
+
nbformat==5.9.0
|
16 |
+
numpy==1.24.2
|
17 |
+
markdown==3.4.4
|
18 |
+
pandas==1.5.3
|
19 |
+
python-dotenv==1.0.0
|
20 |
+
Requests==2.29.0
|
21 |
+
scikit_learn==1.2.2
|
22 |
+
sentence-transformers==2.2.2
|
23 |
+
InstructorEmbedding==1.0.0
|
24 |
+
faiss_cpu==1.7.3
|
25 |
+
tqdm==4.64.1
|
26 |
+
uvicorn==0.22.0
|
27 |
+
wandb==0.15.0
|
28 |
+
pytest==7.3.1
|
run_docker.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
docker-compose down && docker-compose up --build
|
run_tests.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
pytest -o "testpaths=tests/" --noconftest -vv
|
tests/__init__.py
ADDED
File without changes
|
tests/discord_bot/__init__.py
ADDED
File without changes
|
tests/discord_bot/client/__init__.py
ADDED
File without changes
|
tests/discord_bot/client/lorem_ipsum.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Lorem ipsum dolor sit amet, ```consectetur adipiscing elit. Sed vitae arcu``` eros. Sed gravida tellus quis ante luctus, sed scelerisque mi auctor.
|
2 |
+
Pellentesque consectetur fringilla turpis, in viverra odio posuere ut. In venenatis dui lectus, eget dignissim lectus lacinia et.
|
3 |
+
Nulla commodo, nunc et vulputate vestibulum, mauris lorem semper nisl, a lacinia nulla urna ac lectus. Fusce pulvinar pulvinar augue vitae pulvinar. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; Mauris ut elit nec erat tristique sagittis id in massa.
|
4 |
+
Nunc ac libero in mauris consectetur bibendum non id nulla. Suspendisse feugiat ipsum metus, et auctor massa lobortis ac.
|
5 |
+
Sed tincidunt dapibus sapien, ut tincidunt dui venenatis vitae. Praesent fermentum mi quis elit posuere, in consequat dui ultrices. Nullam tempus sapien non gravida faucibus.
|
6 |
+
Etiam placerat magna risus, in dictum justo commodo eu. Fusce commodo viverra augue, non ullamcorper nunc efficitur et.
|
7 |
+
Nam convallis dui vel nunc malesuada, at pharetra dui pulvinar. Cras commodo nibh et neque placerat eleifend. Mauris eget ante pharetra, mattis turpis ac, laoreet enim.
|
8 |
+
Sed non sem sit amet dui venenatis posuere. Proin eget eros at nunc fringilla varius eu nec justo.
|
9 |
+
Duis tristique consequat metus vitae euismod. Maecenas ullamcorper ullamcorper diam, nec tristique est congue nec. Suspendisse eu elit id massa fermentum pharetra in et mauris.
|
10 |
+
In finibus metus nec odio aliquam, eu finibus ligula venenatis. Donec ullamcorper turpis eget sapien blandit volutpat. Sed tincidunt malesuada urna vitae cursus.
|
11 |
+
Phasellus in est in purus lobortis dictum. Vivamus tristique turpis at lectus dapibus, ac volutpat arcu volutpat. Integer ac gravida felis.
|
12 |
+
Donec laoreet neque non enim fermentum, vitae tincidunt est varius. Phasellus vehicula nisl nec metus dictum malesuada. Vestibulum sit amet erat et tellus euismod congue.
|
13 |
+
Nulla facilisi. Mauris venenatis arcu vel elementum consectetur. Cras volutpat scelerisque sollicitudin.
|
14 |
+
Fusce eu justo et magna gravida finibus id ut augue. Curabitur malesuada, justo ut semper dapibus, nulla urna maximus tellus, in vestibulum mauris justo id lectus. Vestibulum sit amet mauris feugiat, rhoncus ligula sed, finibus mi.
|
tests/discord_bot/client/test_utils.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import os
|
3 |
+
from discord_bot.client.utils import ( \
|
4 |
+
find_max_split_index, \
|
5 |
+
find_max_split_index_from_sequence, \
|
6 |
+
split_text_into_chunks
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
@pytest.fixture(scope='module')
|
11 |
+
def test_chunk() -> str:
|
12 |
+
return 't. , \n .'
|
13 |
+
|
14 |
+
|
15 |
+
@pytest.fixture(scope='module')
|
16 |
+
def test_text() -> str:
|
17 |
+
with open('tests/discord_bot/client/lorem_ipsum.txt', 'r') as f:
|
18 |
+
text = f.read()
|
19 |
+
assert text is not None, 'test text is empty'
|
20 |
+
return text
|
21 |
+
|
22 |
+
|
23 |
+
def test_find_max_splitting_index(test_chunk: str):
|
24 |
+
index = find_max_split_index(test_chunk, char='\n')
|
25 |
+
assert index == 6, 'index should be 6'
|
26 |
+
index = find_max_split_index(test_chunk, char='. ')
|
27 |
+
assert index == 3, 'index should be 3'
|
28 |
+
index = find_max_split_index(test_chunk, char='.')
|
29 |
+
assert index == 8, 'index should be 8'
|
30 |
+
|
31 |
+
|
32 |
+
def test_find_max_split_index_from_sequence(test_chunk: str):
|
33 |
+
index = find_max_split_index_from_sequence(
|
34 |
+
test_chunk,
|
35 |
+
split_characters=['\n']
|
36 |
+
)
|
37 |
+
assert index == 6, 'index should be 6'
|
38 |
+
index = find_max_split_index_from_sequence(
|
39 |
+
test_chunk,
|
40 |
+
split_characters=['.', ', ', '\n']
|
41 |
+
)
|
42 |
+
assert index == 8, 'index should be 8'
|
43 |
+
|
44 |
+
|
45 |
+
def test_split_text_into_chunks_with_split_characters(test_text: str):
|
46 |
+
max_chunk_size = 250
|
47 |
+
chunks = split_text_into_chunks(
|
48 |
+
test_text,
|
49 |
+
split_characters=['. ', ', ', '\n'],
|
50 |
+
min_size=20,
|
51 |
+
max_size=max_chunk_size
|
52 |
+
)
|
53 |
+
for chunk in chunks:
|
54 |
+
assert len(chunk) > 0, 'Chunk length is zero'
|
55 |
+
assert len(chunk) <= max_chunk_size, 'Chunk length exceeds maximum limit'
|
56 |
+
|
57 |
+
|
58 |
+
def test_split_text_into_chunks_without_split_characters():
|
59 |
+
test_text = 'a' * 1000
|
60 |
+
max_chunk_size = 250
|
61 |
+
chunks = split_text_into_chunks(
|
62 |
+
test_text,
|
63 |
+
split_characters=[],
|
64 |
+
min_size=20,
|
65 |
+
max_size=max_chunk_size
|
66 |
+
)
|
67 |
+
for chunk in chunks:
|
68 |
+
assert len(chunk) == max_chunk_size, \
|
69 |
+
'Chunk length is too small'
|
tests/index/test_index.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
from typing import Any
|
3 |
+
from huggingface_hub import snapshot_download
|
4 |
+
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
5 |
+
from langchain.vectorstores import FAISS
|
6 |
+
|
7 |
+
|
8 |
+
snapshot_download(
|
9 |
+
repo_id='KonradSzafer/index-large-notebooks',
|
10 |
+
allow_patterns=['*.faiss', '*.pkl'],
|
11 |
+
repo_type='dataset',
|
12 |
+
local_dir='indexes/'
|
13 |
+
)
|
14 |
+
|
15 |
+
@pytest.fixture(scope='module')
|
16 |
+
def embedding_model() -> HuggingFaceInstructEmbeddings:
|
17 |
+
model_name = 'hkunlp/instructor-large'
|
18 |
+
embed_instruction = 'Represent the Hugging Face library documentation'
|
19 |
+
query_instruction = 'Query the most relevant piece of information from the Hugging Face documentation'
|
20 |
+
return HuggingFaceInstructEmbeddings(
|
21 |
+
model_name=model_name,
|
22 |
+
embed_instruction=embed_instruction,
|
23 |
+
query_instruction=query_instruction,
|
24 |
+
)
|
25 |
+
|
26 |
+
@pytest.fixture(scope='module')
|
27 |
+
def index_path() -> str:
|
28 |
+
return 'indexes/'
|
29 |
+
|
30 |
+
@pytest.fixture(scope='module')
|
31 |
+
def index(embedding_model: HuggingFaceInstructEmbeddings, index_path: str):
|
32 |
+
return FAISS.load_local(index_path, embedding_model)
|
33 |
+
|
34 |
+
@pytest.fixture(scope='module')
|
35 |
+
def query() -> str:
|
36 |
+
return 'How to use the tokenizer?'
|
37 |
+
|
38 |
+
def test_load_index(embedding_model: HuggingFaceInstructEmbeddings, index_path: str):
|
39 |
+
index = FAISS.load_local(index_path, embedding_model)
|
40 |
+
assert index is not None, 'Failed to load index'
|
41 |
+
|
42 |
+
def test_index_page_content(index, query: str):
|
43 |
+
query_docs = index.similarity_search(query=query, k=3)
|
44 |
+
assert isinstance(query_docs[0].page_content, str)
|
45 |
+
|
46 |
+
def test_index_metadata(index, query):
|
47 |
+
query_docs = index.similarity_search(query=query, k=3)
|
48 |
+
assert isinstance(query_docs[0].metadata['source'], str)
|
tests/qa_engine/__init__.py
ADDED
File without changes
|
tests/qa_engine/test_response.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pytest
|
3 |
+
import importlib
|
4 |
+
|
5 |
+
from qa_engine.response import Response
|
6 |
+
|
7 |
+
|
8 |
+
def test_set_answer():
|
9 |
+
r = Response()
|
10 |
+
r.set_answer('Hello, World!')
|
11 |
+
assert r.get_answer() == 'Hello, World!'
|
12 |
+
|
13 |
+
|
14 |
+
def test_set_sources():
|
15 |
+
r = Response()
|
16 |
+
r.set_sources(['source1', 'source1', 'source2'])
|
17 |
+
assert len(r.get_sources()) == 2
|
18 |
+
|
19 |
+
|
20 |
+
def test_get_sources_as_text():
|
21 |
+
r = Response()
|
22 |
+
r.set_sources(['source1', 'source2'])
|
23 |
+
assert isinstance(r.get_sources_as_text(), str)
|
24 |
+
|
25 |
+
|
26 |
+
def test_get_response_include_sources():
|
27 |
+
r = Response()
|
28 |
+
r.set_answer('Hello, World!')
|
29 |
+
r.set_sources(['source1', 'source2'])
|
30 |
+
assert len(r.get_answer(include_sources=True)) > len('Hello, World!')
|