IMvision12 commited on
Commit
ba60257
1 Parent(s): 7e967db
Files changed (5) hide show
  1. .gitignore +160 -0
  2. app.py +134 -0
  3. data.py +61 -0
  4. model.py +73 -0
  5. requirements.txt +10 -0
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data import create_retriever
2
+ from model import initialize_llmchain
3
+ import streamlit as st
4
+ import os
5
+ from langchain.chains import RetrievalQA
6
+ from streamlit_chat import message
7
+ import sys
8
+ __import__('pysqlite3')
9
+ sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
10
+ st.set_page_config(page_title="🤗Chat💬")
11
+
12
+ embed_model_dict = {
13
+ "MiniLM-L6": "nreimers/MiniLM-L6-H384-uncased",
14
+ "Mpnet-Base": "sentence-transformers/all-mpnet-base-v2",
15
+ }
16
+
17
+ llm_model_dict = {
18
+ "Llama-2 7B (Free)" : "daryl149/llama-2-7b-chat-hf",
19
+ "Gemma 7B": "google/gemma-7b",
20
+ "Gemma 2B": "google/gemma-2b",
21
+ "Gemma 7B-it": "google/gemma-7b-it",
22
+ "Gemma 2B-it": "google/gemma-2b-it",
23
+ "Llama-2 7B Chat HF": "meta-llama/Llama-2-7b-chat-hf",
24
+ "Llama-2 70B Chat HF": "meta-llama/Llama-2-70b-chat-hf",
25
+ "Llama-2 13B Chat HF": "meta-llama/Llama-2-13b-chat-hf",
26
+ "Llama-2 70B": "meta-llama/Llama-2-70b",
27
+ "Llama-2 13B": "meta-llama/Llama-2-13b",
28
+ "Llama-2 7B": "meta-llama/Llama-2-7b",
29
+ }
30
+
31
+ def save_uploadedfile(uploadedfile):
32
+ if not os.path.exists("./tempfolder"):
33
+ os.makedirs("./tempfolder")
34
+
35
+ full_path = os.path.join("tempfolder", uploadedfile.name)
36
+ with open(full_path, "wb") as f:
37
+ f.write(uploadedfile.getbuffer())
38
+
39
+ return st.success("Saved File")
40
+
41
+
42
+ with st.sidebar:
43
+ st.markdown(
44
+ f"""
45
+ <style>
46
+ section[data-testid="stSidebar"] .css-ng1t4o {{width: 100rem;}}
47
+ </style>
48
+ """,
49
+ unsafe_allow_html=True,
50
+ )
51
+ st.header("Choose and Configure your Embedding Model", divider="rainbow")
52
+ uploaded_files = st.file_uploader(
53
+ "Choose a file", type=["pdf"], accept_multiple_files=True
54
+ )
55
+ embed_model = embed_model_dict[
56
+ st.selectbox("Select Embedding Model", ("MiniLM-L6", "Mpnet-Base"))
57
+ ]
58
+ for file in uploaded_files:
59
+ save_uploadedfile(file)
60
+
61
+ chunksize = st.slider("Chunk Size", 256, 1024, 400, 10)
62
+ chunkoverlap = st.slider("Chunk Overlap", 100, 500, 300, 10)
63
+
64
+ st.header("Choose and Configure your LLM Model", divider="rainbow")
65
+ llm_model = llm_model_dict[
66
+ st.selectbox("Select LLM Model", (llm_model_dict.keys()))
67
+ ]
68
+ access_token = st.text_input("Enter HuggingFace Access Token")
69
+ temperature = st.slider("Temperature", 256, 1024, 400, 10)
70
+ max_tokens = st.slider("Max Tokens", 256, 1024, 400, 10)
71
+ top_k = st.slider("top_k", 256, 1024, 400, 10)
72
+ quantization_option = st.radio("Quantization Option", ("8Bit Quant", "4Bit Quant"))
73
+ load_in_4bit = True if quantization_option == "4Bit Quant" else False
74
+ load_in_8bit = True if quantization_option != "4Bit Quant" else False
75
+
76
+ if st.button("Submit"):
77
+ with st.spinner("Loading.... Processing PDFs..."):
78
+ retriever = create_retriever(
79
+ pdf_directory="./tempfolder",
80
+ chunk_size=chunksize,
81
+ chunk_overlap=chunkoverlap,
82
+ embedding_model_name=embed_model,
83
+ )
84
+ with st.spinner("Loading LLM Model...."):
85
+ llm = initialize_llmchain(
86
+ llm_model=llm_model,
87
+ temperature=temperature,
88
+ max_tokens=max_tokens,
89
+ top_k=top_k,
90
+ load_in_4bit=load_in_4bit,
91
+ load_in_8bit=load_in_8bit,
92
+ access_token=access_token,
93
+ )
94
+
95
+ st.title("💬 Chat With PDFs")
96
+
97
+ st.markdown("- Choose 🚀 and Configure your Embedding Model")
98
+ st.markdown("- Choose 🚀 and COnfigure your LLM Model.")
99
+ st.markdown("- Enter your HuggingFace Token ❗️(Only Llama-2 7B (Free) will work without HF Token)")
100
+ st.markdown(
101
+ """
102
+ <p align="center">It will take some time <b>⏳</b> to download and load the models.</p>
103
+ <p align="center">Once download is complete you can start Chatting!.</p>
104
+ """,
105
+ unsafe_allow_html=True,
106
+ )
107
+
108
+ st.markdown('''
109
+ <style>
110
+ [data-testid="stMarkdownContainer"] ul{
111
+ padding-left:40px;
112
+ }
113
+ </style>
114
+ ''', unsafe_allow_html=True)
115
+
116
+
117
+ if "messages" not in st.session_state:
118
+ st.session_state.messages = []
119
+
120
+ for message in st.session_state.messages:
121
+ with st.chat_message(message["role"]):
122
+ st.markdown(message["content"])
123
+
124
+
125
+ if prompt := st.chat_input("What is up?", key="user_input"):
126
+ st.session_state.messages.append({"role": "user", "content": prompt})
127
+
128
+ with st.chat_message("user"):
129
+ st.markdown(prompt)
130
+
131
+ with st.chat_message("assistant"):
132
+ response = "Hi"
133
+ st.session_state.messages.append({"role": "assistant", "content": response})
134
+ st.markdown(response)
data.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import PyPDFDirectoryLoader
2
+ from typing import Optional, Dict
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_community.vectorstores import Chroma
5
+ from langchain_community.embeddings import HuggingFaceEmbeddings
6
+ import warnings
7
+
8
+ warnings.simplefilter("ignore")
9
+
10
+ def create_retriever(
11
+ pdf_directory: str,
12
+ chunk_size: int = 1000,
13
+ chunk_overlap: int = 100,
14
+ embedding_model_name: str = "sentence-transformers/all-mpnet-base-v2",
15
+ model_kwargs: Optional[Dict[str, str]] = {"device": "cpu"},
16
+ ):
17
+ """
18
+ Creates and returns a retriever object based on the provided PDF directory and configurations.
19
+
20
+ Args:
21
+ - pdf_directory (str): Path to the directory containing PDF files.
22
+ - chunk_size (int): Size of each chunk for splitting documents.
23
+ - chunk_overlap (int): Overlap size between adjacent chunks.
24
+ - embedding_model_name (str): Name of the HuggingFace embedding model to be used.
25
+ - model_kwargs (dict, optional): Additional keyword arguments for the embedding model.
26
+
27
+ Returns:
28
+ - retriever (Retriever): Retriever object for retrieving documents.
29
+
30
+ Raises:
31
+ - ValueError: If input values are invalid.
32
+ """
33
+ if chunk_size <= 0:
34
+ raise ValueError("Chunk size must be a positive integer.")
35
+ if chunk_overlap < 0 or chunk_overlap >= chunk_size:
36
+ raise ValueError(
37
+ "Chunk overlap must be a non-negative integer less than the chunk size."
38
+ )
39
+ # Load documents
40
+ loader = PyPDFDirectoryLoader(pdf_directory)
41
+ documents = loader.load()
42
+
43
+ # Split documents into small chunks
44
+ text_splitter = RecursiveCharacterTextSplitter(
45
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap
46
+ )
47
+ all_splits = text_splitter.split_documents(documents)
48
+
49
+ # Specify embedding model
50
+ embeddings = HuggingFaceEmbeddings(
51
+ model_name=embedding_model_name, model_kwargs=model_kwargs
52
+ )
53
+
54
+ # Embed document chunks
55
+ vectordb = Chroma.from_documents(
56
+ documents=all_splits, embedding=embeddings, persist_directory="chroma_db"
57
+ )
58
+
59
+ # Create and return retriever
60
+ retriever = vectordb.as_retriever()
61
+ return retriever
model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
+ from langchain.llms import HuggingFacePipeline
4
+ from transformers import BitsAndBytesConfig
5
+
6
+ def initialize_llmchain(
7
+ llm_model: str,
8
+ temperature: float,
9
+ max_tokens: int,
10
+ top_k: int,
11
+ access_token: str = None,
12
+ torch_dtype: str = "auto",
13
+ load_in_8bit: bool = False,
14
+ load_in_4bit: bool = False,
15
+ ) -> HuggingFacePipeline:
16
+ """
17
+ Initializes a language model chain based on the provided parameters.
18
+
19
+ Args:
20
+ - llm_model (str): The name of the language model to initialize.
21
+ - temperature (float): The temperature parameter for text generation.
22
+ - max_tokens (int): The maximum number of tokens to generate.
23
+ - top_k (int): The top-k parameter for token selection during generation.
24
+ - torch_dtype (str): The torch dtype to be used for model inference (default is "auto").
25
+ - load_in_8bit (bool): Whether to load the model in 8-bit format (default is False).
26
+ - load_in_4bit (bool): Whether to load the model in 4-bit format (default is False).
27
+
28
+ Returns:
29
+ - HuggingFacePipeline: Initialized language model pipeline.
30
+ """
31
+
32
+ if load_in_8bit:
33
+ bnb_config = BitsAndBytesConfig(
34
+ load_in_8bit=True
35
+ )
36
+ elif load_in_4bit:
37
+ bnb_config = BitsAndBytesConfig(
38
+ load_in_8bit=False,
39
+ load_in_4bit=True,
40
+ bnb_4bit_quant_type="nf4",
41
+ bnb_4bit_use_double_quant=True,
42
+ bnb_4bit_compute_dtype=torch.bfloat16
43
+ )
44
+ else:
45
+ bnb_config = None
46
+
47
+ model_kwargs = {
48
+ "temperature": temperature,
49
+ "max_new_tokens": max_tokens,
50
+ "top_k": top_k,
51
+ "torch_dtype": torch_dtype,
52
+ }
53
+
54
+ # Initialize model and tokenizer
55
+ model = AutoModelForCausalLM.from_pretrained(
56
+ llm_model,
57
+ low_cpu_mem_usage=True,
58
+ quantization_config=bnb_config
59
+ )
60
+ tokenizer = AutoTokenizer.from_pretrained(llm_model)
61
+
62
+ # Initialize pipeline
63
+ pipe = pipeline(
64
+ task="text-generation",
65
+ model=model,
66
+ tokenizer=tokenizer,
67
+ token=access_token,
68
+ model_kwargs=model_kwargs,
69
+ pad_token_id=tokenizer.eos_token_id,
70
+ )
71
+
72
+ llm = HuggingFacePipeline(pipeline=pipe)
73
+ return llm
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ pypdf
2
+ langchain
3
+ sentence-transformers
4
+ peft
5
+ chromadb
6
+ accelerate==0.28.0
7
+ bitsandbytes==0.43.0
8
+ streamlit
9
+ streamlit-chat
10
+ pysqlite3-binary