yonikremer commited on
Commit
a671856
โ€ข
1 Parent(s): 6e4f775

removed the option to search the web

Browse files
Files changed (5) hide show
  1. app.py +0 -7
  2. hanlde_form_submit.py +1 -11
  3. prompt_engeneering.py +0 -86
  4. tests.py +3 -20
  5. user_instructions_hebrew.md +25 -15
app.py CHANGED
@@ -39,12 +39,6 @@ with st.form("request_form"):
39
  max_chars=2048,
40
  )
41
 
42
- web_search: bool = st.checkbox(
43
- label="Web search",
44
- value=True,
45
- help="If checked, the model will get your prompt as well as some web search results."
46
- )
47
-
48
  submitted: bool = st.form_submit_button(
49
  label="Generate",
50
  help="Generate the output text.",
@@ -57,7 +51,6 @@ with st.form("request_form"):
57
  selected_model_name,
58
  output_length,
59
  submitted_prompt,
60
- web_search,
61
  )
62
  except CudaError as e:
63
  st.error("Out of memory. Please try a smaller model, shorter prompt, or a smaller output length.")
 
39
  max_chars=2048,
40
  )
41
 
 
 
 
 
 
 
42
  submitted: bool = st.form_submit_button(
43
  label="Generate",
44
  help="Generate the output text.",
 
51
  selected_model_name,
52
  output_length,
53
  submitted_prompt,
 
54
  )
55
  except CudaError as e:
56
  st.error("Out of memory. Please try a smaller model, shorter prompt, or a smaller output length.")
hanlde_form_submit.py CHANGED
@@ -5,7 +5,6 @@ import streamlit as st
5
  from grouped_sampling import GroupedSamplingPipeLine, is_supported, UnsupportedModelNameException
6
 
7
  from download_repo import download_pytorch_model
8
- from prompt_engeneering import rewrite_prompt
9
 
10
 
11
  def is_downloaded(model_name: str) -> bool:
@@ -51,22 +50,16 @@ def generate_text(
51
  pipeline: GroupedSamplingPipeLine,
52
  prompt: str,
53
  output_length: int,
54
- web_search: bool,
55
  ) -> str:
56
  """
57
  Generates text using the given pipeline.
58
  :param pipeline: The pipeline to use. GroupedSamplingPipeLine.
59
  :param prompt: The prompt to use. str.
60
  :param output_length: The size of the text to generate in tokens. int > 0.
61
- :param web_search: Whether to use web search or not. bool.
62
  :return: The generated text. str.
63
  """
64
- if web_search:
65
- better_prompt = rewrite_prompt(prompt)
66
- else:
67
- better_prompt = prompt
68
  return pipeline(
69
- prompt_s=better_prompt,
70
  max_new_tokens=output_length,
71
  return_text=True,
72
  return_full_text=False,
@@ -77,14 +70,12 @@ def on_form_submit(
77
  model_name: str,
78
  output_length: int,
79
  prompt: str,
80
- web_search: bool
81
  ) -> str:
82
  """
83
  Called when the user submits the form.
84
  :param model_name: The name of the model to use.
85
  :param output_length: The size of the groups to use.
86
  :param prompt: The prompt to use.
87
- :param web_search: Whether to use web search or not.
88
  :return: The output of the model.
89
  :raises ValueError: If the model name is not supported, the output length is <= 0,
90
  the prompt is empty or longer than
@@ -111,7 +102,6 @@ def on_form_submit(
111
  pipeline=pipeline,
112
  prompt=prompt,
113
  output_length=output_length,
114
- web_search=web_search,
115
  )
116
  generation_end_time = time()
117
  generation_time = generation_end_time - generation_start_time
 
5
  from grouped_sampling import GroupedSamplingPipeLine, is_supported, UnsupportedModelNameException
6
 
7
  from download_repo import download_pytorch_model
 
8
 
9
 
10
  def is_downloaded(model_name: str) -> bool:
 
50
  pipeline: GroupedSamplingPipeLine,
51
  prompt: str,
52
  output_length: int,
 
53
  ) -> str:
54
  """
55
  Generates text using the given pipeline.
56
  :param pipeline: The pipeline to use. GroupedSamplingPipeLine.
57
  :param prompt: The prompt to use. str.
58
  :param output_length: The size of the text to generate in tokens. int > 0.
 
59
  :return: The generated text. str.
60
  """
 
 
 
 
61
  return pipeline(
62
+ prompt_s=prompt,
63
  max_new_tokens=output_length,
64
  return_text=True,
65
  return_full_text=False,
 
70
  model_name: str,
71
  output_length: int,
72
  prompt: str,
 
73
  ) -> str:
74
  """
75
  Called when the user submits the form.
76
  :param model_name: The name of the model to use.
77
  :param output_length: The size of the groups to use.
78
  :param prompt: The prompt to use.
 
79
  :return: The output of the model.
80
  :raises ValueError: If the model name is not supported, the output length is <= 0,
81
  the prompt is empty or longer than
 
102
  pipeline=pipeline,
103
  prompt=prompt,
104
  output_length=output_length,
 
105
  )
106
  generation_end_time = time()
107
  generation_time = generation_end_time - generation_start_time
prompt_engeneering.py DELETED
@@ -1,86 +0,0 @@
1
- import os
2
- from dataclasses import dataclass
3
- from datetime import datetime
4
- from typing import Generator, Dict, List
5
-
6
- from googleapiclient.discovery import build
7
- from streamlit import secrets
8
-
9
- INSTRUCTIONS = "Instructions: " \
10
- "Using the provided web search results, " \
11
- "write a comprehensive reply to the given query. " \
12
- "Make sure to cite results using [[number](URL)] notation after the reference. " \
13
- "If the provided search results refer to multiple subjects with the same name, " \
14
- "write separate answers for each subject."
15
-
16
-
17
- def get_google_api_key():
18
- """Returns the Google API key from streamlit's secrets"""
19
- try:
20
- return secrets["google_search_api_key"]
21
- except (FileNotFoundError, IsADirectoryError):
22
- return os.environ["google_search_api_key"]
23
-
24
-
25
- def get_google_cse_id():
26
- """Returns the Google CSE ID from streamlit's secrets"""
27
- try:
28
- return secrets["google_cse_id"]
29
- except (FileNotFoundError, IsADirectoryError):
30
- return os.environ["google_cse_id"]
31
-
32
-
33
- def google_search(search_term, **kwargs) -> list:
34
- service = build("customsearch", "v1", developerKey=get_google_api_key())
35
- search_engine = service.cse()
36
- res = search_engine.list(q=search_term, cx=get_google_cse_id(), **kwargs).execute()
37
- return res['items']
38
-
39
-
40
- @dataclass
41
- class SearchResult:
42
- __slots__ = ["title", "body", "url"]
43
- title: str
44
- body: str
45
- url: str
46
-
47
-
48
- def get_web_search_results(
49
- query: str,
50
- num_results: int,
51
- ) -> Generator[SearchResult, None, None]:
52
- """Gets a list of web search results using the Google search API"""
53
- rew_results: List[Dict[str, str]] = google_search(
54
- search_term=query,
55
- num=num_results
56
- )[:num_results]
57
- for result in rew_results:
58
- if result["snippet"].endswith("\xa0..."):
59
- result["snippet"] = result["snippet"][:-4]
60
- yield SearchResult(
61
- title=result["title"],
62
- body=result["snippet"],
63
- url=result["link"],
64
- )
65
-
66
-
67
- def format_search_result(search_result: Generator[SearchResult, None, None]) -> str:
68
- """Formats a search result to be added to the prompt."""
69
- ans = ""
70
- for i, result in enumerate(search_result):
71
- ans += f"[{i}] {result.body}\nURL: {result.url}\n\n"
72
- return ans
73
-
74
-
75
- def rewrite_prompt(
76
- prompt: str,
77
- ) -> str:
78
- """Rewrites the prompt by adding web search results to it."""
79
- raw_results = get_web_search_results(
80
- query=prompt,
81
- num_results=5,
82
- )
83
- formatted_results = "Web search results:\n" + format_search_result(raw_results)
84
- formatted_date = "Current date: " + datetime.now().strftime("%d/%m/%Y")
85
- formatted_prompt = f"Query: {prompt}"
86
- return "\n".join([formatted_results, formatted_date, INSTRUCTIONS, formatted_prompt])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests.py CHANGED
@@ -5,27 +5,10 @@ from grouped_sampling import GroupedSamplingPipeLine, get_full_models_list, Unsu
5
 
6
  from on_server_start import download_useful_models
7
  from hanlde_form_submit import create_pipeline, on_form_submit
8
- from prompt_engeneering import rewrite_prompt
9
 
10
  HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"
11
 
12
 
13
- def test_prompt_engineering():
14
- example_prompt = "Answer yes or no, is the sky blue?"
15
- rewritten_prompt = rewrite_prompt(example_prompt)
16
- assert rewritten_prompt.startswith("Web search results:")
17
- assert rewritten_prompt.endswith("Query: Answer yes or no, is the sky blue?")
18
- assert "Current date: " in rewritten_prompt
19
- assert "Instructions: " in rewritten_prompt
20
-
21
-
22
- def test_get_supported_model_names():
23
- supported_model_names = get_full_models_list()
24
- assert len(supported_model_names) > 0
25
- assert "gpt2" in supported_model_names
26
- assert all(isinstance(name, str) for name in supported_model_names)
27
-
28
-
29
  def test_on_server_start():
30
  download_useful_models()
31
  assert os.path.exists(HUGGING_FACE_CACHE_DIR)
@@ -36,15 +19,15 @@ def test_on_form_submit():
36
  model_name = "gpt2"
37
  output_length = 10
38
  prompt = "Answer yes or no, is the sky blue?"
39
- output = on_form_submit(model_name, output_length, prompt, web_search=False)
40
  assert output is not None
41
  assert len(output) > 0
42
  empty_prompt = ""
43
  with pytest.raises(ValueError):
44
- on_form_submit(model_name, output_length, empty_prompt, web_search=False)
45
  unsupported_model_name = "unsupported_model_name"
46
  with pytest.raises(UnsupportedModelNameException):
47
- on_form_submit(unsupported_model_name, output_length, prompt, web_search=False)
48
 
49
 
50
  @pytest.mark.parametrize(
 
5
 
6
  from on_server_start import download_useful_models
7
  from hanlde_form_submit import create_pipeline, on_form_submit
 
8
 
9
  HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def test_on_server_start():
13
  download_useful_models()
14
  assert os.path.exists(HUGGING_FACE_CACHE_DIR)
 
19
  model_name = "gpt2"
20
  output_length = 10
21
  prompt = "Answer yes or no, is the sky blue?"
22
+ output = on_form_submit(model_name, output_length, prompt)
23
  assert output is not None
24
  assert len(output) > 0
25
  empty_prompt = ""
26
  with pytest.raises(ValueError):
27
+ on_form_submit(model_name, output_length, empty_prompt)
28
  unsupported_model_name = "unsupported_model_name"
29
  with pytest.raises(UnsupportedModelNameException):
30
+ on_form_submit(unsupported_model_name, output_length, prompt)
31
 
32
 
33
  @pytest.mark.parametrize(
user_instructions_hebrew.md CHANGED
@@ -2,16 +2,27 @@
2
 
3
  ื‘ื“ืžื• ื”ื–ื”, ืืชื ื™ื›ื•ืœื™ื ืœื”ืฉืชืžืฉ ื‘ืงืœื•ืช ื‘ื“ื’ื™ืžื” ื‘ืงื‘ื•ืฆื•ืช.
4
 
5
- ืชื‘ื—ืจื• ืžื•ื“ืœ ืฉื™ืฉ ืœื• ืœืคื—ื•ืช ืœื™ื™ืง ืื—ื“ ืžืชื•ืš [ื”ืžืื’ืจ](https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch&sort=downloads)
6
- ืชื›ืชื‘ื• ืœืžื•ื“ืœ ื”ื•ืจืื•ืช ื‘ืจื•ืจื•ืช ื•ืชื‘ื—ืจื• ืืช ืื•ืจืš ื”ื˜ืงืกื˜ ืฉืืชื ืจื•ืฆื™ื ืฉื”ืžื•ื“ืœ ื™ืฆื•ืจ.
7
 
8
- ## ืขืฆื•ืช ืœื›ืชื™ื‘ืช ื”ืคืจื•ืžืคื˜
9
 
10
- 1. ืชืฉืชืžืฉื• ื‘ืžื•ื“ืœ ื”ืžืชืื™ื - ื›ืœ ืžื•ื“ืœ ื ื•ืฆืจ ืœืžืฉื™ืžื” ืžืกื•ื™ืžืช. ืชืฉืืœื• ืื•ืชื™ ืื™ื–ื” ืžื•ื“ืœ ืœื‘ื—ื•ืจ ืœืžืฉื™ืžื” ืฉืœื›ื
11
- 2. ืชืฉืงื™ืขื• ืžื—ืฉื‘ื” ื‘ืื•ืจืš ื”ื˜ืงืกื˜ ืฉืืชื ืจื•ืฆื™ื ืฉื”ืืœื’ื•ืจื™ืชื ื™ืฆื•ืจ
12
- 3. ืชื–ื›ืจื• ืฉื–ื” ืจืง ื“ืžื•, ื”ื“ืžื• ืจืฅ ืขืœ ืžืขื‘ื“ ื’ืจืคื™ ืื—ื“ ื‘ืœื‘ื“ ื•ืœื›ืŸ ืœื ื ื™ืชืŸ ืœื”ืจื™ืฅ ื‘ื“ืžื• ืืช ื”ืžื•ื“ืœื™ื ื”ื—ื–ืงื™ื ื‘ื™ื•ืชืจ ืืœื ืจืง ืžื•ื“ืœื™ื ืงื˜ื ื™ื ื•ืคื•ืžื‘ื™ื™ื
13
- 4. ืชื›ืชื‘ื• ื‘ืื ื’ืœื™ืช ื•ืœื ื‘ืขื‘ืจื™ืช. ืจื•ื‘ ื”ืžื•ื“ืœื™ื ืœื ืชื•ืžื›ื™ื ื‘ืขื‘ืจื™ืช ื›ืœืœ ื•ืืœื” ืฉื›ืŸ, ืžื’ื™ืขื™ื ืœืชื•ืฆืื•ืช ื”ืจื‘ื” ื™ื•ืชืจ ื˜ื•ื‘ื•ืช ื‘ืื ื’ืœื™ืช
14
- 5. ืชื’ื™ื“ื• ืœืืœื’ื•ืจื™ืชื ื‘ื“ื™ื•ืง ืžื” ืœืขืฉื•ืช - ื”ืชื—ื™ืœื• ื‘ื”ื•ืจืื” ื•ื”ืคืจื™ื“ื• ืืช ื”ื—ืœืงื™ื
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  ืœื“ื•ื’ืžื”:
17
 
@@ -21,7 +32,7 @@
21
 
22
  Answer: """
23
 
24
- 6. ื”ืฉืชืžืฉื• ื‘ืชื‘ื ื™ื•ืช ืขื ื“ื•ื’ืžื•ืช - ืœื“ื•ื’ืžื”:
25
 
26
  Instruction: """Label the following sentences to positive or negative"""
27
 
@@ -31,12 +42,11 @@
31
 
32
  Sentence: """your sentence here.""" Sentiment: """
33
 
34
- 7. ืืœ ืชืกื™ื™ืžื• ืืช ื”ืคืจื•ืžืคื˜ ื‘ืจื•ื•ื—
35
- 8. ืชืฉืชืžืฉื• ื‘ื—ื™ืคื•ืฉ ื‘ืื™ื ื˜ืจื ื˜ ืจืง ืื ื”ืžื™ื“ืข ื‘ืื™ื ื˜ืจื ื˜ ื”ื•ื ืงืจื™ื˜ื™ ืœื”ืฆืœื—ืช ื”ืžืฉื™ืžื” ืฉื”ื˜ืœืชื ืขืœ ื”ืืœื’ื•ืจื™ืชื. ืžื™ื“ืข ืœื ืจืœื•ื•ื ื˜ื™ ื™ื‘ืœื‘ืœ ืืช ื”ืืœื’ื•ืจื™ืชื
36
- 9. ืชืขื–ืจื• ืœืืœื’ื•ื™ืจืชื - ืชืกืคืงื• ืœืืœื’ื•ืจื™ืชื ืืช ื›ืœ ื”ืžื™ื“ืข ื”ื ื“ืจืฉ
37
- 10. ื‘ืžืฉื™ืžื•ืช ืžืกื•ื‘ื›ื•ืช, ื”ืชื—ื™ืœื• ืืช ื”ืคืจื•ืžืคื˜ ื‘ืžืฉืคื˜ "Let's think step by step"
38
- 11. ืชื”ื™ื• ืกืคื™ืฆืคื™ื™ื
39
- 12. ืชืชื—ื™ืœื• ืืช ื”ืžืฉื™ืžื” ื‘ืขืฆืžื›ื
40
 
41
  ื“ื•ื’ืžื”:
42
 
 
2
 
3
  ื‘ื“ืžื• ื”ื–ื”, ืืชื ื™ื›ื•ืœื™ื ืœื”ืฉืชืžืฉ ื‘ืงืœื•ืช ื‘ื“ื’ื™ืžื” ื‘ืงื‘ื•ืฆื•ืช.
4
 
5
+ ## ืฉืœื‘ ืจืืฉื•ืŸ - ืชื‘ื—ืจื• ืžื•ื“ืœ
 
6
 
7
+ - ื”ืžื•ื“ืœื™ื ื‘ื“ืžื• ื”ื
8
 
9
+ 1. opt-iml-max-1.3B (ื”ืžื•ื“ืœ ื”ืงื˜ืŸ)
10
+ 2. opt-iml-max-30B (ื”ืžื•ื“ืœ ื”ื’ื“ื•ืœ)
11
+
12
+ ื”ืžื•ื“ืœ ื”ื’ื“ื•ืœ ื™ื•ืฆืจ ื˜ืงืกื˜ื™ื ื™ื•ืชืจ ื˜ื•ื‘ื™ื, ืืš ื‘ืื™ื˜ื™ื•ืช.
13
+
14
+ ื”ืžื•ื“ืœ ื”ืงื˜ืŸ ื™ื•ืฆืจ ื˜ืงืกื˜ื™ื ืงืฆืช ืคื—ื•ืช ื˜ื•ื‘ื™ื ืืš ื‘ืžื”ื™ืจื•ืช.
15
+
16
+ ืชื›ืชื‘ื• ืœืืœื’ื•ืจื™ืชื ื”ื•ืจืื•ืช ื‘ืจื•ืจื•ืช ื•ืชื‘ื—ืจื• ืืช ืื•ืจืš ื”ื˜ืงืกื˜ ืฉืืชื ืจื•ืฆื™ื ืฉื”ืืœื’ื•ืจื™ืชื ื™ืฆื•ืจ.
17
+
18
+ ## ืชืกื‘ื™ืจื• ืœืืœื’ื•ืจื™ืชื ืžื” ืœืขืฉื•ืช
19
+
20
+ ### ืขืฆื•ืช ืœื›ืชื™ื‘ืช ื”ืงืœื˜ ืœืžื•ื“ืœ
21
+
22
+ 1. ืชื–ื›ืจื• ืฉื–ื” ืจืง ื“ืžื•, ืœื ื ื™ืชืŸ ืœื”ืจื™ืฅ ื‘ื“ืžื• ืืช ื”ืžื•ื“ืœื™ื ื”ื—ื–ืงื™ื ื‘ื™ื•ืชืจ ืืœื ืจืง ืžื•ื“ืœื™ื ืคืชื•ื—ื™ื ืœืฆื™ื‘ื•ืจ.
23
+ 2. ืชื›ืชื‘ื• ื‘ืื ื’ืœื™ืช ื•ืœื ื‘ืขื‘ืจื™ืช. ืจื•ื‘ ื”ืžื•ื“ืœื™ื ืœื ืชื•ืžื›ื™ื ื‘ืขื‘ืจื™ืช ื›ืœืœ ื•ืืœื” ืฉื›ืŸ, ืœื ื™ื•ืฆืจื™ื ื˜ืงืกื˜ื™ื ื‘ืื™ื›ื•ืช ืžืกืคื™ืง ื˜ื•ื‘ื” ื•ืœื›ืŸ ื”ื“ืžื• ืœื ืชื•ืžืš ื‘ื”ื.
24
+ 3. ืชืฉืงื™ืขื• ืžื—ืฉื‘ื” ื‘ืื•ืจืš ื”ื˜ืงืกื˜ ืฉืืชื ืจื•ืฆื™ื ืฉื”ืืœื’ื•ืจื™ืชื ื™ืฆื•ืจ.
25
+ 4. ืชื’ื™ื“ื• ืœืืœื’ื•ืจื™ืชื ื‘ื“ื™ื•ืง ืžื” ืœืขืฉื•ืช - ื”ืชื—ื™ืœื• ื‘ื”ื•ืจืื” ื•ื”ืคืจื™ื“ื• ืืช ื”ื—ืœืงื™ื
26
 
27
  ืœื“ื•ื’ืžื”:
28
 
 
32
 
33
  Answer: """
34
 
35
+ 5. ื”ืฉืชืžืฉื• ื‘ืชื‘ื ื™ื•ืช ืขื ื“ื•ื’ืžื•ืช - ืœื“ื•ื’ืžื”:
36
 
37
  Instruction: """Label the following sentences to positive or negative"""
38
 
 
42
 
43
  Sentence: """your sentence here.""" Sentiment: """
44
 
45
+ 6. ืืœ ืชืกื™ื™ืžื• ืืช ื”ืงืœื˜ ืœืืœื’ื•ืจื™ืชื ื‘ืจื•ื•ื—
46
+ 7. ืชืขื–ืจื• ืœืืœื’ื•ื™ืจืชื - ืชืกืคืงื• ืœืืœื’ื•ืจื™ืชื ืืช ื›ืœ ื”ืžื™ื“ืข ื”ื ื“ืจืฉ
47
+ 8. ื‘ืžืฉื™ืžื•ืช ืžืกื•ื‘ื›ื•ืช, ื”ืชื—ื™ืœื• ืืช ื”ืงืœื˜ ืœืืœื’ื•ืจื™ืชื ื‘ืžืฉืคื˜ "Let's think step by step"
48
+ 9. ืชื”ื™ื• ืกืคื™ืฆืคื™ื™ื
49
+ 10. ืชืชื—ื™ืœื• ืืช ื”ืžืฉื™ืžื” ื‘ืขืฆืžื›ื
 
50
 
51
  ื“ื•ื’ืžื”:
52