yonikremer commited on
Commit
a671856
1 Parent(s): 6e4f775

removed the option to search the web

Browse files
Files changed (5) hide show
  1. app.py +0 -7
  2. hanlde_form_submit.py +1 -11
  3. prompt_engeneering.py +0 -86
  4. tests.py +3 -20
  5. user_instructions_hebrew.md +25 -15
app.py CHANGED
@@ -39,12 +39,6 @@ with st.form("request_form"):
39
  max_chars=2048,
40
  )
41
 
42
- web_search: bool = st.checkbox(
43
- label="Web search",
44
- value=True,
45
- help="If checked, the model will get your prompt as well as some web search results."
46
- )
47
-
48
  submitted: bool = st.form_submit_button(
49
  label="Generate",
50
  help="Generate the output text.",
@@ -57,7 +51,6 @@ with st.form("request_form"):
57
  selected_model_name,
58
  output_length,
59
  submitted_prompt,
60
- web_search,
61
  )
62
  except CudaError as e:
63
  st.error("Out of memory. Please try a smaller model, shorter prompt, or a smaller output length.")
 
39
  max_chars=2048,
40
  )
41
 
 
 
 
 
 
 
42
  submitted: bool = st.form_submit_button(
43
  label="Generate",
44
  help="Generate the output text.",
 
51
  selected_model_name,
52
  output_length,
53
  submitted_prompt,
 
54
  )
55
  except CudaError as e:
56
  st.error("Out of memory. Please try a smaller model, shorter prompt, or a smaller output length.")
hanlde_form_submit.py CHANGED
@@ -5,7 +5,6 @@ import streamlit as st
5
  from grouped_sampling import GroupedSamplingPipeLine, is_supported, UnsupportedModelNameException
6
 
7
  from download_repo import download_pytorch_model
8
- from prompt_engeneering import rewrite_prompt
9
 
10
 
11
  def is_downloaded(model_name: str) -> bool:
@@ -51,22 +50,16 @@ def generate_text(
51
  pipeline: GroupedSamplingPipeLine,
52
  prompt: str,
53
  output_length: int,
54
- web_search: bool,
55
  ) -> str:
56
  """
57
  Generates text using the given pipeline.
58
  :param pipeline: The pipeline to use. GroupedSamplingPipeLine.
59
  :param prompt: The prompt to use. str.
60
  :param output_length: The size of the text to generate in tokens. int > 0.
61
- :param web_search: Whether to use web search or not. bool.
62
  :return: The generated text. str.
63
  """
64
- if web_search:
65
- better_prompt = rewrite_prompt(prompt)
66
- else:
67
- better_prompt = prompt
68
  return pipeline(
69
- prompt_s=better_prompt,
70
  max_new_tokens=output_length,
71
  return_text=True,
72
  return_full_text=False,
@@ -77,14 +70,12 @@ def on_form_submit(
77
  model_name: str,
78
  output_length: int,
79
  prompt: str,
80
- web_search: bool
81
  ) -> str:
82
  """
83
  Called when the user submits the form.
84
  :param model_name: The name of the model to use.
85
  :param output_length: The size of the groups to use.
86
  :param prompt: The prompt to use.
87
- :param web_search: Whether to use web search or not.
88
  :return: The output of the model.
89
  :raises ValueError: If the model name is not supported, the output length is <= 0,
90
  the prompt is empty or longer than
@@ -111,7 +102,6 @@ def on_form_submit(
111
  pipeline=pipeline,
112
  prompt=prompt,
113
  output_length=output_length,
114
- web_search=web_search,
115
  )
116
  generation_end_time = time()
117
  generation_time = generation_end_time - generation_start_time
 
5
  from grouped_sampling import GroupedSamplingPipeLine, is_supported, UnsupportedModelNameException
6
 
7
  from download_repo import download_pytorch_model
 
8
 
9
 
10
  def is_downloaded(model_name: str) -> bool:
 
50
  pipeline: GroupedSamplingPipeLine,
51
  prompt: str,
52
  output_length: int,
 
53
  ) -> str:
54
  """
55
  Generates text using the given pipeline.
56
  :param pipeline: The pipeline to use. GroupedSamplingPipeLine.
57
  :param prompt: The prompt to use. str.
58
  :param output_length: The size of the text to generate in tokens. int > 0.
 
59
  :return: The generated text. str.
60
  """
 
 
 
 
61
  return pipeline(
62
+ prompt_s=prompt,
63
  max_new_tokens=output_length,
64
  return_text=True,
65
  return_full_text=False,
 
70
  model_name: str,
71
  output_length: int,
72
  prompt: str,
 
73
  ) -> str:
74
  """
75
  Called when the user submits the form.
76
  :param model_name: The name of the model to use.
77
  :param output_length: The size of the groups to use.
78
  :param prompt: The prompt to use.
 
79
  :return: The output of the model.
80
  :raises ValueError: If the model name is not supported, the output length is <= 0,
81
  the prompt is empty or longer than
 
102
  pipeline=pipeline,
103
  prompt=prompt,
104
  output_length=output_length,
 
105
  )
106
  generation_end_time = time()
107
  generation_time = generation_end_time - generation_start_time
prompt_engeneering.py DELETED
@@ -1,86 +0,0 @@
1
- import os
2
- from dataclasses import dataclass
3
- from datetime import datetime
4
- from typing import Generator, Dict, List
5
-
6
- from googleapiclient.discovery import build
7
- from streamlit import secrets
8
-
9
- INSTRUCTIONS = "Instructions: " \
10
- "Using the provided web search results, " \
11
- "write a comprehensive reply to the given query. " \
12
- "Make sure to cite results using [[number](URL)] notation after the reference. " \
13
- "If the provided search results refer to multiple subjects with the same name, " \
14
- "write separate answers for each subject."
15
-
16
-
17
- def get_google_api_key():
18
- """Returns the Google API key from streamlit's secrets"""
19
- try:
20
- return secrets["google_search_api_key"]
21
- except (FileNotFoundError, IsADirectoryError):
22
- return os.environ["google_search_api_key"]
23
-
24
-
25
- def get_google_cse_id():
26
- """Returns the Google CSE ID from streamlit's secrets"""
27
- try:
28
- return secrets["google_cse_id"]
29
- except (FileNotFoundError, IsADirectoryError):
30
- return os.environ["google_cse_id"]
31
-
32
-
33
- def google_search(search_term, **kwargs) -> list:
34
- service = build("customsearch", "v1", developerKey=get_google_api_key())
35
- search_engine = service.cse()
36
- res = search_engine.list(q=search_term, cx=get_google_cse_id(), **kwargs).execute()
37
- return res['items']
38
-
39
-
40
- @dataclass
41
- class SearchResult:
42
- __slots__ = ["title", "body", "url"]
43
- title: str
44
- body: str
45
- url: str
46
-
47
-
48
- def get_web_search_results(
49
- query: str,
50
- num_results: int,
51
- ) -> Generator[SearchResult, None, None]:
52
- """Gets a list of web search results using the Google search API"""
53
- rew_results: List[Dict[str, str]] = google_search(
54
- search_term=query,
55
- num=num_results
56
- )[:num_results]
57
- for result in rew_results:
58
- if result["snippet"].endswith("\xa0..."):
59
- result["snippet"] = result["snippet"][:-4]
60
- yield SearchResult(
61
- title=result["title"],
62
- body=result["snippet"],
63
- url=result["link"],
64
- )
65
-
66
-
67
- def format_search_result(search_result: Generator[SearchResult, None, None]) -> str:
68
- """Formats a search result to be added to the prompt."""
69
- ans = ""
70
- for i, result in enumerate(search_result):
71
- ans += f"[{i}] {result.body}\nURL: {result.url}\n\n"
72
- return ans
73
-
74
-
75
- def rewrite_prompt(
76
- prompt: str,
77
- ) -> str:
78
- """Rewrites the prompt by adding web search results to it."""
79
- raw_results = get_web_search_results(
80
- query=prompt,
81
- num_results=5,
82
- )
83
- formatted_results = "Web search results:\n" + format_search_result(raw_results)
84
- formatted_date = "Current date: " + datetime.now().strftime("%d/%m/%Y")
85
- formatted_prompt = f"Query: {prompt}"
86
- return "\n".join([formatted_results, formatted_date, INSTRUCTIONS, formatted_prompt])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests.py CHANGED
@@ -5,27 +5,10 @@ from grouped_sampling import GroupedSamplingPipeLine, get_full_models_list, Unsu
5
 
6
  from on_server_start import download_useful_models
7
  from hanlde_form_submit import create_pipeline, on_form_submit
8
- from prompt_engeneering import rewrite_prompt
9
 
10
  HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"
11
 
12
 
13
- def test_prompt_engineering():
14
- example_prompt = "Answer yes or no, is the sky blue?"
15
- rewritten_prompt = rewrite_prompt(example_prompt)
16
- assert rewritten_prompt.startswith("Web search results:")
17
- assert rewritten_prompt.endswith("Query: Answer yes or no, is the sky blue?")
18
- assert "Current date: " in rewritten_prompt
19
- assert "Instructions: " in rewritten_prompt
20
-
21
-
22
- def test_get_supported_model_names():
23
- supported_model_names = get_full_models_list()
24
- assert len(supported_model_names) > 0
25
- assert "gpt2" in supported_model_names
26
- assert all(isinstance(name, str) for name in supported_model_names)
27
-
28
-
29
  def test_on_server_start():
30
  download_useful_models()
31
  assert os.path.exists(HUGGING_FACE_CACHE_DIR)
@@ -36,15 +19,15 @@ def test_on_form_submit():
36
  model_name = "gpt2"
37
  output_length = 10
38
  prompt = "Answer yes or no, is the sky blue?"
39
- output = on_form_submit(model_name, output_length, prompt, web_search=False)
40
  assert output is not None
41
  assert len(output) > 0
42
  empty_prompt = ""
43
  with pytest.raises(ValueError):
44
- on_form_submit(model_name, output_length, empty_prompt, web_search=False)
45
  unsupported_model_name = "unsupported_model_name"
46
  with pytest.raises(UnsupportedModelNameException):
47
- on_form_submit(unsupported_model_name, output_length, prompt, web_search=False)
48
 
49
 
50
  @pytest.mark.parametrize(
 
5
 
6
  from on_server_start import download_useful_models
7
  from hanlde_form_submit import create_pipeline, on_form_submit
 
8
 
9
  HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def test_on_server_start():
13
  download_useful_models()
14
  assert os.path.exists(HUGGING_FACE_CACHE_DIR)
 
19
  model_name = "gpt2"
20
  output_length = 10
21
  prompt = "Answer yes or no, is the sky blue?"
22
+ output = on_form_submit(model_name, output_length, prompt)
23
  assert output is not None
24
  assert len(output) > 0
25
  empty_prompt = ""
26
  with pytest.raises(ValueError):
27
+ on_form_submit(model_name, output_length, empty_prompt)
28
  unsupported_model_name = "unsupported_model_name"
29
  with pytest.raises(UnsupportedModelNameException):
30
+ on_form_submit(unsupported_model_name, output_length, prompt)
31
 
32
 
33
  @pytest.mark.parametrize(
user_instructions_hebrew.md CHANGED
@@ -2,16 +2,27 @@
2
 
3
  讘讚诪讜 讛讝讛, 讗转诐 讬讻讜诇讬诐 诇讛砖转诪砖 讘拽诇讜转 讘讚讙讬诪讛 讘拽讘讜爪讜转.
4
 
5
- 转讘讞专讜 诪讜讚诇 砖讬砖 诇讜 诇驻讞讜转 诇讬讬拽 讗讞讚 诪转讜讱 [讛诪讗讙专](https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch&sort=downloads)
6
- 转讻转讘讜 诇诪讜讚诇 讛讜专讗讜转 讘专讜专讜转 讜转讘讞专讜 讗转 讗讜专讱 讛讟拽住讟 砖讗转诐 专讜爪讬诐 砖讛诪讜讚诇 讬爪讜专.
7
 
8
- ## 注爪讜转 诇讻转讬讘转 讛驻专讜诪驻讟
9
 
10
- 1. 转砖转诪砖讜 讘诪讜讚诇 讛诪转讗讬诐 - 讻诇 诪讜讚诇 谞讜爪专 诇诪砖讬诪讛 诪住讜讬诪转. 转砖讗诇讜 讗讜转讬 讗讬讝讛 诪讜讚诇 诇讘讞讜专 诇诪砖讬诪讛 砖诇讻诐
11
- 2. 转砖拽讬注讜 诪讞砖讘讛 讘讗讜专讱 讛讟拽住讟 砖讗转诐 专讜爪讬诐 砖讛讗诇讙讜专讬转诐 讬爪讜专
12
- 3. 转讝讻专讜 砖讝讛 专拽 讚诪讜, 讛讚诪讜 专抓 注诇 诪注讘讚 讙专驻讬 讗讞讚 讘诇讘讚 讜诇讻谉 诇讗 谞讬转谉 诇讛专讬抓 讘讚诪讜 讗转 讛诪讜讚诇讬诐 讛讞讝拽讬诐 讘讬讜转专 讗诇讗 专拽 诪讜讚诇讬诐 拽讟谞讬诐 讜驻讜诪讘讬讬诐
13
- 4. 转讻转讘讜 讘讗谞讙诇讬转 讜诇讗 讘注讘专讬转. 专讜讘 讛诪讜讚诇讬诐 诇讗 转讜诪讻讬诐 讘注讘专讬转 讻诇诇 讜讗诇讛 砖讻谉, 诪讙讬注讬诐 诇转讜爪讗讜转 讛专讘讛 讬讜转专 讟讜讘讜转 讘讗谞讙诇讬转
14
- 5. 转讙讬讚讜 诇讗诇讙讜专讬转诐 讘讚讬讜拽 诪讛 诇注砖讜转 - 讛转讞讬诇讜 讘讛讜专讗讛 讜讛驻专讬讚讜 讗转 讛讞诇拽讬诐
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  诇讚讜讙诪讛:
17
 
@@ -21,7 +32,7 @@
21
 
22
  Answer: """
23
 
24
- 6. 讛砖转诪砖讜 讘转讘谞讬讜转 注诐 讚讜讙诪讜转 - 诇讚讜讙诪讛:
25
 
26
  Instruction: """Label the following sentences to positive or negative"""
27
 
@@ -31,12 +42,11 @@
31
 
32
  Sentence: """your sentence here.""" Sentiment: """
33
 
34
- 7. 讗诇 转住讬讬诪讜 讗转 讛驻专讜诪驻讟 讘专讜讜讞
35
- 8. 转砖转诪砖讜 讘讞讬驻讜砖 讘讗讬谞讟专谞讟 专拽 讗诐 讛诪讬讚注 讘讗讬谞讟专谞讟 讛讜讗 拽专讬讟讬 诇讛爪诇讞转 讛诪砖讬诪讛 砖讛讟诇转诐 注诇 讛讗诇讙讜专讬转诐. 诪讬讚注 诇讗 专诇讜讜谞讟讬 讬讘诇讘诇 讗转 讛讗诇讙讜专讬转诐
36
- 9. 转注讝专讜 诇讗诇讙讜讬专转诐 - 转住驻拽讜 诇讗诇讙讜专讬转诐 讗转 讻诇 讛诪讬讚注 讛谞讚专砖
37
- 10. 讘诪砖讬诪讜转 诪住讜讘讻讜转, 讛转讞讬诇讜 讗转 讛驻专讜诪驻讟 讘诪砖驻讟 "Let's think step by step"
38
- 11. 转讛讬讜 住驻讬爪驻讬讬诐
39
- 12. 转转讞讬诇讜 讗转 讛诪砖讬诪讛 讘注爪诪讻诐
40
 
41
  讚讜讙诪讛:
42
 
 
2
 
3
  讘讚诪讜 讛讝讛, 讗转诐 讬讻讜诇讬诐 诇讛砖转诪砖 讘拽诇讜转 讘讚讙讬诪讛 讘拽讘讜爪讜转.
4
 
5
+ ## 砖诇讘 专讗砖讜谉 - 转讘讞专讜 诪讜讚诇
 
6
 
7
+ - 讛诪讜讚诇讬诐 讘讚诪讜 讛诐
8
 
9
+ 1. opt-iml-max-1.3B (讛诪讜讚诇 讛拽讟谉)
10
+ 2. opt-iml-max-30B (讛诪讜讚诇 讛讙讚讜诇)
11
+
12
+ 讛诪讜讚诇 讛讙讚讜诇 讬讜爪专 讟拽住讟讬诐 讬讜转专 讟讜讘讬诐, 讗讱 讘讗讬讟讬讜转.
13
+
14
+ 讛诪讜讚诇 讛拽讟谉 讬讜爪专 讟拽住讟讬诐 拽爪转 驻讞讜转 讟讜讘讬诐 讗讱 讘诪讛讬专讜转.
15
+
16
+ 转讻转讘讜 诇讗诇讙讜专讬转诐 讛讜专讗讜转 讘专讜专讜转 讜转讘讞专讜 讗转 讗讜专讱 讛讟拽住讟 砖讗转诐 专讜爪讬诐 砖讛讗诇讙讜专讬转诐 讬爪讜专.
17
+
18
+ ## 转住讘讬专讜 诇讗诇讙讜专讬转诐 诪讛 诇注砖讜转
19
+
20
+ ### 注爪讜转 诇讻转讬讘转 讛拽诇讟 诇诪讜讚诇
21
+
22
+ 1. 转讝讻专讜 砖讝讛 专拽 讚诪讜, 诇讗 谞讬转谉 诇讛专讬抓 讘讚诪讜 讗转 讛诪讜讚诇讬诐 讛讞讝拽讬诐 讘讬讜转专 讗诇讗 专拽 诪讜讚诇讬诐 驻转讜讞讬诐 诇爪讬讘讜专.
23
+ 2. 转讻转讘讜 讘讗谞讙诇讬转 讜诇讗 讘注讘专讬转. 专讜讘 讛诪讜讚诇讬诐 诇讗 转讜诪讻讬诐 讘注讘专讬转 讻诇诇 讜讗诇讛 砖讻谉, 诇讗 讬讜爪专讬诐 讟拽住讟讬诐 讘讗讬讻讜转 诪住驻讬拽 讟讜讘讛 讜诇讻谉 讛讚诪讜 诇讗 转讜诪讱 讘讛诐.
24
+ 3. 转砖拽讬注讜 诪讞砖讘讛 讘讗讜专讱 讛讟拽住讟 砖讗转诐 专讜爪讬诐 砖讛讗诇讙讜专讬转诐 讬爪讜专.
25
+ 4. 转讙讬讚讜 诇讗诇讙讜专讬转诐 讘讚讬讜拽 诪讛 诇注砖讜转 - 讛转讞讬诇讜 讘讛讜专讗讛 讜讛驻专讬讚讜 讗转 讛讞诇拽讬诐
26
 
27
  诇讚讜讙诪讛:
28
 
 
32
 
33
  Answer: """
34
 
35
+ 5. 讛砖转诪砖讜 讘转讘谞讬讜转 注诐 讚讜讙诪讜转 - 诇讚讜讙诪讛:
36
 
37
  Instruction: """Label the following sentences to positive or negative"""
38
 
 
42
 
43
  Sentence: """your sentence here.""" Sentiment: """
44
 
45
+ 6. 讗诇 转住讬讬诪讜 讗转 讛拽诇讟 诇讗诇讙讜专讬转诐 讘专讜讜讞
46
+ 7. 转注讝专讜 诇讗诇讙讜讬专转诐 - 转住驻拽讜 诇讗诇讙讜专讬转诐 讗转 讻诇 讛诪讬讚注 讛谞讚专砖
47
+ 8. 讘诪砖讬诪讜转 诪住讜讘讻讜转, 讛转讞讬诇讜 讗转 讛拽诇讟 诇讗诇讙讜专讬转诐 讘诪砖驻讟 "Let's think step by step"
48
+ 9. 转讛讬讜 住驻讬爪驻讬讬诐
49
+ 10. 转转讞讬诇讜 讗转 讛诪砖讬诪讛 讘注爪诪讻诐
 
50
 
51
  讚讜讙诪讛:
52