LLM-Brainstorming / tests /test_main.py
ping98k
Handle explain option for plaintext judges
295a884
import sys, os, types, json
from unittest.mock import patch, MagicMock
# Ensure project root in path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
# Provide dummy litellm module so import succeeds
fake_litellm = types.ModuleType('litellm')
fake_litellm.completion = MagicMock()
sys.modules.setdefault('litellm', fake_litellm)
# Provide dummy dotenv module
fake_dotenv = types.ModuleType('dotenv')
fake_dotenv.load_dotenv = MagicMock()
sys.modules.setdefault('dotenv', fake_dotenv)
# Dummy gradio module so import succeeds
fake_gradio = types.ModuleType('gradio')
fake_gradio.Interface = MagicMock(return_value=MagicMock(launch=MagicMock()))
fake_gradio.Textbox = MagicMock
fake_gradio.Number = MagicMock
fake_gradio.Checkbox = MagicMock
fake_gradio.Plot = MagicMock
sys.modules.setdefault('gradio', fake_gradio)
# Dummy tqdm module for write method
class FakeTqdmModule(types.ModuleType):
def __init__(self):
super().__init__('tqdm')
self.write = MagicMock()
def __call__(self, iterable=None, total=None):
return iterable
fake_tqdm_mod = FakeTqdmModule()
fake_tqdm_mod.tqdm = fake_tqdm_mod
sys.modules.setdefault('tqdm', fake_tqdm_mod)
# Dummy matplotlib module
fake_plt = types.ModuleType('matplotlib.pyplot')
fake_plt.figure = MagicMock(return_value='fig')
fake_plt.hist = MagicMock()
fake_plt.bar = MagicMock()
fake_plt.xticks = MagicMock()
fake_matplotlib = types.ModuleType('matplotlib')
fake_matplotlib.pyplot = fake_plt
sys.modules.setdefault('matplotlib', fake_matplotlib)
sys.modules.setdefault('matplotlib.pyplot', fake_plt)
import main
class DummyFuture:
def __init__(self, func, *args):
self._func = func
self._args = args
def result(self):
return self._func(*self._args)
class DummyExecutor:
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
pass
def submit(self, func, *args):
return DummyFuture(func, *args)
def map(self, func, iterable):
for item in iterable:
yield func(item)
class DummyTqdm:
def __call__(self, iterable=None, total=None):
return iterable
def write(self, msg):
pass
def test_run_tournament_full_loop():
dummy_tqdm = DummyTqdm()
with patch('main.generate_players') as mock_gen, \
patch('main.prompt_score') as mock_score, \
patch('main.prompt_pairwise') as mock_pair, \
patch('main.ThreadPoolExecutor', return_value=DummyExecutor()) as MockExec, \
patch('main.as_completed', new=lambda futs: futs), \
patch('main.tqdm', new=dummy_tqdm), \
patch('main.plt.figure', return_value='fig'), \
patch('main.plt.hist'), \
patch('main.plt.bar'):
mock_gen.return_value = (['p1', 'p2', 'p3', 'p4'], {'prompt_tokens':1,'completion_tokens':1})
scores = {'p1':3, 'p2':2, 'p3':1, 'p4':0}
mock_score.side_effect = lambda instr, cl, block, player, **kw: (
f"Final verdict: [{scores[player]}]",
{'prompt_tokens':1,'completion_tokens':1}
)
mock_pair.side_effect = lambda instr, block, a, b, **kw: (
"Final verdict: A",
{'prompt_tokens':1,'completion_tokens':1}
)
results = list(main.run_tournament(
api_base='b',
api_token='k',
generate_model='gm',
score_model='sm',
pairwise_model='pm',
generate_temperature=1,
score_temperature=1,
pairwise_temperature=1,
instruction_input='instr',
criteria_input='c1,c2',
n_gen=4,
pool_size=2,
num_top_picks=1,
max_workers=1,
enable_score_filter=True,
enable_pairwise_filter=True,
score_with_instruction=True,
pairwise_with_instruction=True,
generate_thinking=True,
score_thinking=True,
pairwise_thinking=True,
))
process_log, hist_fig, elo_fig, top_picks, usage = results[-1]
assert 'Done' in process_log
assert hist_fig == 'fig'
assert elo_fig == 'fig'
assert any(p in top_picks for p in {'p1', 'p2'})
mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', temperature=1, thinking=True, return_usage=True)
assert 'Score completion' in process_log
assert 'Pairwise completion' in process_log
assert 'Prompt tokens' in usage
assert mock_score.call_count == 4
assert mock_pair.called
def test_run_tournament_pairwise_odd_players():
dummy_tqdm = DummyTqdm()
with patch('main.generate_players') as mock_gen, \
patch('main.prompt_pairwise') as mock_pair, \
patch('main.ThreadPoolExecutor', return_value=DummyExecutor()) as MockEx, \
patch('main.as_completed', new=lambda futs: futs), \
patch('main.tqdm', new=dummy_tqdm), \
patch('main.plt.figure', return_value='fig'), \
patch('main.plt.hist'), \
patch('main.plt.bar'):
mock_gen.return_value = (['p1', 'p2', 'p3'], {'prompt_tokens':1,'completion_tokens':1})
mock_pair.side_effect = lambda instr, block, a, b, **kw: (
"Final verdict: A",
{'prompt_tokens':1,'completion_tokens':1}
)
results = list(main.run_tournament(
api_base='b',
api_token='k',
generate_model='gm',
score_model='sm',
pairwise_model='pm',
generate_temperature=1,
score_temperature=1,
pairwise_temperature=1,
instruction_input='instr',
criteria_input='c1,c2',
n_gen=3,
pool_size=3,
num_top_picks=1,
max_workers=1,
enable_score_filter=False,
enable_pairwise_filter=True,
score_with_instruction=True,
pairwise_with_instruction=True,
generate_thinking=True,
score_thinking=True,
pairwise_thinking=True,
))
process_log, hist_fig, elo_fig, top_picks, usage = results[-1]
assert 'Done' in process_log
assert any(p in top_picks for p in {'p1', 'p2', 'p3'})
assert mock_pair.call_count == 3