dataprincess commited on
Commit
d32a867
·
verified ·
1 Parent(s): ec4765e

improved code

Browse files
Files changed (1) hide show
  1. app.py +12 -22
app.py CHANGED
@@ -12,16 +12,21 @@ import streamlit as st
12
  FILE_PATH = "anjibot_chunks.json"
13
  BATCH_SIZE = 384
14
  INDEX_NAME = "groq-llama-3-rag"
15
- PINECONE_API_KEY = os.getenv["PINECONE_API_KEY"]
16
- GROQ_API_KEY = os.getenv["GROQ_API_KEY"]
17
  DIMENSIONS = 768
18
 
 
 
 
 
 
 
19
 
20
  def load_data(file_path: str) -> dict:
21
  with open(file_path, 'r') as file:
22
  return json.load(file)
23
 
24
-
25
  def initialize_pinecone(api_key: str, index_name: str, dims: int) -> any:
26
  pc = Pinecone(api_key=api_key)
27
  spec = ServerlessSpec(cloud="aws", region='us-east-1')
@@ -38,10 +43,7 @@ def initialize_pinecone(api_key: str, index_name: str, dims: int) -> any:
38
 
39
  return pc.Index(index_name)
40
 
41
-
42
  def upsert_data_to_pinecone(index: any, data: dict):
43
- encoder = SentenceTransformer('dwzhu/e5-base-4k')
44
-
45
  for i in tqdm(range(0, len(data['id']), BATCH_SIZE)):
46
  # Find end of batch
47
  i_end = min(len(data['id']), i + BATCH_SIZE)
@@ -60,17 +62,15 @@ def upsert_data_to_pinecone(index: any, data: dict):
60
  to_upsert = list(zip(batch["id"], embeds, batch["metadata"]))
61
  index.upsert(vectors=to_upsert)
62
 
63
-
64
  def get_docs(query: str, index: any, encoder: any, top_k: int) -> list[str]:
65
  xq = encoder.encode(query)
66
  res = index.query(vector=xq.tolist(), top_k=top_k, include_metadata=True)
67
  return [x["metadata"]['content'] for x in res["matches"]]
68
 
69
-
70
  def get_response(query: str, docs: list[str], groq_client: any) -> str:
71
  system_message = (
72
- "You are Anjibot, the AI course rep of 400 Level Computer Science department. You are always helpful, jovial, can be sarcastica but still sweet.\n"
73
- "Provide the answer to class related queries using\n"
74
  "context provided below.\n"
75
  "If you don't the answer to the user's question based on your pretrained knowledge and the context provided, just direct the user to Anji the human course rep.\n"
76
  "Anji's phone number: 08145170886.\n\n"
@@ -88,19 +88,11 @@ def get_response(query: str, docs: list[str], groq_client: any) -> str:
88
  )
89
  return chat_response.choices[0].message.content
90
 
91
-
92
  def handle_query(user_query: str):
93
- # Load data
94
- data = load_data(FILE_PATH)
95
-
96
- # Initialize Pinecone
97
- index = initialize_pinecone(PINECONE_API_KEY, INDEX_NAME, DIMENSIONS)
98
-
99
- # Upsert data into Pinecone
100
  upsert_data_to_pinecone(index, data)
101
 
102
- # Initialize encoder and Groq client
103
- encoder = SentenceTransformer('dwzhu/e5-base-4k')
104
  groq_client = Groq(api_key=GROQ_API_KEY)
105
 
106
  # Get relevant documents
@@ -132,5 +124,3 @@ def main():
132
 
133
  if __name__ == "__main__":
134
  main()
135
-
136
-
 
12
  FILE_PATH = "anjibot_chunks.json"
13
  BATCH_SIZE = 384
14
  INDEX_NAME = "groq-llama-3-rag"
15
+ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
16
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
17
  DIMENSIONS = 768
18
 
19
+ # Load data once at the start
20
+ data = load_data(FILE_PATH)
21
+
22
+ # Initialize Pinecone and SentenceTransformer once
23
+ index = initialize_pinecone(PINECONE_API_KEY, INDEX_NAME, DIMENSIONS)
24
+ encoder = SentenceTransformer('dwzhu/e5-base-4k')
25
 
26
  def load_data(file_path: str) -> dict:
27
  with open(file_path, 'r') as file:
28
  return json.load(file)
29
 
 
30
  def initialize_pinecone(api_key: str, index_name: str, dims: int) -> any:
31
  pc = Pinecone(api_key=api_key)
32
  spec = ServerlessSpec(cloud="aws", region='us-east-1')
 
43
 
44
  return pc.Index(index_name)
45
 
 
46
  def upsert_data_to_pinecone(index: any, data: dict):
 
 
47
  for i in tqdm(range(0, len(data['id']), BATCH_SIZE)):
48
  # Find end of batch
49
  i_end = min(len(data['id']), i + BATCH_SIZE)
 
62
  to_upsert = list(zip(batch["id"], embeds, batch["metadata"]))
63
  index.upsert(vectors=to_upsert)
64
 
 
65
  def get_docs(query: str, index: any, encoder: any, top_k: int) -> list[str]:
66
  xq = encoder.encode(query)
67
  res = index.query(vector=xq.tolist(), top_k=top_k, include_metadata=True)
68
  return [x["metadata"]['content'] for x in res["matches"]]
69
 
 
70
  def get_response(query: str, docs: list[str], groq_client: any) -> str:
71
  system_message = (
72
+ "You are Anjibot, the AI course rep of 400 Level Computer Science department. You are always helpful, jovial, can be sarcastic but still sweet.\n"
73
+ "Provide the answer to class-related queries using\n"
74
  "context provided below.\n"
75
  "If you don't the answer to the user's question based on your pretrained knowledge and the context provided, just direct the user to Anji the human course rep.\n"
76
  "Anji's phone number: 08145170886.\n\n"
 
88
  )
89
  return chat_response.choices[0].message.content
90
 
 
91
  def handle_query(user_query: str):
92
+ # Upsert data into Pinecone (if necessary)
 
 
 
 
 
 
93
  upsert_data_to_pinecone(index, data)
94
 
95
+ # Initialize Groq client
 
96
  groq_client = Groq(api_key=GROQ_API_KEY)
97
 
98
  # Get relevant documents
 
124
 
125
  if __name__ == "__main__":
126
  main()