internal/shader: bug fix: return true for float must fail

Closes #2706
This commit is contained in:
Hajime Hoshi 2023-07-25 02:41:41 +09:00
parent 29545906c0
commit b743b7ab50
2 changed files with 52 additions and 0 deletions

View File

@ -500,12 +500,22 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
}
if expr.Const != nil {
switch outT.Main {
case shaderir.Bool:
if expr.Const.Kind() != gconstant.Bool {
cs.addError(stmt.Pos(), fmt.Sprintf("cannot use type %s as type %s in return argument", t.String(), &outT))
return nil, false
}
t = shaderir.Type{Main: shaderir.Bool}
case shaderir.Int:
if !cs.forceToInt(stmt, &expr) {
return nil, false
}
t = shaderir.Type{Main: shaderir.Int}
case shaderir.Float:
if gconstant.ToFloat(expr.Const).Kind() == gconstant.Unknown {
cs.addError(stmt.Pos(), fmt.Sprintf("cannot use type %s as type %s in return argument", t.String(), &outT))
return nil, false
}
t = shaderir.Type{Main: shaderir.Float}
}
}

View File

@ -3249,3 +3249,45 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
t.Error(err)
}
}
// Issue #2706
func TestSyntaxReturnConst(t *testing.T) {
cases := []struct {
typ string
stmt string
err bool
}{
{typ: "bool", stmt: "true", err: false},
{typ: "int", stmt: "true", err: true},
{typ: "float", stmt: "true", err: true},
{typ: "bool", stmt: "1", err: true},
{typ: "int", stmt: "1", err: false},
{typ: "float", stmt: "1", err: false},
{typ: "bool", stmt: "1.0", err: true},
{typ: "int", stmt: "1.0", err: false},
{typ: "float", stmt: "1.0", err: false},
{typ: "bool", stmt: "1.1", err: true},
{typ: "int", stmt: "1.1", err: true},
{typ: "float", stmt: "1.1", err: false},
}
for _, c := range cases {
typ := c.typ
stmt := c.stmt
src := fmt.Sprintf(`package main
func Foo() %s {
return %s
}
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
return position
}`, typ, stmt)
_, err := compileToIR([]byte(src))
if err == nil && c.err {
t.Errorf("return %s for type %s must return an error but does not", stmt, typ)
} else if err != nil && !c.err {
t.Errorf("return %s for type %s must not return nil but returned %v", stmt, typ, err)
}
}
}