From 06bc569b738fcbe20f20dd02e2aa8c2f94743ec7 Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Fri, 20 Jan 2023 02:26:26 +0900 Subject: [PATCH] internal/shader: bug fix: wrongly typed constants were unexpectedly used Closes #2549 --- internal/shader/shader.go | 5 ++++- internal/shader/syntax_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/internal/shader/shader.go b/internal/shader/shader.go index b1b048334..1353c63ce 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -623,21 +623,24 @@ func (s *compileState) parseConstant(block *block, fname string, vs *ast.ValueSp return nil, false } + c := es[0].Const constType := es[0].ConstType switch t.Main { case shaderir.Bool: constType = shaderir.ConstTypeBool case shaderir.Int: constType = shaderir.ConstTypeInt + c = gconstant.ToInt(c) case shaderir.Float: constType = shaderir.ConstTypeFloat + c = gconstant.ToFloat(c) } cs = append(cs, constant{ name: name, typ: t, ctyp: constType, - value: es[0].Const, + value: c, }) } return cs, true diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 04efc5340..eab4aec62 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -2968,3 +2968,34 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { } } } + +// Issue #2549 +func TestConstType2(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "const x = 1; y := x*x; _ = vec4(1) / y", err: true}, + {stmt: "const x = 1.0; y := x*x; _ = vec4(1) / y", err: false}, + {stmt: "const x int = 1; y := x*x; _ = vec4(1) / y", err: true}, + {stmt: "const x int = 1.0; y := x*x; _ = vec4(1) / y", err: true}, + {stmt: "const x float = 1; y := x*x; _ = vec4(1) / y", err: false}, + {stmt: "const x float = 1.0; y := x*x; _ = vec4(1) / y", err: false}, + } + + for _, c := range cases { + stmt := c.stmt + src := fmt.Sprintf(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + %s + return position +}`, stmt) + _, err := compileToIR([]byte(src)) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", stmt, err) + } + } +}