mirror of
https://github.com/hajimehoshi/ebiten.git
synced 2025-01-11 19:48:54 +01:00
internal/shader: stricter const type check
This commit is contained in:
parent
685a6acb05
commit
5d8216def3
@ -731,7 +731,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
|
||||
}
|
||||
}
|
||||
|
||||
if !canAssign(&args[i], &p, &argts[i]) {
|
||||
if !canAssign(&p, &argts[i], args[i].Const) {
|
||||
cs.addError(e.Pos(), fmt.Sprintf("cannot use type %s as type %s in argument", argts[i].String(), p.String()))
|
||||
return nil, nil, nil, false
|
||||
}
|
||||
|
@ -507,7 +507,7 @@ func (s *compileState) parseVariable(block *block, fname string, vs *ast.ValueSp
|
||||
}
|
||||
|
||||
for i, rt := range rts {
|
||||
if !canAssign(&es[i], &t, &rt) {
|
||||
if !canAssign(&t, &rt, es[i].Const) {
|
||||
s.addError(vs.Pos(), fmt.Sprintf("cannot use type %s as type %s in variable declaration", rt.String(), t.String()))
|
||||
}
|
||||
}
|
||||
@ -545,7 +545,7 @@ func (s *compileState) parseVariable(block *block, fname string, vs *ast.ValueSp
|
||||
t = inittypes[i]
|
||||
}
|
||||
|
||||
if !canAssign(&initexprs[i], &t, &inittypes[i]) {
|
||||
if !canAssign(&t, &inittypes[i], initexprs[i].Const) {
|
||||
s.addError(vs.Pos(), fmt.Sprintf("cannot use type %s as type %s in variable declaration", inittypes[i].String(), t.String()))
|
||||
}
|
||||
|
||||
@ -617,10 +617,26 @@ func (s *compileState) parseConstant(block *block, fname string, vs *ast.ValueSp
|
||||
s.addError(vs.Pos(), fmt.Sprintf("constant expression must be a number but not: %s", n))
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if !t.Equal(&shaderir.Type{}) && !canAssign(&t, &ts[0], es[0].Const) {
|
||||
s.addError(vs.Pos(), fmt.Sprintf("cannot use %v as %s value in constant declaration", es[0].Const, t.String()))
|
||||
return nil, false
|
||||
}
|
||||
|
||||
constType := es[0].ConstType
|
||||
switch t.Main {
|
||||
case shaderir.Bool:
|
||||
constType = shaderir.ConstTypeBool
|
||||
case shaderir.Int:
|
||||
constType = shaderir.ConstTypeInt
|
||||
case shaderir.Float:
|
||||
constType = shaderir.ConstTypeFloat
|
||||
}
|
||||
|
||||
cs = append(cs, constant{
|
||||
name: name,
|
||||
typ: t,
|
||||
ctyp: es[0].ConstType,
|
||||
ctyp: constType,
|
||||
value: es[0].Const,
|
||||
})
|
||||
}
|
||||
|
@ -667,7 +667,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
|
||||
}
|
||||
|
||||
for i := range lts {
|
||||
if !canAssign(&r[i], <s[i], &rts[i]) {
|
||||
if !canAssign(<s[i], &rts[i], r[i].Const) {
|
||||
cs.addError(pos, fmt.Sprintf("cannot use type %s as type %s in variable declaration", rts[i].String(), lts[i].String()))
|
||||
return nil, false
|
||||
}
|
||||
@ -755,7 +755,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
|
||||
allblank = false
|
||||
|
||||
for i, lt := range lts {
|
||||
if !canAssign(&rhsExprs[i], <, &rhsTypes[i]) {
|
||||
if !canAssign(<, &rhsTypes[i], rhsExprs[i].Const) {
|
||||
cs.addError(pos, fmt.Sprintf("cannot use type %s as type %s in variable declaration", rhsTypes[i].String(), lt.String()))
|
||||
return nil, false
|
||||
}
|
||||
@ -789,19 +789,23 @@ func toDefaultType(v gconstant.Value) shaderir.Type {
|
||||
return shaderir.Type{}
|
||||
}
|
||||
|
||||
func canAssign(re *shaderir.Expr, lt *shaderir.Type, rt *shaderir.Type) bool {
|
||||
func canAssign(lt *shaderir.Type, rt *shaderir.Type, rc gconstant.Value) bool {
|
||||
if lt.Equal(rt) {
|
||||
return true
|
||||
}
|
||||
if re.Const != nil {
|
||||
switch lt.Main {
|
||||
case shaderir.Bool:
|
||||
return re.Const.Kind() == gconstant.Bool
|
||||
case shaderir.Int:
|
||||
return canTruncateToInteger(re.Const)
|
||||
case shaderir.Float:
|
||||
return gconstant.ToFloat(re.Const).Kind() != gconstant.Unknown
|
||||
}
|
||||
}
|
||||
|
||||
if rc == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch lt.Main {
|
||||
case shaderir.Bool:
|
||||
return rc.Kind() == gconstant.Bool
|
||||
case shaderir.Int:
|
||||
return canTruncateToInteger(rc)
|
||||
case shaderir.Float:
|
||||
return gconstant.ToFloat(rc).Kind() != gconstant.Unknown
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
@ -1127,6 +1127,10 @@ func TestSyntaxOperatorMultiplyAssign(t *testing.T) {
|
||||
{stmt: "a := 1.0; a *= 2.0", err: false},
|
||||
{stmt: "const c = 2; a := 1.0; a *= c", err: false},
|
||||
{stmt: "const c = 2.0; a := 1.0; a *= c", err: false},
|
||||
{stmt: "const c int = 2; a := 1.0; a *= c", err: true},
|
||||
{stmt: "const c int = 2.0; a := 1.0; a *= c", err: true},
|
||||
{stmt: "const c float = 2; a := 1.0; a *= c", err: false},
|
||||
{stmt: "const c float = 2.0; a := 1.0; a *= c", err: false},
|
||||
{stmt: "a := 1.0; a *= int(2)", err: true},
|
||||
{stmt: "a := 1.0; a *= vec2(2)", err: true},
|
||||
{stmt: "a := 1.0; a *= vec3(2)", err: true},
|
||||
@ -1134,10 +1138,15 @@ func TestSyntaxOperatorMultiplyAssign(t *testing.T) {
|
||||
{stmt: "a := 1.0; a *= mat2(2)", err: true},
|
||||
{stmt: "a := 1.0; a *= mat3(2)", err: true},
|
||||
{stmt: "a := 1.0; a *= mat4(2)", err: true},
|
||||
|
||||
{stmt: "a := vec2(1); a *= 2", err: false},
|
||||
{stmt: "a := vec2(1); a *= 2.0", err: false},
|
||||
{stmt: "const c = 2; a := vec2(1); a *= c", err: false},
|
||||
{stmt: "const c = 2.0; a := vec2(1); a *= c", err: false},
|
||||
{stmt: "const c int = 2; a := vec2(1); a *= c", err: true},
|
||||
{stmt: "const c int = 2.0; a := vec2(1); a *= c", err: true},
|
||||
{stmt: "const c float = 2; a := vec2(1); a *= c", err: false},
|
||||
{stmt: "const c float = 2.0; a := vec2(1); a *= c", err: false},
|
||||
{stmt: "a := vec2(1); a /= 2.0", err: false},
|
||||
{stmt: "a := vec2(1); a += 2.0", err: false},
|
||||
{stmt: "a := vec2(1); a *= int(2)", err: true},
|
||||
@ -1152,10 +1161,15 @@ func TestSyntaxOperatorMultiplyAssign(t *testing.T) {
|
||||
{stmt: "a := vec2(1); a /= mat2(2)", err: true},
|
||||
{stmt: "a := vec2(1); a *= mat3(2)", err: true},
|
||||
{stmt: "a := vec2(1); a *= mat4(2)", err: true},
|
||||
|
||||
{stmt: "a := mat2(1); a *= 2", err: false},
|
||||
{stmt: "a := mat2(1); a *= 2.0", err: false},
|
||||
{stmt: "const c = 2; a := mat2(1); a *= c", err: false},
|
||||
{stmt: "const c = 2.0; a := mat2(1); a *= c", err: false},
|
||||
{stmt: "const c int = 2; a := mat2(1); a *= c", err: true},
|
||||
{stmt: "const c int = 2.0; a := mat2(1); a *= c", err: true},
|
||||
{stmt: "const c float = 2; a := mat2(1); a *= c", err: false},
|
||||
{stmt: "const c float = 2.0; a := mat2(1); a *= c", err: false},
|
||||
{stmt: "a := mat2(1); a /= 2.0", err: false},
|
||||
{stmt: "a := mat2(1); a += 2.0", err: true},
|
||||
{stmt: "a := mat2(1); a *= int(2)", err: true},
|
||||
@ -2465,3 +2479,56 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstType(t *testing.T) {
|
||||
cases := []struct {
|
||||
stmt string
|
||||
err bool
|
||||
}{
|
||||
{stmt: "const a = false", err: false},
|
||||
{stmt: "const a bool = false", err: false},
|
||||
{stmt: "const a int = false", err: true},
|
||||
{stmt: "const a float = false", err: true},
|
||||
{stmt: "const a vec2 = false", err: true},
|
||||
|
||||
{stmt: "const a = 1", err: false},
|
||||
{stmt: "const a bool = 1", err: true},
|
||||
{stmt: "const a int = 1", err: false},
|
||||
{stmt: "const a float = 1", err: false},
|
||||
{stmt: "const a vec2 = 1", err: true},
|
||||
|
||||
{stmt: "const a = 1.0", err: false},
|
||||
{stmt: "const a bool = 1.0", err: true},
|
||||
{stmt: "const a int = 1.0", err: false},
|
||||
{stmt: "const a float = 1.0", err: false},
|
||||
{stmt: "const a vec2 = 1.0", err: true},
|
||||
|
||||
{stmt: "const a = 1.1", err: false},
|
||||
{stmt: "const a bool = 1.1", err: true},
|
||||
{stmt: "const a int = 1.1", err: true},
|
||||
{stmt: "const a float = 1.1", err: false},
|
||||
{stmt: "const a vec2 = 1.1", err: true},
|
||||
|
||||
{stmt: "const a = vec2(0)", err: true},
|
||||
{stmt: "const a bool = vec2(0)", err: true},
|
||||
{stmt: "const a int = vec2(0)", err: true},
|
||||
{stmt: "const a float = vec2(0)", err: true},
|
||||
{stmt: "const a vec2 = vec2(0)", err: true},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
stmt := c.stmt
|
||||
src := fmt.Sprintf(`package main
|
||||
|
||||
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
|
||||
%s
|
||||
return position
|
||||
}`, stmt)
|
||||
_, err := compileToIR([]byte(src))
|
||||
if err == nil && c.err {
|
||||
t.Errorf("%s must return an error but does not", stmt)
|
||||
} else if err != nil && !c.err {
|
||||
t.Errorf("%s must not return nil but returned %v", stmt, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user