Alex Cabrera commited on
Commit
8ceca60
1 Parent(s): 14bac2f
Files changed (1) hide show
  1. config.py +19 -3
config.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  from __future__ import annotations
4
 
 
5
  from dataclasses import dataclass
6
 
7
  from zeno_build.evaluation.text_features.capitalization import input_capital_char_ratio
@@ -25,6 +26,8 @@ from zeno_build.evaluation.text_metrics.critique import (
25
  )
26
  from zeno_build.experiments import search_space
27
 
 
 
28
  lang_pairs: dict[str, list[str]] = {
29
  # All language pairs used in any experiment
30
  "all_lang_pairs": [
@@ -66,13 +69,17 @@ main_space = search_space.CombinatorialSearchSpace(
66
  "lang_pairs": search_space.Constant("all_lang_pairs"),
67
  "model_preset": search_space.Categorical(
68
  [
 
69
  "text-davinci-003-RR-1-shot",
70
  "text-davinci-003-RR-5-shot",
71
  "text-davinci-003-QR-1-shot",
72
  "text-davinci-003-QR-5-shot",
73
- "text-davinci-003-zeroshot",
74
- "wmt-best",
 
75
  "MS-Translator",
 
 
76
  ]
77
  ),
78
  }
@@ -87,6 +94,7 @@ class GptMtConfig:
87
  base_model: str
88
  prompt_strategy: str | None = None
89
  prompt_shots: int | None = None
 
90
 
91
 
92
  # The details of each model
@@ -106,8 +114,16 @@ model_configs = {
106
  "text-davinci-003-zeroshot": GptMtConfig(
107
  "text-davinci-003/zeroshot", "text-davinci-003", None, 0
108
  ),
109
- "wmt-best": GptMtConfig("wmt-best", "wmt-best"),
 
 
 
 
 
 
110
  "MS-Translator": GptMtConfig("MS-Translator", "MS-Translator"),
 
 
111
  }
112
 
113
  sweep_distill_functions = [chrf]
 
2
 
3
  from __future__ import annotations
4
 
5
+ from collections.abc import Callable
6
  from dataclasses import dataclass
7
 
8
  from zeno_build.evaluation.text_features.capitalization import input_capital_char_ratio
 
26
  )
27
  from zeno_build.experiments import search_space
28
 
29
+ from modeling import remove_leading_language
30
+
31
  lang_pairs: dict[str, list[str]] = {
32
  # All language pairs used in any experiment
33
  "all_lang_pairs": [
 
69
  "lang_pairs": search_space.Constant("all_lang_pairs"),
70
  "model_preset": search_space.Categorical(
71
  [
72
+ "text-davinci-003-zeroshot",
73
  "text-davinci-003-RR-1-shot",
74
  "text-davinci-003-RR-5-shot",
75
  "text-davinci-003-QR-1-shot",
76
  "text-davinci-003-QR-5-shot",
77
+ "gpt-3.5-turbo-0301-zeroshot",
78
+ "gpt-4-0314-zeroshot",
79
+ "gpt-4-0314-zeroshot-postprocess",
80
  "MS-Translator",
81
+ "google-cloud",
82
+ "wmt-best",
83
  ]
84
  ),
85
  }
 
94
  base_model: str
95
  prompt_strategy: str | None = None
96
  prompt_shots: int | None = None
97
+ post_processors: list[Callable[[str], str]] | None = None
98
 
99
 
100
  # The details of each model
 
114
  "text-davinci-003-zeroshot": GptMtConfig(
115
  "text-davinci-003/zeroshot", "text-davinci-003", None, 0
116
  ),
117
+ "gpt-3.5-turbo-0301-zeroshot": GptMtConfig(
118
+ "gpt-3.5-turbo-0301/zeroshot", "gpt-3.5-turbo-0301", None, 0
119
+ ),
120
+ "gpt-4-0314-zeroshot": GptMtConfig("gpt-4-0314/zeroshot", "gpt-4-0314", None, 0),
121
+ "gpt-4-0314-zeroshot-postprocess": GptMtConfig(
122
+ "gpt-4-0314/zeroshot", "gpt-4-0314", None, 0, [remove_leading_language]
123
+ ),
124
  "MS-Translator": GptMtConfig("MS-Translator", "MS-Translator"),
125
+ "google-cloud": GptMtConfig("google-cloud", "google-cloud"),
126
+ "wmt-best": GptMtConfig("wmt-best", "wmt-best"),
127
  }
128
 
129
  sweep_distill_functions = [chrf]