diff --git a/common/speculative.cpp b/common/speculative.cpp index 20e38bec4..0834e08ce 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -61,18 +61,26 @@ static bool common_speculative_are_compatible( LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); if (vocab_type_tgt != vocab_type_dft) { - LOG_DBG("%s: draft model vocab type must match target model to use speculation but ", __func__); - LOG_DBG("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt); + LOG_WRN("%s: draft model vocab type must match target model to use speculation but " + "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt); return false; } - if ( - llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) || - llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) || - llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) || - llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft) - ) { - LOG_DBG("%s: draft model special tokens must match target model to use speculation\n", __func__); + if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) || + (llama_vocab_get_add_bos(vocab_tgt) && llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft))) { + LOG_WRN("%s: draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n", + __func__, + llama_vocab_get_add_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_dft), + llama_vocab_bos(vocab_tgt), llama_vocab_bos(vocab_dft)); + return false; + } + + if (llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) || + (llama_vocab_get_add_eos(vocab_tgt) && llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft))) { + LOG_WRN("%s: draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n", + __func__, + llama_vocab_get_add_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_dft), + llama_vocab_eos(vocab_tgt), llama_vocab_eos(vocab_dft)); return false; }