rishisim commited on
Commit
806b207
1 Parent(s): 277cd48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -36
app.py CHANGED
@@ -1,61 +1,126 @@
1
  import gradio as gr
2
- from langchain_community.document_loaders import CSVLoader # Changed import
3
- from langchain_community.vectorstores import FAISS # Changed import
4
- from langchain.prompts import PromptTemplate
5
- from langchain.chains import RetrievalQA
6
- from langchain.llms import HuggingFaceLLM # Adjusted for correct instantiation
7
- import warnings
8
- from huggingface_hub import login
9
  import os
10
- from transformers import pipeline
 
 
 
 
 
11
 
12
- # Initialize the LLM using pipeline
13
- llm = pipeline("text-generation", model="meta-llama/Meta-Llama-3-8B-Instruct") # Adjusted initialization
14
 
15
- # Load CSV file
16
- loader = CSVLoader(file_path='aiotsmartlabs_faq.csv', source_column='prompt')
 
 
 
 
17
  data = loader.load()
18
 
19
- # Suppress warnings
20
- warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
21
- warnings.filterwarnings("ignore", category=FutureWarning, message="`resume_download` is deprecated")
22
 
23
- # Embedding model
24
- model_name = "BAAI/bge-m3"
25
- instructor_embeddings = HuggingFaceLLM(model_name=model_name) # Adjusted for correct instantiation
26
 
27
- # Create FAISS vector store from documents
28
- vectordb = FAISS.from_documents(documents=data, embedding=instructor_embeddings)
29
- retriever = vectordb.as_retriever()
 
 
 
 
 
30
 
31
- # Define the prompt template
32
- prompt_template = """Given the following context and a question, generate an answer based on the context only.
33
  In the answer try to provide as much text as possible from "response" section in the source document context without making much changes.
34
  If somebody asks "Who are you?" or a similar phrase, state "I am Rishi's assistant built using a Large Language Model!"
35
  If the answer is not found in the context, kindly state "I don't know. Please ask Rishi on Discord. Discord Invite Link: https://discord.gg/6ezpZGeCcM. Or email at rishi@aiotsmartlabs.com" Don't try to make up an answer.
 
36
  CONTEXT: {context}
37
- QUESTION: {question}"""
38
 
39
- PROMPT = PromptTemplate(
40
- template=prompt_template, input_variables=["context", "question"]
41
- )
42
 
43
- # Initialize the RetrievalQA chain
44
- chain = RetrievalQA.from_chain_type(llm=llm, # Adjusted initialization
45
- chain_type="stuff",
46
- retriever=retriever,
47
- input_key="query",
48
- return_source_documents=True,
49
- chain_type_kwargs={"prompt": PROMPT})
50
 
51
  # Define the chat response function
52
  def chatresponse(message, history):
53
- output = chain(message)
54
- return output['result']
 
55
 
56
  # Launch the Gradio chat interface
57
  gr.ChatInterface(chatresponse).launch()
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  # import gradio as gr
61
  # # from langchain.llms import GooglePalm
 
1
  import gradio as gr
2
+
3
+
 
 
 
 
 
4
  import os
5
+ os.environ["hftoken"] = hftoken
6
+
7
+ from langchain_huggingface import HuggingFaceEndpoint
8
+
9
+ repo_id = "mistralai/Mistral-7B-Instruct-v0.3"
10
+ llm = HuggingFaceEndpoint(repo_id = repo_id, max_new_tokens = 128, temperature = 0.7, huggingfacehub_api_token = hftoken)
11
 
12
+ from langchain_core.output_parsers import StrOutputParser
13
+ from langchain_core.prompts import ChatPromptTemplate
14
 
15
+ prompt = ChatPromptTemplate.from_template("tell me a joke about {topic}")
16
+ chain = prompt | llm | StrOutputParser()
17
+
18
+ from langchain.document_loaders.csv_loader import CSVLoader
19
+
20
+ loader = CSVLoader(file_path='aiotsmartlabs_faq.csv', source_column = 'prompt')
21
  data = loader.load()
22
 
23
+ from langchain_huggingface import HuggingFaceEmbeddings
24
+ from langchain_chroma import Chroma
25
+ from langchain_huggingface.embeddings import HuggingFaceEndpointEmbeddings
26
 
27
+ # CHECK MTEB LEADERBOARD & FIND BEST EMBEDDING MODEL
28
+ model = "BAAI/bge-m3"
29
+ embeddings = HuggingFaceEndpointEmbeddings(model = model)
30
 
31
+ vectorstore = Chroma.from_documents(documents = data, embedding = embeddings)
32
+ retriever = vectorstore.as_retriever()
33
+
34
+ # from langchain.prompts import PromptTemplate
35
+
36
+ from langchain_core.prompts import ChatPromptTemplate
37
+
38
+ prompt = ChatPromptTemplate.from_template("""Given the following context and a question, generate an answer based on the context only.
39
 
 
 
40
  In the answer try to provide as much text as possible from "response" section in the source document context without making much changes.
41
  If somebody asks "Who are you?" or a similar phrase, state "I am Rishi's assistant built using a Large Language Model!"
42
  If the answer is not found in the context, kindly state "I don't know. Please ask Rishi on Discord. Discord Invite Link: https://discord.gg/6ezpZGeCcM. Or email at rishi@aiotsmartlabs.com" Don't try to make up an answer.
43
+
44
  CONTEXT: {context}
 
45
 
46
+ QUESTION: {question}""")
47
+
48
+ from langchain_core.runnables import RunnablePassthrough
49
 
50
+ rag_chain = (
51
+ {"context": retriever, "question": RunnablePassthrough()}
52
+ | prompt
53
+ | llm
54
+ | StrOutputParser()
55
+ )
 
56
 
57
  # Define the chat response function
58
  def chatresponse(message, history):
59
+ output = rag_chain.invoke(message)
60
+ response = output.split('ANSWER: ')[-1].strip()
61
+ print(response)
62
 
63
  # Launch the Gradio chat interface
64
  gr.ChatInterface(chatresponse).launch()
65
 
66
+ # import gradio as gr
67
+ # from langchain_community.document_loaders import CSVLoader # Changed import
68
+ # from langchain_community.vectorstores import FAISS # Changed import
69
+ # from langchain.prompts import PromptTemplate
70
+ # from langchain.chains import RetrievalQA
71
+ # from langchain.llms import HuggingFaceLLM # Adjusted for correct instantiation
72
+ # import warnings
73
+ # from huggingface_hub import login
74
+ # import os
75
+ # from transformers import pipeline
76
+
77
+ # # Initialize the LLM using pipeline
78
+ # llm = pipeline("text-generation", model="meta-llama/Meta-Llama-3-8B-Instruct") # Adjusted initialization
79
+
80
+ # # Load CSV file
81
+ # loader = CSVLoader(file_path='aiotsmartlabs_faq.csv', source_column='prompt')
82
+ # data = loader.load()
83
+
84
+ # # Suppress warnings
85
+ # warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
86
+ # warnings.filterwarnings("ignore", category=FutureWarning, message="`resume_download` is deprecated")
87
+
88
+ # # Embedding model
89
+ # model_name = "BAAI/bge-m3"
90
+ # instructor_embeddings = HuggingFaceLLM(model_name=model_name) # Adjusted for correct instantiation
91
+
92
+ # # Create FAISS vector store from documents
93
+ # vectordb = FAISS.from_documents(documents=data, embedding=instructor_embeddings)
94
+ # retriever = vectordb.as_retriever()
95
+
96
+ # # Define the prompt template
97
+ # prompt_template = """Given the following context and a question, generate an answer based on the context only.
98
+ # In the answer try to provide as much text as possible from "response" section in the source document context without making much changes.
99
+ # If somebody asks "Who are you?" or a similar phrase, state "I am Rishi's assistant built using a Large Language Model!"
100
+ # If the answer is not found in the context, kindly state "I don't know. Please ask Rishi on Discord. Discord Invite Link: https://discord.gg/6ezpZGeCcM. Or email at rishi@aiotsmartlabs.com" Don't try to make up an answer.
101
+ # CONTEXT: {context}
102
+ # QUESTION: {question}"""
103
+
104
+ # PROMPT = PromptTemplate(
105
+ # template=prompt_template, input_variables=["context", "question"]
106
+ # )
107
+
108
+ # # Initialize the RetrievalQA chain
109
+ # chain = RetrievalQA.from_chain_type(llm=llm, # Adjusted initialization
110
+ # chain_type="stuff",
111
+ # retriever=retriever,
112
+ # input_key="query",
113
+ # return_source_documents=True,
114
+ # chain_type_kwargs={"prompt": PROMPT})
115
+
116
+ # # Define the chat response function
117
+ # def chatresponse(message, history):
118
+ # output = chain(message)
119
+ # return output['result']
120
+
121
+ # # Launch the Gradio chat interface
122
+ # gr.ChatInterface(chatresponse).launch()
123
+
124
 
125
  # import gradio as gr
126
  # # from langchain.llms import GooglePalm