internal/shader: bug fix: don't allow a binary op with different typed constants

Closes #2704
This commit is contained in:
Hajime Hoshi 2023-07-23 23:40:22 +09:00
parent 8276a53dd6
commit a8c3eb7167
8 changed files with 180 additions and 54 deletions

View File

@ -78,54 +78,6 @@ func isValidForModOp(lhs, rhs *shaderir.Expr, lhst, rhst shaderir.Type) bool {
return false return false
} }
func canApplyBinaryOp(lhs, rhs *shaderir.Expr, lhst, rhst shaderir.Type, op shaderir.Op) bool {
if op == shaderir.AndAnd || op == shaderir.OrOr {
return lhst.Main == shaderir.Bool && rhst.Main == shaderir.Bool
}
switch {
case lhs.Const != nil && rhs.Const != nil:
if canTruncateToFloat(lhs.Const) && canTruncateToFloat(rhs.Const) {
return true
}
if canTruncateToInteger(lhs.Const) && canTruncateToInteger(rhs.Const) {
return true
}
return lhs.Const.Kind() == rhs.Const.Kind()
case lhs.Const != nil:
if rhst.Main == shaderir.Float {
return canTruncateToFloat(lhs.Const)
}
if rhst.Main == shaderir.Int {
return canTruncateToInteger(lhs.Const)
}
if rhst.Main == shaderir.Bool {
return lhs.Const.Kind() == gconstant.Bool
}
return false
case rhs.Const != nil:
if lhst.Main == shaderir.Float {
return canTruncateToFloat(rhs.Const)
}
if lhst.Main == shaderir.Int {
return canTruncateToInteger(rhs.Const)
}
if lhst.Main == shaderir.Bool {
return rhs.Const.Kind() == gconstant.Bool
}
return false
}
// Comparing matrices are forbidden (#2187).
if lhst.IsMatrix() || rhst.IsMatrix() {
return false
}
return lhst.Equal(&rhst)
}
func goConstantKindString(k gconstant.Kind) string { func goConstantKindString(k gconstant.Kind) string {
switch k { switch k {
case gconstant.Bool: case gconstant.Bool:
@ -211,7 +163,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", op)) cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", op))
return nil, nil, nil, false return nil, nil, nil, false
} }
if !canApplyBinaryOp(&lhs[0], &rhs[0], lhst, rhst, op2) { if !shaderir.AreValidTypesForBinaryOp(op2, &lhs[0], &rhs[0], lhst, rhst) {
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String())) cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String()))
return nil, nil, nil, false return nil, nil, nil, false
} }
@ -242,11 +194,28 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
} }
fallthrough fallthrough
default: default:
op2, ok := shaderir.OpFromToken(op, lhst, rhst)
if !ok {
cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", op))
return nil, nil, nil, false
}
if !shaderir.AreValidTypesForBinaryOp(op2, &lhs[0], &rhs[0], lhst, rhst) {
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String()))
return nil, nil, nil, false
}
v = gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const) v = gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const)
if v.Kind() == gconstant.Float {
switch {
case lhst.Main == shaderir.Float || rhst.Main == shaderir.Float:
t = shaderir.Type{Main: shaderir.Float} t = shaderir.Type{Main: shaderir.Float}
} else { case lhst.Main == shaderir.Int || rhst.Main == shaderir.Int:
t = shaderir.Type{Main: shaderir.Int} t = shaderir.Type{Main: shaderir.Int}
case lhst.Main == shaderir.Bool || rhst.Main == shaderir.Bool:
t = shaderir.Type{Main: shaderir.Bool}
default:
// If both operands are untyped, keep untyped.
t = shaderir.Type{}
} }
} }
@ -267,7 +236,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
var t shaderir.Type var t shaderir.Type
switch { switch {
case op == shaderir.LessThanOp || op == shaderir.LessThanEqualOp || op == shaderir.GreaterThanOp || op == shaderir.GreaterThanEqualOp || op == shaderir.EqualOp || op == shaderir.NotEqualOp || op == shaderir.VectorEqualOp || op == shaderir.VectorNotEqualOp || op == shaderir.AndAnd || op == shaderir.OrOr: case op == shaderir.LessThanOp || op == shaderir.LessThanEqualOp || op == shaderir.GreaterThanOp || op == shaderir.GreaterThanEqualOp || op == shaderir.EqualOp || op == shaderir.NotEqualOp || op == shaderir.VectorEqualOp || op == shaderir.VectorNotEqualOp || op == shaderir.AndAnd || op == shaderir.OrOr:
if !canApplyBinaryOp(&lhs[0], &rhs[0], lhst, rhst, op) { if !shaderir.AreValidTypesForBinaryOp(op, &lhs[0], &rhs[0], lhst, rhst) {
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false return nil, nil, nil, false
} }

View File

@ -545,6 +545,9 @@ func (s *compileState) parseVariable(block *block, fname string, vs *ast.ValueSp
s.addError(vs.Pos(), "the numbers of lhs and rhs don't match") s.addError(vs.Pos(), "the numbers of lhs and rhs don't match")
} }
t = ts[0] t = ts[0]
if t.Main == shaderir.None {
t = toDefaultType(es[0].Const)
}
} }
if es[0].Type == shaderir.NumberExpr { if es[0].Type == shaderir.NumberExpr {
@ -567,6 +570,7 @@ func (s *compileState) parseVariable(block *block, fname string, vs *ast.ValueSp
default: default:
// Multiple-value context // Multiple-value context
// See testcase/var_multiple.go for an actual case.
if i == 0 { if i == 0 {
init := vs.Values[0] init := vs.Values[0]
@ -593,6 +597,10 @@ func (s *compileState) parseVariable(block *block, fname string, vs *ast.ValueSp
if t.Main == shaderir.None && len(inittypes) > 0 { if t.Main == shaderir.None && len(inittypes) > 0 {
t = inittypes[i] t = inittypes[i]
// TODO: Is it possible to reach this?
if t.Main == shaderir.None {
t = toDefaultType(initexprs[i].Const)
}
} }
if !canAssign(&t, &inittypes[i], initexprs[i].Const) { if !canAssign(&t, &inittypes[i], initexprs[i].Const) {

View File

@ -625,7 +625,6 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
cs.addError(pos, "single-value context and multiple-value context cannot be mixed") cs.addError(pos, "single-value context and multiple-value context cannot be mixed")
return nil, false return nil, false
} }
t := ts[0] t := ts[0]
if t.Main == shaderir.None { if t.Main == shaderir.None {
t = toDefaultType(r[0].Const) t = toDefaultType(r[0].Const)

View File

@ -3009,6 +3009,40 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
} }
} }
// Issue #2704
func TestSyntaxConstType3(t *testing.T) {
cases := []struct {
stmt string
err bool
}{
{stmt: "const x = 1; const y = 1; _ = x * y", err: false},
{stmt: "const x = 1; const y int = 1; _ = x * y", err: false},
{stmt: "const x int = 1; const y = 1; _ = x * y", err: false},
{stmt: "const x int = 1; const y int = 1; _ = x * y", err: false},
{stmt: "const x = 1; const y float = 1; _ = x * y", err: false},
{stmt: "const x float = 1; const y = 1; _ = x * y", err: false},
{stmt: "const x float = 1; const y float = 1; _ = x * y", err: false},
{stmt: "const x int = 1; const y float = 1; _ = x * y", err: true},
{stmt: "const x float = 1; const y int = 1; _ = x * y", err: true},
}
for _, c := range cases {
stmt := c.stmt
src := fmt.Sprintf(`package main
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
%s
return position
}`, stmt)
_, err := compileToIR([]byte(src))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", stmt)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", stmt, err)
}
}
}
// Issue #2348 // Issue #2348
func TestSyntaxCompositeLit(t *testing.T) { func TestSyntaxCompositeLit(t *testing.T) {
cases := []struct { cases := []struct {

View File

@ -1,5 +1,6 @@
vec4 F0(void); vec4 F0(void);
vec4 F1(void); vec4 F1(void);
vec4 F2(void);
vec4 F0(void) { vec4 F0(void) {
int l0 = 0; int l0 = 0;
@ -24,3 +25,15 @@ vec4 F1(void) {
l3 = 2.5000000000e+00; l3 = 2.5000000000e+00;
return vec4(l0, l1, l2, l3); return vec4(l0, l1, l2, l3);
} }
vec4 F2(void) {
int l0 = 0;
float l1 = float(0);
float l2 = float(0);
float l3 = float(0);
l0 = 2;
l1 = 2.5000000000e+00;
l2 = 2.5000000000e+00;
l3 = 2.5000000000e+00;
return vec4(l0, l1, l2, l3);
}

View File

@ -15,3 +15,8 @@ func Foo2() vec4 {
var x3 = 5.0 / 2.0 var x3 = 5.0 / 2.0
return vec4(x0, x1, x2, x3) return vec4(x0, x1, x2, x3)
} }
func Foo3() vec4 {
var x0, x1, x2, x3 = 5 / 2, 5.0 / 2, 5 / 2.0, 5.0 / 2.0
return vec4(x0, x1, x2, x3)
}

View File

@ -0,0 +1,91 @@
// Copyright 2023 The Ebitengine Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shaderir
import (
"go/constant"
)
func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool {
if op == AndAnd || op == OrOr {
return lhst.Main == Bool && rhst.Main == Bool
}
if op == VectorEqualOp || op == VectorNotEqualOp {
return lhst.IsVector() && rhst.IsVector() && lhst.Equal(&rhst)
}
// Comparing matrices are forbidden (#2187).
if op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp || op == EqualOp || op == NotEqualOp {
if lhst.IsMatrix() || rhst.IsMatrix() {
return false
}
}
// If both are untyped consts, compare the constants and try to truncate them if necessary.
if lhst.Main == None && rhst.Main == None {
if lhs.Const.Kind() == rhs.Const.Kind() {
return true
}
if lhs.Const.Kind() == constant.Float && constant.ToFloat(rhs.Const).Kind() != constant.Unknown {
return true
}
if rhs.Const.Kind() == constant.Float && constant.ToFloat(lhs.Const).Kind() != constant.Unknown {
return true
}
if lhs.Const.Kind() == constant.Int && constant.ToInt(rhs.Const).Kind() != constant.Unknown {
return true
}
if rhs.Const.Kind() == constant.Int && constant.ToInt(lhs.Const).Kind() != constant.Unknown {
return true
}
return false
}
// If the types match, that's fine.
if lhst.Equal(&rhst) {
return true
}
// If lhs is untyped and rhs is not, compare the constant and the type and try to truncate the constant if necessary.
if lhst.Main == None {
if rhst.Main == Float {
return constant.ToFloat(lhs.Const).Kind() != constant.Unknown
}
if rhst.Main == Int {
return constant.ToInt(lhs.Const).Kind() != constant.Unknown
}
if rhst.Main == Bool {
return lhs.Const.Kind() == constant.Bool
}
return false
}
// Ditto.
if rhst.Main == None {
if lhst.Main == Float {
return constant.ToFloat(rhs.Const).Kind() != constant.Unknown
}
if lhst.Main == Int {
return constant.ToInt(rhs.Const).Kind() != constant.Unknown
}
if lhst.Main == Bool {
return rhs.Const.Kind() == constant.Bool
}
return false
}
return false
}

View File

@ -101,6 +101,7 @@ const (
Discard Discard
) )
// TODO: Remove ConstType (#2550)
type ConstType int type ConstType int
const ( const (
@ -114,11 +115,13 @@ type Expr struct {
Type ExprType Type ExprType
Exprs []Expr Exprs []Expr
Const constant.Value Const constant.Value
ConstType ConstType
BuiltinFunc BuiltinFunc BuiltinFunc BuiltinFunc
Swizzling string Swizzling string
Index int Index int
Op Op Op Op
// TODO: Remove ConstType (#2550)
ConstType ConstType
} }
type ExprType int type ExprType int
@ -183,6 +186,10 @@ func OpFromToken(t token.Token, lhs, rhs Type) (Op, bool) {
return ComponentWiseMul, true return ComponentWiseMul, true
case token.QUO: case token.QUO:
return Div, true return Div, true
case token.QUO_ASSIGN:
// QUO_ASSIGN indicates an integer division.
// https://pkg.go.dev/go/constant/#BinaryOp
return Div, true
case token.REM: case token.REM:
return ModOp, true return ModOp, true
case token.SHL: case token.SHL: