internal/shader: Bug fix: Treat multiple constant definitions in one statement correctly

Updates #1192
This commit is contained in:
Hajime Hoshi 2021-04-09 00:32:15 +09:00
parent 3b6fa891ac
commit 1cdc6ea72b
5 changed files with 82 additions and 7 deletions

View File

@ -393,8 +393,9 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
if c, ok := block.findConstant(e.Name); ok { if c, ok := block.findConstant(e.Name); ok {
return []shaderir.Expr{ return []shaderir.Expr{
{ {
Type: shaderir.NumberExpr, Type: shaderir.NumberExpr,
Const: c.value, Const: c.value,
ConstType: c.ctyp,
}, },
}, []shaderir.Type{c.typ}, nil, true }, []shaderir.Type{c.typ}, nil, true
} }

View File

@ -33,6 +33,7 @@ type variable struct {
type constant struct { type constant struct {
name string name string
typ shaderir.Type typ shaderir.Type
ctyp shaderir.ConstType
value gconstant.Value value gconstant.Value
} }
@ -581,7 +582,7 @@ func (s *compileState) parseConstant(block *block, vs *ast.ValueSpec) ([]constan
} }
} }
es, _, ss, ok := s.parseExpr(block, vs.Values[i], false) es, ts, ss, ok := s.parseExpr(block, vs.Values[i], false)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -589,7 +590,7 @@ func (s *compileState) parseConstant(block *block, vs *ast.ValueSpec) ([]constan
s.addError(vs.Pos(), fmt.Sprintf("invalid constant expression: %s", name)) s.addError(vs.Pos(), fmt.Sprintf("invalid constant expression: %s", name))
return nil, false return nil, false
} }
if len(es) != 1 { if len(ts) != 1 || len(es) != 1 {
s.addError(vs.Pos(), fmt.Sprintf("invalid constant expression: %s", n)) s.addError(vs.Pos(), fmt.Sprintf("invalid constant expression: %s", n))
return nil, false return nil, false
} }
@ -600,6 +601,7 @@ func (s *compileState) parseConstant(block *block, vs *ast.ValueSpec) ([]constan
cs = append(cs, constant{ cs = append(cs, constant{
name: name, name: name,
typ: t, typ: t,
ctyp: es[0].ConstType,
value: es[0].Const, value: es[0].Const,
}) })
} }

View File

@ -520,7 +520,11 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
return nil, false return nil, false
} }
block.addNamedLocalVariable(name, ts[0], e.Pos()) t := ts[0]
if t.Main == shaderir.None {
t = toDefaultType(r[0].Const)
}
block.addNamedLocalVariable(name, t, e.Pos())
} }
if len(r) > 1 { if len(r) > 1 {
@ -582,8 +586,12 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
}) })
} else { } else {
// For variable swapping, use temporary variables. // For variable swapping, use temporary variables.
t := origts[0]
if t.Main == shaderir.None {
t = toDefaultType(r[0].Const)
}
block.vars = append(block.vars, variable{ block.vars = append(block.vars, variable{
typ: origts[0], typ: t,
}) })
idx := len(block.vars) - 1 idx := len(block.vars) - 1
stmts = append(stmts, stmts = append(stmts,
@ -632,7 +640,13 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
} }
} }
} }
block.addNamedLocalVariable(name, rhsTypes[i], e.Pos()) t := rhsTypes[i]
if t.Main == shaderir.None {
// TODO: This is to determine a type when the rhs is a constant,
// but there are no actual cases when len(lhs) != len(rhs). Is this correct?
t = toDefaultType(rhsExprs[i].Const)
}
block.addNamedLocalVariable(name, t, e.Pos())
} }
l, _, ss, ok := cs.parseExpr(block, lhs[i], false) l, _, ss, ok := cs.parseExpr(block, lhs[i], false)
@ -660,3 +674,16 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
return stmts, true return stmts, true
} }
func toDefaultType(v gconstant.Value) shaderir.Type {
switch v.Kind() {
case gconstant.Bool:
return shaderir.Type{Main: shaderir.Bool}
case gconstant.Int:
return shaderir.Type{Main: shaderir.Int}
case gconstant.Float:
return shaderir.Type{Main: shaderir.Float}
}
// TODO: Should this be an error?
return shaderir.Type{}
}

View File

@ -0,0 +1,22 @@
void F0(void);
void F0(void) {
int l0 = 0;
int l1 = 0;
float l2 = float(0);
int l3 = 0;
int l4 = 0;
int l5 = 0;
float l6 = float(0);
float l7 = float(0);
bool l8 = false;
l0 = 1;
l1 = 1;
l2 = 1.0;
l3 = 1;
l5 = 1;
l4 = l5;
l7 = 1.0;
l6 = l7;
l8 = false;
}

23
internal/shader/testdata/const3.go vendored Normal file
View File

@ -0,0 +1,23 @@
package main
const a = 1
const b float = 1
const c int = 1
const d, e = 1, 1.0
const f = false
func Foo() {
l0 := 1
la := a
lb := b
lc := c
ld, le := d, e
lf := f
_ = l0
_ = la
_ = lb
_ = lc
_ = ld
_ = le
_ = lf
}