diff --git a/internal/shader/expr.go b/internal/shader/expr.go index 8b9389201..6d148f53b 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -393,8 +393,9 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable if c, ok := block.findConstant(e.Name); ok { return []shaderir.Expr{ { - Type: shaderir.NumberExpr, - Const: c.value, + Type: shaderir.NumberExpr, + Const: c.value, + ConstType: c.ctyp, }, }, []shaderir.Type{c.typ}, nil, true } diff --git a/internal/shader/shader.go b/internal/shader/shader.go index c9b723511..cf698f5da 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -33,6 +33,7 @@ type variable struct { type constant struct { name string typ shaderir.Type + ctyp shaderir.ConstType value gconstant.Value } @@ -581,7 +582,7 @@ func (s *compileState) parseConstant(block *block, vs *ast.ValueSpec) ([]constan } } - es, _, ss, ok := s.parseExpr(block, vs.Values[i], false) + es, ts, ss, ok := s.parseExpr(block, vs.Values[i], false) if !ok { return nil, false } @@ -589,7 +590,7 @@ func (s *compileState) parseConstant(block *block, vs *ast.ValueSpec) ([]constan s.addError(vs.Pos(), fmt.Sprintf("invalid constant expression: %s", name)) return nil, false } - if len(es) != 1 { + if len(ts) != 1 || len(es) != 1 { s.addError(vs.Pos(), fmt.Sprintf("invalid constant expression: %s", n)) return nil, false } @@ -600,6 +601,7 @@ func (s *compileState) parseConstant(block *block, vs *ast.ValueSpec) ([]constan cs = append(cs, constant{ name: name, typ: t, + ctyp: es[0].ConstType, value: es[0].Const, }) } diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index b52947ab9..e653105b7 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -520,7 +520,11 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r return nil, false } - block.addNamedLocalVariable(name, ts[0], e.Pos()) + t := ts[0] + if t.Main == shaderir.None { + t = toDefaultType(r[0].Const) + } + block.addNamedLocalVariable(name, t, e.Pos()) } if len(r) > 1 { @@ -582,8 +586,12 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r }) } else { // For variable swapping, use temporary variables. + t := origts[0] + if t.Main == shaderir.None { + t = toDefaultType(r[0].Const) + } block.vars = append(block.vars, variable{ - typ: origts[0], + typ: t, }) idx := len(block.vars) - 1 stmts = append(stmts, @@ -632,7 +640,13 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r } } } - block.addNamedLocalVariable(name, rhsTypes[i], e.Pos()) + t := rhsTypes[i] + if t.Main == shaderir.None { + // TODO: This is to determine a type when the rhs is a constant, + // but there are no actual cases when len(lhs) != len(rhs). Is this correct? + t = toDefaultType(rhsExprs[i].Const) + } + block.addNamedLocalVariable(name, t, e.Pos()) } l, _, ss, ok := cs.parseExpr(block, lhs[i], false) @@ -660,3 +674,16 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r return stmts, true } + +func toDefaultType(v gconstant.Value) shaderir.Type { + switch v.Kind() { + case gconstant.Bool: + return shaderir.Type{Main: shaderir.Bool} + case gconstant.Int: + return shaderir.Type{Main: shaderir.Int} + case gconstant.Float: + return shaderir.Type{Main: shaderir.Float} + } + // TODO: Should this be an error? + return shaderir.Type{} +} diff --git a/internal/shader/testdata/const3.expected.vs b/internal/shader/testdata/const3.expected.vs new file mode 100644 index 000000000..f94bb34d8 --- /dev/null +++ b/internal/shader/testdata/const3.expected.vs @@ -0,0 +1,22 @@ +void F0(void); + +void F0(void) { + int l0 = 0; + int l1 = 0; + float l2 = float(0); + int l3 = 0; + int l4 = 0; + int l5 = 0; + float l6 = float(0); + float l7 = float(0); + bool l8 = false; + l0 = 1; + l1 = 1; + l2 = 1.0; + l3 = 1; + l5 = 1; + l4 = l5; + l7 = 1.0; + l6 = l7; + l8 = false; +} diff --git a/internal/shader/testdata/const3.go b/internal/shader/testdata/const3.go new file mode 100644 index 000000000..c3370e827 --- /dev/null +++ b/internal/shader/testdata/const3.go @@ -0,0 +1,23 @@ +package main + +const a = 1 +const b float = 1 +const c int = 1 +const d, e = 1, 1.0 +const f = false + +func Foo() { + l0 := 1 + la := a + lb := b + lc := c + ld, le := d, e + lf := f + _ = l0 + _ = la + _ = lb + _ = lc + _ = ld + _ = le + _ = lf +}