rajdeep1337 commited on
Commit
8944d2f
β€’
1 Parent(s): d5953dd

Initial commit

Browse files
Files changed (7) hide show
  1. .gitignore +160 -0
  2. LICENSE +7 -0
  3. app.py +127 -0
  4. knowledgebase.py +43 -0
  5. llms/__init__.py +0 -0
  6. llms/tiny_llama.py +41 -0
  7. requirements.txt +108 -0
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Copyright <YEAR> <COPYRIGHT HOLDER>
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the β€œSoftware”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4
+
5
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6
+
7
+ THE SOFTWARE IS PROVIDED β€œAS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from llms.tiny_llama import TinyLlama
4
+ from knowledgebase import KnowledgeBase
5
+ from langchain.chains import ConversationalRetrievalChain
6
+ from langchain.memory import ConversationBufferMemory
7
+
8
+
9
+ LLM = None
10
+ KNOWLEDGEBASE = None
11
+ SYTEM_PROMPT = (
12
+ "You are an assistant. Answer to a user's query based on a given context."
13
+ )
14
+
15
+ # MEMORY = ConversationBufferMemory(
16
+ # memory_key="chat_history", output_key="answer", return_messages=True
17
+ # )
18
+ QA_CHAIN = None
19
+
20
+
21
+ # def init_qa_chain():
22
+ # QA_CHAIN = ConversationalRetrievalChain.from_llm(
23
+ # LLM,
24
+ # retriever=KNOWLEDGEBASE,
25
+ # chain_type="stuff",
26
+ # memory=MEMORY,
27
+ # return_source_documents=True,
28
+ # verbose=True,
29
+ # )
30
+
31
+
32
+ def init_llm(llm="TinyLlama"):
33
+ global LLM
34
+ LLM = TinyLlama()
35
+
36
+
37
+ def chat(message, history):
38
+ global LLM, KNOWLEDGEBASE, SYTEM_PROMPT
39
+ context = KNOWLEDGEBASE.invoke(message)[0].page_content
40
+ system = {"role": "system", "content": SYTEM_PROMPT}
41
+ user = {
42
+ "role": "user",
43
+ "content": f"prompt: ```{message}``\ncontext:```{context}```",
44
+ }
45
+ response = LLM([system, user]).split("<|assistant|>")[-1]
46
+ return response
47
+
48
+
49
+ def init_rag(system_prompt, url_input, file_input):
50
+ global SYTEM_PROMPT
51
+ if SYTEM_PROMPT != system_prompt:
52
+ SYTEM_PROMPT = system_prompt
53
+ gr.Info("Saved new system prompt")
54
+ if url_input and file_input:
55
+ gr.Error(message="Provide either an URL or a File")
56
+ path = url_input if url_input else file_input
57
+ load_knowledgebase(path)
58
+
59
+
60
+ def load_knowledgebase(path):
61
+ global KNOWLEDGEBASE
62
+ if not KNOWLEDGEBASE:
63
+ KNOWLEDGEBASE = KnowledgeBase()
64
+ print("Loading knowledgebase:", path)
65
+ if not path:
66
+ return
67
+ if "https://" in path:
68
+ KNOWLEDGEBASE.load_url(path)
69
+ gr.Info(message="Succesfully loaded URL")
70
+ else:
71
+ if path.split(".")[-1] == "pdf":
72
+ KNOWLEDGEBASE.load_pdf(path)
73
+ gr.Info(message="Succesfully loaded pdf")
74
+ else:
75
+ KNOWLEDGEBASE.load_txt(path)
76
+ gr.Info(message="Succesfully loaded txt")
77
+
78
+
79
+ def show_file(file):
80
+ print(file)
81
+ return file
82
+
83
+
84
+ with gr.Blocks(title="d-RAG") as iface:
85
+ gr.Markdown(
86
+ """# d-RAG &nbsp;[![Watch on GitHub](https://img.shields.io/github/watchers/rumbleFTW/d-RAG.svg?style=social)](https://github.com/rumbleFTW/d-RAG/watchers) &nbsp; [![Star on GitHub](https://img.shields.io/github/stars/rumbleFTW/d-RAG.svg?style=social)](https://github.com/rumbleFTW/d-RAG/stargazers)
87
+ """
88
+ )
89
+ with gr.Row(equal_height=True):
90
+ with gr.Column():
91
+ with gr.Row():
92
+ model = gr.Dropdown(
93
+ label="Model",
94
+ choices=[
95
+ "TinyLlama-1.1B-Chat-v1.0",
96
+ "Mixtral-8x7B-Instruct-v0.1",
97
+ "Mistral-7B-Instruct-v0.2",
98
+ ],
99
+ value="TinyLlama-1.1B-Chat-v1.0",
100
+ scale=1,
101
+ interactive=True,
102
+ )
103
+ system_prompt = gr.Textbox(
104
+ label="System prompt",
105
+ value="You are an assistant. Answer to a user's query based on a given context.",
106
+ scale=2,
107
+ )
108
+ with gr.Accordion(label="Knowledge base", open=True):
109
+ url_input = gr.Textbox(placeholder="URL", value=None)
110
+ gr.Markdown("OR")
111
+ file_input = gr.File(
112
+ file_count="multiple",
113
+ file_types=[".txt", ".pdf"],
114
+ show_label=True,
115
+ visible=True,
116
+ )
117
+ submit = gr.Button("Submit")
118
+ submit.click(
119
+ fn=init_rag,
120
+ inputs=[system_prompt, url_input, file_input],
121
+ )
122
+ with gr.Column():
123
+ demo = gr.ChatInterface(fn=chat, examples=["Namaste!", "Hello!", "Hola!"])
124
+
125
+ if __name__ == "__main__":
126
+ init_llm()
127
+ iface.launch(debug=True)
knowledgebase.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores.faiss import FAISS
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain_community.document_loaders import text, PyPDFLoader, WebBaseLoader
4
+ from langchain_community.document_transformers import Html2TextTransformer
5
+ from langchain_community.embeddings import HuggingFaceEmbeddings
6
+
7
+
8
+ class KnowledgeBase:
9
+ def __init__(self) -> None:
10
+ self.embeddings = HuggingFaceEmbeddings(
11
+ model_name="sentence-transformers/all-mpnet-base-v2"
12
+ )
13
+ self.text_splitter = RecursiveCharacterTextSplitter(
14
+ chunk_size=100, chunk_overlap=50
15
+ )
16
+ self.retriever = None
17
+
18
+ def load_txt(self, path):
19
+ loader = text.TextLoader(file_path=path)
20
+ documents = loader.load()
21
+ chunked_docs = self.text_splitter.split_documents(documents=documents)
22
+ print(chunked_docs.__len__())
23
+ db = FAISS.from_documents(embedding=self.embeddings, documents=chunked_docs)
24
+ self.retriever = db.as_retriever()
25
+
26
+ def load_pdf(self, path):
27
+ loader = PyPDFLoader(file_path=path)
28
+ documents = loader.load()
29
+ chunked_docs = self.text_splitter.split_documents(documents=documents)
30
+ db = FAISS.from_documents(embedding=self.embeddings, documents=chunked_docs)
31
+ self.retriever = db.as_retriever()
32
+
33
+ def load_url(self, path):
34
+ loader = WebBaseLoader(web_path=path)
35
+ docs = loader.load()
36
+ html2text = Html2TextTransformer()
37
+ docs_transformed = html2text.transform_documents(docs)
38
+ chunked_docs = self.text_splitter.split_documents(docs_transformed)
39
+ db = FAISS.from_documents(embedding=self.embeddings, documents=chunked_docs)
40
+ self.retriever = db.as_retriever()
41
+
42
+ def invoke(self, query):
43
+ return self.retriever.invoke(query)
llms/__init__.py ADDED
File without changes
llms/tiny_llama.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+
4
+
5
+ class TinyLlama:
6
+ def __init__(self) -> None:
7
+ self.tokenizer = AutoTokenizer.from_pretrained(
8
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
9
+ )
10
+ self.model = AutoModelForCausalLM.from_pretrained(
11
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
12
+ load_in_4bit=True,
13
+ device_map="auto",
14
+ bnb_4bit_compute_dtype=torch.float16,
15
+ )
16
+
17
+ print(f"LLM loaded to {self.model.device}")
18
+
19
+ self._messages = []
20
+
21
+ def __call__(self, messages, *args, **kwds):
22
+ tokenized_chat = self.tokenizer.apply_chat_template(
23
+ messages, tokenize=False, add_generation_prompt=True
24
+ )
25
+ inputs = self.tokenizer(tokenized_chat, return_tensors="pt").to(
26
+ self.model.device
27
+ )
28
+
29
+ outputs = self.model.generate(
30
+ **inputs,
31
+ use_cache=True,
32
+ max_length=1000,
33
+ min_length=10,
34
+ temperature=0.7,
35
+ num_return_sequences=1,
36
+ do_sample=True,
37
+ )
38
+
39
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
40
+
41
+ return generated_text
requirements.txt ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.26.1
2
+ aiofiles==23.2.1
3
+ aiohttp==3.9.3
4
+ aiosignal==1.3.1
5
+ altair==5.2.0
6
+ annotated-types==0.6.0
7
+ anyio==4.2.0
8
+ attrs==23.2.0
9
+ bitsandbytes==0.42.0
10
+ certifi==2023.11.17
11
+ charset-normalizer==3.3.2
12
+ click==8.1.7
13
+ colorama==0.4.6
14
+ contourpy==1.2.0
15
+ cycler==0.12.1
16
+ dataclasses-json==0.6.3
17
+ fastapi==0.109.0
18
+ ffmpy==0.3.1
19
+ filelock==3.13.1
20
+ fonttools==4.47.2
21
+ frozenlist==1.4.1
22
+ fsspec==2023.12.2
23
+ gradio==4.16.0
24
+ gradio_client==0.8.1
25
+ greenlet==3.0.3
26
+ h11==0.14.0
27
+ httpcore==1.0.2
28
+ httpx==0.26.0
29
+ huggingface-hub==0.20.3
30
+ idna==3.6
31
+ importlib-resources==6.1.1
32
+ Jinja2==3.1.3
33
+ jsonpatch==1.33
34
+ jsonpointer==2.4
35
+ jsonschema==4.21.1
36
+ jsonschema-specifications==2023.12.1
37
+ kiwisolver==1.4.5
38
+ langchain==0.1.4
39
+ langchain-community==0.0.16
40
+ langchain-core==0.1.17
41
+ langsmith==0.0.84
42
+ markdown-it-py==3.0.0
43
+ MarkupSafe==2.1.4
44
+ marshmallow==3.20.2
45
+ matplotlib==3.8.2
46
+ mdurl==0.1.2
47
+ mpmath==1.3.0
48
+ multidict==6.0.4
49
+ mypy-extensions==1.0.0
50
+ networkx==3.2.1
51
+ numpy==1.26.3
52
+ nvidia-cublas-cu12==12.1.3.1
53
+ nvidia-cuda-cupti-cu12==12.1.105
54
+ nvidia-cuda-nvrtc-cu12==12.1.105
55
+ nvidia-cuda-runtime-cu12==12.1.105
56
+ nvidia-cudnn-cu12==8.9.2.26
57
+ nvidia-cufft-cu12==11.0.2.54
58
+ nvidia-curand-cu12==10.3.2.106
59
+ nvidia-cusolver-cu12==11.4.5.107
60
+ nvidia-cusparse-cu12==12.1.0.106
61
+ nvidia-nccl-cu12==2.18.1
62
+ nvidia-nvjitlink-cu12==12.3.101
63
+ nvidia-nvtx-cu12==12.1.105
64
+ orjson==3.9.12
65
+ packaging==23.2
66
+ pandas==2.2.0
67
+ pillow==10.2.0
68
+ psutil==5.9.8
69
+ pydantic==2.6.0
70
+ pydantic_core==2.16.1
71
+ pydub==0.25.1
72
+ Pygments==2.17.2
73
+ pyparsing==3.1.1
74
+ python-dateutil==2.8.2
75
+ python-multipart==0.0.6
76
+ pytz==2023.4
77
+ PyYAML==6.0.1
78
+ referencing==0.33.0
79
+ regex==2023.12.25
80
+ requests==2.31.0
81
+ rich==13.7.0
82
+ rpds-py==0.17.1
83
+ ruff==0.1.15
84
+ safetensors==0.4.2
85
+ scipy==1.12.0
86
+ semantic-version==2.10.0
87
+ shellingham==1.5.4
88
+ six==1.16.0
89
+ sniffio==1.3.0
90
+ SQLAlchemy==2.0.25
91
+ starlette==0.35.1
92
+ sympy==1.12
93
+ tenacity==8.2.3
94
+ tokenizers==0.15.1
95
+ tomlkit==0.12.0
96
+ toolz==0.12.1
97
+ torch==2.1.2
98
+ tqdm==4.66.1
99
+ transformers==4.37.2
100
+ triton==2.1.0
101
+ typer==0.9.0
102
+ typing-inspect==0.9.0
103
+ typing_extensions==4.9.0
104
+ tzdata==2023.4
105
+ urllib3==2.1.0
106
+ uvicorn==0.27.0.post1
107
+ websockets==11.0.3
108
+ yarl==1.9.4