internal/shader: stricter const type check

This commit is contained in:
Hajime Hoshi 2022-11-20 15:36:07 +09:00
parent 685a6acb05
commit 5d8216def3
4 changed files with 103 additions and 16 deletions

View File

@ -731,7 +731,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
}
}
if !canAssign(&args[i], &p, &argts[i]) {
if !canAssign(&p, &argts[i], args[i].Const) {
cs.addError(e.Pos(), fmt.Sprintf("cannot use type %s as type %s in argument", argts[i].String(), p.String()))
return nil, nil, nil, false
}

View File

@ -507,7 +507,7 @@ func (s *compileState) parseVariable(block *block, fname string, vs *ast.ValueSp
}
for i, rt := range rts {
if !canAssign(&es[i], &t, &rt) {
if !canAssign(&t, &rt, es[i].Const) {
s.addError(vs.Pos(), fmt.Sprintf("cannot use type %s as type %s in variable declaration", rt.String(), t.String()))
}
}
@ -545,7 +545,7 @@ func (s *compileState) parseVariable(block *block, fname string, vs *ast.ValueSp
t = inittypes[i]
}
if !canAssign(&initexprs[i], &t, &inittypes[i]) {
if !canAssign(&t, &inittypes[i], initexprs[i].Const) {
s.addError(vs.Pos(), fmt.Sprintf("cannot use type %s as type %s in variable declaration", inittypes[i].String(), t.String()))
}
@ -617,10 +617,26 @@ func (s *compileState) parseConstant(block *block, fname string, vs *ast.ValueSp
s.addError(vs.Pos(), fmt.Sprintf("constant expression must be a number but not: %s", n))
return nil, false
}
if !t.Equal(&shaderir.Type{}) && !canAssign(&t, &ts[0], es[0].Const) {
s.addError(vs.Pos(), fmt.Sprintf("cannot use %v as %s value in constant declaration", es[0].Const, t.String()))
return nil, false
}
constType := es[0].ConstType
switch t.Main {
case shaderir.Bool:
constType = shaderir.ConstTypeBool
case shaderir.Int:
constType = shaderir.ConstTypeInt
case shaderir.Float:
constType = shaderir.ConstTypeFloat
}
cs = append(cs, constant{
name: name,
typ: t,
ctyp: es[0].ConstType,
ctyp: constType,
value: es[0].Const,
})
}

View File

@ -667,7 +667,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
}
for i := range lts {
if !canAssign(&r[i], &lts[i], &rts[i]) {
if !canAssign(&lts[i], &rts[i], r[i].Const) {
cs.addError(pos, fmt.Sprintf("cannot use type %s as type %s in variable declaration", rts[i].String(), lts[i].String()))
return nil, false
}
@ -755,7 +755,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
allblank = false
for i, lt := range lts {
if !canAssign(&rhsExprs[i], &lt, &rhsTypes[i]) {
if !canAssign(&lt, &rhsTypes[i], rhsExprs[i].Const) {
cs.addError(pos, fmt.Sprintf("cannot use type %s as type %s in variable declaration", rhsTypes[i].String(), lt.String()))
return nil, false
}
@ -789,19 +789,23 @@ func toDefaultType(v gconstant.Value) shaderir.Type {
return shaderir.Type{}
}
func canAssign(re *shaderir.Expr, lt *shaderir.Type, rt *shaderir.Type) bool {
func canAssign(lt *shaderir.Type, rt *shaderir.Type, rc gconstant.Value) bool {
if lt.Equal(rt) {
return true
}
if re.Const != nil {
switch lt.Main {
case shaderir.Bool:
return re.Const.Kind() == gconstant.Bool
case shaderir.Int:
return canTruncateToInteger(re.Const)
case shaderir.Float:
return gconstant.ToFloat(re.Const).Kind() != gconstant.Unknown
}
if rc == nil {
return false
}
switch lt.Main {
case shaderir.Bool:
return rc.Kind() == gconstant.Bool
case shaderir.Int:
return canTruncateToInteger(rc)
case shaderir.Float:
return gconstant.ToFloat(rc).Kind() != gconstant.Unknown
}
return false
}

View File

@ -1127,6 +1127,10 @@ func TestSyntaxOperatorMultiplyAssign(t *testing.T) {
{stmt: "a := 1.0; a *= 2.0", err: false},
{stmt: "const c = 2; a := 1.0; a *= c", err: false},
{stmt: "const c = 2.0; a := 1.0; a *= c", err: false},
{stmt: "const c int = 2; a := 1.0; a *= c", err: true},
{stmt: "const c int = 2.0; a := 1.0; a *= c", err: true},
{stmt: "const c float = 2; a := 1.0; a *= c", err: false},
{stmt: "const c float = 2.0; a := 1.0; a *= c", err: false},
{stmt: "a := 1.0; a *= int(2)", err: true},
{stmt: "a := 1.0; a *= vec2(2)", err: true},
{stmt: "a := 1.0; a *= vec3(2)", err: true},
@ -1134,10 +1138,15 @@ func TestSyntaxOperatorMultiplyAssign(t *testing.T) {
{stmt: "a := 1.0; a *= mat2(2)", err: true},
{stmt: "a := 1.0; a *= mat3(2)", err: true},
{stmt: "a := 1.0; a *= mat4(2)", err: true},
{stmt: "a := vec2(1); a *= 2", err: false},
{stmt: "a := vec2(1); a *= 2.0", err: false},
{stmt: "const c = 2; a := vec2(1); a *= c", err: false},
{stmt: "const c = 2.0; a := vec2(1); a *= c", err: false},
{stmt: "const c int = 2; a := vec2(1); a *= c", err: true},
{stmt: "const c int = 2.0; a := vec2(1); a *= c", err: true},
{stmt: "const c float = 2; a := vec2(1); a *= c", err: false},
{stmt: "const c float = 2.0; a := vec2(1); a *= c", err: false},
{stmt: "a := vec2(1); a /= 2.0", err: false},
{stmt: "a := vec2(1); a += 2.0", err: false},
{stmt: "a := vec2(1); a *= int(2)", err: true},
@ -1152,10 +1161,15 @@ func TestSyntaxOperatorMultiplyAssign(t *testing.T) {
{stmt: "a := vec2(1); a /= mat2(2)", err: true},
{stmt: "a := vec2(1); a *= mat3(2)", err: true},
{stmt: "a := vec2(1); a *= mat4(2)", err: true},
{stmt: "a := mat2(1); a *= 2", err: false},
{stmt: "a := mat2(1); a *= 2.0", err: false},
{stmt: "const c = 2; a := mat2(1); a *= c", err: false},
{stmt: "const c = 2.0; a := mat2(1); a *= c", err: false},
{stmt: "const c int = 2; a := mat2(1); a *= c", err: true},
{stmt: "const c int = 2.0; a := mat2(1); a *= c", err: true},
{stmt: "const c float = 2; a := mat2(1); a *= c", err: false},
{stmt: "const c float = 2.0; a := mat2(1); a *= c", err: false},
{stmt: "a := mat2(1); a /= 2.0", err: false},
{stmt: "a := mat2(1); a += 2.0", err: true},
{stmt: "a := mat2(1); a *= int(2)", err: true},
@ -2465,3 +2479,56 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
}
}
}
func TestConstType(t *testing.T) {
cases := []struct {
stmt string
err bool
}{
{stmt: "const a = false", err: false},
{stmt: "const a bool = false", err: false},
{stmt: "const a int = false", err: true},
{stmt: "const a float = false", err: true},
{stmt: "const a vec2 = false", err: true},
{stmt: "const a = 1", err: false},
{stmt: "const a bool = 1", err: true},
{stmt: "const a int = 1", err: false},
{stmt: "const a float = 1", err: false},
{stmt: "const a vec2 = 1", err: true},
{stmt: "const a = 1.0", err: false},
{stmt: "const a bool = 1.0", err: true},
{stmt: "const a int = 1.0", err: false},
{stmt: "const a float = 1.0", err: false},
{stmt: "const a vec2 = 1.0", err: true},
{stmt: "const a = 1.1", err: false},
{stmt: "const a bool = 1.1", err: true},
{stmt: "const a int = 1.1", err: true},
{stmt: "const a float = 1.1", err: false},
{stmt: "const a vec2 = 1.1", err: true},
{stmt: "const a = vec2(0)", err: true},
{stmt: "const a bool = vec2(0)", err: true},
{stmt: "const a int = vec2(0)", err: true},
{stmt: "const a float = vec2(0)", err: true},
{stmt: "const a vec2 = vec2(0)", err: true},
}
for _, c := range cases {
stmt := c.stmt
src := fmt.Sprintf(`package main
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
%s
return position
}`, stmt)
_, err := compileToIR([]byte(src))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", stmt)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", stmt, err)
}
}
}