From 6b85d9deb00c9639441ef93d8f6c5d6a1737f8c9 Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Sat, 21 Sep 2024 21:53:03 +0900 Subject: [PATCH] internal/shader: add check for out-of-bounds Closes #3112 --- internal/shader/expr.go | 25 +++++++++++ internal/shader/syntax_test.go | 78 ++++++++++++++++++++++++++++++++++ internal/shaderir/type.go | 13 ++++++ 3 files changed, 116 insertions(+) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index c47c4c5b9..f643cc1f9 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -1156,6 +1156,31 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar x := exprs[0] t := ts[0] + // Check the length only when the index is a constant. + if idx.Const != nil { + var length int + switch { + case t.Main == shaderir.Array: + length = t.Length + case t.IsFloatVector() || t.IsIntVector(): + length = t.VectorElementCount() + case t.IsMatrix(): + length = t.MatrixSize() + default: + cs.addError(e.Pos(), fmt.Sprintf("index operator cannot be applied to the type %s", t.String())) + return nil, nil, nil, false + } + v, ok := gconstant.Int64Val(gconstant.ToInt(idx.Const)) + if !ok { + cs.addError(e.Pos(), fmt.Sprintf("constant %s cannot be used as an index", idx.Const.String())) + return nil, nil, nil, false + } + if v < 0 || int(v) >= length { + cs.addError(e.Pos(), fmt.Sprintf("index out of range: %d", v)) + return nil, nil, nil, false + } + } + var typ shaderir.Type switch t.Main { case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4: diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 744e1f1e6..0470f9e06 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -4404,3 +4404,81 @@ func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { } } } + +func TestSyntaxArrayOutOfBounds(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "a := [0]int{}; _ = a[-1]", err: true}, + {stmt: "a := [0]int{}; _ = a[0]", err: true}, + {stmt: "a := [0]int{}; _ = a[1]", err: true}, + {stmt: "a := [0]int{}; _ = a[2]", err: true}, + {stmt: "a := [0]int{}; _ = a[3]", err: true}, + {stmt: "a := [0]int{}; b := -1; _ = a[b]", err: false}, + {stmt: "a := [0]int{}; b := 0; _ = a[b]", err: false}, + {stmt: "a := [0]int{}; b := 1; _ = a[b]", err: false}, + {stmt: "a := [0]int{}; b := 2; _ = a[b]", err: false}, + {stmt: "a := [0]int{}; b := 3; _ = a[b]", err: false}, + + {stmt: "a := [1]int{}; _ = a[-1]", err: true}, + {stmt: "a := [1]int{}; _ = a[0]", err: false}, + {stmt: "a := [1]int{}; _ = a[1]", err: true}, + {stmt: "a := [1]int{}; _ = a[2]", err: true}, + {stmt: "a := [1]int{}; _ = a[3]", err: true}, + {stmt: "a := [1]int{}; b := -1; _ = a[b]", err: false}, + {stmt: "a := [1]int{}; b := 0; _ = a[b]", err: false}, + {stmt: "a := [1]int{}; b := 1; _ = a[b]", err: false}, + {stmt: "a := [1]int{}; b := 2; _ = a[b]", err: false}, + {stmt: "a := [1]int{}; b := 3; _ = a[b]", err: false}, + + {stmt: "a := [2]int{}; _ = a[-1]", err: true}, + {stmt: "a := [2]int{}; _ = a[0]", err: false}, + {stmt: "a := [2]int{}; _ = a[1]", err: false}, + {stmt: "a := [2]int{}; _ = a[2]", err: true}, + {stmt: "a := [2]int{}; _ = a[3]", err: true}, + {stmt: "a := [2]int{}; b := -1; _ = a[b]", err: false}, + {stmt: "a := [2]int{}; b := 0; _ = a[b]", err: false}, + {stmt: "a := [2]int{}; b := 1; _ = a[b]", err: false}, + {stmt: "a := [2]int{}; b := 2; _ = a[b]", err: false}, + {stmt: "a := [2]int{}; b := 3; _ = a[b]", err: false}, + + {stmt: "a := vec2(0); _ = a[-1]", err: true}, + {stmt: "a := vec2(0); _ = a[0]", err: false}, + {stmt: "a := vec2(0); _ = a[1]", err: false}, + {stmt: "a := vec2(0); _ = a[2]", err: true}, + {stmt: "a := vec2(0); _ = a[3]", err: true}, + {stmt: "a := vec2(0); b := -1; _ = a[b]", err: false}, + {stmt: "a := vec2(0); b := 0; _ = a[b]", err: false}, + {stmt: "a := vec2(0); b := 1; _ = a[b]", err: false}, + {stmt: "a := vec2(0); b := 2; _ = a[b]", err: false}, + {stmt: "a := vec2(0); b := 3; _ = a[b]", err: false}, + + {stmt: "a := mat3(0); _ = a[-1]", err: true}, + {stmt: "a := mat3(0); _ = a[0]", err: false}, + {stmt: "a := mat3(0); _ = a[1]", err: false}, + {stmt: "a := mat3(0); _ = a[2]", err: false}, + {stmt: "a := mat3(0); _ = a[3]", err: true}, + {stmt: "a := mat3(0); b := -1; _ = a[b]", err: false}, + {stmt: "a := mat3(0); b := 0; _ = a[b]", err: false}, + {stmt: "a := mat3(0); b := 1; _ = a[b]", err: false}, + {stmt: "a := mat3(0); b := 2; _ = a[b]", err: false}, + {stmt: "a := mat3(0); b := 3; _ = a[b]", err: false}, + } + + for _, c := range cases { + stmt := c.stmt + src := fmt.Sprintf(`package main + +func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { + %s + return dstPos +}`, stmt) + _, err := compileToIR([]byte(src)) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", stmt, err) + } + } +} diff --git a/internal/shaderir/type.go b/internal/shaderir/type.go index 8601b5842..f5a9d3cee 100644 --- a/internal/shaderir/type.go +++ b/internal/shaderir/type.go @@ -161,6 +161,19 @@ func (t Type) IsMatrix() bool { return false } +func (t Type) MatrixSize() int { + switch t.Main { + case Mat2: + return 2 + case Mat3: + return 3 + case Mat4: + return 4 + default: + return -1 + } +} + type BasicType int const (