diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index aae14fa4e..f66e51d39 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4676,14 +4676,27 @@ def set_gguf_parameters(self): self.gguf_writer.add_uint32(gguf.Keys.LLM.SAMPLING_RATE.format(arch=arch), sampling_rate) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - if name.startswith("language_model."): + # Strip prefixes — Local variant nests under model.language_model / model.embedding_list + if name.startswith("model.language_model."): + name = name.replace("model.language_model.", "", 1) + elif name.startswith("language_model."): name = name.replace("language_model.", "", 1) + # Local variant: embedding_list.0 = text (skip), 1-32 = audio codebooks 0-31 + if (match := re.fullmatch(r"model\.embedding_list\.(\d+)\.weight", name)) is not None: + idx = int(match.group(1)) + if idx == 0: + return # text embedding — already covered by embed_tokens + yield (f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD_AUDIO]}.{idx - 1}.weight", data_torch) + return + + # 8B variant: emb_ext.N = audio codebook N (no offset) if (match := re.fullmatch(r"emb_ext\.(\d+)\.weight", name)) is not None: vq_idx = int(match.group(1)) yield (f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD_AUDIO]}.{vq_idx}.weight", data_torch) return + # Audio LM heads: 0 = text output, 1-32 = audio output 0-31 if (match := re.fullmatch(r"lm_heads\.(\d+)\.weight", name)) is not None: head_idx = int(match.group(1)) if head_idx == 0: @@ -4692,6 +4705,54 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield (f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT_AUDIO]}.{head_idx - 1}.weight", data_torch) return + # Layer norms before LM heads (Local variant) + if (match := re.fullmatch(r"layer_norm_before_lm_heads\.(\d+)\.weight", name)) is not None: + idx = int(match.group(1)) + yield (f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.AUDIO_LN]}.{idx}.weight", data_torch) + return + + # Local transformer (4-layer mini-transformer) + _local_map = { + "self_attn.q_proj": gguf.MODEL_TENSOR.LOCAL_ATTN_Q, + "self_attn.k_proj": gguf.MODEL_TENSOR.LOCAL_ATTN_K, + "self_attn.v_proj": gguf.MODEL_TENSOR.LOCAL_ATTN_V, + "self_attn.o_proj": gguf.MODEL_TENSOR.LOCAL_ATTN_OUT, + "self_attn.q_norm": gguf.MODEL_TENSOR.LOCAL_ATTN_Q_NORM, + "self_attn.k_norm": gguf.MODEL_TENSOR.LOCAL_ATTN_K_NORM, + "input_layernorm": gguf.MODEL_TENSOR.LOCAL_ATTN_NORM, + "post_attention_layernorm": gguf.MODEL_TENSOR.LOCAL_FFN_NORM, + "mlp.gate_proj": gguf.MODEL_TENSOR.LOCAL_FFN_GATE, + "mlp.down_proj": gguf.MODEL_TENSOR.LOCAL_FFN_DOWN, + "mlp.up_proj": gguf.MODEL_TENSOR.LOCAL_FFN_UP, + } + if (match := re.fullmatch(r"local_transformer\.layers\.(\d+)\.(.+?)\.weight", name)) is not None: + layer_id = int(match.group(1)) + suffix = match.group(2) + if suffix in _local_map: + gguf_name = gguf.TENSOR_NAMES[_local_map[suffix]].format(bid=layer_id) + yield (f"{gguf_name}.weight", data_torch) + return + if name == "local_transformer.norm.weight": + yield (f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.LOCAL_OUTPUT_NORM]}.weight", data_torch) + return + + # Local-to-speech bridge MLPs (33 indexed) + _bridge_map = {"gate_proj": gguf.MODEL_TENSOR.LOCAL_TO_SPEECH_GATE, "down_proj": gguf.MODEL_TENSOR.LOCAL_TO_SPEECH_DOWN, "up_proj": gguf.MODEL_TENSOR.LOCAL_TO_SPEECH_UP} + if (match := re.fullmatch(r"local_to_speech_embedding_mlps\.(\d+)\.(.+?)\.weight", name)) is not None: + idx = int(match.group(1)) + proj = match.group(2) + if proj in _bridge_map: + yield (f"{gguf.TENSOR_NAMES[_bridge_map[proj]]}.{idx}.weight", data_torch) + return + + # Speech-to-local bridge MLP (single) + _s2l_map = {"gate_proj": gguf.MODEL_TENSOR.SPEECH_TO_LOCAL_GATE, "down_proj": gguf.MODEL_TENSOR.SPEECH_TO_LOCAL_DOWN, "up_proj": gguf.MODEL_TENSOR.SPEECH_TO_LOCAL_UP} + if (match := re.fullmatch(r"speech_embedding_to_local_mlp\.(.+?)\.weight", name)) is not None: + proj = match.group(1) + if proj in _s2l_map: + yield (f"{gguf.TENSOR_NAMES[_s2l_map[proj]]}.weight", data_torch) + return + yield from super().modify_tensors(data_torch, name, bid) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 8ae6dd3b8..dff780f50 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -517,6 +517,25 @@ class MODEL_TENSOR(IntEnum): POS_EMBD = auto() OUTPUT = auto() OUTPUT_AUDIO = auto() # moss-tts-delay, indexed as output_audio.{id} + AUDIO_LN = auto() + LOCAL_ATTN_NORM = auto() + LOCAL_ATTN_Q = auto() + LOCAL_ATTN_Q_NORM = auto() + LOCAL_ATTN_K = auto() + LOCAL_ATTN_K_NORM = auto() + LOCAL_ATTN_V = auto() + LOCAL_ATTN_OUT = auto() + LOCAL_FFN_NORM = auto() + LOCAL_FFN_GATE = auto() + LOCAL_FFN_DOWN = auto() + LOCAL_FFN_UP = auto() + LOCAL_OUTPUT_NORM = auto() + LOCAL_TO_SPEECH_GATE = auto() + LOCAL_TO_SPEECH_DOWN = auto() + LOCAL_TO_SPEECH_UP = auto() + SPEECH_TO_LOCAL_GATE = auto() + SPEECH_TO_LOCAL_DOWN = auto() + SPEECH_TO_LOCAL_UP = auto() DENSE_2_OUT = auto() # embeddinggemma 2_Dense DENSE_3_OUT = auto() # embeddinggemma 3_Dense OUTPUT_NORM = auto() @@ -964,6 +983,25 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.OUTPUT: "output", MODEL_TENSOR.OUTPUT_AUDIO: "output_audio", + MODEL_TENSOR.AUDIO_LN: "audio_ln", + MODEL_TENSOR.LOCAL_ATTN_NORM: "local.blk.{bid}.attn_norm", + MODEL_TENSOR.LOCAL_ATTN_Q: "local.blk.{bid}.attn_q", + MODEL_TENSOR.LOCAL_ATTN_Q_NORM: "local.blk.{bid}.attn_q_norm", + MODEL_TENSOR.LOCAL_ATTN_K: "local.blk.{bid}.attn_k", + MODEL_TENSOR.LOCAL_ATTN_K_NORM: "local.blk.{bid}.attn_k_norm", + MODEL_TENSOR.LOCAL_ATTN_V: "local.blk.{bid}.attn_v", + MODEL_TENSOR.LOCAL_ATTN_OUT: "local.blk.{bid}.attn_output", + MODEL_TENSOR.LOCAL_FFN_NORM: "local.blk.{bid}.ffn_norm", + MODEL_TENSOR.LOCAL_FFN_GATE: "local.blk.{bid}.ffn_gate", + MODEL_TENSOR.LOCAL_FFN_DOWN: "local.blk.{bid}.ffn_down", + MODEL_TENSOR.LOCAL_FFN_UP: "local.blk.{bid}.ffn_up", + MODEL_TENSOR.LOCAL_OUTPUT_NORM: "local.output_norm", + MODEL_TENSOR.LOCAL_TO_SPEECH_GATE: "local_to_speech.ffn_gate", + MODEL_TENSOR.LOCAL_TO_SPEECH_DOWN: "local_to_speech.ffn_down", + MODEL_TENSOR.LOCAL_TO_SPEECH_UP: "local_to_speech.ffn_up", + MODEL_TENSOR.SPEECH_TO_LOCAL_GATE: "speech_to_local.ffn_gate", + MODEL_TENSOR.SPEECH_TO_LOCAL_DOWN: "speech_to_local.ffn_down", + MODEL_TENSOR.SPEECH_TO_LOCAL_UP: "speech_to_local.ffn_up", MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embeddinggemma 2_Dense MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embeddinggemma 2_Dense MODEL_TENSOR.ROPE_FREQS: "rope_freqs", @@ -1812,6 +1850,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.OUTPUT_AUDIO, + MODEL_TENSOR.AUDIO_LN, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_Q, @@ -1824,6 +1863,24 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.LOCAL_ATTN_NORM, + MODEL_TENSOR.LOCAL_ATTN_Q, + MODEL_TENSOR.LOCAL_ATTN_Q_NORM, + MODEL_TENSOR.LOCAL_ATTN_K, + MODEL_TENSOR.LOCAL_ATTN_K_NORM, + MODEL_TENSOR.LOCAL_ATTN_V, + MODEL_TENSOR.LOCAL_ATTN_OUT, + MODEL_TENSOR.LOCAL_FFN_NORM, + MODEL_TENSOR.LOCAL_FFN_GATE, + MODEL_TENSOR.LOCAL_FFN_DOWN, + MODEL_TENSOR.LOCAL_FFN_UP, + MODEL_TENSOR.LOCAL_OUTPUT_NORM, + MODEL_TENSOR.LOCAL_TO_SPEECH_GATE, + MODEL_TENSOR.LOCAL_TO_SPEECH_DOWN, + MODEL_TENSOR.LOCAL_TO_SPEECH_UP, + MODEL_TENSOR.SPEECH_TO_LOCAL_GATE, + MODEL_TENSOR.SPEECH_TO_LOCAL_DOWN, + MODEL_TENSOR.SPEECH_TO_LOCAL_UP, ], MODEL_ARCH.QWEN3MOE: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 07800c68a..6867ba219 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -350,6 +350,25 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_OUTPUT_NORM_LFM2, "token_embd_norm" }, // fix for wrong tensor name { LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_OUTPUT_AUDIO, "output_audio.%d" }, + { LLM_TENSOR_AUDIO_LN, "audio_ln.%d" }, + { LLM_TENSOR_LOCAL_ATTN_NORM, "local.blk.%d.attn_norm" }, + { LLM_TENSOR_LOCAL_ATTN_Q, "local.blk.%d.attn_q" }, + { LLM_TENSOR_LOCAL_ATTN_Q_NORM, "local.blk.%d.attn_q_norm" }, + { LLM_TENSOR_LOCAL_ATTN_K, "local.blk.%d.attn_k" }, + { LLM_TENSOR_LOCAL_ATTN_K_NORM, "local.blk.%d.attn_k_norm" }, + { LLM_TENSOR_LOCAL_ATTN_V, "local.blk.%d.attn_v" }, + { LLM_TENSOR_LOCAL_ATTN_OUT, "local.blk.%d.attn_output" }, + { LLM_TENSOR_LOCAL_FFN_NORM, "local.blk.%d.ffn_norm" }, + { LLM_TENSOR_LOCAL_FFN_GATE, "local.blk.%d.ffn_gate" }, + { LLM_TENSOR_LOCAL_FFN_DOWN, "local.blk.%d.ffn_down" }, + { LLM_TENSOR_LOCAL_FFN_UP, "local.blk.%d.ffn_up" }, + { LLM_TENSOR_LOCAL_OUTPUT_NORM, "local.output_norm" }, + { LLM_TENSOR_LOCAL_TO_SPEECH_GATE, "local_to_speech.ffn_gate.%d" }, + { LLM_TENSOR_LOCAL_TO_SPEECH_DOWN, "local_to_speech.ffn_down.%d" }, + { LLM_TENSOR_LOCAL_TO_SPEECH_UP, "local_to_speech.ffn_up.%d" }, + { LLM_TENSOR_SPEECH_TO_LOCAL_GATE, "speech_to_local.ffn_gate" }, + { LLM_TENSOR_SPEECH_TO_LOCAL_DOWN, "speech_to_local.ffn_down" }, + { LLM_TENSOR_SPEECH_TO_LOCAL_UP, "speech_to_local.ffn_up" }, { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, @@ -990,6 +1009,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_AUDIO, + LLM_TENSOR_AUDIO_LN, LLM_TENSOR_ATTN_NORM, LLM_TENSOR_ATTN_Q, LLM_TENSOR_ATTN_Q_NORM, @@ -1001,6 +1021,24 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, + LLM_TENSOR_LOCAL_ATTN_NORM, + LLM_TENSOR_LOCAL_ATTN_Q, + LLM_TENSOR_LOCAL_ATTN_Q_NORM, + LLM_TENSOR_LOCAL_ATTN_K, + LLM_TENSOR_LOCAL_ATTN_K_NORM, + LLM_TENSOR_LOCAL_ATTN_V, + LLM_TENSOR_LOCAL_ATTN_OUT, + LLM_TENSOR_LOCAL_FFN_NORM, + LLM_TENSOR_LOCAL_FFN_GATE, + LLM_TENSOR_LOCAL_FFN_DOWN, + LLM_TENSOR_LOCAL_FFN_UP, + LLM_TENSOR_LOCAL_OUTPUT_NORM, + LLM_TENSOR_LOCAL_TO_SPEECH_GATE, + LLM_TENSOR_LOCAL_TO_SPEECH_DOWN, + LLM_TENSOR_LOCAL_TO_SPEECH_UP, + LLM_TENSOR_SPEECH_TO_LOCAL_GATE, + LLM_TENSOR_SPEECH_TO_LOCAL_DOWN, + LLM_TENSOR_SPEECH_TO_LOCAL_UP, }; case LLM_ARCH_QWEN3MOE: case LLM_ARCH_QWEN3VLMOE: @@ -2597,6 +2635,25 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}}, {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_OUTPUT_AUDIO, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_AUDIO_LN, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_LOCAL_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_LOCAL_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_LOCAL_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_LOCAL_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_LOCAL_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_LOCAL_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_LOCAL_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_LOCAL_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_LOCAL_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_LOCAL_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_LOCAL_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_LOCAL_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_LOCAL_TO_SPEECH_GATE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_LOCAL_TO_SPEECH_DOWN, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_LOCAL_TO_SPEECH_UP, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SPEECH_TO_LOCAL_GATE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SPEECH_TO_LOCAL_DOWN, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SPEECH_TO_LOCAL_UP, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, @@ -2827,6 +2884,10 @@ std::string LLM_TN_IMPL::str() const { switch (tensor) { case LLM_TENSOR_TOKEN_EMBD_AUDIO: case LLM_TENSOR_OUTPUT_AUDIO: + case LLM_TENSOR_AUDIO_LN: + case LLM_TENSOR_LOCAL_TO_SPEECH_GATE: + case LLM_TENSOR_LOCAL_TO_SPEECH_DOWN: + case LLM_TENSOR_LOCAL_TO_SPEECH_UP: name = ::format(LLM_TENSOR_NAMES.at(tensor), xid); break; default: diff --git a/src/llama-arch.h b/src/llama-arch.h index 9320b01da..72c9d0ae6 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -359,6 +359,25 @@ enum llm_tensor { LLM_TENSOR_DENSE_3_OUT, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_AUDIO, + LLM_TENSOR_AUDIO_LN, + LLM_TENSOR_LOCAL_ATTN_NORM, + LLM_TENSOR_LOCAL_ATTN_Q, + LLM_TENSOR_LOCAL_ATTN_Q_NORM, + LLM_TENSOR_LOCAL_ATTN_K, + LLM_TENSOR_LOCAL_ATTN_K_NORM, + LLM_TENSOR_LOCAL_ATTN_V, + LLM_TENSOR_LOCAL_ATTN_OUT, + LLM_TENSOR_LOCAL_FFN_NORM, + LLM_TENSOR_LOCAL_FFN_GATE, + LLM_TENSOR_LOCAL_FFN_DOWN, + LLM_TENSOR_LOCAL_FFN_UP, + LLM_TENSOR_LOCAL_OUTPUT_NORM, + LLM_TENSOR_LOCAL_TO_SPEECH_GATE, + LLM_TENSOR_LOCAL_TO_SPEECH_DOWN, + LLM_TENSOR_LOCAL_TO_SPEECH_UP, + LLM_TENSOR_SPEECH_TO_LOCAL_GATE, + LLM_TENSOR_SPEECH_TO_LOCAL_DOWN, + LLM_TENSOR_SPEECH_TO_LOCAL_UP, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_OUTPUT_NORM_LFM2, // fix for wrong tensor name LLM_TENSOR_ROPE_FREQS, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f7b4bd12f..e67366ed9 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3728,6 +3728,57 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } + + // MOSS-TTS Local variant: optional tensors (TENSOR_NOT_REQUIRED for 8B compat) + // Dimensions: local_dim=1536, local_ff=8960, n_embd=2048, n_head_local=16, n_kv_local=8, head_dim=128 + { + const int64_t local_dim = 1536; + const int64_t local_ff = 8960; + const int64_t local_kv = 1024; // n_kv_heads(8) * head_dim(128) + const int64_t local_q = 2048; // n_heads(16) * head_dim(128) + const int64_t head_dim = 128; + const int64_t bridge_ff = 2048; // additional_mlp_ffn_hidden_size + const uint32_t n_heads = hparams.n_vq + 1; + + // Audio layer norms before each head [2048] + audio_ln.resize(n_heads); + for (uint32_t i = 0; i < n_heads; ++i) { + audio_ln[i] = create_tensor(tn(LLM_TENSOR_AUDIO_LN, "weight", -1, i), {n_embd}, TENSOR_NOT_REQUIRED); + } + + // Speech-to-local bridge MLP: 2048 → 1536 (SwiGLU: gate/up [2048,2048], down [2048,1536]) + speech_to_local_gate = create_tensor(tn(LLM_TENSOR_SPEECH_TO_LOCAL_GATE, "weight"), {n_embd, bridge_ff}, TENSOR_NOT_REQUIRED); + speech_to_local_down = create_tensor(tn(LLM_TENSOR_SPEECH_TO_LOCAL_DOWN, "weight"), {bridge_ff, local_dim}, TENSOR_NOT_REQUIRED); + speech_to_local_up = create_tensor(tn(LLM_TENSOR_SPEECH_TO_LOCAL_UP, "weight"), {n_embd, bridge_ff}, TENSOR_NOT_REQUIRED); + + // Local-to-speech bridge MLPs: 1536 → 2048 per head (gate/up [1536,2048], down [2048,2048]) + local_to_speech_gate.resize(n_heads); + local_to_speech_down.resize(n_heads); + local_to_speech_up.resize(n_heads); + for (uint32_t i = 0; i < n_heads; ++i) { + local_to_speech_gate[i] = create_tensor(tn(LLM_TENSOR_LOCAL_TO_SPEECH_GATE, "weight", -1, i), {local_dim, bridge_ff}, TENSOR_NOT_REQUIRED); + local_to_speech_down[i] = create_tensor(tn(LLM_TENSOR_LOCAL_TO_SPEECH_DOWN, "weight", -1, i), {bridge_ff, n_embd}, TENSOR_NOT_REQUIRED); + local_to_speech_up[i] = create_tensor(tn(LLM_TENSOR_LOCAL_TO_SPEECH_UP, "weight", -1, i), {local_dim, bridge_ff}, TENSOR_NOT_REQUIRED); + } + + // Local transformer (4 layers at local_dim=1536) + local_output_norm = create_tensor(tn(LLM_TENSOR_LOCAL_OUTPUT_NORM, "weight"), {local_dim}, TENSOR_NOT_REQUIRED); + local_layers.resize(4); + for (int i = 0; i < 4; ++i) { + auto & ll = local_layers[i]; + ll.attn_norm = create_tensor(tn(LLM_TENSOR_LOCAL_ATTN_NORM, "weight", i), {local_dim}, TENSOR_NOT_REQUIRED); + ll.wq = create_tensor(tn(LLM_TENSOR_LOCAL_ATTN_Q, "weight", i), {local_dim, local_q}, TENSOR_NOT_REQUIRED); + ll.wk = create_tensor(tn(LLM_TENSOR_LOCAL_ATTN_K, "weight", i), {local_dim, local_kv}, TENSOR_NOT_REQUIRED); + ll.wv = create_tensor(tn(LLM_TENSOR_LOCAL_ATTN_V, "weight", i), {local_dim, local_kv}, TENSOR_NOT_REQUIRED); + ll.wo = create_tensor(tn(LLM_TENSOR_LOCAL_ATTN_OUT, "weight", i), {local_q, local_dim}, TENSOR_NOT_REQUIRED); + ll.attn_q_norm = create_tensor(tn(LLM_TENSOR_LOCAL_ATTN_Q_NORM, "weight", i), {head_dim}, TENSOR_NOT_REQUIRED); + ll.attn_k_norm = create_tensor(tn(LLM_TENSOR_LOCAL_ATTN_K_NORM, "weight", i), {head_dim}, TENSOR_NOT_REQUIRED); + ll.ffn_norm = create_tensor(tn(LLM_TENSOR_LOCAL_FFN_NORM, "weight", i), {local_dim}, TENSOR_NOT_REQUIRED); + ll.ffn_gate = create_tensor(tn(LLM_TENSOR_LOCAL_FFN_GATE, "weight", i), {local_dim, local_ff}, TENSOR_NOT_REQUIRED); + ll.ffn_down = create_tensor(tn(LLM_TENSOR_LOCAL_FFN_DOWN, "weight", i), {local_ff, local_dim}, TENSOR_NOT_REQUIRED); + ll.ffn_up = create_tensor(tn(LLM_TENSOR_LOCAL_FFN_UP, "weight", i), {local_dim, local_ff}, TENSOR_NOT_REQUIRED); + } + } } break; case LLM_ARCH_QWEN3MOE: case LLM_ARCH_QWEN3VLMOE: diff --git a/src/llama-model.h b/src/llama-model.h index 1dfbab09c..c86e72723 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -496,6 +496,32 @@ struct llama_model { struct ggml_tensor * output_b = nullptr; struct ggml_tensor * output_norm_enc = nullptr; + // MOSS-TTS Local variant: local transformer + bridge MLPs + audio layer norms + std::vector audio_ln; // layer norm before each audio head (n_vq+1) + struct ggml_tensor * speech_to_local_gate = nullptr; // project backbone → local dim + struct ggml_tensor * speech_to_local_down = nullptr; + struct ggml_tensor * speech_to_local_up = nullptr; + std::vector local_to_speech_gate; // project local → backbone per codebook + std::vector local_to_speech_down; + std::vector local_to_speech_up; + struct ggml_tensor * local_output_norm = nullptr; + + // local transformer layers (stored separately from backbone layers) + struct llama_layer_local { + struct ggml_tensor * attn_norm = nullptr; + struct ggml_tensor * wq = nullptr; + struct ggml_tensor * wk = nullptr; + struct ggml_tensor * wv = nullptr; + struct ggml_tensor * wo = nullptr; + struct ggml_tensor * attn_q_norm = nullptr; + struct ggml_tensor * attn_k_norm = nullptr; + struct ggml_tensor * ffn_norm = nullptr; + struct ggml_tensor * ffn_gate = nullptr; + struct ggml_tensor * ffn_down = nullptr; + struct ggml_tensor * ffn_up = nullptr; + }; + std::vector local_layers; + // classifier struct ggml_tensor * cls = nullptr; struct ggml_tensor * cls_b = nullptr; diff --git a/src/models/moss-tts-delay.cpp b/src/models/moss-tts-delay.cpp index 8212ba059..c55cc2f6b 100644 --- a/src/models/moss-tts-delay.cpp +++ b/src/models/moss-tts-delay.cpp @@ -162,15 +162,101 @@ llm_build_moss_tts_delay::llm_build_moss_tts_delay(const llama_model & model, co GGML_ASSERT(hparams.n_vq == model.output_audio.size()); - ggml_tensor * logits = build_lora_mm(model.output, cur); - cb(logits, "result_output_text", -1); + // Detect Local variant: speech_to_local bridge MLP present + const bool is_local = (model.speech_to_local_gate != nullptr); + + ggml_tensor * logits; + + if (is_local) { + // ── Local variant output path ── + // 1. Text head: backbone output → text logits (same as 8B) + logits = build_lora_mm(model.output, cur); + cb(logits, "result_output_text", -1); + + // 2. Bridge backbone(2048) → local space(1536) via SwiGLU MLP + ggml_tensor * gate_out = build_lora_mm(model.speech_to_local_gate, cur); + gate_out = ggml_silu(ctx0, gate_out); + ggml_tensor * up_out = build_lora_mm(model.speech_to_local_up, cur); + ggml_tensor * local_cur = ggml_mul(ctx0, gate_out, up_out); + local_cur = build_lora_mm(model.speech_to_local_down, local_cur); + cb(local_cur, "speech_to_local_out", -1); + + // 3. Run local transformer (4 layers) on the local representation + // Note: in the full autoregressive version, this runs per-channel + // with accumulating inputs. Here we process the single backbone output + // through all 4 layers as a static pass. This uses the trained weights + // but lacks the sequential channel feedback (TODO: autoregressive loop + // in moss-tts.cpp for proper inter-channel coherence). + if (!model.local_layers.empty() && model.local_layers[0].attn_norm != nullptr) { + const int n_local_layers = (int) model.local_layers.size(); + + for (int ll = 0; ll < n_local_layers; ++ll) { + const auto & layer = model.local_layers[ll]; + + // FFN-only pass for the static single-token case. + // Full attention with GQA and KV cache accumulation is needed + // for the autoregressive channel loop (TODO in moss-tts.cpp). + // The FFN layers do the bulk of the learned transformation. + + ggml_tensor * local_ffn_inp = local_cur; + local_cur = build_norm(local_cur, layer.ffn_norm, nullptr, LLM_NORM_RMS, -1); + cb(local_cur, "local_ffn_norm", ll); + + // SwiGLU FFN + ggml_tensor * ffn_gate = build_lora_mm(layer.ffn_gate, local_cur); + ffn_gate = ggml_silu(ctx0, ffn_gate); + ggml_tensor * ffn_up = build_lora_mm(layer.ffn_up, local_cur); + local_cur = ggml_mul(ctx0, ffn_gate, ffn_up); + local_cur = build_lora_mm(layer.ffn_down, local_cur); + cb(local_cur, "local_ffn_out", ll); + + // Residual + local_cur = ggml_add(ctx0, local_cur, local_ffn_inp); + cb(local_cur, "local_l_out", ll); + } - for (uint32_t i = 0; i < hparams.n_vq; ++i) { - ggml_tensor * audio_logits = build_lora_mm(model.output_audio[i], cur); - cb(audio_logits, "result_output_audio", i); + // Local output norm + if (model.local_output_norm != nullptr) { + local_cur = build_norm(local_cur, model.local_output_norm, nullptr, LLM_NORM_RMS, -1); + cb(local_cur, "local_output_norm", -1); + } + } + + // 4. For each audio channel: local_to_speech bridge → audio_ln → head → logits + for (uint32_t i = 0; i < hparams.n_vq; ++i) { + // Bridge local(1536) → backbone(2048) via SwiGLU MLP + ggml_tensor * ch_gate = build_lora_mm(model.local_to_speech_gate[i], local_cur); + ch_gate = ggml_silu(ctx0, ch_gate); + ggml_tensor * ch_up = build_lora_mm(model.local_to_speech_up[i], local_cur); + ggml_tensor * ch_cur = ggml_mul(ctx0, ch_gate, ch_up); + ch_cur = build_lora_mm(model.local_to_speech_down[i], ch_cur); + cb(ch_cur, "local_to_speech_out", i); + + // Audio layer norm + if (model.audio_ln[i] != nullptr) { + ch_cur = build_norm(ch_cur, model.audio_ln[i], nullptr, LLM_NORM_RMS, -1); + cb(ch_cur, "audio_ln", i); + } - logits = ggml_concat(ctx0, logits, audio_logits, 0); - cb(logits, "result_output_concat", i); + // Audio head → logits + ggml_tensor * audio_logits = build_lora_mm(model.output_audio[i], ch_cur); + cb(audio_logits, "result_output_audio", i); + + logits = ggml_concat(ctx0, logits, audio_logits, 0); + cb(logits, "result_output_concat", i); + } + } else { + // ── 8B Delay variant output path (original) ── + logits = build_lora_mm(model.output, cur); + cb(logits, "result_output_text", -1); + + for (uint32_t i = 0; i < hparams.n_vq; ++i) { + ggml_tensor * audio_logits = build_lora_mm(model.output_audio[i], cur); + cb(audio_logits, "result_output_audio", i); + + logits = ggml_concat(ctx0, logits, audio_logits, 0); + cb(logits, "result_output_concat", i); + } } logits = ggml_cont(ctx0, logits); diff --git a/tools/tts/moss-tts.cpp b/tools/tts/moss-tts.cpp index 2e545297c..736c7a7b0 100644 --- a/tools/tts/moss-tts.cpp +++ b/tools/tts/moss-tts.cpp @@ -3,6 +3,10 @@ #include "log.h" #include "llama.h" #include "llama-cpp.h" +#include "../../src/llama-model.h" +#include "ggml.h" +#include "ggml-backend.h" +#include "ggml-alloc.h" #include #include @@ -637,6 +641,176 @@ static std::vector moss_collect_audio_history_channels( return out; } +// ── MOSS-TTS Local: autoregressive channel inference ── +// Given backbone hidden state, runs the local transformer sequentially per channel. +// Returns audio logits for all n_vq channels as a flat vector [n_vq * audio_vocab]. + +static void swiglu_cpu(const float * gate_w, const float * up_w, const float * down_w, + const float * input, float * output, + int64_t in_dim, int64_t ff_dim, int64_t out_dim) { + // SwiGLU: output = down(silu(gate(input)) * up(input)) + std::vector gate(ff_dim), up(ff_dim), hidden(ff_dim); + + // gate = input @ gate_w^T + for (int64_t j = 0; j < ff_dim; ++j) { + float sum = 0; + for (int64_t k = 0; k < in_dim; ++k) sum += input[k] * gate_w[j * in_dim + k]; + gate[j] = sum / (1.0f + expf(-sum)); // silu + } + // up = input @ up_w^T + for (int64_t j = 0; j < ff_dim; ++j) { + float sum = 0; + for (int64_t k = 0; k < in_dim; ++k) sum += input[k] * up_w[j * in_dim + k]; + hidden[j] = gate[j] * sum; + } + // output = hidden @ down_w^T + for (int64_t j = 0; j < out_dim; ++j) { + float sum = 0; + for (int64_t k = 0; k < ff_dim; ++k) sum += hidden[k] * down_w[j * ff_dim + k]; + output[j] = sum; + } +} + +static void rms_norm_cpu(float * x, const float * weight, int64_t dim, float eps = 1e-6f) { + float ss = 0; + for (int64_t i = 0; i < dim; ++i) ss += x[i] * x[i]; + ss = 1.0f / sqrtf(ss / (float)dim + eps); + for (int64_t i = 0; i < dim; ++i) x[i] = x[i] * ss * weight[i]; +} + +static void matmul_cpu(const float * weight, const float * input, float * output, + int64_t out_dim, int64_t in_dim) { + for (int64_t j = 0; j < out_dim; ++j) { + float sum = 0; + for (int64_t k = 0; k < in_dim; ++k) sum += input[k] * weight[j * in_dim + k]; + output[j] = sum; + } +} + +// Apply RoPE to a single head vector at given position +static void rope_cpu(float * vec, int64_t dim, int pos, float theta = 1000000.0f) { + for (int64_t i = 0; i < dim; i += 2) { + const float freq = 1.0f / powf(theta, (float)i / (float)dim); + const float angle = (float)pos * freq; + const float cos_a = cosf(angle); + const float sin_a = sinf(angle); + const float v0 = vec[i]; + const float v1 = vec[i + 1]; + vec[i] = v0 * cos_a - v1 * sin_a; + vec[i + 1] = v0 * sin_a + v1 * cos_a; + } +} + +// RMS norm on a single vector (in-place) +static void rms_norm_vec(float * x, const float * weight, int64_t dim, float eps = 1e-6f) { + float ss = 0; + for (int64_t i = 0; i < dim; ++i) ss += x[i] * x[i]; + ss = 1.0f / sqrtf(ss / (float)dim + eps); + for (int64_t i = 0; i < dim; ++i) x[i] = x[i] * ss * weight[i]; +} + +// Multi-head attention with GQA on CPU +// x: [seq_len, hidden_dim], output written in-place +// Wq [hidden, n_heads*head_dim], Wk [hidden, n_kv*head_dim], Wv [hidden, n_kv*head_dim], Wo [n_heads*head_dim, hidden] +// q_norm, k_norm: [head_dim] per-head RMS norm weights +static void mha_gqa_cpu( + const float * x, float * output, int64_t seq_len, + const float * Wq, const float * Wk, const float * Wv, const float * Wo, + const float * q_norm, const float * k_norm, + int64_t hidden, int n_heads, int n_kv, int64_t head_dim) +{ + const int gqa_ratio = n_heads / n_kv; + + // Project Q, K, V for all positions + std::vector Q(seq_len * n_heads * head_dim); + std::vector K(seq_len * n_kv * head_dim); + std::vector V(seq_len * n_kv * head_dim); + + for (int64_t s = 0; s < seq_len; ++s) { + const float * xs = x + s * hidden; + // Q = x @ Wq^T → [n_heads * head_dim] + matmul_cpu(Wq, xs, Q.data() + s * n_heads * head_dim, n_heads * head_dim, hidden); + // K = x @ Wk^T → [n_kv * head_dim] + matmul_cpu(Wk, xs, K.data() + s * n_kv * head_dim, n_kv * head_dim, hidden); + // V = x @ Wv^T → [n_kv * head_dim] + matmul_cpu(Wv, xs, V.data() + s * n_kv * head_dim, n_kv * head_dim, hidden); + + // Per-head Q norm + RoPE + for (int h = 0; h < n_heads; ++h) { + float * qh = Q.data() + s * n_heads * head_dim + h * head_dim; + rms_norm_vec(qh, q_norm, head_dim); + rope_cpu(qh, head_dim, (int)s); + } + // Per-head K norm + RoPE + for (int h = 0; h < n_kv; ++h) { + float * kh = K.data() + s * n_kv * head_dim + h * head_dim; + rms_norm_vec(kh, k_norm, head_dim); + rope_cpu(kh, head_dim, (int)s); + } + } + + // Attention per head (with GQA: each KV head serves gqa_ratio Q heads) + std::vector attn_out(seq_len * n_heads * head_dim, 0.0f); + const float scale = 1.0f / sqrtf((float)head_dim); + + for (int h = 0; h < n_heads; ++h) { + const int kv_h = h / gqa_ratio; // which KV head this Q head uses + + for (int64_t qi = 0; qi < seq_len; ++qi) { + const float * qvec = Q.data() + qi * n_heads * head_dim + h * head_dim; + + // Compute attention scores (causal: only attend to positions <= qi) + std::vector scores(qi + 1); + float max_score = -1e30f; + for (int64_t ki = 0; ki <= qi; ++ki) { + const float * kvec = K.data() + ki * n_kv * head_dim + kv_h * head_dim; + float dot = 0; + for (int64_t d = 0; d < head_dim; ++d) dot += qvec[d] * kvec[d]; + scores[ki] = dot * scale; + if (scores[ki] > max_score) max_score = scores[ki]; + } + + // Softmax + float sum_exp = 0; + for (int64_t ki = 0; ki <= qi; ++ki) { + scores[ki] = expf(scores[ki] - max_score); + sum_exp += scores[ki]; + } + for (int64_t ki = 0; ki <= qi; ++ki) scores[ki] /= sum_exp; + + // Weighted sum of V + float * out_h = attn_out.data() + qi * n_heads * head_dim + h * head_dim; + for (int64_t ki = 0; ki <= qi; ++ki) { + const float * vvec = V.data() + ki * n_kv * head_dim + kv_h * head_dim; + for (int64_t d = 0; d < head_dim; ++d) { + out_h[d] += scores[ki] * vvec[d]; + } + } + } + } + + // Output projection: attn_out @ Wo^T → [seq_len, hidden] + for (int64_t s = 0; s < seq_len; ++s) { + matmul_cpu(Wo, attn_out.data() + s * n_heads * head_dim, output + s * hidden, + hidden, n_heads * head_dim); + } +} + +// Read float data from a ggml tensor (dequantize if needed) +static std::vector tensor_to_float(const struct ggml_tensor * t) { + const int64_t n = ggml_nelements(t); + std::vector out(n); + if (t->type == GGML_TYPE_F32) { + ggml_backend_tensor_get(t, out.data(), 0, n * sizeof(float)); + } else { + // Dequantize + std::vector buf(ggml_nbytes(t)); + ggml_backend_tensor_get(t, buf.data(), 0, buf.size()); + ggml_get_type_traits(t->type)->to_float(buf.data(), out.data(), n); + } + return out; +} + static std::vector moss_delay_step( moss_delay_state & state, const std::vector & text_logits, @@ -1141,7 +1315,11 @@ static void moss_generate_from_ref( cparams.n_batch = std::max((uint32_t) hdr.prompt_frames, 1u); cparams.n_ubatch = cparams.n_batch; cparams.n_seq_max = 1; - cparams.embeddings = false; + // Detect Local variant + const llama_model * mdl = model.get(); + const bool is_local = (mdl->speech_to_local_gate != nullptr); + + cparams.embeddings = is_local; // need hidden state for Local channel loop llama_context_ptr ctx(llama_init_from_model(model.get(), cparams)); if (!ctx) { @@ -1150,7 +1328,11 @@ static void moss_generate_from_ref( llama_set_warmup(ctx.get(), false); llama_set_causal_attn(ctx.get(), true); - llama_set_embeddings(ctx.get(), false); + llama_set_embeddings(ctx.get(), is_local); + + if (is_local) { + LOG("moss-tts: detected Local variant — using autoregressive channel loop\n"); + } { moss_owned_batch batch = moss_batch_from_packed_rows( @@ -1169,6 +1351,75 @@ static void moss_generate_from_ref( const size_t audio_vocab = moss_audio_vocab_with_pad(cfg); moss_rng rng(seed); + // Pre-cache local weights on CPU for the channel loop (only for Local variant) + const int64_t n_embd = llama_model_n_embd(mdl); + const int64_t local_dim = 1536; // from config + const int64_t bridge_ff = 2048; // additional_mlp_ffn_hidden_size + const int64_t local_ff = 8960; // local_ffn_hidden_size + const int n_local_layers = is_local ? (int) mdl->local_layers.size() : 0; + + // Pre-read bridge MLP weights to CPU (slow first time, but reused every step) + std::vector s2l_gate_w, s2l_up_w, s2l_down_w; + std::vector> l2s_gate_w, l2s_up_w, l2s_down_w; + std::vector> aln_w; + std::vector> head_w; + // Local transformer weights per layer (attention + FFN) + struct local_layer_weights { + std::vector attn_norm_w; + std::vector wq, wk, wv, wo; + std::vector q_norm_w, k_norm_w; + std::vector ffn_norm_w, ffn_gate_w, ffn_up_w, ffn_down_w; + }; + std::vector local_lyrs; + std::vector local_out_norm_w; + // Audio embeddings for token re-embedding + std::vector> audio_emb_w; + + if (is_local) { + LOG("moss-tts: pre-caching local weights on CPU...\n"); + s2l_gate_w = tensor_to_float(mdl->speech_to_local_gate); + s2l_up_w = tensor_to_float(mdl->speech_to_local_up); + s2l_down_w = tensor_to_float(mdl->speech_to_local_down); + + l2s_gate_w.resize(cfg.n_vq); + l2s_up_w.resize(cfg.n_vq); + l2s_down_w.resize(cfg.n_vq); + aln_w.resize(cfg.n_vq); + head_w.resize(cfg.n_vq); + for (uint32_t i = 0; i < cfg.n_vq; ++i) { + l2s_gate_w[i] = tensor_to_float(mdl->local_to_speech_gate[i]); + l2s_up_w[i] = tensor_to_float(mdl->local_to_speech_up[i]); + l2s_down_w[i] = tensor_to_float(mdl->local_to_speech_down[i]); + if (mdl->audio_ln[i]) aln_w[i] = tensor_to_float(mdl->audio_ln[i]); + head_w[i] = tensor_to_float(mdl->output_audio[i]); + } + + local_lyrs.resize(n_local_layers); + for (int ll = 0; ll < n_local_layers; ++ll) { + const auto & layer = mdl->local_layers[ll]; + auto & lw = local_lyrs[ll]; + if (layer.attn_norm) lw.attn_norm_w = tensor_to_float(layer.attn_norm); + if (layer.wq) lw.wq = tensor_to_float(layer.wq); + if (layer.wk) lw.wk = tensor_to_float(layer.wk); + if (layer.wv) lw.wv = tensor_to_float(layer.wv); + if (layer.wo) lw.wo = tensor_to_float(layer.wo); + if (layer.attn_q_norm) lw.q_norm_w = tensor_to_float(layer.attn_q_norm); + if (layer.attn_k_norm) lw.k_norm_w = tensor_to_float(layer.attn_k_norm); + if (layer.ffn_norm) lw.ffn_norm_w = tensor_to_float(layer.ffn_norm); + if (layer.ffn_gate) lw.ffn_gate_w = tensor_to_float(layer.ffn_gate); + if (layer.ffn_up) lw.ffn_up_w = tensor_to_float(layer.ffn_up); + if (layer.ffn_down) lw.ffn_down_w = tensor_to_float(layer.ffn_down); + } + if (mdl->local_output_norm) local_out_norm_w = tensor_to_float(mdl->local_output_norm); + + // Audio embeddings for re-embedding sampled tokens + audio_emb_w.resize(cfg.n_vq); + for (uint32_t i = 0; i < cfg.n_vq && i < mdl->tok_embd_audio.size(); ++i) { + audio_emb_w[i] = tensor_to_float(mdl->tok_embd_audio[i]); + } + LOG("moss-tts: local weights cached (%d FFN layers, %u channels)\n", n_local_layers, cfg.n_vq); + } + for (int32_t step = 0; step < max_new_tokens; ++step) { const float * logits = llama_get_logits_ith(ctx.get(), -1); if (logits == nullptr) { @@ -1176,9 +1427,126 @@ static void moss_generate_from_ref( } std::vector text_logits(logits, logits + text_vocab); - std::vector audio_logits( - logits + text_vocab, - logits + text_vocab + cfg.n_vq * audio_vocab); + std::vector audio_logits; + + if (is_local) { + // ── Local variant: autoregressive channel loop on CPU ── + const float * embd = llama_get_embeddings_ith(ctx.get(), -1); + if (embd == nullptr) { + throw std::runtime_error("llama_get_embeddings_ith returned null (is embeddings enabled?)"); + } + + // speech_to_local: backbone(2048) → local(1536) + std::vector local_cur(local_dim); + swiglu_cpu(s2l_gate_w.data(), s2l_up_w.data(), s2l_down_w.data(), + embd, local_cur.data(), n_embd, bridge_ff, local_dim); + + audio_logits.resize(cfg.n_vq * audio_vocab); + + // Growing sequence of local embeddings for the local transformer + // Starts with backbone projection, grows by one per channel + std::vector local_seq; // [seq_len * local_dim] + local_seq.insert(local_seq.end(), local_cur.begin(), local_cur.end()); + + const int n_q_heads = 16; + const int n_kv_heads = 8; + const int64_t head_dim_local = 128; + + for (uint32_t ch = 0; ch < cfg.n_vq; ++ch) { + const int64_t seq_len = (int64_t)(ch + 1); + + // Run local transformer (attention + FFN) on the full accumulated sequence + // Work buffer: [seq_len, local_dim] + std::vector lc(local_seq); // copy for processing + + for (int ll = 0; ll < n_local_layers; ++ll) { + const auto & lw = local_lyrs[ll]; + if (lw.attn_norm_w.empty()) continue; + + // --- Attention --- + // Pre-norm + std::vector normed(seq_len * local_dim); + for (int64_t s = 0; s < seq_len; ++s) { + std::copy(lc.begin() + s * local_dim, lc.begin() + (s+1) * local_dim, + normed.begin() + s * local_dim); + rms_norm_cpu(normed.data() + s * local_dim, lw.attn_norm_w.data(), local_dim); + } + + // Multi-head attention with GQA + std::vector attn_out(seq_len * local_dim, 0.0f); + if (!lw.wq.empty()) { + mha_gqa_cpu(normed.data(), attn_out.data(), seq_len, + lw.wq.data(), lw.wk.data(), lw.wv.data(), lw.wo.data(), + lw.q_norm_w.data(), lw.k_norm_w.data(), + local_dim, n_q_heads, n_kv_heads, head_dim_local); + } + + // Residual + for (int64_t i = 0; i < seq_len * local_dim; ++i) lc[i] += attn_out[i]; + + // --- FFN --- + for (int64_t s = 0; s < seq_len; ++s) { + std::vector fn(local_dim); + std::copy(lc.begin() + s * local_dim, lc.begin() + (s+1) * local_dim, fn.begin()); + rms_norm_cpu(fn.data(), lw.ffn_norm_w.data(), local_dim); + + std::vector ffn_out(local_dim); + swiglu_cpu(lw.ffn_gate_w.data(), lw.ffn_up_w.data(), lw.ffn_down_w.data(), + fn.data(), ffn_out.data(), local_dim, local_ff, local_dim); + + for (int64_t k = 0; k < local_dim; ++k) lc[s * local_dim + k] += ffn_out[k]; + } + } + + // Take last position from the local transformer output + std::vector last_pos(local_dim); + std::copy(lc.begin() + (seq_len - 1) * local_dim, + lc.begin() + seq_len * local_dim, last_pos.begin()); + + // Local output norm + if (!local_out_norm_w.empty()) { + rms_norm_cpu(last_pos.data(), local_out_norm_w.data(), local_dim); + } + + // local_to_speech: local(1536) → backbone(2048) + std::vector ch_embd(n_embd); + swiglu_cpu(l2s_gate_w[ch].data(), l2s_up_w[ch].data(), l2s_down_w[ch].data(), + last_pos.data(), ch_embd.data(), local_dim, bridge_ff, n_embd); + + // Audio layer norm + if (!aln_w[ch].empty()) { + rms_norm_cpu(ch_embd.data(), aln_w[ch].data(), n_embd); + } + + // Audio head → logits for this channel + std::vector ch_logits(audio_vocab); + matmul_cpu(head_w[ch].data(), ch_embd.data(), ch_logits.data(), (int64_t)audio_vocab, n_embd); + + std::copy(ch_logits.begin(), ch_logits.end(), audio_logits.begin() + ch * audio_vocab); + + // ── Autoregressive feedback: embed sampled token → speech_to_local → append to sequence ── + const auto max_it = std::max_element(ch_logits.begin(), ch_logits.end()); + const llama_token sampled = (llama_token)(max_it - ch_logits.begin()); + + if (ch < audio_emb_w.size() && !audio_emb_w[ch].empty()) { + const int64_t emb_dim = n_embd; + std::vector token_embd(emb_dim); + if ((size_t)sampled * emb_dim + emb_dim <= audio_emb_w[ch].size()) { + std::copy(audio_emb_w[ch].begin() + sampled * emb_dim, + audio_emb_w[ch].begin() + sampled * emb_dim + emb_dim, + token_embd.data()); + } + // Project to local dim and append to growing sequence + std::vector new_local(local_dim); + swiglu_cpu(s2l_gate_w.data(), s2l_up_w.data(), s2l_down_w.data(), + token_embd.data(), new_local.data(), n_embd, bridge_ff, local_dim); + local_seq.insert(local_seq.end(), new_local.begin(), new_local.end()); + } + } + } else { + // ── 8B variant: read parallel audio logits directly ── + audio_logits.assign(logits + text_vocab, logits + text_vocab + cfg.n_vq * audio_vocab); + } const std::vector next = moss_delay_step( state, text_logits, audio_logits, sampling_cfg, cfg, rng);