EchaRz commited on
Commit
56783f6
·
1 Parent(s): 20ba4c8

Rename backend.py to app.py

Browse files
Files changed (1) hide show
  1. app.py +596 -0
app.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import contextlib
3
+ import json
4
+ from http.server import BaseHTTPRequestHandler
5
+ from urllib.parse import urlparse, parse_qs
6
+ import traceback
7
+ from pydantic import BaseModel, Field
8
+ from typing import List, Dict, Tuple
9
+ import os
10
+ from langchain_community.vectorstores import FAISS
11
+ from langchain_community.embeddings import FakeEmbeddings
12
+ from langchain_community.vectorstores.utils import DistanceStrategy
13
+ from together import Together
14
+ import numpy as np
15
+ from collections import defaultdict
16
+
17
+ app = FastAPI(title="Knowledge Graph API")
18
+
19
+ # Enable CORS for frontend access
20
+ app.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["*"],
23
+ allow_credentials=True,
24
+ allow_methods=["*"],
25
+ allow_headers=["*"],
26
+ )
27
+
28
+ # Database configuration - UPDATE THESE PATHS
29
+ DATABASE_CONFIG = {
30
+ "triplets_db": "triplets_new.db",
31
+ "definitions_db": "relations_new.db",
32
+ "news_db": "cnnhealthnews2.db",
33
+ "triplets_table": "triplets",
34
+ "definitions_table": "relations",
35
+ "head_column": "head_entity",
36
+ "relation_column": "relation",
37
+ "tail_column": "tail_entity",
38
+ "definition_column": "definition",
39
+ "link_column": "link",
40
+ "title_column": "column",
41
+ "content_column": "content"
42
+ }
43
+
44
+ class GraphNode(BaseModel):
45
+ id: str
46
+ label: str
47
+ type: str = "entity"
48
+
49
+ class GraphEdge(BaseModel):
50
+ source: str
51
+ target: str
52
+ relation: str
53
+ definition: Optional[str] = None
54
+
55
+ class GraphData(BaseModel):
56
+ nodes: List[GraphNode]
57
+ edges: List[GraphEdge]
58
+
59
+ class TripletData(BaseModel):
60
+ head: str
61
+ relation: str
62
+ tail: str
63
+
64
+ class RelationDefinition(BaseModel):
65
+ relation: str
66
+ definition: str
67
+
68
+ class RetrieveTripletsResponse(BaseModel):
69
+ triplets: List[TripletData]
70
+ relations: List[RelationDefinition]
71
+
72
+ class NewsItem(BaseModel):
73
+ url: str
74
+ content: str
75
+ preview: str
76
+ title: str
77
+
78
+ class QueryRequest(BaseModel):
79
+ query: str
80
+
81
+ class QueryResponse(BaseModel):
82
+ answer: str
83
+ triplets: List[TripletData]
84
+ relations: List[RelationDefinition]
85
+ news_items: List[NewsItem]
86
+ graph_data: GraphData
87
+
88
+ class ExtractedInformationNews(BaseModel):
89
+ extracted_information: str = Field(description="Extracted information")
90
+ links: list = Field(description="citation links")
91
+
92
+ class ExtractedInformation(BaseModel):
93
+ extracted_information: str = Field(description="Extracted information")
94
+
95
+ @contextlib.contextmanager
96
+ def get_triplets_db():
97
+ conn = None
98
+ try:
99
+ conn = sqlite3.connect(DATABASE_CONFIG["triplets_db"])
100
+ yield conn
101
+ finally:
102
+ if conn:
103
+ conn.close()
104
+
105
+ @contextlib.contextmanager
106
+ def get_news_db():
107
+ conn = None
108
+ try:
109
+ conn = sqlite3.connect(DATABASE_CONFIG["news_db"])
110
+ yield conn
111
+ finally:
112
+ if conn:
113
+ conn.close()
114
+
115
+ @contextlib.contextmanager
116
+ def get_definitions_db():
117
+ conn = None
118
+ try:
119
+ conn = safe_connect(DATABASE_CONFIG["definitions_db"])
120
+ yield conn
121
+ finally:
122
+ if conn:
123
+ conn.close()
124
+
125
+ def retrieve_triplets(query: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
126
+ """
127
+ Args:
128
+ query (str): User query
129
+
130
+ Returns:
131
+ Tuple containing:
132
+ - List of triplets: [(head, relation, tail), ...]
133
+ - List of relations with definitions: [(relation, definition), ...]
134
+ """
135
+ API_KEY = os.environ.get("TOGETHER_API_KEY")
136
+ client = Together(api_key = API_KEY)
137
+
138
+ dummy_embeddings = FakeEmbeddings(size=768)
139
+ triplets_store = FAISS.load_local(
140
+ "triplets_index_compressed", dummy_embeddings, allow_dangerous_deserialization=True
141
+ )
142
+ triplets_store.index.nprobe = 100
143
+ triplets_store._normalize_L2 = True
144
+ triplets_store.distance_strategy = DistanceStrategy.COSINE
145
+
146
+ response = client.embeddings.create(
147
+ model = "Alibaba-NLP/gte-modernbert-base",
148
+ input = query
149
+ )
150
+
151
+ emb = np.array(response.data[0].embedding)
152
+ emb = emb / np.linalg.norm(emb)
153
+
154
+ related_head_entity = []
155
+ result_triplets = triplets_store.similarity_search_with_score_by_vector(emb, k=100)
156
+ for res, score in result_triplets:
157
+ if score > 0.7:
158
+ related_head_entity.append(res)
159
+
160
+ try:
161
+ all_triplets = []
162
+ with get_triplets_db() as conn:
163
+ head_col = DATABASE_CONFIG["head_column"]
164
+ rel_col = DATABASE_CONFIG["relation_column"]
165
+ tail_col = DATABASE_CONFIG["tail_column"]
166
+
167
+ for head_entity in related_head_entity:
168
+ he = head_entity.page_content
169
+ cursor = conn.cursor()
170
+ cursor.execute("SELECT * FROM triplets WHERE head_entity = (?)", ([he]))
171
+ rows = cursor.fetchall()
172
+ triplets = [(str(row[0]), str(row[1]), str(row[2])) for row in rows]
173
+ all_triplets += triplets
174
+
175
+ all_relations = []
176
+ relations = [relation for _, relation, _ in all_triplets]
177
+ with get_definitions_db() as conn:
178
+ rel_col = DATABASE_CONFIG["relation_column"]
179
+ def_col = DATABASE_CONFIG["definition_column"]
180
+
181
+ for rel in set(relations):
182
+ cursor = conn.cursor()
183
+ cursor.execute("SELECT * FROM relations WHERE relation = (?)", ([rel]))
184
+ rows = cursor.fetchall()
185
+ relation = [(str(row[0]), str(row[1])) for row in rows]
186
+ all_relations += relation
187
+
188
+ return all_triplets, all_relations
189
+
190
+ except Exception as e:
191
+ print(f"Error in retrieve_triplets: {e}")
192
+ return [], []
193
+
194
+ def retrieve_news(query: str) -> Dict[str, str]:
195
+ """
196
+ Args:
197
+ query (str): User query
198
+
199
+ Returns: Tuple
200
+ - Related content
201
+ - Links of the related content
202
+ """
203
+ API_KEY = os.environ.get("TOGETHER_API_KEY")
204
+ client = Together(api_key = API_KEY)
205
+
206
+ dummy_embeddings = FakeEmbeddings(size=768)
207
+ news_store = FAISS.load_local(
208
+ "news_index_compressed", dummy_embeddings, allow_dangerous_deserialization=True
209
+ )
210
+ news_store.index.nprobe = 100
211
+ news_store._normalize_L2 = True
212
+ news_store.distance_strategy = DistanceStrategy.COSINE
213
+
214
+ news_store._normalize_L2 = True
215
+ news_store.distance_strategy = DistanceStrategy.COSINE
216
+
217
+ response = client.embeddings.create(
218
+ model = "Alibaba-NLP/gte-modernbert-base",
219
+ input = query
220
+ )
221
+
222
+ emb = np.array(response.data[0].embedding)
223
+ emb = emb / np.linalg.norm(emb)
224
+
225
+ related_news_content = []
226
+ result_news= news_store.similarity_search_with_score_by_vector(emb, k=500)
227
+ for res, score in result_news:
228
+ if score > 0.7:
229
+ print(score)
230
+ related_news_content.append(res)
231
+
232
+ news_dict = defaultdict(list)
233
+ links = [res.metadata["link"] for res in related_news_content]
234
+ for idx, link in enumerate(links):
235
+ news_dict[link].append(related_news_content[idx].page_content)
236
+
237
+ content_only = [". ".join(sentences) for sentences in news_dict.values()]
238
+
239
+ return content_only, links
240
+
241
+
242
+ def extract_information_from_triplets(query: str,
243
+ triplets: List[Tuple[str, str, str]],
244
+ relations: List[Tuple[str, str]]) -> str:
245
+ """
246
+ REPLACE THIS FUNCTION WITH YOUR ACTUAL IMPLEMENTATION
247
+
248
+ Args:
249
+ triplets: List of triplets from retrieve_triplets
250
+ relations: List of relation definitions from retrieve_triplets
251
+
252
+ Returns:
253
+ str: Extracted information from triplets
254
+ """
255
+ system_prompt = f'''Given a a list of relational triplets and a list of relation and its definition. Extract the information from the triplets to answer query question.
256
+ If there is no related or useful information can be extracted from the triplets to answer the query question, inform "No related information found."
257
+ Give the output in paragraphs form narratively, you can explain the reason behind your answer in detail."
258
+ '''
259
+
260
+ user_prompt = f'''
261
+ query question: {query}
262
+ list of triplets: {triplets}
263
+ list of relations and their definition: {relations}
264
+ extracted information:
265
+ '''
266
+
267
+ API_KEY = os.environ.get("TOGETHER_API_KEY")
268
+ client = Together(api_key = API_KEY)
269
+
270
+ response = client.chat.completions.create(
271
+ model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
272
+ temperature = 0,
273
+ messages=[{
274
+ "role": "system",
275
+ "content": [
276
+ {"type": "text", "text":system_prompt}
277
+ ]
278
+ },
279
+ {
280
+ "role": "user",
281
+ "content": [
282
+ {"type": "text", "text":user_prompt},
283
+ ]
284
+ }]
285
+ )
286
+
287
+ return response.choices[0].message.content
288
+
289
+ def extract_information_from_news(query: str,
290
+ news_list: Dict[str, str]) -> Tuple[str, List[str]]:
291
+ """
292
+ Args:
293
+ news_list: List from retrieve_news
294
+
295
+ Returns:
296
+ Extracted information string
297
+ """
298
+ system_prompt = f'''Given a list of some information related to the query, extract all important information from the list to answer query question.
299
+ Every item in the list represent one information, if the information is ambiguous (e.g. contains unknown pronoun to which it refers), do not use that information to answer the query.
300
+ You don't have to use all the information, only use the information that has clarity and a good basis, but try to use as many information as possible.
301
+ If there is no related or useful information can be extracted from the news information to answer the query question, write "No related information found." as the extracted_information output.
302
+ Give the extracted_information output in paragraphs form detailedly.
303
+ The output must be in this form: {{"extracted_information": <output paragraphs>}}
304
+ '''
305
+
306
+ user_prompt = f'''
307
+ query: {query}
308
+ news list: {news_list}
309
+ output:
310
+ '''
311
+
312
+ response = client.chat.completions.create(
313
+ model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
314
+ response_format={
315
+ "type": "json_schema",
316
+ "schema": ExtractedInformation.model_json_schema(),
317
+ },
318
+ temperature = 0,
319
+ messages=[{
320
+ "role": "system",
321
+ "content": [
322
+ {"type": "text", "text":system_prompt}
323
+ ]
324
+ },
325
+ {
326
+ "role": "user",
327
+ "content": [
328
+ {"type": "text", "text":user_prompt},
329
+ ]
330
+ }]
331
+ )
332
+ response = json.loads(response.choices[0].message.content)
333
+ info = response['extracted_information']
334
+
335
+ return info
336
+
337
+ def extract_information(query:str, triplet_info: str, news_info: str, language:str) -> str:
338
+ """
339
+ Args:
340
+ triplet_info: Information extracted from triplets
341
+ news_info: Information extracted from news
342
+
343
+ Returns:
344
+ str: Final answer for the user
345
+ """
346
+ client = Together(api_key = API_KEY)
347
+ system_prompt = f'''Given information from two sources, combine the information and make a comprehensive and informative paragraph that answer the query.
348
+ Make sure the output paragraph includes all crucial information and given in detail.
349
+ If there is no related or useful information can be extracted from the triplets to answer the query question, inform "No related information found."
350
+ Remember this paragraph will be shown to user, so make sure it is based on facts and data, also use appropriate language.
351
+ The output must be in this form and in {language} language: {{"extracted_information": <output paragraphs>}}
352
+ '''
353
+
354
+ user_prompt = f'''
355
+ query: {query}
356
+ first source: {triplet_info}
357
+ second source: {news_info}
358
+ extracted information:
359
+ '''
360
+
361
+ response = client.chat.completions.create(
362
+ model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
363
+ response_format={
364
+ "type": "json_schema",
365
+ "schema": ExtractedInformation.model_json_schema(),
366
+ },
367
+ temperature = 0,
368
+ messages=[{
369
+ "role": "system",
370
+ "content": [
371
+ {"type": "text", "text":system_prompt}
372
+ ]
373
+ },
374
+ {
375
+ "role": "user",
376
+ "content": [
377
+ {"type": "text", "text":user_prompt},
378
+ ]
379
+ }]
380
+ )
381
+
382
+ response = json.loads(response.choices[0].message.content)
383
+ answer = response["extracted_information"]
384
+ return answer
385
+
386
+ def news_preview(links: list[str]) -> Tuple[str, str, str]:
387
+ try:
388
+ preview_contents = []
389
+ with get_news_db() as conn:
390
+ for i in links:
391
+ cursor = conn.cursor()
392
+ cursor.execute("SELECT link, title, content FROM CNNHEALTHNEWS2 WHERE link = (?)", ([i]))
393
+ rows = cursor.fetchall()
394
+ prevs = [(str(row[0]), str(row[1]), str(row[2])) for row in rows]
395
+ preview_contents += prevs
396
+
397
+ return preview_contents
398
+
399
+ except Exception as e:
400
+ print(f"Error in news_preview: {e}")
401
+ return ("", "", "")
402
+
403
+ class Language(BaseModel):
404
+ query: str = Field(description="Translated query")
405
+ language: str = Field(description="Query's language")
406
+
407
+ def query_language(query):
408
+ system_prompt = f'''Your task is to determine what language the question is written in and translate it to english if it is not in English.
409
+ The output must be in this form: {{query: <translated query>, language: <query's language>}}
410
+ '''
411
+
412
+ user_prompt = f'''
413
+ query: {query}
414
+ output:
415
+ '''
416
+ client = Together(api_key = API_KEY)
417
+
418
+ response = client.chat.completions.create(
419
+ model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
420
+ response_format={
421
+ "type": "json_schema",
422
+ "schema": Language.model_json_schema(),
423
+ },
424
+ temperature = 0,
425
+ messages=[{
426
+ "role": "system",
427
+ "content": [
428
+ {"type": "text", "text":system_prompt}
429
+ ]
430
+ },
431
+ {
432
+ "role": "user",
433
+ "content": [
434
+ {"type": "text", "text":user_prompt},
435
+ ]
436
+ }])
437
+
438
+ return json.loads(response.choices[0].message.content)
439
+
440
+ #API ENDPOINTS
441
+
442
+ @app.post("/api/query", response_model=QueryResponse)
443
+ def process_query(request: QueryRequest):
444
+ """Process user query and return comprehensive response"""
445
+ try:
446
+ # Step 1: Retrieve triplets
447
+ query = request.query
448
+ query = query_language(query)
449
+
450
+ triplets_data, relations_data = retrieve_triplets(query['query'])
451
+
452
+ # Step 2: Retrieve news
453
+ news_list, news_links = retrieve_news(query['query'])
454
+
455
+ # Step 3: Extract information from triplets
456
+ triplet_info = extract_information_from_triplets(query['query'], triplets_data, relations_data)
457
+
458
+ # Step 4: Extract information from news
459
+ news_info = extract_information_from_news(query['query'], news_list)
460
+
461
+ # Step 5: Generate final answer
462
+ final_answer = extract_information(query['query'], triplet_info, news_info, query['language'])
463
+
464
+ # Convert triplets to response format
465
+ triplets = [TripletData(head=t[0], relation=t[1], tail=t[2]) for t in triplets_data]
466
+ relations = [RelationDefinition(relation=r[0], definition=r[1]) for r in relations_data]
467
+
468
+ # Convert news to response format with previews
469
+ news_prev = news_preview(news_links)
470
+ news_items = []
471
+ for url, title, content in news_prev:
472
+ preview = content[:300] + "..." if len(content) > 300 else content
473
+ news_items.append(NewsItem(
474
+ url=url,
475
+ content=content,
476
+ preview=preview,
477
+ title=title
478
+ ))
479
+
480
+ # Create mini graph data for visualization
481
+ nodes_set = set()
482
+ edges = []
483
+
484
+ for triplet in triplets_data:
485
+ head, relation, tail = triplet
486
+ nodes_set.add(head)
487
+ nodes_set.add(tail)
488
+
489
+ # Find definition for this relation
490
+ definition = "No definition available"
491
+ for rel, def_text in relations_data:
492
+ if rel == relation:
493
+ definition = def_text
494
+ break
495
+
496
+ edges.append(GraphEdge(
497
+ source=head,
498
+ target=tail,
499
+ relation=relation,
500
+ definition=definition
501
+ ))
502
+
503
+ nodes = [GraphNode(id=node, label=node) for node in nodes_set]
504
+ graph_data = GraphData(nodes=nodes, edges=edges)
505
+
506
+ return QueryResponse(
507
+ answer=final_answer,
508
+ triplets=triplets,
509
+ relations=relations,
510
+ news_items=news_items,
511
+ graph_data=graph_data
512
+ )
513
+
514
+ except Exception as e:
515
+ print(f"Error in process_query: {e}")
516
+ raise HTTPException(status_code=500, detail=f"Query processing failed: {str(e)}")
517
+
518
+ @app.get("/api/graph", response_model=GraphData)
519
+ def get_graph_data(
520
+ search: Optional[str] = None,
521
+ triplets_db: sqlite3.Connection = Depends(get_triplets_connection),
522
+ definitions_db: sqlite3.Connection = Depends(get_definitions_connection)
523
+ ):
524
+ """Get complete graph data with nodes and edges."""
525
+
526
+ try:
527
+ # Build dynamic query based on configuration
528
+ table = DATABASE_CONFIG["triplets_table"]
529
+ head_col = DATABASE_CONFIG["head_column"]
530
+ rel_col = DATABASE_CONFIG["relation_column"]
531
+ tail_col = DATABASE_CONFIG["tail_column"]
532
+
533
+ base_query = f"SELECT {head_col}, {rel_col}, {tail_col} FROM {table}"
534
+ params = []
535
+
536
+ if search:
537
+ base_query += f" WHERE {head_col} LIKE ? OR {tail_col} LIKE ? OR {rel_col} LIKE ?"
538
+ search_term = f"%{search}%"
539
+ params = [search_term, search_term, search_term]
540
+
541
+ base_query += " LIMIT 1000"
542
+
543
+ # Get triplets
544
+ cursor = triplets_db.execute(base_query, params)
545
+ triplets = cursor.fetchall()
546
+
547
+ with get_definitions_db() as conn:
548
+ # Get definitions
549
+ def_table = DATABASE_CONFIG["definitions_table"]
550
+ def_col = DATABASE_CONFIG["definition_column"]
551
+ rel_col_def = DATABASE_CONFIG["relation_column"]
552
+
553
+ def_cursor = conn.execute(f"SELECT {rel_col_def}, {def_col} FROM {def_table}")
554
+ definitions = {row[0]: row[1] for row in def_cursor.fetchall()}
555
+
556
+ # Build nodes and edges
557
+ nodes_set = set()
558
+ edges = []
559
+
560
+ for triple in triplets:
561
+ head = triple[0]
562
+ relation = triple[1]
563
+ tail = triple[2]
564
+
565
+ # Add entities to nodes set
566
+ nodes_set.add(head)
567
+ nodes_set.add(tail)
568
+
569
+ # Create edge with definition
570
+ edge = GraphEdge(
571
+ source=head,
572
+ target=tail,
573
+ relation=relation,
574
+ definition=definitions.get(relation, "No definition available")
575
+ )
576
+ edges.append(edge)
577
+
578
+ # Convert nodes set to list of GraphNode objects
579
+ nodes = [GraphNode(id=node, label=node) for node in nodes_set]
580
+
581
+ return GraphData(nodes=nodes, edges=edges)
582
+
583
+ except Exception as e:
584
+ print(f"Error in get_graph_data: {e}")
585
+ raise HTTPException(status_code=500, detail=f"Database query failed: {str(e)}")
586
+
587
+ if __name__ == "__main__":
588
+ print("Starting Knowledge Graph API...")
589
+ print(f"Triplets DB: {DATABASE_CONFIG['triplets_db']}")
590
+ print(f"Definitions DB: {DATABASE_CONFIG['definitions_db']}")
591
+
592
+ import uvicorn
593
+ port = int(os.environ.get("PORT", 8000))
594
+ uvicorn.run(app, host="0.0.0.0", port=port)
595
+
596
+