diff --git a/internal/shader/expr.go b/internal/shader/expr.go index 3a4e31bc9..059d5751c 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -78,54 +78,6 @@ func isValidForModOp(lhs, rhs *shaderir.Expr, lhst, rhst shaderir.Type) bool { return false } -func canApplyBinaryOp(lhs, rhs *shaderir.Expr, lhst, rhst shaderir.Type, op shaderir.Op) bool { - if op == shaderir.AndAnd || op == shaderir.OrOr { - return lhst.Main == shaderir.Bool && rhst.Main == shaderir.Bool - } - - switch { - case lhs.Const != nil && rhs.Const != nil: - if canTruncateToFloat(lhs.Const) && canTruncateToFloat(rhs.Const) { - return true - } - if canTruncateToInteger(lhs.Const) && canTruncateToInteger(rhs.Const) { - return true - } - return lhs.Const.Kind() == rhs.Const.Kind() - - case lhs.Const != nil: - if rhst.Main == shaderir.Float { - return canTruncateToFloat(lhs.Const) - } - if rhst.Main == shaderir.Int { - return canTruncateToInteger(lhs.Const) - } - if rhst.Main == shaderir.Bool { - return lhs.Const.Kind() == gconstant.Bool - } - return false - - case rhs.Const != nil: - if lhst.Main == shaderir.Float { - return canTruncateToFloat(rhs.Const) - } - if lhst.Main == shaderir.Int { - return canTruncateToInteger(rhs.Const) - } - if lhst.Main == shaderir.Bool { - return rhs.Const.Kind() == gconstant.Bool - } - return false - } - - // Comparing matrices are forbidden (#2187). - if lhst.IsMatrix() || rhst.IsMatrix() { - return false - } - - return lhst.Equal(&rhst) -} - func goConstantKindString(k gconstant.Kind) string { switch k { case gconstant.Bool: @@ -211,7 +163,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", op)) return nil, nil, nil, false } - if !canApplyBinaryOp(&lhs[0], &rhs[0], lhst, rhst, op2) { + if !shaderir.AreValidTypesForBinaryOp(op2, &lhs[0], &rhs[0], lhst, rhst) { cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String())) return nil, nil, nil, false } @@ -242,11 +194,28 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } fallthrough default: + op2, ok := shaderir.OpFromToken(op, lhst, rhst) + if !ok { + cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", op)) + return nil, nil, nil, false + } + if !shaderir.AreValidTypesForBinaryOp(op2, &lhs[0], &rhs[0], lhst, rhst) { + cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String())) + return nil, nil, nil, false + } + v = gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const) - if v.Kind() == gconstant.Float { + + switch { + case lhst.Main == shaderir.Float || rhst.Main == shaderir.Float: t = shaderir.Type{Main: shaderir.Float} - } else { + case lhst.Main == shaderir.Int || rhst.Main == shaderir.Int: t = shaderir.Type{Main: shaderir.Int} + case lhst.Main == shaderir.Bool || rhst.Main == shaderir.Bool: + t = shaderir.Type{Main: shaderir.Bool} + default: + // If both operands are untyped, keep untyped. + t = shaderir.Type{} } } @@ -267,7 +236,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar var t shaderir.Type switch { case op == shaderir.LessThanOp || op == shaderir.LessThanEqualOp || op == shaderir.GreaterThanOp || op == shaderir.GreaterThanEqualOp || op == shaderir.EqualOp || op == shaderir.NotEqualOp || op == shaderir.VectorEqualOp || op == shaderir.VectorNotEqualOp || op == shaderir.AndAnd || op == shaderir.OrOr: - if !canApplyBinaryOp(&lhs[0], &rhs[0], lhst, rhst, op) { + if !shaderir.AreValidTypesForBinaryOp(op, &lhs[0], &rhs[0], lhst, rhst) { cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) return nil, nil, nil, false } diff --git a/internal/shader/shader.go b/internal/shader/shader.go index 1a558abfc..df280a056 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -545,6 +545,9 @@ func (s *compileState) parseVariable(block *block, fname string, vs *ast.ValueSp s.addError(vs.Pos(), "the numbers of lhs and rhs don't match") } t = ts[0] + if t.Main == shaderir.None { + t = toDefaultType(es[0].Const) + } } if es[0].Type == shaderir.NumberExpr { @@ -567,6 +570,7 @@ func (s *compileState) parseVariable(block *block, fname string, vs *ast.ValueSp default: // Multiple-value context + // See testcase/var_multiple.go for an actual case. if i == 0 { init := vs.Values[0] @@ -593,6 +597,10 @@ func (s *compileState) parseVariable(block *block, fname string, vs *ast.ValueSp if t.Main == shaderir.None && len(inittypes) > 0 { t = inittypes[i] + // TODO: Is it possible to reach this? + if t.Main == shaderir.None { + t = toDefaultType(initexprs[i].Const) + } } if !canAssign(&t, &inittypes[i], initexprs[i].Const) { diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index 7e99cbcb7..067fe669e 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -625,7 +625,6 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r cs.addError(pos, "single-value context and multiple-value context cannot be mixed") return nil, false } - t := ts[0] if t.Main == shaderir.None { t = toDefaultType(r[0].Const) diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index f3f4955be..9a1a9295b 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -3009,6 +3009,40 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { } } +// Issue #2704 +func TestSyntaxConstType3(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "const x = 1; const y = 1; _ = x * y", err: false}, + {stmt: "const x = 1; const y int = 1; _ = x * y", err: false}, + {stmt: "const x int = 1; const y = 1; _ = x * y", err: false}, + {stmt: "const x int = 1; const y int = 1; _ = x * y", err: false}, + {stmt: "const x = 1; const y float = 1; _ = x * y", err: false}, + {stmt: "const x float = 1; const y = 1; _ = x * y", err: false}, + {stmt: "const x float = 1; const y float = 1; _ = x * y", err: false}, + {stmt: "const x int = 1; const y float = 1; _ = x * y", err: true}, + {stmt: "const x float = 1; const y int = 1; _ = x * y", err: true}, + } + + for _, c := range cases { + stmt := c.stmt + src := fmt.Sprintf(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + %s + return position +}`, stmt) + _, err := compileToIR([]byte(src)) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", stmt, err) + } + } +} + // Issue #2348 func TestSyntaxCompositeLit(t *testing.T) { cases := []struct { diff --git a/internal/shader/testdata/number_binary.expected.vs b/internal/shader/testdata/number_binary.expected.vs index 58f567ae1..08d3bfc33 100644 --- a/internal/shader/testdata/number_binary.expected.vs +++ b/internal/shader/testdata/number_binary.expected.vs @@ -1,5 +1,6 @@ vec4 F0(void); vec4 F1(void); +vec4 F2(void); vec4 F0(void) { int l0 = 0; @@ -24,3 +25,15 @@ vec4 F1(void) { l3 = 2.5000000000e+00; return vec4(l0, l1, l2, l3); } + +vec4 F2(void) { + int l0 = 0; + float l1 = float(0); + float l2 = float(0); + float l3 = float(0); + l0 = 2; + l1 = 2.5000000000e+00; + l2 = 2.5000000000e+00; + l3 = 2.5000000000e+00; + return vec4(l0, l1, l2, l3); +} diff --git a/internal/shader/testdata/number_binary.go b/internal/shader/testdata/number_binary.go index 7f08fbb11..51612da0b 100644 --- a/internal/shader/testdata/number_binary.go +++ b/internal/shader/testdata/number_binary.go @@ -15,3 +15,8 @@ func Foo2() vec4 { var x3 = 5.0 / 2.0 return vec4(x0, x1, x2, x3) } + +func Foo3() vec4 { + var x0, x1, x2, x3 = 5 / 2, 5.0 / 2, 5 / 2.0, 5.0 / 2.0 + return vec4(x0, x1, x2, x3) +} diff --git a/internal/shaderir/check.go b/internal/shaderir/check.go new file mode 100644 index 000000000..b64d1412a --- /dev/null +++ b/internal/shaderir/check.go @@ -0,0 +1,91 @@ +// Copyright 2023 The Ebitengine Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shaderir + +import ( + "go/constant" +) + +func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool { + if op == AndAnd || op == OrOr { + return lhst.Main == Bool && rhst.Main == Bool + } + + if op == VectorEqualOp || op == VectorNotEqualOp { + return lhst.IsVector() && rhst.IsVector() && lhst.Equal(&rhst) + } + + // Comparing matrices are forbidden (#2187). + if op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp || op == EqualOp || op == NotEqualOp { + if lhst.IsMatrix() || rhst.IsMatrix() { + return false + } + } + + // If both are untyped consts, compare the constants and try to truncate them if necessary. + if lhst.Main == None && rhst.Main == None { + if lhs.Const.Kind() == rhs.Const.Kind() { + return true + } + if lhs.Const.Kind() == constant.Float && constant.ToFloat(rhs.Const).Kind() != constant.Unknown { + return true + } + if rhs.Const.Kind() == constant.Float && constant.ToFloat(lhs.Const).Kind() != constant.Unknown { + return true + } + if lhs.Const.Kind() == constant.Int && constant.ToInt(rhs.Const).Kind() != constant.Unknown { + return true + } + if rhs.Const.Kind() == constant.Int && constant.ToInt(lhs.Const).Kind() != constant.Unknown { + return true + } + return false + } + + // If the types match, that's fine. + if lhst.Equal(&rhst) { + return true + } + + // If lhs is untyped and rhs is not, compare the constant and the type and try to truncate the constant if necessary. + if lhst.Main == None { + if rhst.Main == Float { + return constant.ToFloat(lhs.Const).Kind() != constant.Unknown + } + if rhst.Main == Int { + return constant.ToInt(lhs.Const).Kind() != constant.Unknown + } + if rhst.Main == Bool { + return lhs.Const.Kind() == constant.Bool + } + return false + } + + // Ditto. + if rhst.Main == None { + if lhst.Main == Float { + return constant.ToFloat(rhs.Const).Kind() != constant.Unknown + } + if lhst.Main == Int { + return constant.ToInt(rhs.Const).Kind() != constant.Unknown + } + if lhst.Main == Bool { + return rhs.Const.Kind() == constant.Bool + } + return false + } + + return false +} diff --git a/internal/shaderir/program.go b/internal/shaderir/program.go index 6c2486871..39ad677da 100644 --- a/internal/shaderir/program.go +++ b/internal/shaderir/program.go @@ -101,6 +101,7 @@ const ( Discard ) +// TODO: Remove ConstType (#2550) type ConstType int const ( @@ -114,11 +115,13 @@ type Expr struct { Type ExprType Exprs []Expr Const constant.Value - ConstType ConstType BuiltinFunc BuiltinFunc Swizzling string Index int Op Op + + // TODO: Remove ConstType (#2550) + ConstType ConstType } type ExprType int @@ -183,6 +186,10 @@ func OpFromToken(t token.Token, lhs, rhs Type) (Op, bool) { return ComponentWiseMul, true case token.QUO: return Div, true + case token.QUO_ASSIGN: + // QUO_ASSIGN indicates an integer division. + // https://pkg.go.dev/go/constant/#BinaryOp + return Div, true case token.REM: return ModOp, true case token.SHL: