From 5d8216def3715043e5a234d65d55bc851323202e Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Sun, 20 Nov 2022 15:36:07 +0900 Subject: [PATCH] internal/shader: stricter const type check --- internal/shader/expr.go | 2 +- internal/shader/shader.go | 22 +++++++++-- internal/shader/stmt.go | 28 ++++++++------ internal/shader/syntax_test.go | 67 ++++++++++++++++++++++++++++++++++ 4 files changed, 103 insertions(+), 16 deletions(-) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index a2693812f..2581e1735 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -731,7 +731,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } } - if !canAssign(&args[i], &p, &argts[i]) { + if !canAssign(&p, &argts[i], args[i].Const) { cs.addError(e.Pos(), fmt.Sprintf("cannot use type %s as type %s in argument", argts[i].String(), p.String())) return nil, nil, nil, false } diff --git a/internal/shader/shader.go b/internal/shader/shader.go index c083932bf..1494137ca 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -507,7 +507,7 @@ func (s *compileState) parseVariable(block *block, fname string, vs *ast.ValueSp } for i, rt := range rts { - if !canAssign(&es[i], &t, &rt) { + if !canAssign(&t, &rt, es[i].Const) { s.addError(vs.Pos(), fmt.Sprintf("cannot use type %s as type %s in variable declaration", rt.String(), t.String())) } } @@ -545,7 +545,7 @@ func (s *compileState) parseVariable(block *block, fname string, vs *ast.ValueSp t = inittypes[i] } - if !canAssign(&initexprs[i], &t, &inittypes[i]) { + if !canAssign(&t, &inittypes[i], initexprs[i].Const) { s.addError(vs.Pos(), fmt.Sprintf("cannot use type %s as type %s in variable declaration", inittypes[i].String(), t.String())) } @@ -617,10 +617,26 @@ func (s *compileState) parseConstant(block *block, fname string, vs *ast.ValueSp s.addError(vs.Pos(), fmt.Sprintf("constant expression must be a number but not: %s", n)) return nil, false } + + if !t.Equal(&shaderir.Type{}) && !canAssign(&t, &ts[0], es[0].Const) { + s.addError(vs.Pos(), fmt.Sprintf("cannot use %v as %s value in constant declaration", es[0].Const, t.String())) + return nil, false + } + + constType := es[0].ConstType + switch t.Main { + case shaderir.Bool: + constType = shaderir.ConstTypeBool + case shaderir.Int: + constType = shaderir.ConstTypeInt + case shaderir.Float: + constType = shaderir.ConstTypeFloat + } + cs = append(cs, constant{ name: name, typ: t, - ctyp: es[0].ConstType, + ctyp: constType, value: es[0].Const, }) } diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index 91910341c..8774a2884 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -667,7 +667,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r } for i := range lts { - if !canAssign(&r[i], <s[i], &rts[i]) { + if !canAssign(<s[i], &rts[i], r[i].Const) { cs.addError(pos, fmt.Sprintf("cannot use type %s as type %s in variable declaration", rts[i].String(), lts[i].String())) return nil, false } @@ -755,7 +755,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r allblank = false for i, lt := range lts { - if !canAssign(&rhsExprs[i], <, &rhsTypes[i]) { + if !canAssign(<, &rhsTypes[i], rhsExprs[i].Const) { cs.addError(pos, fmt.Sprintf("cannot use type %s as type %s in variable declaration", rhsTypes[i].String(), lt.String())) return nil, false } @@ -789,19 +789,23 @@ func toDefaultType(v gconstant.Value) shaderir.Type { return shaderir.Type{} } -func canAssign(re *shaderir.Expr, lt *shaderir.Type, rt *shaderir.Type) bool { +func canAssign(lt *shaderir.Type, rt *shaderir.Type, rc gconstant.Value) bool { if lt.Equal(rt) { return true } - if re.Const != nil { - switch lt.Main { - case shaderir.Bool: - return re.Const.Kind() == gconstant.Bool - case shaderir.Int: - return canTruncateToInteger(re.Const) - case shaderir.Float: - return gconstant.ToFloat(re.Const).Kind() != gconstant.Unknown - } + + if rc == nil { + return false } + + switch lt.Main { + case shaderir.Bool: + return rc.Kind() == gconstant.Bool + case shaderir.Int: + return canTruncateToInteger(rc) + case shaderir.Float: + return gconstant.ToFloat(rc).Kind() != gconstant.Unknown + } + return false } diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index c9b193126..cba40766b 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1127,6 +1127,10 @@ func TestSyntaxOperatorMultiplyAssign(t *testing.T) { {stmt: "a := 1.0; a *= 2.0", err: false}, {stmt: "const c = 2; a := 1.0; a *= c", err: false}, {stmt: "const c = 2.0; a := 1.0; a *= c", err: false}, + {stmt: "const c int = 2; a := 1.0; a *= c", err: true}, + {stmt: "const c int = 2.0; a := 1.0; a *= c", err: true}, + {stmt: "const c float = 2; a := 1.0; a *= c", err: false}, + {stmt: "const c float = 2.0; a := 1.0; a *= c", err: false}, {stmt: "a := 1.0; a *= int(2)", err: true}, {stmt: "a := 1.0; a *= vec2(2)", err: true}, {stmt: "a := 1.0; a *= vec3(2)", err: true}, @@ -1134,10 +1138,15 @@ func TestSyntaxOperatorMultiplyAssign(t *testing.T) { {stmt: "a := 1.0; a *= mat2(2)", err: true}, {stmt: "a := 1.0; a *= mat3(2)", err: true}, {stmt: "a := 1.0; a *= mat4(2)", err: true}, + {stmt: "a := vec2(1); a *= 2", err: false}, {stmt: "a := vec2(1); a *= 2.0", err: false}, {stmt: "const c = 2; a := vec2(1); a *= c", err: false}, {stmt: "const c = 2.0; a := vec2(1); a *= c", err: false}, + {stmt: "const c int = 2; a := vec2(1); a *= c", err: true}, + {stmt: "const c int = 2.0; a := vec2(1); a *= c", err: true}, + {stmt: "const c float = 2; a := vec2(1); a *= c", err: false}, + {stmt: "const c float = 2.0; a := vec2(1); a *= c", err: false}, {stmt: "a := vec2(1); a /= 2.0", err: false}, {stmt: "a := vec2(1); a += 2.0", err: false}, {stmt: "a := vec2(1); a *= int(2)", err: true}, @@ -1152,10 +1161,15 @@ func TestSyntaxOperatorMultiplyAssign(t *testing.T) { {stmt: "a := vec2(1); a /= mat2(2)", err: true}, {stmt: "a := vec2(1); a *= mat3(2)", err: true}, {stmt: "a := vec2(1); a *= mat4(2)", err: true}, + {stmt: "a := mat2(1); a *= 2", err: false}, {stmt: "a := mat2(1); a *= 2.0", err: false}, {stmt: "const c = 2; a := mat2(1); a *= c", err: false}, {stmt: "const c = 2.0; a := mat2(1); a *= c", err: false}, + {stmt: "const c int = 2; a := mat2(1); a *= c", err: true}, + {stmt: "const c int = 2.0; a := mat2(1); a *= c", err: true}, + {stmt: "const c float = 2; a := mat2(1); a *= c", err: false}, + {stmt: "const c float = 2.0; a := mat2(1); a *= c", err: false}, {stmt: "a := mat2(1); a /= 2.0", err: false}, {stmt: "a := mat2(1); a += 2.0", err: true}, {stmt: "a := mat2(1); a *= int(2)", err: true}, @@ -2465,3 +2479,56 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { } } } + +func TestConstType(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "const a = false", err: false}, + {stmt: "const a bool = false", err: false}, + {stmt: "const a int = false", err: true}, + {stmt: "const a float = false", err: true}, + {stmt: "const a vec2 = false", err: true}, + + {stmt: "const a = 1", err: false}, + {stmt: "const a bool = 1", err: true}, + {stmt: "const a int = 1", err: false}, + {stmt: "const a float = 1", err: false}, + {stmt: "const a vec2 = 1", err: true}, + + {stmt: "const a = 1.0", err: false}, + {stmt: "const a bool = 1.0", err: true}, + {stmt: "const a int = 1.0", err: false}, + {stmt: "const a float = 1.0", err: false}, + {stmt: "const a vec2 = 1.0", err: true}, + + {stmt: "const a = 1.1", err: false}, + {stmt: "const a bool = 1.1", err: true}, + {stmt: "const a int = 1.1", err: true}, + {stmt: "const a float = 1.1", err: false}, + {stmt: "const a vec2 = 1.1", err: true}, + + {stmt: "const a = vec2(0)", err: true}, + {stmt: "const a bool = vec2(0)", err: true}, + {stmt: "const a int = vec2(0)", err: true}, + {stmt: "const a float = vec2(0)", err: true}, + {stmt: "const a vec2 = vec2(0)", err: true}, + } + + for _, c := range cases { + stmt := c.stmt + src := fmt.Sprintf(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + %s + return position +}`, stmt) + _, err := compileToIR([]byte(src)) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", stmt, err) + } + } +}