Skip to content

Commit bfcee09

Browse files
committed
cuda : implement upscale with bicubic interpolation
1 parent 302a9d0 commit bfcee09

File tree

1 file changed

+87
-6
lines changed

1 file changed

+87
-6
lines changed

ggml/src/ggml-cuda/upscale.cu

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,70 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst,
8181
dst[index] = result;
8282
}
8383

84+
namespace bicubic_interpolation {
85+
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
86+
__device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
87+
88+
static __device__ float weight1(float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
89+
static __device__ float weight2(float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
90+
91+
static __device__ float bicubic(float p0, float p1, float p2, float p3, float x) {
92+
const float w0 = weight2(x + 1);
93+
const float w1 = weight1(x + 0);
94+
const float w2 = weight1(1 - x);
95+
const float w3 = weight2(2 - x);
96+
return p0 * w0 + p1 * w1 + p2 * w2 + p3 * w3;
97+
};
98+
} // namespace bicubic_interpolation
99+
100+
static __global__ void upscale_f32_bicubic(const float * x, float * dst,
101+
const int nb00, const int nb01, const int nb02, const int nb03,
102+
const int ne00_src, const int ne01_src,
103+
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
104+
const float sf0, const float sf1, const float sf2, const float sf3,
105+
const float pixel_offset) {
106+
using bicubic_interpolation::bicubic;
107+
108+
const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
109+
const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
110+
111+
if (index >= dst_total_elements) {
112+
return;
113+
}
114+
115+
const int i10_dst = index % ne10_dst;
116+
const int i11_dst = (index / ne10_dst) % ne11_dst;
117+
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
118+
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
119+
120+
const int i02_src = (int)(i12_dst / sf2);
121+
const int i03_src = (int)(i13_dst / sf3);
122+
123+
const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
124+
const int y0_src = (int)floorf(y_src_f);
125+
const float dy = y_src_f - (float)y0_src;
126+
127+
const float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
128+
const int x0_src = (int)floorf(x_src_f);
129+
const float dx = x_src_f - (float)x0_src;
130+
131+
const char * x_base = (const char *)x + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03;
132+
133+
auto load = [=](int x_off, int y_off) -> float {
134+
int i00_src = max(0, min(x0_src + x_off, ne00_src - 1));
135+
int i01_src = max(0, min(y0_src + y_off, ne01_src - 1));
136+
return *(const float *)(x_base + (int64_t)i00_src * nb00 + (int64_t)i01_src * nb01);
137+
};
138+
139+
const float result = bicubic(
140+
bicubic(load(-1,-1), load(0,-1), load(1,-1), load(2,-1), dx),
141+
bicubic(load(-1, 0), load(0, 0), load(1, 0), load(2, 0), dx),
142+
bicubic(load(-1, 1), load(0, 1), load(1, 1), load(2, 1), dx),
143+
bicubic(load(-1, 2), load(0, 2), load(1, 2), load(2, 2), dx), dy);
144+
145+
dst[index] = result;
146+
}
147+
84148
static void upscale_f32_cuda(const float * x, float * dst,
85149
const int nb00, const int nb01, const int nb02, const int nb03,
86150
const int ne10, const int ne11, const int ne12, const int ne13,
@@ -104,6 +168,18 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst,
104168
upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
105169
}
106170

171+
static void upscale_f32_bicubic_cuda(const float * x, float * dst,
172+
const int nb00, const int nb01, const int nb02, const int nb03,
173+
const int ne00_src, const int ne01_src,
174+
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
175+
const float sf0, const float sf1, const float sf2, const float sf3,
176+
const float pixel_offset, cudaStream_t stream) {
177+
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
178+
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
179+
180+
upscale_f32_bicubic<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
181+
}
182+
107183
void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
108184
const ggml_tensor * src0 = dst->src[0];
109185
const float * src0_d = (const float *)src0->data;
@@ -121,17 +197,22 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
121197
float sf2 = (float)dst->ne[2]/src0->ne[2];
122198
const float sf3 = (float)dst->ne[3]/src0->ne[3];
123199

200+
float pixel_offset = 0.5f;
201+
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
202+
sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0;
203+
sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1;
204+
pixel_offset = 0.0f;
205+
}
206+
124207
if (mode == GGML_SCALE_MODE_NEAREST) {
125208
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
126209
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
127-
float pixel_offset = 0.5f;
128-
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
129-
sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0;
130-
sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1;
131-
pixel_offset = 0.0f;
132-
}
133210
upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
134211
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
135212
sf0, sf1, sf2, sf3, pixel_offset, stream);
213+
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
214+
upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
215+
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
216+
sf0, sf1, sf2, sf3, pixel_offset, stream);
136217
}
137218
}

0 commit comments

Comments
 (0)