ggml : add asserts for type conversion in fattn kernels (llama/9971)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-10-21 16:20:46 +03:00
parent 25f9fee6fb
commit 741c138aa1

View File

@ -325,8 +325,9 @@ struct ggml_logger_state {
static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL}; static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};
static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) { static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
if (format == NULL) if (format == NULL) {
return; return;
}
va_list args_copy; va_list args_copy;
va_copy(args_copy, args); va_copy(args_copy, args);
char buffer[128]; char buffer[128];
@ -15690,6 +15691,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
ggml_to_float_t const v_to_float = type_traits[v->type].to_float; ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
// loop over n_batch and n_head // loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) { for (int ir = ir0; ir < ir1; ++ir) {
// q indices // q indices