Asaad Almutareb commited on
Commit
2e6490e
1 Parent(s): 5c0a79d

added sqlite schema and handling

Browse files
innovation_pathfinder_ai/database/db_handler.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlmodel import SQLModel, create_engine, Session, select
2
+ from innovation_pathfinder_ai.database.schema import Sources
3
+ from innovation_pathfinder_ai.utils.logger import get_console_logger
4
+
5
+ sqlite_file_name = "database.db"
6
+ sqlite_url = f"sqlite:///{sqlite_file_name}"
7
+ engine = create_engine(sqlite_url, echo=False)
8
+
9
+ logger = get_console_logger("db_handler")
10
+
11
+ SQLModel.metadata.create_all(engine)
12
+
13
+
14
+ def read_one(hash_id: dict):
15
+ with Session(engine) as session:
16
+ statement = select(Sources).where(Sources.hash_id == hash_id)
17
+ sources = session.exec(statement).first()
18
+ return sources
19
+
20
+
21
+ def add_one(data: dict):
22
+ with Session(engine) as session:
23
+ if session.exec(
24
+ select(Sources).where(Sources.hash_id == data.get("hash_id"))
25
+ ).first():
26
+ logger.warning(f"Item with hash_id {data.get('hash_id')} already exists")
27
+ return None # or raise an exception, or handle as needed
28
+ sources = Sources(**data)
29
+ session.add(sources)
30
+ session.commit()
31
+ session.refresh(sources)
32
+ logger.info(f"Item with hash_id {data.get('hash_id')} added to the database")
33
+ return sources
34
+
35
+
36
+ def update_one(hash_id: dict, data: dict):
37
+ with Session(engine) as session:
38
+ # Check if the item with the given hash_id exists
39
+ sources = session.exec(
40
+ select(Sources).where(Sources.hash_id == hash_id)
41
+ ).first()
42
+ if not sources:
43
+ logger.warning(f"No item with hash_id {hash_id} found for update")
44
+ return None # or raise an exception, or handle as needed
45
+ for key, value in data.items():
46
+ setattr(sources, key, value)
47
+ session.commit()
48
+ logger.info(f"Item with hash_id {hash_id} updated in the database")
49
+ return sources
50
+
51
+
52
+ def delete_one(id: int):
53
+ with Session(engine) as session:
54
+ # Check if the item with the given hash_id exists
55
+ sources = session.exec(
56
+ select(Sources).where(Sources.hash_id == id)
57
+ ).first()
58
+ if not sources:
59
+ logger.warning(f"No item with hash_id {id} found for deletion")
60
+ return None # or raise an exception, or handle as needed
61
+ session.delete(sources)
62
+ session.commit()
63
+ logger.info(f"Item with hash_id {id} deleted from the database")
64
+
65
+
66
+ def add_many(data: list):
67
+ with Session(engine) as session:
68
+ for info in data:
69
+ # Reuse add_one function for each item
70
+ result = add_one(info)
71
+ if result is None:
72
+ logger.warning(
73
+ f"Item with hash_id {info.get('hash_id')} could not be added"
74
+ )
75
+ else:
76
+ logger.info(
77
+ f"Item with hash_id {info.get('hash_id')} added to the database"
78
+ )
79
+ session.commit() # Commit at the end of the loop
80
+
81
+
82
+ def delete_many(ids: list):
83
+ with Session(engine) as session:
84
+ for id in ids:
85
+ # Reuse delete_one function for each item
86
+ result = delete_one(id)
87
+ if result is None:
88
+ logger.warning(f"No item with hash_id {id} found for deletion")
89
+ else:
90
+ logger.info(f"Item with hash_id {id} deleted from the database")
91
+ session.commit() # Commit at the end of the loop
92
+
93
+
94
+ def read_all(query: dict = None):
95
+ with Session(engine) as session:
96
+ statement = select(Sources)
97
+ if query:
98
+ statement = statement.where(
99
+ *[getattr(Sources, key) == value for key, value in query.items()]
100
+ )
101
+ sources = session.exec(statement).all()
102
+ return sources
103
+
104
+
105
+ def delete_all():
106
+ with Session(engine) as session:
107
+ session.exec(Sources).delete()
108
+ session.commit()
109
+ logger.info("All items deleted from the database")
innovation_pathfinder_ai/database/schema.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlmodel import SQLModel, Field
2
+ from typing import Optional
3
+
4
+ import datetime
5
+
6
+ class Sources(SQLModel, table=True):
7
+ id: Optional[int] = Field(default=None, primary_key=True)
8
+ url: str = Field()
9
+ title: Optional[str] = Field(default="NA", unique=False)
10
+ hash_id: str = Field(unique=True)
11
+ created_at: float = Field(default=datetime.datetime.now().timestamp())
12
+ summary: str = Field(default="")
13
+ embedded: bool = Field(default=False)
14
+
15
+ __table_args__ = {"extend_existing": True}
innovation_pathfinder_ai/structured_tools/structured_tools.py CHANGED
@@ -6,31 +6,32 @@ from langchain_community.utilities import WikipediaAPIWrapper
6
  #from langchain.tools import Tool
7
  from langchain_community.utilities import GoogleSearchAPIWrapper
8
  import arxiv
9
-
10
  # hacky and should be replaced with a database
11
  from innovation_pathfinder_ai.source_container.container import (
12
  all_sources
13
  )
14
- from innovation_pathfinder_ai.utils import create_wikipedia_urls_from_text
 
 
 
 
 
15
 
16
  @tool
17
  def arxiv_search(query: str) -> str:
18
  """Search arxiv database for scientific research papers and studies. This is your primary information source.
19
  always check it first when you search for information, before using any other tool."""
20
- # return "LangChain"
21
  global all_sources
22
- arxiv_retriever = ArxivRetriever(load_max_docs=2)
23
  data = arxiv_retriever.invoke(query)
24
  meta_data = [i.metadata for i in data]
25
- # meta_data += all_sources
26
- # all_sources += meta_data
27
- all_sources += meta_data
28
-
29
- # formatted_info = format_info(entry_id, published, title, authors)
30
-
31
- # formatted_info = format_info_list(all_sources)
32
-
33
- return meta_data.__str__()
34
 
35
  @tool
36
  def get_arxiv_paper(paper_id:str) -> None:
@@ -52,17 +53,13 @@ def get_arxiv_paper(paper_id:str) -> None:
52
  @tool
53
  def google_search(query: str) -> str:
54
  """Search Google for additional results when you can't answer questions using arxiv search or wikipedia search."""
55
- # return "LangChain"
56
  global all_sources
57
 
58
  websearch = GoogleSearchAPIWrapper()
59
- search_results:dict = websearch.results(query, 5)
60
-
61
-
62
- #organic_source = search_results['organic_results']
63
- # formatted_string = "Title: {title}, link: {link}, snippet: {snippet}".format(**organic_source)
64
- cleaner_sources = ["Title: {title}, link: {link}, snippet: {snippet}".format(**i) for i in search_results]
65
-
66
  all_sources += cleaner_sources
67
 
68
  return cleaner_sources.__str__()
@@ -75,5 +72,9 @@ def wikipedia_search(query: str) -> str:
75
  api_wrapper = WikipediaAPIWrapper()
76
  wikipedia_search = WikipediaQueryRun(api_wrapper=api_wrapper)
77
  wikipedia_results = wikipedia_search.run(query)
78
- all_sources += create_wikipedia_urls_from_text(wikipedia_results)
79
- return wikipedia_results
 
 
 
 
 
6
  #from langchain.tools import Tool
7
  from langchain_community.utilities import GoogleSearchAPIWrapper
8
  import arxiv
9
+ import ast
10
  # hacky and should be replaced with a database
11
  from innovation_pathfinder_ai.source_container.container import (
12
  all_sources
13
  )
14
+ from innovation_pathfinder_ai.utils.utils import (
15
+ parse_list_to_dicts, format_wiki_summaries, format_arxiv_documents, format_search_results
16
+ )
17
+ from innovation_pathfinder_ai.database.db_handler import (
18
+ add_many
19
+ )
20
 
21
  @tool
22
  def arxiv_search(query: str) -> str:
23
  """Search arxiv database for scientific research papers and studies. This is your primary information source.
24
  always check it first when you search for information, before using any other tool."""
 
25
  global all_sources
26
+ arxiv_retriever = ArxivRetriever(load_max_docs=3)
27
  data = arxiv_retriever.invoke(query)
28
  meta_data = [i.metadata for i in data]
29
+ formatted_sources = format_arxiv_documents(data)
30
+ all_sources += formatted_sources
31
+ parsed_sources = parse_list_to_dicts(formatted_sources)
32
+ add_many(parsed_sources)
33
+
34
+ return data.__str__()
 
 
 
35
 
36
  @tool
37
  def get_arxiv_paper(paper_id:str) -> None:
 
53
  @tool
54
  def google_search(query: str) -> str:
55
  """Search Google for additional results when you can't answer questions using arxiv search or wikipedia search."""
 
56
  global all_sources
57
 
58
  websearch = GoogleSearchAPIWrapper()
59
+ search_results:dict = websearch.results(query, 3)
60
+ cleaner_sources =format_search_results(search_results)
61
+ parsed_csources = parse_list_to_dicts(cleaner_sources)
62
+ add_many(parsed_csources)
 
 
 
63
  all_sources += cleaner_sources
64
 
65
  return cleaner_sources.__str__()
 
72
  api_wrapper = WikipediaAPIWrapper()
73
  wikipedia_search = WikipediaQueryRun(api_wrapper=api_wrapper)
74
  wikipedia_results = wikipedia_search.run(query)
75
+ formatted_summaries = format_wiki_summaries(wikipedia_results)
76
+ all_sources += formatted_summaries
77
+ parsed_summaries = parse_list_to_dicts(formatted_summaries)
78
+ add_many(parsed_summaries)
79
+
80
+ return wikipedia_results.__str__()