Spaces:
Sleeping
Sleeping
File size: 6,344 Bytes
249284d e6b4ba0 249284d e6b4ba0 3bd9ad6 249284d 295a884 249284d 02aebba 249284d 02aebba 3404ee0 249284d e6b4ba0 249284d e6b4ba0 b936324 6d50a75 3bd9ad6 249284d 2b53c20 e6b4ba0 2b53c20 295a884 2b53c20 02aebba 2b53c20 02aebba 3404ee0 2b53c20 e6b4ba0 2b53c20 b936324 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import sys, os, types, json
from unittest.mock import patch, MagicMock
# Ensure project root in path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
# Provide dummy litellm module so import succeeds
fake_litellm = types.ModuleType('litellm')
fake_litellm.completion = MagicMock()
sys.modules.setdefault('litellm', fake_litellm)
# Provide dummy dotenv module
fake_dotenv = types.ModuleType('dotenv')
fake_dotenv.load_dotenv = MagicMock()
sys.modules.setdefault('dotenv', fake_dotenv)
# Dummy gradio module so import succeeds
fake_gradio = types.ModuleType('gradio')
fake_gradio.Interface = MagicMock(return_value=MagicMock(launch=MagicMock()))
fake_gradio.Textbox = MagicMock
fake_gradio.Number = MagicMock
fake_gradio.Checkbox = MagicMock
fake_gradio.Plot = MagicMock
sys.modules.setdefault('gradio', fake_gradio)
# Dummy tqdm module for write method
class FakeTqdmModule(types.ModuleType):
def __init__(self):
super().__init__('tqdm')
self.write = MagicMock()
def __call__(self, iterable=None, total=None):
return iterable
fake_tqdm_mod = FakeTqdmModule()
fake_tqdm_mod.tqdm = fake_tqdm_mod
sys.modules.setdefault('tqdm', fake_tqdm_mod)
# Dummy matplotlib module
fake_plt = types.ModuleType('matplotlib.pyplot')
fake_plt.figure = MagicMock(return_value='fig')
fake_plt.hist = MagicMock()
fake_plt.bar = MagicMock()
fake_plt.xticks = MagicMock()
fake_matplotlib = types.ModuleType('matplotlib')
fake_matplotlib.pyplot = fake_plt
sys.modules.setdefault('matplotlib', fake_matplotlib)
sys.modules.setdefault('matplotlib.pyplot', fake_plt)
import main
class DummyFuture:
def __init__(self, func, *args):
self._func = func
self._args = args
def result(self):
return self._func(*self._args)
class DummyExecutor:
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
pass
def submit(self, func, *args):
return DummyFuture(func, *args)
def map(self, func, iterable):
for item in iterable:
yield func(item)
class DummyTqdm:
def __call__(self, iterable=None, total=None):
return iterable
def write(self, msg):
pass
def test_run_tournament_full_loop():
dummy_tqdm = DummyTqdm()
with patch('main.generate_players') as mock_gen, \
patch('main.prompt_score') as mock_score, \
patch('main.prompt_pairwise') as mock_pair, \
patch('main.ThreadPoolExecutor', return_value=DummyExecutor()) as MockExec, \
patch('main.as_completed', new=lambda futs: futs), \
patch('main.tqdm', new=dummy_tqdm), \
patch('main.plt.figure', return_value='fig'), \
patch('main.plt.hist'), \
patch('main.plt.bar'):
mock_gen.return_value = (['p1', 'p2', 'p3', 'p4'], {'prompt_tokens':1,'completion_tokens':1})
scores = {'p1':3, 'p2':2, 'p3':1, 'p4':0}
mock_score.side_effect = lambda instr, cl, block, player, **kw: (
f"Final verdict: [{scores[player]}]",
{'prompt_tokens':1,'completion_tokens':1}
)
mock_pair.side_effect = lambda instr, block, a, b, **kw: (
"Final verdict: A",
{'prompt_tokens':1,'completion_tokens':1}
)
results = list(main.run_tournament(
api_base='b',
api_token='k',
generate_model='gm',
score_model='sm',
pairwise_model='pm',
generate_temperature=1,
score_temperature=1,
pairwise_temperature=1,
instruction_input='instr',
criteria_input='c1,c2',
n_gen=4,
pool_size=2,
num_top_picks=1,
max_workers=1,
enable_score_filter=True,
enable_pairwise_filter=True,
score_with_instruction=True,
pairwise_with_instruction=True,
generate_thinking=True,
score_thinking=True,
pairwise_thinking=True,
))
process_log, hist_fig, elo_fig, top_picks, usage = results[-1]
assert 'Done' in process_log
assert hist_fig == 'fig'
assert elo_fig == 'fig'
assert any(p in top_picks for p in {'p1', 'p2'})
mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', temperature=1, thinking=True, return_usage=True)
assert 'Score completion' in process_log
assert 'Pairwise completion' in process_log
assert 'Prompt tokens' in usage
assert mock_score.call_count == 4
assert mock_pair.called
def test_run_tournament_pairwise_odd_players():
dummy_tqdm = DummyTqdm()
with patch('main.generate_players') as mock_gen, \
patch('main.prompt_pairwise') as mock_pair, \
patch('main.ThreadPoolExecutor', return_value=DummyExecutor()) as MockEx, \
patch('main.as_completed', new=lambda futs: futs), \
patch('main.tqdm', new=dummy_tqdm), \
patch('main.plt.figure', return_value='fig'), \
patch('main.plt.hist'), \
patch('main.plt.bar'):
mock_gen.return_value = (['p1', 'p2', 'p3'], {'prompt_tokens':1,'completion_tokens':1})
mock_pair.side_effect = lambda instr, block, a, b, **kw: (
"Final verdict: A",
{'prompt_tokens':1,'completion_tokens':1}
)
results = list(main.run_tournament(
api_base='b',
api_token='k',
generate_model='gm',
score_model='sm',
pairwise_model='pm',
generate_temperature=1,
score_temperature=1,
pairwise_temperature=1,
instruction_input='instr',
criteria_input='c1,c2',
n_gen=3,
pool_size=3,
num_top_picks=1,
max_workers=1,
enable_score_filter=False,
enable_pairwise_filter=True,
score_with_instruction=True,
pairwise_with_instruction=True,
generate_thinking=True,
score_thinking=True,
pairwise_thinking=True,
))
process_log, hist_fig, elo_fig, top_picks, usage = results[-1]
assert 'Done' in process_log
assert any(p in top_picks for p in {'p1', 'p2', 'p3'})
assert mock_pair.call_count == 3
|