internal/shader: bug fix: wrongly typed constants were unexpectedly used

Closes #2549
This commit is contained in:
Hajime Hoshi 2023-01-20 02:26:26 +09:00
parent f054a7634a
commit 06bc569b73
2 changed files with 35 additions and 1 deletions

View File

@ -623,21 +623,24 @@ func (s *compileState) parseConstant(block *block, fname string, vs *ast.ValueSp
return nil, false return nil, false
} }
c := es[0].Const
constType := es[0].ConstType constType := es[0].ConstType
switch t.Main { switch t.Main {
case shaderir.Bool: case shaderir.Bool:
constType = shaderir.ConstTypeBool constType = shaderir.ConstTypeBool
case shaderir.Int: case shaderir.Int:
constType = shaderir.ConstTypeInt constType = shaderir.ConstTypeInt
c = gconstant.ToInt(c)
case shaderir.Float: case shaderir.Float:
constType = shaderir.ConstTypeFloat constType = shaderir.ConstTypeFloat
c = gconstant.ToFloat(c)
} }
cs = append(cs, constant{ cs = append(cs, constant{
name: name, name: name,
typ: t, typ: t,
ctyp: constType, ctyp: constType,
value: es[0].Const, value: c,
}) })
} }
return cs, true return cs, true

View File

@ -2968,3 +2968,34 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
} }
} }
} }
// Issue #2549
func TestConstType2(t *testing.T) {
cases := []struct {
stmt string
err bool
}{
{stmt: "const x = 1; y := x*x; _ = vec4(1) / y", err: true},
{stmt: "const x = 1.0; y := x*x; _ = vec4(1) / y", err: false},
{stmt: "const x int = 1; y := x*x; _ = vec4(1) / y", err: true},
{stmt: "const x int = 1.0; y := x*x; _ = vec4(1) / y", err: true},
{stmt: "const x float = 1; y := x*x; _ = vec4(1) / y", err: false},
{stmt: "const x float = 1.0; y := x*x; _ = vec4(1) / y", err: false},
}
for _, c := range cases {
stmt := c.stmt
src := fmt.Sprintf(`package main
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
%s
return position
}`, stmt)
_, err := compileToIR([]byte(src))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", stmt)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", stmt, err)
}
}
}