amzand commited on
Commit
83251ae
ยท
verified ยท
1 Parent(s): 98a25af

Upload OpenAI_interface.py

Browse files
Files changed (1) hide show
  1. OpenAI_interface.py +180 -0
OpenAI_interface.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import pandas as pd
3
+ import os
4
+ import re
5
+ import tiktoken
6
+ import numpy as np
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ from sentence_transformers import SentenceTransformer
9
+ from OpenAI_tools import run_report_classifier
10
+
11
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
12
+ # ๐Ÿ” OpenAI setup
13
+ client = OpenAI(api_key="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")
14
+
15
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
16
+ # ๐Ÿ“„ Load high-priority agency directory
17
+ AGENCY_CSV = "high_priority_agencies.csv"
18
+ df = pd.read_csv(AGENCY_CSV)
19
+
20
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
21
+ # ๐Ÿค– Load embedding model and precompute agency embeddings
22
+ model = SentenceTransformer("all-MiniLM-L6-v2")
23
+ agency_names = df["agency_name"].tolist()
24
+ agency_embeddings = model.encode(agency_names)
25
+
26
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
27
+ # ๐Ÿง  Cosine similarity matcher for agency name
28
+ def resolve_agency_index(agency_name):
29
+ input_vec = model.encode([agency_name])
30
+ sims = cosine_similarity(input_vec, agency_embeddings)[0]
31
+
32
+ top_k = 3
33
+ top_indices = sims.argsort()[-top_k:][::-1]
34
+ print("๐Ÿ” Top cosine matches:")
35
+ for idx in top_indices:
36
+ print(f" โ€ข {df.iloc[idx]['agency_name']} (score: {sims[idx]:.2f})")
37
+
38
+ best_idx = top_indices[0]
39
+ best_score = sims[best_idx]
40
+ best_name = df.iloc[best_idx]["agency_name"]
41
+
42
+ if best_score >= 0.7:
43
+ print(f"๐Ÿง  Cosine match for '{agency_name}' โž '{best_name}' (score={best_score:.2f})")
44
+ return best_idx, best_name
45
+ else:
46
+ print(f"โŒ No confident match found for agency: '{agency_name}' (score={best_score:.2f})")
47
+ return None, None
48
+
49
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
50
+ # ๐Ÿ“Š Token counting utility
51
+ def count_tokens(messages, model="gpt-3.5-turbo"):
52
+ try:
53
+ encoding = tiktoken.encoding_for_model(model)
54
+ except KeyError:
55
+ encoding = tiktoken.get_encoding("cl100k_base")
56
+
57
+ num_tokens = 0
58
+ for message in messages:
59
+ num_tokens += 4 # message overhead
60
+ for key, value in message.items():
61
+ num_tokens += len(encoding.encode(value))
62
+ num_tokens += 2 # reply overhead
63
+ return num_tokens
64
+
65
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
66
+ # ๐Ÿ“ฌ Query OpenAI for structured extraction or conversation
67
+ def ask_openai(prompt, chatbot_mode=False):
68
+ system_prompt = (
69
+ "You are a helpful assistant that responds casually and explains things clearly."
70
+ if chatbot_mode else
71
+ "You are an extraction agent. Extract the following from the userโ€™s prompt. "
72
+ "Respond only in the format:\n"
73
+ "Agency: [agency name]\nKeyword: [keyword]\nYear: [4-digit year or None]"
74
+ )
75
+
76
+ messages = [
77
+ {"role": "system", "content": system_prompt},
78
+ {"role": "user", "content": prompt}
79
+ ]
80
+
81
+ num_tokens = count_tokens(messages)
82
+ cost = num_tokens / 1000 * 0.0015
83
+
84
+ response = client.chat.completions.create(
85
+ model="gpt-3.5-turbo",
86
+ messages=messages,
87
+ temperature=0.2
88
+ )
89
+
90
+ print(f"๐Ÿงฎ Tokens used: {num_tokens}")
91
+ print(f"๐Ÿ’ฐ Estimated cost: ${cost:.4f}")
92
+
93
+ return response.choices[0].message.content
94
+
95
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
96
+ # ๐Ÿ“ค Extract structured values from model response
97
+ def extract_fields(text):
98
+ agency = "unknown"
99
+ keyword = "budget"
100
+ year = None
101
+
102
+ for line in text.lower().splitlines():
103
+ if "agency" in line:
104
+ agency = line.split(":", 1)[-1].strip()
105
+ elif "keyword" in line:
106
+ keyword = line.split(":", 1)[-1].strip()
107
+ elif "year" in line:
108
+ match = re.search(r"\d{4}", line)
109
+ if match:
110
+ year = int(match.group())
111
+
112
+ return {"agency": agency, "keyword": keyword, "year": year}
113
+
114
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
115
+ # ๐Ÿงพ Main CLI loop
116
+ def main():
117
+ print("๐Ÿค– OpenAI Agent Online. Ask about agency budgets or reports.")
118
+ print("Say 'let's talk' to switch to chatbot mode.")
119
+ print("Say 'let's search' to return to extraction/search mode.")
120
+ print("Say 'exit' or 'quit' to finish.\n")
121
+
122
+ chatbot_mode = False
123
+
124
+ while True:
125
+ user_input = input("You > ").strip()
126
+ if not user_input:
127
+ print("โš ๏ธ Please enter a valid question.")
128
+ continue
129
+
130
+ lowered = user_input.lower()
131
+ if lowered in ["exit", "quit"]:
132
+ print("๐Ÿ‘‹ Goodbye!")
133
+ break
134
+ elif lowered == "let's talk":
135
+ chatbot_mode = True
136
+ print("๐Ÿ—ฃ๏ธ Switched to chatbot mode.")
137
+ continue
138
+ elif lowered == "let's search":
139
+ chatbot_mode = False
140
+ print("๐Ÿ” Switched to extraction/search mode.")
141
+ continue
142
+
143
+ try:
144
+ if chatbot_mode:
145
+ response = ask_openai(user_input, chatbot_mode=True)
146
+ print("\n๐Ÿ’ฌ Chatbot Response:\n" + response + "\n")
147
+ else:
148
+ response = ask_openai(user_input, chatbot_mode=False)
149
+ print("\n๐Ÿง  LLM Response:\n" + response + "\n")
150
+
151
+ parsed = extract_fields(response)
152
+ agency, keyword, year = parsed["agency"], parsed["keyword"], parsed["year"]
153
+ print(f"๐Ÿงพ Parsed โ†’ Agency: {agency} | Keyword: {keyword} | Year: {year}")
154
+
155
+ index, resolved_agency = resolve_agency_index(agency)
156
+ if index is None:
157
+ print(f"โš ๏ธ Could not resolve agency: {agency}")
158
+ continue
159
+
160
+ print(f"๐Ÿš€ Launching search for '{resolved_agency}' (index {index}) with keyword '{keyword}' and FY {year}\n")
161
+
162
+ run_report_classifier(
163
+ agency_df=df,
164
+ search_term=keyword,
165
+ fiscal_year=year if year else "",
166
+ start_index=index,
167
+ end_index=index,
168
+ max_results=15,
169
+ output_filename="openAI_bot_output.csv",
170
+ brave_api_key="BSAnrtOGAioqFKfAPoKPl1tjiNZMyLW",
171
+ google_api_key="AIzaSyBf8FTeYbZWclDiDnf4eFudlWPQAhOybVY",
172
+ google_cse_id="f3d82263565884717"
173
+ )
174
+
175
+ except Exception as e:
176
+ print(f"โŒ Error: {e}")
177
+
178
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
179
+ if __name__ == "__main__":
180
+ main()