yonikremer
commited on
Commit
路
a671856
1
Parent(s):
6e4f775
removed the option to search the web
Browse files- app.py +0 -7
- hanlde_form_submit.py +1 -11
- prompt_engeneering.py +0 -86
- tests.py +3 -20
- user_instructions_hebrew.md +25 -15
app.py
CHANGED
@@ -39,12 +39,6 @@ with st.form("request_form"):
|
|
39 |
max_chars=2048,
|
40 |
)
|
41 |
|
42 |
-
web_search: bool = st.checkbox(
|
43 |
-
label="Web search",
|
44 |
-
value=True,
|
45 |
-
help="If checked, the model will get your prompt as well as some web search results."
|
46 |
-
)
|
47 |
-
|
48 |
submitted: bool = st.form_submit_button(
|
49 |
label="Generate",
|
50 |
help="Generate the output text.",
|
@@ -57,7 +51,6 @@ with st.form("request_form"):
|
|
57 |
selected_model_name,
|
58 |
output_length,
|
59 |
submitted_prompt,
|
60 |
-
web_search,
|
61 |
)
|
62 |
except CudaError as e:
|
63 |
st.error("Out of memory. Please try a smaller model, shorter prompt, or a smaller output length.")
|
|
|
39 |
max_chars=2048,
|
40 |
)
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
submitted: bool = st.form_submit_button(
|
43 |
label="Generate",
|
44 |
help="Generate the output text.",
|
|
|
51 |
selected_model_name,
|
52 |
output_length,
|
53 |
submitted_prompt,
|
|
|
54 |
)
|
55 |
except CudaError as e:
|
56 |
st.error("Out of memory. Please try a smaller model, shorter prompt, or a smaller output length.")
|
hanlde_form_submit.py
CHANGED
@@ -5,7 +5,6 @@ import streamlit as st
|
|
5 |
from grouped_sampling import GroupedSamplingPipeLine, is_supported, UnsupportedModelNameException
|
6 |
|
7 |
from download_repo import download_pytorch_model
|
8 |
-
from prompt_engeneering import rewrite_prompt
|
9 |
|
10 |
|
11 |
def is_downloaded(model_name: str) -> bool:
|
@@ -51,22 +50,16 @@ def generate_text(
|
|
51 |
pipeline: GroupedSamplingPipeLine,
|
52 |
prompt: str,
|
53 |
output_length: int,
|
54 |
-
web_search: bool,
|
55 |
) -> str:
|
56 |
"""
|
57 |
Generates text using the given pipeline.
|
58 |
:param pipeline: The pipeline to use. GroupedSamplingPipeLine.
|
59 |
:param prompt: The prompt to use. str.
|
60 |
:param output_length: The size of the text to generate in tokens. int > 0.
|
61 |
-
:param web_search: Whether to use web search or not. bool.
|
62 |
:return: The generated text. str.
|
63 |
"""
|
64 |
-
if web_search:
|
65 |
-
better_prompt = rewrite_prompt(prompt)
|
66 |
-
else:
|
67 |
-
better_prompt = prompt
|
68 |
return pipeline(
|
69 |
-
prompt_s=
|
70 |
max_new_tokens=output_length,
|
71 |
return_text=True,
|
72 |
return_full_text=False,
|
@@ -77,14 +70,12 @@ def on_form_submit(
|
|
77 |
model_name: str,
|
78 |
output_length: int,
|
79 |
prompt: str,
|
80 |
-
web_search: bool
|
81 |
) -> str:
|
82 |
"""
|
83 |
Called when the user submits the form.
|
84 |
:param model_name: The name of the model to use.
|
85 |
:param output_length: The size of the groups to use.
|
86 |
:param prompt: The prompt to use.
|
87 |
-
:param web_search: Whether to use web search or not.
|
88 |
:return: The output of the model.
|
89 |
:raises ValueError: If the model name is not supported, the output length is <= 0,
|
90 |
the prompt is empty or longer than
|
@@ -111,7 +102,6 @@ def on_form_submit(
|
|
111 |
pipeline=pipeline,
|
112 |
prompt=prompt,
|
113 |
output_length=output_length,
|
114 |
-
web_search=web_search,
|
115 |
)
|
116 |
generation_end_time = time()
|
117 |
generation_time = generation_end_time - generation_start_time
|
|
|
5 |
from grouped_sampling import GroupedSamplingPipeLine, is_supported, UnsupportedModelNameException
|
6 |
|
7 |
from download_repo import download_pytorch_model
|
|
|
8 |
|
9 |
|
10 |
def is_downloaded(model_name: str) -> bool:
|
|
|
50 |
pipeline: GroupedSamplingPipeLine,
|
51 |
prompt: str,
|
52 |
output_length: int,
|
|
|
53 |
) -> str:
|
54 |
"""
|
55 |
Generates text using the given pipeline.
|
56 |
:param pipeline: The pipeline to use. GroupedSamplingPipeLine.
|
57 |
:param prompt: The prompt to use. str.
|
58 |
:param output_length: The size of the text to generate in tokens. int > 0.
|
|
|
59 |
:return: The generated text. str.
|
60 |
"""
|
|
|
|
|
|
|
|
|
61 |
return pipeline(
|
62 |
+
prompt_s=prompt,
|
63 |
max_new_tokens=output_length,
|
64 |
return_text=True,
|
65 |
return_full_text=False,
|
|
|
70 |
model_name: str,
|
71 |
output_length: int,
|
72 |
prompt: str,
|
|
|
73 |
) -> str:
|
74 |
"""
|
75 |
Called when the user submits the form.
|
76 |
:param model_name: The name of the model to use.
|
77 |
:param output_length: The size of the groups to use.
|
78 |
:param prompt: The prompt to use.
|
|
|
79 |
:return: The output of the model.
|
80 |
:raises ValueError: If the model name is not supported, the output length is <= 0,
|
81 |
the prompt is empty or longer than
|
|
|
102 |
pipeline=pipeline,
|
103 |
prompt=prompt,
|
104 |
output_length=output_length,
|
|
|
105 |
)
|
106 |
generation_end_time = time()
|
107 |
generation_time = generation_end_time - generation_start_time
|
prompt_engeneering.py
DELETED
@@ -1,86 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from dataclasses import dataclass
|
3 |
-
from datetime import datetime
|
4 |
-
from typing import Generator, Dict, List
|
5 |
-
|
6 |
-
from googleapiclient.discovery import build
|
7 |
-
from streamlit import secrets
|
8 |
-
|
9 |
-
INSTRUCTIONS = "Instructions: " \
|
10 |
-
"Using the provided web search results, " \
|
11 |
-
"write a comprehensive reply to the given query. " \
|
12 |
-
"Make sure to cite results using [[number](URL)] notation after the reference. " \
|
13 |
-
"If the provided search results refer to multiple subjects with the same name, " \
|
14 |
-
"write separate answers for each subject."
|
15 |
-
|
16 |
-
|
17 |
-
def get_google_api_key():
|
18 |
-
"""Returns the Google API key from streamlit's secrets"""
|
19 |
-
try:
|
20 |
-
return secrets["google_search_api_key"]
|
21 |
-
except (FileNotFoundError, IsADirectoryError):
|
22 |
-
return os.environ["google_search_api_key"]
|
23 |
-
|
24 |
-
|
25 |
-
def get_google_cse_id():
|
26 |
-
"""Returns the Google CSE ID from streamlit's secrets"""
|
27 |
-
try:
|
28 |
-
return secrets["google_cse_id"]
|
29 |
-
except (FileNotFoundError, IsADirectoryError):
|
30 |
-
return os.environ["google_cse_id"]
|
31 |
-
|
32 |
-
|
33 |
-
def google_search(search_term, **kwargs) -> list:
|
34 |
-
service = build("customsearch", "v1", developerKey=get_google_api_key())
|
35 |
-
search_engine = service.cse()
|
36 |
-
res = search_engine.list(q=search_term, cx=get_google_cse_id(), **kwargs).execute()
|
37 |
-
return res['items']
|
38 |
-
|
39 |
-
|
40 |
-
@dataclass
|
41 |
-
class SearchResult:
|
42 |
-
__slots__ = ["title", "body", "url"]
|
43 |
-
title: str
|
44 |
-
body: str
|
45 |
-
url: str
|
46 |
-
|
47 |
-
|
48 |
-
def get_web_search_results(
|
49 |
-
query: str,
|
50 |
-
num_results: int,
|
51 |
-
) -> Generator[SearchResult, None, None]:
|
52 |
-
"""Gets a list of web search results using the Google search API"""
|
53 |
-
rew_results: List[Dict[str, str]] = google_search(
|
54 |
-
search_term=query,
|
55 |
-
num=num_results
|
56 |
-
)[:num_results]
|
57 |
-
for result in rew_results:
|
58 |
-
if result["snippet"].endswith("\xa0..."):
|
59 |
-
result["snippet"] = result["snippet"][:-4]
|
60 |
-
yield SearchResult(
|
61 |
-
title=result["title"],
|
62 |
-
body=result["snippet"],
|
63 |
-
url=result["link"],
|
64 |
-
)
|
65 |
-
|
66 |
-
|
67 |
-
def format_search_result(search_result: Generator[SearchResult, None, None]) -> str:
|
68 |
-
"""Formats a search result to be added to the prompt."""
|
69 |
-
ans = ""
|
70 |
-
for i, result in enumerate(search_result):
|
71 |
-
ans += f"[{i}] {result.body}\nURL: {result.url}\n\n"
|
72 |
-
return ans
|
73 |
-
|
74 |
-
|
75 |
-
def rewrite_prompt(
|
76 |
-
prompt: str,
|
77 |
-
) -> str:
|
78 |
-
"""Rewrites the prompt by adding web search results to it."""
|
79 |
-
raw_results = get_web_search_results(
|
80 |
-
query=prompt,
|
81 |
-
num_results=5,
|
82 |
-
)
|
83 |
-
formatted_results = "Web search results:\n" + format_search_result(raw_results)
|
84 |
-
formatted_date = "Current date: " + datetime.now().strftime("%d/%m/%Y")
|
85 |
-
formatted_prompt = f"Query: {prompt}"
|
86 |
-
return "\n".join([formatted_results, formatted_date, INSTRUCTIONS, formatted_prompt])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests.py
CHANGED
@@ -5,27 +5,10 @@ from grouped_sampling import GroupedSamplingPipeLine, get_full_models_list, Unsu
|
|
5 |
|
6 |
from on_server_start import download_useful_models
|
7 |
from hanlde_form_submit import create_pipeline, on_form_submit
|
8 |
-
from prompt_engeneering import rewrite_prompt
|
9 |
|
10 |
HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"
|
11 |
|
12 |
|
13 |
-
def test_prompt_engineering():
|
14 |
-
example_prompt = "Answer yes or no, is the sky blue?"
|
15 |
-
rewritten_prompt = rewrite_prompt(example_prompt)
|
16 |
-
assert rewritten_prompt.startswith("Web search results:")
|
17 |
-
assert rewritten_prompt.endswith("Query: Answer yes or no, is the sky blue?")
|
18 |
-
assert "Current date: " in rewritten_prompt
|
19 |
-
assert "Instructions: " in rewritten_prompt
|
20 |
-
|
21 |
-
|
22 |
-
def test_get_supported_model_names():
|
23 |
-
supported_model_names = get_full_models_list()
|
24 |
-
assert len(supported_model_names) > 0
|
25 |
-
assert "gpt2" in supported_model_names
|
26 |
-
assert all(isinstance(name, str) for name in supported_model_names)
|
27 |
-
|
28 |
-
|
29 |
def test_on_server_start():
|
30 |
download_useful_models()
|
31 |
assert os.path.exists(HUGGING_FACE_CACHE_DIR)
|
@@ -36,15 +19,15 @@ def test_on_form_submit():
|
|
36 |
model_name = "gpt2"
|
37 |
output_length = 10
|
38 |
prompt = "Answer yes or no, is the sky blue?"
|
39 |
-
output = on_form_submit(model_name, output_length, prompt
|
40 |
assert output is not None
|
41 |
assert len(output) > 0
|
42 |
empty_prompt = ""
|
43 |
with pytest.raises(ValueError):
|
44 |
-
on_form_submit(model_name, output_length, empty_prompt
|
45 |
unsupported_model_name = "unsupported_model_name"
|
46 |
with pytest.raises(UnsupportedModelNameException):
|
47 |
-
on_form_submit(unsupported_model_name, output_length, prompt
|
48 |
|
49 |
|
50 |
@pytest.mark.parametrize(
|
|
|
5 |
|
6 |
from on_server_start import download_useful_models
|
7 |
from hanlde_form_submit import create_pipeline, on_form_submit
|
|
|
8 |
|
9 |
HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"
|
10 |
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
def test_on_server_start():
|
13 |
download_useful_models()
|
14 |
assert os.path.exists(HUGGING_FACE_CACHE_DIR)
|
|
|
19 |
model_name = "gpt2"
|
20 |
output_length = 10
|
21 |
prompt = "Answer yes or no, is the sky blue?"
|
22 |
+
output = on_form_submit(model_name, output_length, prompt)
|
23 |
assert output is not None
|
24 |
assert len(output) > 0
|
25 |
empty_prompt = ""
|
26 |
with pytest.raises(ValueError):
|
27 |
+
on_form_submit(model_name, output_length, empty_prompt)
|
28 |
unsupported_model_name = "unsupported_model_name"
|
29 |
with pytest.raises(UnsupportedModelNameException):
|
30 |
+
on_form_submit(unsupported_model_name, output_length, prompt)
|
31 |
|
32 |
|
33 |
@pytest.mark.parametrize(
|
user_instructions_hebrew.md
CHANGED
@@ -2,16 +2,27 @@
|
|
2 |
|
3 |
讘讚诪讜 讛讝讛, 讗转诐 讬讻讜诇讬诐 诇讛砖转诪砖 讘拽诇讜转 讘讚讙讬诪讛 讘拽讘讜爪讜转.
|
4 |
|
5 |
-
|
6 |
-
转讻转讘讜 诇诪讜讚诇 讛讜专讗讜转 讘专讜专讜转 讜转讘讞专讜 讗转 讗讜专讱 讛讟拽住讟 砖讗转诐 专讜爪讬诐 砖讛诪讜讚诇 讬爪讜专.
|
7 |
|
8 |
-
|
9 |
|
10 |
-
1.
|
11 |
-
2.
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
诇讚讜讙诪讛:
|
17 |
|
@@ -21,7 +32,7 @@
|
|
21 |
|
22 |
Answer: """
|
23 |
|
24 |
-
|
25 |
|
26 |
Instruction: """Label the following sentences to positive or negative"""
|
27 |
|
@@ -31,12 +42,11 @@
|
|
31 |
|
32 |
Sentence: """your sentence here.""" Sentiment: """
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
12. 转转讞讬诇讜 讗转 讛诪砖讬诪讛 讘注爪诪讻诐
|
40 |
|
41 |
讚讜讙诪讛:
|
42 |
|
|
|
2 |
|
3 |
讘讚诪讜 讛讝讛, 讗转诐 讬讻讜诇讬诐 诇讛砖转诪砖 讘拽诇讜转 讘讚讙讬诪讛 讘拽讘讜爪讜转.
|
4 |
|
5 |
+
## 砖诇讘 专讗砖讜谉 - 转讘讞专讜 诪讜讚诇
|
|
|
6 |
|
7 |
+
- 讛诪讜讚诇讬诐 讘讚诪讜 讛诐
|
8 |
|
9 |
+
1. opt-iml-max-1.3B (讛诪讜讚诇 讛拽讟谉)
|
10 |
+
2. opt-iml-max-30B (讛诪讜讚诇 讛讙讚讜诇)
|
11 |
+
|
12 |
+
讛诪讜讚诇 讛讙讚讜诇 讬讜爪专 讟拽住讟讬诐 讬讜转专 讟讜讘讬诐, 讗讱 讘讗讬讟讬讜转.
|
13 |
+
|
14 |
+
讛诪讜讚诇 讛拽讟谉 讬讜爪专 讟拽住讟讬诐 拽爪转 驻讞讜转 讟讜讘讬诐 讗讱 讘诪讛讬专讜转.
|
15 |
+
|
16 |
+
转讻转讘讜 诇讗诇讙讜专讬转诐 讛讜专讗讜转 讘专讜专讜转 讜转讘讞专讜 讗转 讗讜专讱 讛讟拽住讟 砖讗转诐 专讜爪讬诐 砖讛讗诇讙讜专讬转诐 讬爪讜专.
|
17 |
+
|
18 |
+
## 转住讘讬专讜 诇讗诇讙讜专讬转诐 诪讛 诇注砖讜转
|
19 |
+
|
20 |
+
### 注爪讜转 诇讻转讬讘转 讛拽诇讟 诇诪讜讚诇
|
21 |
+
|
22 |
+
1. 转讝讻专讜 砖讝讛 专拽 讚诪讜, 诇讗 谞讬转谉 诇讛专讬抓 讘讚诪讜 讗转 讛诪讜讚诇讬诐 讛讞讝拽讬诐 讘讬讜转专 讗诇讗 专拽 诪讜讚诇讬诐 驻转讜讞讬诐 诇爪讬讘讜专.
|
23 |
+
2. 转讻转讘讜 讘讗谞讙诇讬转 讜诇讗 讘注讘专讬转. 专讜讘 讛诪讜讚诇讬诐 诇讗 转讜诪讻讬诐 讘注讘专讬转 讻诇诇 讜讗诇讛 砖讻谉, 诇讗 讬讜爪专讬诐 讟拽住讟讬诐 讘讗讬讻讜转 诪住驻讬拽 讟讜讘讛 讜诇讻谉 讛讚诪讜 诇讗 转讜诪讱 讘讛诐.
|
24 |
+
3. 转砖拽讬注讜 诪讞砖讘讛 讘讗讜专讱 讛讟拽住讟 砖讗转诐 专讜爪讬诐 砖讛讗诇讙讜专讬转诐 讬爪讜专.
|
25 |
+
4. 转讙讬讚讜 诇讗诇讙讜专讬转诐 讘讚讬讜拽 诪讛 诇注砖讜转 - 讛转讞讬诇讜 讘讛讜专讗讛 讜讛驻专讬讚讜 讗转 讛讞诇拽讬诐
|
26 |
|
27 |
诇讚讜讙诪讛:
|
28 |
|
|
|
32 |
|
33 |
Answer: """
|
34 |
|
35 |
+
5. 讛砖转诪砖讜 讘转讘谞讬讜转 注诐 讚讜讙诪讜转 - 诇讚讜讙诪讛:
|
36 |
|
37 |
Instruction: """Label the following sentences to positive or negative"""
|
38 |
|
|
|
42 |
|
43 |
Sentence: """your sentence here.""" Sentiment: """
|
44 |
|
45 |
+
6. 讗诇 转住讬讬诪讜 讗转 讛拽诇讟 诇讗诇讙讜专讬转诐 讘专讜讜讞
|
46 |
+
7. 转注讝专讜 诇讗诇讙讜讬专转诐 - 转住驻拽讜 诇讗诇讙讜专讬转诐 讗转 讻诇 讛诪讬讚注 讛谞讚专砖
|
47 |
+
8. 讘诪砖讬诪讜转 诪住讜讘讻讜转, 讛转讞讬诇讜 讗转 讛拽诇讟 诇讗诇讙讜专讬转诐 讘诪砖驻讟 "Let's think step by step"
|
48 |
+
9. 转讛讬讜 住驻讬爪驻讬讬诐
|
49 |
+
10. 转转讞讬诇讜 讗转 讛诪砖讬诪讛 讘注爪诪讻诐
|
|
|
50 |
|
51 |
讚讜讙诪讛:
|
52 |
|