palexis3 commited on
Commit
9914661
1 Parent(s): 39fecc7

Implemented rag to respond to user queries

Browse files
Files changed (1) hide show
  1. app/service/transactions_query_rag.py +117 -0
app/service/transactions_query_rag.py CHANGED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pinecone import Pinecone, ServerlessSpec
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain_openai import OpenAIEmbeddings
4
+ from langchain_pinecone import PineconeVectorStore
5
+ from langchain_openai import ChatOpenAI
6
+ from langchain.chains import RetrievalQA
7
+
8
+ from fastapi import HTTPException
9
+
10
+
11
+ import os
12
+ import pandas as pd
13
+ from uuid import uuid4
14
+
15
+ async def answer_query(df: pd.DataFrame, query: str) -> None:
16
+ """Creates an embedding of the transactions table and then returns the answer for the given query.
17
+ Args:
18
+ df (pd.DataFrame): DataFrame containing the transactions that a user has entered
19
+ query (str): The query the user will ask against said embedding
20
+
21
+ Returns:
22
+ str: Response to query
23
+ """
24
+ try:
25
+ batch_limit = 100
26
+
27
+ pinecone_api_key = os.environ['PINECONE_API_KEY']
28
+ openai_api_key = os.environ['OPENAI_API_KEY']
29
+ namespace = "transactionsvector"
30
+
31
+ pc = Pinecone(api_key=pinecone_api_key)
32
+
33
+ embeddings = OpenAIEmbeddings(
34
+ model="text-embedding-3-small",
35
+ openai_api_key=openai_api_key
36
+ )
37
+
38
+ index_name = "transactions_rag"
39
+
40
+ if index_name in pc.list_indexes().names():
41
+ pc.delete_index(index_name)
42
+
43
+ pc.create_index(
44
+ name=index_name,
45
+ dimension=1536,
46
+ metric="cosine",
47
+ spec=ServerlessSpec(
48
+ cloud="aws",
49
+ region="us-east-1"
50
+ )
51
+ )
52
+
53
+ index = pc.Index(index_name)
54
+
55
+ texts = []
56
+ all_texts = []
57
+ metadatas = []
58
+
59
+ text_splitter = RecursiveCharacterTextSplitter(
60
+ chunk_size=1000,
61
+ chunk_overlap=100
62
+ )
63
+
64
+ for _, record in df.iterrows():
65
+ content_texts = text_splitter.split_text(record['content'])
66
+
67
+ metadata = {
68
+ 'user_id': str(record['user_id'])
69
+ }
70
+ content_metadata = [{
71
+ "chunk": j, "text": text, **metadata
72
+ } for j, text in enumerate(content_texts)]
73
+
74
+ texts.extend(content_texts)
75
+ all_texts.extend(content_texts)
76
+ metadatas.extend(content_metadata)
77
+
78
+ # If we have reached the batch limit, then add the texts and reset
79
+ if len(texts) >= batch_limit:
80
+ ids = [str(uuid4()) for _ in range(len(texts))]
81
+ embeds = embeddings.embed_documents(texts)
82
+ index.upsert(vectors=zip(ids, embeds, metadatas))
83
+ texts = []
84
+ metadatas = []
85
+
86
+ if len(texts) > 0:
87
+ ids = [str(uuid4()) for _ in range(len(texts))]
88
+ embeds = embeddings.embed_documents(texts)
89
+ index.upsert(vectors=zip(ids, embeds, metadatas))
90
+
91
+
92
+
93
+ transactions_search = PineconeVectorStore.from_documents(
94
+ documents=all_texts,
95
+ index_name=index_name,
96
+ embedding=embeddings,
97
+ namespace=namespace
98
+ )
99
+
100
+ llm = ChatOpenAI(
101
+ openai_api_key=openai_api_key,
102
+ model_name="gpt-3.5-turbo",
103
+ temperature=0.0
104
+ )
105
+
106
+ qa = RetrievalQA.from_llm(
107
+ llm=llm,
108
+ retriever=transactions_search.as_retriever()
109
+ )
110
+
111
+ answer = qa.invoke(query)
112
+
113
+ return answer
114
+
115
+ except Exception as e:
116
+ raise HTTPException(status_code = 500, detail=f"fetch_pinecone_service error: {str(e)}")
117
+