Ferdi commited on
Commit
3783dce
1 Parent(s): 7a036f4
Files changed (7) hide show
  1. Dockerfile +19 -0
  2. requirements.txt +8 -0
  3. src/app.py +61 -0
  4. src/conversation.py +50 -0
  5. src/setup.py +16 -0
  6. src/utils.py +70 -0
  7. src/vector_index.py +72 -0
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.9-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /usr/src/app
6
+
7
+
8
+ # Install any needed packages specified in requirements.txt
9
+ COPY requirements.txt ./
10
+ RUN pip install -r requirements.txt
11
+
12
+ # Copy the rest of your application's code
13
+ COPY ./src .
14
+
15
+ # Make port 7860 available to the world outside this container
16
+ EXPOSE 7860
17
+
18
+ # Run app.py when the container launches
19
+ CMD ["python", "./app.py"]
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ docarray==0.39.1
2
+ faiss-cpu==1.7.4
3
+ gradio==4.8.0
4
+ langchain==0.0.348
5
+ openai==1.3.8
6
+ pypdf==3.17.2
7
+ tiktoken==0.5.2
8
+ transformers==4.36.0
src/app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils import *
3
+
4
+ prompt_keys = load_prompts_list_from_json('prompts.json')
5
+
6
+ with gr.Blocks(gr.themes.Soft(primary_hue=gr.themes.colors.slate, secondary_hue=gr.themes.colors.purple)) as demo:
7
+ with gr.Row():
8
+
9
+ with gr.Column(scale=1, variant = 'panel'):
10
+ # gr.HTML(f"<img src='file/logo.png' width='100' height='100'>")
11
+ files = gr.File(type="filepath", file_count="multiple")
12
+ with gr.Row(equal_height=True):
13
+ vector_index_btn = gr.Button('Create vector store', variant='primary',scale=1)
14
+ vector_index_msg_out = gr.Textbox(show_label=False, lines=1,scale=1, placeholder="Creating vectore store ...")
15
+
16
+ prompt_dropdown = gr.Dropdown(label="Select a prompt", choices=prompt_keys, value=prompt_keys[0])
17
+
18
+ with gr.Accordion(label="Text generation tuning parameters"):
19
+ temperature = gr.Slider(label="temperature", minimum=0.1, maximum=1, value=0.1, step=0.05)
20
+ max_new_tokens = gr.Slider(label="max_new_tokens", minimum=1, maximum=4096, value=1024, step=1)
21
+ k_context=gr.Slider(label="k_context", minimum=1, maximum=15, value=5, step=1)
22
+
23
+ vector_index_btn.click(upload_and_create_vector_store, inputs=[files], outputs=vector_index_msg_out)
24
+
25
+ with gr.Column(scale=1, variant = 'panel'):
26
+ with gr.Row(equal_height=True):
27
+
28
+ with gr.Column(scale=1):
29
+ llm = gr.Dropdown(choices= ["gpt-3.5-turbo", "gpt-3.5-turbo-instruct", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k"],
30
+ label="Select the model")
31
+
32
+ with gr.Column(scale=1):
33
+ model_load_btn = gr.Button('Load model', variant='primary',scale=2)
34
+ load_success_msg = gr.Textbox(show_label=False,lines=1, placeholder="Model loading ...")
35
+ chatbot = gr.Chatbot([], elem_id="chatbot",
36
+ label='Chatbox', height=725, )
37
+
38
+ txt = gr.Textbox(label= "Question",lines=2,placeholder="Enter your question and press shift+enter ")
39
+
40
+ with gr.Row():
41
+
42
+ with gr.Column(scale=1):
43
+ submit_btn = gr.Button('Submit',variant='primary', size = 'sm')
44
+
45
+ with gr.Column(scale=1):
46
+ clear_btn = gr.Button('Clear',variant='stop',size = 'sm')
47
+
48
+ model_load_btn.click(load_models, [llm], load_success_msg, api_name="load_models")
49
+
50
+ txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then(
51
+ bot, [chatbot,prompt_dropdown,temperature,max_new_tokens,k_context], chatbot)
52
+ submit_btn.click(add_text, [chatbot, txt], [chatbot, txt]).then(
53
+ bot, [chatbot,prompt_dropdown,temperature, max_new_tokens,k_context], chatbot).then(
54
+ clear_cuda_cache, None, None
55
+ )
56
+
57
+ clear_btn.click(lambda: None, None, chatbot, queue=False)
58
+
59
+ if __name__ == '__main__':
60
+ # demo.queue(concurrency_count=3)
61
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)
src/conversation.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.vectorstores import FAISS
2
+ from langchain.chains import ConversationalRetrievalChain
3
+ from langchain.chat_models import ChatOpenAI
4
+ from langchain.embeddings import OpenAIEmbeddings
5
+ from langchain.prompts import PromptTemplate
6
+ import os
7
+
8
+ openai_api_key = os.environ.get("OPENAI_API_KEY")
9
+
10
+ class Conversation_RAG:
11
+ def __init__(self, model_name="gpt-3.5-turbo"):
12
+ self.model_name = model_name
13
+
14
+ def create_vectordb(self):
15
+ vectordb = FAISS.load_local("./db/faiss_index", OpenAIEmbeddings())
16
+
17
+ return vectordb
18
+
19
+ def create_model(self, max_new_tokens=512, temperature=0.1):
20
+
21
+ llm = ChatOpenAI(
22
+ openai_api_key=openai_api_key,
23
+ model_name=self.model_name,
24
+ temperature=temperature,
25
+ max_tokens=max_new_tokens,
26
+ )
27
+
28
+ return llm
29
+
30
+ def create_conversation(self, model, vectordb, k_context=5, instruction="Use the following pieces of context to answer the question at the end by. Generate the answer based on the given context only. If you do not find any information related to the question in the given context, just say that you don't know, don't try to make up an answer. Keep your answer expressive."):
31
+
32
+ print(instruction)
33
+
34
+ template = instruction + """
35
+ context:\n
36
+ {context}\n
37
+ data: {question}\n
38
+ """
39
+
40
+ QCA_PROMPT = PromptTemplate(input_variables=["instruction", "context", "question"], template=template)
41
+
42
+ qa = ConversationalRetrievalChain.from_llm(
43
+ llm=model,
44
+ chain_type='stuff',
45
+ retriever=vectordb.as_retriever(search_kwargs={"k": k_context}),
46
+ combine_docs_chain_kwargs={"prompt": QCA_PROMPT},
47
+ get_chat_history=lambda h: h,
48
+ verbose=True
49
+ )
50
+ return qa
src/setup.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from conversation import Conversation_RAG
2
+ from vector_index import *
3
+
4
+ class ModelSetup:
5
+ def __init__(self, model_name):
6
+
7
+ self.model_name = model_name
8
+
9
+ def setup(self):
10
+
11
+ conv_rag = Conversation_RAG(self.model_name)
12
+
13
+ self.vectordb = conv_rag.create_vectordb()
14
+ self.pipeline = conv_rag.create_model()
15
+
16
+ return "Model Setup Complete"
src/utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from conversation import Conversation_RAG
3
+ from vector_index import *
4
+ from setup import ModelSetup
5
+ import json
6
+
7
+ def load_models(model_name):
8
+ global conv_qa
9
+ conv_qa = Conversation_RAG(model_name)
10
+ global model_setup
11
+ model_setup = ModelSetup(model_name)
12
+ success_prompt = model_setup.setup()
13
+ return success_prompt
14
+
15
+ def get_chat_history(inputs):
16
+
17
+ res = []
18
+ for human, ai in inputs:
19
+ res.append(f"Human:{human}\nAssistant:{ai}")
20
+ return "\n".join(res)
21
+
22
+ def add_text(history, text):
23
+
24
+ history = history + [[text, None]]
25
+ return history, ""
26
+
27
+
28
+ def bot(history,
29
+ instruction="Use the following pieces of context to answer the question at the end. Generate the answer based on the given context only if you find the answer in the context. If you do not find any information related to the question in the given context, just say that you don't know, don't try to make up an answer. Keep your answer expressive.",
30
+ temperature=0.1,
31
+ max_new_tokens=512,
32
+ k_context=5,
33
+ ):
34
+
35
+ instruction = load_prompt('prompts.json', instruction)
36
+
37
+ model = conv_qa.create_model(max_new_tokens=max_new_tokens, temperature=temperature)
38
+
39
+ qa = conv_qa.create_conversation(
40
+ model=model,
41
+ vectordb=model_setup.vectordb,
42
+ k_context=k_context,
43
+ instruction=instruction
44
+ )
45
+
46
+ chat_history_formatted = get_chat_history(history[:-1])
47
+ res = qa(
48
+ {
49
+ 'question': history[-1][0],
50
+ 'chat_history': chat_history_formatted
51
+ }
52
+ )
53
+
54
+ history[-1][1] = res['answer']
55
+ return history
56
+
57
+ def clear_cuda_cache():
58
+
59
+ gc.collect()
60
+ return None
61
+
62
+ def load_prompts_list_from_json(json_filepath):
63
+ with open(json_filepath, 'r') as file:
64
+ data = json.load(file)
65
+ return list(data.keys())
66
+
67
+ def load_prompt(json_filepath, key):
68
+ with open(json_filepath, 'r') as file:
69
+ data = json.load(file)
70
+ return data.get(key, key)
src/vector_index.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.vectorstores import FAISS
2
+ from langchain.document_loaders.csv_loader import CSVLoader
3
+ from langchain.document_loaders import PyPDFLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain.embeddings import OpenAIEmbeddings
6
+ import os, shutil
7
+
8
+
9
+ def create_vector_store_index(file_path):
10
+
11
+ file_path_split = file_path.split(".")
12
+ file_type = file_path_split[-1].rstrip('/')
13
+
14
+ if file_type == 'csv':
15
+ print(file_path)
16
+ loader = CSVLoader(file_path=file_path)
17
+ documents = loader.load()
18
+
19
+ elif file_type == 'pdf':
20
+ loader = PyPDFLoader(file_path)
21
+ pages = loader.load()
22
+
23
+ text_splitter = RecursiveCharacterTextSplitter(
24
+ chunk_size = 512,
25
+ chunk_overlap = 128,)
26
+
27
+ documents = text_splitter.split_documents(pages)
28
+
29
+ file_output = "./db/faiss_index"
30
+
31
+ try:
32
+ vectordb = FAISS.load_local(file_output, OpenAIEmbeddings())
33
+ vectordb.add_documents(documents)
34
+ except:
35
+ print("No vector store exists. Creating new one...")
36
+ vectordb = FAISS.from_documents(documents, OpenAIEmbeddings())
37
+
38
+ vectordb.save_local(file_output)
39
+
40
+ return "Vector store index is created."
41
+
42
+
43
+ def upload_and_create_vector_store(files):
44
+ current_folder = os.getcwd()
45
+ data_folder = os.path.join(current_folder, "data")
46
+
47
+ # Create the directory if it doesn't exist
48
+ if not os.path.exists(data_folder):
49
+ os.makedirs(data_folder)
50
+
51
+ index_success_msg = "No new indices added."
52
+
53
+ for file in files:
54
+ # Save each file to a permanent location
55
+ file_path = file
56
+ split_file_name = file_path.split("/")
57
+ file_name = split_file_name[-1]
58
+ permanent_file_path = os.path.join(data_folder, file_name)
59
+
60
+ if os.path.exists(permanent_file_path):
61
+ print(f"File {file_name} already exists. Skipping.")
62
+ continue
63
+
64
+ shutil.copy(file, permanent_file_path)
65
+
66
+ # Access the path of the saved file
67
+ print(f"File saved to: {permanent_file_path}")
68
+
69
+ # Create an index for each file and store the success messages
70
+ index_success_msg = create_vector_store_index(permanent_file_path)
71
+
72
+ return index_success_msg