Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ extern "C" {

// NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
// In some cases the requested values via llama_context_params may differ from the actual values used by the context
// ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
Expand Down
4 changes: 4 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,14 @@ llama_context::llama_context(
}
}

// ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);

if (cparams.kv_unified) {
cparams.n_ctx_seq = cparams.n_ctx;
} else {
cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256);

if (cparams.n_ctx_seq == 0) {
throw std::runtime_error("n_ctx_seq == 0");
Expand Down
4 changes: 3 additions & 1 deletion src/llama-kv-cache-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(

const uint32_t size_base = kv_size;

uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
// note: the SWA cache is always padded to 256 for performance
// https://github.com/ggml-org/llama.cpp/issues/17037
uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256);

// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
if (swa_full) {
Expand Down
12 changes: 6 additions & 6 deletions tools/server/tests/unit/test_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def test_different_draft_min_draft_max():

def test_slot_ctx_not_exceeded():
global server
server.n_ctx = 64
server.n_ctx = 256
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "Hello " * 56,
"prompt": "Hello " * 248,
"temperature": 0.0,
"top_k": 1,
"speculative.p_min": 0.0,
Expand All @@ -91,19 +91,19 @@ def test_slot_ctx_not_exceeded():

def test_with_ctx_shift():
global server
server.n_ctx = 64
server.n_ctx = 256
server.enable_ctx_shift = True
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "Hello " * 56,
"prompt": "Hello " * 248,
"temperature": 0.0,
"top_k": 1,
"n_predict": 64,
"n_predict": 256,
"speculative.p_min": 0.0,
})
assert res.status_code == 200
assert len(res.body["content"]) > 0
assert res.body["tokens_predicted"] == 64
assert res.body["tokens_predicted"] == 256
assert res.body["truncated"] == True


Expand Down
Loading