ggml : add ggml_scale_bias (llama/14417)

* ggml : add ggml_scale_bias

* ggml_vec_mad1_f32

* add more simd

* add CUDA

* sycl

* vulkan

* cann (placeholder)

* opencl

* will this fix cpu?

* fix cuda

* suggestions from coderabbit

* fix cann compile error

* vDSP_vsmsa

* rm __ARM_FEATURE_SVE

* use memcpy for op params

* make code looks more consistent

* use scalar for __ARM_FEATURE_SVE

* add x param to ggml_vec_mad1_f32
This commit is contained in:
Xuan-Son Nguyen
2025-07-09 18:16:12 +02:00
committed by Georgi Gerganov
parent 48b18f9eb8
commit 2021870fb8
13 changed files with 132 additions and 34 deletions

View File

@ -1294,6 +1294,19 @@ extern "C" {
struct ggml_tensor * a, struct ggml_tensor * a,
float s); float s);
// x = s * a + b
GGML_API struct ggml_tensor * ggml_scale_bias(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s,
float b);
GGML_API struct ggml_tensor * ggml_scale_bias_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s,
float b);
// b -> view(a,offset,nb1,nb2,3), return modified a // b -> view(a,offset,nb1,nb2,3), return modified a
GGML_API struct ggml_tensor * ggml_set( GGML_API struct ggml_tensor * ggml_set(
struct ggml_context * ctx, struct ggml_context * ctx,

View File

@ -2188,7 +2188,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_MUL: case GGML_OP_MUL:
case GGML_OP_DIV: case GGML_OP_DIV:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_SCALE:
case GGML_OP_SQR: case GGML_OP_SQR:
case GGML_OP_SQRT: case GGML_OP_SQRT:
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
@ -2210,6 +2209,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_PAD_REFLECT_1D: case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL: case GGML_OP_COUNT_EQUAL:
return true; return true;
case GGML_OP_SCALE:
float bias;
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
return bias == 0.0f; // TODO: support bias != 0.0f
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
// TODO: support broadcast // TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435 // ref: https://github.com/ggml-org/llama.cpp/pull/14435

View File

@ -4643,9 +4643,11 @@ static void ggml_compute_forward_scale_f32(
GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst));
// scale factor float s; // scale factor
float v; float b; // bias
memcpy(&v, dst->op_params, sizeof(float));
memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
@ -4664,12 +4666,22 @@ static void ggml_compute_forward_scale_f32(
const size_t nb1 = dst->nb[1]; const size_t nb1 = dst->nb[1];
if (b == 0.0f) {
for (int i1 = ir0; i1 < ir1; i1++) { for (int i1 = ir0; i1 < ir1; i1++) {
if (dst->data != src0->data) { if (dst->data != src0->data) {
// src0 is same shape as dst => same indices // src0 is same shape as dst => same indices
// TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
} }
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v); ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
}
} else {
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_mad1_f32(nc,
(float *) ((char *) dst->data + i1*nb1),
(float *) ((char *) src0->data + i1*nb1),
s, b);
}
} }
} }

View File

@ -351,6 +351,45 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
#endif #endif
} }
inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
#if defined(GGML_USE_ACCELERATE)
vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
#elif defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
// scalar ; TODO: Write SVE code
for (int i = 0; i < n; ++i) {
y[i] = x[i]*s + b;
}
#else
const int np = (n & ~(GGML_F32_STEP - 1));
GGML_F32_VEC vs = GGML_F32_VEC_SET1(s);
GGML_F32_VEC vb = GGML_F32_VEC_SET1(b);
GGML_F32_VEC ay[GGML_F32_ARR];
for (int i = 0; i < np; i += GGML_F32_STEP) {
for (int j = 0; j < GGML_F32_ARR; j++) {
ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb);
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] = x[i]*s + b;
}
#endif
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] = x[i]*s + b;
}
#endif
}
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
#if defined(GGML_USE_ACCELERATE) #if defined(GGML_USE_ACCELERATE)

View File

@ -1,18 +1,18 @@
#include "scale.cuh" #include "scale.cuh"
static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) { static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) { if (i >= k) {
return; return;
} }
dst[i] = scale * x[i]; dst[i] = scale * x[i] + bias;
} }
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k); scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k);
} }
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -25,7 +25,9 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
float scale; float scale;
memcpy(&scale, dst->op_params, sizeof(float)); float bias;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream); scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream);
} }

View File

@ -2256,7 +2256,9 @@ static bool ggml_metal_encode_node(
GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src0));
float scale; float scale;
memcpy(&scale, dst->op_params, sizeof(scale)); float bias;
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float));
memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float));
int64_t n = ggml_nelements(dst); int64_t n = ggml_nelements(dst);
@ -2273,6 +2275,7 @@ static bool ggml_metal_encode_node(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
[encoder setBytes:&bias length:sizeof(bias) atIndex:3];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;

View File

@ -1014,16 +1014,18 @@ kernel void kernel_scale(
device const float * src0, device const float * src0,
device float * dst, device float * dst,
constant float & scale, constant float & scale,
constant float & bias,
uint tpig[[thread_position_in_grid]]) { uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale; dst[tpig] = src0[tpig] * scale + bias;
} }
kernel void kernel_scale_4( kernel void kernel_scale_4(
device const float4 * src0, device const float4 * src0,
device float4 * dst, device float4 * dst,
constant float & scale, constant float & scale,
constant float & bias,
uint tpig[[thread_position_in_grid]]) { uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale; dst[tpig] = src0[tpig] * scale + bias;
} }
kernel void kernel_clamp( kernel void kernel_clamp(

View File

@ -5587,7 +5587,9 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
float scale; float scale;
memcpy(&scale, dst->op_params, sizeof(scale)); float bias;
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(float));
memcpy(&bias, ((int32_t *) dst->op_params) + 1, sizeof(float));
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
@ -5602,6 +5604,7 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &bias));
int n = ggml_nelements(dst)/4; int n = ggml_nelements(dst)/4;

View File

@ -8,9 +8,10 @@ kernel void kernel_scale(
ulong offset0, ulong offset0,
global float4 * dst, global float4 * dst,
ulong offsetd, ulong offsetd,
float scale float scale,
float bias
) { ) {
src0 = (global float4*)((global char*)src0 + offset0); src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd); dst = (global float4*)((global char*)dst + offsetd);
dst[get_global_id(0)] = src0[get_global_id(0)] * scale; dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias;
} }

View File

@ -1695,7 +1695,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
} }
static void scale_f32(const float * x, float * dst, const float scale, const int k, static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@ -1704,7 +1704,7 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
return; return;
} }
dst[i] = scale * x[i]; dst[i] = scale * x[i] + bias;
} }
@ -1842,7 +1842,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
static void scale_f32_sycl(const float *x, float *dst, const float scale, static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias,
const int k, queue_ptr stream) { const int k, queue_ptr stream) {
const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE; const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@ -1850,7 +1850,7 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale,
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
scale_f32(x, dst, scale, k, item_ct1); scale_f32(x, dst, scale, bias, k, item_ct1);
}); });
} }
@ -2319,9 +2319,11 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds
float * dst_dd = static_cast<float *>(dst->data); float * dst_dd = static_cast<float *>(dst->data);
float scale; float scale;
memcpy(&scale, dst->op_params, sizeof(float)); float bias;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream); scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream);
/* /*
DPCT1010:87: SYCL uses exceptions to report errors and does not use the DPCT1010:87: SYCL uses exceptions to report errors and does not use the
error codes. The call was replaced with 0. You need to rewrite this code. error codes. The call was replaced with 0. You need to rewrite this code.

View File

@ -7508,7 +7508,7 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0, 0,
op_params[0], 0.0f, op_params[0], op_params[1],
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}, dryrun); }, dryrun);
} }

View File

@ -18,7 +18,7 @@ void main() {
continue; continue;
} }
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1)); data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2));
idx += num_threads; idx += num_threads;
} }
} }

View File

@ -3061,12 +3061,14 @@ static struct ggml_tensor * ggml_scale_impl(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
float s, float s,
float b,
bool inplace) { bool inplace) {
GGML_ASSERT(ggml_is_padded_1d(a)); GGML_ASSERT(ggml_is_padded_1d(a));
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
ggml_set_op_params(result, &s, sizeof(s)); float params[2] = { s, b };
ggml_set_op_params(result, &params, sizeof(params));
result->op = GGML_OP_SCALE; result->op = GGML_OP_SCALE;
result->src[0] = a; result->src[0] = a;
@ -3078,14 +3080,30 @@ struct ggml_tensor * ggml_scale(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
float s) { float s) {
return ggml_scale_impl(ctx, a, s, false); return ggml_scale_impl(ctx, a, s, 0.0, false);
} }
struct ggml_tensor * ggml_scale_inplace( struct ggml_tensor * ggml_scale_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
float s) { float s) {
return ggml_scale_impl(ctx, a, s, true); return ggml_scale_impl(ctx, a, s, 0.0, true);
}
struct ggml_tensor * ggml_scale_bias(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s,
float b) {
return ggml_scale_impl(ctx, a, s, b, false);
}
struct ggml_tensor * ggml_scale_bias_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s,
float b) {
return ggml_scale_impl(ctx, a, s, b, true);
} }
// ggml_set // ggml_set
@ -5769,7 +5787,7 @@ static void ggml_compute_backward(
} break; } break;
case GGML_OP_MEAN: { case GGML_OP_MEAN: {
if (src0_needs_grads) { if (src0_needs_grads) {
ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false)); ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false));
} }
} break; } break;
case GGML_OP_REPEAT: { case GGML_OP_REPEAT: {
@ -5846,7 +5864,7 @@ static void ggml_compute_backward(
if (src0_needs_grads) { if (src0_needs_grads) {
float s; float s;
memcpy(&s, tensor->op_params, sizeof(float)); memcpy(&s, tensor->op_params, sizeof(float));
ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false)); ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, 0.0, false));
} }
} break; } break;
case GGML_OP_SET: { case GGML_OP_SET: {