diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index 067fe669e..c6406ac1f 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -500,12 +500,22 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP } if expr.Const != nil { switch outT.Main { + case shaderir.Bool: + if expr.Const.Kind() != gconstant.Bool { + cs.addError(stmt.Pos(), fmt.Sprintf("cannot use type %s as type %s in return argument", t.String(), &outT)) + return nil, false + } + t = shaderir.Type{Main: shaderir.Bool} case shaderir.Int: if !cs.forceToInt(stmt, &expr) { return nil, false } t = shaderir.Type{Main: shaderir.Int} case shaderir.Float: + if gconstant.ToFloat(expr.Const).Kind() == gconstant.Unknown { + cs.addError(stmt.Pos(), fmt.Sprintf("cannot use type %s as type %s in return argument", t.String(), &outT)) + return nil, false + } t = shaderir.Type{Main: shaderir.Float} } } diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 2aa23c310..029338bf3 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -3249,3 +3249,45 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { t.Error(err) } } + +// Issue #2706 +func TestSyntaxReturnConst(t *testing.T) { + cases := []struct { + typ string + stmt string + err bool + }{ + {typ: "bool", stmt: "true", err: false}, + {typ: "int", stmt: "true", err: true}, + {typ: "float", stmt: "true", err: true}, + {typ: "bool", stmt: "1", err: true}, + {typ: "int", stmt: "1", err: false}, + {typ: "float", stmt: "1", err: false}, + {typ: "bool", stmt: "1.0", err: true}, + {typ: "int", stmt: "1.0", err: false}, + {typ: "float", stmt: "1.0", err: false}, + {typ: "bool", stmt: "1.1", err: true}, + {typ: "int", stmt: "1.1", err: true}, + {typ: "float", stmt: "1.1", err: false}, + } + + for _, c := range cases { + typ := c.typ + stmt := c.stmt + src := fmt.Sprintf(`package main + +func Foo() %s { + return %s +} + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + return position +}`, typ, stmt) + _, err := compileToIR([]byte(src)) + if err == nil && c.err { + t.Errorf("return %s for type %s must return an error but does not", stmt, typ) + } else if err != nil && !c.err { + t.Errorf("return %s for type %s must not return nil but returned %v", stmt, typ, err) + } + } +}