Spaces:
Sleeping
Sleeping
Eddie Pick
commited on
Changed provider/model delimiter to ':'
Browse files- README.md +7 -3
- models.py +9 -5
- search_agent.py +7 -0
README.md
CHANGED
@@ -72,7 +72,7 @@ python search_agent.py [OPTIONS] SEARCH_QUERY
|
|
72 |
- `-c`, `--copywrite`: First produce a draft, review it, and rewrite for a final text.
|
73 |
- `-d DOMAIN`, `--domain=DOMAIN`: Limit search to a specific domain.
|
74 |
- `-t TEMP`, `--temperature=TEMP`: Set the temperature of the LLM [default: 0.0].
|
75 |
-
- `-m MODEL`, `--model=MODEL`: Use a specific model [default: openai
|
76 |
- `-e MODEL`, `--embedding_model=MODEL`: Use a specific embedding model [default: same provider as model].
|
77 |
- `-n NUM`, `--max_pages=NUM`: Max number of pages to retrieve [default: 10].
|
78 |
- `-x NUM`, `--max_extracts=NUM`: Max number of page extracts to consider [default: 7].
|
@@ -82,11 +82,15 @@ python search_agent.py [OPTIONS] SEARCH_QUERY
|
|
82 |
### Examples
|
83 |
|
84 |
```bash
|
85 |
-
python search_agent.py -m openai
|
86 |
```
|
87 |
|
88 |
```bash
|
89 |
-
python search_agent.py -m
|
|
|
|
|
|
|
|
|
90 |
```
|
91 |
|
92 |
## License
|
|
|
72 |
- `-c`, `--copywrite`: First produce a draft, review it, and rewrite for a final text.
|
73 |
- `-d DOMAIN`, `--domain=DOMAIN`: Limit search to a specific domain.
|
74 |
- `-t TEMP`, `--temperature=TEMP`: Set the temperature of the LLM [default: 0.0].
|
75 |
+
- `-m MODEL`, `--model=MODEL`: Use a specific model [default: openai:gpt-4o-mini].
|
76 |
- `-e MODEL`, `--embedding_model=MODEL`: Use a specific embedding model [default: same provider as model].
|
77 |
- `-n NUM`, `--max_pages=NUM`: Max number of pages to retrieve [default: 10].
|
78 |
- `-x NUM`, `--max_extracts=NUM`: Max number of page extracts to consider [default: 7].
|
|
|
82 |
### Examples
|
83 |
|
84 |
```bash
|
85 |
+
python search_agent.py -m openai:gpt-4o-mini "Write a linked post about the current state of M&A for startups. Write in the style of Russ from Silicon Valley TV show."
|
86 |
```
|
87 |
|
88 |
```bash
|
89 |
+
python search_agent.py -m groq:llama-3.1-70b-versatile -e ollama:nomic-embed-text:latest -t 0.7 -n 20 -x 15 "Write a linked post about the state of M&A for startups in 2024. Write in the style of Russ from TV show Silicon Valley" -s
|
90 |
+
```
|
91 |
+
|
92 |
+
```bash
|
93 |
+
python search_agent.py -m groq -e openai "Write an engaging long linked post about the state of M&A for startups in 2024"
|
94 |
```
|
95 |
|
96 |
## License
|
models.py
CHANGED
@@ -28,10 +28,14 @@ from langchain_community.chat_models import ChatPerplexity
|
|
28 |
from langchain_together import ChatTogether
|
29 |
from langchain_together.embeddings import TogetherEmbeddings
|
30 |
|
31 |
-
|
|
|
|
|
|
|
|
|
32 |
|
33 |
def get_model(provider_model, temperature=0.0):
|
34 |
-
provider, model = (provider_model
|
35 |
match provider:
|
36 |
case 'bedrock':
|
37 |
if model is None:
|
@@ -76,8 +80,8 @@ def get_model(provider_model, temperature=0.0):
|
|
76 |
return chat_llm
|
77 |
|
78 |
|
79 |
-
def get_embedding_model(
|
80 |
-
provider, model = (
|
81 |
match provider:
|
82 |
case 'bedrock':
|
83 |
if model is None:
|
@@ -224,7 +228,7 @@ class TestGetModel(unittest.TestCase):
|
|
224 |
@patch('models.ChatGroq')
|
225 |
def test_groq_model(self, mock_groq):
|
226 |
result = get_model('groq')
|
227 |
-
mock_groq.assert_called_once_with(model_name='
|
228 |
self.assertEqual(result, mock_groq.return_value)
|
229 |
|
230 |
@patch('models.ChatOllama')
|
|
|
28 |
from langchain_together import ChatTogether
|
29 |
from langchain_together.embeddings import TogetherEmbeddings
|
30 |
|
31 |
+
def split_provider_model(provider_model):
|
32 |
+
parts = provider_model.split(':', 1)
|
33 |
+
provider = parts[0]
|
34 |
+
model = parts[1] if len(parts) > 1 else None
|
35 |
+
return provider, model
|
36 |
|
37 |
def get_model(provider_model, temperature=0.0):
|
38 |
+
provider, model = split_provider_model(provider_model)
|
39 |
match provider:
|
40 |
case 'bedrock':
|
41 |
if model is None:
|
|
|
80 |
return chat_llm
|
81 |
|
82 |
|
83 |
+
def get_embedding_model(provider_model):
|
84 |
+
provider, model = split_provider_model(provider_model)
|
85 |
match provider:
|
86 |
case 'bedrock':
|
87 |
if model is None:
|
|
|
228 |
@patch('models.ChatGroq')
|
229 |
def test_groq_model(self, mock_groq):
|
230 |
result = get_model('groq')
|
231 |
+
mock_groq.assert_called_once_with(model_name='llama-3.1-8b-instant', temperature=0.0)
|
232 |
self.assertEqual(result, mock_groq.return_value)
|
233 |
|
234 |
@patch('models.ChatOllama')
|
search_agent.py
CHANGED
@@ -12,6 +12,7 @@ Usage:
|
|
12 |
[--max_extracts=num]
|
13 |
[--use_selenium]
|
14 |
[--output=text]
|
|
|
15 |
SEARCH_QUERY
|
16 |
search_agent.py --version
|
17 |
|
@@ -27,6 +28,7 @@ Options:
|
|
27 |
-x num --max_extracts=num Max number of page extract to consider [default: 7]
|
28 |
-s --use_selenium Use selenium to fetch content from the web [default: False]
|
29 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
|
|
30 |
|
31 |
"""
|
32 |
|
@@ -80,6 +82,7 @@ if os.getenv("LANGCHAIN_API_KEY"):
|
|
80 |
)
|
81 |
@traceable(run_type="tool", name="search_agent")
|
82 |
def main(arguments):
|
|
|
83 |
copywrite_mode = arguments["--copywrite"]
|
84 |
model = arguments["--model"]
|
85 |
embedding_model = arguments["--embedding_model"]
|
@@ -98,6 +101,10 @@ def main(arguments):
|
|
98 |
else:
|
99 |
embedding_model = md.get_embedding_model(embedding_model)
|
100 |
|
|
|
|
|
|
|
|
|
101 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
102 |
optimize_search_query = wr.optimize_search_query(chat, query)
|
103 |
if len(optimize_search_query) < 3:
|
|
|
12 |
[--max_extracts=num]
|
13 |
[--use_selenium]
|
14 |
[--output=text]
|
15 |
+
[--verbose]
|
16 |
SEARCH_QUERY
|
17 |
search_agent.py --version
|
18 |
|
|
|
28 |
-x num --max_extracts=num Max number of page extract to consider [default: 7]
|
29 |
-s --use_selenium Use selenium to fetch content from the web [default: False]
|
30 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
31 |
+
-v --verbose Print verbose output [default: False]
|
32 |
|
33 |
"""
|
34 |
|
|
|
82 |
)
|
83 |
@traceable(run_type="tool", name="search_agent")
|
84 |
def main(arguments):
|
85 |
+
verbose = arguments["--verbose"]
|
86 |
copywrite_mode = arguments["--copywrite"]
|
87 |
model = arguments["--model"]
|
88 |
embedding_model = arguments["--embedding_model"]
|
|
|
101 |
else:
|
102 |
embedding_model = md.get_embedding_model(embedding_model)
|
103 |
|
104 |
+
if verbose:
|
105 |
+
console.log(f"Using model: {chat.model_name}")
|
106 |
+
console.log(f"Using embedding model: { embedding_model.model}")
|
107 |
+
|
108 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
109 |
optimize_search_query = wr.optimize_search_query(chat, query)
|
110 |
if len(optimize_search_query) < 3:
|