diff --git a/CMakeLists.txt b/CMakeLists.txt index 4bf8b2789ae7b..c833a4b3997cd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -89,6 +89,7 @@ option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE}) option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT}) +option(LLAMA_MOE_ENABLE "llama: enable experimental MoE runtime" OFF) # 3rd party libs option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON) @@ -111,6 +112,10 @@ set(LLAMA_INSTALL_VERSION 0.0.${LLAMA_BUILD_NUMBER}) set(GGML_ALL_WARNINGS ${LLAMA_ALL_WARNINGS}) set(GGML_FATAL_WARNINGS ${LLAMA_FATAL_WARNINGS}) +if (LLAMA_MOE_ENABLE) + add_compile_definitions(LLAMA_MOE_ENABLE) +endif() + # change the default for these ggml options if (NOT DEFINED GGML_LLAMAFILE) set(GGML_LLAMAFILE_DEFAULT ON) @@ -176,6 +181,9 @@ if (NOT TARGET ggml AND NOT LLAMA_USE_SYSTEM_GGML) set(GGML_BUILD_COMMIT ${LLAMA_BUILD_COMMIT}) add_subdirectory(ggml) # ... otherwise assume ggml is added by a parent CMakeLists.txt + if (GGML_CUDA) + enable_language(CUDA) + endif() endif() if (MINGW) diff --git a/README.md b/README.md index 258963ac16d7c..9faae1e86859e 100644 --- a/README.md +++ b/README.md @@ -305,6 +305,20 @@ The Hugging Face platform provides a variety of online tools for converting, qua - Use the [GGUF-my-repo space](https://huggingface.co/spaces/ggml-org/gguf-my-repo) to convert to GGUF format and quantize model weights to smaller sizes - Use the [GGUF-my-LoRA space](https://huggingface.co/spaces/ggml-org/gguf-my-lora) to convert LoRA adapters to GGUF format (more info: https://github.com/ggml-org/llama.cpp/discussions/10123) + +### Converting MoE models + +Models with Mixture-of-Experts layers should be exported with the new `GGUF_MOE` metadata so that llama.cpp can route and cache experts lazily. The high-level steps are: + +1. Convert the base model with `convert_hf_to_gguf.py --moe` (see the updated script usage below). +2. Ensure the converter emits router tensors (`blk.N.router.*`) and per-expert tensor groups (`blk.N.expert.K.W1`, `W2`, `W3`, …). +3. Provide per-layer metadata keys: + - `moe.layer.N.num_experts` + - `moe.layer.N.top_k` + - optionally, `moe.layer.N.router_type` +4. Run `python examples/moe_loader.py --validate path/to/model.gguf` to verify expert handles before inference. + +With these fields populated, llama.cpp will mmap each expert independently and hydrate them into GPU memory only when the router selects them. - Use the [GGUF-editor space](https://huggingface.co/spaces/CISCai/gguf-editor) to edit GGUF meta data in the browser (more info: https://github.com/ggml-org/llama.cpp/discussions/9268) - Use the [Inference Endpoints](https://ui.endpoints.huggingface.co/) to directly host `llama.cpp` in the cloud (more info: https://github.com/ggml-org/llama.cpp/discussions/9669) diff --git a/common/arg.cpp b/common/arg.cpp index 4316917d74595..03bee56e17455 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1951,6 +1951,36 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.kv_unified = true; } ).set_env("LLAMA_ARG_KV_SPLIT")); +#ifdef LLAMA_MOE_ENABLE + add_opt(common_arg( + {"--moe-enable"}, + "enable dynamic Mixture-of-Experts routing with on-demand expert caching", + [](common_params & params) { + params.moe_enable = true; + } + ).set_env("LLAMA_ARG_MOE_ENABLE")); + add_opt(common_arg( + {"--moe-cache-size"}, "N", + string_format("number of experts pinned in VRAM per device (default: %d, 0 = auto)", params.moe_cache_size), + [](common_params & params, int value) { + params.moe_cache_size = value; + } + ).set_env("LLAMA_ARG_MOE_CACHE")); + add_opt(common_arg( + {"--moe-prefetch"}, + string_format("overlap expert DMA with compute (default: %s)", params.moe_prefetch ? "true" : "false"), + [](common_params & params) { + params.moe_prefetch = true; + } + ).set_env("LLAMA_ARG_MOE_PREFETCH")); + add_opt(common_arg( + {"--moe-prefetch-lookahead"}, "N", + string_format("number of micro-batches to prefetch ahead (default: %d)", params.moe_prefetch_lookahead), + [](common_params & params, int value) { + params.moe_prefetch_lookahead = value; + } + ).set_env("LLAMA_ARG_MOE_PREFETCH_LOOKAHEAD")); +#endif add_opt(common_arg( {"--no-context-shift"}, string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), diff --git a/common/common.cpp b/common/common.cpp index b0591e84b0668..2b51a4ac2caf8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1180,10 +1180,16 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; - cparams.no_perf = params.no_perf; - cparams.op_offload = !params.no_op_offload; - cparams.swa_full = params.swa_full; - cparams.kv_unified = params.kv_unified; + cparams.no_perf = params.no_perf; + cparams.op_offload = !params.no_op_offload; + cparams.swa_full = params.swa_full; + cparams.kv_unified = params.kv_unified; +#ifdef LLAMA_MOE_ENABLE + cparams.moe_enable = params.moe_enable; + cparams.moe_cache_size = params.moe_cache_size > 0 ? (uint32_t) params.moe_cache_size : 0; + cparams.moe_prefetch = params.moe_prefetch; + cparams.moe_prefetch_lookahead = params.moe_prefetch_lookahead > 0 ? (uint32_t) params.moe_prefetch_lookahead : 1; +#endif cparams.type_k = params.cache_type_k; cparams.type_v = params.cache_type_v; diff --git a/common/common.h b/common/common.h index 54b7849b17448..ce072520696cb 100644 --- a/common/common.h +++ b/common/common.h @@ -302,6 +302,13 @@ struct common_params { enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs +#ifdef LLAMA_MOE_ENABLE + bool moe_enable = false; // enable dynamic MoE routing + int32_t moe_cache_size = 0; // number of experts kept resident per device + bool moe_prefetch = false; // enable async prefetch + int32_t moe_prefetch_lookahead = 1; // number of micro-batches to prefetch +#endif + struct cpu_params cpuparams; struct cpu_params cpuparams_batch; diff --git a/docs/development/HOWTO-add-model.md b/docs/development/HOWTO-add-model.md index 5989b873a611d..2a1e7158410f9 100644 --- a/docs/development/HOWTO-add-model.md +++ b/docs/development/HOWTO-add-model.md @@ -117,6 +117,22 @@ Note: to debug the inference graph: you can use [llama-eval-callback](/examples/ https://github.com/ggml-org/ggml/blob/master/docs/gguf.md +### GGUF_MOE (provisional) + +The `GGUF_MOE` extension introduces explicit router tensors and per-expert tensor groups that can be dynamically paged in at runtime. When exporting a model with mixture-of-experts layers, populate the following metadata keys and tensor groups: + +- Metadata keys (per MoE layer): + - `moe.layer.{i}.num_experts` – total number of experts in the layer. + - `moe.layer.{i}.top_k` – active experts per token. + - `moe.layer.{i}.router_type` – optional string describing router activation (e.g. `softmax`). +- Router tensors: + - `blk.{i}.router.w1`, `blk.{i}.router.w2` (plus bias variants when present). +- Expert tensor groups: + - `blk.{i}.expert.{e}.w1`, `blk.{i}.expert.{e}.w2`, `blk.{i}.expert.{e}.w3`, etc., matching the FFN projections for each expert `e`. + - Shared expert tensors continue to use the existing `ffn_*_shexp` names. + +All expert tensors must be stored as standalone GGUF entries (not packed in the last dimension). This allows llama.cpp to mmap each expert independently and back the CUDA ExpertCache with fine-grained handles. + ## Resources - YaRN RoPE scaling https://github.com/ggml-org/llama.cpp/pull/2268 diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index dab795fb90a0a..d86cd6cabfd2f 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -33,6 +33,7 @@ else() add_subdirectory(gen-docs) add_subdirectory(training) add_subdirectory(diffusion) + add_subdirectory(moe) add_subdirectory(model-conversion) if (NOT GGML_BACKEND_DL) add_subdirectory(convert-llama2c-to-ggml) diff --git a/examples/moe/CMakeLists.txt b/examples/moe/CMakeLists.txt new file mode 100644 index 0000000000000..b71a5df310129 --- /dev/null +++ b/examples/moe/CMakeLists.txt @@ -0,0 +1,3 @@ +add_executable(moe-loader main.cpp) +target_link_libraries(moe-loader PRIVATE llama Threads::Threads) +target_include_directories(moe-loader PRIVATE ${CMAKE_SOURCE_DIR}/src) diff --git a/examples/moe/main.cpp b/examples/moe/main.cpp new file mode 100644 index 0000000000000..b885f3448972f --- /dev/null +++ b/examples/moe/main.cpp @@ -0,0 +1,250 @@ +#include "llama.h" +#ifdef LLAMA_MOE_ENABLE +#include "llama-moe.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef LLAMA_MOE_ENABLE + +struct options { + std::string model_path; + std::string prompt = "Hello world"; + int steps = -1; + std::string json_path; +}; + +static bool parse_args(int argc, char ** argv, options & out) { + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--prompt" && i + 1 < argc) { + out.prompt = argv[++i]; + } else if (arg == "--steps" && i + 1 < argc) { + out.steps = std::max(1, std::atoi(argv[++i])); + } else if (arg == "--json" && i + 1 < argc) { + out.json_path = argv[++i]; + } else if (arg.rfind("--", 0) == 0) { + std::cerr << "Unknown option: " << arg << "\n"; + return false; + } else if (out.model_path.empty()) { + out.model_path = arg; + } else { + std::cerr << "Unexpected argument: " << arg << "\n"; + return false; + } + } + if (out.model_path.empty()) { + return false; + } + return true; +} + +static std::vector tokenize_prompt(const llama_vocab * vocab, const std::string & prompt) { + const int32_t max_tokens = prompt.size() + 16; + std::vector tokens(max_tokens); + const int32_t n = llama_tokenize(vocab, prompt.c_str(), prompt.length(), tokens.data(), max_tokens, true, true); + if (n < 0) { + return {}; + } + tokens.resize(n); + return tokens; +} + +struct run_result { + std::vector logits; + llama_perf_context_data perf{}; + llama_moe_cache_stats stats{}; +}; + +static run_result run_trace(const llama_model * model, llama_context * ctx, const std::vector & prompt_tokens, int steps, bool collect_stats) { + const llama_vocab * vocab = llama_model_get_vocab(model); + const int32_t n_vocab = llama_vocab_n_tokens(vocab); + + run_result res; + res.logits.resize(static_cast(steps) * n_vocab); + + llama_batch batch = llama_batch_init(1, 0, 1); + + for (int step = 0; step < steps; ++step) { + llama_token token = prompt_tokens[std::min(step, static_cast(prompt_tokens.size() - 1))]; + batch.n_tokens = 1; + batch.token[0] = token; + batch.pos[0] = step; + batch.n_seq_id[0] = 1; + batch.seq_id[0][0] = 0; + batch.logits[0] = 1; + + const int32_t rc = llama_decode(ctx, batch); + if (rc != 0) { + std::cerr << "llama_decode failed with code " << rc << " at step " << step << "\n"; + break; + } + + const float * logits = llama_get_logits(ctx); + std::copy(logits, logits + n_vocab, res.logits.begin() + static_cast(step) * n_vocab); + } + + llama_batch_free(batch); + + res.perf = llama_perf_context(ctx); +#ifdef LLAMA_MOE_ENABLE + if (collect_stats) { + llama_moe_cache_get_stats(ctx, &res.stats); + } +#else + (void) collect_stats; +#endif + + return res; +} + +static void emit_json(const std::string & path, + int steps, + int vocab, + double max_diff, + double mean_diff, + const llama_perf_context_data & perf_moe, + const llama_perf_context_data & perf_dense, + const llama_moe_cache_stats & stats) { + std::ofstream ofs(path); + if (!ofs.is_open()) { + std::cerr << "Failed to open JSON output file: " << path << "\n"; + return; + } + ofs << "{\n"; + ofs << " \"steps\": " << steps << ",\n"; + ofs << " \"vocab\": " << vocab << ",\n"; + ofs << " \"max_abs_diff\": " << max_diff << ",\n"; + ofs << " \"mean_abs_diff\": " << mean_diff << ",\n"; + ofs << " \"moe_perf\": {\n"; + ofs << " \"tokens\": " << perf_moe.n_eval << ",\n"; + ofs << " \"time_ms\": " << perf_moe.t_eval_ms << ",\n"; + ofs << " \"tok_per_s\": " << (perf_moe.t_eval_ms > 0 ? (perf_moe.n_eval / (perf_moe.t_eval_ms / 1000.0)) : 0.0) << "\n"; + ofs << " },\n"; + ofs << " \"dense_perf\": {\n"; + ofs << " \"tokens\": " << perf_dense.n_eval << ",\n"; + ofs << " \"time_ms\": " << perf_dense.t_eval_ms << ",\n"; + ofs << " \"tok_per_s\": " << (perf_dense.t_eval_ms > 0 ? (perf_dense.n_eval / (perf_dense.t_eval_ms / 1000.0)) : 0.0) << "\n"; + ofs << " },\n"; + ofs << " \"cache_stats\": {\n"; + ofs << " \"resident\": " << stats.resident << ",\n"; + ofs << " \"capacity_bytes\": " << stats.capacity_bytes << ",\n"; + ofs << " \"loads\": " << stats.loads << ",\n"; + ofs << " \"hits\": " << stats.hits << ",\n"; + ofs << " \"evictions\": " << stats.evictions << ",\n"; + ofs << " \"prefetch_requests\": " << stats.prefetch_requests << "\n"; + ofs << " }\n"; + ofs << "}\n"; +} +#endif // LLAMA_MOE_ENABLE + +int main(int argc, char ** argv) { +#ifdef LLAMA_MOE_ENABLE + options opts; + if (!parse_args(argc, argv, opts)) { + std::cerr << "Usage: moe-validate [--prompt \"text\"] [--steps N] [--json path]\n"; + return 1; + } + + llama_backend_init(); + + llama_model_params mparams = llama_model_default_params(); + llama_model * model = llama_model_load_from_file(opts.model_path.c_str(), mparams); + if (model == nullptr) { + std::cerr << "Failed to load model: " << opts.model_path << "\n"; + llama_backend_free(); + return 1; + } + + llama_context_params cparams = llama_context_default_params(); + cparams.moe_enable = true; + cparams.moe_prefetch = true; + llama_context * moe_ctx = llama_init_from_model(model, cparams); + if (moe_ctx == nullptr) { + std::cerr << "Failed to create MoE context\n"; + llama_model_free(model); + llama_backend_free(); + return 1; + } + + llama_context_params dense_params = cparams; + dense_params.moe_enable = false; + llama_context * dense_ctx = llama_init_from_model(model, dense_params); + if (dense_ctx == nullptr) { + std::cerr << "Failed to create dense fallback context\n"; + llama_free(moe_ctx); + llama_model_free(model); + llama_backend_free(); + return 1; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + auto tokens = tokenize_prompt(vocab, opts.prompt); + if (tokens.empty()) { + std::cerr << "Failed to tokenize prompt\n"; + llama_free(dense_ctx); + llama_free(moe_ctx); + llama_model_free(model); + llama_backend_free(); + return 1; + } + + const int steps = (opts.steps > 0) ? std::min(opts.steps, static_cast(tokens.size())) : static_cast(tokens.size()); + const int32_t n_vocab = llama_vocab_n_tokens(vocab); + + auto moe_run = run_trace(model, moe_ctx, tokens, steps, true); + auto dense_run = run_trace(model, dense_ctx, tokens, steps, false); + + double max_diff = 0.0; + double sum_diff = 0.0; + size_t count = static_cast(steps) * n_vocab; + for (size_t i = 0; i < count; ++i) { + double diff = std::abs(moe_run.logits[i] - dense_run.logits[i]); + sum_diff += diff; + if (diff > max_diff) { + max_diff = diff; + } + } + double mean_diff = count > 0 ? sum_diff / static_cast(count) : 0.0; + + auto tok_per_s = [](const llama_perf_context_data & perf) -> double { + return perf.t_eval_ms > 0 ? (perf.n_eval / (perf.t_eval_ms / 1000.0)) : 0.0; + }; + + std::cout << "Prompt tokens: " << tokens.size() << " | Steps evaluated: " << steps << "\n"; + std::cout << "Max abs diff: " << max_diff << " | Mean abs diff: " << mean_diff << "\n"; + std::cout << "Dense : tokens=" << dense_run.perf.n_eval << " time_ms=" << dense_run.perf.t_eval_ms + << " tok/s=" << tok_per_s(dense_run.perf) << "\n"; + std::cout << "MoE : tokens=" << moe_run.perf.n_eval << " time_ms=" << moe_run.perf.t_eval_ms + << " tok/s=" << tok_per_s(moe_run.perf) << "\n"; + std::cout << "MoE cache: resident=" << moe_run.stats.resident + << " loads=" << moe_run.stats.loads + << " hits=" << moe_run.stats.hits + << " evictions=" << moe_run.stats.evictions + << " prefetch_requests=" << moe_run.stats.prefetch_requests << "\n"; + + if (!opts.json_path.empty()) { + emit_json(opts.json_path, steps, n_vocab, max_diff, mean_diff, moe_run.perf, dense_run.perf, moe_run.stats); + std::cout << "Wrote JSON summary to " << opts.json_path << "\n"; + } + + llama_free(dense_ctx); + llama_free(moe_ctx); + llama_model_free(model); + llama_backend_free(); + + return 0; +#else + std::cerr << "This build of llama.cpp was compiled without LLAMA_MOE_ENABLE.\n"; + (void) argc; + (void) argv; + return 1; +#endif +} diff --git a/include/llama-moe.h b/include/llama-moe.h new file mode 100644 index 0000000000000..8ca62f0de841b --- /dev/null +++ b/include/llama-moe.h @@ -0,0 +1,232 @@ +#pragma once + +#ifdef LLAMA_MOE_ENABLE + +#include "ggml.h" +#include "ggml-backend.h" +#include "llama-graph.h" + +#include +#include +#include +#include +#include +#include +#include + +#ifdef GGML_USE_CUDA +struct CUstream_st; +using cudaStream_t = CUstream_st *; +#endif + +struct llm_graph_result; + +enum class llama_moe_weight_kind : uint8_t { + GATE_WEIGHT = 0, + UP_WEIGHT = 1, + DOWN_WEIGHT = 2, + GATE_BIAS = 3, + UP_BIAS = 4, + DOWN_BIAS = 5, + UNKNOWN = 255, +}; + +struct llama_moe_expert_handle { + int32_t id = -1; + int32_t layer = -1; + int32_t expert_index = -1; + llama_moe_weight_kind kind = llama_moe_weight_kind::UNKNOWN; + ggml_tensor * tensor = nullptr; + size_t bytes = 0; + size_t offset = 0; + void * host_ptr = nullptr; + ggml_type type = GGML_TYPE_F32; + int32_t n_dims = 0; + std::array ne = {0, 0, 0, 0}; + std::array nb = {0, 0, 0, 0}; + int64_t rows = 0; + int64_t cols = 0; + int32_t slice_axis = -1; + bool is_quantized = false; + bool is_contiguous = false; + bool is_view = false; +}; + +static inline int32_t llama_moe_compose_id(int32_t layer, int32_t expert_index, llama_moe_weight_kind kind) { + constexpr int32_t KIND_FACTOR = 10; + constexpr int32_t EXPERT_FACTOR = 1000; + + return layer * EXPERT_FACTOR * KIND_FACTOR + + expert_index * KIND_FACTOR + + static_cast(kind); +} + +struct llama_moe_router_handle { + int32_t layer = -1; + ggml_tensor * tensor = nullptr; + size_t bytes = 0; +}; + +struct llama_context; +class ExpertCache; + +struct llama_moe_dispatch_desc { + llama_context * ctx = nullptr; + ExpertCache * cache = nullptr; + ggml_backend_t backend = nullptr; + int32_t layer = -1; + int32_t n_expert = 0; + int32_t n_expert_used = 0; + int32_t n_embd = 0; + int32_t n_tokens = 0; + int32_t n_ff = 0; + llm_ffn_op_type activation = LLM_FFN_SILU; + bool has_gate = false; + bool has_gate_in = false; + bool has_gate_bias = false; + bool has_up_bias = false; + bool has_down_bias = false; + bool weight_before_ffn = false; + bool allow_quantized = false; + bool use_cuda = false; +}; + +struct llama_moe_cache_stats { + size_t resident = 0; + size_t capacity_bytes = 0; + uint64_t loads = 0; + uint64_t hits = 0; + uint64_t evictions = 0; + uint64_t prefetch_requests = 0; + struct device_stats { + int device = -1; + size_t resident = 0; + size_t capacity_bytes = 0; + uint64_t loads = 0; + uint64_t hits = 0; + uint64_t evictions = 0; + }; + std::vector per_device; +}; + +struct llama_moe_prefetch_stats { + uint64_t updates = 0; + uint64_t prefetch_calls = 0; + uint64_t tokens_observed = 0; +}; + +ggml_tensor * llama_moe_build_dispatch( + ggml_context * ctx, + ggml_tensor * input, + ggml_tensor * selected_experts, + ggml_tensor * weights, + const llama_moe_dispatch_desc & desc, + llm_graph_result * owner = nullptr); + +class ExpertCache { +public: + struct Config { + size_t vram_pool_bytes = 0; + uint32_t max_resident_experts = 0; + bool enable_prefetch = false; + uint32_t prefetch_lookahead = 1; + struct DevicePolicy { + int device = -1; + size_t capacity_bytes = 0; + uint32_t max_resident_experts = 0; + float weight = 0.0f; + }; + std::vector device_policies; + bool auto_assign_devices = true; + }; + + ExpertCache(); + explicit ExpertCache(const Config & cfg); + ~ExpertCache(); + + ExpertCache(const ExpertCache &) = delete; + ExpertCache & operator=(const ExpertCache &) = delete; + + void configure(const Config & cfg); + + void register_expert(const llama_moe_expert_handle & handle); + void register_experts(const std::vector & handles); + + void clear(); + + bool has_resident(int32_t expert_id) const; + const llama_moe_expert_handle * find(int32_t expert_id) const; + llama_moe_cache_stats stats() const; + void reset_stats(); + +#ifdef GGML_USE_CUDA + void attach_stream(cudaStream_t stream, int device); + cudaStream_t stream() const; +#endif + + // Ensures the expert is present on the active device. Returns device pointer or nullptr on failure. + void * ensure_loaded(int32_t expert_id, int device = -1, ggml_backend_buffer_t device_buffer = nullptr); + + // Prefetch a list of experts asynchronously + void prefetch(const std::vector & expert_ids); + + size_t resident_count() const; + size_t capacity_bytes() const; + +private: + struct DeviceSlot { + int32_t expert_id = -1; + void * device_ptr = nullptr; + size_t bytes = 0; + uint64_t last_used = 0; + uint64_t hits = 0; + void * host_staging = nullptr; + size_t staging_capacity = 0; + }; + + struct DevicePool { + std::vector slots; + size_t pool_bytes = 0; + struct Stats { + uint64_t loads = 0; + uint64_t hits = 0; + uint64_t evictions = 0; + } stats; + }; + + struct ExpertRecord { + llama_moe_expert_handle handle; + std::unordered_map slot_by_device; + }; + + using ExpertMap = std::unordered_map; + + void allocate_pool(); + void release_pool(); + DeviceSlot * find_lru(int device); + DevicePool & get_or_create_pool(int device); + size_t capacity_for_device(int device) const; + size_t max_slots_for_device(int device) const; + int select_device_for_expert(int32_t expert_id, int device_hint) const; + + Config config_; + ExpertMap experts_; + std::unordered_map device_pools_; + size_t pool_bytes_ = 0; + uint64_t timestamp_ = 0; + mutable llama_moe_cache_stats stats_; + +#ifdef GGML_USE_CUDA + cudaStream_t stream_ = nullptr; + std::unordered_map device_streams_; + int current_device_ = -1; +#endif + std::vector device_policies_; + std::unordered_map device_policy_by_id_; + double device_policy_total_weight_ = 0.0; + bool auto_assign_devices_ = true; + + mutable std::mutex mutex_; +}; + +#endif // LLAMA_MOE_ENABLE diff --git a/include/llama.h b/include/llama.h index 98bed9d6150a0..5c3ddbcb6df65 100644 --- a/include/llama.h +++ b/include/llama.h @@ -337,6 +337,12 @@ extern "C" { ggml_abort_callback abort_callback; void * abort_callback_data; +#ifdef LLAMA_MOE_ENABLE + // Mixture-of-Experts runtime controls + uint32_t moe_cache_size; // number of experts to keep resident per device + uint32_t moe_prefetch_lookahead; // micro-batches to prefetch ahead +#endif + // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU @@ -348,6 +354,10 @@ extern "C" { bool kv_unified; // use a unified buffer across the input sequences when computing the attention // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 +#ifdef LLAMA_MOE_ENABLE + bool moe_enable; // enable dynamic MoE routing and caching + bool moe_prefetch; // overlap DMA with compute using auxiliary streams +#endif }; // model quantization parameters @@ -1410,3 +1420,9 @@ extern "C" { #endif #endif // LLAMA_H +#ifdef LLAMA_MOE_ENABLE +struct llama_moe_cache_stats; +struct llama_moe_prefetch_stats; +LLAMA_API void llama_moe_cache_get_stats(const struct llama_context * ctx, struct llama_moe_cache_stats * out_stats); +LLAMA_API void llama_moe_prefetch_get_stats(const struct llama_context * ctx, struct llama_moe_prefetch_stats * out_stats); +#endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 630b2cddf67e8..75309cdd04055 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -26,6 +26,7 @@ add_library(llama llama-memory-hybrid.cpp llama-memory-recurrent.cpp llama-mmap.cpp + llama-moe.cpp llama-model-loader.cpp llama-model-saver.cpp llama-model.cpp @@ -132,12 +133,29 @@ add_library(llama models/graph-context-mamba.cpp ) -target_include_directories(llama PRIVATE .) +target_include_directories(llama PRIVATE . ../ggml/src ../ggml/include) target_include_directories(llama PUBLIC ../include) target_compile_features (llama PRIVATE cxx_std_17) # don't bump target_link_libraries(llama PUBLIC ggml) +if (GGML_CUDA AND LLAMA_MOE_ENABLE) + target_sources(llama PRIVATE llama-moe-cuda.cu) + if (DEFINED CMAKE_CUDA_ARCHITECTURES) + set_property(TARGET llama PROPERTY CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES}) + endif() + set_property(TARGET llama PROPERTY CUDA_SEPARABLE_COMPILATION ON) +endif() + +if (GGML_CUDA) + if (CUDAToolkit_INCLUDE_DIRS) + target_include_directories(llama PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) + elseif (CUDAToolkit_BIN_DIR) + get_filename_component(_cuda_root "${CUDAToolkit_BIN_DIR}/.." ABSOLUTE) + target_include_directories(llama PRIVATE "${_cuda_root}/include") + endif() +endif() + if (BUILD_SHARED_LIBS) set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_definitions(llama PRIVATE LLAMA_BUILD) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index b7642b568dffb..e4ccbaaef54d4 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -2477,6 +2477,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ENC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_GATE_INP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_GATE_INP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE_EXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_DOWN_EXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_UP_EXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_SSM_IN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 2b39366271ff9..4238c13498a97 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -7,9 +7,11 @@ #include "llama-mmap.h" #include "llama-model.h" +#include #include #include #include +#include #include // @@ -102,6 +104,12 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; +#ifdef LLAMA_MOE_ENABLE + cparams.moe_enable = params.moe_enable; + cparams.moe_cache_size = params.moe_cache_size; + cparams.moe_prefetch = params.moe_prefetch; + cparams.moe_prefetch_lookahead = params.moe_prefetch_lookahead; +#endif { const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); @@ -149,6 +157,36 @@ llama_context::llama_context( } if (!hparams.vocab_only) { +#ifdef LLAMA_MOE_ENABLE + if (cparams.moe_enable) { + expert_cache = std::make_unique(); + ExpertCache::Config cache_cfg{}; + cache_cfg.max_resident_experts = cparams.moe_cache_size; + cache_cfg.vram_pool_bytes = 0; + cache_cfg.enable_prefetch = cparams.moe_prefetch; + cache_cfg.prefetch_lookahead = cparams.moe_prefetch_lookahead; + if (!model.devices.empty()) { + cache_cfg.device_policies.clear(); + cache_cfg.auto_assign_devices = false; + for (size_t di = 0; di < model.devices.size(); ++di) { + ggml_backend_dev_t dev = model.devices[di]; + if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) { + continue; + } + ExpertCache::Config::DevicePolicy policy{}; + policy.device = static_cast(di); + policy.max_resident_experts = cparams.moe_cache_size; + policy.weight = 1.0f; + cache_cfg.device_policies.push_back(policy); + } + if (cache_cfg.device_policies.empty()) { + cache_cfg.auto_assign_devices = true; + } + } + expert_cache->configure(cache_cfg); + moe_initialize(); + } +#endif // GPU backends for (auto * dev : model.devices) { ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); @@ -598,6 +636,23 @@ float * llama_context::get_logits_ith(int32_t i) { } } +#ifdef LLAMA_MOE_ENABLE +ExpertCache * llama_context::get_expert_cache() const { + return expert_cache ? expert_cache.get() : nullptr; +} + +llama_moe_cache_stats llama_context::get_moe_cache_stats() const { + if (!expert_cache) { + return {}; + } + return expert_cache->stats(); +} + +llama_moe_prefetch_stats llama_context::get_moe_prefetch_stats() const { + return moe_prefetch_.stats; +} +#endif + float * llama_context::get_embeddings() { output_reorder(); @@ -1102,6 +1157,12 @@ int llama_context::decode(const llama_batch & batch_inp) { n_outputs = n_outputs_new; } +#ifdef LLAMA_MOE_ENABLE + if (cparams.moe_enable) { + moe_prefetch_for_batch(ubatch); + } +#endif + ggml_status status; const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); @@ -1148,6 +1209,12 @@ int llama_context::decode(const llama_batch & batch_inp) { t_embd = res->get_embd_pooled(); } +#ifdef LLAMA_MOE_ENABLE + if (cparams.moe_enable) { + moe_update_prefetch_state(res, ubatch); + } +#endif + // extract logits if (t_logits && n_outputs > 0) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); @@ -1453,6 +1520,9 @@ llm_graph_params llama_context::graph_params( /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, +#ifdef LLAMA_MOE_ENABLE + /*.expert_cache =*/ expert_cache.get(), +#endif }; } @@ -1518,6 +1588,263 @@ llm_graph_cb llama_context::graph_get_cb() const { }; } +#ifdef LLAMA_MOE_ENABLE +void llama_context::moe_initialize() { + if (!expert_cache) { + return; + } + + const auto & hparams = model.hparams; + if (hparams.n_expert == 0) { + LLAMA_LOG_WARN("%s: MoE enabled but model reports zero experts\n", __func__); + return; + } + + const size_t n_layers = model.layers.size(); + moe_prefetch_.initialized = true; + moe_prefetch_.score.assign(n_layers, std::vector(hparams.n_expert, 0.0f)); + moe_prefetch_.top_experts.assign(n_layers, {}); + uint32_t base_width = hparams.n_expert_used > 0 ? hparams.n_expert_used : 1; + uint32_t lookahead = std::max(1, cparams.moe_prefetch_lookahead); + uint32_t desired = base_width * lookahead; + desired = std::max(desired, base_width); + desired = std::min(desired, hparams.n_expert); + moe_prefetch_.width = desired > 0 ? desired : std::min(hparams.n_expert, 4); + moe_prefetch_.stats = {}; + + for (size_t il = 0; il < n_layers; ++il) { + const auto & layer = model.layers[il]; + std::vector handles; + handles.reserve(hparams.n_expert * 6); + + auto fill_metadata = [&](llama_moe_expert_handle & h, ggml_tensor * tensor, int axis) { + if (!tensor) { + return; + } + h.type = tensor->type; + h.n_dims = ggml_n_dims(tensor); + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + h.ne[i] = tensor->ne[i]; + h.nb[i] = tensor->nb[i]; + } + h.slice_axis = axis; + h.is_quantized = ggml_is_quantized(tensor->type); + h.is_contiguous = ggml_is_contiguous(tensor); + h.is_view = tensor->view_src != nullptr; + if (h.n_dims >= 2) { + h.cols = tensor->ne[0]; + h.rows = tensor->ne[1]; + } else if (h.n_dims == 1) { + h.cols = 1; + h.rows = tensor->ne[0]; + } else { + h.cols = 0; + h.rows = 0; + } + }; + + auto register_combined = [&](ggml_tensor * tensor, llama_moe_weight_kind kind, int axis) { + if (!tensor) { + return; + } + const size_t stride = tensor->nb[axis]; + if (stride == 0) { + return; + } + char * base = static_cast(tensor->data); + for (uint32_t ie = 0; ie < hparams.n_expert; ++ie) { + llama_moe_expert_handle h{}; + h.layer = static_cast(il); + h.expert_index = static_cast(ie); + h.kind = kind; + h.tensor = tensor; + h.bytes = stride; + h.offset = stride * ie; + h.host_ptr = base ? base + h.offset : nullptr; + h.id = llama_moe_compose_id(h.layer, h.expert_index, h.kind); + fill_metadata(h, tensor, axis); + handles.push_back(h); + } + }; + + auto register_split = [&](ggml_tensor * tensor, llama_moe_weight_kind kind, uint32_t expert_index) { + if (!tensor) { + return; + } + llama_moe_expert_handle h{}; + h.layer = static_cast(il); + h.expert_index = static_cast(expert_index); + h.kind = kind; + h.tensor = tensor; + h.bytes = ggml_nbytes(tensor); + h.offset = 0; + h.host_ptr = tensor->data; + h.id = llama_moe_compose_id(h.layer, h.expert_index, h.kind); + fill_metadata(h, tensor, -1); + handles.push_back(h); + }; + + // combined tensors + register_combined(layer.ffn_gate_exps, llama_moe_weight_kind::GATE_WEIGHT, 2); + register_combined(layer.ffn_up_exps, llama_moe_weight_kind::UP_WEIGHT, 2); + register_combined(layer.ffn_down_exps, llama_moe_weight_kind::DOWN_WEIGHT, 2); + register_combined(layer.ffn_gate_exps_b, llama_moe_weight_kind::GATE_BIAS, 1); + register_combined(layer.ffn_up_exps_b, llama_moe_weight_kind::UP_BIAS, 1); + register_combined(layer.ffn_down_exps_b, llama_moe_weight_kind::DOWN_BIAS, 1); + + // split tensors per expert + for (uint32_t ie = 0; ie < hparams.n_expert; ++ie) { + register_split(layer.ffn_gate_exp_splits[ie], llama_moe_weight_kind::GATE_WEIGHT, ie); + register_split(layer.ffn_up_exp_splits[ie], llama_moe_weight_kind::UP_WEIGHT, ie); + register_split(layer.ffn_down_exp_splits[ie], llama_moe_weight_kind::DOWN_WEIGHT, ie); + register_split(layer.ffn_gate_bias_splits[ie], llama_moe_weight_kind::GATE_BIAS, ie); + register_split(layer.ffn_up_bias_splits[ie], llama_moe_weight_kind::UP_BIAS, ie); + register_split(layer.ffn_down_bias_splits[ie], llama_moe_weight_kind::DOWN_BIAS, ie); + } + + expert_cache->register_experts(handles); + } +} + +void llama_context::moe_prefetch_for_batch(const llama_ubatch & ubatch) { + GGML_UNUSED(ubatch); + + if (!expert_cache || !cparams.moe_prefetch) { + return; + } + + const auto & hparams = model.hparams; + if (hparams.n_expert == 0) { + return; + } + + std::vector ids; + ids.reserve(moe_prefetch_.width * 6); + + if (moe_prefetch_.initialized && moe_prefetch_.stats.updates > 0) { + const size_t layers = std::min(moe_prefetch_.top_experts.size(), model.layers.size()); + for (size_t il = 0; il < layers; ++il) { + const auto & top = moe_prefetch_.top_experts[il]; + if (top.empty()) { + continue; + } + const auto & layer = model.layers[il]; + const bool has_gate = layer.ffn_gate_exps != nullptr || layer.ffn_gate_exp_splits[0] != nullptr; + const bool has_up_bias = layer.ffn_up_exps_b != nullptr || layer.ffn_up_bias_splits[0] != nullptr; + const bool has_down_bias = layer.ffn_down_exps_b != nullptr || layer.ffn_down_bias_splits[0] != nullptr; + const bool has_gate_bias = layer.ffn_gate_exps_b != nullptr || layer.ffn_gate_bias_splits[0] != nullptr; + + const size_t width = std::min(moe_prefetch_.width, top.size()); + for (size_t idx = 0; idx < width; ++idx) { + const int32_t expert = top[idx]; + ids.push_back(llama_moe_compose_id(static_cast(il), expert, llama_moe_weight_kind::UP_WEIGHT)); + ids.push_back(llama_moe_compose_id(static_cast(il), expert, llama_moe_weight_kind::DOWN_WEIGHT)); + if (has_gate) { + ids.push_back(llama_moe_compose_id(static_cast(il), expert, llama_moe_weight_kind::GATE_WEIGHT)); + } + if (has_up_bias) { + ids.push_back(llama_moe_compose_id(static_cast(il), expert, llama_moe_weight_kind::UP_BIAS)); + } + if (has_down_bias) { + ids.push_back(llama_moe_compose_id(static_cast(il), expert, llama_moe_weight_kind::DOWN_BIAS)); + } + if (has_gate_bias) { + ids.push_back(llama_moe_compose_id(static_cast(il), expert, llama_moe_weight_kind::GATE_BIAS)); + } + } + } + } + + if (ids.empty()) { + const uint32_t fallback_prefetch = std::min( + hparams.n_expert_used > 0 ? hparams.n_expert_used : 1, + hparams.n_expert); + ids.reserve(fallback_prefetch * model.layers.size() * 2); + for (size_t il = 0; il < model.layers.size(); ++il) { + for (uint32_t i = 0; i < fallback_prefetch; ++i) { + ids.push_back(llama_moe_compose_id(static_cast(il), static_cast(i), llama_moe_weight_kind::UP_WEIGHT)); + ids.push_back(llama_moe_compose_id(static_cast(il), static_cast(i), llama_moe_weight_kind::DOWN_WEIGHT)); + } + } + } + + if (ids.empty()) { + return; + } + + std::sort(ids.begin(), ids.end()); + ids.erase(std::unique(ids.begin(), ids.end()), ids.end()); + + expert_cache->prefetch(ids); + moe_prefetch_.stats.prefetch_calls++; +} + +void llama_context::moe_update_prefetch_state(const llm_graph_result * res, const llama_ubatch & ubatch) { + GGML_UNUSED(ubatch); + + if (!expert_cache || !cparams.moe_enable || !moe_prefetch_.initialized || res == nullptr) { + return; + } + + const auto & states = res->get_moe_states(); + const size_t layers = std::min(states.size(), moe_prefetch_.score.size()); + + for (size_t il = 0; il < layers; ++il) { + const auto & state = states[il]; + if (state.selected == nullptr || state.weights == nullptr) { + continue; + } + + const int64_t top_k = state.selected->ne[0]; + const int64_t n_tokens = state.selected->ne[1]; + if (top_k <= 0 || n_tokens <= 0) { + continue; + } + + const size_t total = static_cast(top_k) * static_cast(n_tokens); + + std::vector selected_host(total); + std::vector weights_host(total); + + ggml_backend_tensor_get(state.selected, selected_host.data(), 0, total * sizeof(int32_t)); + ggml_backend_tensor_get(state.weights, weights_host.data(), 0, total * sizeof(float)); + + auto & score = moe_prefetch_.score[il]; + const float decay = moe_prefetch_.decay; + for (float & s : score) { + s *= decay; + } + + for (size_t idx = 0; idx < total; ++idx) { + const int32_t expert = selected_host[idx]; + if (expert < 0 || static_cast(expert) >= score.size()) { + continue; + } + score[expert] += weights_host[idx]; + } + + auto & top = moe_prefetch_.top_experts[il]; + top.clear(); + std::vector indices(score.size()); + std::iota(indices.begin(), indices.end(), 0); + + const size_t width = std::min(moe_prefetch_.width, indices.size()); + if (width == 0) { + continue; + } + + std::nth_element(indices.begin(), indices.begin() + width, indices.end(), + [&](int a, int b) { return score[a] > score[b]; }); + indices.resize(width); + std::sort(indices.begin(), indices.end(), [&](int a, int b) { return score[a] > score[b]; }); + top = std::move(indices); + } + + moe_prefetch_.stats.updates++; + moe_prefetch_.stats.tokens_observed += ubatch.n_tokens; +} +#endif + // // state save/load // @@ -2313,6 +2640,12 @@ llama_context_params llama_context_default_params() { /*.op_offload =*/ true, /*.swa_full =*/ true, /*.kv_unified =*/ false, +#ifdef LLAMA_MOE_ENABLE + /*.moe_cache_size =*/ 0, + /*.moe_prefetch_lookahead =*/ 1, + /*.moe_enable =*/ false, + /*.moe_prefetch =*/ false, +#endif }; return result; @@ -2803,6 +3136,54 @@ void llama_perf_context_print(const llama_context * ctx) { __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval); LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval)); LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused); +#ifdef LLAMA_MOE_ENABLE + if (ctx) { + llama_moe_cache_stats cache_stats = ctx->get_moe_cache_stats(); + if (cache_stats.loads > 0 || cache_stats.hits > 0 || cache_stats.prefetch_requests > 0 || cache_stats.resident > 0) { + const double denom = static_cast(cache_stats.loads + cache_stats.hits); + const double hit_rate = denom > 0.0 ? (static_cast(cache_stats.hits) / denom) * 100.0 : 0.0; + LLAMA_LOG_INFO( + "%s: moe cache = resident=%zu, cap=%.2f MiB, loads=%" PRIu64 ", hits=%" PRIu64 ", evictions=%" PRIu64 ", hit_rate=%.2f%%, prefetch_req=%" PRIu64 "\n", + __func__, + cache_stats.resident, + cache_stats.capacity_bytes / (1024.0 * 1024.0), + cache_stats.loads, + cache_stats.hits, + cache_stats.evictions, + hit_rate, + cache_stats.prefetch_requests); + } + + for (const auto & dev : cache_stats.per_device) { + const double denom_dev = static_cast(dev.loads + dev.hits); + const double hit_rate_dev = denom_dev > 0.0 ? (static_cast(dev.hits) / denom_dev) * 100.0 : 0.0; + LLAMA_LOG_INFO( + "%s: device %d => resident=%zu, cap=%.2f MiB, loads=%" PRIu64 ", hits=%" PRIu64 ", evictions=%" PRIu64 ", hit_rate=%.2f%%\n", + __func__, + dev.device, + dev.resident, + dev.capacity_bytes / (1024.0 * 1024.0), + dev.loads, + dev.hits, + dev.evictions, + hit_rate_dev); + } + + const auto prefetch_stats = ctx->get_moe_prefetch_stats(); + if (prefetch_stats.prefetch_calls > 0 || prefetch_stats.updates > 0) { + const double avg_tokens = prefetch_stats.prefetch_calls > 0 + ? static_cast(prefetch_stats.tokens_observed) / static_cast(prefetch_stats.prefetch_calls) + : 0.0; + LLAMA_LOG_INFO( + "%s: moe prefetch = updates=%" PRIu64 ", calls=%" PRIu64 ", tokens=%" PRIu64 ", avg_tokens_per_call=%.2f\n", + __func__, + prefetch_stats.updates, + prefetch_stats.prefetch_calls, + prefetch_stats.tokens_observed, + avg_tokens); + } + } +#endif } void llama_perf_context_reset(llama_context * ctx) { diff --git a/src/llama-context.h b/src/llama-context.h index 20cbd78955412..ef7f0b1d2554d 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -4,6 +4,9 @@ #include "llama-cparams.h" #include "llama-graph.h" #include "llama-adapter.h" +#ifdef LLAMA_MOE_ENABLE +#include "llama-moe.h" +#endif #include "ggml-cpp.h" #include "ggml-opt.h" @@ -62,6 +65,12 @@ struct llama_context { float * get_logits(); float * get_logits_ith(int32_t i); +#ifdef LLAMA_MOE_ENABLE + ExpertCache * get_expert_cache() const; + llama_moe_cache_stats get_moe_cache_stats() const; + llama_moe_prefetch_stats get_moe_prefetch_stats() const; +#endif + float * get_embeddings(); float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); @@ -217,6 +226,12 @@ struct llama_context { llm_graph_cb graph_get_cb() const; +#ifdef LLAMA_MOE_ENABLE + void moe_initialize(); + void moe_prefetch_for_batch(const llama_ubatch & ubatch); + void moe_update_prefetch_state(const llm_graph_result * res, const llama_ubatch & ubatch); +#endif + // TODO: read/write lora adapters and cvec size_t state_write_data(llama_io_write_i & io); size_t state_read_data (llama_io_read_i & io); @@ -237,6 +252,17 @@ struct llama_context { llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably std::unique_ptr memory; +#ifdef LLAMA_MOE_ENABLE + std::unique_ptr expert_cache; + struct moe_prefetch_state { + bool initialized = false; + float decay = 0.8f; + uint32_t width = 0; + std::vector> score; + std::vector> top_experts; + llama_moe_prefetch_stats stats; + } moe_prefetch_; +#endif // decode output (2-dimensional array: [n_outputs][n_vocab]) size_t logits_size = 0; // capacity (of floats) for logits diff --git a/src/llama-cparams.h b/src/llama-cparams.h index fcef8fa976038..03ea395d91e00 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -34,6 +34,12 @@ struct llama_cparams { bool warmup; bool op_offload; bool kv_unified; +#ifdef LLAMA_MOE_ENABLE + bool moe_enable; + uint32_t moe_cache_size; + bool moe_prefetch; + uint32_t moe_prefetch_lookahead; +#endif enum llama_pooling_type pooling_type; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index f9751b3183694..128bf4eaae39c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -8,6 +8,12 @@ #include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" #include "llama-memory-recurrent.h" +#ifdef LLAMA_MOE_ENABLE +#include "llama-moe.h" +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +#endif +#endif #include #include @@ -473,16 +479,37 @@ llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) { debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0; } +llm_graph_result::~llm_graph_result() { +#ifdef LLAMA_MOE_ENABLE + for (auto & fn : cleanups) { + if (fn) { + fn(); + } + } + cleanups.clear(); +#endif +} + int64_t llm_graph_result::get_max_nodes() const { return max_nodes; } void llm_graph_result::reset() { +#ifdef LLAMA_MOE_ENABLE + for (auto & fn : cleanups) { + if (fn) { + fn(); + } + } + cleanups.clear(); +#endif t_tokens = nullptr; t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; + clear_moe_states(); + params = {}; inputs.clear(); @@ -500,6 +527,14 @@ void llm_graph_result::reset() { gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false); } +#ifdef LLAMA_MOE_ENABLE +void llm_graph_result::add_cleanup(std::function fn) { + if (fn) { + cleanups.push_back(std::move(fn)); + } +} +#endif + void llm_graph_result::set_inputs(const llama_ubatch * ubatch) { for (auto & input : inputs) { input->set_input(ubatch); @@ -547,6 +582,17 @@ void llm_graph_result::set_params(const llm_graph_params & params) { this->params = params; } +void llm_graph_result::set_moe_state(size_t layer, moe_state_view state) { + if (moe_states.size() < layer + 1) { + moe_states.resize(layer + 1); + } + moe_states[layer] = state; +} + +void llm_graph_result::clear_moe_states() { + moe_states.clear(); +} + // // llm_graph_context // @@ -590,7 +636,11 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : cb_func (params.cb), res (params.res), ctx0 (res->get_ctx()), - gf (res->get_gf()) { + gf (res->get_gf()) +#ifdef LLAMA_MOE_ENABLE + , expert_cache (params.expert_cache) +#endif +{ res->set_params(params); } @@ -900,6 +950,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn( const int64_t n_embd = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN +#ifdef LLAMA_MOE_ENABLE + const bool moe_runtime_enabled = cparams.moe_enable && expert_cache != nullptr; +#else + const bool moe_runtime_enabled = false; +#endif ggml_tensor * logits = nullptr; @@ -995,6 +1050,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens] cb(weights, "ffn_moe_weights", il); + if (res) { + res->set_moe_state(il, {selected_experts, weights, probs}); + } + if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) { weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); @@ -1026,6 +1085,42 @@ ggml_tensor * llm_graph_context::build_moe_ffn( //call early so that topk-moe can be used ggml_build_forward_expand(gf, weights); +#ifdef LLAMA_MOE_ENABLE + if (moe_runtime_enabled) { + llama_moe_dispatch_desc desc{}; + desc.ctx = nullptr; + desc.cache = expert_cache; + desc.layer = il; + desc.n_expert = static_cast(n_expert); + desc.n_expert_used = static_cast(n_expert_used); + if (up_exps) { + desc.n_ff = static_cast(up_exps->ne[1]); + } else if (down_exps) { + desc.n_ff = static_cast(down_exps->ne[0]); + } else { + desc.n_ff = 0; + } + desc.activation = type_op; + desc.has_gate = gate_exps != nullptr; + desc.has_gate_in = gate_inp != nullptr; + desc.has_gate_bias = gate_exps_b != nullptr; + desc.has_up_bias = up_exps_b != nullptr; + desc.has_down_bias = down_exps_b != nullptr; + desc.weight_before_ffn = weight_before_ffn; + desc.allow_quantized = false; + ggml_backend_t tensor_backend = sched ? ggml_backend_sched_get_tensor_backend(sched, cur) : nullptr; + if (tensor_backend && ggml_backend_is_cuda(tensor_backend)) { + desc.use_cuda = true; + desc.backend = tensor_backend; + } + + ggml_tensor * dispatch_input = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + ggml_tensor * moe_out = llama_moe_build_dispatch(ctx0, dispatch_input, selected_experts, weights, desc, res); + cb(moe_out, "ffn_moe_dispatch", il); + return moe_out; + } +#endif + cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); if (weight_before_ffn) { diff --git a/src/llama-graph.h b/src/llama-graph.h index d0c3934f67927..94d469f5102e7 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -24,6 +24,10 @@ class llama_kv_cache_iswa_context; class llama_memory_recurrent_context; class llama_memory_hybrid_context; +#ifdef LLAMA_MOE_ENABLE +class ExpertCache; +#endif + // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, @@ -421,6 +425,9 @@ struct llm_graph_params { llm_graph_cb cb; llm_graph_result * res; +#ifdef LLAMA_MOE_ENABLE + ExpertCache * expert_cache = nullptr; +#endif // return true if the "other" params would result in a graph with the same topology as with the current params // having the same topology allows us to reuse the graph in some cases @@ -471,13 +478,23 @@ class llm_graph_result { public: llm_graph_result(int64_t max_nodes); - virtual ~llm_graph_result() = default; + virtual ~llm_graph_result(); ggml_tensor * get_tokens() const { return t_tokens; } ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + struct moe_state_view { + ggml_tensor * selected = nullptr; // i32 [n_expert_used, n_tokens] + ggml_tensor * weights = nullptr; // f32 [1, n_expert_used, n_tokens] + ggml_tensor * probs = nullptr; // optional probabilities tensor + }; + + const std::vector & get_moe_states() const { return moe_states; } + void set_moe_state(size_t layer, moe_state_view state); + void clear_moe_states(); + ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -498,6 +515,10 @@ class llm_graph_result { void set_params(const llm_graph_params & params); +#ifdef LLAMA_MOE_ENABLE + void add_cleanup(std::function fn); +#endif + // important graph nodes ggml_tensor * t_tokens = nullptr; ggml_tensor * t_logits = nullptr; @@ -511,10 +532,16 @@ class llm_graph_result { // memory buffers used to evaluate the model std::vector buf_compute_meta; +#ifdef LLAMA_MOE_ENABLE + std::vector> cleanups; +#endif + ggml_cgraph * gf; int64_t max_nodes; + std::vector moe_states; + private: // keep a copy of the previous graph parameters // we will use this to determine whether the graph can be reused by comparing them with the new parameters @@ -585,6 +612,9 @@ struct llm_graph_context { ggml_context * ctx0 = nullptr; ggml_cgraph * gf = nullptr; +#ifdef LLAMA_MOE_ENABLE + ExpertCache * expert_cache = nullptr; +#endif llm_graph_context(const llm_graph_params & params); virtual ~llm_graph_context() = default; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1987135ca6a2e..e52aea0b3c398 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2548,9 +2548,76 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); } else { layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); +#ifdef LLAMA_MOE_ENABLE + auto has_tensor = [&](const LLM_TN_IMPL & tni) -> bool { + return ml.get_tensor_meta(tni.str().c_str()) != nullptr; + }; + + std::fill(layer.ffn_gate_exp_splits.begin(), layer.ffn_gate_exp_splits.end(), nullptr); + std::fill(layer.ffn_down_exp_splits.begin(), layer.ffn_down_exp_splits.end(), nullptr); + std::fill(layer.ffn_up_exp_splits.begin(), layer.ffn_up_exp_splits.end(), nullptr); + std::fill(layer.ffn_gate_bias_splits.begin(), layer.ffn_gate_bias_splits.end(), nullptr); + std::fill(layer.ffn_down_bias_splits.begin(), layer.ffn_down_bias_splits.end(), nullptr); + std::fill(layer.ffn_up_bias_splits.begin(), layer.ffn_up_bias_splits.end(), nullptr); + + if (has_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i))) { + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_exps = nullptr; + for (int64_t ie = 0; ie < n_expert; ++ie) { + layer.ffn_gate_exp_splits[ie] = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, ie), {n_embd, n_ff}, 0); + } + } + + if (has_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i))) { + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + } else { + layer.ffn_down_exps = nullptr; + for (int64_t ie = 0; ie < n_expert; ++ie) { + layer.ffn_down_exp_splits[ie] = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, ie), {n_ff, n_embd}, 0); + } + } + + if (has_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i))) { + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } else { + layer.ffn_up_exps = nullptr; + for (int64_t ie = 0; ie < n_expert; ++ie) { + layer.ffn_up_exp_splits[ie] = create_tensor(tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, ie), {n_embd, n_ff}, 0); + } + } + + if (has_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i))) { + layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff, n_expert}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_exps_b = nullptr; + for (int64_t ie = 0; ie < n_expert; ++ie) { + layer.ffn_gate_bias_splits[ie] = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXP, "bias", i, ie), {n_ff}, TENSOR_NOT_REQUIRED); + } + } + + if (has_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i))) { + layer.ffn_down_exps_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_down_exps_b = nullptr; + for (int64_t ie = 0; ie < n_expert; ++ie) { + layer.ffn_down_bias_splits[ie] = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXP, "bias", i, ie), {n_embd}, TENSOR_NOT_REQUIRED); + } + } + + if (has_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i))) { + layer.ffn_up_exps_b = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff, n_expert}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_up_exps_b = nullptr; + for (int64_t ie = 0; ie < n_expert; ++ie) { + layer.ffn_up_bias_splits[ie] = create_tensor(tn(LLM_TENSOR_FFN_UP_EXP, "bias", i, ie), {n_ff}, TENSOR_NOT_REQUIRED); + } + } +#else layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); +#endif // For Granite MoE Shared if (hparams.n_ff_shexp > 0) { diff --git a/src/llama-model.h b/src/llama-model.h index 71ff148e07dae..5ee1677b04773 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -12,6 +12,9 @@ #include #include #include +#ifdef LLAMA_MOE_ENABLE +#include +#endif struct llama_cparams; struct llama_ubatch; @@ -291,6 +294,14 @@ struct llama_layer { struct ggml_tensor * ffn_up_b = nullptr; // b3 struct ggml_tensor * ffn_act = nullptr; struct ggml_tensor * ffn_exp_probs_b = nullptr; +#ifdef LLAMA_MOE_ENABLE + std::array ffn_gate_exp_splits = {}; + std::array ffn_down_exp_splits = {}; + std::array ffn_up_exp_splits = {}; + std::array ffn_gate_bias_splits = {}; + std::array ffn_down_bias_splits = {}; + std::array ffn_up_bias_splits = {}; +#endif // mamba proj struct ggml_tensor * ssm_in = nullptr; diff --git a/src/llama-moe-cuda.cu b/src/llama-moe-cuda.cu new file mode 100644 index 0000000000000..953bcd600904e --- /dev/null +++ b/src/llama-moe-cuda.cu @@ -0,0 +1,493 @@ +#include "llama-moe.h" + +#if defined(LLAMA_MOE_ENABLE) && defined(GGML_USE_CUDA) + +#include "llama-impl.h" +#include "ggml-backend-impl.h" +#include "ggml-cuda.h" +#include "ggml-cuda/common.cuh" +#include "ggml-cuda/convert.cuh" + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace { + +static inline void llama_cuda_try(cudaError_t result, const char * expr) { + if (result != cudaSuccess) { + LLAMA_LOG_ERROR("%s: CUDA call failed: %s (%d)\n", __func__, cudaGetErrorString(result), result); + GGML_ABORT("%s", expr); + } +} + +constexpr float kInvSqrt2 = 0.70710678118654752440f; +constexpr float kSwigluOaiAlpha = 1.702f; +constexpr float kSwigluOaiLimit = 7.0f; + +struct device_buffer { + float * ptr = nullptr; + + void allocate(size_t count) { + if (count == 0) { + return; + } + if (ptr != nullptr) { + return; + } + llama_cuda_try(cudaMalloc(&ptr, count * sizeof(float)), "cudaMalloc device_buffer"); + } + + ~device_buffer() { + if (ptr != nullptr) { + cudaFree(ptr); + } + } + + device_buffer() = default; + device_buffer(const device_buffer &) = delete; + device_buffer & operator=(const device_buffer &) = delete; +}; + +struct pinned_buffer { + void * ptr = nullptr; + size_t bytes = 0; + + void allocate(size_t nbytes) { + if (nbytes == 0) { + return; + } + if (ptr != nullptr && bytes >= nbytes) { + return; + } + release(); + llama_cuda_try(cudaMallocHost(&ptr, nbytes), "cudaMallocHost pinned_buffer"); + bytes = nbytes; + } + + template + T * data() { + return reinterpret_cast(ptr); + } + + ~pinned_buffer() { + release(); + } + + void release() { + if (ptr != nullptr) { + cudaFreeHost(ptr); + ptr = nullptr; + bytes = 0; + } + } + + pinned_buffer() = default; + pinned_buffer(const pinned_buffer &) = delete; + pinned_buffer & operator=(const pinned_buffer &) = delete; +}; + +static inline void llama_cublas_try(cublasStatus_t result, const char * expr) { + if (result == CUBLAS_STATUS_SUCCESS) { + return; + } + LLAMA_LOG_ERROR("%s: cuBLAS call failed: status=%d\n", __func__, (int) result); + GGML_ABORT("%s", expr); +} + +__device__ inline float act_silu(float x) { + return x / (1.0f + __expf(-x)); +} + +__device__ inline float act_gelu(float x) { + const float cdf = 0.5f * (1.0f + erff(x * kInvSqrt2)); + return x * cdf; +} + +__device__ inline float act_relu(float x) { + return x > 0.0f ? x : 0.0f; +} + +__global__ void add_bias_kernel(float * data, const float * bias, int64_t n) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + data[idx] += bias[idx]; + } +} + +__global__ void apply_activation_kernel( + llm_ffn_op_type type, + const float * gate, + const float * up, + float * hidden, + int64_t n, + bool has_gate) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= n) { + return; + } + + const float up_val = up[idx]; + const float gate_val = has_gate ? gate[idx] : up_val; + + switch (type) { + case LLM_FFN_SILU: + case LLM_FFN_SWIGLU: + hidden[idx] = has_gate ? act_silu(gate_val) * up_val : act_silu(up_val); + break; + case LLM_FFN_SWIGLU_OAI_MOE: { + if (has_gate) { + const float x = fminf(gate_val, kSwigluOaiLimit); + const float y = fminf(fmaxf(up_val, -kSwigluOaiLimit), kSwigluOaiLimit); + const float out_glu = x / (1.0f + __expf(kSwigluOaiAlpha * (-x))); + hidden[idx] = out_glu * (y + 1.0f); + } else { + hidden[idx] = act_silu(up_val); + } + break; + } + case LLM_FFN_GELU: + case LLM_FFN_GEGLU: + hidden[idx] = has_gate ? act_gelu(gate_val) * up_val : act_gelu(up_val); + break; + case LLM_FFN_RELU: + case LLM_FFN_REGLU: + hidden[idx] = has_gate ? act_relu(gate_val) * up_val : act_relu(up_val); + break; + case LLM_FFN_RELU_SQR: { + const float r = act_relu(up_val); + hidden[idx] = r * r; + break; + } + default: + hidden[idx] = up_val; + break; + } +} + +__global__ void scale_and_accumulate_kernel(float * dst, const float * src, float scale, int64_t n) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + dst[idx] += scale * src[idx]; + } +} + +inline int compute_grid(int64_t n, int block_size = 256) { + return static_cast((n + block_size - 1) / block_size); +} + +static void add_bias(cudaStream_t stream, float * data, const float * bias, int64_t n) { + if (bias == nullptr || n == 0) { + return; + } + const int block = 256; + const int grid = compute_grid(n, block); + add_bias_kernel<<>>(data, bias, n); +} + +static void apply_activation( + cudaStream_t stream, + llm_ffn_op_type type, + const float * gate, + const float * up, + float * hidden, + int64_t n, + bool has_gate) { + if (n == 0) { + return; + } + const int block = 256; + const int grid = compute_grid(n, block); + apply_activation_kernel<<>>(type, gate, up, hidden, n, has_gate); +} + +static void accumulate(cudaStream_t stream, float * dst, const float * src, float scale, int64_t n) { + if (n == 0) { + return; + } + const int block = 256; + const int grid = compute_grid(n, block); + scale_and_accumulate_kernel<<>>(dst, src, scale, n); +} + +static void copy_and_scale_input( + cublasHandle_t handle, + cudaStream_t stream, + const float * input, + float * tmp, + int64_t n, + float scale) { + if (tmp == nullptr || input == nullptr || n == 0) { + return; + } + llama_cuda_try(cudaMemcpyAsync(tmp, input, n * sizeof(float), cudaMemcpyDeviceToDevice, stream), + "cudaMemcpyAsync expert input scale"); + llama_cublas_try(cublasSetStream(handle, stream), "cublasSetStream"); + llama_cublas_try(cublasSscal(handle, static_cast(n), &scale, tmp, 1), "cublasSscal expert input"); +} + +static void run_matvec( + cublasHandle_t handle, + const float * weight, + const float * input, + float * output, + int64_t rows, + int64_t cols) { + if (weight == nullptr || input == nullptr || output == nullptr || rows == 0 || cols == 0) { + LLAMA_LOG_ERROR("%s: invalid matvec args rows=%" PRId64 " cols=%" PRId64 "\n", __func__, rows, cols); + GGML_ABORT("invalid matvec args"); + } + const float alpha = 1.0f; + const float beta = 0.0f; + // Treat row-major weight as column-major transpose. + llama_cublas_try( + cublasSgemv(handle, + CUBLAS_OP_T, + static_cast(cols), + static_cast(rows), + &alpha, + weight, + static_cast(cols), + input, + 1, + &beta, + output, + 1), + "cublasSgemv expert matvec"); +} + +static void validate_handle(const llama_moe_expert_handle * handle, int64_t expected_cols, int64_t expected_rows = -1) { + if (handle == nullptr) { + LLAMA_LOG_ERROR("%s: expert handle missing\n", __func__); + GGML_ABORT("missing expert handle"); + } + if (handle->type != GGML_TYPE_F32 && !ggml_is_quantized(handle->type)) { + LLAMA_LOG_ERROR("%s: expert handle type %d unsupported\n", __func__, (int) handle->type); + GGML_ABORT("unsupported expert handle type"); + } + if (handle->cols != expected_cols) { + LLAMA_LOG_ERROR("%s: expert columns mismatch (expected=%" PRId64 " got=%" PRId64 ")\n", + __func__, expected_cols, handle->cols); + GGML_ABORT("expert columns mismatch"); + } + if (expected_rows != -1 && handle->rows != expected_rows) { + LLAMA_LOG_ERROR("%s: expert rows mismatch (expected=%" PRId64 " got=%" PRId64 ")\n", + __func__, expected_rows, handle->rows); + GGML_ABORT("expert rows mismatch"); + } +} + +} // namespace + +void llama_moe_dispatch_cuda( + const llama_moe_dispatch_desc & desc, + ggml_tensor * dst, + ggml_tensor * input, + ggml_tensor * selected, + ggml_tensor * weights) { + GGML_ASSERT(desc.cache != nullptr); + GGML_ASSERT(desc.backend != nullptr); + + auto * cuda_ctx = static_cast(desc.backend->context); + GGML_ASSERT(cuda_ctx != nullptr); + + cudaStream_t stream = cuda_ctx->stream(); + cublasHandle_t handle = cuda_ctx->cublas_handle(); + llama_cublas_try(cublasSetStream(handle, stream), "cublasSetStream"); + + ExpertCache * cache = desc.cache; + cache->attach_stream(stream, cuda_ctx->device); + const int device_id = cuda_ctx->device; + + const int64_t n_embd = desc.n_embd; + const int64_t n_tokens = desc.n_tokens; + const int64_t n_ff = desc.n_ff; + const int64_t top_k = selected->ne[0]; + + const float * input_d = static_cast(input->data); + const float * weights_d = static_cast(weights->data); + const int32_t * selected_d = static_cast(selected->data); + float * output_d = static_cast(dst->data); + + const size_t input_stride = input->nb[1] / sizeof(float); + const size_t output_stride = dst->nb[1] / sizeof(float); + // zero output + llama_cuda_try(cudaMemsetAsync(output_d, 0, n_embd * n_tokens * sizeof(float), stream), + "cudaMemsetAsync moe dst"); + + pinned_buffer selected_h; + pinned_buffer weights_h; + selected_h.allocate(top_k * n_tokens * sizeof(int32_t)); + weights_h.allocate(top_k * n_tokens * sizeof(float)); + + llama_cuda_try(cudaMemcpyAsync( + selected_h.data(), + selected_d, + top_k * n_tokens * sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream), + "cudaMemcpyAsync selected experts"); + + llama_cuda_try(cudaMemcpyAsync( + weights_h.data(), + weights_d, + top_k * n_tokens * sizeof(float), + cudaMemcpyDeviceToHost, + stream), + "cudaMemcpyAsync expert weights"); + + llama_cuda_try(cudaStreamSynchronize(stream), "cudaStreamSynchronize moe selection copy"); + + device_buffer input_scaled; + device_buffer up_buf; + device_buffer gate_buf; + device_buffer hidden_buf; + device_buffer down_buf; + device_buffer up_weight_deq; + device_buffer gate_weight_deq; + device_buffer down_weight_deq; + device_buffer up_bias_deq; + device_buffer gate_bias_deq; + device_buffer down_bias_deq; + + if (desc.weight_before_ffn) { + input_scaled.allocate(n_embd); + } + up_buf.allocate(n_ff); + hidden_buf.allocate(n_ff); + if (desc.has_gate) { + gate_buf.allocate(n_ff); + } + down_buf.allocate(n_embd); + + for (int64_t t = 0; t < n_tokens; ++t) { + const float * token_input = input_d + t * input_stride; + float * token_output = output_d + t * output_stride; + + for (int64_t k = 0; k < top_k; ++k) { + const int32_t expert_index = selected_h.data()[t * top_k + k]; + if (expert_index < 0 || expert_index >= desc.n_expert) { + continue; + } + + const float router_weight = weights_h.data()[t * top_k + k]; + if (!desc.weight_before_ffn && std::abs(router_weight) < std::numeric_limits::min()) { + continue; + } + + const auto fetch_handle = [&](llama_moe_weight_kind kind) -> const llama_moe_expert_handle * { + const int32_t composed_id = llama_moe_compose_id(desc.layer, expert_index, kind); + return cache->find(composed_id); + }; + + const llama_moe_expert_handle * up_h = fetch_handle(llama_moe_weight_kind::UP_WEIGHT); + const llama_moe_expert_handle * gate_h = desc.has_gate ? fetch_handle(llama_moe_weight_kind::GATE_WEIGHT) : nullptr; + const llama_moe_expert_handle * down_h = fetch_handle(llama_moe_weight_kind::DOWN_WEIGHT); + const llama_moe_expert_handle * up_b_h = desc.has_up_bias ? fetch_handle(llama_moe_weight_kind::UP_BIAS) : nullptr; + const llama_moe_expert_handle * gate_b_h = desc.has_gate_bias ? fetch_handle(llama_moe_weight_kind::GATE_BIAS) : nullptr; + const llama_moe_expert_handle * down_b_h = desc.has_down_bias ? fetch_handle(llama_moe_weight_kind::DOWN_BIAS) : nullptr; + + validate_handle(up_h, n_embd, n_ff); + validate_handle(down_h, n_ff, n_embd); + if (desc.has_gate) { + validate_handle(gate_h, n_embd, n_ff); + } + + auto load_tensor = [&](const llama_moe_expert_handle * handle, + llama_moe_weight_kind kind, + device_buffer & scratch) -> const float * { + if (handle == nullptr) { + return nullptr; + } + const int32_t composed_id = llama_moe_compose_id(desc.layer, expert_index, kind); + void * raw = cache->ensure_loaded(composed_id, device_id); + if (raw == nullptr) { + return nullptr; + } + if (!ggml_is_quantized(handle->type)) { + return static_cast(raw); + } + + const to_fp32_cuda_t to_fp32 = handle->is_contiguous && handle->nb[1] == 0 + ? ggml_get_to_fp32_cuda(handle->type) + : nullptr; + const to_fp32_nc_cuda_t to_fp32_nc = (!handle->is_contiguous || handle->nb[1] != 0) + ? ggml_get_to_fp32_nc_cuda(handle->type) + : nullptr; + + if (to_fp32 == nullptr && to_fp32_nc == nullptr) { + LLAMA_LOG_ERROR("%s: no CUDA dequantizer for tensor type %d\n", __func__, (int) handle->type); + GGML_ABORT("missing CUDA dequantizer"); + } + + size_t elems = 1; + for (int i = 0; i < handle->n_dims; ++i) { + const int64_t dim = handle->ne[i]; + if (dim <= 0) { + break; + } + elems *= static_cast(dim); + } + scratch.allocate(elems); + if (to_fp32) { + to_fp32(raw, scratch.ptr, static_cast(elems), stream); + } else { + const int64_t ne0 = handle->ne[0]; + const int64_t ne1 = handle->ne[1]; + const int64_t ne2 = handle->ne[2]; + const int64_t ne3 = handle->ne[3]; + const int64_t nb1 = handle->nb[1]; + const int64_t nb2 = handle->nb[2]; + const int64_t nb3 = handle->nb[3]; + to_fp32_nc(raw, scratch.ptr, ne0, ne1, ne2, ne3, nb1, nb2, nb3, stream); + } + return scratch.ptr; + }; + + const float * up_w = load_tensor(up_h, llama_moe_weight_kind::UP_WEIGHT, up_weight_deq); + const float * gate_w = desc.has_gate ? load_tensor(gate_h, llama_moe_weight_kind::GATE_WEIGHT, gate_weight_deq) : nullptr; + const float * down_w = load_tensor(down_h, llama_moe_weight_kind::DOWN_WEIGHT, down_weight_deq); + const float * up_b = desc.has_up_bias ? load_tensor(up_b_h, llama_moe_weight_kind::UP_BIAS, up_bias_deq) : nullptr; + const float * gate_b = desc.has_gate_bias ? load_tensor(gate_b_h, llama_moe_weight_kind::GATE_BIAS, gate_bias_deq) : nullptr; + const float * down_b = desc.has_down_bias ? load_tensor(down_b_h, llama_moe_weight_kind::DOWN_BIAS, down_bias_deq) : nullptr; + + if (up_w == nullptr || down_w == nullptr || (desc.has_gate && gate_w == nullptr)) { + LLAMA_LOG_ERROR("%s: missing expert weights for layer %d expert %d\n", __func__, desc.layer, expert_index); + GGML_ABORT("missing expert weights"); + } + + float * expert_input = const_cast(token_input); + if (desc.weight_before_ffn) { + expert_input = input_scaled.ptr; + copy_and_scale_input(handle, stream, token_input, input_scaled.ptr, n_embd, router_weight); + } + + run_matvec(handle, up_w, expert_input, up_buf.ptr, n_ff, n_embd); + add_bias(stream, up_buf.ptr, up_b, n_ff); + + if (desc.has_gate) { + run_matvec(handle, gate_w, expert_input, gate_buf.ptr, n_ff, n_embd); + add_bias(stream, gate_buf.ptr, gate_b, n_ff); + } + + apply_activation(stream, desc.activation, gate_buf.ptr, up_buf.ptr, hidden_buf.ptr, n_ff, desc.has_gate); + + run_matvec(handle, down_w, hidden_buf.ptr, down_buf.ptr, n_embd, n_ff); + add_bias(stream, down_buf.ptr, down_b, n_embd); + + if (desc.weight_before_ffn) { + accumulate(stream, token_output, down_buf.ptr, 1.0f, n_embd); + } else { + accumulate(stream, token_output, down_buf.ptr, router_weight, n_embd); + } + } + } +} + +#endif // defined(LLAMA_MOE_ENABLE) && defined(GGML_USE_CUDA) diff --git a/src/llama-moe.cpp b/src/llama-moe.cpp new file mode 100644 index 0000000000000..603752a666de8 --- /dev/null +++ b/src/llama-moe.cpp @@ -0,0 +1,880 @@ +#include "llama-moe.h" + +#ifdef LLAMA_MOE_ENABLE + +#include "llama-impl.h" +#include "llama-context.h" + +#ifdef GGML_USE_CUDA +#include "ggml-backend-impl.h" +#include "ggml-cuda.h" +#include + +static inline void llama_cuda_try(cudaError_t result, const char * expr) { + if (result != cudaSuccess) { + LLAMA_LOG_ERROR("%s: CUDA call failed: %s (%d)\n", __func__, cudaGetErrorString(result), result); + GGML_ABORT("%s", expr); + } +} +#endif + +#ifdef GGML_USE_CUDA +void llama_moe_dispatch_cuda(const llama_moe_dispatch_desc & desc, + ggml_tensor * dst, + ggml_tensor * input, + ggml_tensor * selected, + ggml_tensor * weights); +#endif + +#include +#include +#include +#include +#include + +namespace { +constexpr uint32_t kDefaultMaxExperts = 16; +constexpr float kInvSqrt2 = 0.70710678118654752440f; +constexpr float kSwigluOaiAlpha = 1.702f; +constexpr float kSwigluOaiLimit = 7.0f; + +[[nodiscard]] size_t llama_moe_required_device_bytes(const llama_moe_expert_handle & handle) { + return handle.bytes; +} + +static void llama_moe_copy_to_staging(const llama_moe_expert_handle & handle, void * dst) { + if (handle.host_ptr == nullptr) { + GGML_ABORT("expert handle host data unavailable"); + } + + std::memcpy(dst, handle.host_ptr, handle.bytes); +} +} + +struct llama_moe_dispatch_userdata { + llama_moe_dispatch_desc desc; + std::vector input_scaled; + std::vector up_buf; + std::vector gate_buf; + std::vector hidden_buf; + std::vector down_buf; +}; +namespace { + +[[nodiscard]] inline float act_silu(float x) { + return x / (1.0f + std::exp(-x)); +} + +[[nodiscard]] inline float act_gelu(float x) { + const float cdf = 0.5f * (1.0f + std::erf(x * kInvSqrt2)); + return x * cdf; +} + +[[nodiscard]] inline float act_relu(float x) { + return x > 0.0f ? x : 0.0f; +} + +bool llama_moe_is_supported_handle(const llama_moe_expert_handle * h) { + if (h == nullptr) { + return false; + } + if (h->host_ptr == nullptr) { + LLAMA_LOG_ERROR("%s: expert handle missing host data (id=%d)\n", __func__, h->id); + return false; + } + if (h->is_quantized) { + LLAMA_LOG_ERROR("%s: quantized experts not yet supported in CPU dispatcher (id=%d)\n", __func__, h->id); + return false; + } + if (h->type != GGML_TYPE_F32) { + LLAMA_LOG_ERROR("%s: unsupported tensor type %d for expert id=%d (only F32 supported)\n", __func__, (int) h->type, h->id); + return false; + } + if (h->rows <= 0 || h->cols <= 0) { + LLAMA_LOG_ERROR("%s: invalid tensor shape for expert id=%d (rows=%lld cols=%lld)\n", __func__, h->id, (long long) h->rows, (long long) h->cols); + return false; + } + GGML_ASSERT(h->nb[0] == sizeof(float)); + return true; +} + +void llama_moe_add_bias(const llama_moe_expert_handle * bias, float * vec, int64_t len) { + if (!bias) { + return; + } + if (!llama_moe_is_supported_handle(bias)) { + GGML_ABORT("unsupported bias handle"); + } + GGML_ASSERT(bias->rows == len || (bias->rows == 1 && len == bias->cols)); + const float * data = reinterpret_cast(bias->host_ptr); + for (int64_t i = 0; i < len; ++i) { + vec[i] += data[i]; + } +} + +bool llama_moe_matvec(const llama_moe_expert_handle * weight, int64_t cols, const float * input, float * output) { + if (!llama_moe_is_supported_handle(weight)) { + return false; + } + GGML_ASSERT(weight->cols == cols); + const char * base = static_cast(weight->host_ptr); + const size_t row_stride = weight->nb[1]; + for (int64_t row = 0; row < weight->rows; ++row) { + const float * row_ptr = reinterpret_cast(base + row_stride * row); + float sum = 0.0f; + for (int64_t col = 0; col < cols; ++col) { + sum += row_ptr[col] * input[col]; + } + output[row] = sum; + } + return true; +} + +void llama_moe_apply_activation( + llm_ffn_op_type type, + const float * gate, + const float * up, + float * hidden, + int64_t n, + bool has_gate) { + switch (type) { + case LLM_FFN_SILU: + case LLM_FFN_SWIGLU: + if (has_gate) { + for (int64_t i = 0; i < n; ++i) { + hidden[i] = act_silu(gate[i]) * up[i]; + } + } else { + for (int64_t i = 0; i < n; ++i) { + hidden[i] = act_silu(up[i]); + } + } + break; + case LLM_FFN_SWIGLU_OAI_MOE: + if (has_gate) { + for (int64_t i = 0; i < n; ++i) { + const float x = std::min(gate[i], kSwigluOaiLimit); + const float y = std::clamp(up[i], -kSwigluOaiLimit, kSwigluOaiLimit); + const float out_glu = x / (1.0f + std::exp(kSwigluOaiAlpha * (-x))); + hidden[i] = out_glu * (y + 1.0f); + } + } else { + for (int64_t i = 0; i < n; ++i) { + hidden[i] = act_silu(up[i]); + } + } + break; + case LLM_FFN_GELU: + case LLM_FFN_GEGLU: + if (has_gate) { + for (int64_t i = 0; i < n; ++i) { + hidden[i] = act_gelu(gate[i]) * up[i]; + } + } else { + for (int64_t i = 0; i < n; ++i) { + hidden[i] = act_gelu(up[i]); + } + } + break; + case LLM_FFN_RELU: + case LLM_FFN_REGLU: + if (has_gate) { + for (int64_t i = 0; i < n; ++i) { + hidden[i] = act_relu(gate[i]) * up[i]; + } + } else { + for (int64_t i = 0; i < n; ++i) { + hidden[i] = act_relu(up[i]); + } + } + break; + case LLM_FFN_RELU_SQR: + for (int64_t i = 0; i < n; ++i) { + const float r = act_relu(up[i]); + hidden[i] = r * r; + } + break; + default: + LLAMA_LOG_ERROR("%s: unsupported activation type %d\n", __func__, (int) type); + GGML_ABORT("unsupported activation type"); + } +} + +} // namespace + +static void llama_moe_dispatch_kernel( + ggml_tensor * dst, + int ith, + int nth, + void * user) { + GGML_UNUSED(nth); + + if (ith != 0) { + return; + } + + auto * ud = static_cast(user); + GGML_ASSERT(ud != nullptr); + const auto & desc = ud->desc; + + ExpertCache * cache = desc.cache; + if (cache == nullptr && desc.ctx != nullptr) { + cache = desc.ctx->get_expert_cache(); + } + + if (cache == nullptr) { + LLAMA_LOG_ERROR("%s: ExpertCache unavailable\n", __func__); + GGML_ABORT("missing expert cache"); + } + + ggml_tensor * input = dst->src[0]; + ggml_tensor * selected = dst->src[1]; + ggml_tensor * weights = dst->src[2]; + + GGML_ASSERT(input != nullptr && selected != nullptr && weights != nullptr); + GGML_ASSERT(input->type == GGML_TYPE_F32); + GGML_ASSERT(weights->type == GGML_TYPE_F32); + GGML_ASSERT(selected->type == GGML_TYPE_I32); + +#ifdef GGML_USE_CUDA + if (desc.use_cuda) { + llama_moe_dispatch_cuda(desc, dst, input, selected, weights); + return; + } +#endif + + const int64_t n_embd = input->ne[0]; + const int64_t n_tokens = input->ne[1]; + const int64_t top_k = selected->ne[0]; + + GGML_ASSERT(dst->ne[0] == n_embd); + GGML_ASSERT(dst->ne[1] == n_tokens); + GGML_ASSERT(weights->ne[1] == top_k); + GGML_ASSERT(weights->ne[2] == n_tokens); + + const char * input_data = static_cast(input->data); + const char * selected_data = static_cast(selected->data); + const char * weight_data = static_cast(weights->data); + float * output_data = static_cast(dst->data); + + auto & input_scaled = ud->input_scaled; + auto & up_buf = ud->up_buf; + auto & gate_buf = ud->gate_buf; + auto & hidden_buf = ud->hidden_buf; + auto & down_buf = ud->down_buf; + + input_scaled.resize(n_embd); + + for (int64_t t = 0; t < n_tokens; ++t) { + float * out = reinterpret_cast(reinterpret_cast(output_data) + t * dst->nb[1]); + std::fill(out, out + n_embd, 0.0f); + + const float * input_vec = reinterpret_cast(input_data + t * input->nb[1]); + + for (int64_t k = 0; k < top_k; ++k) { + const int32_t expert_index = *reinterpret_cast(selected_data + k * selected->nb[0] + t * selected->nb[1]); + if (expert_index < 0) { + continue; + } + + const float weight = *reinterpret_cast(weight_data + k * weights->nb[1] + t * weights->nb[2]); + if (!desc.weight_before_ffn && std::abs(weight) < std::numeric_limits::min()) { + continue; + } + + const float * expert_input = input_vec; + if (desc.weight_before_ffn) { + for (int64_t i = 0; i < n_embd; ++i) { + input_scaled[i] = input_vec[i] * weight; + } + expert_input = input_scaled.data(); + } + + const auto fetch = [&](llama_moe_weight_kind kind) -> const llama_moe_expert_handle * { + const int32_t id = llama_moe_compose_id(desc.layer, expert_index, kind); + return cache->find(id); + }; + + const llama_moe_expert_handle * up_w = fetch(llama_moe_weight_kind::UP_WEIGHT); + const llama_moe_expert_handle * gate_w = fetch(llama_moe_weight_kind::GATE_WEIGHT); + const llama_moe_expert_handle * down_w = fetch(llama_moe_weight_kind::DOWN_WEIGHT); + const llama_moe_expert_handle * up_b = desc.has_up_bias ? fetch(llama_moe_weight_kind::UP_BIAS) : nullptr; + const llama_moe_expert_handle * gate_b = desc.has_gate_bias ? fetch(llama_moe_weight_kind::GATE_BIAS) : nullptr; + const llama_moe_expert_handle * down_b = desc.has_down_bias ? fetch(llama_moe_weight_kind::DOWN_BIAS) : nullptr; + + if (!llama_moe_is_supported_handle(up_w) || !llama_moe_is_supported_handle(down_w)) { + LLAMA_LOG_ERROR("%s: missing expert weights for layer %d expert %d\n", __func__, desc.layer, expert_index); + GGML_ABORT("missing expert weights"); + } + + const int64_t n_ff = up_w->rows; + if ((int64_t) up_w->cols != n_embd || (int64_t) down_w->cols != n_ff || (int64_t) down_w->rows != n_embd) { + LLAMA_LOG_ERROR("%s: unexpected expert dimension mismatch (layer %d expert %d)\n", __func__, desc.layer, expert_index); + GGML_ABORT("expert dim mismatch"); + } + + if ((int64_t) up_buf.size() < n_ff) up_buf.resize(n_ff); + if ((int64_t) hidden_buf.size() < n_ff) hidden_buf.resize(n_ff); + if ((int64_t) gate_buf.size() < n_ff) gate_buf.resize(n_ff); + if ((int64_t) down_buf.size() < n_embd) down_buf.resize(n_embd); + + std::fill_n(up_buf.begin(), n_ff, 0.0f); + std::fill_n(hidden_buf.begin(), n_ff, 0.0f); + std::fill_n(gate_buf.begin(), n_ff, 0.0f); + std::fill_n(down_buf.begin(), n_embd, 0.0f); + + if (!llama_moe_matvec(up_w, n_embd, expert_input, up_buf.data())) { + GGML_ABORT("failed to compute up branch"); + } + llama_moe_add_bias(up_b, up_buf.data(), n_ff); + + bool has_gate = gate_w != nullptr && llama_moe_is_supported_handle(gate_w); + if (has_gate) { + if (gate_w->rows != n_ff) { + LLAMA_LOG_ERROR("%s: gate rows mismatch (layer %d expert %d)\n", __func__, desc.layer, expert_index); + GGML_ABORT("gate dim mismatch"); + } + if (!llama_moe_matvec(gate_w, n_embd, expert_input, gate_buf.data())) { + GGML_ABORT("failed to compute gate branch"); + } + llama_moe_add_bias(gate_b, gate_buf.data(), n_ff); + } + + llama_moe_apply_activation(desc.activation, gate_buf.data(), up_buf.data(), hidden_buf.data(), n_ff, has_gate); + + if (!llama_moe_matvec(down_w, n_ff, hidden_buf.data(), down_buf.data())) { + GGML_ABORT("failed to compute down branch"); + } + llama_moe_add_bias(down_b, down_buf.data(), n_embd); + + if (desc.weight_before_ffn) { + for (int64_t i = 0; i < n_embd; ++i) { + out[i] += down_buf[i]; + } + } else { + for (int64_t i = 0; i < n_embd; ++i) { + out[i] += weight * down_buf[i]; + } + } + } + } +} + +ggml_tensor * llama_moe_build_dispatch( + ggml_context * ctx, + ggml_tensor * input, + ggml_tensor * selected_experts, + ggml_tensor * weights, + const llama_moe_dispatch_desc & desc, + llm_graph_result * owner) { + GGML_ASSERT(ctx != nullptr); + GGML_ASSERT(input != nullptr); + GGML_ASSERT(selected_experts != nullptr); + GGML_ASSERT(weights != nullptr); + + const int64_t ne0 = input->ne[0]; + const int64_t ne1 = input->ne[1]; + + GGML_ASSERT(owner != nullptr); + + llama_moe_dispatch_desc local = desc; + local.n_embd = static_cast(ne0); + local.n_tokens = static_cast(ne1); + local.n_expert_used = static_cast(selected_experts->ne[0]); + + if (local.n_expert == 0) { + local.n_expert = local.n_expert_used; + } + + auto * userdata = new llama_moe_dispatch_userdata{}; + userdata->desc = local; + + owner->add_cleanup([userdata]() { delete userdata; }); + + ggml_tensor * srcs[3] = {input, selected_experts, weights}; + + ggml_tensor * out = ggml_custom_4d( + ctx, + GGML_TYPE_F32, + ne0, + ne1, + 1, + 1, + srcs, + 3, + llama_moe_dispatch_kernel, + 1, + userdata); + + return out; +} + +ExpertCache::ExpertCache() { + configure(Config{}); +} + +ExpertCache::ExpertCache(const Config & cfg) { + configure(cfg); +} + +ExpertCache::~ExpertCache() { + clear(); + release_pool(); +} + +void ExpertCache::configure(const Config & cfg) { + std::lock_guard lock(mutex_); + config_ = cfg; + if (config_.max_resident_experts == 0) { + config_.max_resident_experts = kDefaultMaxExperts; + } + device_policies_ = cfg.device_policies; + device_policy_by_id_.clear(); + device_policy_total_weight_ = 0.0; + for (const auto & policy : device_policies_) { + if (policy.device < 0) { + continue; + } + device_policy_by_id_[policy.device] = policy; + const double weight = policy.weight > 0.0f ? policy.weight : 1.0; + device_policy_total_weight_ += weight; + } + auto_assign_devices_ = cfg.auto_assign_devices; + auto_assign_devices_ = cfg.auto_assign_devices; + release_pool(); + allocate_pool(); + reset_stats(); +} + +void ExpertCache::register_expert(const llama_moe_expert_handle & handle) { + if (handle.tensor == nullptr || handle.bytes == 0) { + return; + } + + std::lock_guard lock(mutex_); + llama_moe_expert_handle stored = handle; + if (stored.id < 0 && stored.layer >= 0 && stored.expert_index >= 0 && stored.kind != llama_moe_weight_kind::UNKNOWN) { + stored.id = llama_moe_compose_id(stored.layer, stored.expert_index, stored.kind); + } + + if (stored.id < 0) { + LLAMA_LOG_WARN("%s: skipping expert registration with invalid id (layer=%d expert=%d kind=%d)\n", + __func__, stored.layer, stored.expert_index, static_cast(stored.kind)); + return; + } + + auto & slot = experts_[stored.id]; + slot.handle = stored; + slot.slot_by_device.clear(); +} + +void ExpertCache::register_experts(const std::vector & handles) { + for (const auto & h : handles) { + register_expert(h); + } +} + +void ExpertCache::clear() { + std::lock_guard lock(mutex_); + release_pool(); + for (auto & kv : experts_) { + kv.second.slot_by_device.clear(); + } + timestamp_ = 0; + allocate_pool(); +} + +bool ExpertCache::has_resident(int32_t expert_id) const { + std::lock_guard lock(mutex_); + auto it = experts_.find(expert_id); + if (it == experts_.end()) { + return false; + } + return !it->second.slot_by_device.empty(); +} + +const llama_moe_expert_handle * ExpertCache::find(int32_t expert_id) const { + std::lock_guard lock(mutex_); + auto it = experts_.find(expert_id); + if (it == experts_.end()) { + return nullptr; + } + return &it->second.handle; +} + +llama_moe_cache_stats ExpertCache::stats() const { + std::lock_guard lock(mutex_); + llama_moe_cache_stats result = stats_; + result.per_device.clear(); + size_t resident = 0; + size_t capacity = 0; + for (const auto & pool_entry : device_pools_) { + const int device = pool_entry.first; + const DevicePool & pool = pool_entry.second; + size_t device_resident = 0; + for (const auto & slot : pool_entry.second.slots) { + if (slot.expert_id != -1) { + ++resident; + ++device_resident; + } + } + capacity += pool.pool_bytes; + llama_moe_cache_stats::device_stats dev_stats{}; + dev_stats.device = device; + dev_stats.resident = device_resident; + dev_stats.capacity_bytes = pool.pool_bytes; + dev_stats.loads = pool.stats.loads; + dev_stats.hits = pool.stats.hits; + dev_stats.evictions = pool.stats.evictions; + result.per_device.push_back(dev_stats); + } + result.resident = resident; + result.capacity_bytes = capacity; + return result; +} + +void ExpertCache::reset_stats() { + std::lock_guard lock(mutex_); + stats_ = {}; + for (auto & pool_entry : device_pools_) { + pool_entry.second.stats = {}; + } +} + +#ifdef GGML_USE_CUDA +void ExpertCache::attach_stream(cudaStream_t stream, int device) { + std::lock_guard lock(mutex_); + stream_ = stream; + device_streams_[device] = stream; + current_device_ = device; +} + +cudaStream_t ExpertCache::stream() const { + return stream_; +} +#endif + +void * ExpertCache::ensure_loaded(int32_t expert_id, int device, ggml_backend_buffer_t device_buffer) { + GGML_UNUSED(device_buffer); +#ifndef GGML_USE_CUDA + GGML_UNUSED(expert_id); + GGML_UNUSED(device); + std::lock_guard lock(mutex_); + stats_.hits++; + return nullptr; +#else + std::lock_guard lock(mutex_); + + if (device < 0) { + device = current_device_; + } + device = select_device_for_expert(expert_id, device); + + if (device < 0) { + LLAMA_LOG_ERROR("%s: device id not set for expert cache load\n", __func__); + return nullptr; + } + + auto it = experts_.find(expert_id); + if (it == experts_.end()) { + LLAMA_LOG_WARN("%s: expert %d not registered\n", __func__, expert_id); + return nullptr; + } + + auto & record = it->second; + timestamp_++; + + DevicePool & pool = get_or_create_pool(device); + + if (record.slot_by_device.count(device) != 0) { + size_t slot_idx = record.slot_by_device[device]; + DeviceSlot & slot = pool.slots[slot_idx]; + slot.last_used = timestamp_; + slot.hits++; + stats_.hits++; + pool.stats.hits++; + return slot.device_ptr; + } + + if (record.handle.host_ptr == nullptr && record.handle.tensor) { + record.handle.host_ptr = record.handle.tensor->data; + } + + if (record.handle.host_ptr == nullptr) { + LLAMA_LOG_WARN("%s: expert %d host data unavailable\n", __func__, expert_id); + return nullptr; + } + + DeviceSlot * target = nullptr; + size_t target_idx = 0; + for (size_t i = 0; i < pool.slots.size(); ++i) { + if (pool.slots[i].expert_id == -1) { + target = &pool.slots[i]; + target_idx = i; + break; + } + } + if (!target) { + target = find_lru(device); + if (target) { + target_idx = static_cast(target - pool.slots.data()); + } + } + + if (!target) { + LLAMA_LOG_ERROR("%s: unable to evict expert for expert_id=%d\n", __func__, expert_id); + return nullptr; + } + + const size_t required_bytes = llama_moe_required_device_bytes(record.handle); + + if (target->device_ptr == nullptr || target->bytes < required_bytes) { + if (target->device_ptr != nullptr) { + cudaFree(target->device_ptr); + } + llama_cuda_try(cudaMalloc(&target->device_ptr, required_bytes), "cudaMalloc expert cache"); + target->bytes = required_bytes; + } + + if (target->host_staging == nullptr || target->staging_capacity < required_bytes) { + if (target->host_staging != nullptr) { + cudaFreeHost(target->host_staging); + } + llama_cuda_try(cudaMallocHost(&target->host_staging, required_bytes), "cudaMallocHost expert staging"); + target->staging_capacity = required_bytes; + } + + llama_moe_copy_to_staging(record.handle, target->host_staging); + + bool evicting = target->expert_id != -1 && target->expert_id != expert_id; + if (evicting) { + auto evicted_it = experts_.find(target->expert_id); + if (evicted_it != experts_.end()) { + evicted_it->second.slot_by_device.erase(device); + } + stats_.evictions++; + pool.stats.evictions++; + } + + cudaStream_t stream = stream_; + auto stream_it = device_streams_.find(device); + if (stream_it != device_streams_.end()) { + stream = stream_it->second; + } + + llama_cuda_try(cudaMemcpyAsync( + target->device_ptr, + target->host_staging, + required_bytes, + cudaMemcpyHostToDevice, + stream), + "cudaMemcpyAsync expert upload"); + + stats_.loads++; + pool.stats.loads++; + + target->expert_id = expert_id; + target->last_used = timestamp_; + target->hits = 1; + + record.slot_by_device[device] = target_idx; + + return target->device_ptr; +#endif +} + +void ExpertCache::prefetch(const std::vector & expert_ids) { +#ifdef GGML_USE_CUDA + if (!config_.enable_prefetch || expert_ids.empty()) { + return; + } + int device = -1; + { + std::lock_guard lock(mutex_); + device = current_device_; + if (device < 0) { + return; + } + stats_.prefetch_requests += expert_ids.size(); + } + for (int32_t id : expert_ids) { + ensure_loaded(id, device); + } +#else + GGML_UNUSED(expert_ids); +#endif +} + +size_t ExpertCache::resident_count() const { + std::lock_guard lock(mutex_); + size_t count = 0; + for (const auto & pool_entry : device_pools_) { + for (const auto & slot : pool_entry.second.slots) { + if (slot.expert_id != -1) { + ++count; + } + } + } + return count; +} + +size_t ExpertCache::capacity_bytes() const { + return pool_bytes_; +} + +void ExpertCache::allocate_pool() { + pool_bytes_ = 0; + if (!device_policies_.empty()) { + for (const auto & policy : device_policies_) { + if (policy.device < 0) { + continue; + } + DevicePool pool; + pool.pool_bytes = capacity_for_device(policy.device); + const size_t slot_count = max_slots_for_device(policy.device); + pool.slots.resize(slot_count); + for (auto & slot : pool.slots) { + slot.expert_id = -1; + slot.device_ptr = nullptr; + slot.bytes = 0; + slot.last_used = 0; + slot.hits = 0; + slot.host_staging = nullptr; + slot.staging_capacity = 0; + } + pool.stats = {}; + device_pools_.emplace(policy.device, std::move(pool)); + pool_bytes_ += capacity_for_device(policy.device); + } + } else { + pool_bytes_ = config_.vram_pool_bytes; + } +} + +void ExpertCache::release_pool() { +#ifdef GGML_USE_CUDA + for (auto & pool_entry : device_pools_) { + for (auto & slot : pool_entry.second.slots) { + if (slot.device_ptr != nullptr) { + cudaFree(slot.device_ptr); + slot.device_ptr = nullptr; + } + if (slot.host_staging != nullptr) { + cudaFreeHost(slot.host_staging); + slot.host_staging = nullptr; + slot.staging_capacity = 0; + } + slot.expert_id = -1; + slot.bytes = 0; + slot.last_used = 0; + slot.hits = 0; + } + pool_entry.second.stats = {}; + } +#endif + device_pools_.clear(); + pool_bytes_ = 0; +} + +ExpertCache::DeviceSlot * ExpertCache::find_lru(int device) { +#ifdef GGML_USE_CUDA + auto it = device_pools_.find(device); + if (it == device_pools_.end()) { + return nullptr; + } + DeviceSlot * candidate = nullptr; + for (auto & slot : it->second.slots) { + if (slot.expert_id == -1) { + return &slot; + } + if (candidate == nullptr || slot.last_used < candidate->last_used) { + candidate = &slot; + } + } + return candidate; +#else + GGML_UNUSED(device); + return nullptr; +#endif +} + +ExpertCache::DevicePool & ExpertCache::get_or_create_pool(int device) { + auto it = device_pools_.find(device); + if (it == device_pools_.end()) { + DevicePool pool; + pool.pool_bytes = capacity_for_device(device); + const size_t slot_count = max_slots_for_device(device); + pool.slots.resize(slot_count); + for (auto & slot : pool.slots) { + slot.expert_id = -1; + slot.device_ptr = nullptr; + slot.bytes = 0; + slot.last_used = 0; + slot.hits = 0; + slot.host_staging = nullptr; + slot.staging_capacity = 0; + } + pool.stats = {}; + it = device_pools_.emplace(device, std::move(pool)).first; + pool_bytes_ += capacity_for_device(device); + } + return it->second; +} + +size_t ExpertCache::capacity_for_device(int device) const { + auto it = device_policy_by_id_.find(device); + if (it != device_policy_by_id_.end() && it->second.capacity_bytes > 0) { + return it->second.capacity_bytes; + } + return config_.vram_pool_bytes; +} + +size_t ExpertCache::max_slots_for_device(int device) const { + auto it = device_policy_by_id_.find(device); + uint32_t slots = config_.max_resident_experts; + if (it != device_policy_by_id_.end() && it->second.max_resident_experts > 0) { + slots = it->second.max_resident_experts; + } + if (slots == 0) { + slots = kDefaultMaxExperts; + } + return slots; +} + +int ExpertCache::select_device_for_expert(int32_t expert_id, int device_hint) const { + if (device_hint >= 0) { + return device_hint; + } + if (!device_policies_.empty() && device_policy_total_weight_ > 0.0) { + const uint64_t hash = std::hash{}(expert_id); + const double normalized = static_cast(hash) / static_cast(std::numeric_limits::max()); + const double target = normalized * device_policy_total_weight_; + double accum = 0.0; + for (const auto & policy : device_policies_) { + if (policy.device < 0) { + continue; + } + const double weight = policy.weight > 0.0f ? policy.weight : 1.0; + accum += weight; + if (target <= accum) { + return policy.device; + } + } + return device_policies_.back().device; + } + + if (!device_pools_.empty()) { + const size_t index = static_cast(std::abs(expert_id)) % device_pools_.size(); + auto it = device_pools_.begin(); + std::advance(it, index); + return it->first; + } + + if (current_device_ >= 0) { + return current_device_; + } + + return device_hint; +} + +#endif // LLAMA_MOE_ENABLE diff --git a/src/llama.cpp b/src/llama.cpp index ab2e9868af468..d31080c21119f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1,11 +1,15 @@ #include "llama-impl.h" #include "llama-chat.h" +#include "llama-context.h" #include "llama-mmap.h" #include "llama-vocab.h" #include "llama-model-loader.h" #include "llama-model-saver.h" #include "llama-model.h" +#ifdef LLAMA_MOE_ENABLE +#include "llama-moe.h" +#endif #include "ggml.h" #include "ggml-backend.h" @@ -415,3 +419,26 @@ const char * llama_print_system_info(void) { return s.c_str(); } +#ifdef LLAMA_MOE_ENABLE +void llama_moe_cache_get_stats(const llama_context * ctx, llama_moe_cache_stats * out_stats) { + if (out_stats == nullptr) { + return; + } + if (ctx == nullptr) { + *out_stats = {}; + return; + } + *out_stats = ctx->get_moe_cache_stats(); +} + +void llama_moe_prefetch_get_stats(const llama_context * ctx, llama_moe_prefetch_stats * out_stats) { + if (out_stats == nullptr) { + return; + } + if (ctx == nullptr) { + *out_stats = {}; + return; + } + *out_stats = ctx->get_moe_prefetch_stats(); +} +#endif