diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index 86cd576ac..450adf596 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -121,7 +121,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP } case shaderir.Float: if rhs[0].Const != nil && - rhs[0].ConstType != shaderir.ConstTypeInt && + (rts[0].Main == shaderir.None || rts[0].Main == shaderir.Float) && gconstant.ToFloat(rhs[0].Const).Kind() != gconstant.Unknown { rhs[0].Const = gconstant.ToFloat(rhs[0].Const) rhs[0].ConstType = shaderir.ConstTypeFloat @@ -133,7 +133,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP if (op == shaderir.MatrixMul || op == shaderir.Div) && (rts[0].Main == shaderir.Float || (rhs[0].Const != nil && - rhs[0].ConstType != shaderir.ConstTypeInt && + (rts[0].Main == shaderir.None || rts[0].Main == shaderir.Float) && gconstant.ToFloat(rhs[0].Const).Kind() != gconstant.Unknown)) { if rhs[0].Const != nil { rhs[0].Const = gconstant.ToFloat(rhs[0].Const) @@ -146,7 +146,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP } else if (op == shaderir.MatrixMul || op == shaderir.ComponentWiseMul || lts[0].IsFloatVector()) && (rts[0].Main == shaderir.Float || (rhs[0].Const != nil && - rhs[0].ConstType != shaderir.ConstTypeInt && + (rts[0].Main == shaderir.None || rts[0].Main == shaderir.Float) && gconstant.ToFloat(rhs[0].Const).Kind() != gconstant.Unknown)) { if rhs[0].Const != nil { rhs[0].Const = gconstant.ToFloat(rhs[0].Const)