Jae-Won Chung commited on
Commit
e38f79f
1 Parent(s): 6393815

Do not standardize system prompts

Browse files
spitfight/colosseum/controller/controller.py CHANGED
@@ -44,9 +44,10 @@ class RequestState(BaseModel):
44
  This model is also serialized as is and logged.
45
  """
46
  request_id: str
47
- prompt: str
48
  model_names: list[str]
49
- responses: list[str] = ["EMPTY", "EMPTY"]
 
 
50
  energy_consumptions: list[float] = [0.0, 0.0]
51
  response_victory_index: Optional[Literal[0, 1]] = None
52
  extra_energy_was_worth: Optional[bool] = None
@@ -172,7 +173,7 @@ class Controller:
172
  model_names = [worker.model_name for worker in workers]
173
  self.request_states[request_id] = RequestState(
174
  request_id=request_id,
175
- prompt=prompt,
176
  model_names=model_names,
177
  )
178
  request_state = self.request_states[request_id]
@@ -185,11 +186,13 @@ class Controller:
185
  except RuntimeError:
186
  controller_logger.error("Worker %s is dead.", model_name)
187
  raise
 
 
188
  prompt, stop_str, stop_token_ids = apply_model_characteristics(
189
- system_prompt=get_system_prompt("chat"),
190
  prompt=prompt,
191
  model_name=worker.model_id,
192
  )
 
193
 
194
  # Request the model worker to stream the response to the user's prompt.
195
  response = ""
 
44
  This model is also serialized as is and logged.
45
  """
46
  request_id: str
 
47
  model_names: list[str]
48
+ raw_prompt: str
49
+ responses: list[str] = ["UNSET", "UNSET"]
50
+ model_prompts: list[str] = ["UNSET", "UNSET"]
51
  energy_consumptions: list[float] = [0.0, 0.0]
52
  response_victory_index: Optional[Literal[0, 1]] = None
53
  extra_energy_was_worth: Optional[bool] = None
 
173
  model_names = [worker.model_name for worker in workers]
174
  self.request_states[request_id] = RequestState(
175
  request_id=request_id,
176
+ raw_prompt=prompt,
177
  model_names=model_names,
178
  )
179
  request_state = self.request_states[request_id]
 
186
  except RuntimeError:
187
  controller_logger.error("Worker %s is dead.", model_name)
188
  raise
189
+
190
+ # Models have different prompt formatting requirements and stopping criteria.
191
  prompt, stop_str, stop_token_ids = apply_model_characteristics(
 
192
  prompt=prompt,
193
  model_name=worker.model_id,
194
  )
195
+ request_state.model_prompts[model_index] = prompt
196
 
197
  # Request the model worker to stream the response to the user's prompt.
198
  response = ""
spitfight/prompt.py CHANGED
@@ -45,14 +45,15 @@ def get_system_prompt(task: Task | str) -> str:
45
 
46
 
47
  def apply_model_characteristics(
48
- system_prompt: str,
49
  prompt: str,
50
  model_name: str,
 
51
  ) -> tuple[str, str | None, list[int]]:
52
  """Apply and return model-specific differences."""
53
  conv = get_conversation_template(model_name)
54
 
55
- conv.system_message = system_prompt
 
56
  conv.messages = []
57
  conv.offset = 0
58
 
 
45
 
46
 
47
  def apply_model_characteristics(
 
48
  prompt: str,
49
  model_name: str,
50
+ system_prompt: str | None = None,
51
  ) -> tuple[str, str | None, list[int]]:
52
  """Apply and return model-specific differences."""
53
  conv = get_conversation_template(model_name)
54
 
55
+ if system_prompt is not None:
56
+ conv.system_message = system_prompt
57
  conv.messages = []
58
  conv.offset = 0
59