From dc1df824a5c28ef28d213791c01dcdfdb4f7cbe5 Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Sun, 20 Nov 2022 17:18:24 +0900 Subject: [PATCH] internal/shader: more strict type checks with built-in functions --- internal/shader/expr.go | 20 +++++++++++-- internal/shader/stmt.go | 30 +++++++++++++------ internal/shader/syntax_test.go | 54 ++++++++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 11 deletions(-) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index 2581e1735..f4a964306 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -153,19 +153,21 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar case *ast.BasicLit: switch e.Kind { case token.INT: + // The type is not determined yet. return []shaderir.Expr{ { Type: shaderir.NumberExpr, Const: gconstant.MakeFromLiteral(e.Value, e.Kind, 0), }, - }, []shaderir.Type{{Main: shaderir.Int}}, nil, true + }, []shaderir.Type{{}}, nil, true case token.FLOAT: + // The type is not determined yet. return []shaderir.Expr{ { Type: shaderir.NumberExpr, Const: gconstant.MakeFromLiteral(e.Value, e.Kind, 0), }, - }, []shaderir.Type{{Main: shaderir.Float}}, nil, true + }, []shaderir.Type{{}}, nil, true default: cs.addError(e.Pos(), fmt.Sprintf("literal not implemented: %#v", e)) } @@ -444,6 +446,20 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar ConstType: shaderir.ConstTypeInt, }, }, []shaderir.Type{{Main: shaderir.Int}}, stmts, true + case shaderir.BoolF: + if len(args) == 1 && args[0].Const != nil { + if args[0].Const.Kind() != gconstant.Bool { + cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type bool", args[0].Const.String())) + return nil, nil, nil, false + } + return []shaderir.Expr{ + { + Type: shaderir.NumberExpr, + Const: args[0].Const, + ConstType: shaderir.ConstTypeBool, + }, + }, []shaderir.Type{{Main: shaderir.Bool}}, stmts, true + } case shaderir.IntF: if len(args) == 1 && args[0].Const != nil { if !canTruncateToInteger(args[0].Const) { diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index 8774a2884..9bec9e355 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -119,20 +119,28 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP return nil, false } case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: - if (op == shaderir.MatrixMul || op == shaderir.Div) && rts[0].Main == shaderir.Float { - // OK - } else if lts[0].IsVector() && rts[0].Main == shaderir.Float { - // OK + if (op == shaderir.MatrixMul || op == shaderir.Div) && + (rts[0].Main == shaderir.Float || + (rhs[0].Const != nil && + rhs[0].ConstType != shaderir.ConstTypeInt && + gconstant.ToFloat(rhs[0].Const).Kind() != gconstant.Unknown)) { + if rhs[0].Const != nil { + rhs[0].Const = gconstant.ToFloat(rhs[0].Const) + rhs[0].ConstType = shaderir.ConstTypeFloat + } } else if op == shaderir.MatrixMul && ((lts[0].Main == shaderir.Vec2 && rts[0].Main == shaderir.Mat2) || (lts[0].Main == shaderir.Vec3 && rts[0].Main == shaderir.Mat3) || (lts[0].Main == shaderir.Vec4 && rts[0].Main == shaderir.Mat4)) { // OK } else if (op == shaderir.MatrixMul || op == shaderir.ComponentWiseMul || lts[0].IsVector()) && - rhs[0].Const != nil && - rhs[0].ConstType != shaderir.ConstTypeInt && - gconstant.ToFloat(rhs[0].Const).Kind() != gconstant.Unknown { - rhs[0].Const = gconstant.ToFloat(rhs[0].Const) - rhs[0].ConstType = shaderir.ConstTypeFloat + (rts[0].Main == shaderir.Float || + (rhs[0].Const != nil && + rhs[0].ConstType != shaderir.ConstTypeInt && + gconstant.ToFloat(rhs[0].Const).Kind() != gconstant.Unknown)) { + if rhs[0].Const != nil { + rhs[0].Const = gconstant.ToFloat(rhs[0].Const) + rhs[0].ConstType = shaderir.ConstTypeFloat + } } else { cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: mismatched types %s and %s", lts[0].String(), rts[0].String())) return nil, false @@ -798,6 +806,10 @@ func canAssign(lt *shaderir.Type, rt *shaderir.Type, rc gconstant.Value) bool { return false } + if !rt.Equal(&shaderir.Type{}) { + return false + } + switch lt.Main { case shaderir.Bool: return rc.Kind() == gconstant.Bool diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index cba40766b..f90b53852 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -2491,24 +2491,78 @@ func TestConstType(t *testing.T) { {stmt: "const a float = false", err: true}, {stmt: "const a vec2 = false", err: true}, + {stmt: "const a = bool(false)", err: false}, + {stmt: "const a bool = bool(false)", err: false}, + {stmt: "const a int = bool(false)", err: true}, + {stmt: "const a float = bool(false)", err: true}, + {stmt: "const a vec2 = bool(false)", err: true}, + + {stmt: "const a = int(false)", err: true}, + {stmt: "const a bool = int(false)", err: true}, + {stmt: "const a int = int(false)", err: true}, + {stmt: "const a float = int(false)", err: true}, + {stmt: "const a vec2 = int(false)", err: true}, + + {stmt: "const a = float(false)", err: true}, + {stmt: "const a bool = float(false)", err: true}, + {stmt: "const a int = float(false)", err: true}, + {stmt: "const a float = float(false)", err: true}, + {stmt: "const a vec2 = float(false)", err: true}, + {stmt: "const a = 1", err: false}, {stmt: "const a bool = 1", err: true}, {stmt: "const a int = 1", err: false}, {stmt: "const a float = 1", err: false}, {stmt: "const a vec2 = 1", err: true}, + {stmt: "const a = int(1)", err: false}, + {stmt: "const a bool = int(1)", err: true}, + {stmt: "const a int = int(1)", err: false}, + {stmt: "const a float = int(1)", err: true}, + {stmt: "const a vec2 = int(1)", err: true}, + + {stmt: "const a = float(1)", err: false}, + {stmt: "const a bool = float(1)", err: true}, + {stmt: "const a int = float(1)", err: true}, + {stmt: "const a float = float(1)", err: false}, + {stmt: "const a vec2 = float(1)", err: true}, + {stmt: "const a = 1.0", err: false}, {stmt: "const a bool = 1.0", err: true}, {stmt: "const a int = 1.0", err: false}, {stmt: "const a float = 1.0", err: false}, {stmt: "const a vec2 = 1.0", err: true}, + {stmt: "const a = int(1.0)", err: false}, + {stmt: "const a bool = int(1.0)", err: true}, + {stmt: "const a int = int(1.0)", err: false}, + {stmt: "const a float = int(1.0)", err: true}, + {stmt: "const a vec2 = int(1.0)", err: true}, + + {stmt: "const a = float(1.0)", err: false}, + {stmt: "const a bool = float(1.0)", err: true}, + {stmt: "const a int = float(1.0)", err: true}, + {stmt: "const a float = float(1.0)", err: false}, + {stmt: "const a vec2 = float(1.0)", err: true}, + {stmt: "const a = 1.1", err: false}, {stmt: "const a bool = 1.1", err: true}, {stmt: "const a int = 1.1", err: true}, {stmt: "const a float = 1.1", err: false}, {stmt: "const a vec2 = 1.1", err: true}, + {stmt: "const a = int(1.1)", err: true}, + {stmt: "const a bool = int(1.1)", err: true}, + {stmt: "const a int = int(1.1)", err: true}, + {stmt: "const a float = int(1.1)", err: true}, + {stmt: "const a vec2 = int(1.1)", err: true}, + + {stmt: "const a = float(1.1)", err: false}, + {stmt: "const a bool = float(1.1)", err: true}, + {stmt: "const a int = float(1.1)", err: true}, + {stmt: "const a float = float(1.1)", err: false}, + {stmt: "const a vec2 = float(1.1)", err: true}, + {stmt: "const a = vec2(0)", err: true}, {stmt: "const a bool = vec2(0)", err: true}, {stmt: "const a int = vec2(0)", err: true},