Spaces:
Paused
Paused
Upload 2 files
Browse files- InnovationHub/llm/chain.py +127 -0
- InnovationHub/llm/vector_store.py +179 -0
InnovationHub/llm/chain.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio
|
2 |
+
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
3 |
+
from langchain.vectorstores import FAISS
|
4 |
+
from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate
|
5 |
+
from langchain.chains.conversation.memory import ConversationBufferMemory
|
6 |
+
|
7 |
+
def chat(question, vehicle, k=10, temperature=0.01):
|
8 |
+
chatgpt_chain = create_chatgpt_chain(temperature=temperature)
|
9 |
+
response = ask_question(question=question, vehicle=vehicle, k=k, embeddings=model_norm, chatgpt_chain=chatgpt_chain)
|
10 |
+
return response
|
11 |
+
|
12 |
+
def create_chatgpt_chain(temperature):
|
13 |
+
template = """
|
14 |
+
{chat_history}
|
15 |
+
Human: {question}
|
16 |
+
AI:
|
17 |
+
"""
|
18 |
+
prompt_template = PromptTemplate(input_variables=["chat_history", "question"], template=template)
|
19 |
+
return LLMChain(llm=OpenAI(temperature=temperature,model_name="gpt-3.5-turbo"),prompt=prompt_template,verbose=True,memory=ConversationBufferMemory(memory_key="chat_history"))
|
20 |
+
|
21 |
+
def ask_question(question, vehicle, k, embeddings, chatgpt_chain):
|
22 |
+
index = FAISS.load_local(folder_path=db_paths[vehicle], embeddings=embeddings)
|
23 |
+
prompt = get_prompt(question=question, vehicle=vehicle, k=k)
|
24 |
+
response = chatgpt_chain.run(question=prompt)
|
25 |
+
return response
|
26 |
+
|
27 |
+
def get_prompt(question, vehicle, k):
|
28 |
+
prompt = f"""
|
29 |
+
I need information from my {vehicle} manual.
|
30 |
+
I will provide an excerpt from the manual. Use the excerpt and nothing else to answer the question.
|
31 |
+
You must refer to the excerpt as "{vehicle} Manual" in your response. Here is the excerpt:
|
32 |
+
"""
|
33 |
+
|
34 |
+
index = FAISS.load_local(folder_path=db_paths[vehicle], embeddings=model_norm)
|
35 |
+
similar_docs = index.similarity_search(query=question, k=k)
|
36 |
+
context = []
|
37 |
+
for d in similar_docs:
|
38 |
+
content = d.page_content
|
39 |
+
context.append(content)
|
40 |
+
|
41 |
+
user_input = prompt + '\n[EXCERPT]' + '\n'.join(context[:k]) + '\nQuestion: ' + question
|
42 |
+
return user_input
|
43 |
+
|
44 |
+
db_paths = {
|
45 |
+
"2023 AMG C-Coupe-Cab": "data/amg_c_coupe_cab",
|
46 |
+
"2023 AMG C-Sedan": "data/amg_c_sedan",
|
47 |
+
"2023 AMG E-Coupe-Cab": "data/amg_e_coupe_cab",
|
48 |
+
"2023 AMG E-Sedan_wagon": "data/amg_e_sedan_wagon",
|
49 |
+
"2023 AMG_EQE-Sedan": "data/amg_eqe_sedan",
|
50 |
+
"2023 AMG_GLE-suv": "data/amg_gle_suv",
|
51 |
+
"2023 AMG_GLS SUV": "data/amg_gls_suv",
|
52 |
+
"2023 C-Cab": "data/c_cab",
|
53 |
+
"2023 C-Coupe": "data/c_coupe",
|
54 |
+
"2023 C-Sedan": "data/c_sedan",
|
55 |
+
"2023 CLA": "data/cla",
|
56 |
+
"2023 E-Cab": "data/e_cab",
|
57 |
+
"2023 E-Coupe": "data/e_coupe",
|
58 |
+
"2023 E-Sedan": "data/e_sedan",
|
59 |
+
"2023 E-wagon": "data/e_wagon",
|
60 |
+
"2023 eqb SUV": "data/eqb_suv",
|
61 |
+
"2023 EQE-Sedan": "data/eqe_sedan",
|
62 |
+
"2023 EQS_Sedan": "data/eqs_sedan",
|
63 |
+
"2023 EQS SUV": "data/eqs_suv",
|
64 |
+
"2023 GLA": "data/gla",
|
65 |
+
"2023 GLB": "data/glb",
|
66 |
+
"2023 GLC-Coupe": "data/glc_coupe",
|
67 |
+
"2023 GLE-Coupe": "data/gle_coupe",
|
68 |
+
"2023 GLE-suv": "data/gle_suv",
|
69 |
+
"2023 GLS SUV": "data/gls_suv"
|
70 |
+
}
|
71 |
+
|
72 |
+
vehicle_options = [
|
73 |
+
"2023 AMG C-Coupe-Cab",
|
74 |
+
"2023 AMG C-Sedan",
|
75 |
+
"2023 AMG E-Coupe-Cab",
|
76 |
+
"2023 AMG E-Sedan_wagon",
|
77 |
+
"2023 AMG_EQE-Sedan",
|
78 |
+
"2023 AMG_GLE-suv",
|
79 |
+
"2023 AMG_GLS SUV",
|
80 |
+
"2023 C-Cab",
|
81 |
+
"2023 C-Coupe",
|
82 |
+
"2023 C-Sedan",
|
83 |
+
"2023 CLA",
|
84 |
+
"2023 E-Cab",
|
85 |
+
"2023 E-Coupe",
|
86 |
+
"2023 E-Sedan",
|
87 |
+
"2023 E-wagon",
|
88 |
+
"2023 eqb SUV",
|
89 |
+
"2023 EQE-Sedan",
|
90 |
+
"2023 EQS SUV",
|
91 |
+
"2023 EQS_Sedan",
|
92 |
+
"2023 GLA",
|
93 |
+
"2023 GLB",
|
94 |
+
"2023 GLC-Coupe",
|
95 |
+
"2023 GLE-Coupe",
|
96 |
+
"2023 GLE-suv",
|
97 |
+
"2023 GLS SUV",
|
98 |
+
]
|
99 |
+
|
100 |
+
model_name = "BAAI/bge-large-en"
|
101 |
+
model_kwargs = {'device': 'cpu'}
|
102 |
+
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
|
103 |
+
model_norm = HuggingFaceBgeEmbeddings(
|
104 |
+
model_name=model_name,
|
105 |
+
model_kwargs=model_kwargs,
|
106 |
+
encode_kwargs=encode_kwargs
|
107 |
+
)
|
108 |
+
|
109 |
+
def start_ui():
|
110 |
+
chatbot_interface = gradio.Interface(
|
111 |
+
fn=chat,
|
112 |
+
inputs=["text",
|
113 |
+
gradio.inputs.Dropdown(vehicle_options, label="Select Mercedes-Benz Owner's Manual")
|
114 |
+
#gradio.inputs.Slider(minimum=1, maximum=10, step=1, label="k")
|
115 |
+
],
|
116 |
+
outputs="text",
|
117 |
+
title="Mercedes-Benz Owner's Manual",
|
118 |
+
description="Ask a question and get answers from Mercedes-Benz Owner's Manual.<u>Disclaimer:</u> THIS IS NOT OFFICIAL AND MAY NOT BE AVAILABLE ALL THE TIME. ALWAYS LOOK AT THE OFFICIAL DOCUMENTATION at https://www.mbusa.com/en/owners/manuals",
|
119 |
+
examples=[["What are the different features of the dashboard console?", "2023 S-Class", 10, 0.01],
|
120 |
+
["What is flacon? Which page has that information? Show me all the exact content from that page", "2023 S-Class", 10, 0.01],
|
121 |
+
["What is hyperscreen?", "2023 EQS", 10, 0.01],
|
122 |
+
["Where can I find my vin?", "2023 EQS", 10, 0.01],
|
123 |
+
["Does it take more than 30 minutes to charge? Which page has that information? Show me all the exact content from that page", "2023 EQE", 10, 0.01]],
|
124 |
+
article = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=kaushikdatta.owner-manual" alt="visitor badge"/></center>'
|
125 |
+
)
|
126 |
+
|
127 |
+
chatbot_interface.launch()
|
InnovationHub/llm/vector_store.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import plotly.graph_objs as go
|
2 |
+
from sklearn.cluster import KMeans
|
3 |
+
from sklearn.decomposition import PCA
|
4 |
+
import plotly.express as px
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import pprint
|
8 |
+
import codecs
|
9 |
+
import chardet
|
10 |
+
import gradio as gr
|
11 |
+
from langchain.llms import HuggingFacePipeline
|
12 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
13 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
14 |
+
from langchain.vectorstores import FAISS
|
15 |
+
from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate
|
16 |
+
from langchain.memory import ConversationBufferWindowMemory
|
17 |
+
|
18 |
+
|
19 |
+
def get_content(input_file):
|
20 |
+
# Read the input file in binary mode
|
21 |
+
with open(input_file, 'rb') as f:
|
22 |
+
raw_data = f.read()
|
23 |
+
|
24 |
+
# Detect the encoding of the file
|
25 |
+
result = chardet.detect(raw_data)
|
26 |
+
encoding = result['encoding']
|
27 |
+
|
28 |
+
# Decode the contents using the detected encoding
|
29 |
+
with codecs.open(input_file, 'r', encoding=encoding) as f:
|
30 |
+
raw_text = f.read()
|
31 |
+
|
32 |
+
# Return the content of the input file
|
33 |
+
return raw_text
|
34 |
+
|
35 |
+
|
36 |
+
def split_text(input_file, chunk_size=1000, chunk_overlap=0):
|
37 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
38 |
+
chunk_size=chunk_size,
|
39 |
+
chunk_overlap=chunk_overlap,
|
40 |
+
length_function=len,
|
41 |
+
)
|
42 |
+
|
43 |
+
basename = os.path.basename(input_file)
|
44 |
+
basename = os.path.splitext(basename)[0]
|
45 |
+
raw_text = get_content(input_file=input_file)
|
46 |
+
|
47 |
+
texts = text_splitter.split_text(text=raw_text)
|
48 |
+
metadatas = [{"source": f"{basename}[{i}]"} for i in range(len(texts))]
|
49 |
+
docs = text_splitter.create_documents(texts=texts, metadatas=metadatas)
|
50 |
+
|
51 |
+
return texts, metadatas, docs
|
52 |
+
|
53 |
+
|
54 |
+
def create_docs(input_file):
|
55 |
+
# Create a text splitter object with a separator character
|
56 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
57 |
+
chunk_size=1000,
|
58 |
+
chunk_overlap=0,
|
59 |
+
length_function=len,
|
60 |
+
)
|
61 |
+
|
62 |
+
basename = os.path.basename(input_file)
|
63 |
+
basename = os.path.splitext(basename)[0]
|
64 |
+
texts = get_content(input_file=input_file)
|
65 |
+
metadatas = {'source': basename}
|
66 |
+
docs = text_splitter.create_documents(texts=[texts], metadatas=[metadatas])
|
67 |
+
return docs
|
68 |
+
|
69 |
+
|
70 |
+
def get_similar_docs(query, index, k=5):
|
71 |
+
similar_docs = index.similarity_search(query=query, k=k)
|
72 |
+
result = [(d.summary, d.metadata) for d in similar_docs]
|
73 |
+
return result
|
74 |
+
|
75 |
+
|
76 |
+
def convert_to_html(similar_docs):
|
77 |
+
result = []
|
78 |
+
for summary, metadata in similar_docs:
|
79 |
+
record = '<tr><td>' + summary + '</td><td>' + \
|
80 |
+
metadata['source'] + '</td></tr>'
|
81 |
+
result.append(record)
|
82 |
+
html = '<table><thead><th>Page Content</th><th>Source</th></thead><tbody>' + \
|
83 |
+
'\n'.join(result) + '</tbody></table>'
|
84 |
+
return html
|
85 |
+
|
86 |
+
|
87 |
+
def create_similarity_plot(embeddings, labels, query, n_clusters=3):
|
88 |
+
# Only include embeddings that have corresponding labels
|
89 |
+
embeddings_with_labels = [
|
90 |
+
embedding for i, embedding in enumerate(embeddings) if i < len(labels)]
|
91 |
+
|
92 |
+
# Reduce the dimensionality of the embeddings using PCA
|
93 |
+
pca = PCA(n_components=3)
|
94 |
+
pca_embeddings = pca.fit_transform(embeddings_with_labels)
|
95 |
+
|
96 |
+
# Cluster the embeddings using k-means
|
97 |
+
kmeans = KMeans(n_clusters=n_clusters)
|
98 |
+
kmeans.fit(embeddings_with_labels)
|
99 |
+
|
100 |
+
# Create a trace for the query point
|
101 |
+
query_trace = go.Scatter3d(
|
102 |
+
x=[pca_embeddings[-1, 0]],
|
103 |
+
y=[pca_embeddings[-1, 1]],
|
104 |
+
z=[pca_embeddings[-1, 2]],
|
105 |
+
mode='markers',
|
106 |
+
marker=dict(
|
107 |
+
color='black',
|
108 |
+
symbol='diamond',
|
109 |
+
size=10
|
110 |
+
),
|
111 |
+
name=f"Query: '{query}'"
|
112 |
+
)
|
113 |
+
|
114 |
+
# Create a trace for the other points
|
115 |
+
points_trace = go.Scatter3d(
|
116 |
+
x=pca_embeddings[:, 0],
|
117 |
+
y=pca_embeddings[:, 1],
|
118 |
+
z=pca_embeddings[:, 2],
|
119 |
+
mode='markers',
|
120 |
+
marker=dict(
|
121 |
+
color=kmeans.labels_,
|
122 |
+
colorscale=px.colors.qualitative.Alphabet,
|
123 |
+
size=5
|
124 |
+
),
|
125 |
+
text=labels,
|
126 |
+
name='Points'
|
127 |
+
)
|
128 |
+
|
129 |
+
# Create the figure
|
130 |
+
fig = go.Figure(data=[query_trace, points_trace])
|
131 |
+
|
132 |
+
# Add a title and legend
|
133 |
+
fig.update_layout(
|
134 |
+
title="3D Similarity Plot",
|
135 |
+
legend_title_text="Cluster"
|
136 |
+
)
|
137 |
+
|
138 |
+
# Show the plot
|
139 |
+
fig.show()
|
140 |
+
|
141 |
+
|
142 |
+
def plot_similarities(query, index, embeddings=HuggingFaceEmbeddings(), k=5):
|
143 |
+
query_embeddings = embeddings.embed_query(text=query)
|
144 |
+
|
145 |
+
similar_docs = get_similar_docs(query=query, index=index, k=k)
|
146 |
+
texts = []
|
147 |
+
for d in similar_docs:
|
148 |
+
texts.append(d[0])
|
149 |
+
|
150 |
+
embeddings_array = embeddings.embed_documents(texts=texts)
|
151 |
+
|
152 |
+
# Get the index of the query point
|
153 |
+
query_index = len(embeddings_array) - 1
|
154 |
+
|
155 |
+
create_similarity_plot(
|
156 |
+
embeddings=embeddings_array,
|
157 |
+
labels=texts,
|
158 |
+
query_index=query_index,
|
159 |
+
n_clusters=3
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
def start_ui(index):
|
164 |
+
def query_index(query):
|
165 |
+
similar_docs = get_similar_docs(query=query, index=index)
|
166 |
+
formatted_output = convert_to_html(similar_docs=similar_docs)
|
167 |
+
return formatted_output
|
168 |
+
|
169 |
+
# Define input and output types
|
170 |
+
input = gr.inputs.Textbox(lines=2)
|
171 |
+
output = gr.outputs.HTML()
|
172 |
+
|
173 |
+
# Create interface object
|
174 |
+
iface = gr.Interface(fn=query_index,
|
175 |
+
inputs=input,
|
176 |
+
outputs=output)
|
177 |
+
|
178 |
+
# Launch interface
|
179 |
+
iface.launch()
|