From be2123f7fdd5d49a783b2f72761e70cf07afcd35 Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Fri, 28 Jul 2023 00:54:36 +0900 Subject: [PATCH] Revert "internal/shader: bug fix: stricter type checks for the built-in functions" This reverts commit 287545b02ae25ce0b0bc14ff3e4d55ce40721800. Reason: test failures Updates #2712 --- internal/shader/expr.go | 20 ++-- internal/shader/syntax_test.go | 91 +++++--------- internal/shader/type.go | 211 ++++++--------------------------- internal/shaderir/check.go | 18 +-- internal/shaderir/type.go | 4 +- 5 files changed, 80 insertions(+), 264 deletions(-) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index 5ca5f38c2..16dd200cc 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -291,20 +291,14 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } case shaderir.IntF: if len(args) == 1 && args[0].Const != nil { - var v int64 - switch args[0].Const.Kind() { - case gconstant.Int: - v, _ = gconstant.Int64Val(args[0].Const) - case gconstant.Float: - fv, _ := gconstant.Float64Val(args[0].Const) - v = int64(fv) - default: - panic("not reached") + if !canTruncateToInteger(args[0].Const) { + cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type int", args[0].Const.String())) + return nil, nil, nil, false } return []shaderir.Expr{ { Type: shaderir.NumberExpr, - Const: gconstant.MakeInt64(v), + Const: gconstant.ToInt(args[0].Const), ConstType: shaderir.ConstTypeInt, }, }, []shaderir.Type{{Main: shaderir.Int}}, stmts, true @@ -363,19 +357,19 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } t = shaderir.Type{Main: shaderir.Vec4} case shaderir.IVec2F: - if err := checkArgsForIVec2BuiltinFunc(args, argts); err != nil { + if err := checkArgsForVec2BuiltinFunc(args, argts); err != nil { cs.addError(e.Pos(), err.Error()) return nil, nil, nil, false } t = shaderir.Type{Main: shaderir.IVec2} case shaderir.IVec3F: - if err := checkArgsForIVec3BuiltinFunc(args, argts); err != nil { + if err := checkArgsForVec3BuiltinFunc(args, argts); err != nil { cs.addError(e.Pos(), err.Error()) return nil, nil, nil, false } t = shaderir.Type{Main: shaderir.IVec3} case shaderir.IVec4F: - if err := checkArgsForIVec4BuiltinFunc(args, argts); err != nil { + if err := checkArgsForVec4BuiltinFunc(args, argts); err != nil { cs.addError(e.Pos(), err.Error()) return nil, nil, nil, false } diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index fa94159fe..11e76ce4a 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1653,8 +1653,8 @@ func TestSyntaxConstructorFuncType(t *testing.T) { {stmt: "i := 1.0; a := vec3(i, i, i); _ = a", err: false}, {stmt: "a := vec3(vec2(1), 1); _ = a", err: false}, {stmt: "a := vec3(1, vec2(1)); _ = a", err: false}, - {stmt: "a := vec3(ivec2(1), 1); _ = a", err: true}, - {stmt: "a := vec3(1, ivec2(1)); _ = a", err: true}, + {stmt: "a := vec3(ivec2(1), 1); _ = a", err: false}, + {stmt: "a := vec3(1, ivec2(1)); _ = a", err: false}, {stmt: "a := vec3(vec3(1), 1); _ = a", err: true}, {stmt: "a := vec3(1, vec3(1)); _ = a", err: true}, {stmt: "a := vec3(vec3(1), vec3(1), vec3(1)); _ = a", err: true}, @@ -1678,15 +1678,15 @@ func TestSyntaxConstructorFuncType(t *testing.T) { {stmt: "i := 1.0; a := vec4(i, i, i, i); _ = a", err: false}, {stmt: "a := vec4(vec2(1), 1, 1); _ = a", err: false}, {stmt: "a := vec4(1, vec2(1), 1); _ = a", err: false}, - {stmt: "a := vec4(ivec2(1), 1, 1); _ = a", err: true}, - {stmt: "a := vec4(1, ivec2(1), 1); _ = a", err: true}, + {stmt: "a := vec4(ivec2(1), 1, 1); _ = a", err: false}, + {stmt: "a := vec4(1, ivec2(1), 1); _ = a", err: false}, {stmt: "a := vec4(1, 1, vec2(1)); _ = a", err: false}, {stmt: "a := vec4(vec2(1), vec2(1)); _ = a", err: false}, - {stmt: "a := vec4(ivec2(1), ivec2(1)); _ = a", err: true}, + {stmt: "a := vec4(ivec2(1), ivec2(1)); _ = a", err: false}, {stmt: "a := vec4(vec3(1), 1); _ = a", err: false}, {stmt: "a := vec4(1, vec3(1)); _ = a", err: false}, - {stmt: "a := vec4(ivec3(1), 1); _ = a", err: true}, - {stmt: "a := vec4(1, ivec3(1)); _ = a", err: true}, + {stmt: "a := vec4(ivec3(1), 1); _ = a", err: false}, + {stmt: "a := vec4(1, ivec3(1)); _ = a", err: false}, {stmt: "a := vec4(vec4(1), 1); _ = a", err: true}, {stmt: "a := vec4(1, vec4(1)); _ = a", err: true}, {stmt: "a := vec4(vec4(1), vec4(1), vec4(1), vec4(1)); _ = a", err: true}, @@ -1695,7 +1695,7 @@ func TestSyntaxConstructorFuncType(t *testing.T) { {stmt: "a := ivec2(1); _ = a", err: false}, {stmt: "a := ivec2(1.0); _ = a", err: false}, {stmt: "i := 1; a := ivec2(i); _ = a", err: false}, - {stmt: "i := 1.0; a := ivec2(i); _ = a", err: true}, + {stmt: "i := 1.0; a := ivec2(i); _ = a", err: false}, {stmt: "a := ivec2(vec2(1)); _ = a", err: false}, {stmt: "a := ivec2(vec3(1)); _ = a", err: true}, {stmt: "a := ivec2(ivec2(1)); _ = a", err: false}, @@ -1704,7 +1704,7 @@ func TestSyntaxConstructorFuncType(t *testing.T) { {stmt: "a := ivec2(1, 1); _ = a", err: false}, {stmt: "a := ivec2(1.0, 1.0); _ = a", err: false}, {stmt: "i := 1; a := ivec2(i, i); _ = a", err: false}, - {stmt: "i := 1.0; a := ivec2(i, i); _ = a", err: true}, + {stmt: "i := 1.0; a := ivec2(i, i); _ = a", err: false}, {stmt: "a := ivec2(vec2(1), 1); _ = a", err: true}, {stmt: "a := ivec2(1, vec2(1)); _ = a", err: true}, {stmt: "a := ivec2(ivec2(1), 1); _ = a", err: true}, @@ -1714,9 +1714,9 @@ func TestSyntaxConstructorFuncType(t *testing.T) { {stmt: "a := ivec3(1); _ = a", err: false}, {stmt: "a := ivec3(1.0); _ = a", err: false}, - {stmt: "a := ivec3(1.1); _ = a", err: true}, + {stmt: "a := ivec3(1.1); _ = a", err: false}, {stmt: "i := 1; a := ivec3(i); _ = a", err: false}, - {stmt: "i := 1.0; a := ivec3(i); _ = a", err: true}, + {stmt: "i := 1.0; a := ivec3(i); _ = a", err: false}, {stmt: "a := ivec3(vec3(1)); _ = a", err: false}, {stmt: "a := ivec3(vec2(1)); _ = a", err: true}, {stmt: "a := ivec3(vec4(1)); _ = a", err: true}, @@ -1726,11 +1726,11 @@ func TestSyntaxConstructorFuncType(t *testing.T) { {stmt: "a := ivec3(1, 1, 1); _ = a", err: false}, {stmt: "a := ivec3(1.0, 1.0, 1.0); _ = a", err: false}, - {stmt: "a := ivec3(1.1, 1.1, 1.1); _ = a", err: true}, + {stmt: "a := ivec3(1.1, 1.1, 1.1); _ = a", err: false}, {stmt: "i := 1; a := ivec3(i, i, i); _ = a", err: false}, - {stmt: "i := 1.0; a := ivec3(i, i, i); _ = a", err: true}, - {stmt: "a := ivec3(vec2(1), 1); _ = a", err: true}, - {stmt: "a := ivec3(1, vec2(1)); _ = a", err: true}, + {stmt: "i := 1.0; a := ivec3(i, i, i); _ = a", err: false}, + {stmt: "a := ivec3(vec2(1), 1); _ = a", err: false}, + {stmt: "a := ivec3(1, vec2(1)); _ = a", err: false}, {stmt: "a := ivec3(ivec2(1), 1); _ = a", err: false}, {stmt: "a := ivec3(1, ivec2(1)); _ = a", err: false}, {stmt: "a := ivec3(vec3(1), 1); _ = a", err: true}, @@ -1741,7 +1741,7 @@ func TestSyntaxConstructorFuncType(t *testing.T) { {stmt: "a := ivec4(1); _ = a", err: false}, {stmt: "a := ivec4(1.0); _ = a", err: false}, {stmt: "i := 1; a := ivec4(i); _ = a", err: false}, - {stmt: "i := 1.0; a := ivec4(i); _ = a", err: true}, + {stmt: "i := 1.0; a := ivec4(i); _ = a", err: false}, {stmt: "a := ivec4(vec4(1)); _ = a", err: false}, {stmt: "a := ivec4(vec2(1)); _ = a", err: true}, {stmt: "a := ivec4(vec3(1)); _ = a", err: true}, @@ -1751,19 +1751,19 @@ func TestSyntaxConstructorFuncType(t *testing.T) { {stmt: "a := ivec4(1, 1, 1, 1); _ = a", err: false}, {stmt: "a := ivec4(1.0, 1.0, 1.0, 1.0); _ = a", err: false}, - {stmt: "a := ivec4(1.1, 1.1, 1.1, 1.1); _ = a", err: true}, + {stmt: "a := ivec4(1.1, 1.1, 1.1, 1.1); _ = a", err: false}, {stmt: "i := 1; a := ivec4(i, i, i, i); _ = a", err: false}, - {stmt: "i := 1.0; a := ivec4(i, i, i, i); _ = a", err: true}, - {stmt: "a := ivec4(vec2(1), 1, 1); _ = a", err: true}, - {stmt: "a := ivec4(1, vec2(1), 1); _ = a", err: true}, - {stmt: "a := ivec4(1, 1, vec2(1)); _ = a", err: true}, + {stmt: "i := 1.0; a := ivec4(i, i, i, i); _ = a", err: false}, + {stmt: "a := ivec4(vec2(1), 1, 1); _ = a", err: false}, + {stmt: "a := ivec4(1, vec2(1), 1); _ = a", err: false}, + {stmt: "a := ivec4(1, 1, vec2(1)); _ = a", err: false}, {stmt: "a := ivec4(ivec2(1), 1, 1); _ = a", err: false}, {stmt: "a := ivec4(1, ivec2(1), 1); _ = a", err: false}, {stmt: "a := ivec4(1, 1, ivec2(1)); _ = a", err: false}, - {stmt: "a := ivec4(vec2(1), vec2(1)); _ = a", err: true}, + {stmt: "a := ivec4(vec2(1), vec2(1)); _ = a", err: false}, {stmt: "a := ivec4(ivec2(1), ivec2(1)); _ = a", err: false}, - {stmt: "a := ivec4(vec3(1), 1); _ = a", err: true}, - {stmt: "a := ivec4(1, vec3(1)); _ = a", err: true}, + {stmt: "a := ivec4(vec3(1), 1); _ = a", err: false}, + {stmt: "a := ivec4(1, vec3(1)); _ = a", err: false}, {stmt: "a := ivec4(ivec3(1), 1); _ = a", err: false}, {stmt: "a := ivec4(1, ivec3(1)); _ = a", err: false}, {stmt: "a := ivec4(vec4(1), 1); _ = a", err: true}, @@ -1782,7 +1782,7 @@ func TestSyntaxConstructorFuncType(t *testing.T) { {stmt: "a := mat2(mat4(1)); _ = a", err: true}, {stmt: "a := mat2(vec2(1), vec2(1)); _ = a", err: false}, - {stmt: "a := mat2(ivec2(1), ivec2(1)); _ = a", err: true}, + {stmt: "a := mat2(ivec2(1), ivec2(1)); _ = a", err: false}, {stmt: "a := mat2(1, 1); _ = a", err: true}, {stmt: "a := mat2(1, vec2(1)); _ = a", err: true}, {stmt: "a := mat2(vec2(1), vec3(1)); _ = a", err: true}, @@ -1810,7 +1810,7 @@ func TestSyntaxConstructorFuncType(t *testing.T) { {stmt: "a := mat3(mat4(1)); _ = a", err: true}, {stmt: "a := mat3(vec3(1), vec3(1), vec3(1)); _ = a", err: false}, - {stmt: "a := mat3(ivec3(1), ivec3(1), ivec3(1)); _ = a", err: true}, + {stmt: "a := mat3(ivec3(1), ivec3(1), ivec3(1)); _ = a", err: false}, {stmt: "a := mat3(1, 1, 1); _ = a", err: true}, {stmt: "a := mat3(1, 1, vec3(1)); _ = a", err: true}, {stmt: "a := mat3(vec3(1), vec3(1), vec4(1)); _ = a", err: true}, @@ -1838,7 +1838,7 @@ func TestSyntaxConstructorFuncType(t *testing.T) { {stmt: "a := mat4(mat3(1)); _ = a", err: true}, {stmt: "a := mat4(vec4(1), vec4(1), vec4(1), vec4(1)); _ = a", err: false}, - {stmt: "a := mat4(ivec4(1), ivec4(1), ivec4(1), ivec4(1)); _ = a", err: true}, + {stmt: "a := mat4(ivec4(1), ivec4(1), ivec4(1), ivec4(1)); _ = a", err: false}, {stmt: "a := mat4(1, 1, 1, 1); _ = a", err: true}, {stmt: "a := mat4(1, 1, 1, vec4(1)); _ = a", err: true}, {stmt: "a := mat4(vec4(1), vec4(1), vec4(1), vec2(1)); _ = a", err: true}, @@ -3341,40 +3341,3 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { } } } - -// Issue #2712 -func TestSyntaxCast(t *testing.T) { - cases := []struct { - stmt string - err bool - }{ - {stmt: "a := int(1); _ = a", err: false}, - {stmt: "a := int(1.0); _ = a", err: false}, - {stmt: "a := int(1.1); _ = a", err: false}, - {stmt: "a := float(1); _ = a", err: false}, - {stmt: "a := float(1.0); _ = a", err: false}, - {stmt: "a := float(1.1); _ = a", err: false}, - {stmt: "a := 1; _ = int(a)", err: false}, - {stmt: "a := 1.0; _ = int(a)", err: false}, - {stmt: "a := 1.1; _ = int(a)", err: false}, - {stmt: "a := 1; _ = float(a)", err: false}, - {stmt: "a := 1.0; _ = float(a)", err: false}, - {stmt: "a := 1.1; _ = float(a)", err: false}, - } - - for _, c := range cases { - stmt := c.stmt - src := fmt.Sprintf(`package main - -func Fragment(position vec4, texCoord vec2, color vec4) vec4 { - %s - return position -}`, stmt) - _, err := compileToIR([]byte(src)) - if err == nil && c.err { - t.Errorf("%s must return an error but does not", stmt) - } else if err != nil && !c.err { - t.Errorf("%s must not return nil but returned %v", stmt, err) - } - } -} diff --git a/internal/shader/type.go b/internal/shader/type.go index 6df45ac2b..942dcacb5 100644 --- a/internal/shader/type.go +++ b/internal/shader/type.go @@ -106,17 +106,13 @@ func (cs *compileState) parseType(block *block, fname string, expr ast.Expr) (sh } } -func canBeFloat(expr shaderir.Expr, t shaderir.Type) bool { - if expr.Const != nil { - if t.Main == shaderir.Float { - return true - } - if canTruncateToFloat(expr.Const) { - return true - } - return false +func canBeFloatImplicitly(expr shaderir.Expr, t shaderir.Type) bool { + if expr.Const != nil && canTruncateToFloat(expr.Const) { + return true } + // canBeFloatImplicitly is used for a cast-like functions like float() or vec2(). + // A non-constant integer value is acceptable. if t.Main == shaderir.Int { return true } @@ -126,23 +122,6 @@ func canBeFloat(expr shaderir.Expr, t shaderir.Type) bool { return false } -func canBeInt(expr shaderir.Expr, t shaderir.Type) bool { - if expr.Const != nil { - if t.Main == shaderir.Float { - return true - } - if canTruncateToInteger(expr.Const) { - return true - } - return false - } - - if t.Main == shaderir.Int { - return true - } - return false -} - func checkArgsForBoolBuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error { if len(args) != len(argts) { return fmt.Errorf("the number of arguments and types doesn't match: %d vs %d", len(args), len(argts)) @@ -171,7 +150,7 @@ func checkArgsForIntBuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) err if argts[0].Main == shaderir.Int || argts[0].Main == shaderir.Float { return nil } - if args[0].Const != nil && (args[0].Const.Kind() == gconstant.Float || args[0].Const.Kind() == gconstant.Int) { + if args[0].Const != nil && canTruncateToInteger(args[0].Const) { return nil } return fmt.Errorf("invalid arguments for int: (%s)", argts[0].String()) @@ -185,10 +164,7 @@ func checkArgsForFloatBuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) e if len(args) != 1 { return fmt.Errorf("number of float's arguments must be 1 but %d", len(args)) } - if argts[0].Main == shaderir.Int || argts[0].Main == shaderir.Float { - return nil - } - if args[0].Const != nil && (args[0].Const.Kind() == gconstant.Float || args[0].Const.Kind() == gconstant.Int) { + if canBeFloatImplicitly(args[0], argts[0]) { return nil } return fmt.Errorf("invalid arguments for float: (%s)", argts[0].String()) @@ -201,15 +177,14 @@ func checkArgsForVec2BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) er switch len(args) { case 1: - if canBeFloat(args[0], argts[0]) { + if canBeFloatImplicitly(args[0], argts[0]) { return nil } - // Allow any vectors to perform a cast-like function. if argts[0].IsVector() && argts[0].VectorElementCount() == 2 { return nil } case 2: - if canBeFloat(args[0], argts[0]) && canBeFloat(args[1], argts[1]) { + if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) { return nil } default: @@ -230,22 +205,21 @@ func checkArgsForVec3BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) er switch len(args) { case 1: - if canBeFloat(args[0], argts[0]) { + if canBeFloatImplicitly(args[0], argts[0]) { return nil } - // Allow any vectors to perform a cast-like function. if argts[0].IsVector() && argts[0].VectorElementCount() == 3 { return nil } case 2: - if canBeFloat(args[0], argts[0]) && argts[1].IsFloatVector() && argts[1].VectorElementCount() == 2 { + if canBeFloatImplicitly(args[0], argts[0]) && argts[1].IsVector() && argts[1].VectorElementCount() == 2 { return nil } - if argts[0].IsFloatVector() && argts[0].VectorElementCount() == 2 && canBeFloat(args[1], argts[1]) { + if argts[0].IsVector() && argts[0].VectorElementCount() == 2 && canBeFloatImplicitly(args[1], argts[1]) { return nil } case 3: - if canBeFloat(args[0], argts[0]) && canBeFloat(args[1], argts[1]) && canBeFloat(args[2], argts[2]) { + if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) && canBeFloatImplicitly(args[2], argts[2]) { return nil } default: @@ -266,35 +240,34 @@ func checkArgsForVec4BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) er switch len(args) { case 1: - if canBeFloat(args[0], argts[0]) { + if canBeFloatImplicitly(args[0], argts[0]) { return nil } - // Allow any vectors to perform a cast-like function. if argts[0].IsVector() && argts[0].VectorElementCount() == 4 { return nil } case 2: - if canBeFloat(args[0], argts[0]) && argts[1].IsFloatVector() && argts[1].VectorElementCount() == 3 { + if canBeFloatImplicitly(args[0], argts[0]) && argts[1].IsVector() && argts[1].VectorElementCount() == 3 { return nil } - if argts[0].IsFloatVector() && argts[0].VectorElementCount() == 2 && argts[1].IsFloatVector() && argts[1].VectorElementCount() == 2 { + if argts[0].IsVector() && argts[0].VectorElementCount() == 2 && argts[1].IsVector() && argts[1].VectorElementCount() == 2 { return nil } - if argts[0].IsFloatVector() && argts[0].VectorElementCount() == 3 && canBeFloat(args[1], argts[1]) { + if argts[0].IsVector() && argts[0].VectorElementCount() == 3 && canBeFloatImplicitly(args[1], argts[1]) { return nil } case 3: - if canBeFloat(args[0], argts[0]) && canBeFloat(args[1], argts[1]) && argts[2].IsFloatVector() && argts[2].VectorElementCount() == 2 { + if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) && argts[2].IsVector() && argts[2].VectorElementCount() == 2 { return nil } - if canBeFloat(args[0], argts[0]) && argts[1].IsFloatVector() && argts[1].VectorElementCount() == 2 && canBeFloat(args[2], argts[2]) { + if canBeFloatImplicitly(args[0], argts[0]) && argts[1].IsVector() && argts[1].VectorElementCount() == 2 && canBeFloatImplicitly(args[2], argts[2]) { return nil } - if argts[0].IsFloatVector() && argts[0].VectorElementCount() == 2 && canBeFloat(args[1], argts[1]) && canBeFloat(args[2], argts[2]) { + if argts[0].IsVector() && argts[0].VectorElementCount() == 2 && canBeFloatImplicitly(args[1], argts[1]) && canBeFloatImplicitly(args[2], argts[2]) { return nil } case 4: - if canBeFloat(args[0], argts[0]) && canBeFloat(args[1], argts[1]) && canBeFloat(args[2], argts[2]) && canBeFloat(args[3], argts[3]) { + if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) && canBeFloatImplicitly(args[2], argts[2]) && canBeFloatImplicitly(args[3], argts[3]) { return nil } default: @@ -308,120 +281,6 @@ func checkArgsForVec4BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) er return fmt.Errorf("invalid arguments for vec4: (%s)", strings.Join(str, ", ")) } -func checkArgsForIVec2BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error { - if len(args) != len(argts) { - return fmt.Errorf("the number of arguments and types doesn't match: %d vs %d", len(args), len(argts)) - } - - switch len(args) { - case 1: - if canBeInt(args[0], argts[0]) { - return nil - } - // Allow any vectors to perform a cast-like function. - if argts[0].IsVector() && argts[0].VectorElementCount() == 2 { - return nil - } - case 2: - if canBeInt(args[0], argts[0]) && canBeInt(args[1], argts[1]) { - return nil - } - default: - return fmt.Errorf("invalid number of arguments for vec2") - } - - var str []string - for _, t := range argts { - str = append(str, t.String()) - } - return fmt.Errorf("invalid arguments for ivec2: (%s)", strings.Join(str, ", ")) -} - -func checkArgsForIVec3BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error { - if len(args) != len(argts) { - return fmt.Errorf("the number of arguments and types doesn't match: %d vs %d", len(args), len(argts)) - } - - switch len(args) { - case 1: - if canBeInt(args[0], argts[0]) { - return nil - } - // Allow any vectors to perform a cast-like function. - if argts[0].IsVector() && argts[0].VectorElementCount() == 3 { - return nil - } - case 2: - if canBeInt(args[0], argts[0]) && argts[1].IsIntVector() && argts[1].VectorElementCount() == 2 { - return nil - } - if argts[0].IsIntVector() && argts[0].VectorElementCount() == 2 && canBeInt(args[1], argts[1]) { - return nil - } - case 3: - if canBeInt(args[0], argts[0]) && canBeInt(args[1], argts[1]) && canBeInt(args[2], argts[2]) { - return nil - } - default: - return fmt.Errorf("invalid number of arguments for vec3") - } - - var str []string - for _, t := range argts { - str = append(str, t.String()) - } - return fmt.Errorf("invalid arguments for ivec3: (%s)", strings.Join(str, ", ")) -} - -func checkArgsForIVec4BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error { - if len(args) != len(argts) { - return fmt.Errorf("the number of arguments and types doesn't match: %d vs %d", len(args), len(argts)) - } - - switch len(args) { - case 1: - if canBeInt(args[0], argts[0]) { - return nil - } - // Allow any vectors to perform a cast-like function. - if argts[0].IsVector() && argts[0].VectorElementCount() == 4 { - return nil - } - case 2: - if canBeInt(args[0], argts[0]) && argts[1].IsIntVector() && argts[1].VectorElementCount() == 3 { - return nil - } - if argts[0].IsIntVector() && argts[0].VectorElementCount() == 2 && argts[1].IsIntVector() && argts[1].VectorElementCount() == 2 { - return nil - } - if argts[0].IsIntVector() && argts[0].VectorElementCount() == 3 && canBeInt(args[1], argts[1]) { - return nil - } - case 3: - if canBeInt(args[0], argts[0]) && canBeInt(args[1], argts[1]) && argts[2].IsIntVector() && argts[2].VectorElementCount() == 2 { - return nil - } - if canBeInt(args[0], argts[0]) && argts[1].IsIntVector() && argts[1].VectorElementCount() == 2 && canBeInt(args[2], argts[2]) { - return nil - } - if argts[0].IsIntVector() && argts[0].VectorElementCount() == 2 && canBeInt(args[1], argts[1]) && canBeInt(args[2], argts[2]) { - return nil - } - case 4: - if canBeInt(args[0], argts[0]) && canBeInt(args[1], argts[1]) && canBeInt(args[2], argts[2]) && canBeInt(args[3], argts[3]) { - return nil - } - default: - return fmt.Errorf("invalid number of arguments for vec4") - } - - var str []string - for _, t := range argts { - str = append(str, t.String()) - } - return fmt.Errorf("invalid arguments for ivec4: (%s)", strings.Join(str, ", ")) -} - func checkArgsForMat2BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error { if len(args) != len(argts) { return fmt.Errorf("the number of arguments and types doesn't match: %d vs %d", len(args), len(argts)) @@ -429,20 +288,20 @@ func checkArgsForMat2BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) er switch len(args) { case 1: - if canBeFloat(args[0], argts[0]) { + if canBeFloatImplicitly(args[0], argts[0]) { return nil } if argts[0].Main == shaderir.Mat2 { return nil } case 2: - if argts[0].IsFloatVector() && argts[0].VectorElementCount() == 2 && argts[1].IsFloatVector() && argts[1].VectorElementCount() == 2 { + if argts[0].IsVector() && argts[0].VectorElementCount() == 2 && argts[1].IsVector() && argts[1].VectorElementCount() == 2 { return nil } case 4: ok := true for i := range argts { - if !canBeFloat(args[i], argts[i]) { + if !canBeFloatImplicitly(args[i], argts[i]) { ok = false break } @@ -468,22 +327,22 @@ func checkArgsForMat3BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) er switch len(args) { case 1: - if canBeFloat(args[0], argts[0]) { + if canBeFloatImplicitly(args[0], argts[0]) { return nil } if argts[0].Main == shaderir.Mat3 { return nil } case 3: - if argts[0].IsFloatVector() && argts[0].VectorElementCount() == 3 && - argts[1].IsFloatVector() && argts[1].VectorElementCount() == 3 && - argts[2].IsFloatVector() && argts[2].VectorElementCount() == 3 { + if argts[0].IsVector() && argts[0].VectorElementCount() == 3 && + argts[1].IsVector() && argts[1].VectorElementCount() == 3 && + argts[2].IsVector() && argts[2].VectorElementCount() == 3 { return nil } case 9: ok := true for i := range argts { - if !canBeFloat(args[i], argts[i]) { + if !canBeFloatImplicitly(args[i], argts[i]) { ok = false break } @@ -509,23 +368,23 @@ func checkArgsForMat4BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) er switch len(args) { case 1: - if canBeFloat(args[0], argts[0]) { + if canBeFloatImplicitly(args[0], argts[0]) { return nil } if argts[0].Main == shaderir.Mat4 { return nil } case 4: - if argts[0].IsFloatVector() && argts[0].VectorElementCount() == 4 && - argts[1].IsFloatVector() && argts[1].VectorElementCount() == 4 && - argts[2].IsFloatVector() && argts[2].VectorElementCount() == 4 && - argts[3].IsFloatVector() && argts[3].VectorElementCount() == 4 { + if argts[0].IsVector() && argts[0].VectorElementCount() == 4 && + argts[1].IsVector() && argts[1].VectorElementCount() == 4 && + argts[2].IsVector() && argts[2].VectorElementCount() == 4 && + argts[3].IsVector() && argts[3].VectorElementCount() == 4 { return nil } case 16: ok := true for i := range argts { - if !canBeFloat(args[i], argts[i]) { + if !canBeFloatImplicitly(args[i], argts[i]) { ok = false break } diff --git a/internal/shaderir/check.go b/internal/shaderir/check.go index a37fbf309..ebb899264 100644 --- a/internal/shaderir/check.go +++ b/internal/shaderir/check.go @@ -39,10 +39,10 @@ func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) ( } if lhst.Main == None { - if (rhst.Main == Float || rhst.IsFloatVector() || rhst.IsMatrix()) && constant.ToFloat(lhs).Kind() != constant.Unknown { + if (rhst.Main == Float || rhst.isFloatVector() || rhst.IsMatrix()) && constant.ToFloat(lhs).Kind() != constant.Unknown { return constant.ToFloat(lhs), rhs, true } - if (rhst.Main == Int || rhst.IsIntVector()) && constant.ToInt(lhs).Kind() != constant.Unknown { + if (rhst.Main == Int || rhst.isIntVector()) && constant.ToInt(lhs).Kind() != constant.Unknown { return constant.ToInt(lhs), rhs, true } if rhst.Main == Bool && lhs.Kind() == constant.Bool { @@ -52,10 +52,10 @@ func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) ( } if rhst.Main == None { - if (lhst.Main == Float || lhst.IsFloatVector() || lhst.IsMatrix()) && constant.ToFloat(rhs).Kind() != constant.Unknown { + if (lhst.Main == Float || lhst.isFloatVector() || lhst.IsMatrix()) && constant.ToFloat(rhs).Kind() != constant.Unknown { return lhs, constant.ToFloat(rhs), true } - if (lhst.Main == Int || lhst.IsIntVector()) && constant.ToInt(rhs).Kind() != constant.Unknown { + if (lhst.Main == Int || lhst.isIntVector()) && constant.ToInt(rhs).Kind() != constant.Unknown { return lhs, constant.ToInt(rhs), true } if lhst.Main == Bool && rhs.Kind() == constant.Bool { @@ -122,7 +122,7 @@ func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool { if lhst.Main == IVec4 && rhst.Main == IVec4 { return true } - return (lhst.Main == Int || lhst.IsIntVector()) && rhst.Main == Int + return (lhst.Main == Int || lhst.isIntVector()) && rhst.Main == Int } if lhst.Equal(&rhst) { @@ -164,16 +164,16 @@ func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool { // fallback } - if lhst.IsFloatVector() && rhst.Main == Float { + if lhst.isFloatVector() && rhst.Main == Float { return true } - if rhst.IsFloatVector() && lhst.Main == Float { + if rhst.isFloatVector() && lhst.Main == Float { return true } - if lhst.IsIntVector() && rhst.Main == Int { + if lhst.isIntVector() && rhst.Main == Int { return true } - if rhst.IsIntVector() && lhst.Main == Int { + if rhst.isIntVector() && lhst.Main == Int { return true } diff --git a/internal/shaderir/type.go b/internal/shaderir/type.go index 807b5f8de..76733a70b 100644 --- a/internal/shaderir/type.go +++ b/internal/shaderir/type.go @@ -126,7 +126,7 @@ func (t *Type) IsVector() bool { return false } -func (t *Type) IsFloatVector() bool { +func (t *Type) isFloatVector() bool { switch t.Main { case Vec2, Vec3, Vec4: return true @@ -134,7 +134,7 @@ func (t *Type) IsFloatVector() bool { return false } -func (t *Type) IsIntVector() bool { +func (t *Type) isIntVector() bool { switch t.Main { case IVec2, IVec3, IVec4: return true