internal/shader: bug fix: check the type for composite literal

Closes #2348
This commit is contained in:
Hajime Hoshi 2023-02-20 23:07:08 +09:00
parent 57025641a0
commit ad7d5a86f9
2 changed files with 36 additions and 4 deletions

View File

@ -1048,6 +1048,10 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
if !ok {
return nil, nil, nil, false
}
if t.Main != shaderir.Array {
cs.addError(e.Pos(), fmt.Sprintf("invalid composite literal type %s", t.String()))
return nil, nil, nil, false
}
if t.Main == shaderir.Array && t.Length == -1 {
t.Length = len(e.Elts)
}

View File

@ -2724,7 +2724,7 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
}
}
func TestTypeRedeclaration(t *testing.T) {
func TestSyntaxTypeRedeclaration(t *testing.T) {
cases := []struct {
stmt string
err bool
@ -2752,7 +2752,7 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
}
}
func TestSwizzling(t *testing.T) {
func TestSyntaxSwizzling(t *testing.T) {
cases := []struct {
stmt string
err bool
@ -2841,7 +2841,7 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
}
}
func TestConstType(t *testing.T) {
func TestSyntaxConstType(t *testing.T) {
cases := []struct {
stmt string
err bool
@ -2970,7 +2970,7 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
}
// Issue #2549
func TestConstType2(t *testing.T) {
func TestSyntaxConstType2(t *testing.T) {
cases := []struct {
stmt string
err bool
@ -2999,3 +2999,31 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
}
}
}
// Issue #2348
func TestSyntaxCompositeLit(t *testing.T) {
cases := []struct {
stmt string
err bool
}{
{stmt: "_ = undefined{1, 2, 3, 4}", err: true},
{stmt: "_ = int{1, 2, 3, 4}", err: true},
{stmt: "_ = vec4{1, 2, 3, 4}", err: true},
}
for _, c := range cases {
stmt := c.stmt
src := fmt.Sprintf(`package main
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
%s
return position
}`, stmt)
_, err := compileToIR([]byte(src))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", stmt)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", stmt, err)
}
}
}