internal/shader: add check for out-of-bounds

Closes #3112
This commit is contained in:
Hajime Hoshi 2024-09-21 21:53:03 +09:00
parent eabc697022
commit 6b85d9deb0
3 changed files with 116 additions and 0 deletions

View File

@ -1156,6 +1156,31 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
x := exprs[0]
t := ts[0]
// Check the length only when the index is a constant.
if idx.Const != nil {
var length int
switch {
case t.Main == shaderir.Array:
length = t.Length
case t.IsFloatVector() || t.IsIntVector():
length = t.VectorElementCount()
case t.IsMatrix():
length = t.MatrixSize()
default:
cs.addError(e.Pos(), fmt.Sprintf("index operator cannot be applied to the type %s", t.String()))
return nil, nil, nil, false
}
v, ok := gconstant.Int64Val(gconstant.ToInt(idx.Const))
if !ok {
cs.addError(e.Pos(), fmt.Sprintf("constant %s cannot be used as an index", idx.Const.String()))
return nil, nil, nil, false
}
if v < 0 || int(v) >= length {
cs.addError(e.Pos(), fmt.Sprintf("index out of range: %d", v))
return nil, nil, nil, false
}
}
var typ shaderir.Type
switch t.Main {
case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4:

View File

@ -4404,3 +4404,81 @@ func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
}
}
}
func TestSyntaxArrayOutOfBounds(t *testing.T) {
cases := []struct {
stmt string
err bool
}{
{stmt: "a := [0]int{}; _ = a[-1]", err: true},
{stmt: "a := [0]int{}; _ = a[0]", err: true},
{stmt: "a := [0]int{}; _ = a[1]", err: true},
{stmt: "a := [0]int{}; _ = a[2]", err: true},
{stmt: "a := [0]int{}; _ = a[3]", err: true},
{stmt: "a := [0]int{}; b := -1; _ = a[b]", err: false},
{stmt: "a := [0]int{}; b := 0; _ = a[b]", err: false},
{stmt: "a := [0]int{}; b := 1; _ = a[b]", err: false},
{stmt: "a := [0]int{}; b := 2; _ = a[b]", err: false},
{stmt: "a := [0]int{}; b := 3; _ = a[b]", err: false},
{stmt: "a := [1]int{}; _ = a[-1]", err: true},
{stmt: "a := [1]int{}; _ = a[0]", err: false},
{stmt: "a := [1]int{}; _ = a[1]", err: true},
{stmt: "a := [1]int{}; _ = a[2]", err: true},
{stmt: "a := [1]int{}; _ = a[3]", err: true},
{stmt: "a := [1]int{}; b := -1; _ = a[b]", err: false},
{stmt: "a := [1]int{}; b := 0; _ = a[b]", err: false},
{stmt: "a := [1]int{}; b := 1; _ = a[b]", err: false},
{stmt: "a := [1]int{}; b := 2; _ = a[b]", err: false},
{stmt: "a := [1]int{}; b := 3; _ = a[b]", err: false},
{stmt: "a := [2]int{}; _ = a[-1]", err: true},
{stmt: "a := [2]int{}; _ = a[0]", err: false},
{stmt: "a := [2]int{}; _ = a[1]", err: false},
{stmt: "a := [2]int{}; _ = a[2]", err: true},
{stmt: "a := [2]int{}; _ = a[3]", err: true},
{stmt: "a := [2]int{}; b := -1; _ = a[b]", err: false},
{stmt: "a := [2]int{}; b := 0; _ = a[b]", err: false},
{stmt: "a := [2]int{}; b := 1; _ = a[b]", err: false},
{stmt: "a := [2]int{}; b := 2; _ = a[b]", err: false},
{stmt: "a := [2]int{}; b := 3; _ = a[b]", err: false},
{stmt: "a := vec2(0); _ = a[-1]", err: true},
{stmt: "a := vec2(0); _ = a[0]", err: false},
{stmt: "a := vec2(0); _ = a[1]", err: false},
{stmt: "a := vec2(0); _ = a[2]", err: true},
{stmt: "a := vec2(0); _ = a[3]", err: true},
{stmt: "a := vec2(0); b := -1; _ = a[b]", err: false},
{stmt: "a := vec2(0); b := 0; _ = a[b]", err: false},
{stmt: "a := vec2(0); b := 1; _ = a[b]", err: false},
{stmt: "a := vec2(0); b := 2; _ = a[b]", err: false},
{stmt: "a := vec2(0); b := 3; _ = a[b]", err: false},
{stmt: "a := mat3(0); _ = a[-1]", err: true},
{stmt: "a := mat3(0); _ = a[0]", err: false},
{stmt: "a := mat3(0); _ = a[1]", err: false},
{stmt: "a := mat3(0); _ = a[2]", err: false},
{stmt: "a := mat3(0); _ = a[3]", err: true},
{stmt: "a := mat3(0); b := -1; _ = a[b]", err: false},
{stmt: "a := mat3(0); b := 0; _ = a[b]", err: false},
{stmt: "a := mat3(0); b := 1; _ = a[b]", err: false},
{stmt: "a := mat3(0); b := 2; _ = a[b]", err: false},
{stmt: "a := mat3(0); b := 3; _ = a[b]", err: false},
}
for _, c := range cases {
stmt := c.stmt
src := fmt.Sprintf(`package main
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
%s
return dstPos
}`, stmt)
_, err := compileToIR([]byte(src))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", stmt)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", stmt, err)
}
}
}

View File

@ -161,6 +161,19 @@ func (t Type) IsMatrix() bool {
return false
}
func (t Type) MatrixSize() int {
switch t.Main {
case Mat2:
return 2
case Mat3:
return 3
case Mat4:
return 4
default:
return -1
}
}
type BasicType int
const (