Skip to content

Commit 302a9d0

Browse files
committed
vulkan : implement upscale with bicubic interpolation
1 parent 9f05247 commit 302a9d0

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ struct vk_device_struct {
595595
vk_pipeline pipeline_add_id_f32;
596596

597597
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
598-
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32;
598+
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32;
599599
vk_pipeline pipeline_scale_f32;
600600
vk_pipeline pipeline_sqr_f32;
601601
vk_pipeline pipeline_sqrt_f32;
@@ -3664,6 +3664,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
36643664

36653665
ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
36663666
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
3667+
ggml_vk_create_pipeline(device, device->pipeline_upscale_bicubic_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BICUBIC}, 1);
36673668

36683669
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
36693670

@@ -8220,6 +8221,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
82208221
return ctx->device->pipeline_upscale_nearest_f32;
82218222
case GGML_SCALE_MODE_BILINEAR:
82228223
return ctx->device->pipeline_upscale_bilinear_f32;
8224+
case GGML_SCALE_MODE_BICUBIC:
8225+
return ctx->device->pipeline_upscale_bicubic_f32;
82238226
default:
82248227
return nullptr;
82258228
}

ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
2020
// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag
2121
#define NEAREST 0
2222
#define BILINEAR 1
23+
#define BICUBIC 2
2324

2425
layout (constant_id = 0) const uint scale_mode = 0;
2526

@@ -61,6 +62,39 @@ float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
6162
return fetch_bilinear(c0, c1, d, i12, i13);
6263
}
6364

65+
// Bicubic interpolation with alpha = -0.75
66+
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
67+
const vec4 bcoeffs1 = vec4( 1.25, -2.25, 0.0, 1.0);
68+
const vec4 bcoeffs2 = vec4(-0.75, 3.75, -6.0, 3.0);
69+
vec4 powers(float x) { return vec4(x*x*x, x*x, x, 1); }
70+
71+
float bicubic(float p0, float p1, float p2, float p3, float x) {
72+
return p0 * dot(bcoeffs2, powers(x + 1)) +
73+
p1 * dot(bcoeffs1, powers(x )) +
74+
p2 * dot(bcoeffs1, powers(1 - x)) +
75+
p3 * dot(bcoeffs2, powers(2 - x));
76+
}
77+
78+
#define FETCH(a,b) data_a[base + clamp(i.x+(a), 0, res.x) * p.nb00 + clamp(i.y+(b), 0, res.y) * p.nb01]
79+
80+
float interpolate_bicubic(uint i10, uint i11, uint i12, uint i13) {
81+
const ivec2 res = ivec2(p.ne00 - 1, p.ne01 - 1);
82+
83+
const vec2 coord = (vec2(i10, i11) + p.pixel_offset) / vec2(p.sf0, p.sf1) - p.pixel_offset;
84+
const vec2 d = fract(coord);
85+
const ivec2 i = ivec2(floor(coord));
86+
87+
const uint i02 = uint(i12 / p.sf2);
88+
const uint i03 = uint(i13 / p.sf3);
89+
const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02;
90+
91+
return bicubic(
92+
bicubic(FETCH(-1,-1), FETCH(0,-1), FETCH(1,-1), FETCH(2,-1), d.x),
93+
bicubic(FETCH(-1, 0), FETCH(0, 0), FETCH(1, 0), FETCH(2, 0), d.x),
94+
bicubic(FETCH(-1, 1), FETCH(0, 1), FETCH(1, 1), FETCH(2, 1), d.x),
95+
bicubic(FETCH(-1, 2), FETCH(0, 2), FETCH(1, 2), FETCH(2, 2), d.x), d.y);
96+
}
97+
6498
void main() {
6599
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
66100

@@ -81,6 +115,9 @@ void main() {
81115
case BILINEAR:
82116
result = interpolate_bilinear(i10, i11, i12, i13);
83117
break;
118+
case BICUBIC:
119+
result = interpolate_bicubic(i10, i11, i12, i13);
120+
break;
84121
}
85122

86123
data_d[p.d_offset + idx] = D_TYPE(result);

0 commit comments

Comments
 (0)