mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-02-05 04:50:18 +01:00
ggml : refactor rope norm/neox (llama/7634)
* ggml : unify rope norm/neox (CPU) * ggml : fix compile warning * ggml : remove GLM rope mode ggml-ci * metal : better rope implementation ggml-ci * cuda : better rope implementation ggml-ci * naming : n_orig_ctx -> n_ctx_orig ggml-ci * dev : add reminders to update backends ggml-ci * vulkan : fix ggml_rope_ext() usage * cuda : fix array size + indents ggml-ci
This commit is contained in:
parent
e666315fa8
commit
abab4500fa
@ -1,7 +1,7 @@
|
||||
#include "rope.cuh"
|
||||
|
||||
struct rope_corr_dims {
|
||||
float v[4];
|
||||
float v[2];
|
||||
};
|
||||
|
||||
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
@ -13,8 +13,7 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const
|
||||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||
static __device__ void rope_yarn(
|
||||
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
|
||||
float * cos_theta, float * sin_theta
|
||||
) {
|
||||
float * cos_theta, float * sin_theta) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float theta = theta_interp;
|
||||
@ -29,27 +28,38 @@ static __device__ void rope_yarn(
|
||||
*sin_theta = sinf(theta) * mscale;
|
||||
}
|
||||
|
||||
// rope == RoPE == rotary positional embedding
|
||||
template<typename T, bool has_pos>
|
||||
static __global__ void rope(
|
||||
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
|
||||
float ext_factor, float attn_factor, rope_corr_dims corr_dims
|
||||
) {
|
||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
template<typename T, bool has_ff>
|
||||
static __global__ void rope_norm(
|
||||
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (col >= ncols) {
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int i = row*ncols + col;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row*ne0 + i0;
|
||||
|
||||
dst[i + 0] = x[i + 0];
|
||||
dst[i + 1] = x[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int i = row*ne0 + i0;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const int p = has_pos ? pos[i2] : 0;
|
||||
const float theta_base = p*powf(freq_base, -float(col)/ncols);
|
||||
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
const float x0 = x[i + 0];
|
||||
const float x1 = x[i + 1];
|
||||
@ -58,23 +68,20 @@ static __global__ void rope(
|
||||
dst[i + 1] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
template<typename T, bool has_pos, bool has_freq_facs>
|
||||
template<typename T, bool has_ff>
|
||||
static __global__ void rope_neox(
|
||||
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
|
||||
) {
|
||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (col >= ncols) {
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int ib = col / n_dims;
|
||||
const int ic = col % n_dims;
|
||||
|
||||
if (ib > 0) {
|
||||
const int i = row*ncols + ib*n_dims + ic;
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row*ne0 + i0;
|
||||
|
||||
dst[i + 0] = x[i + 0];
|
||||
dst[i + 1] = x[i + 1];
|
||||
@ -82,16 +89,17 @@ static __global__ void rope_neox(
|
||||
return;
|
||||
}
|
||||
|
||||
const int i = row*ncols + ib*n_dims + ic/2;
|
||||
const int i = row*ne0 + i0/2;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const int p = has_pos ? pos[i2] : 0;
|
||||
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
|
||||
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
|
||||
|
||||
const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor;
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
const float x0 = x[i + 0];
|
||||
const float x1 = x[i + n_dims/2];
|
||||
@ -100,144 +108,81 @@ static __global__ void rope_neox(
|
||||
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
static __global__ void rope_glm_f32(
|
||||
const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
|
||||
int n_ctx
|
||||
) {
|
||||
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int half_n_dims = ncols/4;
|
||||
|
||||
if (col >= half_n_dims) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int i = row*ncols + col;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
|
||||
// FIXME: this is likely wrong
|
||||
const int p = pos != nullptr ? pos[i2] : 0;
|
||||
|
||||
const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta);
|
||||
|
||||
const float x0 = x[i + 0];
|
||||
const float x1 = x[i + half_n_dims];
|
||||
|
||||
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
||||
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
|
||||
|
||||
const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
|
||||
const float sin_block_theta = sinf(block_theta);
|
||||
const float cos_block_theta = cosf(block_theta);
|
||||
|
||||
const float x2 = x[i + half_n_dims * 2];
|
||||
const float x3 = x[i + half_n_dims * 3];
|
||||
|
||||
dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
|
||||
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
static void rope_cuda(
|
||||
const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
|
||||
) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
static void rope_norm_cuda(
|
||||
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||
if (pos == nullptr) {
|
||||
rope<T, false><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
|
||||
);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nr, n_blocks_x, 1);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
|
||||
if (freq_factors == nullptr) {
|
||||
rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors
|
||||
);
|
||||
} else {
|
||||
rope<T, true><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
|
||||
);
|
||||
rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void rope_neox_cuda(
|
||||
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
|
||||
) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nr, n_blocks_x, 1);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
|
||||
if (pos == nullptr) {
|
||||
if (freq_factors == nullptr) {
|
||||
rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
if (freq_factors == nullptr) {
|
||||
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors
|
||||
);
|
||||
} else {
|
||||
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors
|
||||
);
|
||||
}
|
||||
} else {
|
||||
if (freq_factors == nullptr) {
|
||||
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors
|
||||
);
|
||||
} else {
|
||||
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_glm_f32_cuda(
|
||||
const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, int n_ctx, cudaStream_t stream
|
||||
) {
|
||||
GGML_ASSERT(ncols % 4 == 0);
|
||||
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
|
||||
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
|
||||
const dim3 block_nums(num_blocks_x, nrows, 1);
|
||||
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx);
|
||||
static void rope_norm_cuda_f16(
|
||||
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
||||
|
||||
rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
||||
}
|
||||
|
||||
static void rope_cuda_f16(
|
||||
const half * x, half * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
|
||||
static void rope_norm_cuda_f32(
|
||||
const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
||||
|
||||
rope_cuda<half>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
|
||||
}
|
||||
|
||||
static void rope_cuda_f32(
|
||||
const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
|
||||
|
||||
rope_cuda<float>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
|
||||
rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
||||
}
|
||||
|
||||
static void rope_neox_cuda_f16(
|
||||
const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
||||
|
||||
rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
||||
rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
||||
}
|
||||
|
||||
static void rope_neox_cuda_f32(
|
||||
const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
|
||||
) {
|
||||
|
||||
rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
||||
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
@ -258,16 +203,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
const int64_t nr = ggml_nrows(src0);
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
// RoPE alteration for extended context
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
@ -275,38 +226,28 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
|
||||
const float * freq_factors = nullptr;
|
||||
const int32_t * pos = nullptr;
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
pos = (const int32_t *) src1_d;
|
||||
const int32_t * pos = (const int32_t *) src1_d;
|
||||
|
||||
if (is_neox) {
|
||||
if (src2 != nullptr) {
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
|
||||
const float * freq_factors = nullptr;
|
||||
if (src2 != nullptr) {
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
|
||||
rope_corr_dims corr_dims;
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||
|
||||
// compute
|
||||
if (is_glm) {
|
||||
GGML_ASSERT(false);
|
||||
rope_glm_f32_cuda(src0_d, dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, stream);
|
||||
} else if (is_neox) {
|
||||
if (is_neox) {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_neox_cuda_f32(
|
||||
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, stream
|
||||
);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_neox_cuda_f16(
|
||||
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, stream
|
||||
);
|
||||
} else {
|
||||
@ -314,14 +255,14 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
}
|
||||
} else {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_cuda_f32(
|
||||
(const float *)src0_d, (float *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, stream
|
||||
rope_norm_cuda_f32(
|
||||
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, stream
|
||||
);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_cuda_f16(
|
||||
(const half *)src0_d, (half *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, stream
|
||||
rope_norm_cuda_f16(
|
||||
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, stream
|
||||
);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
|
@ -1192,7 +1192,7 @@ static void ggml_vk_rope(
|
||||
const std::shared_ptr<kp::Tensor>& inB,
|
||||
const std::shared_ptr<kp::Tensor>& out,
|
||||
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
||||
ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_orig_ctx,
|
||||
ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
|
||||
float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
|
||||
int32_t ne01, int32_t ne02, int32_t ne03,
|
||||
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
||||
@ -1221,14 +1221,14 @@ static void ggml_vk_rope(
|
||||
|
||||
struct PushConstants {
|
||||
uint32_t inAOff, inBOff, outOff;
|
||||
int32_t n_dims, mode, n_orig_ctx;
|
||||
int32_t n_dims, mode, n_ctx_orig;
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
uint32_t nb00, nb01, nb02, nb03;
|
||||
int32_t ne0;
|
||||
uint32_t nb0, nb1, nb2, nb3;
|
||||
} pushConsts {
|
||||
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
|
||||
n_dims, mode, n_orig_ctx,
|
||||
n_dims, mode, n_ctx_orig,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
|
||||
nb00, nb01, nb02, nb03,
|
||||
ne0,
|
||||
@ -1692,13 +1692,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
|
||||
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
|
||||
|
||||
#pragma message("TODO: update rope NORM mode to match NEOX mode")
|
||||
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
|
||||
|
||||
GGML_ASSERT(ne10 == ne02);
|
||||
GGML_ASSERT(src0t == dstt);
|
||||
// const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
// skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
@ -1708,7 +1711,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
ggml_vk_rope(
|
||||
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_orig_ctx,
|
||||
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
|
||||
ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
|
||||
);
|
||||
|
54
ggml-metal.m
54
ggml-metal.m
@ -172,8 +172,10 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
||||
@ -626,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
||||
@ -2285,7 +2289,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
@ -2302,21 +2306,22 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
|
||||
|
||||
if (!is_neox) {
|
||||
GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
|
||||
}
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
|
||||
default: GGML_ASSERT(false);
|
||||
};
|
||||
if (!is_neox) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
|
||||
default: GGML_ASSERT(false);
|
||||
};
|
||||
} else {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
|
||||
default: GGML_ASSERT(false);
|
||||
};
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
@ -2345,14 +2350,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
||||
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
||||
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
||||
[encoder setBytes:&mode length:sizeof( int) atIndex:22];
|
||||
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
|
||||
[encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
|
||||
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
|
||||
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
|
||||
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
|
||||
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
|
||||
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
|
||||
[encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
|
||||
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
||||
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
||||
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
||||
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
||||
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
||||
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
|
200
ggml-metal.metal
200
ggml-metal.metal
@ -1654,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||
static void rope_yarn(
|
||||
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
||||
thread float * cos_theta, thread float * sin_theta
|
||||
) {
|
||||
thread float * cos_theta, thread float * sin_theta) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float theta = theta_interp;
|
||||
@ -1672,55 +1671,20 @@ static void rope_yarn(
|
||||
|
||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||||
static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
||||
return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
||||
static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
||||
return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
||||
}
|
||||
|
||||
static void rope_yarn_corr_dims(
|
||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
||||
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
||||
) {
|
||||
// start and end correction dims
|
||||
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
|
||||
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
|
||||
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
|
||||
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
|
||||
}
|
||||
|
||||
typedef void (rope_t)(
|
||||
device const void * src0,
|
||||
device const int32_t * src1,
|
||||
device const float * src2,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
constant int & n_past,
|
||||
constant int & n_dims,
|
||||
constant int & mode,
|
||||
constant int & n_orig_ctx,
|
||||
constant float & freq_base,
|
||||
constant float & freq_scale,
|
||||
constant float & ext_factor,
|
||||
constant float & attn_factor,
|
||||
constant float & beta_fast,
|
||||
constant float & beta_slow,
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg[[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]);
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_rope(
|
||||
kernel void kernel_rope_norm(
|
||||
device const void * src0,
|
||||
device const int32_t * src1,
|
||||
device const float * src2,
|
||||
@ -1743,8 +1707,7 @@ kernel void kernel_rope(
|
||||
constant uint64_t & nb3,
|
||||
constant int & n_past,
|
||||
constant int & n_dims,
|
||||
constant int & mode,
|
||||
constant int & n_orig_ctx,
|
||||
constant int & n_ctx_orig,
|
||||
constant float & freq_base,
|
||||
constant float & freq_scale,
|
||||
constant float & ext_factor,
|
||||
@ -1758,69 +1721,130 @@ kernel void kernel_rope(
|
||||
const int64_t i2 = tgpig[1];
|
||||
const int64_t i1 = tgpig[0];
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = src1;
|
||||
|
||||
const int64_t p = pos[i2];
|
||||
|
||||
const float theta_base = (float)p;
|
||||
const float theta_base = (float) pos[i2];
|
||||
const float inv_ndims = -1.f/n_dims;
|
||||
|
||||
if (!is_neox) {
|
||||
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
||||
if (i0 < n_dims) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const T x0 = src[0];
|
||||
const T x1 = src[1];
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[1];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
} else {
|
||||
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
||||
if (ic < n_dims) {
|
||||
const int64_t i0 = ic/2;
|
||||
} else {
|
||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
|
||||
|
||||
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
} else {
|
||||
const int64_t i0 = ic;
|
||||
|
||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
||||
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
||||
template<typename T>
|
||||
kernel void kernel_rope_neox(
|
||||
device const void * src0,
|
||||
device const int32_t * src1,
|
||||
device const float * src2,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
constant int & n_past,
|
||||
constant int & n_dims,
|
||||
constant int & n_ctx_orig,
|
||||
constant float & freq_base,
|
||||
constant float & freq_scale,
|
||||
constant float & ext_factor,
|
||||
constant float & attn_factor,
|
||||
constant float & beta_fast,
|
||||
constant float & beta_slow,
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg[[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int64_t i3 = tgpig[2];
|
||||
const int64_t i2 = tgpig[1];
|
||||
const int64_t i1 = tgpig[0];
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = src1;
|
||||
|
||||
const float theta_base = (float) pos[i2];
|
||||
const float inv_ndims = -1.f/n_dims;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
||||
if (i0 < n_dims) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
} else {
|
||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
||||
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
||||
|
||||
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
||||
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
||||
|
||||
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
||||
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
||||
|
||||
typedef void (im2col_t)(
|
||||
device const float * x,
|
||||
|
@ -8928,49 +8928,6 @@ static void rope_neox(
|
||||
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
static void rope_glm_f32(
|
||||
const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
|
||||
int n_ctx
|
||||
, const sycl::nd_item<3> &item_ct1) {
|
||||
const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
const int half_n_dims = ncols/4;
|
||||
|
||||
if (col >= half_n_dims) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
const int i = row*ncols + col;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const float col_theta_scale = dpct::pow(freq_base, -2.0f * col / ncols);
|
||||
// FIXME: this is likely wrong
|
||||
const int p = pos != nullptr ? pos[i2] : 0;
|
||||
|
||||
const float theta = sycl::min(p, n_ctx - 2) * freq_scale * col_theta_scale;
|
||||
const float sin_theta = sycl::sin((float)theta);
|
||||
const float cos_theta = sycl::cos((float)theta);
|
||||
|
||||
const float x0 = x[i + 0];
|
||||
const float x1 = x[i + half_n_dims];
|
||||
|
||||
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
||||
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
|
||||
|
||||
const float block_theta =
|
||||
((float)sycl::max(p - n_ctx - 2, 0)) * col_theta_scale;
|
||||
const float sin_block_theta = sycl::sin((float)block_theta);
|
||||
const float cos_block_theta = sycl::cos((float)block_theta);
|
||||
|
||||
const float x2 = x[i + half_n_dims * 2];
|
||||
const float x3 = x[i + half_n_dims * 3];
|
||||
|
||||
dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
|
||||
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
|
||||
}
|
||||
|
||||
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int row = item_ct1.get_group(1);
|
||||
@ -12520,22 +12477,6 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
|
||||
const int32_t *pos, float freq_scale,
|
||||
int p_delta_rows, float freq_base, int n_ctx,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % 4 == 0);
|
||||
const sycl::range<3> block_dims(1, 1, SYCL_ROPE_BLOCK_SIZE / 4);
|
||||
const int num_blocks_x = (ncols + SYCL_ROPE_BLOCK_SIZE - 1) / SYCL_ROPE_BLOCK_SIZE;
|
||||
const sycl::range<3> block_nums(1, nrows, num_blocks_x);
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rope_glm_f32(x, dst, ncols, pos, freq_scale,
|
||||
p_delta_rows, freq_base, n_ctx,
|
||||
item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
||||
const int nrows, dpct::queue_ptr stream) {
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
@ -14066,8 +14007,8 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
// RoPE alteration for extended context
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
@ -14087,7 +14028,9 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
}
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
#pragma message("TODO: update rope NORM mode to match NEOX mode")
|
||||
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
|
||||
|
||||
if (is_neox) {
|
||||
pos = (const int32_t *) src1_dd;
|
||||
@ -14100,13 +14043,10 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
}
|
||||
|
||||
rope_corr_dims corr_dims;
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||
|
||||
// compute
|
||||
if (is_glm) {
|
||||
GGML_ASSERT(false);
|
||||
rope_glm_f32_sycl(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
|
||||
} else if (is_neox) {
|
||||
if (is_neox) {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_neox_sycl(
|
||||
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
|
@ -3898,11 +3898,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
{
|
||||
const int mode = ((const int32_t *) dst->op_params)[2];
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
if (is_glm) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (is_neox) {
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
@ -4401,7 +4396,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
// const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
const float freq_base = ((float *) dst->op_params)[5];
|
||||
const float freq_scale = ((float *) dst->op_params)[6];
|
||||
const float ext_factor = ((float *) dst->op_params)[7];
|
||||
@ -4410,12 +4405,12 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
|
||||
const float beta_slow = ((float *) dst->op_params)[10];
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
GGML_ASSERT(!is_glm);
|
||||
#pragma message("TODO: update rope NORM mode to match NEOX mode")
|
||||
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
|
||||
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
if (is_neox) {
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
@ -6485,9 +6480,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
const int mode = ((const int32_t *) op->op_params)[2];
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
return !is_glm;
|
||||
return true;
|
||||
} break;
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
@ -6992,15 +6986,15 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
|
||||
} else if (tensor->op == GGML_OP_ROPE) {
|
||||
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
||||
const int mode = ((int32_t *) tensor->op_params)[2];
|
||||
const int n_ggml_ctx = ((int32_t *) tensor->op_params)[3];
|
||||
const int n_orig_ggml_ctx = ((int32_t *) tensor->op_params)[4];
|
||||
//const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
|
||||
const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4];
|
||||
float freq_base = ((float *) tensor->op_params)[5];
|
||||
float freq_scale = ((float *) tensor->op_params)[6];
|
||||
float ext_factor = ((float *) tensor->op_params)[7];
|
||||
float attn_factor = ((float *) tensor->op_params)[8];
|
||||
float beta_fast = ((float *) tensor->op_params)[9];
|
||||
float beta_slow = ((float *) tensor->op_params)[10];
|
||||
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ggml_ctx, n_orig_ggml_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
} else if (tensor->op == GGML_OP_UNARY) {
|
||||
switch (ggml_get_unary_op(tensor)) {
|
||||
case GGML_UNARY_OP_SILU:
|
||||
|
334
ggml.c
334
ggml.c
@ -6250,16 +6250,13 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow,
|
||||
float xpos_base,
|
||||
bool xpos_down,
|
||||
bool inplace) {
|
||||
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
|
||||
|
||||
@ -6280,15 +6277,13 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
|
||||
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
||||
memcpy(params + 5, &freq_base, sizeof(float));
|
||||
memcpy(params + 6, &freq_scale, sizeof(float));
|
||||
memcpy(params + 7, &ext_factor, sizeof(float));
|
||||
memcpy(params + 8, &attn_factor, sizeof(float));
|
||||
memcpy(params + 9, &beta_fast, sizeof(float));
|
||||
memcpy(params + 10, &beta_slow, sizeof(float));
|
||||
memcpy(params + 11, &xpos_base, sizeof(float));
|
||||
memcpy(params + 12, &xpos_down, sizeof(bool));
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_ROPE;
|
||||
@ -6305,10 +6300,9 @@ struct ggml_tensor * ggml_rope(
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx) {
|
||||
int mode) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
||||
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
|
||||
);
|
||||
}
|
||||
|
||||
@ -6317,10 +6311,9 @@ struct ggml_tensor * ggml_rope_inplace(
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx) {
|
||||
int mode) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
||||
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
|
||||
);
|
||||
}
|
||||
|
||||
@ -6331,8 +6324,7 @@ struct ggml_tensor * ggml_rope_ext(
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
@ -6340,8 +6332,8 @@ struct ggml_tensor * ggml_rope_ext(
|
||||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
||||
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, false
|
||||
);
|
||||
}
|
||||
|
||||
@ -6352,8 +6344,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
@ -6361,8 +6352,8 @@ struct ggml_tensor * ggml_rope_ext_inplace(
|
||||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
||||
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, true
|
||||
);
|
||||
}
|
||||
|
||||
@ -6372,8 +6363,7 @@ struct ggml_tensor * ggml_rope_custom(
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
@ -6381,8 +6371,8 @@ struct ggml_tensor * ggml_rope_custom(
|
||||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, false
|
||||
);
|
||||
}
|
||||
|
||||
@ -6392,8 +6382,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
@ -6401,21 +6390,11 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
||||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, true
|
||||
);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
float base,
|
||||
bool down) {
|
||||
return ggml_rope_impl(ctx, a, b, NULL, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
|
||||
}
|
||||
|
||||
// ggml_rope_back
|
||||
|
||||
struct ggml_tensor * ggml_rope_back(
|
||||
@ -6425,16 +6404,13 @@ struct ggml_tensor * ggml_rope_back(
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow,
|
||||
float xpos_base,
|
||||
bool xpos_down) {
|
||||
float beta_slow) {
|
||||
GGML_ASSERT(ggml_is_vector(b));
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
||||
@ -6450,15 +6426,13 @@ struct ggml_tensor * ggml_rope_back(
|
||||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
|
||||
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
||||
memcpy(params + 5, &freq_base, sizeof(float));
|
||||
memcpy(params + 6, &freq_scale, sizeof(float));
|
||||
memcpy(params + 7, &ext_factor, sizeof(float));
|
||||
memcpy(params + 8, &attn_factor, sizeof(float));
|
||||
memcpy(params + 9, &beta_fast, sizeof(float));
|
||||
memcpy(params + 10, &beta_slow, sizeof(float));
|
||||
memcpy(params + 11, &xpos_base, sizeof(float));
|
||||
memcpy(params + 12, &xpos_down, sizeof(bool));
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_ROPE_BACK;
|
||||
@ -14227,8 +14201,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||
static void rope_yarn(
|
||||
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
||||
float * cos_theta, float * sin_theta
|
||||
) {
|
||||
float * cos_theta, float * sin_theta) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float theta = theta_interp;
|
||||
@ -14245,18 +14218,19 @@ static void rope_yarn(
|
||||
|
||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||||
static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
||||
return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
|
||||
static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
||||
return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
|
||||
}
|
||||
|
||||
static void ggml_rope_cache_init(
|
||||
float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
||||
float * cache, float sin_sign, float theta_scale
|
||||
) {
|
||||
float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
||||
float * cache, float sin_sign, float theta_scale) {
|
||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||
float theta = theta_base;
|
||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
|
||||
rope_yarn(
|
||||
theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
||||
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
||||
);
|
||||
cache[i0 + 1] *= sin_sign;
|
||||
|
||||
@ -14265,11 +14239,11 @@ static void ggml_rope_cache_init(
|
||||
}
|
||||
|
||||
GGML_CALL void ggml_rope_yarn_corr_dims(
|
||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
||||
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
||||
) {
|
||||
// start and end correction dims
|
||||
float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base));
|
||||
float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base));
|
||||
float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
|
||||
float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
|
||||
dims[0] = MAX(0, start);
|
||||
dims[1] = MIN(n_dims - 1, end);
|
||||
}
|
||||
@ -14289,15 +14263,11 @@ static void ggml_compute_forward_rope_f32(
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
|
||||
// these two only relevant for xPos RoPE:
|
||||
float xpos_base;
|
||||
bool xpos_down;
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
@ -14305,8 +14275,6 @@ static void ggml_compute_forward_rope_f32(
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float));
|
||||
memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
@ -14336,20 +14304,15 @@ static void ggml_compute_forward_rope_f32(
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (is_neox) {
|
||||
if (src2 != NULL) {
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
||||
if (src2 != NULL) {
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
|
||||
// backward process uses inverse rotation by cos and sin.
|
||||
@ -14364,95 +14327,51 @@ static void ggml_compute_forward_rope_f32(
|
||||
const int64_t p = pos[i2];
|
||||
|
||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
|
||||
ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir > ir1) break;
|
||||
|
||||
float theta_base = (float)p;
|
||||
|
||||
if (is_glm) {
|
||||
theta_base = MIN(p, n_ctx - 2);
|
||||
float block_theta = MAX(p - (n_ctx - 2), 0);
|
||||
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
|
||||
const float cos_theta = cosf(theta_base);
|
||||
const float sin_theta = sinf(theta_base) * sin_sign;
|
||||
const float cos_block_theta = cosf(block_theta);
|
||||
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
||||
|
||||
theta_base *= theta_scale;
|
||||
block_theta *= theta_scale;
|
||||
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
const float x2 = src[n_dims];
|
||||
const float x3 = src[n_dims/2*3];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta;
|
||||
dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
|
||||
}
|
||||
} else if (!is_neox) {
|
||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
if (!is_neox) {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
// zeta scaling for xPos only:
|
||||
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
|
||||
if (xpos_down) zeta = 1.0f / zeta;
|
||||
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[1];
|
||||
|
||||
dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta;
|
||||
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
} else {
|
||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||
for (int64_t ic = 0; ic < ne0; ic += 2) {
|
||||
if (ic < n_dims) {
|
||||
const int64_t i0 = ic/2;
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(
|
||||
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
|
||||
&cos_theta, &sin_theta
|
||||
);
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
sin_theta *= sin_sign;
|
||||
theta_base *= theta_scale;
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
} else {
|
||||
const int64_t i0 = ic;
|
||||
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -14477,8 +14396,8 @@ static void ggml_compute_forward_rope_f16(
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
@ -14514,20 +14433,15 @@ static void ggml_compute_forward_rope_f16(
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (is_neox) {
|
||||
if (src2 != NULL) {
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
||||
if (src2 != NULL) {
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
|
||||
// backward process uses inverse rotation by cos and sin.
|
||||
@ -14542,43 +14456,14 @@ static void ggml_compute_forward_rope_f16(
|
||||
const int64_t p = pos[i2];
|
||||
|
||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
|
||||
ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir > ir1) break;
|
||||
|
||||
float theta_base = (float)p;
|
||||
|
||||
if (is_glm) {
|
||||
theta_base = MIN(p, n_ctx - 2);
|
||||
float block_theta = MAX(p - (n_ctx - 2), 0);
|
||||
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
|
||||
const float cos_theta = cosf(theta_base);
|
||||
const float sin_theta = sinf(theta_base) * sin_sign;
|
||||
const float cos_block_theta = cosf(block_theta);
|
||||
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
||||
|
||||
theta_base *= theta_scale;
|
||||
block_theta *= theta_scale;
|
||||
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
||||
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
|
||||
const float x2 = GGML_FP16_TO_FP32(src[n_dims]);
|
||||
const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]);
|
||||
|
||||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
|
||||
dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
|
||||
}
|
||||
} else if (!is_neox) {
|
||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
if (!is_neox) {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
@ -14592,41 +14477,30 @@ static void ggml_compute_forward_rope_f16(
|
||||
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
} else {
|
||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||
for (int64_t ic = 0; ic < ne0; ic += 2) {
|
||||
if (ic < n_dims) {
|
||||
const int64_t i0 = ic/2;
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(
|
||||
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
|
||||
&cos_theta, &sin_theta
|
||||
);
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
sin_theta *= sin_sign;
|
||||
theta_base *= theta_scale;
|
||||
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
||||
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
|
||||
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
||||
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
|
||||
|
||||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
} else {
|
||||
const int64_t i0 = ic;
|
||||
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -18327,9 +18201,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
||||
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
||||
const int mode = ((int32_t *) tensor->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
|
||||
//const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
|
||||
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
|
||||
@ -18337,8 +18211,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
|
||||
memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
|
||||
memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
|
||||
|
||||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
@ -18348,16 +18220,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
src2,
|
||||
n_dims,
|
||||
mode,
|
||||
n_ctx,
|
||||
n_orig_ctx,
|
||||
n_ctx_orig,
|
||||
freq_base,
|
||||
freq_scale,
|
||||
ext_factor,
|
||||
attn_factor,
|
||||
beta_fast,
|
||||
beta_slow,
|
||||
xpos_base,
|
||||
xpos_down),
|
||||
beta_slow),
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
@ -18367,9 +18236,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
||||
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
||||
const int mode = ((int32_t *) tensor->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
|
||||
//const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
|
||||
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
|
||||
@ -18377,8 +18246,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
|
||||
memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
|
||||
memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
|
||||
|
||||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
@ -18388,16 +18255,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
src2,
|
||||
n_dims,
|
||||
mode,
|
||||
n_ctx,
|
||||
n_orig_ctx,
|
||||
n_ctx_orig,
|
||||
freq_base,
|
||||
freq_scale,
|
||||
ext_factor,
|
||||
attn_factor,
|
||||
beta_fast,
|
||||
beta_slow,
|
||||
xpos_base,
|
||||
xpos_down,
|
||||
false),
|
||||
zero_table);
|
||||
}
|
||||
|
36
ggml.h
36
ggml.h
@ -1465,7 +1465,6 @@ extern "C" {
|
||||
// rotary position embedding
|
||||
// if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
|
||||
// if mode & 2 == 1, GPT-NeoX style
|
||||
// if mode & 4 == 1, ChatGLM style
|
||||
//
|
||||
// b is an int32 vector with size a->ne[2], it contains the positions
|
||||
// c is freq factors (e.g. phi3-128k), (optional)
|
||||
@ -1474,8 +1473,7 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx);
|
||||
int mode);
|
||||
|
||||
// in-place, returns view(a)
|
||||
GGML_API struct ggml_tensor * ggml_rope_inplace(
|
||||
@ -1483,8 +1481,7 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx);
|
||||
int mode);
|
||||
|
||||
// custom RoPE
|
||||
GGML_API struct ggml_tensor * ggml_rope_ext(
|
||||
@ -1494,8 +1491,7 @@ extern "C" {
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
@ -1511,8 +1507,7 @@ extern "C" {
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
@ -1526,8 +1521,7 @@ extern "C" {
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
@ -1542,8 +1536,7 @@ extern "C" {
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
@ -1552,17 +1545,9 @@ extern "C" {
|
||||
float beta_slow),
|
||||
"use ggml_rope_ext_inplace instead");
|
||||
|
||||
struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
float base,
|
||||
bool down);
|
||||
|
||||
// compute correction dims for YaRN RoPE scaling
|
||||
GGML_CALL void ggml_rope_yarn_corr_dims(
|
||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
||||
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
||||
|
||||
// rotary position embedding backward, i.e compute dx from dy
|
||||
// a - dy
|
||||
@ -1573,16 +1558,13 @@ extern "C" {
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow,
|
||||
float xpos_base,
|
||||
bool xpos_down);
|
||||
float beta_slow);
|
||||
|
||||
// clamp
|
||||
// in-place, returns view(a)
|
||||
|
Loading…
Reference in New Issue
Block a user