Marina Pliusnina commited on
Commit
c8bd9ca
1 Parent(s): 1823861

change generation parameters values

Browse files
__pycache__/rag.cpython-311.pyc ADDED
Binary file (3.09 kB). View file
 
__pycache__/utils.cpython-311.pyc ADDED
Binary file (1.71 kB). View file
 
app.py CHANGED
@@ -7,7 +7,7 @@ from urllib.error import HTTPError
7
  from rag import RAG
8
  from utils import setup
9
 
10
- MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", default=100))
11
  SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="True") == "True"
12
 
13
  setup()
@@ -44,13 +44,13 @@ def submit_input(input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sa
44
 
45
 
46
  model_parameters = {
47
- "max_new_tokens": max_new_tokens,
48
- "repetition_penalty": repetition_penalty,
49
- "top_k": top_k,
50
- "top_p": top_p,
51
- "do_sample": do_sample,
52
- "num_beams": num_beams,
53
- "temperature": temperature
54
  }
55
 
56
  output = generate(input_, model_parameters)
@@ -110,17 +110,17 @@ def gradio_app():
110
  with gr.Row(variant="panel"):
111
  with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
112
  max_new_tokens = Slider(
113
- minimum=1,
114
- maximum=200,
115
  step=1,
116
  value=MAX_NEW_TOKENS,
117
  label="Max tokens"
118
  )
119
  repetition_penalty = Slider(
120
  minimum=0.1,
121
- maximum=10,
122
  step=0.1,
123
- value=1.2,
124
  label="Repetition penalty"
125
  )
126
  top_k = Slider(
@@ -132,25 +132,25 @@ def gradio_app():
132
  )
133
  top_p = Slider(
134
  minimum=0.01,
135
- maximum=0.99,
136
- value=0.95,
137
  label="Top p"
138
  )
139
  do_sample = Checkbox(
140
- value=True,
141
  label="Do sample"
142
  )
143
  num_beams = Slider(
144
  minimum=1,
145
- maximum=8,
146
  step=1,
147
- value=4,
148
  label="Beams"
149
  )
150
  temperature = Slider(
151
- minimum=0,
152
  maximum=1,
153
- value=0.5,
154
  label="Temperature"
155
  )
156
 
 
7
  from rag import RAG
8
  from utils import setup
9
 
10
+ MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", default=200))
11
  SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="True") == "True"
12
 
13
  setup()
 
44
 
45
 
46
  model_parameters = {
47
+ "MAX_NEW_TOKENS": max_new_tokens,
48
+ "REPETITION_PENALTY": repetition_penalty,
49
+ "TOP_K": top_k,
50
+ "TOP_P": top_p,
51
+ "DO_SAMPLE": do_sample,
52
+ "NUM_BEAMS": num_beams,
53
+ "TEMPERATURE": temperature
54
  }
55
 
56
  output = generate(input_, model_parameters)
 
110
  with gr.Row(variant="panel"):
111
  with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
112
  max_new_tokens = Slider(
113
+ minimum=50,
114
+ maximum=1000,
115
  step=1,
116
  value=MAX_NEW_TOKENS,
117
  label="Max tokens"
118
  )
119
  repetition_penalty = Slider(
120
  minimum=0.1,
121
+ maximum=2.0,
122
  step=0.1,
123
+ value=1.0,
124
  label="Repetition penalty"
125
  )
126
  top_k = Slider(
 
132
  )
133
  top_p = Slider(
134
  minimum=0.01,
135
+ maximum=1.0,
136
+ value=1.0,
137
  label="Top p"
138
  )
139
  do_sample = Checkbox(
140
+ value=False,
141
  label="Do sample"
142
  )
143
  num_beams = Slider(
144
  minimum=1,
145
+ maximum=4,
146
  step=1,
147
+ value=1,
148
  label="Beams"
149
  )
150
  temperature = Slider(
151
+ minimum=0.1,
152
  maximum=1,
153
+ value=0.35,
154
  label="Temperature"
155
  )
156
 
rag.py CHANGED
@@ -38,7 +38,7 @@ class RAG:
38
 
39
  return context
40
 
41
- def predict(self, instruction, context):
42
 
43
  api_key = os.getenv("HF_TOKEN")
44
 
@@ -55,18 +55,18 @@ class RAG:
55
 
56
  payload = {
57
  "inputs": query,
58
- "parameters": {"MAX_NEW_TOKENS": 1000, "TEMPERATURE": 0.25}
59
  }
60
 
61
  response = requests.post(self.model_name, headers=headers, json=payload)
62
 
63
  return response.json()[0]["generated_text"].split("###")[-1][8:-1]
64
 
65
- def get_response(self, prompt: str) -> str:
66
 
67
  context = self.get_context(prompt)
68
 
69
- response = self.predict(prompt, context)
70
 
71
  if not response:
72
  return self.NO_ANSWER_MESSAGE
 
38
 
39
  return context
40
 
41
+ def predict(self, instruction, context, model_parameters):
42
 
43
  api_key = os.getenv("HF_TOKEN")
44
 
 
55
 
56
  payload = {
57
  "inputs": query,
58
+ "parameters": model_parameters
59
  }
60
 
61
  response = requests.post(self.model_name, headers=headers, json=payload)
62
 
63
  return response.json()[0]["generated_text"].split("###")[-1][8:-1]
64
 
65
+ def get_response(self, prompt: str, model_parameters: dict) -> str:
66
 
67
  context = self.get_context(prompt)
68
 
69
+ response = self.predict(prompt, context, model_parameters)
70
 
71
  if not response:
72
  return self.NO_ANSWER_MESSAGE