diff --git a/internal/shader/expr.go b/internal/shader/expr.go index abdf93eaf..273d4cb3e 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -276,26 +276,26 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } // Process the expression as a regular function call. - var t shaderir.Type + var finalType shaderir.Type switch callee.BuiltinFunc { case shaderir.BoolF: if err := checkArgsForBoolBuiltinFunc(args, argts); err != nil { cs.addError(e.Pos(), err.Error()) return nil, nil, nil, false } - t = shaderir.Type{Main: shaderir.Bool} + finalType = shaderir.Type{Main: shaderir.Bool} case shaderir.IntF: if err := checkArgsForIntBuiltinFunc(args, argts); err != nil { cs.addError(e.Pos(), err.Error()) return nil, nil, nil, false } - t = shaderir.Type{Main: shaderir.Int} + finalType = shaderir.Type{Main: shaderir.Int} case shaderir.FloatF: if err := checkArgsForFloatBuiltinFunc(args, argts); err != nil { cs.addError(e.Pos(), err.Error()) return nil, nil, nil, false } - t = shaderir.Type{Main: shaderir.Float} + finalType = shaderir.Type{Main: shaderir.Float} case shaderir.Vec2F: if err := checkArgsForVec2BuiltinFunc(args, argts); err != nil { cs.addError(e.Pos(), err.Error()) @@ -308,7 +308,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar args[i].Const = gconstant.ToFloat(args[i].Const) argts[i] = shaderir.Type{Main: shaderir.Float} } - t = shaderir.Type{Main: shaderir.Vec2} + finalType = shaderir.Type{Main: shaderir.Vec2} case shaderir.Vec3F: if err := checkArgsForVec3BuiltinFunc(args, argts); err != nil { cs.addError(e.Pos(), err.Error()) @@ -321,7 +321,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar args[i].Const = gconstant.ToFloat(args[i].Const) argts[i] = shaderir.Type{Main: shaderir.Float} } - t = shaderir.Type{Main: shaderir.Vec3} + finalType = shaderir.Type{Main: shaderir.Vec3} case shaderir.Vec4F: if err := checkArgsForVec4BuiltinFunc(args, argts); err != nil { cs.addError(e.Pos(), err.Error()) @@ -334,25 +334,25 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar args[i].Const = gconstant.ToFloat(args[i].Const) argts[i] = shaderir.Type{Main: shaderir.Float} } - t = shaderir.Type{Main: shaderir.Vec4} + finalType = shaderir.Type{Main: shaderir.Vec4} case shaderir.IVec2F: if err := checkArgsForIVec2BuiltinFunc(args, argts); err != nil { cs.addError(e.Pos(), err.Error()) return nil, nil, nil, false } - t = shaderir.Type{Main: shaderir.IVec2} + finalType = shaderir.Type{Main: shaderir.IVec2} case shaderir.IVec3F: if err := checkArgsForIVec3BuiltinFunc(args, argts); err != nil { cs.addError(e.Pos(), err.Error()) return nil, nil, nil, false } - t = shaderir.Type{Main: shaderir.IVec3} + finalType = shaderir.Type{Main: shaderir.IVec3} case shaderir.IVec4F: if err := checkArgsForIVec4BuiltinFunc(args, argts); err != nil { cs.addError(e.Pos(), err.Error()) return nil, nil, nil, false } - t = shaderir.Type{Main: shaderir.IVec4} + finalType = shaderir.Type{Main: shaderir.IVec4} case shaderir.Mat2F: if err := checkArgsForMat2BuiltinFunc(args, argts); err != nil { cs.addError(e.Pos(), err.Error()) @@ -365,7 +365,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar args[i].Const = gconstant.ToFloat(args[i].Const) argts[i] = shaderir.Type{Main: shaderir.Float} } - t = shaderir.Type{Main: shaderir.Mat2} + finalType = shaderir.Type{Main: shaderir.Mat2} case shaderir.Mat3F: if err := checkArgsForMat3BuiltinFunc(args, argts); err != nil { cs.addError(e.Pos(), err.Error()) @@ -378,7 +378,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar args[i].Const = gconstant.ToFloat(args[i].Const) argts[i] = shaderir.Type{Main: shaderir.Float} } - t = shaderir.Type{Main: shaderir.Mat3} + finalType = shaderir.Type{Main: shaderir.Mat3} case shaderir.Mat4F: if err := checkArgsForMat4BuiltinFunc(args, argts); err != nil { cs.addError(e.Pos(), err.Error()) @@ -391,7 +391,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar args[i].Const = gconstant.ToFloat(args[i].Const) argts[i] = shaderir.Type{Main: shaderir.Float} } - t = shaderir.Type{Main: shaderir.Mat4} + finalType = shaderir.Type{Main: shaderir.Mat4} case shaderir.TexelAt: if len(args) != 2 { cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 2 but %d", callee.BuiltinFunc, len(args))) @@ -405,7 +405,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar cs.addError(e.Pos(), fmt.Sprintf("cannot use %s as vec2 value in argument to %s", argts[1].String(), callee.BuiltinFunc)) return nil, nil, nil, false } - t = shaderir.Type{Main: shaderir.Vec4} + finalType = shaderir.Type{Main: shaderir.Vec4} case shaderir.DiscardF: if len(args) != 0 { cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 0 but %d", callee.BuiltinFunc, len(args))) @@ -428,13 +428,17 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } switch callee.BuiltinFunc { case shaderir.Clamp: - if kind, allConsts := resolveConstKind(args, argts); allConsts { + if kind, _ := resolveConstKind(args, argts); kind != gconstant.Unknown { switch kind { - case gconstant.Unknown: - cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s, %s, and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String(), argts[2].String())) - return nil, nil, nil, false case gconstant.Int: for i, arg := range args { + if arg.Const == nil { + if argts[i].Main != shaderir.Int { + cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s, %s, and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String(), argts[2].String())) + return nil, nil, nil, false + } + continue + } v := gconstant.ToInt(arg.Const) if v.Kind() == gconstant.Unknown { cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type int", arg.Const.String())) @@ -445,6 +449,13 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } case gconstant.Float: for i, arg := range args { + if arg.Const == nil { + if argts[i].Main != shaderir.Float { + cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s, %s, and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String(), argts[2].String())) + return nil, nil, nil, false + } + continue + } v := gconstant.ToFloat(arg.Const) if v.Kind() == gconstant.Unknown { cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type float", arg.Const.String())) @@ -544,9 +555,9 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar switch callee.BuiltinFunc { case shaderir.Smoothstep: - t = argts[2] + finalType = argts[2] default: - t = argts[0] + finalType = argts[0] } case shaderir.Atan2, shaderir.Pow, shaderir.Mod, shaderir.Min, shaderir.Max, shaderir.Step, shaderir.Distance, shaderir.Dot, shaderir.Cross, shaderir.Reflect: @@ -558,13 +569,17 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar switch callee.BuiltinFunc { case shaderir.Min, shaderir.Max: - if kind, allConsts := resolveConstKind(args, argts); allConsts { + if kind, _ := resolveConstKind(args, argts); kind != gconstant.Unknown { switch kind { - case gconstant.Unknown: - cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String())) - return nil, nil, nil, false case gconstant.Int: for i, arg := range args { + if arg.Const == nil { + if argts[i].Main != shaderir.Int { + cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String())) + return nil, nil, nil, false + } + continue + } v := gconstant.ToInt(arg.Const) if v.Kind() == gconstant.Unknown { cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type int", arg.Const.String())) @@ -575,6 +590,13 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } case gconstant.Float: for i, arg := range args { + if arg.Const == nil { + if argts[i].Main != shaderir.Float { + cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String())) + return nil, nil, nil, false + } + continue + } v := gconstant.ToFloat(arg.Const) if v.Kind() == gconstant.Unknown { cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type float", arg.Const.String())) @@ -662,11 +684,11 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } switch callee.BuiltinFunc { case shaderir.Distance, shaderir.Dot: - t = shaderir.Type{Main: shaderir.Float} + finalType = shaderir.Type{Main: shaderir.Float} case shaderir.Step: - t = argts[1] + finalType = argts[1] default: - t = argts[0] + finalType = argts[0] } default: @@ -711,9 +733,9 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } } if callee.BuiltinFunc == shaderir.Length { - t = shaderir.Type{Main: shaderir.Float} + finalType = shaderir.Type{Main: shaderir.Float} } else { - t = argts[0] + finalType = argts[0] } } return []shaderir.Expr{ @@ -721,7 +743,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar Type: shaderir.Call, Exprs: append([]shaderir.Expr{callee}, args...), }, - }, []shaderir.Type{t}, stmts, true + }, []shaderir.Type{finalType}, stmts, true } if callee.Type != shaderir.FunctionExpr { @@ -1162,8 +1184,24 @@ func resolveConstKind(exprs []shaderir.Expr, ts []shaderir.Type) (kind gconstant panic("not reached") } + allConsts = true for _, expr := range exprs { if expr.Const == nil { + allConsts = false + } + } + + if !allConsts { + for _, t := range ts { + if t.Main == shaderir.None { + continue + } + if t.Main == shaderir.Float { + return gconstant.Float, false + } + if t.Main == shaderir.Int { + return gconstant.Int, false + } return gconstant.Unknown, false } } @@ -1192,17 +1230,15 @@ func resolveConstKind(exprs []shaderir.Expr, ts []shaderir.Type) (kind gconstant } } - if kind == gconstant.Float { - return gconstant.Float, true + if kind != gconstant.Unknown { + return kind, true } // Prefer floats over integers for non-typed constant values. // For example, max(1.0, 1) should return a float value. - if kind == gconstant.Unknown { - for _, expr := range exprs { - if expr.Const.Kind() == gconstant.Float { - return gconstant.Float, true - } + for _, expr := range exprs { + if expr.Const.Kind() == gconstant.Float { + return gconstant.Float, true } } diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 42dd1219c..d62cada9a 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -2347,6 +2347,10 @@ func TestSyntaxBuiltinFuncArgsMinMax(t *testing.T) { {stmt: "a := {{.Func}}(int(1), int(1)); var _ int = a", err: false}, {stmt: "a := {{.Func}}(int(1), float(1)); _ = a", err: true}, {stmt: "a := {{.Func}}(float(1), int(1)); _ = a", err: true}, + {stmt: "x := 1.1; a := {{.Func}}(int(x), 1); _ = a", err: false}, + {stmt: "x := 1; a := {{.Func}}(float(x), 1.1); _ = a", err: false}, + {stmt: "x := 1.1; a := {{.Func}}(1, int(x)); _ = a", err: false}, + {stmt: "x := 1; a := {{.Func}}(1.1, float(x)); _ = a", err: false}, {stmt: "a := {{.Func}}(1, vec2(1)); _ = a", err: true}, {stmt: "a := {{.Func}}(1, vec3(1)); _ = a", err: true}, {stmt: "a := {{.Func}}(1, vec4(1)); _ = a", err: true}, @@ -2547,6 +2551,12 @@ func TestSyntaxBuiltinFuncClampType(t *testing.T) { {stmt: "a := clamp(float(1), 1, 1.0); var _ float = a", err: false}, {stmt: "a := clamp(float(1), 1.1, 1); _ = a", err: false}, {stmt: "a := clamp(float(1), 1, 1.1); _ = a", err: false}, + {stmt: "x := 1.1; a := clamp(int(x), 1, 1); _ = a", err: false}, + {stmt: "x := 1; a := clamp(float(x), 1.1, 1.1); _ = a", err: false}, + {stmt: "x := 1.1; a := clamp(1, int(x), 1); _ = a", err: false}, + {stmt: "x := 1; a := clamp(1.1, float(x), 1.1); _ = a", err: false}, + {stmt: "x := 1.1; a := clamp(1, 1, int(x)); _ = a", err: false}, + {stmt: "x := 1; a := clamp(1.1, 1.1, float(x)); _ = a", err: false}, {stmt: "a := clamp(1.0, 1, 1); var _ float = a", err: false}, {stmt: "a := clamp(1, 1.0, 1); var _ float = a", err: false}, {stmt: "a := clamp(1, 1, 1.0); var _ float = a", err: false},