LiteRT-LM / c /engine_test.cc
SeaWolf-AI's picture
Upload full LiteRT-LM codebase
5f923cd verified
#include "c/engine.h"
#include <algorithm>
#include <cstring>
#include <memory>
#include <string>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/status_matchers.h" // from @com_google_absl
#include "absl/synchronization/notification.h" // from @com_google_absl
#include "nlohmann/json.hpp" // from @nlohmann_json
#include "runtime/conversation/conversation.h"
#include "runtime/conversation/io_types.h"
#include "runtime/engine/engine_settings.h"
#include "runtime/executor/executor_settings_base.h"
#include "runtime/executor/llm_executor_settings.h"
struct LiteRtLmEngineSettings {
std::unique_ptr<litert::lm::EngineSettings> settings;
};
struct LiteRtLmSessionConfig {
std::unique_ptr<litert::lm::SessionConfig> config;
};
struct LiteRtLmConversationConfig {
std::unique_ptr<litert::lm::ConversationConfig> config;
};
namespace {
std::string GetTestdataPath(const std::string& filename) {
std::string srcdir = ::testing::SrcDir();
// On Windows, SrcDir() may return paths with backslashes. The LiteRT LM C API
// expects forward slashes.
std::replace(srcdir.begin(), srcdir.end(), '\\', '/');
return srcdir + "/" + filename;
}
// Use unique_ptr for automatic resource management of C API objects.
using EngineSettingsPtr =
std::unique_ptr<LiteRtLmEngineSettings,
decltype(&litert_lm_engine_settings_delete)>;
using EnginePtr =
std::unique_ptr<LiteRtLmEngine, decltype(&litert_lm_engine_delete)>;
using SessionPtr =
std::unique_ptr<LiteRtLmSession, decltype(&litert_lm_session_delete)>;
using ResponsesPtr =
std::unique_ptr<LiteRtLmResponses, decltype(&litert_lm_responses_delete)>;
using ConversationPtr =
std::unique_ptr<LiteRtLmConversation,
decltype(&litert_lm_conversation_delete)>;
using JsonResponsePtr =
std::unique_ptr<LiteRtLmJsonResponse,
decltype(&litert_lm_json_response_delete)>;
using SessionConfigPtr =
std::unique_ptr<LiteRtLmSessionConfig,
decltype(&litert_lm_session_config_delete)>;
using ConversationConfigPtr =
std::unique_ptr<LiteRtLmConversationConfig,
decltype(&litert_lm_conversation_config_delete)>;
TEST(EngineCTest, CreateSettingsWithNoVisionAndAudioBackend) {
const std::string task_path = "test_model_path_1";
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
EXPECT_FALSE(settings->settings->GetVisionExecutorSettings().has_value());
EXPECT_FALSE(settings->settings->GetAudioExecutorSettings().has_value());
}
TEST(EngineCTest, CreateSettingsWithVisionAndAudioBackend) {
const std::string task_path = "test_model_path_1";
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ "gpu",
/* audio_backend_str */ "cpu"),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
EXPECT_TRUE(settings->settings->GetVisionExecutorSettings().has_value());
EXPECT_TRUE(settings->settings->GetAudioExecutorSettings().has_value());
EXPECT_EQ(settings->settings->GetVisionExecutorSettings()->GetBackend(),
litert::lm::Backend::GPU);
EXPECT_EQ(settings->settings->GetAudioExecutorSettings()->GetBackend(),
litert::lm::Backend::CPU);
}
TEST(EngineCTest, CreateSettingsWithInvalidVisionBackend) {
const std::string task_path = "test_model_path_1";
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ "dummy_backend",
/* audio_backend_str */ "cpu"),
&litert_lm_engine_settings_delete);
ASSERT_EQ(settings, nullptr);
}
TEST(EngineCTest, SetCacheDir) {
const std::string task_path = "test_model_path_1";
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
const std::string cache_dir = "test_cache_dir";
litert_lm_engine_settings_set_cache_dir(settings.get(), cache_dir.c_str());
EXPECT_EQ(settings->settings->GetMainExecutorSettings().GetCacheDir(),
cache_dir);
}
TEST(EngineCTest, SetPrefillChunkSize) {
const std::string task_path = "test_model_path_1";
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
int prefill_chunk_size = 128;
litert_lm_engine_settings_set_prefill_chunk_size(settings.get(),
prefill_chunk_size);
auto config = settings->settings->GetMainExecutorSettings()
.GetBackendConfig<litert::lm::CpuConfig>();
ASSERT_TRUE(config.ok());
EXPECT_EQ(config->prefill_chunk_size, prefill_chunk_size);
}
TEST(EngineCTest, BenchmarkSettings) {
const std::string task_path = "test_model_path_1";
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_enable_benchmark(settings.get());
litert_lm_engine_settings_set_num_prefill_tokens(settings.get(), 100);
litert_lm_engine_settings_set_num_decode_tokens(settings.get(), 200);
const auto& params = settings->settings->GetBenchmarkParams();
EXPECT_EQ(params->num_prefill_tokens(), 100);
EXPECT_EQ(params->num_decode_tokens(), 200);
}
TEST(EngineCTest, CreateSessionConfigWithSamplerParams) {
LiteRtLmSamplerParams sampler_params;
sampler_params.type = kTopP;
sampler_params.top_k = 10;
sampler_params.top_p = 0.5f;
sampler_params.temperature = 0.1f;
sampler_params.seed = 1234;
SessionConfigPtr config(litert_lm_session_config_create(),
&litert_lm_session_config_delete);
ASSERT_NE(config, nullptr);
litert_lm_session_config_set_sampler_params(config.get(), &sampler_params);
const auto& params = config->config->GetSamplerParams();
EXPECT_EQ(params.k(), 10);
EXPECT_FLOAT_EQ(params.p(), 0.5f);
EXPECT_FLOAT_EQ(params.temperature(), 0.1f);
EXPECT_EQ(params.seed(), 1234);
}
TEST(EngineCTest, CreateSessionConfigWithNoSamplerParams) {
SessionConfigPtr config(litert_lm_session_config_create(),
&litert_lm_session_config_delete);
ASSERT_NE(config, nullptr);
// Verify that the default sampler parameters are used.
const auto& params = config->config->GetSamplerParams();
EXPECT_EQ(params.type(),
litert::lm::proto::SamplerParameters::TYPE_UNSPECIFIED);
}
TEST(EngineCTest, CreateConversationConfig) {
// 1. Create an engine.
const std::string task_path = GetTestdataPath(
"litert_lm/runtime/testdata/test_lm_new_metadata.task");
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 16);
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
// 2. Create Sampler Params.
LiteRtLmSamplerParams sampler_params;
sampler_params.type = kTopP;
sampler_params.top_k = 10;
sampler_params.top_p = 0.5f;
sampler_params.temperature = 0.1f;
sampler_params.seed = 1234;
SessionConfigPtr session_config(litert_lm_session_config_create(),
&litert_lm_session_config_delete);
ASSERT_NE(session_config, nullptr);
litert_lm_session_config_set_sampler_params(session_config.get(),
&sampler_params);
// 3. Create a Conversation Config with the Engine Handle, Session Config
// and System Message.
const std::string system_message =
R"({"type":"text","text":"You are a helpful assistant."})";
ConversationConfigPtr conversation_config(
litert_lm_conversation_config_create(
engine.get(), session_config.get(), system_message.c_str(),
/*tools_json=*/nullptr, /*messages_json=*/nullptr,
/*enable_constrained_decoding=*/false),
&litert_lm_conversation_config_delete);
ASSERT_NE(conversation_config, nullptr);
// 4. Test to see if the Conversation Config has the Sampler Params.
const auto& params =
conversation_config->config->GetSessionConfig().GetSamplerParams();
EXPECT_EQ(params.k(), 10);
EXPECT_FLOAT_EQ(params.p(), 0.5f);
EXPECT_FLOAT_EQ(params.temperature(), 0.1f);
EXPECT_EQ(params.seed(), 1234);
// 5. Test to see if the Conversation Config has the correct System Message.
const auto& preface = std::get<litert::lm::JsonPreface>(
conversation_config->config->GetPreface());
nlohmann::ordered_json message;
message["role"] = "system";
message["content"] = nlohmann::ordered_json::parse(system_message);
nlohmann::ordered_json expected_messages =
nlohmann::ordered_json::array({message});
EXPECT_EQ(preface.messages, expected_messages);
}
TEST(EngineCTest, CreateConversationConfigWithNoSamplerParams) {
// 1. Create an engine.
const std::string task_path = GetTestdataPath(
"litert_lm/runtime/testdata/test_lm_new_metadata.task");
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 16);
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
// 2. Create a Conversation Config with the Engine Handle and System Message.
const std::string system_message =
R"({"type":"text","text":"You are a helpful assistant."})";
SessionConfigPtr session_config(litert_lm_session_config_create(),
&litert_lm_session_config_delete);
ASSERT_NE(session_config, nullptr);
ConversationConfigPtr conversation_config(
litert_lm_conversation_config_create(
engine.get(), session_config.get(), system_message.c_str(),
/*tools_json=*/nullptr, /*messages_json=*/nullptr,
/*enable_constrained_decoding=*/false),
&litert_lm_conversation_config_delete);
ASSERT_NE(conversation_config, nullptr);
// 3. Test to see if the Conversation Config has the correct System Message.
const auto& preface = std::get<litert::lm::JsonPreface>(
conversation_config->config->GetPreface());
nlohmann::ordered_json message;
message["role"] = "system";
message["content"] = nlohmann::ordered_json::parse(system_message);
nlohmann::ordered_json expected_messages =
nlohmann::ordered_json::array({message});
EXPECT_EQ(preface.messages, expected_messages);
}
TEST(EngineCTest, CreateConversationConfigWithNoSamplerParamsNoSystemMessage) {
// 1. Create an engine.
const std::string task_path = GetTestdataPath(
"litert_lm/runtime/testdata/test_lm_new_metadata.task");
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 16);
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
// 2. Create a Conversation Config with the Engine Handle and System Message.
SessionConfigPtr session_config(litert_lm_session_config_create(),
&litert_lm_session_config_delete);
ASSERT_NE(session_config, nullptr);
ConversationConfigPtr conversation_config(
litert_lm_conversation_config_create(
engine.get(),
/*session_config=*/session_config.get(),
/*system_message_json=*/nullptr,
/*tools_json=*/nullptr,
/*messages_json=*/nullptr,
/*enable_constrained_decoding=*/false),
&litert_lm_conversation_config_delete);
ASSERT_NE(conversation_config, nullptr);
// 4. Test to see if the Conversation Config has the correct System Message.
const auto& preface = std::get<litert::lm::JsonPreface>(
conversation_config->config->GetPreface());
EXPECT_EQ(preface.messages, nullptr);
}
TEST(EngineCTest, CreateConversationConfigWithTools) {
// 1. Create an engine.
const std::string task_path = GetTestdataPath(
"litert_lm/runtime/testdata/test_lm_new_metadata.task");
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 16);
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
// 2. Create a Conversation Config with tools.
const std::string tools_json = R"([
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}
},
"required": ["location"]
}
}
}
])";
ConversationConfigPtr conversation_config(
litert_lm_conversation_config_create(
engine.get(), /*session_config=*/nullptr,
/*system_message_json=*/nullptr, tools_json.c_str(),
/*messages_json=*/nullptr,
/*enable_constrained_decoding=*/false),
&litert_lm_conversation_config_delete);
ASSERT_NE(conversation_config, nullptr);
// 3. Test to see if the Conversation Config has the correct tools.
const auto& preface = std::get<litert::lm::JsonPreface>(
conversation_config->config->GetPreface());
EXPECT_EQ(preface.tools, nlohmann::ordered_json::parse(tools_json));
}
TEST(EngineCTest, CreateConversationConfigWithInvalidTools) {
// 1. Create an engine.
const std::string task_path = GetTestdataPath(
"litert_lm/runtime/testdata/test_lm_new_metadata.task");
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 16);
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
// 2. Create a Conversation Config with an invalid tools json.
const std::string tools_json = R"({"type": "function"})"; // Not an array
ConversationConfigPtr conversation_config(
litert_lm_conversation_config_create(
engine.get(), /*session_config=*/nullptr,
/*system_message_json=*/nullptr, tools_json.c_str(),
/*messages_json=*/nullptr,
/*enable_constrained_decoding=*/false),
&litert_lm_conversation_config_delete);
ASSERT_NE(conversation_config, nullptr);
// 3. Test to see if the Conversation Config has no tools.
const auto& preface = std::get<litert::lm::JsonPreface>(
conversation_config->config->GetPreface());
EXPECT_TRUE(preface.tools.is_null());
}
TEST(EngineCTest, CreateConversationConfigWithEmptyToolsArray) {
// 1. Create an engine.
const std::string task_path = GetTestdataPath(
"litert_lm/runtime/testdata/test_lm_new_metadata.task");
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 16);
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
// 2. Create a Conversation Config with an empty tools array.
const std::string tools_json = R"([])";
ConversationConfigPtr conversation_config(
litert_lm_conversation_config_create(
engine.get(), /*session_config=*/nullptr,
/*system_message_json=*/nullptr, tools_json.c_str(),
/*messages_json=*/nullptr,
/*enable_constrained_decoding=*/false),
&litert_lm_conversation_config_delete);
ASSERT_NE(conversation_config, nullptr);
// 3. Test to see if the Conversation Config has empty tools.
const auto& preface = std::get<litert::lm::JsonPreface>(
conversation_config->config->GetPreface());
EXPECT_TRUE(preface.tools.is_array());
EXPECT_TRUE(preface.tools.empty());
}
TEST(EngineCTest, CreateConversationConfigWithMalformedToolsJson) {
// 1. Create an engine.
const std::string task_path = GetTestdataPath(
"litert_lm/runtime/testdata/test_lm_new_metadata.task");
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 16);
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
// 2. Create a Conversation Config with malformed tools json.
const std::string tools_json = R"([{"type": "function", ...}])";
ConversationConfigPtr conversation_config(
litert_lm_conversation_config_create(
engine.get(), /*session_config=*/nullptr,
/*system_message_json=*/nullptr, tools_json.c_str(),
/*messages_json=*/nullptr,
/*enable_constrained_decoding=*/false),
&litert_lm_conversation_config_delete);
ASSERT_NE(conversation_config, nullptr);
// 3. Test to see if the Conversation Config has no tools.
const auto& preface = std::get<litert::lm::JsonPreface>(
conversation_config->config->GetPreface());
EXPECT_TRUE(preface.tools.is_null());
}
TEST(EngineCTest, CreateConversationConfigWithNoSystemMessage) {
// 1. Create an engine.
const std::string task_path = GetTestdataPath(
"litert_lm/runtime/testdata/test_lm_new_metadata.task");
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 16);
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
// 2. Create Sampler Params.
LiteRtLmSamplerParams sampler_params;
sampler_params.type = kTopP;
sampler_params.top_k = 10;
sampler_params.top_p = 0.5f;
sampler_params.temperature = 0.1f;
sampler_params.seed = 1234;
SessionConfigPtr session_config(litert_lm_session_config_create(),
&litert_lm_session_config_delete);
ASSERT_NE(session_config, nullptr);
litert_lm_session_config_set_sampler_params(session_config.get(),
&sampler_params);
// 3. Create a Conversation Config with the Engine Handle and Session Config.
ConversationConfigPtr conversation_config(
litert_lm_conversation_config_create(
engine.get(), session_config.get(), /*system_message_json=*/nullptr,
/*tools_json=*/nullptr, /*messages_json=*/nullptr,
/*enable_constrained_decoding=*/false),
&litert_lm_conversation_config_delete);
ASSERT_NE(conversation_config, nullptr);
// 4. Test to see if the Conversation Config has the default Sampler Params.
const auto& params =
conversation_config->config->GetSessionConfig().GetSamplerParams();
EXPECT_EQ(params.k(), 10);
EXPECT_FLOAT_EQ(params.p(), 0.5f);
EXPECT_FLOAT_EQ(params.temperature(), 0.1f);
EXPECT_EQ(params.seed(), 1234);
// 5. Test to see if the Conversation Config has the correct System Message.
const auto& preface = std::get<litert::lm::JsonPreface>(
conversation_config->config->GetPreface());
EXPECT_EQ(preface.messages, nullptr);
}
TEST(EngineCTest, GenerateContent) {
const std::string task_path = GetTestdataPath(
"litert_lm/runtime/testdata/test_lm_new_metadata.task");
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 16);
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
SessionPtr session(litert_lm_engine_create_session(
engine.get(), /* session_config */ nullptr),
&litert_lm_session_delete);
ASSERT_NE(session, nullptr);
const char* prompt = "Hello world!";
InputData input_data;
input_data.type = kInputText;
input_data.data = prompt;
input_data.size = strlen(prompt);
ResponsesPtr responses(
litert_lm_session_generate_content(session.get(), &input_data, 1),
&litert_lm_responses_delete);
ASSERT_NE(responses, nullptr);
EXPECT_EQ(litert_lm_responses_get_num_candidates(responses.get()), 1);
const char* response_text =
litert_lm_responses_get_response_text_at(responses.get(), 0);
ASSERT_NE(response_text, nullptr);
EXPECT_GT(strlen(response_text), 0);
}
TEST(EngineCTest, CreateSessionWithMaxOutputTokens) {
const std::string task_path = GetTestdataPath(
"litert_lm/runtime/testdata/test_lm_new_metadata.task");
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 16);
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
// Test with max_output_tokens=1. The response length should be short (<10).
{
SessionConfigPtr session_config(litert_lm_session_config_create(),
&litert_lm_session_config_delete);
ASSERT_NE(session_config, nullptr);
litert_lm_session_config_set_max_output_tokens(session_config.get(), 1);
SessionPtr session(
litert_lm_engine_create_session(engine.get(), session_config.get()),
&litert_lm_session_delete);
ASSERT_NE(session, nullptr);
const char* prompt = "Hello world!";
InputData input_data;
input_data.type = kInputText;
input_data.data = prompt;
input_data.size = strlen(prompt);
ResponsesPtr responses(
litert_lm_session_generate_content(session.get(), &input_data, 1),
&litert_lm_responses_delete);
ASSERT_NE(responses, nullptr);
EXPECT_EQ(litert_lm_responses_get_num_candidates(responses.get()), 1);
const char* response_text =
litert_lm_responses_get_response_text_at(responses.get(), 0);
ASSERT_NE(response_text, nullptr);
EXPECT_GT(strlen(response_text), 0);
EXPECT_LT(strlen(response_text), 10);
}
// Test without max_output_tokens. The response length should be long (>=10).
{
SessionConfigPtr session_config(litert_lm_session_config_create(),
&litert_lm_session_config_delete);
ASSERT_NE(session_config, nullptr);
SessionPtr session(
litert_lm_engine_create_session(engine.get(), session_config.get()),
&litert_lm_session_delete);
ASSERT_NE(session, nullptr);
const char* prompt = "Hello world!";
InputData input_data;
input_data.type = kInputText;
input_data.data = prompt;
input_data.size = strlen(prompt);
ResponsesPtr responses(
litert_lm_session_generate_content(session.get(), &input_data, 1),
&litert_lm_responses_delete);
ASSERT_NE(responses, nullptr);
EXPECT_EQ(litert_lm_responses_get_num_candidates(responses.get()), 1);
const char* response_text =
litert_lm_responses_get_response_text_at(responses.get(), 0);
ASSERT_NE(response_text, nullptr);
EXPECT_GT(strlen(response_text), 10);
}
}
TEST(EngineCTest, ConversationSendMessage) {
const std::string task_path = GetTestdataPath(
"litert_lm/runtime/testdata/test_lm_new_metadata.task");
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 16);
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
ConversationPtr conversation(
litert_lm_conversation_create(engine.get(),
/*conversation_config=*/nullptr),
&litert_lm_conversation_delete);
ASSERT_NE(conversation, nullptr);
const char* message_json =
R"({"role": "user", "content": [{"type": "text", "text": "Hello"}]})";
JsonResponsePtr response(
litert_lm_conversation_send_message(conversation.get(), message_json,
/*extra_context=*/nullptr),
&litert_lm_json_response_delete);
ASSERT_NE(response, nullptr);
const char* response_str = litert_lm_json_response_get_string(response.get());
ASSERT_NE(response_str, nullptr);
EXPECT_GT(strlen(response_str), 0);
}
TEST(EngineCTest, ConversationSendMessageWithConfig) {
// 1. Create an engine.
const std::string task_path = GetTestdataPath(
"litert_lm/runtime/testdata/test_lm.litertlm");
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 16);
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
// 2. Create Sampler Params.
LiteRtLmSamplerParams sampler_params;
sampler_params.type = kTopP;
sampler_params.top_k = 10;
sampler_params.top_p = 0.5f;
sampler_params.temperature = 0.1f;
sampler_params.seed = 1234;
SessionConfigPtr session_config(litert_lm_session_config_create(),
&litert_lm_session_config_delete);
ASSERT_NE(session_config, nullptr);
litert_lm_session_config_set_sampler_params(session_config.get(),
&sampler_params);
// 3. Create a Conversation Config with the Engine Handle, Session Config
// and System Message.
const std::string system_message =
R"({"type":"text","text":"You are a helpful assistant."})";
ConversationConfigPtr conversation_config(
litert_lm_conversation_config_create(
engine.get(), session_config.get(), system_message.c_str(),
/*tools_json=*/nullptr, /*messages_json=*/nullptr,
/*enable_constrained_decoding=*/false),
&litert_lm_conversation_config_delete);
ASSERT_NE(conversation_config, nullptr);
// 4. Create a Conversation with the Conversation Config.
ConversationPtr conversation(
litert_lm_conversation_create(engine.get(), conversation_config.get()),
&litert_lm_conversation_delete);
ASSERT_NE(conversation, nullptr);
// 5. Send a message to the conversation.
const char* message_json =
R"({"role": "user", "content": [{"type": "text", "text": "Hello"}]})";
JsonResponsePtr response(
litert_lm_conversation_send_message(conversation.get(), message_json,
/*extra_context=*/nullptr),
&litert_lm_json_response_delete);
ASSERT_NE(response, nullptr);
const char* response_str = litert_lm_json_response_get_string(response.get());
ASSERT_NE(response_str, nullptr);
EXPECT_GT(strlen(response_str), 0);
}
TEST(EngineCTest, ConversationSendMessageWithExtraContext) {
// 1. Create an engine.
const std::string task_path = GetTestdataPath(
"litert_lm/runtime/testdata/test_lm.litertlm");
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 16);
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
// 2. Create a Conversation Config.
ConversationConfigPtr conversation_config(
litert_lm_conversation_config_create(
engine.get(), /*session_config=*/nullptr,
/*system_message_json=*/nullptr, /*tools_json=*/nullptr,
/*messages_json=*/nullptr, /*enable_constrained_decoding=*/false),
&litert_lm_conversation_config_delete);
ASSERT_NE(conversation_config, nullptr);
// 3. Create a Conversation with the Conversation Config.
ConversationPtr conversation(
litert_lm_conversation_create(engine.get(), conversation_config.get()),
&litert_lm_conversation_delete);
ASSERT_NE(conversation, nullptr);
// 4. Send a message to the conversation with extra context.
const char* message_json =
R"({"role": "user", "content": [{"type": "text", "text": "Hello"}]})";
const char* extra_context = R"({"key": "value"})";
JsonResponsePtr response(
litert_lm_conversation_send_message(conversation.get(), message_json,
/*extra_context=*/extra_context),
&litert_lm_json_response_delete);
ASSERT_NE(response, nullptr);
const char* response_str = litert_lm_json_response_get_string(response.get());
ASSERT_NE(response_str, nullptr);
EXPECT_GT(strlen(response_str), 0);
}
struct StreamCallbackData {
std::string response;
absl::Notification done;
absl::Status status;
};
void StreamCallback(void* callback_data, const char* chunk, bool is_final,
const char* error_msg) {
auto* data = static_cast<StreamCallbackData*>(callback_data);
if (error_msg) {
data->status = absl::InternalError(error_msg);
}
if (chunk) {
data->response.append(chunk);
}
if (is_final) {
data->done.Notify();
}
}
TEST(EngineCTest, GenerateContentStream) {
const std::string task_path = GetTestdataPath(
"litert_lm/runtime/testdata/test_lm_new_metadata.task");
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 16);
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
SessionPtr session(litert_lm_engine_create_session(
engine.get(), /* session_config */ nullptr),
&litert_lm_session_delete);
ASSERT_NE(session, nullptr);
const char* prompt = "Hello world!";
InputData input_data;
input_data.type = kInputText;
input_data.data = prompt;
input_data.size = strlen(prompt);
StreamCallbackData callback_data;
int result = litert_lm_session_generate_content_stream(
session.get(), &input_data, 1, &StreamCallback, &callback_data);
ASSERT_EQ(result, 0);
callback_data.done.WaitForNotification();
// This model is too small and generate random output, so the result may be
// either success or failure due to maximum kv-cache size reached.
EXPECT_THAT(
callback_data.status,
testing::AnyOf(absl_testing::IsOk(),
absl_testing::StatusIs(
absl::StatusCode::kInternal,
testing::HasSubstr("Max number of tokens reached."))));
EXPECT_GT(callback_data.response.length(), 0);
}
TEST(EngineCTest, ConversationSendMessageStream) {
const std::string task_path = GetTestdataPath(
"litert_lm/runtime/testdata/test_lm_new_metadata.task");
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 16);
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
ConversationPtr conversation(
litert_lm_conversation_create(engine.get(),
/*conversation_config=*/nullptr),
&litert_lm_conversation_delete);
ASSERT_NE(conversation, nullptr);
const char* message_json =
R"({"role": "user", "content": [{"type": "text", "text": "Hello"}]})";
StreamCallbackData callback_data;
int result = litert_lm_conversation_send_message_stream(
conversation.get(), message_json, /*extra_context=*/nullptr,
&StreamCallback, &callback_data);
ASSERT_EQ(result, 0);
callback_data.done.WaitForNotification();
EXPECT_GT(callback_data.response.length(), 0);
}
TEST(EngineCTest, ConversationSendMessageStreamWithExtraContext) {
const std::string task_path = GetTestdataPath(
"litert_lm/runtime/testdata/test_lm_new_metadata.task");
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 16);
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
ConversationPtr conversation(
litert_lm_conversation_create(engine.get(),
/*conversation_config=*/nullptr),
&litert_lm_conversation_delete);
ASSERT_NE(conversation, nullptr);
const char* message_json =
R"({"role": "user", "content": [{"type": "text", "text": "Hello"}]})";
const char* extra_context = R"({"key": "value"})";
StreamCallbackData callback_data;
int result = litert_lm_conversation_send_message_stream(
conversation.get(), message_json, /*extra_context=*/extra_context,
&StreamCallback, &callback_data);
ASSERT_EQ(result, 0);
callback_data.done.WaitForNotification();
EXPECT_GT(callback_data.response.length(), 0);
}
TEST(EngineCTest, ConversationSendMessageStreamAndCancel) {
const std::string task_path = GetTestdataPath(
"litert_lm/runtime/testdata/test_lm_new_metadata.task");
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 512);
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
ConversationPtr conversation(
litert_lm_conversation_create(engine.get(),
/*conversation_config=*/nullptr),
&litert_lm_conversation_delete);
ASSERT_NE(conversation, nullptr);
const char* message_json =
R"({"role": "user", "content": [{"type": "text", "text": "Hello"}]})";
StreamCallbackData callback_data;
int result = litert_lm_conversation_send_message_stream(
conversation.get(), message_json, /*extra_context=*/nullptr,
&StreamCallback, &callback_data);
ASSERT_EQ(result, 0);
litert_lm_conversation_cancel_process(conversation.get());
callback_data.done.WaitForNotification();
EXPECT_THAT(callback_data.status,
absl_testing::StatusIs(absl::StatusCode::kInternal,
testing::HasSubstr("CANCELLED")));
}
using BenchmarkInfoPtr =
std::unique_ptr<LiteRtLmBenchmarkInfo,
decltype(&litert_lm_benchmark_info_delete)>;
TEST(EngineCTest, Benchmark) {
auto task_path =
std::filesystem::path(::testing::SrcDir()) /
"litert_lm/runtime/testdata/test_lm_new_metadata.task";
EngineSettingsPtr settings(
litert_lm_engine_settings_create(task_path.string().c_str(), "cpu",
/* vision_backend_str */ nullptr,
/* audio_backend_str */ nullptr),
&litert_lm_engine_settings_delete);
ASSERT_NE(settings, nullptr);
litert_lm_engine_settings_set_max_num_tokens(settings.get(), 16);
litert_lm_engine_settings_enable_benchmark(settings.get());
EnginePtr engine(litert_lm_engine_create(settings.get()),
&litert_lm_engine_delete);
ASSERT_NE(engine, nullptr);
SessionPtr session(litert_lm_engine_create_session(
engine.get(), /* session_config */ nullptr),
&litert_lm_session_delete);
ASSERT_NE(session, nullptr);
const char* prompt = "Hello world!";
InputData input_data;
input_data.type = kInputText;
input_data.data = prompt;
input_data.size = strlen(prompt);
ResponsesPtr responses(
litert_lm_session_generate_content(session.get(), &input_data, 1),
&litert_lm_responses_delete);
ASSERT_NE(responses, nullptr);
BenchmarkInfoPtr benchmark_info(
litert_lm_session_get_benchmark_info(session.get()),
&litert_lm_benchmark_info_delete);
ASSERT_NE(benchmark_info, nullptr);
EXPECT_GT(
litert_lm_benchmark_info_get_time_to_first_token(benchmark_info.get()),
0.0);
EXPECT_GT(litert_lm_benchmark_info_get_total_init_time_in_second(
benchmark_info.get()),
0.0);
int num_prefill_turns =
litert_lm_benchmark_info_get_num_prefill_turns(benchmark_info.get());
EXPECT_GT(num_prefill_turns, 0);
for (int i = 0; i < num_prefill_turns; ++i) {
EXPECT_GT(litert_lm_benchmark_info_get_prefill_token_count_at(
benchmark_info.get(), i),
0);
EXPECT_GT(litert_lm_benchmark_info_get_prefill_tokens_per_sec_at(
benchmark_info.get(), i),
0.0);
}
int num_decode_turns =
litert_lm_benchmark_info_get_num_decode_turns(benchmark_info.get());
EXPECT_GT(num_decode_turns, 0);
for (int i = 0; i < num_decode_turns; ++i) {
EXPECT_GT(litert_lm_benchmark_info_get_decode_token_count_at(
benchmark_info.get(), i),
0);
EXPECT_GT(litert_lm_benchmark_info_get_decode_tokens_per_sec_at(
benchmark_info.get(), i),
0.0);
}
}
} // namespace