MrAl3x0 commited on
Commit
3712c2c
·
1 Parent(s): 6d05e12

style: make entire codebase PEP 8 compliant using Ruff

Browse files
.gitignore CHANGED
@@ -1,40 +1,22 @@
1
- # Python
2
  __pycache__/
3
  *.pyc
 
4
  *.egg-info/
5
  .venv/
6
  .pytest_cache/
7
  .mypy_cache/
8
- *.so
9
  build/
10
  dist/
11
  htmlcov/
12
  .coverage.*
13
  .nox/
14
  .tox/
15
- pip-log.txt
16
- pip-delete-this-directory.txt
17
  .ipynb_checkpoints/
18
-
19
- # Environment
20
  .env
21
-
22
- # VSCode
23
- .vscode/*.code-workspace
24
- .vscode/*.log
25
- .vscode/*.vsix
26
- .vscode/*.bak
27
- .vscode/.history/
28
- .vscode/extensions.json
29
- .vscode/launch.json
30
- .vscode/settings.json
31
-
32
- # Operating System Files
33
  .DS_Store
34
  Thumbs.db
35
  desktop.ini
36
-
37
- # Logs and temporary files
38
  *.log
39
  *.tmp
40
  *.bak
 
 
1
  __pycache__/
2
  *.pyc
3
+ *.so
4
  *.egg-info/
5
  .venv/
6
  .pytest_cache/
7
  .mypy_cache/
 
8
  build/
9
  dist/
10
  htmlcov/
11
  .coverage.*
12
  .nox/
13
  .tox/
 
 
14
  .ipynb_checkpoints/
 
 
15
  .env
16
+ .vscode/
 
 
 
 
 
 
 
 
 
 
 
17
  .DS_Store
18
  Thumbs.db
19
  desktop.ini
 
 
20
  *.log
21
  *.tmp
22
  *.bak
README.md CHANGED
@@ -1,136 +1,134 @@
1
- # LexAI Demo
2
 
3
  ## AI-Powered Legal Research Assistant
4
 
5
- This repository hosts a demonstration of **LexAI**, an AI-powered legal research assistant designed to provide relevant legal information based on user queries and specified locations. This project serves as a proof of concept, showcasing the integration of large language models (LLMs) with local embedding data for specialized information retrieval.
6
 
7
- ### Features
8
 
9
- ![LexAI Demo Screenshot](assets/screenshot.png)
10
 
11
- - **AI-Powered Responses**: Utilizes OpenAI's GPT-4 model to generate natural language responses to legal queries.
12
- - **Location-Specific Information**: Provides legal information tailored to specific jurisdictions (currently Boulder County, Colorado, and Denver, Colorado).
13
- - **Semantic Search**: Employs embeddings and vector similarity search to find the most relevant legal documents.
14
- - **Interactive Web Interface**: Built with Gradio for an easy-to-use, browser-based demonstration.
15
 
16
- ---
 
 
 
 
 
17
 
18
- ### Getting Started
19
 
20
- Follow these steps to set up and run the LexAI demo on your local machine.
21
 
22
- #### 1. Clone the Repository
23
 
24
  ```bash
25
- git clone https://github.com/alexulanch/lexai-demo.git
26
- cd lexai-demo
27
  ```
28
 
29
- This project uses [Git LFS](https://git-lfs.github.com/) to manage the embedding data.
30
 
31
- If you’re **not using the provided dev container**, install Git LFS before cloning:
32
 
33
  ```bash
34
  git lfs install
35
  git lfs pull
36
  ```
37
- ---
38
-
39
- #### 2. Install Dependencies
40
 
41
- Install the required Python packages using pip. The dependencies are: `pandas`, `numpy`, `openai`, `gradio`, `scipy`, and `python-dotenv`.
42
 
43
  ```bash
44
  pip install -r requirements.txt
45
  ```
46
 
47
- ---
48
-
49
- #### 3. Configure Your OpenAI API Key
50
-
51
- This application relies on the OpenAI API. You will need an API key to access the models used for embeddings and chat completions (e.g., `text-embedding-ada-002`, `gpt-4`).
52
-
53
- **Using a `.env` file (recommended for local development):**
54
-
55
- 1. Create a file named `.env` in the root directory of the project.
56
- 2. Add your API key to the file like this:
57
 
58
- ```dotenv
59
- OPENAI_API_KEY="your_openai_api_key_here"
60
- ```
61
 
62
- 3. Ensure `.env` is listed in `.gitignore` to avoid committing it by mistake.
 
 
 
 
63
 
64
  ---
65
 
66
- #### 4. Run the Application
67
-
68
- Start the Gradio app:
69
 
70
  ```bash
71
  python -m lexai
72
  ```
73
 
74
- You’ll see a local URL like `http://127.0.0.1:7860` — open it in your browser to use LexAI Demo.
75
 
76
  ---
77
 
78
- ### Project Structure
79
 
80
  ```
81
- lexai-demo/
82
- ├── .devcontainer/ # Dev container config for VS Code
83
- │ └── devcontainer.json
84
- ├── lexai/ # Main Python package
 
 
85
  │ ├── __init__.py
86
- │ ├── __main__.py # Gradio app entry point
87
- │ ├── config.py # Global app config and constants
88
- │ ├── core/ # Core logic components
89
- │ │ ├── data_loader.py # Loads embedding data
90
- │ │ └── matcher.py # Semantic search logic
91
- └── services/ # External API integrations
92
- └── openai_client.py # Interacts with OpenAI API
93
- ├── pyproject.toml # Project metadata and build config
94
- ├── requirements.txt # Python dependencies
95
- └── .gitignore # Files/directories Git should ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  ```
97
 
98
  ---
99
 
100
- ### Usage
101
-
102
- 1. Enter your legal question in the "Query" textbox.
103
- 2. Select the desired "Location" (Boulder or Denver) from the dropdown.
104
- 3. Click "Submit" to get an AI-generated response and relevant legal references.
105
- 4. Use the "Clear" button to reset.
106
- 5. Explore the example queries provided.
107
-
108
- ---
109
 
110
- ### Error Handling
111
 
112
- The app handles several common error cases:
113
 
114
- - `Invalid OpenAI API key...`: Check your `.env` file or environment variable setup.
115
- - `OpenAI API Error...`: Rate limits, network issues, etc.
116
- - `File Error...`: Missing or unreadable `.npz` embedding files.
117
- - `Input Error...`: Malformed or missing user input.
118
-
119
- ---
120
 
121
- ### Contributing
122
 
123
- Contributions are welcome! Open an issue or pull request with ideas or fixes.
 
 
124
 
125
  ---
126
 
127
- ### License
128
 
129
- MIT
130
 
131
  ---
132
 
133
- ### Acknowledgements
134
 
135
  - Built with [Gradio](https://gradio.app)
136
  - Powered by [OpenAI](https://openai.com)
 
1
+ # LexAI
2
 
3
  ## AI-Powered Legal Research Assistant
4
 
5
+ LexAI is an AI assistant that delivers jurisdiction-specific legal information by integrating OpenAI's language models with local vector embeddings. The system uses semantic search to surface relevant legal references and provides a web interface for users to query the model interactively.
6
 
7
+ ![LexAI Screenshot](assets/screenshot.png)
8
 
9
+ ---
10
 
11
+ ## Features
 
 
 
12
 
13
+ - **GPT-4 Integration**: Uses OpenAI's GPT-4 to generate concise, relevant legal responses.
14
+ - **Jurisdiction-Specific Search**: Preloaded embeddings for Boulder County and Denver, Colorado.
15
+ - **Semantic Search Engine**: Uses cosine similarity for embedding-based document retrieval.
16
+ - **Modern Web Interface**: Built with Gradio for real-time interaction.
17
+ - **Modular Design**: Separation of logic for UI, inference, and API handling.
18
+ - **Fully Tested**: Includes unit tests for embedding loading, matching logic, and OpenAI API integration.
19
 
20
+ ---
21
 
22
+ ## Getting Started
23
 
24
+ ### 1. Clone the Repository
25
 
26
  ```bash
27
+ git clone https://github.com/alexulanch/lexai.git
28
+ cd lexai
29
  ```
30
 
31
+ ### 2. Install Git LFS (if needed)
32
 
33
+ This project uses [Git LFS](https://git-lfs.github.com/) for storing large `.npz` embedding files.
34
 
35
  ```bash
36
  git lfs install
37
  git lfs pull
38
  ```
 
 
 
39
 
40
+ ### 3. Install Python Dependencies
41
 
42
  ```bash
43
  pip install -r requirements.txt
44
  ```
45
 
46
+ ### 4. Configure OpenAI API Key and Embedding Paths
 
 
 
 
 
 
 
 
 
47
 
48
+ Create a `.env` file in the root directory:
 
 
49
 
50
+ ```dotenv
51
+ OPENAI_API_KEY=your_openai_api_key_here
52
+ BOULDER_EMBEDDINGS_PATH=lexai/data/boulder_embeddings.npz
53
+ DENVER_EMBEDDINGS_PATH=lexai/data/denver_embeddings.npz
54
+ ```
55
 
56
  ---
57
 
58
+ ## Running the App
 
 
59
 
60
  ```bash
61
  python -m lexai
62
  ```
63
 
64
+ Then open `http://127.0.0.1:7860` in your browser.
65
 
66
  ---
67
 
68
+ ## Project Structure
69
 
70
  ```
71
+ .
72
+ ├── LICENSE
73
+ ├── README.md
74
+ ├── assets
75
+ │ └── screenshot.png
76
+ ├── lexai
77
  │ ├── __init__.py
78
+ │ ├── __main__.py
79
+ │ ├── config.py
80
+ │ ├── core
81
+ │ │ ├── __init__.py
82
+ │ │ ├── data_loader.py
83
+ │ ├── match_engine.py
84
+ └── matcher.py
85
+ ├── data
86
+ │ │ ├── boulder_embeddings.npz
87
+ │ │ └── denver_embeddings.npz
88
+ │ ├── models
89
+ │ │ └── embedding_model.py
90
+ │ ├── services
91
+ │ │ └── openai_client.py
92
+ │ └── ui
93
+ │ ├── __init__.py
94
+ │ └── gradio_interface.py
95
+ ├── pyproject.toml
96
+ ├── pytest.ini
97
+ ├── requirements.txt
98
+ └── tests
99
+ ├── __init__.py
100
+ ├── test_data_loader.py
101
+ ├── test_matcher.py
102
+ └── test_openai_client.py
103
  ```
104
 
105
  ---
106
 
107
+ ## Testing
 
 
 
 
 
 
 
 
108
 
109
+ LexAI includes a full suite of unit tests using `pytest`.
110
 
111
+ To run the tests:
112
 
113
+ ```bash
114
+ pytest
115
+ ```
 
 
 
116
 
117
+ Tests are located in the `tests/` directory and cover:
118
 
119
+ - Embedding data loading
120
+ - Semantic similarity matching
121
+ - OpenAI API interaction
122
 
123
  ---
124
 
125
+ ## License
126
 
127
+ MIT License
128
 
129
  ---
130
 
131
+ ## Acknowledgements
132
 
133
  - Built with [Gradio](https://gradio.app)
134
  - Powered by [OpenAI](https://openai.com)
assets/screenshot.png CHANGED

Git LFS Details

  • SHA256: dd2bd4dcccaddbeffda844540cc35ad9775c36fabda0ed5afd5d0be0f856552a
  • Pointer size: 131 Bytes
  • Size of remote file: 177 kB

Git LFS Details

  • SHA256: 78f6b31d42be479f16bf8cf654968120c769199bdc9f538dee9e74874126da5d
  • Pointer size: 131 Bytes
  • Size of remote file: 175 kB
lexai/__main__.py CHANGED
@@ -1,165 +1,16 @@
1
  import logging
2
- import os
3
- import openai
4
- import gradio as gr
5
- from dotenv import load_dotenv
6
 
7
- if not os.getenv("OPENAI_API_KEY"):
8
- load_dotenv(override=True)
9
 
10
- from lexai.config import (
11
- LOCATION_INFO,
12
- APP_DESCRIPTION,
13
- AI_ROLE_TEMPLATE,
14
- )
15
 
16
- from lexai.core.data_loader import load_embeddings_data
17
- from lexai.core.matcher import find_top_matches
18
- from lexai.services.openai_client import get_embedding, get_chat_completion
19
 
20
- logging.basicConfig(
21
- level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
22
- )
23
-
24
-
25
- def generate_matches(query: str, location: str) -> str:
26
- """
27
- Generate legal information matches based on the user's query and location.
28
-
29
- This function orchestrates the process of generating legal information matches
30
- using OpenAI's models and local embedding data. It returns an HTML response
31
- containing the AI-generated response and references to relevant legal information.
32
-
33
- Parameters
34
- ----------
35
- query : str
36
- The user's query for legal information.
37
- location : str
38
- The location for which the user is seeking legal information.
39
- Possible values are "Boulder" and "Denver".
40
-
41
- Returns
42
- -------
43
- str
44
- An HTML response containing the AI-generated response and references
45
- to relevant legal information. In case of an error, an error message
46
- is returned in HTML format.
47
- """
48
- try:
49
- logging.info(f"Generating embedding for query: '{query}'")
50
- query_embedding = get_embedding(query)
51
-
52
- location_data = LOCATION_INFO.get(location)
53
- if not location_data:
54
- logging.error(f"No data found for location '{location}'.")
55
- raise ValueError(
56
- f"No data found for location '{location}'. Please select a valid location."
57
- )
58
-
59
- npz_file = location_data["npz_file"]
60
- role_description_base = location_data["role_description"]
61
-
62
- logging.info(f"Loading embeddings data from: {npz_file}")
63
- embeddings, jurisdiction_data = load_embeddings_data(npz_file)
64
-
65
- logging.info("Finding top matches...")
66
- top_matches = find_top_matches(
67
- query_embedding, embeddings, jurisdiction_data, num_matches=3
68
- )
69
-
70
- full_ai_role = f"{role_description_base}\n{AI_ROLE_TEMPLATE}"
71
- top_matches_str = str(
72
- top_matches
73
- )
74
-
75
- logging.info("Getting chat completion from OpenAI...")
76
- ai_message = get_chat_completion(full_ai_role, top_matches_str, query)
77
-
78
- html_response = "<p><strong>Response:</strong></p><p>" + ai_message + "</p>"
79
- html_references = "<p><strong>References:</strong></p><ul>"
80
- for match in top_matches:
81
- url = match.get("url", "#")
82
- title = match.get("title", "No Title")
83
- subtitle = match.get("subtitle", "No Subtitle")
84
- html_references += (
85
- f'<li><a href="{url}" target="_blank">{title}: {subtitle}</a></li>'
86
- )
87
- html_references += "</ul>"
88
-
89
- logging.info("Successfully generated response and references.")
90
- return html_response + html_references
91
-
92
- except openai.AuthenticationError:
93
- logging.error("OpenAI Authentication Error: Invalid API key provided.")
94
- return """<p style="font-family: Arial, sans-serif; font-size: 16px; color: #d9534f;">
95
- <strong>Error:</strong> Invalid OpenAI API key. Please ensure your `OPENAI_API_KEY` environment variable is correctly set.
96
- </p>"""
97
- except openai.OpenAIError as e:
98
- logging.error(f"OpenAI API Error: {e}")
99
- return f"""<p style="font-family: Arial, sans-serif; font-size: 16px; color: #d9534f;">
100
- <strong>OpenAI API Error:</strong> {str(e)}
101
- </p>"""
102
- except FileNotFoundError as e:
103
- logging.error(f"File Not Found Error: {e}")
104
- return f"""<p style="font-family: Arial, sans-serif; font-size: 16px; color: #d9534f;">
105
- <strong>File Error:</strong> {str(e)} Please ensure embedding files are correctly placed.
106
- </p>"""
107
- except ValueError as e:
108
- logging.error(f"Value Error: {e}")
109
- return f"""<p style="font-family: Arial, sans-serif; font-size: 16px; color: #333;">
110
- <strong>Input Error:</strong> {str(e)}
111
- </p>"""
112
- except Exception as e:
113
- logging.exception(
114
- "An unexpected error occurred during generate_matches.")
115
- return f"""<p style="font-family: Arial, sans-serif; font-size: 16px; color: #d9534f;">
116
- <strong>Notice:</strong> An unexpected error occurred while processing your request. Please see the details below:
117
- <br>{str(e)}
118
- </p>"""
119
-
120
-
121
- with gr.Blocks(title="LexAI") as iface:
122
- gr.HTML("<h1 style='text-align: center;'>LexAI</h1>")
123
- gr.Markdown(APP_DESCRIPTION)
124
-
125
- with gr.Row():
126
- with gr.Column(scale=2):
127
- query_input = gr.Textbox(
128
- label="Query", lines=3, placeholder="Enter your legal question here...")
129
- location_input = gr.Dropdown(choices=list(
130
- LOCATION_INFO.keys()), label="Location", value=list(LOCATION_INFO.keys())[0])
131
- with gr.Row():
132
- clear_btn = gr.Button("Clear", variant="secondary")
133
- submit_btn = gr.Button("Submit", variant="primary")
134
- with gr.Column(scale=3):
135
- response_output = gr.HTML(
136
- value="<p><strong>Response:</strong></p>",
137
- show_label=False
138
- )
139
- gr.Button("Flag", variant="secondary")
140
-
141
- def handle_submit(query, location):
142
- return gr.update(value=generate_matches(query, location))
143
-
144
- def handle_clear():
145
- return gr.update(value="<p><strong>Response:</strong></p>")
146
-
147
- submit_btn.click(fn=handle_submit, inputs=[
148
- query_input, location_input], outputs=[response_output])
149
- clear_btn.click(fn=handle_clear, outputs=[response_output])
150
 
151
- gr.Examples(
152
- examples=[
153
- ["Is it legal for me to use rocks to construct a cairn in an outdoor area?", "Boulder"],
154
- ["Is it legal to possess a dog and take ownership of it as a pet?", "Denver"],
155
- ["Am I allowed to go shirtless in public spaces?", "Boulder"],
156
- ["What is the maximum height I can legally build a structure?", "Denver"],
157
- ["Is it legal to place indoor furniture on an outdoor porch?", "Boulder"],
158
- ["Can I legally graze livestock like llamas on public land?", "Denver"],
159
- ],
160
- inputs=[query_input, location_input]
161
- )
162
 
163
  if __name__ == "__main__":
164
- logging.info("Starting LexAI Gradio application...")
165
- iface.launch()
 
1
  import logging
 
 
 
 
2
 
3
+ from lexai.ui.gradio_interface import build_interface
 
4
 
 
 
 
 
 
5
 
6
+ def main():
7
+ logging.basicConfig(level=logging.INFO)
8
+ logging.getLogger("httpx").setLevel(logging.WARNING)
9
 
10
+ logging.info("Launching LexAI...")
11
+ iface = build_interface()
12
+ iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  if __name__ == "__main__":
16
+ main()
 
lexai/config.py CHANGED
@@ -1,18 +1,22 @@
1
- MODEL_ENGINE = "text-embedding-ada-002"
 
 
2
 
3
  LOCATION_INFO = {
4
  "Boulder": {
5
- "npz_file": "lexai/data/boulder_embeddings.npz",
 
 
6
  "role_description": (
7
- "You are an AI-powered legal assistant specializing in the jurisdiction of "
8
- "Boulder County, Colorado."
9
  ),
10
  },
11
  "Denver": {
12
- "npz_file": "lexai/data/denver_embeddings.npz",
13
  "role_description": (
14
- "You are an AI-powered legal assistant specializing in the jurisdiction of "
15
- "Denver, Colorado."
16
  ),
17
  },
18
  }
@@ -20,13 +24,12 @@ LOCATION_INFO = {
20
  APP_DESCRIPTION = (
21
  "LexAI is an AI-powered legal research app designed to assist individuals, "
22
  "including law enforcement officers, legal professionals, and the general public, "
23
- "in accessing accurate legal information. The app covers various jurisdictions "
24
- "and ensures that users can stay informed and confident, regardless of their location. "
25
- "This demo is meant to serve as a proof of concept."
 
26
  )
27
 
28
- OPENAI_API_KEY_PLACEHOLDER = "Enter your OpenAI API key"
29
-
30
  GPT4_MODEL = "gpt-4"
31
  GPT4_TEMPERATURE = 0.7
32
  GPT4_MAX_TOKENS = 120
@@ -34,9 +37,11 @@ GPT4_TOP_P = 1
34
  GPT4_FREQUENCY_PENALTY = 0
35
  GPT4_PRESENCE_PENALTY = 0
36
 
37
- AI_ROLE_TEMPLATE = """
38
- Your expertise lies in providing accurate and timely information on the laws and regulations specific to your jurisdiction.
39
- Your role is to assist individuals, including law enforcement officers, legal professionals, and the general public,
40
- in understanding and applying legal standards within this jurisdiction. You are knowledgeable, precise, and always
41
- ready to offer guidance on legal matters. Your max_tokens is set to 120 so keep your response below that.
42
- """
 
 
 
1
+ import os
2
+
3
+ EMBEDDING_MODEL = "text-embedding-ada-002"
4
 
5
  LOCATION_INFO = {
6
  "Boulder": {
7
+ "npz_file": os.getenv(
8
+ "BOULDER_NPZ_FILE", "lexai/data/boulder_embeddings.npz"
9
+ ),
10
  "role_description": (
11
+ "You are an AI-powered legal assistant specializing in the jurisdiction "
12
+ "of Boulder County, Colorado."
13
  ),
14
  },
15
  "Denver": {
16
+ "npz_file": os.getenv("DENVER_NPZ_FILE", "lexai/data/denver_embeddings.npz"),
17
  "role_description": (
18
+ "You are an AI-powered legal assistant specializing in the jurisdiction "
19
+ "of Denver, Colorado."
20
  ),
21
  },
22
  }
 
24
  APP_DESCRIPTION = (
25
  "LexAI is an AI-powered legal research app designed to assist individuals, "
26
  "including law enforcement officers, legal professionals, and the general public, "
27
+ "in accessing jurisdiction-specific legal information. While LexAI aims to provide "
28
+ "useful and relevant results, it does not constitute legal advice. Its output may "
29
+ "not always be accurate or up to date. Users should verify information "
30
+ "independently and consult qualified legal professionals when needed."
31
  )
32
 
 
 
33
  GPT4_MODEL = "gpt-4"
34
  GPT4_TEMPERATURE = 0.7
35
  GPT4_MAX_TOKENS = 120
 
37
  GPT4_FREQUENCY_PENALTY = 0
38
  GPT4_PRESENCE_PENALTY = 0
39
 
40
+ AI_ROLE_TEMPLATE = (
41
+ "Your expertise lies in providing accurate and timely information on the laws and "
42
+ "regulations specific to your jurisdiction. Your role is to assist individuals, "
43
+ "including law enforcement officers, legal professionals, and the general public. "
44
+ "You help them understand and apply legal standards within this jurisdiction. You "
45
+ "are knowledgeable, precise, and always ready to offer guidance on legal matters. "
46
+ "Your max_tokens is set to 120, so keep your response below that."
47
+ )
lexai/core/data_loader.py CHANGED
@@ -1,9 +1,10 @@
 
 
1
  import numpy as np
2
  import pandas as pd
3
- import os
4
 
5
 
6
- def load_embeddings_data(npz_file_path: str) -> tuple[np.ndarray, pd.DataFrame]:
7
  """
8
  Loads embeddings and associated jurisdiction data from a .npz file.
9
 
 
1
+ import os
2
+
3
  import numpy as np
4
  import pandas as pd
 
5
 
6
 
7
+ def load_embeddings(npz_file_path: str) -> tuple[np.ndarray, pd.DataFrame]:
8
  """
9
  Loads embeddings and associated jurisdiction data from a .npz file.
10
 
lexai/core/match_engine.py CHANGED
@@ -1,62 +1,87 @@
1
  import logging
2
- import os
 
3
  import openai
4
- from dotenv import load_dotenv
5
 
6
- from lexai.config import LOCATION_INFO, AI_ROLE_TEMPLATE
7
  from lexai.core.data_loader import load_embeddings
8
  from lexai.core.matcher import find_top_matches
9
- from lexai.services.openai_client import get_embedding, get_chat_completion
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
- if not os.getenv("OPENAI_API_KEY"):
14
- load_dotenv(override=True)
15
-
16
 
17
  def generate_matches(query: str, location: str) -> str:
18
- try:
19
- location_data = LOCATION_INFO.get(location)
20
- if not location_data:
21
- raise ValueError(f"Invalid location: '{location}'")
 
 
22
 
 
23
  query_embedding = get_embedding(query)
24
- embeddings, metadata_df = load_embeddings(location_data["npz_file"])
 
25
 
26
- if embeddings.shape[0] != len(metadata_df):
27
  raise ValueError(
28
- "Mismatch between number of embeddings and metadata entries")
 
 
 
29
 
30
- top_matches = find_top_matches(
31
- query_embedding, embeddings, metadata_df)
32
- system_prompt = f"{location_data['role_description']}\n{AI_ROLE_TEMPLATE}"
33
  ai_response = get_chat_completion(
34
  system_prompt, str(top_matches), query)
35
 
36
- response_html = f"<p><strong>Response:</strong></p><p>{ai_response}</p>"
 
 
 
37
  reference_html = "<p><strong>References:</strong></p><ul>"
38
 
39
  for match in top_matches:
40
- url = match.get("url", "#")
41
- title = match.get("title", "Untitled")
42
- subtitle = match.get("subtitle", "")
43
- reference_html += f'<li><a href="{url}" target="_blank">{title}: {subtitle}</a></li>'
 
 
 
44
 
45
  reference_html += "</ul>"
46
  return response_html + reference_html
47
 
48
  except openai.AuthenticationError:
49
  logger.error("Invalid OpenAI API key.")
50
- return "<p style='color: #d9534f;'><strong>Error:</strong> Invalid OpenAI API key.</p>"
 
 
 
51
  except openai.OpenAIError as e:
52
  logger.error(f"OpenAI API Error: {e}")
53
- return f"<p style='color: #d9534f;'><strong>OpenAI Error:</strong> {e}</p>"
 
 
 
54
  except FileNotFoundError as e:
55
  logger.error(f"File not found: {e}")
56
- return f"<p style='color: #d9534f;'><strong>File Error:</strong> {e}</p>"
 
 
 
57
  except ValueError as e:
58
  logger.error(f"Value error: {e}")
59
- return f"<p><strong>Input Error:</strong> {e}</p>"
 
 
 
60
  except Exception as e:
61
  logger.exception("Unhandled exception during generate_matches.")
62
- return f"<p style='color: #d9534f;'><strong>Unexpected error:</strong> {e}</p>"
 
 
 
 
1
  import logging
2
+ from html import escape
3
+
4
  import openai
 
5
 
6
+ from lexai.config import AI_ROLE_TEMPLATE, LOCATION_INFO
7
  from lexai.core.data_loader import load_embeddings
8
  from lexai.core.matcher import find_top_matches
9
+ from lexai.services.openai_client import get_chat_completion, get_embedding
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
 
13
 
14
  def generate_matches(query: str, location: str) -> str:
15
+ if location not in LOCATION_INFO:
16
+ logger.error(f"Invalid location: {location}")
17
+ return (
18
+ "<p><strong>Input Error:</strong> "
19
+ f"Invalid location: '{escape(location)}'</p>"
20
+ )
21
 
22
+ try:
23
  query_embedding = get_embedding(query)
24
+ location_data = LOCATION_INFO[location]
25
+ embeddings, metadata = load_embeddings(location_data["npz_file"])
26
 
27
+ if embeddings.shape[0] != len(metadata):
28
  raise ValueError(
29
+ "Mismatch between number of embeddings and metadata entries"
30
+ )
31
+
32
+ top_matches = find_top_matches(query_embedding, embeddings, metadata)
33
 
34
+ system_prompt = (
35
+ f"{location_data['role_description']}\n{AI_ROLE_TEMPLATE}"
36
+ )
37
  ai_response = get_chat_completion(
38
  system_prompt, str(top_matches), query)
39
 
40
+ response_html = (
41
+ "<p><strong>Response:</strong></p>"
42
+ f"<p>{escape(ai_response)}</p>"
43
+ )
44
  reference_html = "<p><strong>References:</strong></p><ul>"
45
 
46
  for match in top_matches:
47
+ url = escape(match["url"])
48
+ title = escape(match["title"])
49
+ subtitle = escape(match["subtitle"])
50
+ reference_html += (
51
+ f'<li><a href="{url}" target="_blank">'
52
+ f"{title}: {subtitle}</a></li>"
53
+ )
54
 
55
  reference_html += "</ul>"
56
  return response_html + reference_html
57
 
58
  except openai.AuthenticationError:
59
  logger.error("Invalid OpenAI API key.")
60
+ return (
61
+ "<p style='color: #d9534f;'><strong>Error:</strong> "
62
+ "Invalid OpenAI API key.</p>"
63
+ )
64
  except openai.OpenAIError as e:
65
  logger.error(f"OpenAI API Error: {e}")
66
+ return (
67
+ "<p style='color: #d9534f;'><strong>OpenAI Error:</strong> "
68
+ f"{escape(str(e))}</p>"
69
+ )
70
  except FileNotFoundError as e:
71
  logger.error(f"File not found: {e}")
72
+ return (
73
+ "<p style='color: #d9534f;'><strong>File Error:</strong> "
74
+ f"{escape(str(e))}</p>"
75
+ )
76
  except ValueError as e:
77
  logger.error(f"Value error: {e}")
78
+ return (
79
+ "<p><strong>Input Error:</strong> "
80
+ f"{escape(str(e))}</p>"
81
+ )
82
  except Exception as e:
83
  logger.exception("Unhandled exception during generate_matches.")
84
+ return (
85
+ "<p style='color: #d9534f;'><strong>Unexpected error:</strong> "
86
+ f"{escape(str(e))}</p>"
87
+ )
lexai/core/matcher.py CHANGED
@@ -1,7 +1,9 @@
 
 
1
  import numpy as np
2
  import pandas as pd
3
  from scipy.spatial.distance import cdist
4
- from typing import Any
5
 
6
  def find_top_matches(
7
  query_embedding: np.ndarray,
@@ -30,10 +32,22 @@ def find_top_matches(
30
  A list of dictionaries, where each dictionary represents a top match
31
  and contains its 'url', 'title', 'subtitle', and 'content'.
32
  """
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- distances = cdist(query_embedding.reshape(1, -1), embeddings, metric="cosine")[0]
35
- indices = np.argsort(distances)[:num_matches]
36
- subset: pd.DataFrame = jurisdiction_data.loc[indices]
37
- top_matches: list[dict[str, Any]] = subset.to_dict("records")
 
38
 
39
- return top_matches
 
1
+ from typing import Any
2
+
3
  import numpy as np
4
  import pandas as pd
5
  from scipy.spatial.distance import cdist
6
+
7
 
8
  def find_top_matches(
9
  query_embedding: np.ndarray,
 
32
  A list of dictionaries, where each dictionary represents a top match
33
  and contains its 'url', 'title', 'subtitle', and 'content'.
34
  """
35
+ if jurisdiction_data.empty or embeddings.shape[0] == 0:
36
+ return []
37
+
38
+ if jurisdiction_data.shape[0] != embeddings.shape[0]:
39
+ raise ValueError(
40
+ "Number of embeddings and metadata entries must match.")
41
+
42
+ if query_embedding.ndim != 1 or query_embedding.shape[0] != embeddings.shape[1]:
43
+ raise ValueError(
44
+ "Query embedding must match the dimensionality of the embeddings."
45
+ )
46
 
47
+ distances = cdist(query_embedding.reshape(1, -1),
48
+ embeddings, metric="cosine")[0]
49
+ safe_num_matches = min(num_matches, len(jurisdiction_data))
50
+ indices = np.argsort(distances)[:safe_num_matches]
51
+ subset = jurisdiction_data.iloc[indices]
52
 
53
+ return subset.to_dict("records")
lexai/models/embedding_model.py DELETED
File without changes
lexai/services/openai_client.py CHANGED
@@ -1,14 +1,16 @@
1
- from openai import OpenAI
2
- import numpy as np
3
  import os
 
 
 
 
4
  from lexai.config import (
5
- MODEL_ENGINE,
 
 
6
  GPT4_MODEL,
 
7
  GPT4_TEMPERATURE,
8
- GPT4_MAX_TOKENS,
9
  GPT4_TOP_P,
10
- GPT4_FREQUENCY_PENALTY,
11
- GPT4_PRESENCE_PENALTY,
12
  )
13
 
14
  API_KEY = os.getenv("OPENAI_API_KEY")
@@ -17,10 +19,7 @@ client = OpenAI(api_key=API_KEY)
17
 
18
  def get_embedding(text: str) -> np.ndarray:
19
  """
20
- Generates an embedding for the given text using OpenAI's text-embedding model.
21
-
22
- The OpenAI API key is loaded from the OPENAI_API_KEY environment variable
23
- to authenticate the request.
24
 
25
  Parameters
26
  ----------
@@ -30,67 +29,61 @@ def get_embedding(text: str) -> np.ndarray:
30
  Returns
31
  -------
32
  np.ndarray
33
- A NumPy array representing the embedding of the input text.
34
 
35
  Raises
36
  ------
37
  openai.AuthenticationError
38
- If the OPENAI_API_KEY environment variable is not set or is invalid.
39
  openai.OpenAIError
40
- If there's another issue with the OpenAI API call, such as network problems.
41
  """
42
- response = client.embeddings.create(input=text, model=MODEL_ENGINE)
43
  return np.array(response.data[0].embedding)
44
 
45
 
46
- def get_chat_completion(role_description: str, top_matches_str: str, query: str) -> str:
 
 
 
 
47
  """
48
- Generates a chat completion response using OpenAI's GPT-4 model.
49
-
50
- The OpenAI API key is loaded from the OPENAI_API_KEY environment variable
51
- to authenticate the request. The function constructs a conversation history
52
- with system and user roles to provide context to the language model.
53
 
54
  Parameters
55
  ----------
56
  role_description : str
57
- The system role description for the AI assistant, defining its persona
58
- and limitations.
59
  top_matches_str : str
60
- A string representation of the top legal information matches. This is
61
- provided as system context to help the AI formulate relevant responses.
62
  query : str
63
- The user's direct query or question.
64
 
65
  Returns
66
  -------
67
  str
68
- The AI-generated response message from the chat completion.
69
 
70
  Raises
71
  ------
72
  openai.AuthenticationError
73
- If the OPENAI_API_KEY environment variable is not set or is invalid.
74
  openai.OpenAIError
75
- If there's an issue with the OpenAI API call, such as rate limiting,
76
- or other API-related errors.
77
  """
78
-
79
- response = client.chat.completions.create(model=GPT4_MODEL,
80
- messages=[
81
- {"role": "system",
82
- "content": role_description.strip()},
83
- {"role": "system",
84
- "content": top_matches_str},
85
- {"role": "user",
86
- "content": query},
87
- {"role": "assistant",
88
- "content": ""},
89
- ],
90
- temperature=GPT4_TEMPERATURE,
91
- max_tokens=GPT4_MAX_TOKENS,
92
- top_p=GPT4_TOP_P,
93
- frequency_penalty=GPT4_FREQUENCY_PENALTY,
94
- presence_penalty=GPT4_PRESENCE_PENALTY)
95
 
96
  return response.choices[0].message.content.strip()
 
 
 
1
  import os
2
+
3
+ import numpy as np
4
+ from openai import OpenAI
5
+
6
  from lexai.config import (
7
+ EMBEDDING_MODEL,
8
+ GPT4_FREQUENCY_PENALTY,
9
+ GPT4_MAX_TOKENS,
10
  GPT4_MODEL,
11
+ GPT4_PRESENCE_PENALTY,
12
  GPT4_TEMPERATURE,
 
13
  GPT4_TOP_P,
 
 
14
  )
15
 
16
  API_KEY = os.getenv("OPENAI_API_KEY")
 
19
 
20
  def get_embedding(text: str) -> np.ndarray:
21
  """
22
+ Generates an embedding for the given text using OpenAI's embedding model.
 
 
 
23
 
24
  Parameters
25
  ----------
 
29
  Returns
30
  -------
31
  np.ndarray
32
+ A NumPy array representing the embedding.
33
 
34
  Raises
35
  ------
36
  openai.AuthenticationError
37
+ If the API key is not set or invalid.
38
  openai.OpenAIError
39
+ For other API-related issues.
40
  """
41
+ response = client.embeddings.create(input=text, model=EMBEDDING_MODEL)
42
  return np.array(response.data[0].embedding)
43
 
44
 
45
+ def get_chat_completion(
46
+ role_description: str,
47
+ top_matches_str: str,
48
+ query: str,
49
+ ) -> str:
50
  """
51
+ Generates a chat completion using OpenAI's GPT-4 model.
 
 
 
 
52
 
53
  Parameters
54
  ----------
55
  role_description : str
56
+ Description of the assistant's persona and context.
 
57
  top_matches_str : str
58
+ Summary of top legal matches used to guide the assistant.
 
59
  query : str
60
+ The users legal query.
61
 
62
  Returns
63
  -------
64
  str
65
+ The AI-generated response.
66
 
67
  Raises
68
  ------
69
  openai.AuthenticationError
70
+ If the API key is not set or invalid.
71
  openai.OpenAIError
72
+ For other API-related issues.
 
73
  """
74
+ response = client.chat.completions.create(
75
+ model=GPT4_MODEL,
76
+ messages=[
77
+ {"role": "system", "content": role_description.strip()},
78
+ {"role": "system", "content": top_matches_str},
79
+ {"role": "user", "content": query},
80
+ {"role": "assistant", "content": ""},
81
+ ],
82
+ temperature=GPT4_TEMPERATURE,
83
+ max_tokens=GPT4_MAX_TOKENS,
84
+ top_p=GPT4_TOP_P,
85
+ frequency_penalty=GPT4_FREQUENCY_PENALTY,
86
+ presence_penalty=GPT4_PRESENCE_PENALTY,
87
+ )
 
 
 
88
 
89
  return response.choices[0].message.content.strip()
lexai/ui/gradio_interface.py CHANGED
@@ -1,12 +1,14 @@
1
- import gradio as gr
2
  import logging
3
- from lexai.config import LOCATION_INFO, APP_DESCRIPTION
 
 
 
4
  from lexai.core.match_engine import generate_matches
5
 
6
  logger = logging.getLogger(__name__)
7
 
8
 
9
- def launch_interface():
10
  with gr.Blocks(title="LexAI") as iface:
11
  gr.HTML("<h1 style='text-align: center;'>LexAI</h1>")
12
  gr.Markdown(APP_DESCRIPTION)
@@ -56,5 +58,5 @@ def launch_interface():
56
  inputs=[query_input, location_input]
57
  )
58
 
59
- logger.info("Launching LexAI interface...")
60
- iface.launch()
 
 
1
  import logging
2
+
3
+ import gradio as gr
4
+
5
+ from lexai.config import APP_DESCRIPTION, LOCATION_INFO
6
  from lexai.core.match_engine import generate_matches
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
 
11
+ def build_interface():
12
  with gr.Blocks(title="LexAI") as iface:
13
  gr.HTML("<h1 style='text-align: center;'>LexAI</h1>")
14
  gr.Markdown(APP_DESCRIPTION)
 
58
  inputs=[query_input, location_input]
59
  )
60
 
61
+ logger.info("LexAI interface built.")
62
+ return iface
pyproject.toml CHANGED
@@ -1,31 +1,34 @@
1
  [project]
2
  name = "lexai"
3
  version = "0.1.0"
4
- description = "A demo of LexAI, an AI legal assistant that delivers accurate legal information."
5
  readme = "README.md"
6
  requires-python = ">=3.8"
7
  license = { text = "MIT" }
 
8
  authors = [
9
  { name = "Alex Ulanch", email = "alexulanch@gmail.com" },
10
  ]
 
11
  keywords = ["AI", "Legal", "Gradio", "OpenAI", "RAG"]
 
12
  classifiers = [
13
- "Programming Language :: Python :: 3",
14
- "License :: OSI Approved :: MIT License",
15
- "Operating System :: OS Independent",
16
- "Development Status :: 3 - Alpha",
17
- "Intended Audience :: Developers",
18
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
19
- "Topic :: Software Development :: Libraries :: Application Frameworks",
20
  ]
21
 
22
  dependencies = [
23
- "pandas",
24
- "numpy",
25
- "openai",
26
- "gradio",
27
- "scipy",
28
- "python-dotenv",
29
  ]
30
 
31
  [project.urls]
@@ -42,11 +45,11 @@ include = ["lexai*"]
42
 
43
  [tool.black]
44
  line-length = 88
45
- target-version = ['py38']
46
  include = '\.pyi?$'
47
  exclude = '''
48
  /(
49
- \.git
50
  | \.venv
51
  | \.mypy_cache
52
  | \.pytest_cache
@@ -60,20 +63,26 @@ exclude = '''
60
  '''
61
 
62
  [tool.isort]
63
- known_local_folder = ["lexai"]
64
  profile = "black"
65
- line_length = 88
66
  known_first_party = ["lexai"]
67
- skip_glob = ["**/data/*"]
 
68
  multi_line_output = 3
69
  include_trailing_comma = true
70
  force_grid_wrap = 0
71
  use_parentheses = true
72
  ensure_newline_before_comments = true
 
73
 
74
  [tool.pytest.ini_options]
75
  minversion = "6.0"
76
  addopts = "-ra -q"
77
- testpaths = [
78
- "tests",
79
- ]
 
 
 
 
 
 
 
1
  [project]
2
  name = "lexai"
3
  version = "0.1.0"
4
+ description = "LexAI is an AI legal assistant that provides accurate, location-specific legal information in a clear and accessible format."
5
  readme = "README.md"
6
  requires-python = ">=3.8"
7
  license = { text = "MIT" }
8
+
9
  authors = [
10
  { name = "Alex Ulanch", email = "alexulanch@gmail.com" },
11
  ]
12
+
13
  keywords = ["AI", "Legal", "Gradio", "OpenAI", "RAG"]
14
+
15
  classifiers = [
16
+ "Programming Language :: Python :: 3",
17
+ "License :: OSI Approved :: MIT License",
18
+ "Operating System :: OS Independent",
19
+ "Development Status :: 3 - Alpha",
20
+ "Intended Audience :: Developers",
21
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
22
+ "Topic :: Software Development :: Libraries :: Application Frameworks"
23
  ]
24
 
25
  dependencies = [
26
+ "pandas",
27
+ "numpy",
28
+ "openai",
29
+ "gradio",
30
+ "scipy",
31
+ "python-dotenv"
32
  ]
33
 
34
  [project.urls]
 
45
 
46
  [tool.black]
47
  line-length = 88
48
+ target-version = ["py38"]
49
  include = '\.pyi?$'
50
  exclude = '''
51
  /(
52
+ \.git
53
  | \.venv
54
  | \.mypy_cache
55
  | \.pytest_cache
 
63
  '''
64
 
65
  [tool.isort]
 
66
  profile = "black"
 
67
  known_first_party = ["lexai"]
68
+ known_local_folder = ["lexai"]
69
+ line_length = 88
70
  multi_line_output = 3
71
  include_trailing_comma = true
72
  force_grid_wrap = 0
73
  use_parentheses = true
74
  ensure_newline_before_comments = true
75
+ skip_glob = ["**/data/*"]
76
 
77
  [tool.pytest.ini_options]
78
  minversion = "6.0"
79
  addopts = "-ra -q"
80
+ testpaths = ["tests"]
81
+
82
+ [tool.ruff]
83
+ line-length = 88
84
+ target-version = "py38"
85
+ exclude = ["data", "build", "dist"]
86
+
87
+ [tool.ruff.lint]
88
+ select = ["E", "F", "W", "I"]
tests/test_data_loader.py CHANGED
@@ -1,7 +1,8 @@
1
- import pytest
 
2
  import numpy as np
3
  import pandas as pd
4
- from pathlib import Path
5
 
6
  from lexai.core.data_loader import load_embeddings
7
 
@@ -46,12 +47,12 @@ def test_load_embeddings_success(temp_npz_file):
46
 
47
 
48
  def test_load_embeddings_missing_key(broken_npz_missing_embeddings):
49
- with pytest.raises(KeyError, match="Missing key 'embeddings'"):
50
  load_embeddings(broken_npz_missing_embeddings)
51
 
52
 
53
  def test_load_metadata_missing_key(broken_npz_missing_columns):
54
- with pytest.raises(KeyError, match="Missing key 'titles'"):
55
  load_embeddings(broken_npz_missing_columns)
56
 
57
 
 
1
+ from pathlib import Path
2
+
3
  import numpy as np
4
  import pandas as pd
5
+ import pytest
6
 
7
  from lexai.core.data_loader import load_embeddings
8
 
 
47
 
48
 
49
  def test_load_embeddings_missing_key(broken_npz_missing_embeddings):
50
+ with pytest.raises(KeyError, match="Missing key"):
51
  load_embeddings(broken_npz_missing_embeddings)
52
 
53
 
54
  def test_load_metadata_missing_key(broken_npz_missing_columns):
55
+ with pytest.raises(KeyError, match="Missing key"):
56
  load_embeddings(broken_npz_missing_columns)
57
 
58
 
tests/test_main.py DELETED
File without changes
tests/test_matcher.py CHANGED
@@ -1,29 +1,46 @@
1
- import pytest
2
  import numpy as np
3
  import pandas as pd
 
4
 
5
  from lexai.core.matcher import find_top_matches
6
 
7
 
8
  @pytest.fixture
9
  def sample_embeddings():
10
- return np.array([
11
- [1.0, 0.1, 0.1],
12
- [0.8, 0.3, 0.2],
13
- [0.5, 0.5, 0.5],
14
- [0.1, 0.1, 1.0],
15
- [0.0, 0.0, 0.0]
16
- ], dtype=np.float32)
 
 
 
17
 
18
 
19
  @pytest.fixture
20
  def sample_jurisdiction_data():
21
- return pd.DataFrame({
22
- "url": ["url1", "url2", "url3", "url4", "url5"],
23
- "title": ["Title 1", "Title 2", "Title 3", "Title 4", "Title 5"],
24
- "subtitle": ["Subtitle A", "Subtitle B", "Subtitle C", "Subtitle D", "Subtitle E"],
25
- "content": ["Content X", "Content Y", "Content Z", "Content W", "Content V"]
26
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  @pytest.fixture
@@ -41,35 +58,69 @@ def empty_jurisdiction_data():
41
  return pd.DataFrame(columns=["url", "title", "subtitle", "content"])
42
 
43
 
44
- def test_returns_expected_number_of_matches(sample_query_embedding, sample_embeddings, sample_jurisdiction_data):
 
 
45
  matches = find_top_matches(
46
- sample_query_embedding, sample_embeddings, sample_jurisdiction_data, num_matches=3)
 
 
 
 
47
  assert len(matches) == 3
48
- assert [m["title"] for m in matches] == ["Title 1", "Title 2", "Title 3"]
 
 
 
 
49
 
50
 
51
- def test_returns_all_available_matches_if_less_than_requested(sample_query_embedding, sample_embeddings, sample_jurisdiction_data):
 
 
52
  matches = find_top_matches(
53
- sample_query_embedding, sample_embeddings, sample_jurisdiction_data, num_matches=10)
 
 
 
 
54
  assert len(matches) == len(sample_embeddings)
55
  assert matches[0]["title"] == "Title 1"
56
 
57
 
58
- def test_returns_empty_list_for_empty_embeddings(sample_query_embedding, empty_embeddings, empty_jurisdiction_data):
 
 
59
  matches = find_top_matches(
60
- sample_query_embedding, empty_embeddings, empty_jurisdiction_data, num_matches=3)
 
 
 
 
61
  assert matches == []
62
 
63
 
64
- def test_returns_empty_list_for_empty_jurisdiction_data(sample_query_embedding, sample_embeddings, empty_jurisdiction_data):
 
 
65
  matches = find_top_matches(
66
- sample_query_embedding, sample_embeddings, empty_jurisdiction_data, num_matches=3)
 
 
 
 
67
  assert matches == []
68
 
69
 
70
- def test_output_contains_expected_keys(sample_query_embedding, sample_embeddings, sample_jurisdiction_data):
 
 
71
  matches = find_top_matches(
72
- sample_query_embedding, sample_embeddings, sample_jurisdiction_data, num_matches=1)
 
 
 
 
73
  match = matches[0]
74
  assert set(match.keys()) == {"url", "title", "subtitle", "content"}
75
  assert match["url"] == "url1"
@@ -78,24 +129,36 @@ def test_output_contains_expected_keys(sample_query_embedding, sample_embeddings
78
  assert match["content"] == "Content X"
79
 
80
 
81
- def test_handles_single_embedding_and_row(sample_query_embedding, sample_embeddings, sample_jurisdiction_data):
 
 
82
  matches = find_top_matches(
83
  sample_query_embedding,
84
- sample_embeddings[0:1],
85
- sample_jurisdiction_data.iloc[0:1],
86
- num_matches=3
87
  )
88
  assert len(matches) == 1
89
  assert matches[0]["title"] == "Title 1"
90
 
91
 
92
- def test_raises_for_invalid_query_embedding_shape(sample_embeddings, sample_jurisdiction_data):
93
- bad_embedding = np.array([1.0, 2.0], dtype=np.float32)
94
- with pytest.raises(ValueError, match="same number of columns"):
95
- find_top_matches(bad_embedding, sample_embeddings,
96
- sample_jurisdiction_data, num_matches=1)
 
 
 
 
 
 
97
 
98
- scalar_embedding = np.array(1.0, dtype=np.float32)
99
  with pytest.raises(ValueError):
100
- find_top_matches(scalar_embedding, sample_embeddings,
101
- sample_jurisdiction_data, num_matches=1)
 
 
 
 
 
 
1
  import numpy as np
2
  import pandas as pd
3
+ import pytest
4
 
5
  from lexai.core.matcher import find_top_matches
6
 
7
 
8
  @pytest.fixture
9
  def sample_embeddings():
10
+ return np.array(
11
+ [
12
+ [1.0, 0.1, 0.1],
13
+ [0.8, 0.3, 0.2],
14
+ [0.5, 0.5, 0.5],
15
+ [0.1, 0.1, 1.0],
16
+ [0.0, 0.0, 0.0],
17
+ ],
18
+ dtype=np.float32,
19
+ )
20
 
21
 
22
  @pytest.fixture
23
  def sample_jurisdiction_data():
24
+ return pd.DataFrame(
25
+ {
26
+ "url": ["url1", "url2", "url3", "url4", "url5"],
27
+ "title": ["Title 1", "Title 2", "Title 3", "Title 4", "Title 5"],
28
+ "subtitle": [
29
+ "Subtitle A",
30
+ "Subtitle B",
31
+ "Subtitle C",
32
+ "Subtitle D",
33
+ "Subtitle E",
34
+ ],
35
+ "content": [
36
+ "Content X",
37
+ "Content Y",
38
+ "Content Z",
39
+ "Content W",
40
+ "Content V",
41
+ ],
42
+ }
43
+ )
44
 
45
 
46
  @pytest.fixture
 
58
  return pd.DataFrame(columns=["url", "title", "subtitle", "content"])
59
 
60
 
61
+ def test_returns_expected_number_of_matches(
62
+ sample_query_embedding, sample_embeddings, sample_jurisdiction_data
63
+ ):
64
  matches = find_top_matches(
65
+ sample_query_embedding,
66
+ sample_embeddings,
67
+ sample_jurisdiction_data,
68
+ num_matches=3,
69
+ )
70
  assert len(matches) == 3
71
+ assert [match["title"] for match in matches] == [
72
+ "Title 1",
73
+ "Title 2",
74
+ "Title 3",
75
+ ]
76
 
77
 
78
+ def test_returns_all_available_matches_if_less_than_requested(
79
+ sample_query_embedding, sample_embeddings, sample_jurisdiction_data
80
+ ):
81
  matches = find_top_matches(
82
+ sample_query_embedding,
83
+ sample_embeddings,
84
+ sample_jurisdiction_data,
85
+ num_matches=10,
86
+ )
87
  assert len(matches) == len(sample_embeddings)
88
  assert matches[0]["title"] == "Title 1"
89
 
90
 
91
+ def test_returns_empty_list_for_empty_embeddings(
92
+ sample_query_embedding, empty_embeddings, empty_jurisdiction_data
93
+ ):
94
  matches = find_top_matches(
95
+ sample_query_embedding,
96
+ empty_embeddings,
97
+ empty_jurisdiction_data,
98
+ num_matches=3,
99
+ )
100
  assert matches == []
101
 
102
 
103
+ def test_returns_empty_list_for_empty_jurisdiction_data(
104
+ sample_query_embedding, sample_embeddings, empty_jurisdiction_data
105
+ ):
106
  matches = find_top_matches(
107
+ sample_query_embedding,
108
+ sample_embeddings,
109
+ empty_jurisdiction_data,
110
+ num_matches=3,
111
+ )
112
  assert matches == []
113
 
114
 
115
+ def test_output_contains_expected_keys(
116
+ sample_query_embedding, sample_embeddings, sample_jurisdiction_data
117
+ ):
118
  matches = find_top_matches(
119
+ sample_query_embedding,
120
+ sample_embeddings,
121
+ sample_jurisdiction_data,
122
+ num_matches=1,
123
+ )
124
  match = matches[0]
125
  assert set(match.keys()) == {"url", "title", "subtitle", "content"}
126
  assert match["url"] == "url1"
 
129
  assert match["content"] == "Content X"
130
 
131
 
132
+ def test_handles_single_embedding_and_row(
133
+ sample_query_embedding, sample_embeddings, sample_jurisdiction_data
134
+ ):
135
  matches = find_top_matches(
136
  sample_query_embedding,
137
+ sample_embeddings[:1],
138
+ sample_jurisdiction_data.iloc[:1],
139
+ num_matches=3,
140
  )
141
  assert len(matches) == 1
142
  assert matches[0]["title"] == "Title 1"
143
 
144
 
145
+ def test_raises_for_invalid_query_embedding_shape(
146
+ sample_embeddings, sample_jurisdiction_data
147
+ ):
148
+ invalid_vector = np.array([1.0, 2.0], dtype=np.float32)
149
+ with pytest.raises(ValueError, match="dimensionality of the embeddings"):
150
+ find_top_matches(
151
+ invalid_vector,
152
+ sample_embeddings,
153
+ sample_jurisdiction_data,
154
+ num_matches=1,
155
+ )
156
 
157
+ scalar_value = np.array(1.0, dtype=np.float32)
158
  with pytest.raises(ValueError):
159
+ find_top_matches(
160
+ scalar_value,
161
+ sample_embeddings,
162
+ sample_jurisdiction_data,
163
+ num_matches=1,
164
+ )
tests/test_matching.py DELETED
File without changes
tests/test_openai_client.py CHANGED
@@ -1,6 +1,8 @@
 
 
1
  import numpy as np
2
- from unittest.mock import patch, MagicMock
3
- from lexai.services.openai_client import get_embedding, get_chat_completion
4
 
5
 
6
  @patch("lexai.services.openai_client.client")
 
1
+ from unittest.mock import MagicMock, patch
2
+
3
  import numpy as np
4
+
5
+ from lexai.services.openai_client import get_chat_completion, get_embedding
6
 
7
 
8
  @patch("lexai.services.openai_client.client")