mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-09 20:48:52 +02:00
Compare commits
21 Commits
v1.7.6
...
sync-ggml-
Author | SHA1 | Date | |
---|---|---|---|
bff8dc248a | |||
69753804ed | |||
89970b9aaa | |||
79fb43e252 | |||
926e06dbfd | |||
43a59eccf6 | |||
fe0d52b9a2 | |||
cb90cb0992 | |||
8264872b5d | |||
882d975729 | |||
c426829771 | |||
0b1962a181 | |||
86dece9c7c | |||
04445664b4 | |||
22f4997dd8 | |||
b493e03b90 | |||
aef59f4851 | |||
f8c75dc43e | |||
00c8056715 | |||
19d8d9a928 | |||
0c4a229154 |
@ -20,6 +20,7 @@ if (WHISPER_SDL2)
|
||||
llama-memory.cpp
|
||||
llama-mmap.cpp
|
||||
llama-model-loader.cpp
|
||||
llama-model-saver.cpp
|
||||
llama-model.cpp
|
||||
llama-quant.cpp
|
||||
llama-sampling.cpp
|
||||
|
@ -253,6 +253,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
|
||||
std::vector<ggml_backend_buffer_type_t> buft_extra;
|
||||
{
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!cpu_dev) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
||||
|
||||
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
||||
@ -291,6 +294,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
|
||||
LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
|
||||
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!cpu_dev) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
buft = ggml_backend_dev_buffer_type(cpu_dev);
|
||||
|
||||
break;
|
||||
|
@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
|
||||
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
|
||||
GGML_ASSERT(batch.n_tokens >= 0);
|
||||
this->batch = &batch;
|
||||
this->n_embd = n_embd;
|
||||
@ -203,6 +203,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
ids[i] = i;
|
||||
}
|
||||
|
||||
if (simple_split) {
|
||||
seq.resize(1);
|
||||
llama_sbatch_seq & s = seq[0];
|
||||
@ -212,6 +213,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
|
||||
s.length = n_tokens;
|
||||
return;
|
||||
}
|
||||
|
||||
std::sort(ids.begin(), ids.end(),
|
||||
[&batch](size_t a, size_t b) {
|
||||
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
|
||||
@ -239,6 +241,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
|
||||
return n_seq_a > n_seq_b;
|
||||
}
|
||||
);
|
||||
|
||||
// init seq
|
||||
llama_sbatch_seq * last_seq = nullptr;
|
||||
|
||||
@ -262,6 +265,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
|
||||
seq.push_back(new_seq);
|
||||
last_seq = &seq.back();
|
||||
}
|
||||
|
||||
// keep shared prompts first at the end, then sort by length descending.
|
||||
std::sort(seq.begin(), seq.end(),
|
||||
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
|
||||
|
@ -70,7 +70,8 @@ struct llama_sbatch {
|
||||
// sequence-wise split
|
||||
llama_ubatch split_seq(size_t n_ubatch);
|
||||
|
||||
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
|
||||
llama_sbatch() = default;
|
||||
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
|
||||
};
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
|
@ -35,6 +35,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 },
|
||||
{ "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
|
||||
{ "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
|
||||
{ "mistral-v7-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN },
|
||||
{ "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
|
||||
{ "phi4", LLM_CHAT_TEMPLATE_PHI_4 },
|
||||
{ "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
|
||||
@ -202,19 +203,20 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<|im_start|>assistant\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) {
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN) {
|
||||
// Official mistral 'v7' template
|
||||
// See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
|
||||
// https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503#basic-instruct-template-v7-tekken
|
||||
const char * trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 ? " " : "";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
std::string content(message->content);
|
||||
if (role == "system") {
|
||||
ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]";
|
||||
ss << "[SYSTEM_PROMPT]" << trailing_space << content << "[/SYSTEM_PROMPT]";
|
||||
} else if (role == "user") {
|
||||
ss << "[INST] " << content << "[/INST]";
|
||||
}
|
||||
else {
|
||||
ss << " " << content << "</s>";
|
||||
ss << "[INST]" << trailing_space << content << "[/INST]";
|
||||
} else {
|
||||
ss << trailing_space << content << "</s>";
|
||||
}
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
|
||||
@ -447,8 +449,16 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4 || tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4) {
|
||||
ss << "[gMASK]" << "<sop>";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|" << role << "|>" << "\n" << message->content;
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|" << role << "|>" << "\n" << message->content;
|
||||
|
@ -14,6 +14,7 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V3,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V7,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN,
|
||||
LLM_CHAT_TEMPLATE_PHI_3,
|
||||
LLM_CHAT_TEMPLATE_PHI_4,
|
||||
LLM_CHAT_TEMPLATE_FALCON_3,
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -7,6 +7,7 @@
|
||||
#include "llama-adapter.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
#include "ggml-opt.h"
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
@ -27,7 +28,12 @@ struct llama_context {
|
||||
|
||||
void synchronize();
|
||||
|
||||
const llama_model & get_model() const;
|
||||
const llama_model & get_model() const;
|
||||
const llama_cparams & get_cparams() const;
|
||||
|
||||
ggml_backend_sched_t get_sched() const;
|
||||
|
||||
ggml_context * get_ctx_compute() const;
|
||||
|
||||
uint32_t n_ctx() const;
|
||||
uint32_t n_ctx_per_seq() const;
|
||||
@ -128,6 +134,32 @@ struct llama_context {
|
||||
llama_perf_context_data perf_get_data() const;
|
||||
void perf_reset();
|
||||
|
||||
//
|
||||
// training
|
||||
//
|
||||
|
||||
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
|
||||
|
||||
void opt_epoch(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result_train,
|
||||
ggml_opt_result_t result_eval,
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval);
|
||||
|
||||
void opt_epoch_iter(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result,
|
||||
const std::vector<llama_token> & tokens,
|
||||
const std::vector<llama_token> & labels_sparse,
|
||||
llama_batch & batch,
|
||||
ggml_opt_epoch_callback callback,
|
||||
bool train,
|
||||
int64_t idata_in_loop,
|
||||
int64_t ndata_in_loop,
|
||||
int64_t t_loop_start);
|
||||
|
||||
private:
|
||||
//
|
||||
// output
|
||||
@ -137,49 +169,30 @@ private:
|
||||
// Returns max number of outputs for which space was reserved.
|
||||
int32_t output_reserve(int32_t n_outputs);
|
||||
|
||||
// make the outputs have the same order they had in the user-provided batch
|
||||
// TODO: maybe remove this
|
||||
void output_reorder();
|
||||
|
||||
//
|
||||
// graph
|
||||
//
|
||||
|
||||
public:
|
||||
int32_t graph_max_nodes() const;
|
||||
|
||||
// zero-out inputs and create the ctx_compute for the compute graph
|
||||
ggml_cgraph * graph_init();
|
||||
|
||||
llm_graph_result_ptr graph_build(
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype);
|
||||
|
||||
// returns the result of ggml_backend_sched_graph_compute_async execution
|
||||
ggml_status graph_compute(
|
||||
ggml_cgraph * gf,
|
||||
bool batched);
|
||||
|
||||
private:
|
||||
llm_graph_result_ptr graph_build(
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype);
|
||||
|
||||
llm_graph_cb graph_get_cb() const;
|
||||
|
||||
// used by kv_self_update()
|
||||
ggml_tensor * build_rope_shift(
|
||||
ggml_context * ctx0,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale) const;
|
||||
|
||||
llm_graph_result_ptr build_kv_self_shift(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf) const;
|
||||
|
||||
llm_graph_result_ptr build_kv_self_defrag(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf) const;
|
||||
|
||||
// TODO: read/write lora adapters and cvec
|
||||
size_t state_write_data(llama_io_write_i & io);
|
||||
size_t state_read_data (llama_io_read_i & io);
|
||||
@ -196,14 +209,10 @@ private:
|
||||
llama_cparams cparams;
|
||||
llama_adapter_cvec cvec;
|
||||
llama_adapter_loras loras;
|
||||
llama_sbatch sbatch;
|
||||
|
||||
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_self;
|
||||
|
||||
// TODO: remove
|
||||
bool logits_all = false;
|
||||
std::unique_ptr<llama_memory_i> memory;
|
||||
|
||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
size_t logits_size = 0; // capacity (of floats) for logits
|
||||
@ -230,6 +239,9 @@ private:
|
||||
|
||||
ggml_context_ptr ctx_compute;
|
||||
|
||||
// training
|
||||
ggml_opt_context_t opt_ctx = nullptr;
|
||||
|
||||
ggml_threadpool_t threadpool = nullptr;
|
||||
ggml_threadpool_t threadpool_batch = nullptr;
|
||||
|
||||
|
@ -30,6 +30,7 @@ struct llama_cparams {
|
||||
bool flash_attn;
|
||||
bool no_perf;
|
||||
bool warmup;
|
||||
bool op_offload;
|
||||
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
|
@ -284,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
const uint32_t cell_id = i + kv_self->head;
|
||||
|
||||
//////////////////////////////////////////////
|
||||
// TODO: this should not mutate the KV cache !
|
||||
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
||||
|
||||
// prevent out-of-bound sources
|
||||
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
|
||||
data[i] = kv_cell.src;
|
||||
|
||||
// TODO: do not mutate the KV cache
|
||||
// ensure copy only happens once
|
||||
if (kv_cell.src != (int32_t) cell_id) {
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
data[i] = kv_self->s_copy(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -317,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
// clear unused states
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
const uint32_t cell_id = i + kv_self->head;
|
||||
|
||||
//////////////////////////////////////////////
|
||||
// TODO: this should not mutate the KV cache !
|
||||
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
||||
|
||||
data[i] = (float) (kv_cell.src >= 0);
|
||||
|
||||
// only clear once
|
||||
if (kv_cell.src < 0) {
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
data[i] = kv_self->s_mask(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -810,7 +782,7 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||
} break;
|
||||
}
|
||||
|
||||
if (type_gate == LLM_FFN_PAR) {
|
||||
if (gate && type_gate == LLM_FFN_PAR) {
|
||||
cur = ggml_mul(ctx0, cur, tmp);
|
||||
cb(cur, "ffn_gate_par", il);
|
||||
}
|
||||
@ -999,6 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
//cb(inp->tokens, "inp_tokens", -1);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_tokens = inp->tokens;
|
||||
|
||||
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
|
||||
|
||||
@ -1105,7 +1078,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
|
||||
|
||||
@ -1122,7 +1095,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
|
||||
|
||||
@ -1255,8 +1228,19 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
|
||||
if (v_mla) {
|
||||
#if 0
|
||||
// v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
|
||||
// However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
|
||||
cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
|
||||
cur = ggml_mul_mat(ctx0, v_mla, cur);
|
||||
#else
|
||||
// It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
|
||||
// The permutations are noops and only change how the tensor data is interpreted.
|
||||
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||
cur = ggml_mul_mat(ctx0, v_mla, cur);
|
||||
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||
cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
|
||||
#endif
|
||||
}
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
@ -1436,8 +1420,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
GGML_ASSERT(!kv_self->recurrent);
|
||||
|
||||
const auto kv_head = kv_self->head;
|
||||
|
||||
GGML_ASSERT(kv_self->size == n_ctx);
|
||||
@ -1587,7 +1569,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
|
||||
ggml_tensor * state_mask,
|
||||
int32_t n_state,
|
||||
int32_t n_seqs) const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
const auto n_kv = kv_self->n;
|
||||
const auto kv_head = kv_self->head;
|
||||
@ -1619,7 +1601,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
const auto token_shift_count = hparams.token_shift_count;
|
||||
|
||||
@ -1640,7 +1622,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
||||
ggml_tensor * token_shift,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
const auto token_shift_count = hparams.token_shift_count;
|
||||
const auto n_embd = hparams.n_embd;
|
||||
|
@ -19,6 +19,7 @@ struct llama_cparams;
|
||||
|
||||
class llama_memory_i;
|
||||
class llama_kv_cache_unified;
|
||||
class llama_kv_cache_recurrent;
|
||||
|
||||
// certain models (typically multi-modal) can produce different types of graphs
|
||||
enum llm_graph_type {
|
||||
@ -186,26 +187,26 @@ public:
|
||||
|
||||
class llm_graph_input_s_copy : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
||||
llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
||||
virtual ~llm_graph_input_s_copy() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_copy; // I32 [kv_size]
|
||||
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
const llama_kv_cache_recurrent * kv_self;
|
||||
};
|
||||
|
||||
class llm_graph_input_s_mask : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
||||
llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
||||
virtual ~llm_graph_input_s_mask() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_mask; // F32 [1, n_kv]
|
||||
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
const llama_kv_cache_recurrent * kv_self;
|
||||
};
|
||||
|
||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||
@ -297,6 +298,7 @@ class llm_graph_result_i {
|
||||
public:
|
||||
virtual ~llm_graph_result_i() = default;
|
||||
|
||||
virtual ggml_tensor * get_tokens() = 0;
|
||||
virtual ggml_tensor * get_logits() = 0;
|
||||
virtual ggml_tensor * get_embd() = 0;
|
||||
virtual ggml_tensor * get_embd_pooled() = 0;
|
||||
@ -311,6 +313,7 @@ class llm_graph_result : public llm_graph_result_i {
|
||||
public:
|
||||
virtual ~llm_graph_result() = default;
|
||||
|
||||
ggml_tensor * get_tokens() override { return t_tokens; }
|
||||
ggml_tensor * get_logits() override { return t_logits; }
|
||||
ggml_tensor * get_embd() override { return t_embd; }
|
||||
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
|
||||
@ -327,6 +330,7 @@ public:
|
||||
}
|
||||
|
||||
// important graph nodes
|
||||
ggml_tensor * t_tokens = nullptr;
|
||||
ggml_tensor * t_logits = nullptr;
|
||||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
@ -350,8 +354,8 @@ struct llm_graph_params {
|
||||
const llama_cparams & cparams;
|
||||
const llama_ubatch & ubatch;
|
||||
|
||||
ggml_backend_sched * sched;
|
||||
ggml_backend * backend_cpu;
|
||||
ggml_backend_sched_t sched;
|
||||
ggml_backend_t backend_cpu;
|
||||
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
@ -402,9 +406,9 @@ struct llm_graph_context {
|
||||
|
||||
ggml_context * ctx0 = nullptr;
|
||||
|
||||
ggml_backend_sched * sched;
|
||||
ggml_backend_sched_t sched;
|
||||
|
||||
ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
||||
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
||||
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -2,32 +2,72 @@
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-io.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-memory.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
#include <functional>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
struct llama_cparams;
|
||||
struct llama_hparams;
|
||||
struct llama_ubatch;
|
||||
struct llama_sbatch;
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
|
||||
struct llama_kv_cache : public llama_memory_i {
|
||||
using llama_memory_i::llama_memory_i;
|
||||
virtual ~llama_kv_cache() = default;
|
||||
|
||||
virtual void restore() = 0; // call if batch processing fails - restores the cache state
|
||||
virtual void commit() = 0; // call after successful batch processing - clears any pending state
|
||||
// call if batch processing fails - restores the cache state
|
||||
virtual void restore() = 0;
|
||||
|
||||
virtual int32_t get_n_tokens() const = 0;
|
||||
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
|
||||
// call after successful batch processing - clears any pending state
|
||||
virtual void commit() = 0;
|
||||
|
||||
virtual bool get_can_shift() const = 0;
|
||||
// process any pending defrag/shift/etc. operations
|
||||
// optionally call once before processing a new batch
|
||||
virtual bool update(llama_context & lctx) = 0;
|
||||
|
||||
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
|
||||
virtual void defrag_sched(float thold) = 0;
|
||||
|
||||
// simulate full cache, used for allocating worst-case compute buffers
|
||||
virtual void set_full() = 0;
|
||||
|
||||
//
|
||||
// batch processing
|
||||
//
|
||||
|
||||
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
|
||||
|
||||
// different KV caches require different batch splitting strategies
|
||||
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
|
||||
|
||||
// find an empty slot of size "n_tokens" in the cache
|
||||
virtual bool find_slot(const llama_ubatch & batch) = 0;
|
||||
|
||||
// getters
|
||||
virtual int32_t get_n_tokens() const = 0;
|
||||
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
|
||||
virtual llama_pos get_pos_max() const = 0;
|
||||
virtual bool get_can_shift() const = 0;
|
||||
|
||||
bool get_can_edit() const override { return get_can_shift(); }
|
||||
|
||||
//
|
||||
// state write/read
|
||||
//
|
||||
|
||||
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
|
||||
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
|
||||
};
|
||||
|
||||
//
|
||||
// llama_kv_cache_guard
|
||||
//
|
||||
|
||||
struct llama_kv_cache_guard {
|
||||
llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
|
||||
|
||||
@ -43,65 +83,50 @@ private:
|
||||
llama_kv_cache * kv;
|
||||
};
|
||||
|
||||
struct llama_kv_cell {
|
||||
llama_pos pos = -1;
|
||||
llama_pos delta = 0;
|
||||
int32_t src = -1; // used by recurrent state models to copy states
|
||||
int32_t tail = -1;
|
||||
//
|
||||
// llama_kv_cache_unified
|
||||
//
|
||||
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const llama_kv_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
};
|
||||
|
||||
// ring-buffer of cached KV data
|
||||
// TODO: pimpl
|
||||
// TODO: add notion of max sequences
|
||||
class llama_kv_cache_unified : public llama_kv_cache {
|
||||
public:
|
||||
// can be used to query data from the model if needed
|
||||
struct callbacks {
|
||||
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
|
||||
struct kv_cell {
|
||||
llama_pos pos = -1;
|
||||
llama_pos delta = 0;
|
||||
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const kv_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
};
|
||||
|
||||
static uint32_t get_padding(const llama_cparams & cparams);
|
||||
|
||||
llama_kv_cache_unified(
|
||||
const llama_hparams & hparams,
|
||||
callbacks cbs);
|
||||
|
||||
virtual ~llama_kv_cache_unified() = default;
|
||||
|
||||
// TODO: become constructor
|
||||
bool init(
|
||||
const llama_model & model, // TODO: do not reference the model
|
||||
const llama_cparams & cparams,
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
bool offload);
|
||||
uint32_t padding);
|
||||
|
||||
int32_t get_n_tokens() const override;
|
||||
int32_t get_used_cells() const override;
|
||||
~llama_kv_cache_unified() = default;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
// TODO: better data structures to reduce the cost of this operation
|
||||
llama_pos pos_max() const;
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
void clear() override;
|
||||
void defrag() override;
|
||||
|
||||
virtual void restore() override;
|
||||
virtual void commit() override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
@ -111,63 +136,40 @@ public:
|
||||
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
//
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
void restore() override;
|
||||
void commit() override;
|
||||
|
||||
bool update(llama_context & ctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
void set_full() override;
|
||||
|
||||
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
||||
|
||||
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
||||
|
||||
// find an empty slot of size "n_tokens" in the cache
|
||||
// updates the cache head
|
||||
// Note: On success, it's important that cache.head points
|
||||
// to the first cell of the slot.
|
||||
bool find_slot(const llama_ubatch & batch);
|
||||
bool find_slot(const llama_ubatch & batch) override;
|
||||
|
||||
// TODO: maybe not needed
|
||||
uint32_t get_padding(const llama_cparams & cparams) const;
|
||||
int32_t get_n_tokens() const override;
|
||||
int32_t get_used_cells() const override;
|
||||
|
||||
// find how many cells are currently in use
|
||||
uint32_t cell_max() const;
|
||||
// TODO: better data structures to reduce the cost of this operation
|
||||
llama_pos get_pos_max() const override;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
// defrag
|
||||
|
||||
struct {
|
||||
std::vector<uint32_t> ids;
|
||||
} defrag_info;
|
||||
|
||||
// return true if cells have been moved
|
||||
bool defrag_prepare(int32_t n_max_nodes);
|
||||
|
||||
// commit/restore cache
|
||||
|
||||
struct slot_range {
|
||||
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
|
||||
uint32_t c1 = 0;
|
||||
};
|
||||
|
||||
// pending cell updates that are not yet committed
|
||||
struct {
|
||||
std::vector<slot_range> ranges;
|
||||
} pending;
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
|
||||
|
||||
// members
|
||||
|
||||
const llama_hparams & hparams;
|
||||
|
||||
callbacks cbs;
|
||||
|
||||
bool has_shift = false;
|
||||
bool do_defrag = false;
|
||||
|
||||
// TODO: remove this and implement llama_kv_cache_recurrent instead
|
||||
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
bool can_shift = false;
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
||||
@ -179,18 +181,213 @@ public:
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
std::vector<llama_kv_cell> cells;
|
||||
std::vector<kv_cell> cells;
|
||||
|
||||
std::vector<ggml_tensor *> k_l; // per layer
|
||||
std::vector<ggml_tensor *> v_l;
|
||||
|
||||
private:
|
||||
const llama_model & model;
|
||||
const llama_hparams & hparams;
|
||||
|
||||
bool has_shift = false;
|
||||
bool do_defrag = false;
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
bool can_shift = false;
|
||||
|
||||
// required padding
|
||||
uint32_t padding = 1;
|
||||
|
||||
ggml_type type_k = GGML_TYPE_F16;
|
||||
ggml_type type_v = GGML_TYPE_F16;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
// defrag
|
||||
struct {
|
||||
std::vector<uint32_t> ids;
|
||||
} defrag_info;
|
||||
|
||||
// return true if cells have been moved
|
||||
bool defrag_prepare(int32_t n_max_nodes);
|
||||
|
||||
// commit/restore cache
|
||||
struct slot_range {
|
||||
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
|
||||
uint32_t c1 = 0;
|
||||
};
|
||||
|
||||
// pending cell updates that are not yet committed
|
||||
struct {
|
||||
std::vector<slot_range> ranges;
|
||||
} pending;
|
||||
|
||||
// find how many cells are currently in use
|
||||
uint32_t cell_max() const;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
ggml_tensor * build_rope_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale) const;
|
||||
|
||||
llm_graph_result_ptr build_graph_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf) const;
|
||||
|
||||
llm_graph_result_ptr build_graph_defrag(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf) const;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
//
|
||||
// llama_kv_cache_recurrent
|
||||
//
|
||||
|
||||
class llama_kv_cache_recurrent : public llama_kv_cache {
|
||||
public:
|
||||
struct kv_cell {
|
||||
llama_pos pos = -1;
|
||||
int32_t src = -1; // used to copy states
|
||||
int32_t tail = -1;
|
||||
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const kv_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
};
|
||||
|
||||
llama_kv_cache_recurrent(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool offload,
|
||||
uint32_t kv_size);
|
||||
|
||||
~llama_kv_cache_recurrent() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
void clear() override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
void restore() override;
|
||||
void commit() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
void set_full() override;
|
||||
|
||||
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
||||
|
||||
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
||||
|
||||
bool find_slot(const llama_ubatch & batch) override;
|
||||
|
||||
int32_t get_n_tokens() const override;
|
||||
int32_t get_used_cells() const override;
|
||||
|
||||
// TODO: better data structures to reduce the cost of this operation
|
||||
llama_pos get_pos_max() const override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
|
||||
int32_t s_copy(int i) const;
|
||||
float s_mask(int i) const;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
||||
// cannot be freely changed after a slot has been allocated.
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
std::vector<kv_cell> cells;
|
||||
|
||||
std::vector<ggml_tensor *> k_l; // per layer
|
||||
std::vector<ggml_tensor *> v_l;
|
||||
|
||||
private:
|
||||
//const llama_model & model;
|
||||
const llama_hparams & hparams;
|
||||
|
||||
// commit/restore cache
|
||||
// TODO: rework for recurrent cache
|
||||
struct slot_range {
|
||||
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
|
||||
uint32_t c1 = 0;
|
||||
};
|
||||
|
||||
// pending cell updates that are not yet committed
|
||||
struct {
|
||||
std::vector<slot_range> ranges;
|
||||
} pending;
|
||||
|
||||
ggml_type type_k = GGML_TYPE_F16;
|
||||
ggml_type type_v = GGML_TYPE_F16;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
// find how many cells are currently in use
|
||||
uint32_t cell_max() const;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
|
||||
@ -198,11 +395,6 @@ private:
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
|
||||
//class llama_kv_cache_recurrent : public llama_kv_cache_unified {
|
||||
//public:
|
||||
// using llama_kv_cache_unified::llama_kv_cache_unified;
|
||||
//};
|
||||
|
||||
//
|
||||
// kv cache view
|
||||
|
@ -2,12 +2,22 @@
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
struct llama_memory_params {
|
||||
// kv cache
|
||||
ggml_type type_k;
|
||||
ggml_type type_v;
|
||||
|
||||
// parameters for other types of memory
|
||||
// ...
|
||||
};
|
||||
|
||||
// general concept of LLM memory
|
||||
// the KV cache is a type of LLM memory, but there can be other types
|
||||
class llama_memory_i {
|
||||
public:
|
||||
virtual ~llama_memory_i() = default;
|
||||
|
||||
virtual void clear() = 0;
|
||||
virtual void defrag() = 0;
|
||||
|
||||
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
||||
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
||||
|
@ -301,12 +301,12 @@ namespace GGUFMeta {
|
||||
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
|
||||
|
||||
switch (arr_info.gt) {
|
||||
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
||||
case GGUF_TYPE_INT32: GGML_ASSERT(
|
||||
(std::is_same<T, int32_t>::value) ||
|
||||
(std::is_same<T, uint32_t>::value)); break;
|
||||
case GGUF_TYPE_UINT32:
|
||||
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
|
||||
(std::is_same<T, uint32_t>::value)); break;
|
||||
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
||||
default:
|
||||
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
|
||||
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
|
||||
}
|
||||
|
||||
result.resize(arr_info.length);
|
||||
@ -330,12 +330,12 @@ namespace GGUFMeta {
|
||||
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
|
||||
|
||||
switch (arr_info.gt) {
|
||||
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
||||
case GGUF_TYPE_INT32: GGML_ASSERT(
|
||||
(std::is_same<T, int32_t>::value) ||
|
||||
(std::is_same<T, uint32_t>::value)); break;
|
||||
case GGUF_TYPE_UINT32:
|
||||
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
|
||||
(std::is_same<T, uint32_t>::value)); break;
|
||||
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
||||
default:
|
||||
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
|
||||
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
|
||||
}
|
||||
|
||||
if (arr_info.length > N_MAX) {
|
||||
@ -823,6 +823,10 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps
|
||||
mmaps_used.reserve(files.size());
|
||||
for (const auto & file : files) {
|
||||
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
|
||||
if (!reg) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
|
||||
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
|
||||
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa_fn());
|
||||
mmaps_used.emplace_back(mapping->size(), 0);
|
||||
|
281
examples/talk-llama/llama-model-saver.cpp
Normal file
281
examples/talk-llama/llama-model-saver.cpp
Normal file
@ -0,0 +1,281 @@
|
||||
#include "llama-model-saver.h"
|
||||
|
||||
#include "gguf.h"
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-hparams.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-vocab.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
llama_model_saver::llama_model_saver(const struct llama_model & model) : model(model), llm_kv(model.arch) {
|
||||
gguf_ctx = gguf_init_empty();
|
||||
}
|
||||
|
||||
llama_model_saver::~llama_model_saver() {
|
||||
gguf_free(gguf_ctx);
|
||||
}
|
||||
|
||||
void llama_model_saver::add_kv(const enum llm_kv key, const uint32_t value) {
|
||||
gguf_set_val_u32(gguf_ctx, llm_kv(key).c_str(), value);
|
||||
}
|
||||
|
||||
void llama_model_saver::add_kv(const enum llm_kv key, const int32_t value) {
|
||||
gguf_set_val_i32(gguf_ctx, llm_kv(key).c_str(), value);
|
||||
}
|
||||
|
||||
void llama_model_saver::add_kv(const enum llm_kv key, const float value) {
|
||||
gguf_set_val_f32(gguf_ctx, llm_kv(key).c_str(), value);
|
||||
}
|
||||
|
||||
void llama_model_saver::add_kv(const enum llm_kv key, const bool value) {
|
||||
gguf_set_val_bool(gguf_ctx, llm_kv(key).c_str(), value);
|
||||
}
|
||||
|
||||
void llama_model_saver::add_kv(const enum llm_kv key, const char * value) {
|
||||
gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), value);
|
||||
}
|
||||
|
||||
[[noreturn]]
|
||||
void llama_model_saver::add_kv(const enum llm_kv key, const char value) {
|
||||
GGML_UNUSED(key);
|
||||
GGML_UNUSED(value);
|
||||
GGML_ABORT("fatal error"); // this should never be called, only needed to make the template below compile
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, const bool per_layer) {
|
||||
const size_t n_values = per_layer ? size_t(model.hparams.n_layer) : value.size();
|
||||
GGML_ASSERT(n_values <= value.size());
|
||||
|
||||
if (n_values == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (per_layer) {
|
||||
bool all_values_the_same = true;
|
||||
for (size_t i = 1; i < n_values; ++i) {
|
||||
if (value[i] != value[0]) {
|
||||
all_values_the_same = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (all_values_the_same) {
|
||||
add_kv(key, value[0]);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (std::is_same<typename Container::value_type, uint8_t>::value) {
|
||||
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT8, value.data(), n_values);
|
||||
} else if (std::is_same<typename Container::value_type, int8_t>::value) {
|
||||
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT8, value.data(), n_values);
|
||||
} else if (std::is_same<typename Container::value_type, uint32_t>::value) {
|
||||
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT32, value.data(), n_values);
|
||||
} else if (std::is_same<typename Container::value_type, int32_t>::value) {
|
||||
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT32, value.data(), n_values);
|
||||
} else if (std::is_same<typename Container::value_type, float>::value) {
|
||||
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_FLOAT32, value.data(), n_values);
|
||||
} else if (std::is_same<Container, std::string>::value) {
|
||||
gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), reinterpret_cast<const char *>(value.data()));
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_saver::add_kv(const enum llm_kv key, const std::vector<std::string> & value) {
|
||||
std::vector<const char *> tmp(value.size());
|
||||
for (size_t i = 0; i < value.size(); ++i) {
|
||||
tmp[i] = value[i].c_str();
|
||||
}
|
||||
gguf_set_arr_str(gguf_ctx, llm_kv(key).c_str(), tmp.data(), tmp.size());
|
||||
}
|
||||
|
||||
void llama_model_saver::add_tensor(const struct ggml_tensor * tensor) {
|
||||
if (!tensor) {
|
||||
return;
|
||||
}
|
||||
if (gguf_find_tensor(gguf_ctx, tensor->name) >= 0) {
|
||||
GGML_ASSERT(std::string(tensor->name) == "rope_freqs.weight"); // FIXME
|
||||
return;
|
||||
}
|
||||
gguf_add_tensor(gguf_ctx, tensor);
|
||||
}
|
||||
|
||||
void llama_model_saver::add_kv_from_model() {
|
||||
const llama_hparams & hparams = model.hparams;
|
||||
const llama_vocab & vocab = model.vocab;
|
||||
|
||||
const int32_t n_vocab = vocab.n_tokens();
|
||||
std::vector<std::string> tokens(n_vocab);
|
||||
std::vector<float> scores(n_vocab);
|
||||
std::vector<int32_t> token_types(n_vocab);
|
||||
|
||||
for (int32_t id = 0; id < n_vocab; ++id) {
|
||||
const llama_vocab::token_data & token_data = vocab.get_token_data(id);
|
||||
|
||||
tokens[id] = token_data.text;
|
||||
scores[id] = token_data.score;
|
||||
|
||||
switch(token_data.attr) {
|
||||
case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break;
|
||||
case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break;
|
||||
case LLAMA_TOKEN_ATTR_NORMAL: token_types[id] = LLAMA_TOKEN_TYPE_NORMAL; break;
|
||||
case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break;
|
||||
case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break;
|
||||
case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break;
|
||||
case LLAMA_TOKEN_ATTR_UNDEFINED:
|
||||
default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break;
|
||||
}
|
||||
}
|
||||
|
||||
// add_kv(LLM_KV_GENERAL_TYPE, ???);
|
||||
add_kv(LLM_KV_GENERAL_ARCHITECTURE, model.arch_name());
|
||||
// add_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION, ???);
|
||||
// add_kv(LLM_KV_GENERAL_ALIGNMENT, ???);
|
||||
add_kv(LLM_KV_GENERAL_NAME, model.name);
|
||||
// add_kv(LLM_KV_GENERAL_AUTHOR, ???);
|
||||
// add_kv(LLM_KV_GENERAL_VERSION, ???);
|
||||
// add_kv(LLM_KV_GENERAL_URL, ???);
|
||||
// add_kv(LLM_KV_GENERAL_DESCRIPTION, ???);
|
||||
// add_kv(LLM_KV_GENERAL_LICENSE, ???);
|
||||
// add_kv(LLM_KV_GENERAL_SOURCE_URL, ???);
|
||||
// add_kv(LLM_KV_GENERAL_SOURCE_HF_REPO, ???);
|
||||
|
||||
add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens());
|
||||
add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
|
||||
add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
|
||||
add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer);
|
||||
add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
||||
add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true);
|
||||
add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
add_kv(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res);
|
||||
// add_kv(LLM_KV_TENSOR_DATA_LAYOUT, ???);
|
||||
add_kv(LLM_KV_EXPERT_COUNT, hparams.n_expert);
|
||||
add_kv(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used);
|
||||
add_kv(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||
add_kv(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
|
||||
add_kv(LLM_KV_POOLING_TYPE, uint32_t(hparams.pooling_type));
|
||||
add_kv(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
||||
add_kv(LLM_KV_DECODER_START_TOKEN_ID, hparams.dec_start_token_id);
|
||||
add_kv(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping);
|
||||
add_kv(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping);
|
||||
add_kv(LLM_KV_SWIN_NORM, hparams.swin_norm);
|
||||
add_kv(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers);
|
||||
add_kv(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim);
|
||||
add_kv(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim);
|
||||
add_kv(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale);
|
||||
add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale);
|
||||
|
||||
add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true);
|
||||
add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true);
|
||||
add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
|
||||
add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv);
|
||||
add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k);
|
||||
add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v);
|
||||
add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
||||
add_kv(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
|
||||
add_kv(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
|
||||
add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts);
|
||||
add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
||||
add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
|
||||
|
||||
const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train;
|
||||
|
||||
add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot);
|
||||
add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train);
|
||||
// add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name
|
||||
add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train));
|
||||
add_kv(LLM_KV_ROPE_SCALING_FACTOR, rope_scaling_factor);
|
||||
add_kv(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor);
|
||||
add_kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn);
|
||||
add_kv(LLM_KV_ROPE_SCALING_FINETUNED, hparams.rope_finetuned);
|
||||
add_kv(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul);
|
||||
|
||||
// TODO: implement split file support
|
||||
// add_kv(LLM_KV_SPLIT_NO, ???);
|
||||
// add_kv(LLM_KV_SPLIT_COUNT, ???);
|
||||
// add_kv(LLM_KV_SPLIT_TENSORS_COUNT, ???);
|
||||
|
||||
add_kv(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
||||
add_kv(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
||||
add_kv(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
|
||||
add_kv(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
|
||||
add_kv(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms);
|
||||
|
||||
add_kv(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
|
||||
|
||||
add_kv(LLM_KV_TOKENIZER_MODEL, vocab.get_tokenizer_model());
|
||||
add_kv(LLM_KV_TOKENIZER_PRE, vocab.get_tokenizer_pre());
|
||||
add_kv(LLM_KV_TOKENIZER_LIST, tokens);
|
||||
add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE, token_types);
|
||||
add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, vocab.n_token_types());
|
||||
add_kv(LLM_KV_TOKENIZER_SCORES, scores);
|
||||
add_kv(LLM_KV_TOKENIZER_MERGES, vocab.get_bpe_merges());
|
||||
// FIXME llama_token is type i32 but when reading in a GGUF file u32 is expected, not an issue for writing though
|
||||
add_kv(LLM_KV_TOKENIZER_BOS_ID, uint32_t(vocab.token_bos()));
|
||||
add_kv(LLM_KV_TOKENIZER_EOS_ID, uint32_t(vocab.token_eos()));
|
||||
add_kv(LLM_KV_TOKENIZER_EOT_ID, uint32_t(vocab.token_eot()));
|
||||
add_kv(LLM_KV_TOKENIZER_EOM_ID, uint32_t(vocab.token_eom()));
|
||||
add_kv(LLM_KV_TOKENIZER_UNK_ID, uint32_t(vocab.token_unk()));
|
||||
add_kv(LLM_KV_TOKENIZER_SEP_ID, uint32_t(vocab.token_sep()));
|
||||
add_kv(LLM_KV_TOKENIZER_PAD_ID, uint32_t(vocab.token_pad()));
|
||||
// add_kv(LLM_KV_TOKENIZER_CLS_ID, uint32_t(vocab.token_bos())); // deprecated
|
||||
// add_kv(LLM_KV_TOKENIZER_MASK_ID, ???);
|
||||
add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos());
|
||||
add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos());
|
||||
add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix());
|
||||
add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces());
|
||||
add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap());
|
||||
// add_kv(LLM_KV_TOKENIZER_HF_JSON, ???);
|
||||
// add_kv(LLM_KV_TOKENIZER_RWKV, ???);
|
||||
add_kv(LLM_KV_TOKENIZER_FIM_PRE_ID, uint32_t(vocab.token_fim_pre()));
|
||||
add_kv(LLM_KV_TOKENIZER_FIM_SUF_ID, uint32_t(vocab.token_fim_suf()));
|
||||
add_kv(LLM_KV_TOKENIZER_FIM_MID_ID, uint32_t(vocab.token_fim_mid()));
|
||||
add_kv(LLM_KV_TOKENIZER_FIM_PAD_ID, uint32_t(vocab.token_fim_pad()));
|
||||
add_kv(LLM_KV_TOKENIZER_FIM_REP_ID, uint32_t(vocab.token_fim_rep()));
|
||||
add_kv(LLM_KV_TOKENIZER_FIM_SEP_ID, uint32_t(vocab.token_fim_sep()));
|
||||
|
||||
// TODO: implement LoRA support
|
||||
// add_kv(LLM_KV_ADAPTER_TYPE, ???);
|
||||
// add_kv(LLM_KV_ADAPTER_LORA_ALPHA, ???);
|
||||
|
||||
// deprecated
|
||||
// add_kv(LLM_KV_TOKENIZER_PREFIX_ID, ???);
|
||||
// add_kv(LLM_KV_TOKENIZER_SUFFIX_ID, ???);
|
||||
// add_kv(LLM_KV_TOKENIZER_MIDDLE_ID, ???);
|
||||
}
|
||||
|
||||
void llama_model_saver::add_tensors_from_model() {
|
||||
if (std::string(model.output->name) != std::string(model.tok_embd->name)) {
|
||||
add_tensor(model.tok_embd); // some models use the same tensor for tok_embd and output
|
||||
}
|
||||
add_tensor(model.type_embd);
|
||||
add_tensor(model.pos_embd);
|
||||
add_tensor(model.tok_norm);
|
||||
add_tensor(model.tok_norm_b);
|
||||
add_tensor(model.output_norm);
|
||||
add_tensor(model.output_norm_b);
|
||||
add_tensor(model.output);
|
||||
add_tensor(model.output_b);
|
||||
add_tensor(model.output_norm_enc);
|
||||
add_tensor(model.cls);
|
||||
add_tensor(model.cls_b);
|
||||
add_tensor(model.cls_out);
|
||||
add_tensor(model.cls_out_b);
|
||||
|
||||
for (const struct llama_layer & layer : model.layers) {
|
||||
for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
|
||||
add_tensor(reinterpret_cast<const struct ggml_tensor * const *>(&layer)[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_saver::save(const std::string & path_model) {
|
||||
gguf_write_to_file(gguf_ctx, path_model.c_str(), false);
|
||||
}
|
||||
|
37
examples/talk-llama/llama-model-saver.h
Normal file
37
examples/talk-llama/llama-model-saver.h
Normal file
@ -0,0 +1,37 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-arch.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
struct llama_model_saver {
|
||||
struct gguf_context * gguf_ctx = nullptr;
|
||||
const struct llama_model & model;
|
||||
const struct LLM_KV llm_kv;
|
||||
|
||||
llama_model_saver(const struct llama_model & model);
|
||||
~llama_model_saver();
|
||||
|
||||
void add_kv(enum llm_kv key, uint32_t value);
|
||||
void add_kv(enum llm_kv key, int32_t value);
|
||||
void add_kv(enum llm_kv key, float value);
|
||||
void add_kv(enum llm_kv key, bool value);
|
||||
void add_kv(enum llm_kv key, const char * value);
|
||||
|
||||
[[noreturn]]
|
||||
void add_kv(enum llm_kv key, char value); // needed to make the template below compile
|
||||
|
||||
template <typename Container>
|
||||
void add_kv(enum llm_kv key, const Container & value, bool per_layer = false);
|
||||
|
||||
void add_kv(enum llm_kv key, const std::vector<std::string> & value);
|
||||
|
||||
void add_tensor(const struct ggml_tensor * tensor);
|
||||
|
||||
void add_kv_from_model();
|
||||
|
||||
void add_tensors_from_model();
|
||||
|
||||
void save(const std::string & path_model);
|
||||
};
|
@ -80,6 +80,7 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_236B: return "236B";
|
||||
case LLM_TYPE_290B: return "290B";
|
||||
case LLM_TYPE_314B: return "314B";
|
||||
case LLM_TYPE_405B: return "405B";
|
||||
case LLM_TYPE_671B: return "671B";
|
||||
case LLM_TYPE_SMALL: return "0.1B";
|
||||
case LLM_TYPE_MEDIUM: return "0.4B";
|
||||
@ -116,6 +117,10 @@ static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_
|
||||
{ LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" },
|
||||
};
|
||||
|
||||
std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type) {
|
||||
return LLAMA_ROPE_SCALING_TYPES.at(rope_scaling_type);
|
||||
}
|
||||
|
||||
static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
|
||||
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
|
||||
if (kv.second == name) {
|
||||
@ -298,6 +303,10 @@ static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & de
|
||||
// add extra buffer types, only if no GPU device is present
|
||||
// ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (cpu_dev == nullptr) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
|
||||
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
||||
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
||||
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
|
||||
@ -582,6 +591,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
switch (hparams.n_layer) {
|
||||
case 32: type = LLM_TYPE_7B; break;
|
||||
case 80: type = LLM_TYPE_70B; break;
|
||||
case 162: type = LLM_TYPE_405B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
@ -773,6 +783,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
// fall through
|
||||
case LLM_ARCH_QWEN2:
|
||||
{
|
||||
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
switch (hparams.n_layer) {
|
||||
case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break;
|
||||
@ -1481,6 +1492,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
}
|
||||
|
||||
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (cpu_dev == nullptr) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
|
||||
const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
|
||||
auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
|
||||
@ -1648,8 +1662,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
|
||||
std::regex pattern(overrides->pattern);
|
||||
if (std::regex_search(tensor_name, pattern)) {
|
||||
LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft));
|
||||
buft = overrides->buft;
|
||||
LLAMA_LOG_DEBUG("tensor %s (%zu MiB %s) buffer type overridden to %s\n",
|
||||
tensor_name.c_str(),
|
||||
ggml_nbytes(t_meta) / 1024 / 1024, ggml_type_name(t_meta->type),
|
||||
ggml_backend_buft_name(buft));
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -1666,6 +1683,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
auto * buft_dev = ggml_backend_buft_get_device(buft);
|
||||
if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!cpu_dev) {
|
||||
throw std::runtime_error("no CPU backend found");
|
||||
}
|
||||
buft = ggml_backend_dev_buffer_type(cpu_dev);
|
||||
}
|
||||
|
||||
@ -1847,7 +1867,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
if (n_ff > 0) {
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
|
||||
if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
|
||||
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
@ -1857,9 +1879,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
}
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
if (n_ff > 0) {
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
|
||||
// optional MLP bias
|
||||
layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
|
||||
@ -3503,7 +3527,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
@ -4108,6 +4136,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
if (!dev) {
|
||||
// FIXME: workaround for CPU backend buft having a NULL device
|
||||
dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!dev) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
}
|
||||
ggml_backend_dev_props props;
|
||||
ggml_backend_dev_get_props(dev, &props);
|
||||
@ -4237,7 +4268,7 @@ uint64_t llama_model::n_elements() const {
|
||||
}
|
||||
|
||||
void llama_model::print_info() const {
|
||||
const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
|
||||
const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train);
|
||||
|
||||
auto print_f = [](const std::function<uint32_t(uint32_t)> & f, uint32_t n) {
|
||||
bool is_var = false;
|
||||
@ -4298,7 +4329,7 @@ void llama_model::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn);
|
||||
LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
|
||||
LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
|
||||
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
|
||||
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
|
||||
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
|
||||
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
||||
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
|
||||
@ -4445,6 +4476,19 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const {
|
||||
// choose long/short freq factors based on the context size
|
||||
if (layers[il].rope_freqs != nullptr) {
|
||||
return layers[il].rope_freqs;
|
||||
}
|
||||
|
||||
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
|
||||
return layers[il].rope_long;
|
||||
}
|
||||
|
||||
return layers[il].rope_short;
|
||||
}
|
||||
|
||||
struct llm_build_llama : public llm_graph_context {
|
||||
llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
@ -4485,7 +4529,7 @@ struct llm_build_llama : public llm_graph_context {
|
||||
// self-attention
|
||||
{
|
||||
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
|
||||
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
@ -4691,6 +4735,7 @@ struct llm_build_deci : public llm_graph_context {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||
const int64_t n_head = hparams.n_head(il);
|
||||
const int64_t n_ff = hparams.n_ff(il);
|
||||
|
||||
if (n_head == 0) {
|
||||
// attention-free layer of Llama-3_1-Nemotron-51B
|
||||
@ -4710,7 +4755,7 @@ struct llm_build_deci : public llm_graph_context {
|
||||
} else if (n_head > 0) {
|
||||
// self-attention
|
||||
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
|
||||
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
@ -4766,6 +4811,11 @@ struct llm_build_deci : public llm_graph_context {
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
// FFN-free layer of Llama-3_1-Nemotron-Ultra-253B
|
||||
if (n_ff == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// For Granite architecture
|
||||
if (hparams.f_residual_scale) {
|
||||
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
||||
@ -7192,7 +7242,7 @@ struct llm_build_phi3 : public llm_graph_context {
|
||||
// self-attention
|
||||
{
|
||||
// rope freq factors for 128k context
|
||||
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
|
||||
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
|
||||
|
||||
ggml_tensor* attn_norm_output = build_norm(inpL,
|
||||
model.layers[il].attn_norm,
|
||||
@ -7944,7 +7994,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
|
||||
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,
|
||||
@ -8711,7 +8761,7 @@ struct llm_build_mamba : public llm_graph_context {
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
const auto kv_head = kv_self->head;
|
||||
|
||||
@ -9012,7 +9062,7 @@ struct llm_build_cohere2 : public llm_graph_context {
|
||||
// self-attention
|
||||
{
|
||||
// rope freq factors for 128k context
|
||||
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
|
||||
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
@ -9950,7 +10000,7 @@ struct llm_build_deepseek : public llm_graph_context {
|
||||
// self-attention
|
||||
{
|
||||
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
|
||||
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
@ -11314,7 +11364,7 @@ struct llm_build_exaone : public llm_graph_context {
|
||||
// self-attention
|
||||
{
|
||||
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
|
||||
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
@ -11459,7 +11509,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
@ -11855,7 +11905,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||
ggml_tensor *& first_layer_value,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
@ -12695,7 +12745,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
|
||||
// self-attention
|
||||
{
|
||||
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
|
||||
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
@ -12815,36 +12865,46 @@ struct llm_build_bailingmoe : public llm_graph_context {
|
||||
}
|
||||
};
|
||||
|
||||
llama_memory_i * llama_model::create_memory() const {
|
||||
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
|
||||
llama_memory_i * res;
|
||||
|
||||
switch (arch) {
|
||||
case LLM_ARCH_BERT:
|
||||
case LLM_ARCH_JINA_BERT_V2:
|
||||
case LLM_ARCH_NOMIC_BERT:
|
||||
case LLM_ARCH_NOMIC_BERT_MOE:
|
||||
{
|
||||
res = nullptr;
|
||||
} break;
|
||||
case LLM_ARCH_MAMBA:
|
||||
case LLM_ARCH_RWKV6:
|
||||
case LLM_ARCH_RWKV6QWEN2:
|
||||
case LLM_ARCH_RWKV7:
|
||||
case LLM_ARCH_ARWKV7:
|
||||
{
|
||||
res = new llama_kv_cache_unified(hparams, {
|
||||
/*.get_rope_factors =*/ nullptr
|
||||
});
|
||||
res = new llama_kv_cache_recurrent(
|
||||
*this,
|
||||
GGML_TYPE_F32,
|
||||
GGML_TYPE_F32,
|
||||
cparams.offload_kqv,
|
||||
std::max((uint32_t) 1, cparams.n_seq_max));
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
res = new llama_kv_cache_unified(hparams, {
|
||||
/*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) {
|
||||
// choose long/short freq factors based on the context size
|
||||
if (layers[il].rope_freqs != nullptr) {
|
||||
return layers[il].rope_freqs;
|
||||
}
|
||||
const auto padding = llama_kv_cache_unified::get_padding(cparams);
|
||||
|
||||
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
|
||||
return layers[il].rope_long;
|
||||
}
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
||||
|
||||
return layers[il].rope_short;
|
||||
}
|
||||
});
|
||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
||||
|
||||
res = new llama_kv_cache_unified(
|
||||
*this,
|
||||
params.type_k,
|
||||
params.type_v,
|
||||
!cparams.flash_attn,
|
||||
cparams.offload_kqv,
|
||||
cparams.n_ctx,
|
||||
padding);
|
||||
}
|
||||
}
|
||||
|
||||
@ -13226,8 +13286,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_DECI:
|
||||
case LLM_ARCH_BAICHUAN:
|
||||
case LLM_ARCH_STARCODER:
|
||||
case LLM_ARCH_PLAMO:
|
||||
case LLM_ARCH_ORION:
|
||||
case LLM_ARCH_INTERNLM2:
|
||||
case LLM_ARCH_MINICPM:
|
||||
case LLM_ARCH_XVERSE:
|
||||
@ -13265,6 +13323,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_PHI2:
|
||||
case LLM_ARCH_PHI3:
|
||||
case LLM_ARCH_PHIMOE:
|
||||
case LLM_ARCH_PLAMO:
|
||||
case LLM_ARCH_GEMMA:
|
||||
case LLM_ARCH_GEMMA2:
|
||||
case LLM_ARCH_GEMMA3:
|
||||
@ -13272,6 +13331,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_OPENELM:
|
||||
case LLM_ARCH_GPTNEOX:
|
||||
case LLM_ARCH_CODESHELL:
|
||||
case LLM_ARCH_ORION:
|
||||
case LLM_ARCH_NEMOTRON:
|
||||
case LLM_ARCH_EXAONE:
|
||||
case LLM_ARCH_MINICPM3:
|
||||
@ -13344,6 +13404,14 @@ const char * llama_model_chat_template(const llama_model * model, const char * n
|
||||
: LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
|
||||
const auto & it = model->gguf_kv.find(key);
|
||||
if (it == model->gguf_kv.end()) {
|
||||
// one-off fix for very popular models (so we are not flooded with issues)
|
||||
// do not extend this list unless absolutely necessary
|
||||
// Mistral-Small-2503 does not have built-in chat template
|
||||
llama_vocab_pre_type pre_type = model->vocab.get_pre_type();
|
||||
if (pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
|
||||
return "mistral-v7-tekken";
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -76,6 +76,7 @@ enum llm_type {
|
||||
LLM_TYPE_236B,
|
||||
LLM_TYPE_290B,
|
||||
LLM_TYPE_314B,
|
||||
LLM_TYPE_405B,
|
||||
LLM_TYPE_671B,
|
||||
LLM_TYPE_SMALL,
|
||||
LLM_TYPE_MEDIUM,
|
||||
@ -95,6 +96,8 @@ enum llm_type {
|
||||
LLM_TYPE_235B_A22B,
|
||||
};
|
||||
|
||||
std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type);
|
||||
|
||||
struct llama_layer_posnet {
|
||||
// resnet
|
||||
struct ggml_tensor * norm1 = nullptr;
|
||||
@ -395,8 +398,11 @@ struct llama_model {
|
||||
|
||||
const struct ggml_tensor * get_tensor(const char * name) const;
|
||||
|
||||
ggml_tensor * get_rope_factors(uint32_t n_ctx_per_seq, int il) const;
|
||||
|
||||
// note: can mutate `cparams`
|
||||
// TODO: move this to new llm_arch_model_i interface
|
||||
llama_memory_i * create_memory() const; // TODO: params
|
||||
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
|
||||
|
||||
// TODO: move this to new llm_arch_model_i interface
|
||||
llm_graph_result_ptr build_graph(
|
||||
|
@ -519,7 +519,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
nthread = std::thread::hardware_concurrency();
|
||||
}
|
||||
|
||||
// mmap consistently increases speed Linux, and also increases speed on Windows with
|
||||
// mmap consistently increases speed on Linux, and also increases speed on Windows with
|
||||
// hot cache. It may cause a slowdown on macOS, possibly related to free memory.
|
||||
#if defined(__linux__) || defined(_WIN32)
|
||||
constexpr bool use_mmap = true;
|
||||
@ -529,7 +529,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
|
||||
llama_model_kv_override * kv_overrides = nullptr;
|
||||
if (params->kv_overrides) {
|
||||
auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
|
||||
auto * v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
|
||||
kv_overrides = v->data();
|
||||
}
|
||||
|
||||
|
@ -1750,23 +1750,35 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler *
|
||||
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
|
||||
|
||||
if (ctx->n <= 0.0f || cur_p->size <= 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
// find max logit and calculate mean
|
||||
float max = cur_p->data[0].logit;
|
||||
float logits_sum = 0;
|
||||
size_t valid_count = 0;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].logit > max) {
|
||||
max = cur_p->data[i].logit;
|
||||
// Only count non-negative infinity values
|
||||
if (cur_p->data[i].logit != -INFINITY) {
|
||||
if (cur_p->data[i].logit > max) {
|
||||
max = cur_p->data[i].logit;
|
||||
}
|
||||
logits_sum += cur_p->data[i].logit;
|
||||
valid_count++;
|
||||
}
|
||||
logits_sum += cur_p->data[i].logit;
|
||||
}
|
||||
float mean = logits_sum/cur_p->size;
|
||||
float mean = valid_count > 0 ? logits_sum/valid_count : 0;
|
||||
|
||||
// calculate standard deviation
|
||||
float acc = 0;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
acc += pow(cur_p->data[i].logit - mean, 2);
|
||||
// Skip -infinity in std calculation
|
||||
if (cur_p->data[i].logit != -INFINITY) {
|
||||
acc += pow(cur_p->data[i].logit - mean, 2);
|
||||
}
|
||||
}
|
||||
float std = sqrt(acc/cur_p->size);
|
||||
float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
|
||||
|
||||
//apply mask
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
|
@ -1,5 +1,7 @@
|
||||
#include "llama-vocab.h"
|
||||
|
||||
#include "ggml.h"
|
||||
#include "gguf.h"
|
||||
#include "llama-impl.h"
|
||||
#include "llama-model-loader.h"
|
||||
|
||||
@ -415,6 +417,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||
"'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_SEED_CODER:
|
||||
regex_exprs = {
|
||||
// original regex from tokenizer.json
|
||||
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\r\n]+|\\s*[\r\n]+|\\s+(?!\\S)|\\s+"
|
||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
regex_exprs = {
|
||||
@ -1227,6 +1236,9 @@ struct fragment_buffer_variant {
|
||||
struct llama_vocab::impl {
|
||||
uint32_t n_token_types = 0; // for BERT-style token types
|
||||
|
||||
std::string tokenizer_model;
|
||||
std::string tokenizer_pre;
|
||||
|
||||
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
||||
enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||
|
||||
@ -1362,9 +1374,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|
||||
// determine vocab type
|
||||
{
|
||||
std::string tokenizer_model;
|
||||
std::string tokenizer_pre;
|
||||
|
||||
ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
|
||||
ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
|
||||
|
||||
@ -1459,7 +1468,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|
||||
const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
|
||||
if (precompiled_charsmap_keyidx != -1) {
|
||||
size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
|
||||
const gguf_type pc_type = gguf_get_arr_type(ctx, precompiled_charsmap_keyidx);
|
||||
GGML_ASSERT(pc_type == GGUF_TYPE_INT8 || pc_type == GGUF_TYPE_UINT8);
|
||||
|
||||
const size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
|
||||
const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
|
||||
precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
|
||||
#ifdef IS_BIG_ENDIAN
|
||||
@ -1634,6 +1646,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
tokenizer_pre == "bailingmoe") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "seed-coder") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
|
||||
clean_spaces = false;
|
||||
} else {
|
||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||
}
|
||||
@ -2778,6 +2794,14 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
pimpl->load(ml, kv);
|
||||
}
|
||||
|
||||
std::string llama_vocab::get_tokenizer_model() const {
|
||||
return pimpl->tokenizer_model;
|
||||
}
|
||||
|
||||
std::string llama_vocab::get_tokenizer_pre() const {
|
||||
return pimpl->tokenizer_pre;
|
||||
}
|
||||
|
||||
enum llama_vocab_type llama_vocab::get_type() const {
|
||||
return pimpl->type;
|
||||
}
|
||||
@ -3000,6 +3024,20 @@ int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> llama_vocab::get_bpe_merges() const {
|
||||
std::vector<std::string> result(pimpl->bpe_ranks.size());
|
||||
|
||||
for (const auto & pair : pimpl->bpe_ranks) {
|
||||
result[pair.second] = pair.first.first + " " + pair.first.second;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<char> llama_vocab::get_precompiled_charsmap() const {
|
||||
return pimpl->precompiled_charsmap;
|
||||
}
|
||||
|
||||
int32_t llama_vocab::tokenize(
|
||||
const char * text,
|
||||
int32_t text_len,
|
||||
|
@ -21,6 +21,9 @@ struct llama_vocab {
|
||||
|
||||
void load(llama_model_loader & ml, const LLM_KV & kv);
|
||||
|
||||
std::string get_tokenizer_model() const;
|
||||
std::string get_tokenizer_pre() const;
|
||||
|
||||
enum llama_vocab_type get_type() const;
|
||||
enum llama_vocab_pre_type get_pre_type() const;
|
||||
|
||||
@ -80,6 +83,9 @@ struct llama_vocab {
|
||||
int max_token_len() const;
|
||||
|
||||
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
||||
std::vector<std::string> get_bpe_merges() const;
|
||||
|
||||
std::vector<char> get_precompiled_charsmap() const;
|
||||
|
||||
int32_t tokenize(
|
||||
const char * text,
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include "llama-mmap.h"
|
||||
#include "llama-vocab.h"
|
||||
#include "llama-model-loader.h"
|
||||
#include "llama-model-saver.h"
|
||||
#include "llama-model.h"
|
||||
|
||||
#include "ggml.h"
|
||||
@ -16,6 +17,10 @@
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
//
|
||||
// interface implementation
|
||||
//
|
||||
@ -249,6 +254,13 @@ struct llama_model * llama_model_load_from_splits(
|
||||
return llama_model_load_from_file_impl(splits.front(), splits, params);
|
||||
}
|
||||
|
||||
void llama_model_save_to_file(const struct llama_model * model, const char * path_model) {
|
||||
llama_model_saver ms(*model);
|
||||
ms.add_kv_from_model();
|
||||
ms.add_tensors_from_model();
|
||||
ms.save(path_model);
|
||||
}
|
||||
|
||||
//
|
||||
// chat templates
|
||||
//
|
||||
@ -334,3 +346,4 @@ const char * llama_print_system_info(void) {
|
||||
|
||||
return s.c_str();
|
||||
}
|
||||
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include "ggml.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-opt.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
@ -112,6 +113,7 @@ extern "C" {
|
||||
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
|
||||
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
||||
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
||||
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
@ -343,7 +345,7 @@ extern "C" {
|
||||
float yarn_beta_fast; // YaRN low correction dim
|
||||
float yarn_beta_slow; // YaRN high correction dim
|
||||
uint32_t yarn_orig_ctx; // YaRN original context size
|
||||
float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default)
|
||||
float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default)
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
@ -351,19 +353,18 @@ extern "C" {
|
||||
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
|
||||
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
|
||||
|
||||
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
||||
// TODO: move at the end of the struct
|
||||
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||
bool embeddings; // if true, extract embeddings (together with logits)
|
||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
||||
bool no_perf; // whether to measure performance timings
|
||||
|
||||
// Abort callback
|
||||
// if it returns true, execution of llama_decode() will be aborted
|
||||
// currently works only with CPU execution
|
||||
ggml_abort_callback abort_callback;
|
||||
void * abort_callback_data;
|
||||
|
||||
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
||||
bool embeddings; // if true, extract embeddings (together with logits)
|
||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
||||
bool no_perf; // whether to measure performance timings
|
||||
bool op_offload; // whether to offload host tensor operations to device
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
@ -445,6 +446,10 @@ extern "C" {
|
||||
size_t n_paths,
|
||||
struct llama_model_params params);
|
||||
|
||||
LLAMA_API void llama_model_save_to_file(
|
||||
const struct llama_model * model,
|
||||
const char * path_model);
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
|
||||
"use llama_model_free instead");
|
||||
|
||||
@ -924,14 +929,19 @@ extern "C" {
|
||||
// Frees a batch of tokens allocated with llama_batch_init()
|
||||
LLAMA_API void llama_batch_free(struct llama_batch batch);
|
||||
|
||||
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
|
||||
// Stores the encoder output internally for later use by the decoder cross-attention layers.
|
||||
// Process a batch of tokens.
|
||||
// In contrast to llama_decode() - this call does not use KV cache.
|
||||
// For encode-decoder contexts, processes the batch using the encoder.
|
||||
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
|
||||
// 0 - success
|
||||
// < 0 - error. the KV cache state is restored to the state before this call
|
||||
LLAMA_API int32_t llama_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
|
||||
// Process a batch of tokens.
|
||||
// Requires KV cache.
|
||||
// For encode-decoder contexts, processes the batch using the decoder.
|
||||
// Positive return values does not mean a fatal error, but rather a warning.
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
@ -1428,6 +1438,37 @@ extern "C" {
|
||||
LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
|
||||
LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
|
||||
|
||||
//
|
||||
// training
|
||||
//
|
||||
|
||||
// function that returns whether or not a given tensor contains trainable parameters
|
||||
typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata);
|
||||
|
||||
// always returns true
|
||||
LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata);
|
||||
|
||||
struct llama_opt_params {
|
||||
uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
|
||||
|
||||
llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters
|
||||
void * param_filter_ud; // userdata for determining which tensors contain trainable parameters
|
||||
|
||||
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
||||
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
||||
};
|
||||
|
||||
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
|
||||
|
||||
LLAMA_API void llama_opt_epoch(
|
||||
struct llama_context * lctx,
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result_train,
|
||||
ggml_opt_result_t result_eval,
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@ -248,7 +248,7 @@ extern "C" {
|
||||
// preferrably to run on the same backend as the buffer
|
||||
ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
||||
|
||||
sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false);
|
||||
sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false, true);
|
||||
|
||||
// initialize buffers from a max size graph (optional)
|
||||
reserve_graph = build_graph(sched, max_batch_size);
|
||||
@ -289,7 +289,7 @@ extern "C" {
|
||||
typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
|
||||
|
||||
// Initialize a backend scheduler, backends with low index are given priority over backends with high index
|
||||
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
|
||||
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload);
|
||||
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
|
||||
|
||||
// Initialize backend buffers from a measure graph
|
||||
|
@ -37,13 +37,16 @@ extern "C" {
|
||||
// ====== Dataset ======
|
||||
|
||||
GGML_API ggml_opt_dataset_t ggml_opt_dataset_init(
|
||||
int64_t ne_datapoint, // number of elements per datapoint
|
||||
int64_t ne_label, // number of elements per label
|
||||
int64_t ndata, // total number of datapoints/labels
|
||||
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
|
||||
enum ggml_type type_data, // the type for the internal data tensor
|
||||
enum ggml_type type_label, // the type for the internal labels tensor
|
||||
int64_t ne_datapoint, // number of elements per datapoint
|
||||
int64_t ne_label, // number of elements per label
|
||||
int64_t ndata, // total number of datapoints/labels
|
||||
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
|
||||
GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset);
|
||||
|
||||
// get underlying tensors that store the data
|
||||
GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset);
|
||||
GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
|
||||
GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata]
|
||||
|
||||
@ -56,13 +59,19 @@ extern "C" {
|
||||
struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch]
|
||||
struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch]
|
||||
int64_t ibatch);
|
||||
GGML_API void ggml_opt_dataset_get_batch_host(
|
||||
ggml_opt_dataset_t dataset,
|
||||
void * data_batch,
|
||||
size_t nb_data_batch,
|
||||
void * labels_batch,
|
||||
int64_t ibatch);
|
||||
|
||||
// ====== Model / Context ======
|
||||
|
||||
enum ggml_opt_build_type {
|
||||
GGML_OPT_BUILD_TYPE_FORWARD,
|
||||
GGML_OPT_BUILD_TYPE_GRAD,
|
||||
GGML_OPT_BUILD_TYPE_OPT,
|
||||
GGML_OPT_BUILD_TYPE_FORWARD = 10,
|
||||
GGML_OPT_BUILD_TYPE_GRAD = 20,
|
||||
GGML_OPT_BUILD_TYPE_OPT = 30,
|
||||
};
|
||||
|
||||
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
|
||||
@ -81,20 +90,22 @@ extern "C" {
|
||||
// userdata can be used to pass arbitrary data
|
||||
typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata);
|
||||
|
||||
// returns the default optimizer params (constant)
|
||||
// returns the default optimizer params (constant, hard-coded values)
|
||||
// userdata is not used
|
||||
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata);
|
||||
|
||||
// casts userdata to ggml_opt_optimizer_params and returns it
|
||||
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata);
|
||||
|
||||
// parameters for initializing a new optimization context
|
||||
struct ggml_opt_params {
|
||||
ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
|
||||
|
||||
struct ggml_context * ctx_compute; // created in user code, holds non-static tensors
|
||||
|
||||
// the forward graph is defined by inputs and outputs
|
||||
// those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
|
||||
struct ggml_tensor * inputs;
|
||||
struct ggml_tensor * outputs;
|
||||
// by default the forward graph needs to be reconstructed for each eval
|
||||
// if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
|
||||
struct ggml_context * ctx_compute;
|
||||
struct ggml_tensor * inputs;
|
||||
struct ggml_tensor * outputs;
|
||||
|
||||
enum ggml_opt_loss_type loss_type;
|
||||
enum ggml_opt_build_type build_type;
|
||||
@ -107,12 +118,9 @@ extern "C" {
|
||||
|
||||
// get parameters for an optimization context with defaults set where possible
|
||||
// parameters for which no sensible defaults exist are supplied as arguments to this function
|
||||
GGML_API ggml_opt_params ggml_opt_default_params(
|
||||
ggml_backend_sched_t backend_sched,
|
||||
struct ggml_context * ctx_compute,
|
||||
struct ggml_tensor * inputs,
|
||||
struct ggml_tensor * outputs,
|
||||
enum ggml_opt_loss_type loss_type);
|
||||
GGML_API struct ggml_opt_params ggml_opt_default_params(
|
||||
ggml_backend_sched_t backend_sched,
|
||||
enum ggml_opt_loss_type loss_type);
|
||||
|
||||
GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);
|
||||
GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);
|
||||
@ -121,6 +129,7 @@ extern "C" {
|
||||
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
|
||||
|
||||
// get underlying tensors that store data
|
||||
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
|
||||
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
|
||||
GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor
|
||||
GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against
|
||||
@ -128,11 +137,12 @@ extern "C" {
|
||||
GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs
|
||||
GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
|
||||
|
||||
// get the gradient accumulator for a node from the forward graph
|
||||
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
|
||||
|
||||
// ====== Optimization Result ======
|
||||
|
||||
GGML_API ggml_opt_result_t ggml_opt_result_init();
|
||||
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
|
||||
GGML_API void ggml_opt_result_free(ggml_opt_result_t result);
|
||||
GGML_API void ggml_opt_result_reset(ggml_opt_result_t result);
|
||||
|
||||
@ -144,11 +154,20 @@ extern "C" {
|
||||
|
||||
// ====== Computation ======
|
||||
|
||||
// do forward pass, increment result if not NULL
|
||||
GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
|
||||
// if not using static graphs, this function must be called prior to ggml_opt_alloc
|
||||
GGML_API void ggml_opt_prepare_alloc(
|
||||
ggml_opt_context_t opt_ctx,
|
||||
struct ggml_context * ctx_compute,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_tensor * inputs,
|
||||
struct ggml_tensor * outputs);
|
||||
|
||||
// do forward pass, increment result if not NULL, do backward pass
|
||||
GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
|
||||
// allocate the next graph for evaluation, either forward or forward + backward
|
||||
// must be called exactly once prior to calling ggml_opt_eval
|
||||
GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward);
|
||||
|
||||
// do forward pass, increment result if not NULL, do backward pass if allocated
|
||||
GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
|
||||
|
||||
// ############################################################################
|
||||
// ## The high-level functions start here. They do not depend on any private ##
|
||||
@ -200,9 +219,9 @@ extern "C" {
|
||||
// fit model defined by inputs and outputs to dataset
|
||||
GGML_API void ggml_opt_fit(
|
||||
ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs
|
||||
ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
|
||||
ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
|
||||
ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
|
||||
struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
|
||||
struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
|
||||
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
|
||||
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
|
||||
enum ggml_opt_loss_type loss_type, // loss to minimize
|
||||
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
|
||||
|
@ -768,7 +768,7 @@ extern "C" {
|
||||
// Tensor flags
|
||||
GGML_API void ggml_set_input(struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_set_output(struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_set_param(struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
|
||||
|
||||
//
|
||||
@ -938,7 +938,7 @@ extern "C" {
|
||||
GGML_API struct ggml_tensor * ggml_repeat_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
struct ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride
|
||||
|
||||
// concat a and b along dim
|
||||
// used in stable-diffusion
|
||||
@ -2049,15 +2049,14 @@ extern "C" {
|
||||
|
||||
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_build_backward_expand(
|
||||
struct ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation)
|
||||
struct ggml_context * ctx_compute, // context for gradient computation
|
||||
struct ggml_cgraph * cgraph,
|
||||
bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
|
||||
struct ggml_context * ctx, // context for gradient computation
|
||||
struct ggml_cgraph * cgraph,
|
||||
struct ggml_tensor ** grad_accs);
|
||||
|
||||
// graph allocation in a context
|
||||
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
|
||||
GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads);
|
||||
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
|
||||
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads);
|
||||
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
|
||||
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
|
||||
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
|
||||
|
@ -674,6 +674,8 @@ struct ggml_backend_sched {
|
||||
char * context_buffer;
|
||||
size_t context_buffer_size;
|
||||
|
||||
bool op_offload;
|
||||
|
||||
int debug;
|
||||
};
|
||||
|
||||
@ -766,7 +768,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
|
||||
if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
|
||||
int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
|
||||
// check if a backend with higher prio wants to offload the op
|
||||
if (src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
|
||||
if (sched->op_offload && src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
|
||||
for (int b = 0; b < src_backend_id; b++) {
|
||||
if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
|
||||
SET_CAUSE(tensor, "1.off");
|
||||
@ -1109,7 +1111,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
|
||||
const int node_backend_id = tensor_backend_id(node);
|
||||
|
||||
assert(node_backend_id != -1); // all nodes should be assigned by now
|
||||
assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
|
||||
|
||||
// check if we should start a new split based on the sources of the current node
|
||||
bool need_new_split = false;
|
||||
@ -1452,7 +1454,8 @@ ggml_backend_sched_t ggml_backend_sched_new(
|
||||
ggml_backend_buffer_type_t * bufts,
|
||||
int n_backends,
|
||||
size_t graph_size,
|
||||
bool parallel) {
|
||||
bool parallel,
|
||||
bool op_offload) {
|
||||
GGML_ASSERT(n_backends > 0);
|
||||
GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
|
||||
GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
@ -1497,6 +1500,7 @@ ggml_backend_sched_t ggml_backend_sched_new(
|
||||
}
|
||||
|
||||
sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends);
|
||||
sched->op_offload = op_offload;
|
||||
|
||||
ggml_backend_sched_reset(sched);
|
||||
|
||||
|
@ -428,6 +428,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
${KLEIDIAI_SRC}/kai/ukernels/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
|
||||
|
||||
set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
|
||||
@ -438,17 +439,19 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
|
||||
string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
|
||||
|
||||
set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS})
|
||||
set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP})
|
||||
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
|
||||
|
||||
if (NOT DOTPROD_ENABLED MATCHES -1)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
|
||||
endif()
|
||||
|
||||
if (NOT I8MM_ENABLED MATCHES -1)
|
||||
@ -456,9 +459,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
endif()
|
||||
|
||||
if (NOT SME_ENABLED MATCHES -1)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c)
|
||||
set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2")
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c)
|
||||
set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
|
||||
endif()
|
||||
|
||||
set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
|
||||
|
@ -4,16 +4,22 @@
|
||||
|
||||
// KleidiAI micro-kernels
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
||||
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
||||
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
|
||||
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
|
||||
|
||||
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
||||
|
||||
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
|
||||
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
||||
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
||||
|
||||
#include "kai_common.h"
|
||||
|
||||
#include "kernels.h"
|
||||
@ -61,6 +67,53 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_SME,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
{
|
||||
/* SME GEMM */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
},
|
||||
/* SME GEMV */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
},
|
||||
/* .lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
||||
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_SME,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_F16,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
#endif
|
||||
#if defined(__APPLE__)
|
||||
@ -105,6 +158,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
@ -148,6 +204,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
#endif
|
||||
#else
|
||||
@ -192,6 +251,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
@ -235,12 +297,33 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
#endif
|
||||
#endif
|
||||
};
|
||||
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) {
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
|
||||
ggml_kleidiai_kernels * kernel = nullptr;
|
||||
|
||||
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
||||
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
|
||||
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
|
||||
gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
|
||||
gemm_gemv_kernels[i].op_type == tensor->type) {
|
||||
kernel = &gemm_gemv_kernels[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
|
||||
ggml_kleidiai_kernels * kernels = nullptr;
|
||||
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
||||
|
@ -4,6 +4,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include "ggml.h"
|
||||
|
||||
enum cpu_feature {
|
||||
CPU_FEATURE_NONE = 0,
|
||||
CPU_FEATURE_DOTPROD = 1,
|
||||
@ -26,26 +29,53 @@ struct kernel_info {
|
||||
size_t (*get_nr)(void);
|
||||
size_t (*get_kr)(void);
|
||||
size_t (*get_sr)(void);
|
||||
size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl);
|
||||
size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl);
|
||||
std::variant<
|
||||
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
|
||||
std::function<size_t(size_t m_idx, size_t k)>
|
||||
> get_lhs_offset;
|
||||
std::variant<
|
||||
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
|
||||
std::function<size_t(size_t n_idx, size_t k)>
|
||||
> get_rhs_packed_offset;
|
||||
size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
|
||||
size_t (*get_dst_size)(size_t m, size_t n);
|
||||
void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
|
||||
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max);
|
||||
std::variant<
|
||||
std::function<void(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
|
||||
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max)>,
|
||||
std::function<void(size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
|
||||
size_t dst_stride_col, float clamp_min, float clamp_max)>
|
||||
> run_kernel;
|
||||
};
|
||||
|
||||
struct lhs_packing_info {
|
||||
size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
|
||||
size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
||||
size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
||||
void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
|
||||
size_t lhs_stride, void* lhs_packed);
|
||||
std::variant<
|
||||
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
|
||||
std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)>
|
||||
> get_packed_offset;
|
||||
std::variant<
|
||||
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
|
||||
std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)>
|
||||
> packed_size;
|
||||
std::variant<
|
||||
std::function<void(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
|
||||
size_t lhs_stride, void* lhs_packed)>,
|
||||
std::function<void(size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
|
||||
void* lhs_packed)>
|
||||
> pack_func;
|
||||
};
|
||||
|
||||
struct rhs_packing_info {
|
||||
size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
|
||||
void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
|
||||
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params);
|
||||
std::variant<
|
||||
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
|
||||
std::function<size_t(size_t n, size_t k)>
|
||||
> packed_size;
|
||||
std::variant<
|
||||
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
|
||||
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
|
||||
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
|
||||
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
|
||||
> pack_func;
|
||||
};
|
||||
|
||||
struct ggml_kleidiai_kernels {
|
||||
@ -55,6 +85,10 @@ struct ggml_kleidiai_kernels {
|
||||
rhs_packing_info rhs_info;
|
||||
|
||||
cpu_feature required_cpu;
|
||||
ggml_type lhs_type;
|
||||
ggml_type rhs_type;
|
||||
ggml_type op_type;
|
||||
};
|
||||
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features);
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
|
||||
|
@ -34,8 +34,9 @@
|
||||
#include "ggml-common.h"
|
||||
|
||||
struct ggml_kleidiai_context {
|
||||
cpu_feature features;
|
||||
ggml_kleidiai_kernels * kernels;
|
||||
} static ctx = { NULL };
|
||||
} static ctx = { CPU_FEATURE_NONE, NULL };
|
||||
|
||||
static void init_kleidiai_context(void) {
|
||||
|
||||
@ -47,18 +48,18 @@ static void init_kleidiai_context(void) {
|
||||
const char *env_var = getenv("GGML_KLEIDIAI_SME");
|
||||
int sme_enabled = 0;
|
||||
|
||||
cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
|
||||
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
|
||||
(ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
|
||||
ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
|
||||
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
|
||||
(ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
|
||||
|
||||
if (env_var) {
|
||||
sme_enabled = atoi(env_var);
|
||||
}
|
||||
|
||||
if (sme_enabled != 0) {
|
||||
features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
||||
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
||||
}
|
||||
ctx.kernels = ggml_kleidiai_select_kernels(features);
|
||||
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
||||
}
|
||||
ggml_critical_section_end();
|
||||
}
|
||||
@ -68,95 +69,275 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
|
||||
return tensor->ne[dim];
|
||||
}
|
||||
|
||||
template<typename Ret, typename Variant, typename... Args>
|
||||
static Ret variant_call(const Variant & var, Args&&... args) {
|
||||
return std::visit([&](auto&& func) -> Ret {
|
||||
if constexpr (std::is_invocable_r_v<Ret, decltype(func), Args...>) {
|
||||
return func(std::forward<Args>(args)...);
|
||||
} else {
|
||||
throw std::runtime_error("Invalid function type in variant_call");
|
||||
}
|
||||
}, var);
|
||||
}
|
||||
|
||||
namespace ggml::cpu::kleidiai {
|
||||
|
||||
static size_t round_down(size_t x, size_t y) {
|
||||
return y == 0 ? x : x - (x % y);
|
||||
}
|
||||
|
||||
static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) {
|
||||
size_t src_stride = rhs_stride / sizeof(uint16_t);
|
||||
size_t dst_stride = n;
|
||||
|
||||
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
||||
for (size_t n_idx = 0; n_idx < n; ++n_idx) {
|
||||
uint16_t v = *(src + k_idx + n_idx * src_stride);
|
||||
*(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
||||
GGML_ASSERT(ctx.kernels);
|
||||
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
|
||||
GGML_ASSERT(kernels);
|
||||
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
||||
|
||||
size_t k = op->src[0]->ne[0];
|
||||
size_t n = op->src[0]->ne[1];
|
||||
size_t m = op->src[1]->ne[1];
|
||||
|
||||
size_t mr = kernel->get_mr();
|
||||
size_t kr = kernel->get_kr();
|
||||
size_t sr = kernel->get_sr();
|
||||
|
||||
size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr);
|
||||
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
|
||||
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr);
|
||||
} else if (kernels->rhs_type == GGML_TYPE_F16) {
|
||||
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr) +
|
||||
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
|
||||
k * n * sizeof(float) + n * sizeof(float);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
|
||||
if (dst->op == GGML_OP_MUL_MAT) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
||||
return compute_forward_q4_0(params, dst);
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
return compute_forward_kv_cache(params, dst);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
|
||||
|
||||
GGML_ASSERT(ctx.kernels);
|
||||
kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
|
||||
lhs_packing_info * lhs_info = &ctx.kernels->lhs_info;
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(kernel);
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
||||
GGML_ASSERT(kernels);
|
||||
|
||||
const size_t k = ne00;
|
||||
const size_t m = ne11;
|
||||
const size_t n = ne01;
|
||||
kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
||||
GGML_ASSERT(kernel);
|
||||
|
||||
const size_t n_step = kernel->get_n_step();
|
||||
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
|
||||
const size_t n_start = ith * num_n_per_thread;
|
||||
const int nth = params->nth;
|
||||
const int ith = params->ith;
|
||||
|
||||
size_t n_to_process = num_n_per_thread;
|
||||
if ((n_start + n_to_process) > n) {
|
||||
n_to_process = n - n_start;
|
||||
const int64_t lhs_batch_size0 = ne12;
|
||||
const int64_t rhs_batch_size0 = ne02;
|
||||
const int64_t batch_size = rhs_batch_size0;
|
||||
|
||||
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
|
||||
|
||||
const int64_t m = ne11 * r;
|
||||
const int64_t n = ne01;
|
||||
const int64_t k = ne00;
|
||||
|
||||
const size_t lhs_stride = src1->nb[1];
|
||||
const size_t rhs_stride = src0->nb[1];
|
||||
const size_t dst_stride = dst->nb[1];
|
||||
|
||||
const int64_t mr = static_cast<int64_t>(kernel->get_mr());
|
||||
const int64_t nr = static_cast<int64_t>(kernel->get_nr());
|
||||
const int64_t kr = static_cast<int64_t>(kernel->get_kr());
|
||||
const int64_t sr = static_cast<int64_t>(kernel->get_sr());
|
||||
|
||||
const size_t lhs_packed_size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr);
|
||||
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
|
||||
const size_t kxn_size = k * n * sizeof(float);
|
||||
const size_t bias_size = n * sizeof(float);
|
||||
|
||||
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
|
||||
GGML_ASSERT(wsize_required <= params->wsize);
|
||||
|
||||
uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
|
||||
uint8_t * rhs_packed = lhs_packed + lhs_packed_size;
|
||||
uint8_t * rhs_kxn = rhs_packed + rhs_packed_size;
|
||||
uint8_t * bias = rhs_kxn + kxn_size;
|
||||
|
||||
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
||||
const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride;
|
||||
const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride;
|
||||
uint8_t * dst_batch = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride;
|
||||
|
||||
// LHS packing
|
||||
{
|
||||
const int64_t m_roundup_mr = kai_roundup(m, mr);
|
||||
const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
|
||||
|
||||
if (ith < num_threads) {
|
||||
const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr);
|
||||
const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
|
||||
|
||||
const int64_t m_start = ith * num_m_per_thread0;
|
||||
const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
|
||||
|
||||
const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
|
||||
const size_t lhs_packed_offset = variant_call<size_t>(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr);
|
||||
|
||||
const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
|
||||
void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
|
||||
|
||||
variant_call<void>(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
|
||||
uint8_t * lhs_packed = (uint8_t*)params->wdata;
|
||||
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
|
||||
// RHS packing
|
||||
if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) {
|
||||
// First thread to reach this point handles RHS packing
|
||||
memset(bias, 0, n * sizeof(float));
|
||||
transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn),
|
||||
reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride);
|
||||
|
||||
size_t mr = kernel->get_mr();
|
||||
size_t kr = kernel->get_kr();
|
||||
size_t sr = kernel->get_sr();
|
||||
|
||||
// Calculate number of columns to be processed per thread
|
||||
const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
|
||||
const size_t m_start = ith * num_m_per_thread;
|
||||
size_t m_to_process = num_m_per_thread;
|
||||
if ((m_start + m_to_process) > m) {
|
||||
m_to_process = m - m_start;
|
||||
}
|
||||
|
||||
if(m_start < m) {
|
||||
// Transform LHS
|
||||
const size_t src_stride = src1->nb[1];
|
||||
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
|
||||
const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
|
||||
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
||||
|
||||
lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
|
||||
variant_call<void>(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float),
|
||||
rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
// Perform the operation
|
||||
const size_t dst_stride = dst->nb[1];
|
||||
const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr);
|
||||
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0);
|
||||
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
|
||||
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
|
||||
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
|
||||
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
||||
first_to_arrive.clear(std::memory_order_release);
|
||||
|
||||
kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr,
|
||||
dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
|
||||
return true;
|
||||
// Perform the matmul
|
||||
{
|
||||
const int64_t m_to_process = m;
|
||||
const int64_t m_start = 0;
|
||||
|
||||
const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
|
||||
const int64_t num_threads = KAI_MIN(n / n_step, nth);
|
||||
|
||||
if (ith < num_threads) {
|
||||
const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
|
||||
const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
|
||||
|
||||
const int64_t n_start = ith * num_n_per_thread0;
|
||||
const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
|
||||
|
||||
const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k);
|
||||
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k);
|
||||
const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride);
|
||||
|
||||
const void * lhs_ptr = lhs_packed + lhs_packed_offset;
|
||||
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
|
||||
float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset);
|
||||
|
||||
variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
if (batch_idx != batch_size - 1) {
|
||||
// This barrier is necessary when the batch size is larger than 1. While processing a batch,
|
||||
// the work data buffer (params->wdata) is used as temporary storage which means that only
|
||||
// a single batch can be processed at any given time. No barrier is needed for the last
|
||||
// batch since GGML inserts a barrier between the execution of every operator.
|
||||
ggml_barrier(params->threadpool);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
||||
GGML_ASSERT(kernels);
|
||||
|
||||
kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
||||
lhs_packing_info * lhs_info = &kernels->lhs_info;
|
||||
|
||||
GGML_ASSERT(kernel);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const size_t k = ne00;
|
||||
const size_t m = ne11;
|
||||
const size_t n = ne01;
|
||||
|
||||
size_t mr = kernel->get_mr();
|
||||
size_t kr = kernel->get_kr();
|
||||
size_t sr = kernel->get_sr();
|
||||
|
||||
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
|
||||
uint8_t * lhs_packed = (uint8_t*)params->wdata;
|
||||
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
|
||||
|
||||
const size_t n_step = kernel->get_n_step();
|
||||
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
|
||||
const size_t n_start = ith * num_n_per_thread;
|
||||
|
||||
size_t n_to_process = num_n_per_thread;
|
||||
if ((n_start + n_to_process) > n) {
|
||||
n_to_process = n - n_start;
|
||||
}
|
||||
|
||||
// Calculate number of columns to be processed per thread
|
||||
const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
|
||||
const size_t m_start = ith * num_m_per_thread;
|
||||
size_t m_to_process = num_m_per_thread;
|
||||
if ((m_start + m_to_process) > m) {
|
||||
m_to_process = m - m_start;
|
||||
}
|
||||
|
||||
if (m_start < m) {
|
||||
// Transform LHS
|
||||
const size_t src_stride = src1->nb[1];
|
||||
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
|
||||
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
|
||||
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
||||
|
||||
variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
// Perform the operation
|
||||
const size_t dst_stride = dst->nb[1];
|
||||
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
|
||||
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
|
||||
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
|
||||
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
|
||||
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
|
||||
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
||||
|
||||
variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
|
||||
sizeof(float), -FLT_MAX, FLT_MAX);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
public:
|
||||
@ -169,13 +350,13 @@ public:
|
||||
size_t sr = ctx.kernels->gemm.get_sr();
|
||||
|
||||
#ifndef NDEBUG
|
||||
const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0);
|
||||
const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
|
||||
GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
|
||||
#endif
|
||||
struct kai_rhs_pack_qs4cxs1s0_param params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, ¶ms);
|
||||
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms);
|
||||
|
||||
return 0;
|
||||
|
||||
@ -189,7 +370,7 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
|
||||
}
|
||||
} // namespace ggml::cpu::kleidiai
|
||||
|
||||
GGML_API enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
||||
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
||||
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
@ -238,12 +419,11 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
|
||||
namespace ggml::cpu::kleidiai {
|
||||
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
||||
if ( op->op == GGML_OP_MUL_MAT &&
|
||||
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
||||
op->src[0]->buffer &&
|
||||
(ggml_n_dims(op->src[0]) == 2) &&
|
||||
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels
|
||||
) {
|
||||
if (op->op == GGML_OP_MUL_MAT &&
|
||||
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
||||
op->src[0]->buffer &&
|
||||
(ggml_n_dims(op->src[0]) == 2) &&
|
||||
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
|
||||
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
||||
return false;
|
||||
}
|
||||
@ -260,6 +440,19 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
||||
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
||||
}
|
||||
else if (ggml_kleidiai_select_kernels(ctx.features, op) &&
|
||||
op->src[0]->op == GGML_OP_VIEW &&
|
||||
(op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) &&
|
||||
op->src[1]->ne[1] > 1) {
|
||||
if ((op->src[0]->nb[0] != 2) ||
|
||||
(op->src[1]->nb[0] != 4) ||
|
||||
(op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
|
||||
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -118,7 +118,7 @@ if (CUDAToolkit_FOUND)
|
||||
|
||||
set(CUDA_CXX_FLAGS "")
|
||||
|
||||
set(CUDA_FLAGS -use_fast_math)
|
||||
set(CUDA_FLAGS -use_fast_math -extended-lambda)
|
||||
|
||||
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
|
||||
# Options are:
|
||||
|
@ -1,47 +1,61 @@
|
||||
#include "acc.cuh"
|
||||
|
||||
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||
const int ne10, const int ne11, const int ne12,
|
||||
const int nb1, const int nb2, int offset) {
|
||||
const int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int64_t ne,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) {
|
||||
const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
int src1_idx = i - offset;
|
||||
int oz = src1_idx / nb2;
|
||||
int oy = (src1_idx - (oz * nb2)) / nb1;
|
||||
int ox = src1_idx % nb1;
|
||||
if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
|
||||
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
|
||||
} else {
|
||||
dst[i] = x[i];
|
||||
|
||||
int64_t src1_idx = i - offset;
|
||||
|
||||
int64_t tmp = src1_idx;
|
||||
const int64_t i13 = tmp / s13;
|
||||
tmp -= i13 * s13;
|
||||
const int64_t i12 = tmp / s12;
|
||||
tmp -= i12 * s12;
|
||||
const int64_t i11 = tmp / s11;
|
||||
tmp -= i11 * s11;
|
||||
const int64_t i10 = tmp;
|
||||
|
||||
float val = x[i];
|
||||
if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) {
|
||||
val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10];
|
||||
}
|
||||
dst[i] = val;
|
||||
}
|
||||
|
||||
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
|
||||
const int ne10, const int ne11, const int ne12,
|
||||
const int nb1, const int nb2, const int offset, cudaStream_t stream) {
|
||||
int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
|
||||
acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
|
||||
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int64_t n_elements,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s1, const int64_t s2, const int64_t s3, const int64_t offset, cudaStream_t stream) {
|
||||
const int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
|
||||
acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const float * src1_d = (const float *)src1->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
|
||||
|
||||
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
|
||||
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
|
||||
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
|
||||
int offset = dst->op_params[3] / 4; // offset in bytes
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(dst->nb[0] == ggml_element_size(dst));
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(dst));
|
||||
|
||||
acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, stream);
|
||||
const int64_t s1 = dst->op_params[0] / sizeof(float);
|
||||
const int64_t s2 = dst->op_params[1] / sizeof(float);
|
||||
const int64_t s3 = dst->op_params[2] / sizeof(float);
|
||||
const int64_t offset = dst->op_params[3] / sizeof(float);
|
||||
|
||||
acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], s1, s2, s3, offset, stream);
|
||||
}
|
||||
|
@ -296,6 +296,25 @@ static __device__ void no_device_code(
|
||||
#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
|
||||
#endif // __CUDA_ARCH__
|
||||
|
||||
// The compiler is always able to unroll loops if they contain continue expressions.
|
||||
// In such cases loop unrolling can still be achieved via recursion:
|
||||
template <int n>
|
||||
struct ggml_cuda_unroll {
|
||||
template <typename Func, typename... Args>
|
||||
__device__ void operator()(const Func & f, Args... args) const {
|
||||
f(n - 1, args...);
|
||||
ggml_cuda_unroll<n - 1>{}(f, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ggml_cuda_unroll<1> {
|
||||
template <typename Func, typename... Args>
|
||||
__device__ void operator()(const Func & f, Args... args) const {
|
||||
f(0, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ int warp_reduce_sum(int x) {
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
|
@ -2,6 +2,17 @@
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
|
||||
static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {
|
||||
#ifdef CP_ASYNC_AVAILABLE
|
||||
return __cvta_generic_to_shared(generic_ptr);
|
||||
#else
|
||||
GGML_UNUSED(generic_ptr);
|
||||
NO_DEVICE_CODE;
|
||||
return 0;
|
||||
#endif // CP_ASYNC_AVAILABLE
|
||||
}
|
||||
|
||||
// Copies data from global to shared memory, cg == cache global.
|
||||
// Both the src and dst pointers must be aligned to 16 bit.
|
||||
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
|
||||
|
@ -516,7 +516,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
||||
nullptr;
|
||||
}
|
||||
|
||||
template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
|
||||
template<int D, int ncols1, int ncols2> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
||||
@ -665,13 +665,13 @@ static void on_no_fattn_vec_case(const int D) {
|
||||
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
|
||||
GGML_ABORT("fatal error");
|
||||
} else {
|
||||
fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
|
||||
fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
|
||||
fprintf(stderr, "Only f16 is supported.\n");
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int ncols1, int ncols2, int KQ_stride>
|
||||
template <int DV, int ncols1, int ncols2>
|
||||
void launch_fattn(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
||||
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
||||
@ -691,7 +691,7 @@ void launch_fattn(
|
||||
|
||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||
|
||||
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
||||
|
||||
@ -754,10 +754,13 @@ void launch_fattn(
|
||||
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
||||
|
||||
const dim3 block_dim(warp_size, nwarps, 1);
|
||||
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
||||
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
||||
|
||||
dim3 blocks_num;
|
||||
if (stream_k) {
|
||||
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
||||
const int max_blocks = 2*nsm;
|
||||
const int max_blocks = max_blocks_per_sm*nsm;
|
||||
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
|
||||
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
|
||||
|
||||
@ -769,14 +772,11 @@ void launch_fattn(
|
||||
blocks_num.y = 1;
|
||||
blocks_num.z = 1;
|
||||
|
||||
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
|
||||
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
|
||||
} else {
|
||||
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
|
||||
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
|
||||
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
||||
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
||||
|
||||
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
|
||||
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
|
||||
|
||||
@ -853,19 +853,19 @@ void launch_fattn(
|
||||
|
||||
if (stream_k) {
|
||||
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
||||
|
||||
flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
|
||||
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
||||
}
|
||||
} else if (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
|
||||
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
||||
|
||||
flash_attn_combine_results<D>
|
||||
flash_attn_combine_results<DV>
|
||||
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
|
||||
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -307,7 +307,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, 1, -1>
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
|
||||
} break;
|
||||
case 128: {
|
||||
@ -315,7 +315,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, 1, -1>
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
|
||||
} break;
|
||||
default: {
|
||||
|
@ -318,7 +318,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, 1, -1>
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
|
||||
} break;
|
||||
case 128: {
|
||||
@ -326,7 +326,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, 1, -1>
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
|
||||
} break;
|
||||
default: {
|
||||
|
@ -168,6 +168,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ[j*D + tid] = -HALF_MAX_HALF;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
||||
|
||||
@ -315,7 +316,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
|
@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
|
@ -490,7 +490,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
@ -8,58 +8,32 @@
|
||||
#include "fattn-wmma-f16.cuh"
|
||||
#include "fattn.cuh"
|
||||
|
||||
template <int D, int ncols2>
|
||||
template <int DKQ, int DV, int ncols2>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
if (Q->ne[1] <= 8/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
if constexpr (ncols2 <= 8) {
|
||||
if (Q->ne[1] <= 8/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 16/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
|
||||
}
|
||||
|
||||
template <int ncols2>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
template <int DKQ, int DV>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
@ -68,27 +42,79 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
const float use_gqa_opt = mask && max_bias == 0.0f;
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f;
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio == 4) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
|
||||
if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio == 2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
|
||||
if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
GGML_ASSERT(V->ne[0] == 64);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
GGML_ASSERT(V->ne[0] == 80);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
GGML_ASSERT(V->ne[0] == 96);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
GGML_ASSERT(V->ne[0] == 112);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
GGML_ASSERT(V->ne[0] == 128);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
GGML_ASSERT(V->ne[0] == 256);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
||||
break;
|
||||
case 576: {
|
||||
// For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f;
|
||||
GGML_ASSERT(use_gqa_opt);
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
||||
@ -299,7 +325,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
||||
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
||||
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
||||
const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0;
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
|
||||
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
|
@ -10,10 +10,11 @@ static __global__ void k_get_rows(
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
|
||||
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
|
||||
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2;
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = blockIdx.z / ne12;
|
||||
const int i12 = blockIdx.z % ne12;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
@ -46,10 +47,11 @@ static __global__ void k_get_rows_float(
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
|
||||
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i00 = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = blockIdx.z / ne12;
|
||||
const int i12 = blockIdx.z % ne12;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
@ -94,8 +96,8 @@ static void get_rows_cuda_q(
|
||||
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
cudaStream_t stream) {
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
||||
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
|
||||
const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
||||
const dim3 block_nums(ne10, block_num_y, ne11*ne12);
|
||||
|
||||
// strides in elements
|
||||
// const size_t s0 = nb0 / sizeof(dst_t);
|
||||
@ -127,8 +129,8 @@ static void get_rows_cuda_float(
|
||||
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
cudaStream_t stream) {
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
|
||||
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
|
||||
const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
|
||||
const dim3 block_nums(ne10, block_num_y, ne11*ne12);
|
||||
|
||||
// strides in elements
|
||||
// const size_t s0 = nb0 / sizeof(dst_t);
|
||||
|
@ -1909,13 +1909,19 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
|
||||
|
||||
// If src0 is a temporary compute buffer it may have some padding that needs to be cleared for mul_mat_vec_q or mul_mat_q.
|
||||
// But if src0 is also a view of another tensor then this cannot be done safely because it may overwrite valid tensor data.
|
||||
// Therefore, in such cases use cuBLAS.
|
||||
const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
|
||||
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
|
||||
|
||||
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
|
||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||
bool use_mul_mat_q = ggml_is_quantized(src0->type)
|
||||
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||
|
||||
bool any_gpus_with_slow_fp16 = false;
|
||||
@ -3215,16 +3221,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return false;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
||||
// different head sizes of K and V are not supported yet
|
||||
return false;
|
||||
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
||||
if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) {
|
||||
return false;
|
||||
}
|
||||
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|
||||
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 192) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 576) {
|
||||
// DeepSeek MLA
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
@ -91,11 +91,11 @@ void ggml_cuda_mul_mat_q(
|
||||
|
||||
// If src0 is a temporary compute buffer, clear any potential padding.
|
||||
if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
|
||||
GGML_ASSERT(!src0->view_src);
|
||||
const size_t size_data = ggml_nbytes(src0);
|
||||
const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
|
||||
if (size_alloc > size_data) {
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
|
||||
GGML_ASSERT(!src0->view_src);
|
||||
CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream));
|
||||
}
|
||||
}
|
||||
|
@ -515,11 +515,11 @@ void ggml_cuda_mul_mat_vec_q(
|
||||
|
||||
// If src0 is a temporary compute buffer, clear any potential padding.
|
||||
if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
|
||||
GGML_ASSERT(!src0->view_src);
|
||||
const size_t size_data = ggml_nbytes(src0);
|
||||
const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
|
||||
if (size_alloc > size_data) {
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
|
||||
GGML_ASSERT(!src0->view_src);
|
||||
CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream));
|
||||
}
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 1);
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2);
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
||||
|
@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 32, 1);
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2);
|
||||
|
@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2);
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 64, 1);
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 1);
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2);
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8);
|
||||
|
@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f
|
||||
|
||||
"""
|
||||
|
||||
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n"
|
||||
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n"
|
||||
|
||||
TYPES_MMQ = [
|
||||
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
||||
@ -57,18 +57,21 @@ for vkq_size in [16, 32]:
|
||||
with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
|
||||
|
||||
for ncols in [8, 16, 32, 64, 128]:
|
||||
for ncols2 in [1, 2, 4, 8]:
|
||||
for ncols in [8, 16, 32, 64]:
|
||||
for ncols2 in [1, 2, 4, 8, 16]:
|
||||
if ncols2 > ncols:
|
||||
continue
|
||||
ncols1 = ncols // ncols2
|
||||
if ncols == 128:
|
||||
continue # Too much register pressure.
|
||||
with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_MMA_START)
|
||||
|
||||
for head_size in [64, 80, 96, 112, 128, 256]:
|
||||
if ncols == 128 and head_size == 256:
|
||||
continue # Needs too much shared memory.
|
||||
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size))
|
||||
for head_size_kq in [64, 80, 96, 112, 128, 256, 576]:
|
||||
if head_size_kq != 576 and ncols2 == 16:
|
||||
continue
|
||||
if head_size_kq == 576 and ncols2 != 16:
|
||||
continue
|
||||
head_size_v = head_size_kq if head_size_kq != 576 else 512
|
||||
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
|
||||
|
||||
for type in TYPES_MMQ:
|
||||
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
|
||||
|
@ -207,6 +207,10 @@ typedef struct {
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
int32_t sect_0;
|
||||
int32_t sect_1;
|
||||
int32_t sect_2;
|
||||
int32_t sect_3;
|
||||
} ggml_metal_kargs_rope;
|
||||
|
||||
typedef struct {
|
||||
@ -299,21 +303,42 @@ typedef struct {
|
||||
} ggml_metal_kargs_mul_mv_ext;
|
||||
|
||||
typedef struct {
|
||||
int32_t nei0;
|
||||
int32_t nei1;
|
||||
uint64_t nbi1;
|
||||
int32_t ne10;
|
||||
int32_t ne11; // n_expert_used (bcast)
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
int32_t neh11; // n_tokens
|
||||
uint64_t nbh11;
|
||||
int32_t ne20; // n_expert_used
|
||||
uint64_t nb21;
|
||||
} ggml_metal_kargs_mul_mm_id_map0;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne20; // n_expert_used
|
||||
int32_t neh0;
|
||||
int32_t neh1;
|
||||
uint64_t nbh1;
|
||||
uint64_t nbh2;
|
||||
int32_t ne0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
} ggml_metal_kargs_mul_mm_id_map1;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne02;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
int32_t ne11;
|
||||
int32_t ne12;
|
||||
int32_t ne13;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
uint64_t nb03;
|
||||
int32_t neh12;
|
||||
uint64_t nbh10;
|
||||
uint64_t nbh11;
|
||||
uint64_t nbh12;
|
||||
uint64_t nbh13;
|
||||
int32_t neh0;
|
||||
int32_t neh1;
|
||||
int16_t r2;
|
||||
int16_t r3;
|
||||
} ggml_metal_kargs_mul_mm_id;
|
||||
|
||||
typedef struct {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -2713,8 +2713,148 @@ kernel void kernel_rope_neox(
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_rope_multi(
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = (device const int32_t *) src1;
|
||||
|
||||
const float inv_ndims = -1.f/args.n_dims;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
if (i0 < args.n_dims) {
|
||||
const int ic = i0/2;
|
||||
|
||||
// mrope theta calculations
|
||||
// note: the rest is the same as kernel_rope_neox
|
||||
const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
|
||||
const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
|
||||
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
|
||||
const int sector = ic % sect_dims;
|
||||
|
||||
float theta_base;
|
||||
if (sector < args.sect_0) {
|
||||
theta_base = (float) pos[i2];
|
||||
} else if (sector < sec_w01) {
|
||||
theta_base = (float) pos[i2 + args.ne02];
|
||||
} else if (sector < sec_w012) {
|
||||
theta_base = (float) pos[i2 + args.ne02 * 2];
|
||||
} else {
|
||||
theta_base = (float) pos[i2 + args.ne02 * 3];
|
||||
}
|
||||
// end of mrope
|
||||
|
||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[args.n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
} else {
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_rope_vision(
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = (device const int32_t *) src1;
|
||||
|
||||
const float inv_ndims = -1.f/args.n_dims;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
|
||||
const int ic = i0/2;
|
||||
|
||||
// mrope theta calculations (only support 2 dimensions)
|
||||
const int sect_dims = args.sect_0 + args.sect_1;
|
||||
const int sector = ic % sect_dims;
|
||||
|
||||
float p;
|
||||
float theta_base;
|
||||
if (sector < args.sect_1) {
|
||||
p = (float) sector;
|
||||
theta_base = (float) pos[i2];
|
||||
} else {
|
||||
p = (float) sector - args.sect_0;
|
||||
theta_base = (float) pos[i2 + args.ne02];
|
||||
}
|
||||
|
||||
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
|
||||
// end of mrope
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[args.n_dims]; // different from kernel_rope_multi
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
|
||||
} else {
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
||||
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
||||
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
|
||||
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
|
||||
|
||||
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
||||
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
||||
@ -2722,6 +2862,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
|
||||
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
||||
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
||||
|
||||
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
|
||||
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
|
||||
|
||||
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
|
||||
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
|
||||
|
||||
typedef void (im2col_t)(
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
@ -6336,127 +6482,219 @@ kernel void kernel_mul_mm(
|
||||
}
|
||||
}
|
||||
|
||||
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
|
||||
// TODO: this kernel needs to be reimplemented from scratch for better performance
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
void kernel_mul_mm_id_impl(
|
||||
int32_t ne00,
|
||||
int32_t ne02,
|
||||
uint64_t nb01,
|
||||
uint64_t nb02,
|
||||
int32_t ne11,
|
||||
int32_t ne12,
|
||||
uint64_t nb10,
|
||||
uint64_t nb11,
|
||||
uint64_t nb12,
|
||||
int32_t ne0,
|
||||
int32_t ne1,
|
||||
int64_t ne0ne1,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
threadgroup ushort2 * rowids,
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
template<typename T4>
|
||||
kernel void kernel_mul_mm_id_map0(
|
||||
constant ggml_metal_kargs_mul_mm_id_map0 & args,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * hsrc1,
|
||||
device char * htpe,
|
||||
device char * hids,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int ide = tgpig[0]; // expert id
|
||||
|
||||
int n_all = 0;
|
||||
|
||||
device int32_t * ids_i32 = (device int32_t *) (hids);
|
||||
|
||||
for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
|
||||
device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
|
||||
|
||||
for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
|
||||
if (src2_i32[i20] != ide) {
|
||||
continue;
|
||||
}
|
||||
|
||||
device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
|
||||
device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
|
||||
|
||||
for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
|
||||
hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
|
||||
}
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
|
||||
}
|
||||
|
||||
++n_all;
|
||||
}
|
||||
}
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
device int32_t * tpe_i32 = (device int32_t *) (htpe);
|
||||
tpe_i32[ide] = n_all;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_mul_mm_id_map1(
|
||||
constant ggml_metal_kargs_mul_mm_id_map1 & args,
|
||||
device const char * hdst,
|
||||
device const char * hids,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i20 = tgpig[0]; // used expert
|
||||
const int i21 = tgpig[1]; // token
|
||||
|
||||
device const int32_t * ids_i32 = (device const int32_t *) (hids);
|
||||
device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
|
||||
|
||||
const int id = ids_i32[i21*args.ne20 + i20];
|
||||
|
||||
const int ide = id / args.neh1;
|
||||
const int idt = id % args.neh1;
|
||||
|
||||
device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
|
||||
|
||||
for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
|
||||
dst_f32x4[i0] = hdst_f32x4[i0];
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
|
||||
|
||||
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
||||
kernel void kernel_mul_mm_id(
|
||||
constant ggml_metal_kargs_mul_mm_id & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * tpe,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup half * sa = (threadgroup half *)(shmem);
|
||||
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
|
||||
threadgroup T * sa = (threadgroup T *)(shmem);
|
||||
threadgroup half * sb = (threadgroup half *)(shmem + 4096);
|
||||
|
||||
const int r0 = tgpig.y;
|
||||
const int r1 = tgpig.x;
|
||||
const int im = tgpig.z;
|
||||
|
||||
if (r1*BLOCK_SIZE_N >= ne1) return;
|
||||
device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
|
||||
|
||||
const int neh1 = tpe_i32[im];
|
||||
|
||||
if (r1*BLOCK_SIZE_N >= neh1) {
|
||||
return;
|
||||
}
|
||||
|
||||
// if this block is of 64x32 shape or smaller
|
||||
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
||||
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
||||
const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
||||
const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
||||
|
||||
// a thread shouldn't load data outside of the matrix
|
||||
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||
const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||
|
||||
simdgroup_half8x8 ma[4];
|
||||
simdgroup_float8x8 mb[2];
|
||||
simdgroup_T8x8 ma[4];
|
||||
simdgroup_half8x8 mb[2];
|
||||
simdgroup_float8x8 mc[8];
|
||||
for (int i = 0; i < 8; i++){
|
||||
|
||||
for (short i = 0; i < 8; i++){
|
||||
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
||||
}
|
||||
|
||||
short il = (tiitg % THREAD_PER_ROW);
|
||||
|
||||
ushort offset1 = il/nl;
|
||||
const int i12 = im%args.neh12;
|
||||
const int i13 = im/args.neh12;
|
||||
|
||||
threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
|
||||
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const short offset1 = il/nl;
|
||||
|
||||
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
|
||||
device const float * y = (device const float *)(src1
|
||||
+ nb12 * id[1]
|
||||
+ nb11 * (id[0] % ne11)
|
||||
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||
device const block_q * x = (device const block_q *)(src0
|
||||
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
|
||||
|
||||
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
||||
device const half * y = (device const half *)(src1
|
||||
+ args.nbh13*i13
|
||||
+ args.nbh12*i12
|
||||
+ args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
|
||||
+ args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||
|
||||
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
|
||||
// load data and store to threadgroup memory
|
||||
half4x4 temp_a;
|
||||
T4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
||||
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
||||
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
||||
#pragma unroll(16)
|
||||
for (short i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
|
||||
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
|
||||
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
||||
}
|
||||
|
||||
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
||||
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
|
||||
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||
y += BLOCK_SIZE_K;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// load matrices from threadgroup memory and conduct outer products
|
||||
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
||||
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
||||
threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
|
||||
threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
|
||||
|
||||
#pragma unroll(BLOCK_SIZE_K/8)
|
||||
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
||||
#pragma unroll(4)
|
||||
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (short i = 0; i < 4; i++) {
|
||||
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
#pragma unroll(2)
|
||||
for (int i = 0; i < 2; i++) {
|
||||
for (short i = 0; i < 2; i++) {
|
||||
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
|
||||
}
|
||||
|
||||
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
||||
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
||||
|
||||
#pragma unroll(8)
|
||||
for (int i = 0; i < 8; i++){
|
||||
for (short i = 0; i < 8; i++){
|
||||
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
||||
}
|
||||
|
||||
lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
|
||||
lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
|
||||
device float * C = (device float *) dst +
|
||||
(BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
|
||||
(BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
|
||||
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
|
||||
}
|
||||
} else {
|
||||
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup float * temp_str = ((threadgroup float *) shmem) \
|
||||
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
||||
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sgitg == 0) {
|
||||
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
||||
threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
|
||||
int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
|
||||
|
||||
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
|
||||
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
|
||||
device float4 * D4 = (device float4 *) D;
|
||||
|
||||
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
|
||||
@ -6476,66 +6714,6 @@ void kernel_mul_mm_id_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
kernel void kernel_mul_mm_id(
|
||||
constant ggml_metal_kargs_mul_mm_id & args,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
device const char * ids,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
const int32_t i02 = tgpig.z;
|
||||
|
||||
tgpig.z = 0;
|
||||
|
||||
device const char * src0 = src0s + i02*args.nb02;
|
||||
|
||||
// row indices
|
||||
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
|
||||
|
||||
// TODO: parallelize this loop
|
||||
int32_t _ne1 = 0;
|
||||
for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
|
||||
for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
|
||||
int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
|
||||
if (id == i02) {
|
||||
if (tiitg == 0) {
|
||||
rowids[_ne1] = ushort2(ii0, ii1);
|
||||
}
|
||||
_ne1++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
||||
args.ne00,
|
||||
args.ne02,
|
||||
args.nb01,
|
||||
args.nb02,
|
||||
args.ne11,
|
||||
args.ne12,
|
||||
args.nb10,
|
||||
args.nb11,
|
||||
args.nb12,
|
||||
args.ne0,
|
||||
_ne1,
|
||||
(int64_t)args.ne0*args.ne1,
|
||||
src0,
|
||||
src1,
|
||||
rowids,
|
||||
dst,
|
||||
shmem,
|
||||
tgpig,
|
||||
tiitg,
|
||||
sgitg);
|
||||
}
|
||||
|
||||
#define QK_NL 16
|
||||
|
||||
//
|
||||
@ -6576,63 +6754,64 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get
|
||||
// matrix-matrix multiplication
|
||||
//
|
||||
|
||||
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
|
||||
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
|
||||
//
|
||||
// indirect matrix-matrix multiplication
|
||||
//
|
||||
|
||||
typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
|
||||
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_id;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
|
||||
|
||||
//
|
||||
// matrix-vector multiplication
|
||||
|
@ -4855,8 +4855,6 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
func = ggml_cl_add;
|
||||
break;
|
||||
case GGML_OP_MUL:
|
||||
|
@ -28,16 +28,19 @@ struct ggml_opt_dataset {
|
||||
};
|
||||
|
||||
struct ggml_opt_context {
|
||||
ggml_backend_sched_t backend_sched = nullptr;
|
||||
ggml_cgraph * allocated_graph = nullptr;
|
||||
ggml_cgraph * allocated_graph_copy = nullptr;
|
||||
struct ggml_context * ctx_static = nullptr;
|
||||
struct ggml_context * ctx_static_cpu = nullptr;
|
||||
struct ggml_context * ctx_compute = nullptr;
|
||||
struct ggml_context * ctx_copy = nullptr;
|
||||
ggml_backend_buffer_t buf_static = nullptr;
|
||||
ggml_backend_buffer_t buf_static_cpu = nullptr;
|
||||
std::mt19937 rng;
|
||||
ggml_backend_sched_t backend_sched = nullptr;
|
||||
ggml_cgraph * allocated_graph = nullptr;
|
||||
ggml_cgraph * allocated_graph_copy = nullptr;
|
||||
struct ggml_context * ctx_static = nullptr;
|
||||
struct ggml_context * ctx_cpu = nullptr;
|
||||
struct ggml_context * ctx_compute = nullptr;
|
||||
struct ggml_context * ctx_copy = nullptr;
|
||||
ggml_backend_buffer_t buf_static = nullptr;
|
||||
ggml_backend_buffer_t buf_cpu = nullptr;
|
||||
std::mt19937 rng;
|
||||
enum ggml_opt_loss_type loss_type;
|
||||
enum ggml_opt_build_type build_type;
|
||||
enum ggml_opt_build_type build_type_alloc;
|
||||
|
||||
struct ggml_tensor * inputs = nullptr;
|
||||
struct ggml_tensor * outputs = nullptr;
|
||||
@ -50,6 +53,11 @@ struct ggml_opt_context {
|
||||
struct ggml_cgraph * gf = nullptr;
|
||||
struct ggml_cgraph * gb_grad = nullptr;
|
||||
struct ggml_cgraph * gb_opt = nullptr;
|
||||
bool static_graphs = false;
|
||||
bool eval_ready = false;
|
||||
std::vector<struct ggml_tensor *> grad_accs;
|
||||
std::vector<struct ggml_tensor *> grad_m;
|
||||
std::vector<struct ggml_tensor *> grad_v;
|
||||
|
||||
int64_t iter = 1;
|
||||
int32_t opt_period = 1;
|
||||
@ -73,7 +81,13 @@ struct ggml_opt_result {
|
||||
|
||||
// ====== Dataset ======
|
||||
|
||||
ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) {
|
||||
ggml_opt_dataset_t ggml_opt_dataset_init(
|
||||
enum ggml_type type_data,
|
||||
enum ggml_type type_label,
|
||||
int64_t ne_datapoint,
|
||||
int64_t ne_label,
|
||||
int64_t ndata,
|
||||
int64_t ndata_shard) {
|
||||
GGML_ASSERT(ne_datapoint > 0);
|
||||
GGML_ASSERT(ne_label >= 0);
|
||||
GGML_ASSERT(ndata > 0);
|
||||
@ -92,11 +106,11 @@ ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label,
|
||||
result->ctx = ggml_init(params);
|
||||
}
|
||||
|
||||
result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata);
|
||||
result->data = ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata);
|
||||
result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata;
|
||||
|
||||
if (ne_label > 0) {
|
||||
result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata);
|
||||
result->labels = ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata);
|
||||
result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata;
|
||||
} else {
|
||||
result->labels = nullptr;
|
||||
@ -119,6 +133,10 @@ void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) {
|
||||
delete dataset;
|
||||
}
|
||||
|
||||
int64_t ggml_opt_dataset_ndata(ggml_opt_dataset_t dataset) {
|
||||
return dataset->ndata;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) {
|
||||
return dataset->data;
|
||||
}
|
||||
@ -144,6 +162,8 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
|
||||
GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch));
|
||||
GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch));
|
||||
GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
|
||||
GGML_ASSERT( data_batch->type == dataset->data->type);
|
||||
GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type);
|
||||
|
||||
const size_t nb_data_batch = ggml_nbytes(data_batch);
|
||||
GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
|
||||
@ -171,6 +191,31 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_opt_dataset_get_batch_host(ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) {
|
||||
GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
|
||||
GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
|
||||
|
||||
const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
|
||||
|
||||
GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
|
||||
|
||||
for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
|
||||
const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
|
||||
|
||||
const char * ptr_data = (const char *) dataset->data->data + ishard *dataset->nbs_data;
|
||||
char * ptr_data_batch = (char *) data_batch + ishard_batch*dataset->nbs_data;
|
||||
memcpy(ptr_data_batch, ptr_data, dataset->nbs_data);
|
||||
|
||||
if (!labels_batch) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const char * ptr_labels = (const char *) dataset->labels->data + ishard *dataset->nbs_labels;
|
||||
char * ptr_labels_batch = (char *) labels_batch + ishard_batch*dataset->nbs_labels;
|
||||
memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels);
|
||||
}
|
||||
}
|
||||
|
||||
// ====== Model / Context ======
|
||||
|
||||
struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) {
|
||||
@ -187,17 +232,18 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
|
||||
return *((struct ggml_opt_optimizer_params *) userdata);
|
||||
}
|
||||
|
||||
struct ggml_opt_params ggml_opt_default_params(
|
||||
ggml_backend_sched_t backend_sched,
|
||||
struct ggml_context * ctx_compute,
|
||||
struct ggml_tensor * inputs,
|
||||
struct ggml_tensor * outputs,
|
||||
enum ggml_opt_loss_type loss_type) {
|
||||
return {
|
||||
/*backend_sched =*/ backend_sched,
|
||||
/*ctx_compute =*/ ctx_compute,
|
||||
/*inputs =*/ inputs,
|
||||
/*logits =*/ outputs,
|
||||
/*ctx_compute =*/ nullptr,
|
||||
/*inputs =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
/*loss_type =*/ loss_type,
|
||||
/*build_type =*/ GGML_OPT_BUILD_TYPE_OPT,
|
||||
/*opt_period =*/ 1,
|
||||
@ -266,195 +312,246 @@ static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) {
|
||||
return dst;
|
||||
}
|
||||
|
||||
static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) {
|
||||
GGML_ASSERT(graph);
|
||||
if (opt_ctx->allocated_graph == graph) {
|
||||
return;
|
||||
}
|
||||
static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
||||
GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
|
||||
GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
|
||||
|
||||
ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
|
||||
const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
|
||||
!(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
|
||||
|
||||
{
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE,
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ggml_free(opt_ctx->ctx_copy);
|
||||
opt_ctx->ctx_copy = ggml_init(params);
|
||||
}
|
||||
|
||||
opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
|
||||
|
||||
ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
||||
opt_ctx->allocated_graph = graph;
|
||||
}
|
||||
|
||||
ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
||||
ggml_opt_context_t result = new struct ggml_opt_context;
|
||||
result->backend_sched = params.backend_sched;
|
||||
result->ctx_compute = params.ctx_compute;
|
||||
result->inputs = params.inputs;
|
||||
result->outputs = params.outputs;
|
||||
result->opt_period = params.opt_period;
|
||||
result->get_opt_pars = params.get_opt_pars;
|
||||
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
||||
|
||||
GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically");
|
||||
GGML_ASSERT(result->opt_period >= 1);
|
||||
|
||||
const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD ||
|
||||
(params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1);
|
||||
|
||||
ggml_set_input(result->inputs);
|
||||
ggml_set_output(result->outputs);
|
||||
|
||||
result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
|
||||
ggml_build_forward_expand(result->gf, result->outputs);
|
||||
ggml_set_input(opt_ctx->inputs);
|
||||
ggml_set_output(opt_ctx->outputs);
|
||||
|
||||
int n_param = 0;
|
||||
for (int i = 0; i < result->gf->n_nodes; ++i) {
|
||||
if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
|
||||
for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) {
|
||||
const struct ggml_tensor * node = opt_ctx->gf->nodes[i];
|
||||
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
|
||||
n_param++;
|
||||
}
|
||||
GGML_ASSERT(!(node->flags & GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented");
|
||||
}
|
||||
|
||||
{
|
||||
if (!opt_ctx->ctx_static) {
|
||||
// The static context is used for:
|
||||
// - gradients (1 tensor per param if using gradient accumulation)
|
||||
// - gradients (1 per loss, 1 tensor per param if using gradient accumulation)
|
||||
// - optimizer momenta (2 tensors per param)
|
||||
// - labels
|
||||
// - loss + its gradient (up to 5 tensors)
|
||||
// - pred
|
||||
// - ncorrect (2 tensors).
|
||||
const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
|
||||
const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead();
|
||||
// - labels (if using static graphs)
|
||||
// - loss (if using static graphs, up to 5 tensors)
|
||||
// - pred (if using static graphs)
|
||||
// - ncorrect (if using static graphs, 2 tensors).
|
||||
constexpr size_t n_loss = 1;
|
||||
const size_t tensors_per_param = (accumulate ? 1 : 0) +
|
||||
(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
|
||||
const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
|
||||
const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ size_meta,
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
result->ctx_static = ggml_init(params);
|
||||
opt_ctx->ctx_static = ggml_init(params);
|
||||
}
|
||||
GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc);
|
||||
|
||||
{
|
||||
// The static cpu context is used for:
|
||||
// - optimizer parameters (1 for the entire context)
|
||||
// The cpu context is allocated statically if using static graphs, dynamically otherwise.
|
||||
// It is used for:
|
||||
// - optimizer parameters (1 shared for all optimizer invocations)
|
||||
const size_t size_meta = 1 * ggml_tensor_overhead();
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ size_meta,
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
result->ctx_static_cpu = ggml_init(params);
|
||||
ggml_free(opt_ctx->ctx_cpu);
|
||||
opt_ctx->ctx_cpu = ggml_init(params);
|
||||
|
||||
ggml_backend_buffer_free(opt_ctx->buf_cpu);
|
||||
opt_ctx->buf_cpu = nullptr;
|
||||
}
|
||||
|
||||
struct ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute;
|
||||
|
||||
switch (params.loss_type) {
|
||||
switch (opt_ctx->loss_type) {
|
||||
case GGML_OPT_LOSS_TYPE_MEAN: {
|
||||
result->loss = ggml_sum(result->ctx_static, result->outputs);
|
||||
ggml_set_name(result->loss, "loss_sum");
|
||||
const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
|
||||
result->loss = ggml_scale(result->ctx_static, result->loss, scale);
|
||||
ggml_set_name(result->loss, "loss_mean");
|
||||
result->loss_per_datapoint = true;
|
||||
opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
|
||||
ggml_set_name(opt_ctx->loss, "loss_sum");
|
||||
const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
|
||||
opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
|
||||
ggml_set_name(opt_ctx->loss, "loss_mean");
|
||||
opt_ctx->loss_per_datapoint = true;
|
||||
break;
|
||||
}
|
||||
case GGML_OPT_LOSS_TYPE_SUM: {
|
||||
result->loss = ggml_sum(result->ctx_static, result->outputs);
|
||||
ggml_set_name(result->loss, "loss_sum");
|
||||
result->loss_per_datapoint = false;
|
||||
opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
|
||||
ggml_set_name(opt_ctx->loss, "loss_sum");
|
||||
opt_ctx->loss_per_datapoint = false;
|
||||
break;
|
||||
}
|
||||
case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
|
||||
result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
|
||||
ggml_set_input(result->labels);
|
||||
ggml_set_name(result->labels, "labels");
|
||||
result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels);
|
||||
ggml_set_name(result->loss, "loss_cross_entropy");
|
||||
if (result->opt_period > 1) {
|
||||
result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period);
|
||||
ggml_set_name(result->loss, "loss_cross_entropy_scaled");
|
||||
opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
|
||||
ggml_set_input(opt_ctx->labels);
|
||||
ggml_set_name(opt_ctx->labels, "labels");
|
||||
opt_ctx->loss = ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels);
|
||||
ggml_set_name(opt_ctx->loss, "loss_cross_entropy");
|
||||
if (opt_ctx->opt_period > 1) {
|
||||
opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period);
|
||||
ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled");
|
||||
}
|
||||
result->loss_per_datapoint = true;
|
||||
opt_ctx->loss_per_datapoint = true;
|
||||
break;
|
||||
}
|
||||
case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
|
||||
result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
|
||||
ggml_set_input(result->labels);
|
||||
ggml_set_name(result->labels, "labels");
|
||||
result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels);
|
||||
ggml_set_name(result->loss, "loss_error");
|
||||
result->loss = ggml_sqr(result->ctx_static, result->loss);
|
||||
ggml_set_name(result->loss, "loss_squared_error");
|
||||
result->loss = ggml_sum(result->ctx_static, result->loss);
|
||||
ggml_set_name(result->loss, "loss_sum_squared_error");
|
||||
const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
|
||||
result->loss = ggml_scale(result->ctx_static, result->loss, scale);
|
||||
ggml_set_name(result->loss, "loss_mean_squared_error");
|
||||
result->loss_per_datapoint = true;
|
||||
opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
|
||||
ggml_set_input(opt_ctx->labels);
|
||||
ggml_set_name(opt_ctx->labels, "labels");
|
||||
opt_ctx->loss = ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels);
|
||||
ggml_set_name(opt_ctx->loss, "loss_error");
|
||||
opt_ctx->loss = ggml_sqr(ctx_results, opt_ctx->loss);
|
||||
ggml_set_name(opt_ctx->loss, "loss_squared_error");
|
||||
opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->loss);
|
||||
ggml_set_name(opt_ctx->loss, "loss_sum_squared_error");
|
||||
const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
|
||||
opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
|
||||
ggml_set_name(opt_ctx->loss, "loss_mean_squared_error");
|
||||
opt_ctx->loss_per_datapoint = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
ggml_set_output(result->loss);
|
||||
ggml_set_loss(result->loss);
|
||||
ggml_build_forward_expand(result->gf, result->loss);
|
||||
ggml_set_output(opt_ctx->loss);
|
||||
ggml_set_loss(opt_ctx->loss);
|
||||
ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss);
|
||||
|
||||
result->pred = ggml_argmax(result->ctx_static, result->outputs);
|
||||
ggml_set_name(result->pred, "pred");
|
||||
ggml_set_output(result->pred);
|
||||
ggml_build_forward_expand(result->gf, result->pred);
|
||||
if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) {
|
||||
opt_ctx->pred = ggml_argmax(ctx_results, opt_ctx->outputs);
|
||||
ggml_set_name(opt_ctx->pred, "pred");
|
||||
ggml_set_output(opt_ctx->pred);
|
||||
ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred);
|
||||
|
||||
if (result->labels) {
|
||||
result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels));
|
||||
ggml_set_name(result->ncorrect, "ncorrect");
|
||||
ggml_set_output(result->ncorrect);
|
||||
ggml_build_forward_expand(result->gf, result->ncorrect);
|
||||
} else {
|
||||
result->ncorrect = nullptr;
|
||||
opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels));
|
||||
ggml_set_name(opt_ctx->ncorrect, "ncorrect");
|
||||
ggml_set_output(opt_ctx->ncorrect);
|
||||
ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect);
|
||||
}
|
||||
|
||||
if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
|
||||
result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
|
||||
return result;
|
||||
if (opt_ctx->buf_static) {
|
||||
if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
|
||||
return;
|
||||
}
|
||||
} else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_FORWARD) {
|
||||
opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
|
||||
opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
|
||||
return;
|
||||
}
|
||||
|
||||
// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
|
||||
result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf);
|
||||
ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate);
|
||||
if (opt_ctx->grad_accs.empty()) {
|
||||
GGML_ASSERT(opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD);
|
||||
|
||||
if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) {
|
||||
result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
|
||||
ggml_graph_reset(result->gb_grad);
|
||||
return result;
|
||||
}
|
||||
const int n_nodes = opt_ctx->gf->n_nodes;
|
||||
opt_ctx->grad_accs.resize(n_nodes);
|
||||
for (int i = 0; i < n_nodes; ++i) {
|
||||
ggml_tensor * node = opt_ctx->gf->nodes[i];
|
||||
if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
|
||||
opt_ctx->grad_accs[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
|
||||
} else {
|
||||
opt_ctx->grad_accs[i] = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT);
|
||||
|
||||
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
|
||||
result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad);
|
||||
|
||||
result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7);
|
||||
ggml_set_input(result->adamw_params);
|
||||
ggml_set_name(result->adamw_params, "adamw_params");
|
||||
|
||||
for (int i = result->gf->n_nodes-1; i >= 0; --i) {
|
||||
struct ggml_tensor * node = result->gb_opt->nodes[i];
|
||||
struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node);
|
||||
|
||||
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
|
||||
struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node);
|
||||
struct ggml_tensor * v = ggml_dup_tensor(result->ctx_static, node);
|
||||
struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params);
|
||||
ggml_build_forward_expand(result->gb_opt, opt_step);
|
||||
if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
|
||||
opt_ctx->grad_m.resize(n_nodes);
|
||||
opt_ctx->grad_v.resize(n_nodes);
|
||||
for (int i = 0; i < n_nodes; ++i) {
|
||||
ggml_tensor * node = opt_ctx->gf->nodes[i];
|
||||
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
|
||||
opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
|
||||
opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
|
||||
} else {
|
||||
opt_ctx->grad_m[i] = nullptr;
|
||||
opt_ctx->grad_v[i] = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result->buf_static = ggml_backend_alloc_ctx_tensors(
|
||||
result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
|
||||
// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
|
||||
opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true);
|
||||
ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data());
|
||||
|
||||
result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type());
|
||||
if (opt_ctx->buf_static) {
|
||||
if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_GRAD) {
|
||||
return;
|
||||
}
|
||||
} else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_GRAD) {
|
||||
opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
|
||||
ggml_graph_reset(opt_ctx->gb_grad);
|
||||
}
|
||||
|
||||
ggml_graph_reset(result->gb_opt);
|
||||
GGML_ASSERT(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT);
|
||||
|
||||
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
|
||||
opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
|
||||
|
||||
opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7);
|
||||
ggml_set_input(opt_ctx->adamw_params);
|
||||
ggml_set_name(opt_ctx->adamw_params, "adamw_params");
|
||||
|
||||
for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
|
||||
struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
|
||||
struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
|
||||
|
||||
if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
|
||||
struct ggml_tensor * m = opt_ctx->grad_m[i];
|
||||
struct ggml_tensor * v = opt_ctx->grad_v[i];
|
||||
struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
|
||||
|
||||
ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str());
|
||||
ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str());
|
||||
ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
|
||||
|
||||
ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
|
||||
}
|
||||
}
|
||||
|
||||
if (!opt_ctx->buf_static) {
|
||||
opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
|
||||
opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
|
||||
ggml_graph_reset(opt_ctx->gb_opt);
|
||||
}
|
||||
|
||||
opt_ctx->buf_cpu = ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, ggml_backend_cpu_buffer_type());
|
||||
}
|
||||
|
||||
ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
||||
ggml_opt_context_t result = new struct ggml_opt_context;
|
||||
result->backend_sched = params.backend_sched;
|
||||
result->ctx_compute = params.ctx_compute;
|
||||
result->loss_type = params.loss_type;
|
||||
result->build_type = params.build_type;
|
||||
result->build_type_alloc = params.build_type;
|
||||
result->inputs = params.inputs;
|
||||
result->outputs = params.outputs;
|
||||
result->opt_period = params.opt_period;
|
||||
result->get_opt_pars = params.get_opt_pars;
|
||||
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
||||
|
||||
GGML_ASSERT(result->opt_period >= 1);
|
||||
|
||||
result->static_graphs = result->ctx_compute;
|
||||
|
||||
if (!result->static_graphs) {
|
||||
GGML_ASSERT(!result->inputs);
|
||||
GGML_ASSERT(!result->outputs);
|
||||
return result;
|
||||
}
|
||||
|
||||
GGML_ASSERT(result->inputs);
|
||||
GGML_ASSERT(result->outputs);
|
||||
|
||||
result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
|
||||
ggml_build_forward_expand(result->gf, result->outputs);
|
||||
|
||||
ggml_opt_build(result);
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -464,9 +561,9 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) {
|
||||
return;
|
||||
}
|
||||
ggml_backend_buffer_free(opt_ctx->buf_static);
|
||||
ggml_backend_buffer_free(opt_ctx->buf_static_cpu);
|
||||
ggml_backend_buffer_free(opt_ctx->buf_cpu);
|
||||
ggml_free(opt_ctx->ctx_static);
|
||||
ggml_free(opt_ctx->ctx_static_cpu);
|
||||
ggml_free(opt_ctx->ctx_cpu);
|
||||
delete opt_ctx;
|
||||
}
|
||||
|
||||
@ -582,8 +679,79 @@ void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, doubl
|
||||
|
||||
// ====== Computation ======
|
||||
|
||||
static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) {
|
||||
if (graph != opt_ctx->gf) {
|
||||
void ggml_opt_prepare_alloc(
|
||||
ggml_opt_context_t opt_ctx,
|
||||
struct ggml_context * ctx_compute,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_tensor * inputs,
|
||||
struct ggml_tensor * outputs) {
|
||||
GGML_ASSERT(!opt_ctx->static_graphs);
|
||||
opt_ctx->ctx_compute = ctx_compute;
|
||||
opt_ctx->gf = gf;
|
||||
opt_ctx->inputs = inputs;
|
||||
opt_ctx->outputs = outputs;
|
||||
}
|
||||
|
||||
void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
|
||||
GGML_ASSERT(!opt_ctx->eval_ready);
|
||||
if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) {
|
||||
ggml_graph_reset(opt_ctx->gb_grad);
|
||||
}
|
||||
if (backward) {
|
||||
const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
|
||||
opt_ctx->build_type = opt_i_next == 0 ? GGML_OPT_BUILD_TYPE_OPT : GGML_OPT_BUILD_TYPE_GRAD;
|
||||
} else {
|
||||
opt_ctx->build_type = GGML_OPT_BUILD_TYPE_FORWARD;
|
||||
}
|
||||
|
||||
if (!opt_ctx->static_graphs) {
|
||||
ggml_opt_build(opt_ctx);
|
||||
}
|
||||
|
||||
struct ggml_cgraph * graph = nullptr;
|
||||
switch (opt_ctx->build_type) {
|
||||
case GGML_OPT_BUILD_TYPE_FORWARD: {
|
||||
graph = opt_ctx->gf;
|
||||
} break;
|
||||
case GGML_OPT_BUILD_TYPE_GRAD: {
|
||||
graph = opt_ctx->gb_grad;
|
||||
} break;
|
||||
case GGML_OPT_BUILD_TYPE_OPT: {
|
||||
graph = opt_ctx->gb_opt;
|
||||
} break;
|
||||
}
|
||||
GGML_ASSERT(graph);
|
||||
|
||||
if (opt_ctx->allocated_graph == graph) {
|
||||
opt_ctx->eval_ready = true;
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
|
||||
|
||||
if (opt_ctx->static_graphs) {
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ graph->size*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph->size, graph->grads),
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ggml_free(opt_ctx->ctx_copy);
|
||||
opt_ctx->ctx_copy = ggml_init(params);
|
||||
|
||||
opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
|
||||
} else {
|
||||
opt_ctx->allocated_graph_copy = graph;
|
||||
}
|
||||
|
||||
ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
||||
opt_ctx->allocated_graph = graph;
|
||||
|
||||
opt_ctx->eval_ready = true;
|
||||
}
|
||||
|
||||
void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
|
||||
GGML_ASSERT(opt_ctx->eval_ready);
|
||||
if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
|
||||
struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
|
||||
|
||||
GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
|
||||
@ -609,9 +777,19 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
|
||||
adamw_par_data[6] = beta2h;
|
||||
}
|
||||
|
||||
ggml_opt_alloc_graph(opt_ctx, graph);
|
||||
ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
||||
opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
|
||||
opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
|
||||
|
||||
if (!opt_ctx->static_graphs) {
|
||||
opt_ctx->gf = nullptr;
|
||||
opt_ctx->gb_grad = nullptr;
|
||||
opt_ctx->gb_opt = nullptr;
|
||||
opt_ctx->allocated_graph = nullptr;
|
||||
opt_ctx->allocated_graph_copy = nullptr;
|
||||
}
|
||||
|
||||
opt_ctx->eval_ready = false;
|
||||
|
||||
if (!result) {
|
||||
return;
|
||||
@ -635,12 +813,14 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
|
||||
ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss));
|
||||
result->loss.push_back(loss);
|
||||
|
||||
GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
|
||||
std::vector<int32_t> pred(ndata);
|
||||
ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
|
||||
result->pred.insert(result->pred.end(), pred.begin(), pred.end());
|
||||
if (opt_ctx->pred) {
|
||||
GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
|
||||
std::vector<int32_t> pred(ndata);
|
||||
ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
|
||||
result->pred.insert(result->pred.end(), pred.begin(), pred.end());
|
||||
}
|
||||
|
||||
if (!opt_ctx->labels || result->ncorrect < 0) {
|
||||
if (!opt_ctx->ncorrect || result->ncorrect < 0) {
|
||||
result->ncorrect = -1;
|
||||
return;
|
||||
}
|
||||
@ -652,26 +832,6 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
|
||||
result->ncorrect += ncorrect;
|
||||
}
|
||||
|
||||
void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
|
||||
ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result);
|
||||
}
|
||||
|
||||
void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
|
||||
if (opt_ctx->opt_period == 1) {
|
||||
ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
|
||||
if (opt_i_next == 0) {
|
||||
ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
|
||||
ggml_opt_reset(opt_ctx, /*optimizer =*/ false);
|
||||
} else {
|
||||
ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result);
|
||||
}
|
||||
opt_ctx->opt_i = opt_i_next;
|
||||
}
|
||||
|
||||
// ====== High-Level Functions ======
|
||||
|
||||
void ggml_opt_epoch(
|
||||
@ -700,16 +860,18 @@ void ggml_opt_epoch(
|
||||
int64_t ibatch = 0;
|
||||
int64_t t_loop_start = ggml_time_us();
|
||||
for (; ibatch < ibatch_split; ++ibatch) {
|
||||
ggml_opt_alloc(opt_ctx, /*backward =*/ true);
|
||||
ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
|
||||
ggml_opt_forward_backward(opt_ctx, result_train);
|
||||
ggml_opt_eval(opt_ctx, result_train);
|
||||
if (callback_train) {
|
||||
callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
|
||||
}
|
||||
}
|
||||
t_loop_start = ggml_time_us();
|
||||
for (; ibatch < nbatches; ++ibatch) {
|
||||
ggml_opt_alloc(opt_ctx, /*backward =*/ false);
|
||||
ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
|
||||
ggml_opt_forward(opt_ctx, result_eval);
|
||||
ggml_opt_eval(opt_ctx, result_eval);
|
||||
if (callback_eval) {
|
||||
callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
|
||||
}
|
||||
@ -726,13 +888,26 @@ void ggml_opt_epoch_callback_progress_bar(
|
||||
int64_t t_start_us) {
|
||||
fprintf(stderr, "%s[", train ? "train: " : "val: ");
|
||||
|
||||
constexpr int64_t bar_length = 25;
|
||||
// The progress bar consists of partially filled blocks, unicode has 8 separate fill levels.
|
||||
constexpr int64_t bar_length = 8;
|
||||
const int64_t ibatch8 = 8 * ibatch;
|
||||
for (int64_t j = 0; j < bar_length; ++j) {
|
||||
const int64_t ibatch_j = ibatch_max * j/bar_length;
|
||||
if (ibatch_j < ibatch) {
|
||||
fprintf(stderr, "=");
|
||||
} else if (ibatch_max * (j - 1)/bar_length < ibatch) {
|
||||
fprintf(stderr, ">");
|
||||
if (ibatch_max * (8*j + 8) / bar_length < ibatch8) {
|
||||
fprintf(stderr, "\u2588"); // full block
|
||||
} else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) {
|
||||
fprintf(stderr, "\u2589"); // 7/8 filled
|
||||
} else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) {
|
||||
fprintf(stderr, "\u258A"); // 6/8 filled
|
||||
} else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) {
|
||||
fprintf(stderr, "\u258B"); // 5/8 filled
|
||||
} else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) {
|
||||
fprintf(stderr, "\u258C"); // 4/8 filled
|
||||
} else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) {
|
||||
fprintf(stderr, "\u258D"); // 3/8 filled
|
||||
} else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) {
|
||||
fprintf(stderr, "\u258E"); // 2/8 filled
|
||||
} else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) {
|
||||
fprintf(stderr, "\u258F"); // 1/8 filled
|
||||
} else {
|
||||
fprintf(stderr, " ");
|
||||
}
|
||||
@ -764,8 +939,8 @@ void ggml_opt_epoch_callback_progress_bar(
|
||||
const int64_t t_eta_m = t_eta_s / 60;
|
||||
t_eta_s -= t_eta_m * 60;
|
||||
|
||||
fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, "
|
||||
"t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r",
|
||||
fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% "
|
||||
"t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r",
|
||||
idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
|
||||
t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
|
||||
if (ibatch == ibatch_max) {
|
||||
@ -806,7 +981,10 @@ void ggml_opt_fit(
|
||||
|
||||
int64_t epoch = 1;
|
||||
|
||||
ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
|
||||
ggml_opt_params params = ggml_opt_default_params(backend_sched, loss_type);
|
||||
params.ctx_compute = ctx_compute;
|
||||
params.inputs = inputs;
|
||||
params.outputs = outputs;
|
||||
params.opt_period = opt_period;
|
||||
params.get_opt_pars = get_opt_pars;
|
||||
params.get_opt_pars_ud = &epoch;
|
||||
|
@ -151,6 +151,12 @@ struct rpc_msg_buffer_clear_req {
|
||||
uint8_t value;
|
||||
};
|
||||
|
||||
struct rpc_msg_set_tensor_hash_req {
|
||||
rpc_tensor tensor;
|
||||
uint64_t offset;
|
||||
uint64_t hash;
|
||||
};
|
||||
|
||||
struct rpc_msg_set_tensor_hash_rsp {
|
||||
uint8_t result;
|
||||
};
|
||||
@ -548,15 +554,12 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
|
||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||
rpc_tensor rpc_tensor = serialize_tensor(tensor);
|
||||
if (size > HASH_THRESHOLD) {
|
||||
// input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes)
|
||||
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + sizeof(uint64_t);
|
||||
std::vector<uint8_t> input(input_size, 0);
|
||||
uint64_t hash = fnv_hash((const uint8_t*)data, size);
|
||||
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
||||
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
||||
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &hash, sizeof(hash));
|
||||
rpc_msg_set_tensor_hash_req request;
|
||||
request.tensor = rpc_tensor;
|
||||
request.offset = offset;
|
||||
request.hash = fnv_hash((const uint8_t*)data, size);
|
||||
rpc_msg_set_tensor_hash_rsp response;
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, input.data(), input.size(), &response, sizeof(response));
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
|
||||
GGML_ASSERT(status);
|
||||
if (response.result) {
|
||||
// the server has the same data, no need to send it
|
||||
@ -864,7 +867,7 @@ public:
|
||||
bool free_buffer(const rpc_msg_free_buffer_req & request);
|
||||
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
|
||||
bool set_tensor(const std::vector<uint8_t> & input);
|
||||
bool set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response);
|
||||
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
|
||||
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
||||
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
||||
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
||||
@ -1101,18 +1104,10 @@ bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response)
|
||||
bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response)
|
||||
{
|
||||
// serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) |
|
||||
if (input.size() != sizeof(rpc_tensor) + 16) {
|
||||
return false;
|
||||
}
|
||||
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
|
||||
uint64_t offset;
|
||||
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
|
||||
const uint64_t * hash = (const uint64_t *)(input.data() + sizeof(rpc_tensor) + sizeof(offset));
|
||||
std::vector<uint8_t> cached_file;
|
||||
if (!get_cached_file(*hash, cached_file)) {
|
||||
if (!get_cached_file(request.hash, cached_file)) {
|
||||
response.result = 0;
|
||||
return true;
|
||||
}
|
||||
@ -1125,25 +1120,28 @@ bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set
|
||||
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
||||
GGML_ASSERT(ctx_ptr != nullptr);
|
||||
ggml_context * ctx = ctx_ptr.get();
|
||||
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
|
||||
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
||||
if (tensor == nullptr) {
|
||||
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
|
||||
return false;
|
||||
}
|
||||
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size, *hash);
|
||||
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
|
||||
__func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);
|
||||
|
||||
// sanitize tensor->data
|
||||
{
|
||||
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
|
||||
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
||||
|
||||
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
|
||||
if (request.tensor.data + request.offset < p0
|
||||
|| request.tensor.data + request.offset >= p1
|
||||
|| size > (p1 - request.tensor.data - request.offset)) {
|
||||
GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
|
||||
__func__, in_tensor->data, offset, size, *hash, p0, p1);
|
||||
__func__, request.tensor.data, request.offset, size, request.hash, p0, p1);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
ggml_backend_tensor_set(tensor, cached_file.data(), offset, size);
|
||||
ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size);
|
||||
response.result = 1;
|
||||
return true;
|
||||
}
|
||||
@ -1503,12 +1501,12 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_SET_TENSOR_HASH: {
|
||||
std::vector<uint8_t> input;
|
||||
if (!recv_msg(sockfd, input)) {
|
||||
rpc_msg_set_tensor_hash_req request;
|
||||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_set_tensor_hash_rsp response;
|
||||
if (!server.set_tensor_hash(input, response)) {
|
||||
if (!server.set_tensor_hash(request, response)) {
|
||||
return;
|
||||
}
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
|
@ -52,9 +52,8 @@ target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
|
||||
find_package(DNNL)
|
||||
set(GGML_SYCL_DNNL 0)
|
||||
if(DNNL_FOUND)
|
||||
if (DEFINED ENV{ONEAPI_ROOT} AND NOT DEFINED DNNL_GPU_VENDOR)
|
||||
# Assuming oneDNN packaged with oneapi release is used which
|
||||
# supports only intel target
|
||||
if (NOT DEFINED DNNL_GPU_VENDOR)
|
||||
# default to intel target
|
||||
set(DNNL_GPU_VENDOR "INTEL")
|
||||
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
|
||||
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
|
||||
@ -108,6 +107,9 @@ endif()
|
||||
if (GGML_SYCL_TARGET STREQUAL "INTEL")
|
||||
# Intel devices use Intel oneMKL directly instead of oneMath to avoid the limitation of linking Intel oneMKL statically
|
||||
# See https://github.com/uxlfoundation/oneMath/issues/654
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
set(SYCL_COMPILER ON)
|
||||
endif()
|
||||
find_package(MKL REQUIRED)
|
||||
target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS)
|
||||
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_USE_INTEL_ONEMKL)
|
||||
|
@ -14,23 +14,24 @@
|
||||
#define GGML_SYCL_BACKEND_HPP
|
||||
|
||||
#include "binbcast.hpp"
|
||||
#include "concat.hpp"
|
||||
#include "common.hpp"
|
||||
#include "concat.hpp"
|
||||
#include "conv.hpp"
|
||||
#include "convert.hpp"
|
||||
#include "cpy.hpp"
|
||||
#include "dequantize.hpp"
|
||||
#include "dmmv.hpp"
|
||||
#include "element_wise.hpp"
|
||||
#include "gla.hpp"
|
||||
#include "im2col.hpp"
|
||||
#include "mmq.hpp"
|
||||
#include "mmvq.hpp"
|
||||
#include "rope.hpp"
|
||||
#include "norm.hpp"
|
||||
#include "outprod.hpp"
|
||||
#include "quants.hpp"
|
||||
#include "rope.hpp"
|
||||
#include "softmax.hpp"
|
||||
#include "tsembd.hpp"
|
||||
#include "im2col.hpp"
|
||||
#include "wkv.hpp"
|
||||
#include "outprod.hpp"
|
||||
#include "element_wise.hpp"
|
||||
#include "cpy.hpp"
|
||||
#include "gla.hpp"
|
||||
|
||||
#endif // GGML_SYCL_BACKEND_HPP
|
||||
#endif // GGML_SYCL_BACKEND_HPP
|
||||
|
@ -42,6 +42,7 @@ void ggml_sycl_host_free(void* ptr);
|
||||
|
||||
extern int g_ggml_sycl_debug;
|
||||
extern int g_ggml_sycl_disable_optimize;
|
||||
extern int g_ggml_sycl_prioritize_dmmv;
|
||||
|
||||
#define GGML_SYCL_DEBUG(...) \
|
||||
do { \
|
||||
@ -114,17 +115,12 @@ static void crash() {
|
||||
GGML_ABORT("SYCL error");
|
||||
}
|
||||
|
||||
#define SYCL_CHECK(err) \
|
||||
do { \
|
||||
auto err_ = (err); \
|
||||
if (err_ != 0) \
|
||||
ggml_sycl_error( \
|
||||
#err, \
|
||||
__func__, \
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
"Meet error in this line code!"); \
|
||||
} while (0)
|
||||
#define SYCL_CHECK(err) \
|
||||
do { \
|
||||
auto err_ = (err); \
|
||||
if (err_ != 0) \
|
||||
ggml_sycl_error(#err, __func__, __FILE__, __LINE__, "Exception caught in this line of code."); \
|
||||
} while (0)
|
||||
|
||||
#if DPCT_COMPAT_RT_VERSION >= 11100
|
||||
#define GGML_SYCL_ASSUME(x) __builtin_assume(x)
|
||||
|
@ -437,41 +437,52 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k
|
||||
}
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
|
||||
const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const sycl::nd_item<3> & item_ct1) {
|
||||
|
||||
const int64_t work_group_size = item_ct1.get_local_range(2);
|
||||
const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
|
||||
const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
|
||||
|
||||
const int64_t i01 = item_ct1.get_group(1);
|
||||
const int64_t i02 = item_ct1.get_group(0) % ne02;
|
||||
const int64_t i03 = item_ct1.get_group(0) / ne02;
|
||||
|
||||
// make each work-item deal with more elements since sycl global range can not exceed max int
|
||||
const src_t * x = (const src_t *) vx;
|
||||
for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
|
||||
y[i] = x[i];
|
||||
const src_t * x = static_cast<const src_t *>(vx);
|
||||
const int64_t ix = i03 * s03 + i02 * s02 + i01 * s01;
|
||||
const int64_t iy = ((i03 * ne02 + i02) * ne01 + i01) * ne00;
|
||||
|
||||
#pragma unroll
|
||||
for (int64_t i00 = global_id; i00 < ne00; i00 += work_group_size * item_ct1.get_group_range(2)) {
|
||||
y[iy + i00] = static_cast<dst_t>(x[ix + i00]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static void convert_unary_sycl(const void *__restrict__ vx,
|
||||
dst_t *__restrict__ y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
|
||||
static void convert_unary_nc_sycl(const void * __restrict__ vx, dst_t * __restrict__ y,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03, dpct::queue_ptr queue) {
|
||||
dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
sycl::range<3> global_size(ne02 * ne03, ne01, ceil_div(ne00, SYCL_DEQUANTIZE_BLOCK_SIZE));
|
||||
|
||||
// decrease global range when it exceeds the max int
|
||||
int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE);
|
||||
sycl::range<3> block_nums(1, 1, num_blocks);
|
||||
sycl::range<3> local_range(1, 1, local_size);
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
// TODO: Downsample logic is separated from the kernel, a rewrite is desirable
|
||||
int64_t downsized_workgroup = downsample_sycl_global_range(global_size[0], SYCL_DEQUANTIZE_BLOCK_SIZE);
|
||||
sycl::range<3> workgroup_size(1, 1, downsized_workgroup);
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * local_range, local_range),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
convert_unary<src_t>(vx, y, k, item_ct1);
|
||||
});
|
||||
}
|
||||
queue->parallel_for(sycl::nd_range<3>(global_size * workgroup_size, workgroup_size), [=](sycl::nd_item<3> item_ct1) {
|
||||
convert_unary_nc<src_t>(vx, y, ne00, ne01, ne02, s01, s02, s03, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst) {
|
||||
template <typename src_t, typename dst_t>
|
||||
static void convert_unary_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr queue) {
|
||||
convert_unary_nc_sycl<src_t>(vx, y, k, 1, 1, 1, k, k, k, queue);
|
||||
}
|
||||
|
||||
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
if (dst->src[0]->extra &&
|
||||
@ -574,3 +585,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
return convert_unary_nc_sycl<float>;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// Copyright (C) 2025 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
@ -16,12 +16,19 @@
|
||||
#include "common.hpp"
|
||||
|
||||
template <typename T>
|
||||
using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
|
||||
int64_t k, dpct::queue_ptr stream);
|
||||
typedef to_t_sycl_t<float> to_fp32_sycl_t;
|
||||
using to_t_sycl_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, dpct::queue_ptr stream);
|
||||
typedef to_t_sycl_t<float> to_fp32_sycl_t;
|
||||
typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
|
||||
|
||||
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst);
|
||||
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst);
|
||||
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst);
|
||||
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_CONVERT_HPP
|
||||
// Nc = Non-contiguous
|
||||
template <typename T>
|
||||
using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
|
||||
int64_t s01, int64_t s02, int64_t s03, dpct::queue_ptr queue);
|
||||
|
||||
typedef to_t_nc_sycl_t<sycl::half> to_fp16_nc_sycl_t;
|
||||
to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type);
|
||||
|
||||
#endif // GGML_SYCL_CONVERT_HPP
|
||||
|
@ -49,6 +49,7 @@ static bool g_sycl_loaded = false;
|
||||
int g_ggml_sycl_debug = 0;
|
||||
int g_ggml_sycl_disable_optimize = 0;
|
||||
int g_ggml_sycl_disable_graph = 0;
|
||||
int g_ggml_sycl_prioritize_dmmv = 0;
|
||||
|
||||
static ggml_sycl_device_info ggml_sycl_init() {
|
||||
ggml_sycl_device_info info = {};
|
||||
@ -195,11 +196,13 @@ static void ggml_check_sycl() try {
|
||||
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
|
||||
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
|
||||
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
|
||||
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
|
||||
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
||||
GGML_LOG_INFO("Running with Environment Variables:\n");
|
||||
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
|
||||
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
|
||||
GGML_LOG_INFO("Build with Macros:\n");
|
||||
#if defined(GGML_SYCL_FORCE_MMQ)
|
||||
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
|
||||
@ -2694,35 +2697,31 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
|
||||
const sycl::half *src1_as_f16, char *dst,
|
||||
const void **ptrs_src, void **ptrs_dst,
|
||||
int64_t ne12, int64_t ne13, int64_t ne23,
|
||||
size_t nb02, size_t nb03, size_t nb12,
|
||||
size_t nb13, size_t nbd2, size_t nbd3,
|
||||
int64_t r2, int64_t r3,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst,
|
||||
const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
|
||||
size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
|
||||
int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
|
||||
const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
|
||||
const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
|
||||
|
||||
if (i13 >= ne13 || i12 >= ne12) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t i03 = i13 / r3;
|
||||
int64_t i02 = i12 / r2;
|
||||
const int64_t i03 = i13 / r3;
|
||||
const int64_t i02 = i12 / r2;
|
||||
|
||||
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
|
||||
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
|
||||
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
|
||||
const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
|
||||
const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
|
||||
uint8_t * dst_bytes = reinterpret_cast<uint8_t *>(dst);
|
||||
|
||||
ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
|
||||
ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
|
||||
ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3;
|
||||
}
|
||||
|
||||
static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
||||
const ggml_tensor *src0,
|
||||
const ggml_tensor *src1,
|
||||
ggml_tensor *dst) try {
|
||||
static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0,
|
||||
const ggml_tensor * src1, ggml_tensor * dst) try {
|
||||
GGML_ASSERT(!ggml_is_transposed(src0));
|
||||
GGML_ASSERT(!ggml_is_transposed(src1));
|
||||
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
||||
@ -2730,103 +2729,107 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
// TODO: see https://github.com/ggml-org/llama.cpp/pull/13155
|
||||
// Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
queue_ptr main_stream = ctx.stream();;
|
||||
queue_ptr queue = ctx.stream();
|
||||
|
||||
void * src0_ddq = src0->data;
|
||||
sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
|
||||
float * src1_ddf = (float *) src1->data;
|
||||
float * dst_ddf = (float *) dst->data;
|
||||
dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
const sycl::half * src0_f16 = static_cast<const sycl::half *>(src0->data);
|
||||
float * dst_ddf = static_cast<float *>(dst->data);
|
||||
|
||||
const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
|
||||
const size_t type_size_src1 = ggml_type_size(src1->type);
|
||||
GGML_ASSERT(nb10 == type_size_src1);
|
||||
|
||||
// SRC1 strides
|
||||
int64_t s11 = nb11 / type_size_src1;
|
||||
int64_t s12 = nb12 / type_size_src1;
|
||||
int64_t s13 = nb13 / type_size_src1;
|
||||
ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
|
||||
|
||||
// convert src1 to fp16
|
||||
ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
|
||||
if (src1->type != GGML_TYPE_F16) {
|
||||
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
|
||||
const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
|
||||
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
|
||||
const int64_t ne_src1 = ggml_nelements(src1);
|
||||
src1_f16_alloc.alloc(ne_src1);
|
||||
GGML_ASSERT(to_fp16_sycl != nullptr);
|
||||
to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
|
||||
to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
|
||||
|
||||
src1_f16 = src1_f16_alloc.get();
|
||||
s11 = ne10;
|
||||
s12 = ne11 * s11;
|
||||
s13 = ne12 * s12;
|
||||
}
|
||||
sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
|
||||
: src1_f16_alloc.get();
|
||||
|
||||
char * dst_t;
|
||||
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
|
||||
char * dst_t = reinterpret_cast<char *>(dst_ddf);
|
||||
|
||||
dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
|
||||
dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
|
||||
dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
|
||||
dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
|
||||
|
||||
// dst strides
|
||||
size_t nbd2 = dst->nb[2];
|
||||
size_t nbd3 = dst->nb[3];
|
||||
|
||||
const float alpha_f32 = 1.0f;
|
||||
const float beta_f32 = 0.0f;
|
||||
const float beta_f32 = 0.0f;
|
||||
|
||||
const void * alpha = &alpha_f32;
|
||||
const void * beta = &beta_f32;
|
||||
|
||||
dst_t = (char *) dst_ddf;
|
||||
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
GGML_ASSERT(ne13 % ne03 == 0);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t r2 = ne12/ne02;
|
||||
const int64_t r3 = ne13/ne03;
|
||||
const int64_t r2 = ne12 / ne02;
|
||||
const int64_t r3 = ne13 / ne03;
|
||||
|
||||
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
||||
*main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||
(const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
|
||||
(const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t,
|
||||
cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type)));
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
|
||||
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
|
||||
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t,
|
||||
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
|
||||
} else {
|
||||
const int ne23 = ne12*ne13;
|
||||
const int ne23 = ne12 * ne13;
|
||||
|
||||
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
|
||||
ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
|
||||
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
|
||||
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
|
||||
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
|
||||
|
||||
sycl::range<3> block_dims(1, ne12, ne13);
|
||||
/*
|
||||
DPCT1049:47: The work-group size passed to the SYCL kernel may exceed
|
||||
the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||
*/
|
||||
{
|
||||
dpct::has_capability_or_fail(main_stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
main_stream->submit([&](sycl::handler &cgh) {
|
||||
const void **ptrs_src_get = ptrs_src.get();
|
||||
void **ptrs_dst_get = ptrs_dst.get();
|
||||
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
|
||||
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
|
||||
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_compute_batched_ptrs(
|
||||
src0_as_f16, src1_f16,
|
||||
dst_t, ptrs_src_get,
|
||||
ptrs_dst_get, ne12, ne13, ne23,
|
||||
nb02, nb03, nb12_scaled, nb13_scaled,
|
||||
nbd2, nbd3, r2, r3, item_ct1);
|
||||
});
|
||||
queue->submit([&](sycl::handler & cgh) {
|
||||
const void ** ptrs_src_get = ptrs_src.get();
|
||||
void ** ptrs_dst_get = ptrs_dst.get();
|
||||
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
|
||||
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
|
||||
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
|
||||
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
||||
*main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||
*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
||||
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
|
||||
(void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
|
||||
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
|
||||
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
|
||||
}
|
||||
} catch (const sycl::exception & exc) {
|
||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
||||
std::exit(1);
|
||||
}
|
||||
catch (sycl::exception const &exc) {
|
||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
||||
<< ", line:" << __LINE__ << std::endl;
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
enum class mul_mat_algo {
|
||||
DMMV = 0,
|
||||
MMVQ = 1,
|
||||
MUL_MAT_SYCL = 2,
|
||||
};
|
||||
|
||||
inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
||||
// TODO: accuracy issues in MMQ
|
||||
@ -2834,6 +2837,33 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
@ -2862,7 +2892,7 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
GGML_ASSERT((size % sizeof(block_q4_0) == 0));
|
||||
GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
|
||||
int offset_blks = offset / sizeof(block_q4_0);
|
||||
auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;;
|
||||
auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;
|
||||
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
|
||||
|
||||
stream->parallel_for(
|
||||
@ -2890,25 +2920,44 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
reorder_qw(data_device, ncols, nrows, size, 0, stream);
|
||||
}
|
||||
|
||||
/*
|
||||
* This function could be called when the OP (mul_mat) function support reorder optimizition.
|
||||
*/
|
||||
static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
if (!g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
|
||||
ctx->opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
|
||||
dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
|
||||
src0->type == GGML_TYPE_Q4_0 &&
|
||||
src1->ne[2]==1 && src1->ne[3]==1) {
|
||||
static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
|
||||
return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
|
||||
ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
|
||||
dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
|
||||
dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra;
|
||||
if (!extra) return; //only happen in CI/UT permute case.
|
||||
|
||||
if (extra->optimized_feature.reorder) return; //skip the tensor which is handled for reorder.
|
||||
|
||||
reorder_qw(src0, ctx->stream());
|
||||
extra->optimized_feature.reorder = true; //used to decode/dequan in next steps.
|
||||
static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */,
|
||||
ggml_tensor * dst, mul_mat_algo mm_algorithm) {
|
||||
if (!should_reorder_tensor(*ctx, dst)) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
|
||||
if (!extra || extra->optimized_feature.reorder) {
|
||||
return; // Skip permutations and already reordered tensors
|
||||
}
|
||||
|
||||
switch (mm_algorithm) {
|
||||
case mul_mat_algo::DMMV:
|
||||
if (!ggml_sycl_supports_reorder_dmmv(src0->type)) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case mul_mat_algo::MMVQ:
|
||||
if (!ggml_sycl_supports_reorder_mmvq(src0->type)) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case mul_mat_algo::MUL_MAT_SYCL:
|
||||
if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
reorder_qw(src0, ctx->stream());
|
||||
extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
|
||||
}
|
||||
|
||||
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
@ -2917,7 +2966,8 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
||||
int64_t min_compute_capability = INT_MAX;
|
||||
|
||||
if (split) {
|
||||
ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
|
||||
ggml_backend_sycl_split_buffer_type_context * buft_ctx =
|
||||
(ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
|
||||
auto & tensor_split = buft_ctx->tensor_split;
|
||||
for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
|
||||
// skip devices that are not going to do any work:
|
||||
@ -2930,7 +2980,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
||||
}
|
||||
}
|
||||
} else {
|
||||
min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
|
||||
min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
|
||||
}
|
||||
|
||||
// check data types and tensor shapes for custom matrix multiplication kernels:
|
||||
@ -2952,9 +3002,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
||||
use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
|
||||
#endif // SYCL_USE_XMX
|
||||
|
||||
|
||||
// mmvq path is faster in the CUDA backend.
|
||||
if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
|
||||
if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
|
||||
// Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
|
||||
// is enabled takes precedence over DMMV, the current if-else implementation
|
||||
// requires disabling DMMV if both conditions are met
|
||||
|| (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) {
|
||||
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
|
||||
}
|
||||
|
||||
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
// TODO: Refactor and cleanup of mul mat dispatching.
|
||||
@ -2966,24 +3022,30 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
||||
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
|
||||
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
||||
}
|
||||
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
// KQV single-batch
|
||||
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
|
||||
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
||||
// KQ + KQV multi-batch
|
||||
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
||||
} else if (use_dequantize_mul_mat_vec) {
|
||||
opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder.
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
|
||||
// save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream());
|
||||
constexpr bool convert_src1_to_q8_1 = false;
|
||||
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
|
||||
} else if (use_mul_mat_vec_q) {
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
|
||||
constexpr bool convert_src1_to_q8_1 = true;
|
||||
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
|
||||
} else if (use_mul_mat_q) {
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
|
||||
constexpr bool convert_src1_to_q8_1 = true;
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
|
||||
} else {
|
||||
opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder.
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
|
||||
constexpr bool convert_src1_to_q8_1 = false;
|
||||
// MUL_MAT_SYCL supports reorder
|
||||
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MUL_MAT_SYCL);
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
|
||||
}
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
@ -3873,9 +3935,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (a->ne[3] != b->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
if (!ggml_is_contiguous(b)) {
|
||||
return false;
|
||||
}
|
||||
ggml_type a_type = a->type;
|
||||
if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
|
||||
a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
|
||||
|
@ -1,6 +1,60 @@
|
||||
#include "mmvq.hpp"
|
||||
|
||||
#include "ggml.h"
|
||||
#include "common.hpp"
|
||||
#include "quants.hpp"
|
||||
#include "vecdotq.hpp"
|
||||
#include <cassert>
|
||||
|
||||
template <typename reorder_vec_dot_q_sycl>
|
||||
static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const int ncols, const int nrows, const sycl::nd_item<3> & nd_item) {
|
||||
using block_type = ggml_sycl_reordered::block_q_t<reorder_vec_dot_q_sycl::gtype>;
|
||||
using block_traits = typename block_type::traits;
|
||||
|
||||
const auto sg = nd_item.get_sub_group();
|
||||
const int sg_range = sg.get_group_linear_range();
|
||||
const int workgroup_id = nd_item.get_group_linear_id();
|
||||
const int sg_id = sg.get_group_linear_id();
|
||||
const int row = workgroup_id * sg_range + sg_id;
|
||||
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int blocks_per_row = ncols / block_traits::qk;
|
||||
constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
|
||||
constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
|
||||
|
||||
static_assert(blocks_per_subgroup > 0);
|
||||
static_assert(block_elements_per_subgroup > 0);
|
||||
|
||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||
|
||||
float partial_sum = 0.0f;
|
||||
for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
|
||||
const int ibx = row * blocks_per_row + i; // x block index
|
||||
// TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
|
||||
const int bx_offset = block_type::get_block_offset(ibx);
|
||||
const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
|
||||
|
||||
// Y block index that aligns with ibx
|
||||
const int iby = i * block_type::block_to_q8_1_ratio();
|
||||
|
||||
#pragma unroll
|
||||
for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {
|
||||
// x block quant index when casting the quants to int
|
||||
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
|
||||
|
||||
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs);
|
||||
}
|
||||
}
|
||||
|
||||
auto sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum, std::plus<>());
|
||||
|
||||
if (sg.leader()) {
|
||||
dst[row] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
|
||||
static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
@ -480,26 +534,39 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
|
||||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
float *dst, const int ncols,
|
||||
const int nrows,
|
||||
static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
||||
const int nrows, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK4_0 == 0);
|
||||
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
||||
constexpr size_t num_subgroups = 16;
|
||||
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
||||
|
||||
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
|
||||
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
||||
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
||||
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
|
||||
nd_item);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK4_0 == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
|
||||
VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -916,93 +983,95 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_mul_mat_vec_q(
|
||||
ggml_backend_sycl_context & ctx,
|
||||
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
|
||||
float *dst_dd_i, const int64_t row_low, const int64_t row_high,
|
||||
const int64_t src1_ncols, const int64_t src1_padded_col_size,
|
||||
const dpct::queue_ptr &stream) {
|
||||
|
||||
void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
|
||||
ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low,
|
||||
const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_col_size,
|
||||
const dpct::queue_ptr & stream) {
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
GGML_ASSERT(ne10 % QK8_1 == 0);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
||||
int id;
|
||||
SYCL_CHECK(
|
||||
CHECK_TRY_ERROR(id = get_current_device_id()));
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id()));
|
||||
const size_t q8_1_ts = sizeof(block_q8_1);
|
||||
const size_t q8_1_bs = QK8_1;
|
||||
// the main device has a larger memory buffer to hold the results from all GPUs
|
||||
// nrows_dst == nrows of the matrix that the kernel writes into
|
||||
|
||||
for (int i = 0; i < src1_ncols; i++)
|
||||
{
|
||||
for (int i = 0; i < src1_ncols; i++) {
|
||||
const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs;
|
||||
const char* src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset;
|
||||
float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
|
||||
const char * src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset;
|
||||
float * dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
case GGML_TYPE_Q4_0:
|
||||
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
||||
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n");
|
||||
reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
} else {
|
||||
GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n");
|
||||
mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
GGML_UNUSED(src1);
|
||||
|
61
ggml/src/ggml-sycl/quants.hpp
Normal file
61
ggml/src/ggml-sycl/quants.hpp
Normal file
@ -0,0 +1,61 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2025 Codeplay Software Ltd.
|
||||
// Copyright (C) 2025 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#ifndef GGML_SYCL_QUANTS_HPP
|
||||
#define GGML_SYCL_QUANTS_HPP
|
||||
|
||||
#include "ggml-common.h"
|
||||
#include "ggml.h"
|
||||
|
||||
namespace ggml_sycl_reordered {
|
||||
|
||||
|
||||
// The reordered block moves quants (qs) and scales(d) to two
|
||||
// uniform regions of memory that is contiguous in the same tensor.
|
||||
// What this means is that instead of having:
|
||||
// [d0, qs0] [d1, qs1] [d2, qs2] ... [dN, qsN]
|
||||
// We have:
|
||||
// [qs0, qs1, qs2, ..., qsN] [d0, d1, d2, ..., dN]
|
||||
//
|
||||
// Notes: out-of-bounds qs will run into d values
|
||||
// Aligment relies on the allocated size of qs
|
||||
|
||||
template <ggml_type type> struct block_q_t;
|
||||
|
||||
|
||||
// qk number of weights / quants in a block
|
||||
// qr number of weights in a byte (described as 'before dequantization')
|
||||
// for quantization types that has low and high bits split, qr is calculated with
|
||||
// using the lower bits, e.g for Q6 quants QR6 is 2
|
||||
// qi number of 32 bit integers needed to represent all the quants from a block (`qs` field)
|
||||
// See ggml-common.h to see how these are calculated
|
||||
template <> struct block_q_t<GGML_TYPE_Q4_0> {
|
||||
struct traits {
|
||||
static constexpr uint32_t qk = QK4_0;
|
||||
static constexpr uint32_t qi = QI4_0;
|
||||
static constexpr uint32_t qr = QR4_0;
|
||||
static constexpr uint32_t vdr_mmvq = 2;
|
||||
};
|
||||
|
||||
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
|
||||
|
||||
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
|
||||
return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half);
|
||||
}
|
||||
|
||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||
};
|
||||
|
||||
} // namespace ggml_sycl_reordered
|
||||
|
||||
#endif // GGML_SYCL_QUANTS_HPP
|
@ -1,6 +1,6 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// Copyright (C) 2025 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
@ -14,8 +14,11 @@
|
||||
#define GGML_SYCL_VECDOTQ_HPP
|
||||
|
||||
#include "dpct/helper.hpp"
|
||||
#include "ggml.h"
|
||||
#include "quants.hpp"
|
||||
|
||||
typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
|
||||
typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
|
||||
const int & iqs);
|
||||
|
||||
static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
|
||||
const uint16_t* x16 =
|
||||
@ -252,13 +255,60 @@ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh,
|
||||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
||||
|
||||
template <ggml_type T> struct reorder_vec_dot_q_sycl {
|
||||
static_assert(T != T, "ggml_type for reorder vecdot not implemented");
|
||||
};
|
||||
|
||||
template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
|
||||
static constexpr ggml_type gtype = GGML_TYPE_Q4_0;
|
||||
|
||||
using q4_0_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_0>;
|
||||
using q4_0_traits = typename q4_0_block::traits;
|
||||
|
||||
__dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4, const sycl::half2 & ds8) {
|
||||
int sumi = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {
|
||||
const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
|
||||
const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
|
||||
|
||||
// SIMD dot product of quantized values
|
||||
sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi);
|
||||
sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
|
||||
}
|
||||
|
||||
const sycl::float2 ds8f = ds8.convert<float, sycl::rounding_mode::automatic>();
|
||||
|
||||
// second part effectively subtracts 8 from each quant value
|
||||
return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y());
|
||||
}
|
||||
|
||||
__dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
|
||||
const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
|
||||
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
|
||||
int v[q4_0_traits::vdr_mmvq];
|
||||
int u[2 * q4_0_traits::vdr_mmvq];
|
||||
|
||||
#pragma unroll
|
||||
|
||||
for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {
|
||||
v[i] = get_int_from_uint8(bq4_0, iqs + i);
|
||||
u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
|
||||
u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + q4_0_traits::qi);
|
||||
}
|
||||
|
||||
return vec_dot_q4_0_q8_1_impl(v, u, d, bq8_1->ds);
|
||||
};
|
||||
};
|
||||
|
||||
#define VDR_Q4_0_Q8_1_MMVQ 2
|
||||
#define VDR_Q4_0_Q8_1_MMQ 4
|
||||
|
||||
template <int vdr>
|
||||
static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u,
|
||||
const float &d4,
|
||||
const sycl::half2 &ds8) {
|
||||
static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4,
|
||||
const sycl::half2 & ds8) {
|
||||
int sumi = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vdr; ++i) {
|
||||
@ -270,8 +320,7 @@ static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u,
|
||||
sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
|
||||
}
|
||||
|
||||
const sycl::float2 ds8f =
|
||||
ds8.convert<float, sycl::rounding_mode::automatic>();
|
||||
const sycl::float2 ds8f = ds8.convert<float, sycl::rounding_mode::automatic>();
|
||||
|
||||
// second part effectively subtracts 8 from each quant value
|
||||
return d4 * (sumi * ds8f.x() - (8 * vdr / QI4_0) * ds8f.y());
|
||||
@ -456,13 +505,13 @@ vec_dot_q4_0_q8_1(const void *__restrict__ vbq,
|
||||
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
|
||||
|
||||
int v[VDR_Q4_0_Q8_1_MMVQ];
|
||||
int u[2*VDR_Q4_0_Q8_1_MMVQ];
|
||||
int u[2 * VDR_Q4_0_Q8_1_MMVQ];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
|
||||
v[i] = get_int_from_uint8(bq4_0->qs, iqs + i);
|
||||
u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
|
||||
u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
|
||||
v[i] = get_int_from_uint8(bq4_0->qs, iqs + i);
|
||||
u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
|
||||
u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
|
||||
}
|
||||
|
||||
return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
|
||||
|
@ -275,6 +275,7 @@ struct vk_device_struct {
|
||||
bool prefer_host_memory;
|
||||
bool float_controls_rte_fp16;
|
||||
bool subgroup_add;
|
||||
bool subgroup_shuffle;
|
||||
|
||||
bool integer_dot_product;
|
||||
|
||||
@ -402,12 +403,20 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
||||
|
||||
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
|
||||
|
||||
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
||||
|
||||
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
|
||||
@ -1581,13 +1590,29 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
||||
|
||||
// number of rows/cols for flash attention shader
|
||||
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
||||
static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
||||
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
||||
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
|
||||
|
||||
static uint32_t get_fa_num_small_rows(bool scalar) {
|
||||
return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows;
|
||||
}
|
||||
|
||||
static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
||||
GGML_UNUSED(clamp);
|
||||
|
||||
if (scalar) {
|
||||
if (small_rows) {
|
||||
return {scalar_flash_attention_num_small_rows, 64};
|
||||
} else {
|
||||
return {scalar_flash_attention_num_large_rows, 32};
|
||||
}
|
||||
}
|
||||
|
||||
// small rows, large cols
|
||||
if (small_rows) {
|
||||
return {flash_attention_num_small_rows, 64};
|
||||
return {get_fa_num_small_rows(scalar), 32};
|
||||
}
|
||||
|
||||
// small cols to reduce register count
|
||||
if (ggml_is_quantized(type) || D == 256) {
|
||||
return {64, 32};
|
||||
@ -1632,7 +1657,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
||||
const uint32_t warps = warptile[0] / warptile[10];
|
||||
|
||||
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
|
||||
const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
|
||||
const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
|
||||
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
||||
|
||||
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
|
||||
@ -1882,65 +1907,66 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
||||
};
|
||||
|
||||
auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1};
|
||||
};
|
||||
|
||||
auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
||||
// For large number of rows, 128 invocations seems to work best.
|
||||
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
||||
// can't use 256 for D==80.
|
||||
// For scalar, use 128 (arbitrary)
|
||||
uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128);
|
||||
auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows);
|
||||
|
||||
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
||||
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
||||
const uint32_t D_lsb = D ^ (D & (D-1));
|
||||
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
|
||||
|
||||
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
||||
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
||||
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
|
||||
};
|
||||
|
||||
#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
|
||||
|
||||
#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256)
|
||||
|
||||
CREATE_FA(GGML_TYPE_F16, f16, true, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, )
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
|
||||
auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1};
|
||||
};
|
||||
|
||||
auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
||||
// For large number of rows, 128 invocations seems to work best.
|
||||
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
||||
// can't use 256 for D==80.
|
||||
uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
|
||||
auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
|
||||
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
||||
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
||||
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
|
||||
};
|
||||
|
||||
#define CREATE_FA2(TYPE, NAMELC, D) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
|
||||
|
||||
#define CREATE_FA(TYPE, NAMELC) \
|
||||
CREATE_FA2(TYPE, NAMELC, 64) \
|
||||
CREATE_FA2(TYPE, NAMELC, 80) \
|
||||
CREATE_FA2(TYPE, NAMELC, 96) \
|
||||
CREATE_FA2(TYPE, NAMELC, 112) \
|
||||
CREATE_FA2(TYPE, NAMELC, 128) \
|
||||
CREATE_FA2(TYPE, NAMELC, 256)
|
||||
|
||||
CREATE_FA(GGML_TYPE_F16, f16)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0)
|
||||
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
|
||||
//CREATE_FA(GGML_TYPE_Q2_K, q2_k)
|
||||
//CREATE_FA(GGML_TYPE_Q3_K, q3_k)
|
||||
//CREATE_FA(GGML_TYPE_Q4_K, q4_k)
|
||||
//CREATE_FA(GGML_TYPE_Q5_K, q5_k)
|
||||
//CREATE_FA(GGML_TYPE_Q6_K, q6_k)
|
||||
//CREATE_FA(GGML_TYPE_IQ1_S, iq1_s)
|
||||
//CREATE_FA(GGML_TYPE_IQ1_M, iq1_m)
|
||||
//CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs)
|
||||
//CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs)
|
||||
//CREATE_FA(GGML_TYPE_IQ2_S, iq2_s)
|
||||
//CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs)
|
||||
//CREATE_FA(GGML_TYPE_IQ3_S, iq3_s)
|
||||
//CREATE_FA(GGML_TYPE_IQ4_XS, iq4_xs)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2)
|
||||
}
|
||||
#endif
|
||||
#undef CREATE_FA2
|
||||
#undef CREATE_FA
|
||||
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
||||
@ -2837,6 +2863,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
||||
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
|
||||
|
||||
device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
||||
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
|
||||
|
||||
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
|
||||
|
||||
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
||||
@ -5260,7 +5289,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
|
||||
const uint64_t nei0 = ids->ne[0];
|
||||
const uint64_t nei1 = ids->ne[1];
|
||||
GGML_ASSERT(nei0 * nei1 <= 3072);
|
||||
GGML_ASSERT(nei0 * nei1 <= 4096);
|
||||
|
||||
const uint32_t nbi1 = ids->nb[1];
|
||||
const uint32_t nbi2 = ids->nb[2];
|
||||
@ -5709,20 +5738,57 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
assert(q->type == GGML_TYPE_F32);
|
||||
assert(k->type == v->type);
|
||||
|
||||
bool scalar = !ctx->device->coopmat2;
|
||||
|
||||
uint32_t gqa_ratio = 1;
|
||||
uint32_t qk_ratio = neq2 / nek2;
|
||||
uint32_t workgroups_x = (uint32_t)neq1;
|
||||
uint32_t workgroups_y = (uint32_t)neq2;
|
||||
uint32_t workgroups_z = (uint32_t)neq3;
|
||||
|
||||
// For scalar FA, we can use the "large" size to accommodate qga.
|
||||
// For coopmat FA, we always use the small size (which is still pretty large for gqa).
|
||||
const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false);
|
||||
|
||||
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
||||
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
|
||||
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
||||
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
||||
// and change addressing calculations to index Q's dimension 2.
|
||||
gqa_ratio = qk_ratio;
|
||||
N = gqa_ratio;
|
||||
workgroups_y /= N;
|
||||
}
|
||||
|
||||
vk_pipeline *pipelines;
|
||||
// XXX TODO other backends may be changing accumulator precision to default to f32 soon
|
||||
bool f32acc = dst->op_params[3] == GGML_PREC_F32;
|
||||
bool small_rows = N <= flash_attention_num_small_rows;
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
||||
default:
|
||||
assert(!"unsupported D value");
|
||||
return;
|
||||
bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32;
|
||||
bool small_rows = N <= get_fa_num_small_rows(scalar);
|
||||
|
||||
if (scalar) {
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
||||
default:
|
||||
GGML_ASSERT(!"unsupported D value");
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
default:
|
||||
GGML_ASSERT(!"unsupported D value");
|
||||
return;
|
||||
}
|
||||
}
|
||||
assert(pipelines);
|
||||
|
||||
@ -5740,27 +5806,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
vk_pipeline pipeline = pipelines[aligned];
|
||||
assert(pipeline);
|
||||
|
||||
uint32_t gqa_ratio = 1;
|
||||
uint32_t qk_ratio = neq2 / nek2;
|
||||
uint32_t workgroups_x = (uint32_t)neq1;
|
||||
uint32_t workgroups_y = (uint32_t)neq2;
|
||||
uint32_t workgroups_z = (uint32_t)neq3;
|
||||
|
||||
if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows &&
|
||||
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
|
||||
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
||||
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
||||
// and change addressing calculations to index Q's dimension 2.
|
||||
gqa_ratio = qk_ratio;
|
||||
N = gqa_ratio;
|
||||
workgroups_y /= N;
|
||||
}
|
||||
|
||||
uint32_t split_kv = KV;
|
||||
uint32_t split_k = 1;
|
||||
|
||||
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
|
||||
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
||||
|
||||
// Try to use split_k when KV is large enough to be worth the overhead
|
||||
if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) {
|
||||
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
||||
// Try to run two workgroups per SM.
|
||||
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
|
||||
if (split_k > 1) {
|
||||
@ -9530,9 +9583,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
if (!ggml_vk_get_device(ctx->device)->coopmat2) {
|
||||
return false;
|
||||
}
|
||||
auto device = ggml_vk_get_device(ctx->device);
|
||||
bool coopmat2 = device->coopmat2;
|
||||
switch (op->src[0]->ne[0]) {
|
||||
case 64:
|
||||
case 80:
|
||||
@ -9540,7 +9592,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case 112:
|
||||
case 128:
|
||||
case 256:
|
||||
case 575: // DeepSeek MLA
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
@ -9566,10 +9617,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
switch (op->src[1]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
// supported in scalar and coopmat2 paths
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
|
||||
//case GGML_TYPE_Q2_K:
|
||||
//case GGML_TYPE_Q3_K:
|
||||
@ -9585,10 +9638,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
//case GGML_TYPE_IQ3_S:
|
||||
//case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
// currently supported only in coopmat2 path
|
||||
if (!coopmat2) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
if (!coopmat2 && !device->subgroup_shuffle) {
|
||||
// scalar FA uses subgroupShuffle
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case GGML_OP_GET_ROWS:
|
||||
|
483
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
Normal file
483
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
Normal file
@ -0,0 +1,483 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 1) const uint32_t Br = 1;
|
||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||
layout (constant_id = 3) const uint32_t D = 32;
|
||||
|
||||
layout (constant_id = 5) const uint32_t D_split = 16;
|
||||
const uint32_t D_per_thread = D / D_split;
|
||||
|
||||
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split;
|
||||
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
uint32_t KV;
|
||||
|
||||
uint32_t ne1;
|
||||
uint32_t ne2;
|
||||
uint32_t ne3;
|
||||
|
||||
uint32_t neq2;
|
||||
uint32_t neq3;
|
||||
uint32_t nek2;
|
||||
uint32_t nek3;
|
||||
uint32_t nev2;
|
||||
uint32_t nev3;
|
||||
uint32_t nem1;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
uint32_t nb21;
|
||||
uint32_t nb22;
|
||||
uint32_t nb23;
|
||||
uint32_t nb31;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
uint32_t mask;
|
||||
uint32_t n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t gqa_ratio;
|
||||
uint32_t split_kv;
|
||||
uint32_t k_num;
|
||||
} p;
|
||||
|
||||
layout (binding = 0) readonly buffer Q {float data_q[];};
|
||||
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
|
||||
layout (binding = 1) readonly buffer K {float16_t data_k[];};
|
||||
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
|
||||
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
||||
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
||||
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
#define BINDING_IDX_K 0
|
||||
#define BINDING_IDX_V 1
|
||||
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
#define BLOCK_BYTE_SIZE 34
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
#endif
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
// Store the output when doing grouped query attention.
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
uint32_t offset = (iq2 + r) * D + c;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Store column zero. This is used to save per-row m and L values for split_k.
|
||||
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
if (r < N && c == 0) {
|
||||
uint32_t offset = iq2 + r;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Load the slope matrix, indexed by Q's dimension 2.
|
||||
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
||||
{
|
||||
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
||||
|
||||
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
||||
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
||||
|
||||
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
||||
}
|
||||
|
||||
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
||||
shared vec4 tmpshv4[gl_WorkGroupSize.x];
|
||||
|
||||
shared float masksh[Bc][Br];
|
||||
shared vec4 Qf[Br][D / 4];
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
const uint32_t tid = gl_LocalInvocationIndex;
|
||||
const uint32_t N = p.N;
|
||||
const uint32_t KV = p.KV;
|
||||
|
||||
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
|
||||
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
|
||||
|
||||
uint32_t i = gl_WorkGroupID.x;
|
||||
uint32_t split_k_index = 0;
|
||||
|
||||
if (p.k_num > 1) {
|
||||
i = 0;
|
||||
split_k_index = gl_WorkGroupID.x;
|
||||
}
|
||||
|
||||
const uint32_t Tr = CEIL_DIV(N, Br);
|
||||
|
||||
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
||||
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
||||
|
||||
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
||||
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
||||
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
||||
const uint32_t iq3 = gl_WorkGroupID.z;
|
||||
|
||||
// broadcast factors
|
||||
const uint32_t rk2 = p.neq2/p.nek2;
|
||||
const uint32_t rk3 = p.neq3/p.nek3;
|
||||
|
||||
const uint32_t rv2 = p.neq2/p.nev2;
|
||||
const uint32_t rv3 = p.neq3/p.nev3;
|
||||
|
||||
// k indices
|
||||
const uint32_t ik3 = iq3 / rk3;
|
||||
const uint32_t ik2 = iq2 / rk2;
|
||||
|
||||
// v indices
|
||||
const uint32_t iv3 = iq3 / rv3;
|
||||
const uint32_t iv2 = iq2 / rv2;
|
||||
|
||||
// nb?1 are already divided by the type size and are in units of elements.
|
||||
// When using grouped query attention, Q is indexed by iq2, so the stride
|
||||
// should be nb02 (which is in bytes).
|
||||
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
||||
uint32_t k_stride = p.nb11;
|
||||
uint32_t v_stride = p.nb21;
|
||||
// When using grouped query attention, all rows use the same mask (stride 0).
|
||||
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
||||
// that prevents the compiler from folding the "&" through the select
|
||||
// and breaking the alignment detection.
|
||||
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
||||
|
||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (D / 4);
|
||||
uint32_t r = (idx + tid) / (D / 4);
|
||||
if (r < Br && d < D / 4 &&
|
||||
i * Br + r < N) {
|
||||
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
vec4 Of[Br][D_per_thread / 4];
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] = vec4(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
float Lf[Br], Mf[Br];
|
||||
|
||||
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
|
||||
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Lf[r] = 0;
|
||||
Mf[r] = NEG_FLT_MAX_OVER_2;
|
||||
}
|
||||
|
||||
float slope[Br];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
slope[r] = 1.0;
|
||||
}
|
||||
|
||||
// ALiBi
|
||||
if (p.max_bias > 0.0f) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
|
||||
}
|
||||
}
|
||||
|
||||
#if BLOCK_SIZE > 1
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
|
||||
#else
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||
#endif
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
float Sf[Br][cols_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
Sf[r][c] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
// Compute sum across the D_split
|
||||
[[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (p.logit_softcap != 0.0f) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (p.mask != 0) {
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br) {
|
||||
masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
float mvf = masksh[c * cols_per_iter + col_tid][r];
|
||||
|
||||
Sf[r][c] += slope[r]*mvf;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
rowmaxf[r] = Sf[r][0];
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
|
||||
}
|
||||
Moldf[r] = Mf[r];
|
||||
|
||||
// M = max(rowmax, Mold)
|
||||
// P = e^(S - M)
|
||||
// eM = e^(Mold - M)
|
||||
Mf[r] = max(rowmaxf[r], Moldf[r]);
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
Pf[r][c] = exp(Sf[r][c] - Mf[r]);
|
||||
}
|
||||
eMf[r] = exp(Moldf[r] - Mf[r]);
|
||||
|
||||
// Compute sum across row of P
|
||||
rowsumf[r] = 0.0;
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
rowsumf[r] += Pf[r][c];
|
||||
}
|
||||
|
||||
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] = eMf[r] * Of[r][d];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] += Pf[r][c] * Vf;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
// reduce across threads
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
float rowmaxf, eMf;
|
||||
|
||||
tmpsh[tid] = Mf[r];
|
||||
// Compute max across the row
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
|
||||
if (tid < s) {
|
||||
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
rowmaxf = tmpsh[d_tid];
|
||||
barrier();
|
||||
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// M = max(rowmax, Mold)
|
||||
// eM = e^(Mold - M)
|
||||
Mf[r] = max(rowmaxf, Moldf);
|
||||
eMf = exp(Moldf - Mf[r]);
|
||||
|
||||
Lf[r] = eMf*Lf[r];
|
||||
|
||||
tmpsh[tid] = Lf[r];
|
||||
|
||||
// Compute sum across the row
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
|
||||
if (tid < s) {
|
||||
tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
Lf[r] = tmpsh[d_tid];
|
||||
barrier();
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
|
||||
Of[r][d] = eMf * Of[r][d];
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
|
||||
if (tid < s) {
|
||||
Of[r][d] += tmpshv4[tid + s];
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
Of[r][d] = tmpshv4[d_tid];
|
||||
barrier();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
uint32_t o_offset = D * p.ne1 * split_k_index;
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
float Lfrcp[Br];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Lfrcp[r] = 1.0 / Lf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] *= Lfrcp[r];
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
||||
|
||||
if (p.gqa_ratio > 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (i * Br + r < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -103,7 +103,7 @@ shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[3072];
|
||||
shared u16vec2 row_ids[4096];
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
@ -92,7 +92,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||
|
||||
shared u16vec4 row_ids[3072];
|
||||
shared u16vec4 row_ids[4096];
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
|
||||
B_TYPE b[];
|
||||
|
@ -101,7 +101,7 @@ shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
|
||||
#define LOAD_VEC_B 4
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[3072];
|
||||
shared u16vec2 row_ids[4096];
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
@ -421,7 +421,6 @@ void process_shaders() {
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
// flash attention
|
||||
for (const auto& f16acc : {false, true}) {
|
||||
std::string acctype = f16acc ? "float16_t" : "float";
|
||||
@ -432,6 +431,7 @@ void process_shaders() {
|
||||
}
|
||||
if (tname == "bf16") continue;
|
||||
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
|
||||
@ -440,9 +440,17 @@ void process_shaders() {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
||||
}
|
||||
#endif
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
// mul mat vec
|
||||
|
@ -2732,11 +2732,11 @@ void ggml_mul_mat_set_prec(
|
||||
c = ggml_mul_mat_id(ctx, as, b, ids);
|
||||
|
||||
as -> [cols, rows, n_expert]
|
||||
ids -> [n_experts_used, n_tokens] (i32)
|
||||
b -> [cols, n_expert_used, n_tokens]
|
||||
ids -> [n_expert_used, n_tokens] (i32)
|
||||
c -> [rows, n_expert_used, n_tokens]
|
||||
|
||||
in b, n_experts_used can be broadcasted to match the n_expert_used of ids
|
||||
in b, n_expert_used can be broadcasted to match the n_expert_used of ids
|
||||
|
||||
c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
|
||||
*/
|
||||
@ -5499,7 +5499,7 @@ static void ggml_compute_backward(
|
||||
// tensor = src0 * 1 + src1 * 0
|
||||
if (src0_needs_grads) {
|
||||
// dsrc0 = dtensor * 1
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, grad);
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0));
|
||||
}
|
||||
if (src1_needs_grads) {
|
||||
// dsrc1 = dtensor * 0 -> noop
|
||||
@ -5780,10 +5780,9 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
|
||||
}
|
||||
|
||||
void ggml_build_backward_expand(
|
||||
struct ggml_context * ctx_static,
|
||||
struct ggml_context * ctx_compute,
|
||||
struct ggml_cgraph * cgraph,
|
||||
bool accumulate) {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * cgraph,
|
||||
struct ggml_tensor ** grad_accs) {
|
||||
GGML_ASSERT(cgraph->n_nodes > 0);
|
||||
GGML_ASSERT(cgraph->grads);
|
||||
GGML_ASSERT(cgraph->grad_accs);
|
||||
@ -5856,21 +5855,24 @@ void ggml_build_backward_expand(
|
||||
GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
|
||||
node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
|
||||
|
||||
const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
|
||||
GGML_ASSERT(igrad != GGML_HASHSET_FULL);
|
||||
GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad));
|
||||
if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
|
||||
cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node);
|
||||
cgraph->grads[igrad] = cgraph->grad_accs[igrad];
|
||||
ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name);
|
||||
const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node);
|
||||
GGML_ASSERT(ihash != GGML_HASHSET_FULL);
|
||||
GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash));
|
||||
if (grad_accs && grad_accs[i]) {
|
||||
cgraph->grad_accs[ihash] = grad_accs[i];
|
||||
cgraph->grads[ihash] = cgraph->grad_accs[ihash];
|
||||
} else if (node->flags & GGML_TENSOR_FLAG_LOSS) {
|
||||
// loss tensors always need a gradient accumulator
|
||||
cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
|
||||
cgraph->grads[ihash] = cgraph->grad_accs[ihash];
|
||||
}
|
||||
grads_needed[igrad] = true;
|
||||
grads_needed[ihash] = true;
|
||||
}
|
||||
|
||||
for (int i = n_nodes_f - 1; i >= 0; --i) {
|
||||
// inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
|
||||
// use allocator to automatically make inplace operations
|
||||
ggml_compute_backward(ctx_compute, cgraph, i, grads_needed);
|
||||
ggml_compute_backward(ctx, cgraph, i, grads_needed);
|
||||
}
|
||||
|
||||
free(grads_needed);
|
||||
@ -6016,8 +6018,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
|
||||
struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL);
|
||||
struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) {
|
||||
struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);
|
||||
ggml_graph_cpy(cgraph, result);
|
||||
return result;
|
||||
}
|
||||
@ -6036,6 +6038,9 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
|
||||
}
|
||||
|
||||
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
|
||||
if (!cgraph) {
|
||||
return;
|
||||
}
|
||||
GGML_ASSERT(cgraph->grads != NULL);
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
@ -6345,8 +6350,8 @@ void ggml_set_output(struct ggml_tensor * tensor) {
|
||||
tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
|
||||
}
|
||||
|
||||
void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
|
||||
GGML_UNUSED(ctx); // TODO: remove this parameter
|
||||
void ggml_set_param(struct ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor->op == GGML_OP_NONE);
|
||||
tensor->flags |= GGML_TENSOR_FLAG_PARAM;
|
||||
}
|
||||
|
||||
|
@ -1 +1 @@
|
||||
726dbd0636ae954ba3805c6e1fe6ecc69f851c5b
|
||||
148b286332db1259dcd299c04047a1fd31b02713
|
||||
|
@ -597,7 +597,7 @@ static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<
|
||||
auto & sched = allocr.sched;
|
||||
auto & meta = allocr.meta;
|
||||
|
||||
sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
|
||||
sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false, true);
|
||||
|
||||
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
|
||||
|
||||
|
Reference in New Issue
Block a user