Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
## Unreleased

* **LiteRT-LM speculative decoding opt-in**:
* Added `GenerationParams.speculativeDecoding` and wired it through the
native LiteRT-LM backend to
`litert_lm_engine_settings_set_enable_speculative_decoding`. The
`LlamaEngine` default remains disabled for stable/parity behavior;
llama.cpp, WebGPU, and LiteRT-LM web reject the option until their
speculative paths are implemented.
* Updated the LiteRT-LM benchmark app so its speculative toggle now affects
native LiteRT-LM generation and is recorded in per-run/final metrics.
* **LiteRT-LM Gemma 4 function calling + thinking fix**:
* Fixed Gemma 4 `.litertlm` models not calling tools and producing unreliable
thinking. The backend supplied a hand-written stub chat template that
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,9 @@ thread counts, LoRA load configs, and rope overrides are rejected instead of
being silently ignored. `.litertlm` generation honors `GenerationParams`
`maxTokens`, `temp`, `topK`, `topP`, and `seed` on native and web, with
`stopSequences` enforced by llamadart. Native LiteRT-LM also honors stream
batching thresholds. llama.cpp-only sampling and constrained-decoding controls
batching thresholds and the opt-in `speculativeDecoding` flag; Web LiteRT-LM
rejects speculative decoding until the browser runtime exposes an equivalent
control. llama.cpp-only sampling and constrained-decoding controls
such as Min-P, repeat penalty overrides, grammar/lazy grammar triggers,
preserved tokens, custom grammar roots, and web stream batching thresholds are
rejected until LiteRT-LM exposes equivalent runtime controls.
Expand Down
42 changes: 33 additions & 9 deletions example/chat_app/lib/litert_lm_benchmark_app.dart
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class _LiteRtLmBenchmarkAppState extends State<LiteRtLmBenchmarkApp> {
);
bool _speculative = const bool.fromEnvironment(
'LITERT_LM_SPECULATIVE',
defaultValue: true,
defaultValue: false,
);
int _maxTokens = const int.fromEnvironment(
'LITERT_LM_MAX_TOKENS',
Expand Down Expand Up @@ -236,7 +236,7 @@ class _LiteRtLmBenchmarkAppState extends State<LiteRtLmBenchmarkApp> {
_append('Initializing LiteRT-LM:');
_append(' model: $modelPath');
_append(' backend: $_backend');
_append(' speculative: ignored by backend API');
_append(' speculative: $_speculative');
if (_cacheDir.isNotEmpty) {
await Directory(_cacheDir).create(recursive: true);
_append(' cache override ignored by backend API: $_cacheDir');
Expand All @@ -247,11 +247,8 @@ class _LiteRtLmBenchmarkAppState extends State<LiteRtLmBenchmarkApp> {
modelPath,
modelParams: ModelParams(
contextSize: _maxTokens,
preferredBackend: _backend == 'cpu'
? GpuBackend.cpu
: Platform.isMacOS
? GpuBackend.metal
: GpuBackend.vulkan,
preferredBackend: _preferredGpuBackendForLiteRt(_backend),
liteRtLmBackend: _liteRtLmBackendPreference(_backend),
),
);
loadSw.stop();
Expand All @@ -263,7 +260,11 @@ class _LiteRtLmBenchmarkAppState extends State<LiteRtLmBenchmarkApp> {
await engine
.generate(
_promptController.text,
params: GenerationParams(maxTokens: _outputTokens, seed: 1),
params: GenerationParams(
maxTokens: _outputTokens,
seed: 1,
speculativeDecoding: _speculative,
),
Comment thread
leehack marked this conversation as resolved.
)
.drain<void>();
}
Expand All @@ -277,7 +278,11 @@ class _LiteRtLmBenchmarkAppState extends State<LiteRtLmBenchmarkApp> {
final sw = Stopwatch()..start();
await for (final chunk in engine.generate(
_promptController.text,
params: GenerationParams(maxTokens: _outputTokens, seed: 1),
params: GenerationParams(
maxTokens: _outputTokens,
seed: 1,
speculativeDecoding: _speculative,
),
)) {
buffer.write(chunk);
}
Expand All @@ -288,6 +293,7 @@ class _LiteRtLmBenchmarkAppState extends State<LiteRtLmBenchmarkApp> {
final runMetrics = {
'index': i,
'wallMilliseconds': wallMs,
'speculativeDecoding': _speculative,
'promptEvalTokens': perf?.promptEvalTokens,
'evalTokens': perf?.evalTokens,
'hitEosBeforeTarget': perf == null
Expand Down Expand Up @@ -319,6 +325,7 @@ class _LiteRtLmBenchmarkAppState extends State<LiteRtLmBenchmarkApp> {
'wallMilliseconds': wallMs,
'backendName': await engine.getBackendName(),
'targetDecodeTokens': _outputTokens,
'speculativeDecoding': _speculative,
'backendInitMilliseconds': perf?.loadMs,
'promptEvalTokens': perf?.promptEvalTokens,
'evalTokens': perf?.evalTokens,
Expand Down Expand Up @@ -356,6 +363,23 @@ class _LiteRtLmBenchmarkAppState extends State<LiteRtLmBenchmarkApp> {
}
}

GpuBackend _preferredGpuBackendForLiteRt(String backend) {
return switch (backend) {
'cpu' => GpuBackend.cpu,
'gpu' => Platform.isMacOS ? GpuBackend.metal : GpuBackend.vulkan,
_ => GpuBackend.auto,
};
}

LiteRtLmBackendPreference _liteRtLmBackendPreference(String backend) {
return switch (backend) {
'cpu' => LiteRtLmBackendPreference.cpu,
'gpu' => LiteRtLmBackendPreference.gpu,
'npu' => LiteRtLmBackendPreference.npu,
_ => LiteRtLmBackendPreference.auto,
};
}

Future<void> _runLlamaDartBenchmark(String modelPath) async {
final engine = LlamaEngine(LlamaBackend());
try {
Expand Down
3 changes: 3 additions & 0 deletions lib/src/backends/litert_lm/litert_lm_backend_web.dart
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,9 @@ class LiteRtLmBackend
if (params.grammarRoot != defaults.grammarRoot) {
unsupported.add('grammarRoot');
}
if (params.speculativeDecoding) {
unsupported.add('speculativeDecoding');
}
if (params.streamBatchTokenThreshold !=
defaults.streamBatchTokenThreshold) {
unsupported.add('streamBatchTokenThreshold');
Expand Down
22 changes: 18 additions & 4 deletions lib/src/backends/litert_lm/litert_lm_service.dart
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class LiteRtLmService {
String? _modelPath;
String? _activeBackend;
int? _activeOutputTokens;
bool? _activeSpeculativeDecoding;
int _nextModelHandle = 1;
int _nextContextHandle = 1;
int? _modelHandle;
Expand Down Expand Up @@ -77,6 +78,7 @@ class LiteRtLmService {
_modelParams = params;
_activeBackend = resolvedBackend;
_activeOutputTokens = null;
_activeSpeculativeDecoding = null;
_modelHandle = _nextModelHandle++;
_contextHandle = null;
_lastMetrics = null;
Expand All @@ -95,6 +97,7 @@ class LiteRtLmService {
_modelParams = null;
_activeBackend = null;
_activeOutputTokens = null;
_activeSpeculativeDecoding = null;
_modelHandle = null;
_contextHandle = null;
_lastMetrics = null;
Expand Down Expand Up @@ -398,18 +401,23 @@ class LiteRtLmService {
_client?.dispose();
_client = null;
_activeOutputTokens = null;
_activeSpeculativeDecoding = null;
_lastMetrics = null;
_cancelRequested = false;
}

Future<LiteRtLmRuntimeClient> _ensureClientForGeneration(
GenerationParams params,
) {
return _ensureClientForRuntime(outputTokens: params.maxTokens);
return _ensureClientForRuntime(
outputTokens: params.maxTokens,
speculativeDecoding: params.speculativeDecoding,
);
}

Future<LiteRtLmRuntimeClient> _ensureClientForRuntime({
int? outputTokens,
bool? speculativeDecoding,
}) async {
final modelPath = _modelPath;
final modelParams = _modelParams;
Expand All @@ -419,17 +427,22 @@ class LiteRtLmService {

final resolvedOutputTokens =
outputTokens ?? _activeOutputTokens ?? GenerationParams().maxTokens;
final resolvedSpeculativeDecoding =
speculativeDecoding ?? _activeSpeculativeDecoding ?? false;
final backend = _activeBackend ?? _backendNameFor(modelParams);
final existing = _client;
if (existing != null &&
(outputTokens == null || _activeOutputTokens == resolvedOutputTokens) &&
(speculativeDecoding == null ||
_activeSpeculativeDecoding == resolvedSpeculativeDecoding) &&
_activeBackend == backend) {
return existing;
}

existing?.dispose();
_client = null;
_activeOutputTokens = null;
_activeSpeculativeDecoding = null;
final client = _clientFactory();
final responseThinkingTags = _responseThinkingTagsForModel(modelPath);
client.configureResponseThinkingTags(
Expand All @@ -443,7 +456,7 @@ class LiteRtLmService {
maxTokens: modelParams.contextSize,
outputTokens: resolvedOutputTokens,
cacheDir: _defaultCacheDir(),
speculativeDecoding: false,
speculativeDecoding: resolvedSpeculativeDecoding,
minLogLevel: _liteRtLmMinLogLevel(_logLevel),
);
} catch (_) {
Expand All @@ -456,6 +469,7 @@ class LiteRtLmService {
}
_client = client;
_activeOutputTokens = resolvedOutputTokens;
_activeSpeculativeDecoding = resolvedSpeculativeDecoding;
_activeBackend = backend;
return client;
}
Expand Down Expand Up @@ -714,8 +728,8 @@ class LiteRtLmService {
throw UnsupportedError(
'LiteRtLmBackend does not support llama.cpp-specific GenerationParams: '
'${unsupported.join(', ')}. Supported LiteRT-LM generation options are '
'maxTokens, temp, topK, topP, seed, stopSequences, and native stream '
'batching thresholds.',
'maxTokens, temp, topK, topP, seed, stopSequences, '
'speculativeDecoding, and native stream batching thresholds.',
);
}

Expand Down
8 changes: 8 additions & 0 deletions lib/src/backends/llama_cpp/llama_cpp_service.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2711,6 +2711,14 @@ class LlamaCppService {
int cancelTokenAddress, {
List<LlamaContentPart>? parts,
}) async* {
if (params.speculativeDecoding) {
throw UnsupportedError(
'llama.cpp speculative decoding is not exposed by llamadart yet. '
'Use the LiteRT-LM native backend or track llama.cpp support in '
'issues #168/#190.',
);
}

var ctx = _contexts[contextHandle];
if (ctx == null) throw Exception("Invalid context handle");
_generatingContexts.update(
Expand Down
6 changes: 6 additions & 0 deletions lib/src/backends/webgpu/webgpu_backend.dart
Original file line number Diff line number Diff line change
Expand Up @@ -1649,6 +1649,12 @@ class WebGpuLlamaBackend
GenerationParams params, {
List<LlamaContentPart>? parts,
}) {
if (params.speculativeDecoding) {
throw UnsupportedError(
'WebGPU speculative decoding is not supported yet.',
);
}

final mediaParts = _buildMultimodalParts(parts);
if (mediaParts != null && !_mmContextActive) {
throw StateError(
Expand Down
10 changes: 10 additions & 0 deletions lib/src/core/models/inference/generation_params.dart
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ class GenerationParams {
/// Grammar start symbol. Defaults to "root".
final String grammarRoot;

/// Enables backend-native speculative decoding when supported.
///
/// Native LiteRT-LM currently honors this flag by forwarding it to the
/// runtime's speculative decoding setting. llama.cpp, WebGPU, and LiteRT-LM
/// web reject this option until their speculative paths are implemented.
final bool speculativeDecoding;

/// Reuses matching prompt prefixes from previous requests in the same native
/// context to reduce prompt ingestion latency.
///
Expand Down Expand Up @@ -125,6 +132,7 @@ class GenerationParams {
this.grammarTriggers = const [],
this.preservedTokens = const [],
this.grammarRoot = 'root',
this.speculativeDecoding = false,
this.reusePromptPrefix = defaultReusePromptPrefix,
this.streamBatchTokenThreshold = defaultStreamBatchTokenThreshold,
this.streamBatchByteThreshold = defaultStreamBatchByteThreshold,
Expand All @@ -145,6 +153,7 @@ class GenerationParams {
List<GenerationGrammarTrigger>? grammarTriggers,
List<String>? preservedTokens,
String? grammarRoot,
bool? speculativeDecoding,
bool? reusePromptPrefix,
int? streamBatchTokenThreshold,
int? streamBatchByteThreshold,
Expand All @@ -163,6 +172,7 @@ class GenerationParams {
grammarTriggers: grammarTriggers ?? this.grammarTriggers,
preservedTokens: preservedTokens ?? this.preservedTokens,
grammarRoot: grammarRoot ?? this.grammarRoot,
speculativeDecoding: speculativeDecoding ?? this.speculativeDecoding,
reusePromptPrefix: reusePromptPrefix ?? this.reusePromptPrefix,
streamBatchTokenThreshold:
streamBatchTokenThreshold ?? this.streamBatchTokenThreshold,
Expand Down
33 changes: 33 additions & 0 deletions test/unit/backends/litert_lm/litert_lm_backend_web_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,39 @@ void main() {
}
});

test('rejects speculative decoding on LiteRT-LM web', () async {
_installFakeEngine(chunks: <JSAny?>[]);

final backend = LiteRtLmBackend();
try {
final modelHandle = await backend.modelLoadFromUrl(
'https://example.com/model.litertlm',
const ModelParams(),
);
final contextHandle = await backend.contextCreate(
modelHandle,
const ModelParams(),
);

await expectLater(
backend.generate(
contextHandle,
'hello',
const GenerationParams(speculativeDecoding: true),
),
emitsError(
isA<UnsupportedError>().having(
(error) => error.message.toString(),
'message',
contains('speculativeDecoding'),
),
),
);
} finally {
await backend.dispose();
}
});

test('rejects unsupported context-time model params', () async {
_installFakeEngine(chunks: <JSAny?>[]);

Expand Down
Loading
Loading