|
from injector import Injector |
|
|
|
from taskweaver.config.config_mgt import AppConfigSource |
|
from taskweaver.logging import LoggingModule |
|
from taskweaver.memory import RoundCompressor |
|
|
|
|
|
def test_round_compressor(): |
|
from taskweaver.memory import Post, Round |
|
|
|
app_injector = Injector( |
|
[LoggingModule], |
|
) |
|
app_config = AppConfigSource( |
|
config={ |
|
"llm.api_key": "test_key", |
|
"round_compressor.rounds_to_compress": 2, |
|
"round_compressor.rounds_to_retain": 2, |
|
}, |
|
) |
|
app_injector.binder.bind(AppConfigSource, to=app_config) |
|
compressor = app_injector.get(RoundCompressor) |
|
|
|
assert compressor.rounds_to_compress == 2 |
|
assert compressor.rounds_to_retain == 2 |
|
|
|
round1 = Round.create(user_query="hello", id="round-1") |
|
post1 = Post.create( |
|
message="hello", |
|
send_from="User", |
|
send_to="Planner", |
|
attachment_list=[], |
|
) |
|
post2 = Post.create( |
|
message="hello", |
|
send_from="Planner", |
|
send_to="User", |
|
attachment_list=[], |
|
) |
|
round1.add_post(post1) |
|
round1.add_post(post2) |
|
|
|
summary, retained = compressor.compress_rounds( |
|
[round1], |
|
lambda x: x, |
|
use_back_up_engine=False, |
|
) |
|
assert summary == "None" |
|
assert len(retained) == 1 |
|
|
|
round2 = Round.create(user_query="hello", id="round-2") |
|
round2.add_post(post1) |
|
round2.add_post(post2) |
|
|
|
summary, retained = compressor.compress_rounds( |
|
[round1, round2], |
|
lambda x: x, |
|
use_back_up_engine=False, |
|
) |
|
assert summary == "None" |
|
assert len(retained) == 2 |
|
|
|
round3 = Round.create(user_query="hello", id="round-3") |
|
round3.add_post(post1) |
|
round3.add_post(post2) |
|
summary, retained = compressor.compress_rounds( |
|
[round1, round2, round3], |
|
lambda x: x, |
|
use_back_up_engine=False, |
|
) |
|
assert summary == "None" |
|
assert len(retained) == 3 |
|
|
|
round4 = Round.create(user_query="hello", id="round-4") |
|
round4.add_post(post1) |
|
round4.add_post(post2) |
|
summary, retained = compressor.compress_rounds( |
|
[round1, round2, round3, round4], |
|
lambda x: x, |
|
use_back_up_engine=False, |
|
) |
|
assert summary == "None" |
|
assert len(retained) == 4 |
|
|