SeaWolf-AI's picture
Upload full LiteRT-LM codebase
5f923cd verified
# Copyright 2026 The ODML Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for litert-lm models."""
import dataclasses
import glob
import importlib.util
import inspect
import json
import os
import pathlib
import traceback
import click
import prompt_toolkit
from prompt_toolkit import key_binding
import litert_lm
try:
# pylint: disable=g-import-not-at-top
from litert_lm.adb import adb_benchmark
from litert_lm.adb import adb_engine
_HAS_ADB = True
except ImportError:
_HAS_ADB = False
def load_preset(preset: str):
"""Loads a preset file and returns the tools, messages and extra_context."""
click.echo(click.style(f"Loading preset from {preset}:", dim=True))
if not os.path.exists(preset):
click.echo(click.style(f"Preset file not found: {preset}", fg="red"))
return None, None, None
spec = importlib.util.spec_from_file_location("user_tools", preset)
if not spec or not spec.loader:
click.echo(click.style(f"Failed to load tools from {preset}", fg="red"))
return None, None, None
user_tools = importlib.util.module_from_spec(spec)
spec.loader.exec_module(user_tools)
tools = getattr(user_tools, "tools", None)
if tools is None:
tools = [
obj
for name, obj in inspect.getmembers(user_tools, inspect.isfunction)
if obj.__module__ == "user_tools"
]
messages = None
system_instruction = getattr(user_tools, "system_instruction", None)
if system_instruction:
click.echo(
click.style(f"- System instruction: {system_instruction}", dim=True)
)
messages = [{
"role": "system",
"content": [{"type": "text", "text": system_instruction}],
}]
click.echo(click.style("- Tools:", dim=True))
for tool in tools:
click.echo(
click.style(f" - {getattr(tool, '__name__', str(tool))}", dim=True)
)
extra_context = getattr(user_tools, "extra_context", None)
if extra_context:
click.echo(click.style(f"- Extra context: {extra_context}", dim=True))
return tools, messages, extra_context
class LoggingToolEventHandler(litert_lm.ToolEventHandler):
"""Log tool call and tool response events."""
def __init__(self, model):
self.model = model
def approve_tool_call(self, tool_call):
"""Logs a tool call."""
if self.model.active_channel is not None:
click.echo("\n", nl=False)
self.model.active_channel = None
click.echo(
click.style(
f"[tool_call] {json.dumps(tool_call['function'])}", fg="green"
)
)
return True
def process_tool_response(self, tool_response):
"""Logs a tool response."""
click.echo(
click.style(f"[tool_response] {json.dumps(tool_response)}", fg="green")
)
return tool_response
def _parse_backend(backend: str) -> litert_lm.Backend:
"""Parses the backend string and returns the corresponding Backend enum."""
backend_lower = backend.lower()
if backend_lower == "gpu":
return litert_lm.Backend.GPU
return litert_lm.Backend.CPU
@dataclasses.dataclass
class Model:
"""Represents a LiteRT-LM model.
Attributes:
model_id: The ID of the model.
model_path: The local path to the model file.
active_channel: The name of the currently active channel, or None if default
text is being printed.
"""
model_id: str
model_path: str
active_channel: str | None = None
def exists(self) -> bool:
"""Returns True if the model file exists locally."""
return os.path.isfile(self.model_path)
def to_str(self) -> str:
"""Returns a string representation of the model."""
return self.model_id
def run_interactive(
self,
is_android: bool = False,
backend: str = "cpu",
preset: str | None = None,
prompt: str | None = None,
enable_speculative_decoding: bool | None = None,
no_template: bool = False,
):
"""Runs the model interactively or with a single prompt.
Args:
is_android: Whether to run the model on an Android device via ADB.
backend: The backend to use (cpu or gpu).
preset: Path to a Python file containing tool functions and system
instructions.
prompt: A single prompt to run once and exit.
enable_speculative_decoding: Whether to enable speculative decoding. If
None, use the model's default.
no_template: Interact with the model directly without applying prompt
templates or stripping stop tokens.
"""
if not self.exists():
click.echo(
click.style(
f"Could not find {self.to_str()} locally in {self.model_path}.",
fg="red",
)
)
return
try:
backend_val = _parse_backend(backend)
if is_android:
if not _HAS_ADB:
raise ImportError("litert_lm.adb dependencies are not available.")
engine_cm = adb_engine.AdbEngine(self.model_path, backend=backend_val)
else:
engine_cm = litert_lm.Engine(
self.model_path,
backend=backend_val,
enable_speculative_decoding=enable_speculative_decoding,
)
with engine_cm as engine:
if no_template:
runner_cm = engine.create_session(apply_prompt_template=False)
else:
tools = None
messages = None
extra_context = None
if preset:
tools, messages, extra_context = load_preset(preset)
if tools is None and messages is None and extra_context is None:
return
handler = LoggingToolEventHandler(self) if tools else None
runner_cm = engine.create_conversation(
tools=tools,
messages=messages,
tool_event_handler=handler,
extra_context=extra_context,
)
with runner_cm as runner:
if prompt:
if isinstance(runner, litert_lm.AbstractSession):
self._execute_raw_prompt(runner, prompt)
elif isinstance(runner, litert_lm.AbstractConversation):
self._execute_prompt(runner, prompt)
return
click.echo(
click.style(
"[enter] submit | [ctrl+j] newline | [ctrl+c] clear/exit",
fg="cyan",
)
)
click.echo()
history_path = os.path.join(
os.path.expanduser("~"), ".litert-lm", "history"
)
os.makedirs(os.path.dirname(history_path), exist_ok=True)
prompt_session = prompt_toolkit.PromptSession(
history=prompt_toolkit.history.FileHistory(history_path),
key_bindings=self._create_keybindings(),
)
while True:
try:
user_prompt = prompt_session.prompt(
prompt_toolkit.ANSI(click.style("> ", fg="green", bold=True)),
multiline=True,
# Start the new line in the beginning of line. This makes
# copying respecting the text.
prompt_continuation=lambda width, line_number, is_soft_wrap: (
""
),
)
if not user_prompt:
continue
if isinstance(runner, litert_lm.AbstractSession):
self._execute_raw_prompt(
runner,
user_prompt,
)
elif isinstance(runner, litert_lm.AbstractConversation):
self._execute_prompt(
runner,
user_prompt,
)
except EOFError:
break
except KeyboardInterrupt:
# Catch Ctrl+C at the input prompt
click.echo()
continue
except Exception: # pylint: disable=broad-exception-caught
click.echo(click.style("Error during inference", fg="red"))
traceback.print_exc()
except Exception: # pylint: disable=broad-exception-caught
click.echo(click.style("An error occurred", fg="red"))
traceback.print_exc()
def _execute_prompt(
self, conversation: litert_lm.AbstractConversation, prompt: str
):
"""Executes a single prompt and prints the result."""
self.active_channel = None
stream = conversation.send_message_async(prompt)
try:
for chunk in stream:
# Handle regular content
content_list = chunk.get("content", [])
for item in content_list:
if item.get("type") == "text":
if self.active_channel is not None:
click.echo()
self.active_channel = None
click.echo(click.style(item.get("text", ""), fg="yellow"), nl=False)
# Handle channels
channels = chunk.get("channels", {})
for channel_name, channel_content in channels.items():
if self.active_channel != channel_name:
if self.active_channel is not None:
click.echo()
click.echo(click.style(f"[{channel_name}] ", fg="blue"), nl=False)
self.active_channel = channel_name
click.echo(click.style(channel_content, fg="yellow"), nl=False)
if self.active_channel is not None:
click.echo()
else:
click.echo()
except KeyboardInterrupt:
conversation.cancel_process()
# Empty the iterator queue.
# This ensures we don't throw away StopIteration.
for _ in stream:
pass
click.echo(click.style("\n[Generation cancelled]", dim=True))
def _execute_raw_prompt(
self, session: litert_lm.AbstractSession, prompt: str
):
"""Executes a single raw prompt and prints the result."""
session.run_prefill([prompt])
stream = session.run_decode_async()
try:
for chunk in stream:
if chunk.texts:
click.echo(click.style(chunk.texts[0], fg="yellow"), nl=False)
click.echo()
except KeyboardInterrupt:
# Empty the iterator queue.
for _ in stream:
pass
click.echo(click.style("\n[Generation cancelled]", dim=True))
def _create_keybindings(self) -> key_binding.KeyBindings:
"""Creates keybindings for the interactive prompt."""
kb = key_binding.KeyBindings()
# Key binding for sending the prompt.
@kb.add("enter")
def _handle_enter(event):
buffer = event.current_buffer
if buffer.text.strip():
buffer.validate_and_handle()
# Key binding for new line. Note that terminal cannot take
# "shift+enter", and "ctrl+enter"
@kb.add("c-j") # standard terminal convention.
@kb.add("escape", "enter") # alt+enter and esc+enter
def _handle_newline(event):
event.current_buffer.insert_text("\n")
# Key binding for clearing input or exiting.
@kb.add("c-c")
def _handle_clear_or_exit(event):
buffer = event.current_buffer
if buffer.text:
buffer.text = ""
else:
event.app.exit(exception=EOFError)
return kb
def benchmark(
self,
prefill_tokens: int = 256,
decode_tokens: int = 256,
is_android: bool = False,
backend: str = "cpu",
enable_speculative_decoding: bool | None = None,
):
"""Benchmarks the model.
Args:
prefill_tokens: The number of tokens to prefill.
decode_tokens: The number of tokens to decode.
is_android: Whether to run the benchmark on an Android device via ADB.
backend: The backend to use (cpu or gpu).
enable_speculative_decoding: Whether to enable speculative decoding. If
None, use the model's default.
"""
if not self.exists():
click.echo(
click.style(
f"Could not find {self.to_str()} locally in {self.model_path}.",
fg="red",
)
)
return
try:
backend_val = _parse_backend(backend)
if is_android:
if not _HAS_ADB:
raise ImportError("litert_lm.adb dependencies are not available.")
benchmark_obj = adb_benchmark.AdbBenchmark(
self.model_path,
backend=backend_val,
prefill_tokens=prefill_tokens,
decode_tokens=decode_tokens,
cache_dir=":nocache",
)
else:
benchmark_obj = litert_lm.Benchmark(
self.model_path,
backend=backend_val,
prefill_tokens=prefill_tokens,
decode_tokens=decode_tokens,
cache_dir=":nocache",
enable_speculative_decoding=enable_speculative_decoding,
)
click.echo(f"Benchmarking model: {self.to_str()} ({self.model_path})")
click.echo(f"Number of tokens in prefill: {prefill_tokens}")
click.echo(f"Number of tokens in decode : {decode_tokens}")
click.echo(f"Backend : {backend}")
spec_dec_str = "auto"
if enable_speculative_decoding is True:
spec_dec_str = "true"
elif enable_speculative_decoding is False:
spec_dec_str = "false"
print(f"Speculative decoding : {spec_dec_str}")
if is_android:
click.echo("Target : Android")
result = benchmark_obj.run()
click.echo("----- Results -----")
click.echo(
f"Prefill speed: {result.last_prefill_tokens_per_second:.2f}"
" tokens/s"
)
click.echo(
f"Decode speed: {result.last_decode_tokens_per_second:.2f}"
" tokens/s"
)
click.echo(f"Init time: {result.init_time_in_second:.4f} s")
click.echo(
f"Time to first token: {result.time_to_first_token_in_second:.4f} s"
)
except Exception: # pylint: disable=broad-exception-caught
click.echo(click.style("An error occurred during benchmarking", fg="red"))
traceback.print_exc()
@classmethod
def get_all_models(cls):
"""Returns a list of all locally available models."""
model_paths = glob.glob(
"*/model.litertlm",
root_dir=get_converted_models_base_dir(),
recursive=True,
)
return [
Model.from_model_id(
path.removesuffix("/model.litertlm").replace("--", "/")
)
for path in model_paths
]
@classmethod
def from_model_reference(cls, model_reference):
"""Creates a Model instance from a model reference."""
if os.path.exists(model_reference):
return cls.from_model_path(model_reference)
else:
# assume the reference is model_id
return cls.from_model_id(model_reference)
@classmethod
def from_model_path(cls, model_path):
"""Creates a Model instance from a model path."""
return cls(
model_id=os.path.basename(model_path),
model_path=os.path.abspath(model_path),
)
@classmethod
def from_model_id(cls, model_id):
"""Creates a Model instance from a model ID."""
return cls(
model_id=model_id,
model_path=os.path.join(
get_converted_models_base_dir(),
model_id.replace("/", "--"),
"model.litertlm",
),
)
# Just to use the huggingface convention. Likely to change.
def model_id_dir_name(model_id):
"""Converts a model ID to a directory name."""
return model_id.replace("/", "--")
# ~/.litert-lm/models
def get_converted_models_base_dir():
"""Gets the base directory for all converted models."""
return os.path.join(os.path.expanduser("~"), ".litert-lm", "models")
# ~/.litert-lm/models/<model_id>
def get_model_dir(model_id):
"""Gets the model directory for a given model ID."""
return os.path.join(
get_converted_models_base_dir(),
model_id_dir_name(model_id),
)