shader: Parse number literals in binary expressions correctly

Updates #1190
This commit is contained in:
Hajime Hoshi 2020-06-21 20:25:46 +09:00
parent afc39a100c
commit 29b70bea95
3 changed files with 91 additions and 9 deletions

View File

@ -384,8 +384,6 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl
if !ok { if !ok {
return nil, nil, nil, false return nil, nil, nil, false
} }
inits = append(inits, es...)
stmts = append(stmts, ss...)
if t.Main == shaderir.None { if t.Main == shaderir.None {
ts, ok := s.functionReturnTypes(block, init) ts, ok := s.functionReturnTypes(block, init)
@ -398,6 +396,18 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl
t = ts[0] t = ts[0]
} }
if es[0].Type == shaderir.NumberExpr {
switch t.Main {
case shaderir.Int:
es[0].ConstType = shaderir.ConstTypeInt
case shaderir.Float:
es[0].ConstType = shaderir.ConstTypeFloat
}
}
inits = append(inits, es...)
stmts = append(stmts, ss...)
default: default:
// Multiple-value context // Multiple-value context
@ -745,8 +755,8 @@ func (cs *compileState) assign(block *block, pos token.Pos, lhs, rhs []ast.Expr,
var rhsTypes []shaderir.Type var rhsTypes []shaderir.Type
for i, e := range lhs { for i, e := range lhs {
// Prase RHS first for the order of the statements.
if len(lhs) == len(rhs) { if len(lhs) == len(rhs) {
// Prase RHS first for the order of the statements.
r, origts, stmts, ok := cs.parseExpr(block, rhs[i]) r, origts, stmts, ok := cs.parseExpr(block, rhs[i])
if !ok { if !ok {
return false return false
@ -782,6 +792,15 @@ func (cs *compileState) assign(block *block, pos token.Pos, lhs, rhs []ast.Expr,
} }
block.ir.Stmts = append(block.ir.Stmts, stmts...) block.ir.Stmts = append(block.ir.Stmts, stmts...)
if r[0].Type == shaderir.NumberExpr {
switch block.vars[l[0].Index].typ.Main {
case shaderir.Int:
r[0].ConstType = shaderir.ConstTypeInt
case shaderir.Float:
r[0].ConstType = shaderir.ConstTypeFloat
}
}
block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{ block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{
Type: shaderir.Assign, Type: shaderir.Assign,
Exprs: []shaderir.Expr{l[0], r[0]}, Exprs: []shaderir.Expr{l[0], r[0]},
@ -845,12 +864,6 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
cs.addError(e.Pos(), fmt.Sprintf("literal not implemented: %#v", e)) cs.addError(e.Pos(), fmt.Sprintf("literal not implemented: %#v", e))
} }
case *ast.BinaryExpr: case *ast.BinaryExpr:
op, ok := shaderir.OpFromToken(e.Op)
if !ok {
cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", e.Op))
return nil, nil, nil, false
}
var stmts []shaderir.Stmt var stmts []shaderir.Stmt
// Prase LHS first for the order of the statements. // Prase LHS first for the order of the statements.
@ -876,6 +889,27 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
rhst := ts[0] rhst := ts[0]
if lhs[0].Type == shaderir.NumberExpr && rhs[0].Type == shaderir.NumberExpr {
op := e.Op
// https://golang.org/pkg/go/constant/#BinaryOp
// "To force integer division of Int operands, use op == token.QUO_ASSIGN instead of
// token.QUO; the result is guaranteed to be Int in this case."
if op == token.QUO && lhs[0].Const.Kind() == gconstant.Int && rhs[0].Const.Kind() == gconstant.Int {
op = token.QUO_ASSIGN
}
v := gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const)
t := shaderir.Type{Main: shaderir.Int}
if v.Kind() == gconstant.Float {
t = shaderir.Type{Main: shaderir.Float}
}
return []shaderir.Expr{
{
Type: shaderir.NumberExpr,
Const: v,
},
}, []shaderir.Type{t}, stmts, true
}
var t shaderir.Type var t shaderir.Type
if lhst.Equal(&rhst) { if lhst.Equal(&rhst) {
t = lhst t = lhst
@ -901,6 +935,12 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
} }
} }
op, ok := shaderir.OpFromToken(e.Op)
if !ok {
cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", e.Op))
return nil, nil, nil, false
}
return []shaderir.Expr{ return []shaderir.Expr{
{ {
Type: shaderir.Binary, Type: shaderir.Binary,

View File

@ -0,0 +1,25 @@
void F0(out vec4 l0) {
int l1 = 0;
float l2 = float(0);
float l3 = float(0);
float l4 = float(0);
l1 = 2;
l2 = 2.500000000e+00;
l3 = 2.500000000e+00;
l4 = 2.500000000e+00;
l0 = vec4(l1, l2, l3, l4);
return;
}
void F1(out vec4 l0) {
int l1 = 0;
float l2 = float(0);
float l3 = float(0);
float l4 = float(0);
l1 = 2;
l2 = 2.500000000e+00;
l3 = 2.500000000e+00;
l4 = 2.500000000e+00;
l0 = vec4(l1, l2, l3, l4);
return;
}

View File

@ -0,0 +1,17 @@
package main
func Foo1() vec4 {
x0 := 5 / 2
x1 := 5.0 / 2
x2 := 5 / 2.0
x3 := 5.0 / 2.0
return vec4(x0, x1, x2, x3)
}
func Foo2() vec4 {
var x0 = 5 / 2
var x1 = 5.0 / 2
var x2 = 5 / 2.0
var x3 = 5.0 / 2.0
return vec4(x0, x1, x2, x3)
}