ggml : resolve merge (ggml/0)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-05-11 16:25:50 +03:00
parent e54329da7b
commit accada542a
3 changed files with 9 additions and 4 deletions

View File

@ -71,6 +71,7 @@ bool ggml_common_quantize_0(
case GGML_FTYPE_MOSTLY_IQ4_NL: case GGML_FTYPE_MOSTLY_IQ4_NL:
case GGML_FTYPE_MOSTLY_IQ4_XS: case GGML_FTYPE_MOSTLY_IQ4_XS:
case GGML_FTYPE_MOSTLY_IQ1_M: case GGML_FTYPE_MOSTLY_IQ1_M:
case GGML_FTYPE_MOSTLY_BF16:
{ {
fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype); fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
return false; return false;
@ -207,6 +208,7 @@ bool ggml_common_quantize_0(
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_M:
case GGML_TYPE_BF16:
case GGML_TYPE_COUNT: case GGML_TYPE_COUNT:
{ {
fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype)); fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype));

View File

@ -296,7 +296,7 @@ kernel void kernel_silu(
dst[tpig] = x / (1.0f + exp(-x)); dst[tpig] = x / (1.0f + exp(-x));
} }
+kernel void kernel_silu_4( kernel void kernel_silu_4(
device const float4 * src0, device const float4 * src0,
device float4 * dst, device float4 * dst,
uint tpig[[thread_position_in_grid]]) { uint tpig[[thread_position_in_grid]]) {
@ -2217,7 +2217,7 @@ kernel void kernel_flash_attn_ext_f16(
// ALiBi // ALiBi
if (max_bias > 0.0f) { if (max_bias > 0.0f) {
const short h = iq2; const uint32_t h = iq2;
const float base = h < n_head_log2 ? m0 : m1; const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
@ -2473,7 +2473,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
// ALiBi // ALiBi
if (max_bias > 0.0f) { if (max_bias > 0.0f) {
const short h = iq2; const uint32_t h = iq2;
const float base = h < n_head_log2 ? m0 : m1; const float base = h < n_head_log2 ? m0 : m1;
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;

5
ggml.c
View File

@ -4,7 +4,6 @@
#include "ggml-impl.h" #include "ggml-impl.h"
#include "ggml-quants.h" #include "ggml-quants.h"
#include "ggml.h" #include "ggml.h"
#include "sgemm.h"
#if defined(_MSC_VER) || defined(__MINGW32__) #if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW #include <malloc.h> // using malloc.h with MSC/MINGW
@ -37,6 +36,10 @@
#undef GGML_USE_LLAMAFILE #undef GGML_USE_LLAMAFILE
#endif #endif
#ifdef GGML_USE_LLAMAFILE
#include "sgemm.h"
#endif
#if defined(_MSC_VER) #if defined(_MSC_VER)
// disable "possible loss of data" to avoid hundreds of casts // disable "possible loss of data" to avoid hundreds of casts
// we should just be careful :) // we should just be careful :)