internal/shader: more strict type checks with built-in functions

This commit is contained in:
Hajime Hoshi 2022-11-20 17:18:24 +09:00
parent 5d8216def3
commit dc1df824a5
3 changed files with 93 additions and 11 deletions

View File

@ -153,19 +153,21 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
case *ast.BasicLit: case *ast.BasicLit:
switch e.Kind { switch e.Kind {
case token.INT: case token.INT:
// The type is not determined yet.
return []shaderir.Expr{ return []shaderir.Expr{
{ {
Type: shaderir.NumberExpr, Type: shaderir.NumberExpr,
Const: gconstant.MakeFromLiteral(e.Value, e.Kind, 0), Const: gconstant.MakeFromLiteral(e.Value, e.Kind, 0),
}, },
}, []shaderir.Type{{Main: shaderir.Int}}, nil, true }, []shaderir.Type{{}}, nil, true
case token.FLOAT: case token.FLOAT:
// The type is not determined yet.
return []shaderir.Expr{ return []shaderir.Expr{
{ {
Type: shaderir.NumberExpr, Type: shaderir.NumberExpr,
Const: gconstant.MakeFromLiteral(e.Value, e.Kind, 0), Const: gconstant.MakeFromLiteral(e.Value, e.Kind, 0),
}, },
}, []shaderir.Type{{Main: shaderir.Float}}, nil, true }, []shaderir.Type{{}}, nil, true
default: default:
cs.addError(e.Pos(), fmt.Sprintf("literal not implemented: %#v", e)) cs.addError(e.Pos(), fmt.Sprintf("literal not implemented: %#v", e))
} }
@ -444,6 +446,20 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
ConstType: shaderir.ConstTypeInt, ConstType: shaderir.ConstTypeInt,
}, },
}, []shaderir.Type{{Main: shaderir.Int}}, stmts, true }, []shaderir.Type{{Main: shaderir.Int}}, stmts, true
case shaderir.BoolF:
if len(args) == 1 && args[0].Const != nil {
if args[0].Const.Kind() != gconstant.Bool {
cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type bool", args[0].Const.String()))
return nil, nil, nil, false
}
return []shaderir.Expr{
{
Type: shaderir.NumberExpr,
Const: args[0].Const,
ConstType: shaderir.ConstTypeBool,
},
}, []shaderir.Type{{Main: shaderir.Bool}}, stmts, true
}
case shaderir.IntF: case shaderir.IntF:
if len(args) == 1 && args[0].Const != nil { if len(args) == 1 && args[0].Const != nil {
if !canTruncateToInteger(args[0].Const) { if !canTruncateToInteger(args[0].Const) {

View File

@ -119,20 +119,28 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
return nil, false return nil, false
} }
case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4:
if (op == shaderir.MatrixMul || op == shaderir.Div) && rts[0].Main == shaderir.Float { if (op == shaderir.MatrixMul || op == shaderir.Div) &&
// OK (rts[0].Main == shaderir.Float ||
} else if lts[0].IsVector() && rts[0].Main == shaderir.Float { (rhs[0].Const != nil &&
// OK rhs[0].ConstType != shaderir.ConstTypeInt &&
gconstant.ToFloat(rhs[0].Const).Kind() != gconstant.Unknown)) {
if rhs[0].Const != nil {
rhs[0].Const = gconstant.ToFloat(rhs[0].Const)
rhs[0].ConstType = shaderir.ConstTypeFloat
}
} else if op == shaderir.MatrixMul && ((lts[0].Main == shaderir.Vec2 && rts[0].Main == shaderir.Mat2) || } else if op == shaderir.MatrixMul && ((lts[0].Main == shaderir.Vec2 && rts[0].Main == shaderir.Mat2) ||
(lts[0].Main == shaderir.Vec3 && rts[0].Main == shaderir.Mat3) || (lts[0].Main == shaderir.Vec3 && rts[0].Main == shaderir.Mat3) ||
(lts[0].Main == shaderir.Vec4 && rts[0].Main == shaderir.Mat4)) { (lts[0].Main == shaderir.Vec4 && rts[0].Main == shaderir.Mat4)) {
// OK // OK
} else if (op == shaderir.MatrixMul || op == shaderir.ComponentWiseMul || lts[0].IsVector()) && } else if (op == shaderir.MatrixMul || op == shaderir.ComponentWiseMul || lts[0].IsVector()) &&
rhs[0].Const != nil && (rts[0].Main == shaderir.Float ||
rhs[0].ConstType != shaderir.ConstTypeInt && (rhs[0].Const != nil &&
gconstant.ToFloat(rhs[0].Const).Kind() != gconstant.Unknown { rhs[0].ConstType != shaderir.ConstTypeInt &&
rhs[0].Const = gconstant.ToFloat(rhs[0].Const) gconstant.ToFloat(rhs[0].Const).Kind() != gconstant.Unknown)) {
rhs[0].ConstType = shaderir.ConstTypeFloat if rhs[0].Const != nil {
rhs[0].Const = gconstant.ToFloat(rhs[0].Const)
rhs[0].ConstType = shaderir.ConstTypeFloat
}
} else { } else {
cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: mismatched types %s and %s", lts[0].String(), rts[0].String())) cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: mismatched types %s and %s", lts[0].String(), rts[0].String()))
return nil, false return nil, false
@ -798,6 +806,10 @@ func canAssign(lt *shaderir.Type, rt *shaderir.Type, rc gconstant.Value) bool {
return false return false
} }
if !rt.Equal(&shaderir.Type{}) {
return false
}
switch lt.Main { switch lt.Main {
case shaderir.Bool: case shaderir.Bool:
return rc.Kind() == gconstant.Bool return rc.Kind() == gconstant.Bool

View File

@ -2491,24 +2491,78 @@ func TestConstType(t *testing.T) {
{stmt: "const a float = false", err: true}, {stmt: "const a float = false", err: true},
{stmt: "const a vec2 = false", err: true}, {stmt: "const a vec2 = false", err: true},
{stmt: "const a = bool(false)", err: false},
{stmt: "const a bool = bool(false)", err: false},
{stmt: "const a int = bool(false)", err: true},
{stmt: "const a float = bool(false)", err: true},
{stmt: "const a vec2 = bool(false)", err: true},
{stmt: "const a = int(false)", err: true},
{stmt: "const a bool = int(false)", err: true},
{stmt: "const a int = int(false)", err: true},
{stmt: "const a float = int(false)", err: true},
{stmt: "const a vec2 = int(false)", err: true},
{stmt: "const a = float(false)", err: true},
{stmt: "const a bool = float(false)", err: true},
{stmt: "const a int = float(false)", err: true},
{stmt: "const a float = float(false)", err: true},
{stmt: "const a vec2 = float(false)", err: true},
{stmt: "const a = 1", err: false}, {stmt: "const a = 1", err: false},
{stmt: "const a bool = 1", err: true}, {stmt: "const a bool = 1", err: true},
{stmt: "const a int = 1", err: false}, {stmt: "const a int = 1", err: false},
{stmt: "const a float = 1", err: false}, {stmt: "const a float = 1", err: false},
{stmt: "const a vec2 = 1", err: true}, {stmt: "const a vec2 = 1", err: true},
{stmt: "const a = int(1)", err: false},
{stmt: "const a bool = int(1)", err: true},
{stmt: "const a int = int(1)", err: false},
{stmt: "const a float = int(1)", err: true},
{stmt: "const a vec2 = int(1)", err: true},
{stmt: "const a = float(1)", err: false},
{stmt: "const a bool = float(1)", err: true},
{stmt: "const a int = float(1)", err: true},
{stmt: "const a float = float(1)", err: false},
{stmt: "const a vec2 = float(1)", err: true},
{stmt: "const a = 1.0", err: false}, {stmt: "const a = 1.0", err: false},
{stmt: "const a bool = 1.0", err: true}, {stmt: "const a bool = 1.0", err: true},
{stmt: "const a int = 1.0", err: false}, {stmt: "const a int = 1.0", err: false},
{stmt: "const a float = 1.0", err: false}, {stmt: "const a float = 1.0", err: false},
{stmt: "const a vec2 = 1.0", err: true}, {stmt: "const a vec2 = 1.0", err: true},
{stmt: "const a = int(1.0)", err: false},
{stmt: "const a bool = int(1.0)", err: true},
{stmt: "const a int = int(1.0)", err: false},
{stmt: "const a float = int(1.0)", err: true},
{stmt: "const a vec2 = int(1.0)", err: true},
{stmt: "const a = float(1.0)", err: false},
{stmt: "const a bool = float(1.0)", err: true},
{stmt: "const a int = float(1.0)", err: true},
{stmt: "const a float = float(1.0)", err: false},
{stmt: "const a vec2 = float(1.0)", err: true},
{stmt: "const a = 1.1", err: false}, {stmt: "const a = 1.1", err: false},
{stmt: "const a bool = 1.1", err: true}, {stmt: "const a bool = 1.1", err: true},
{stmt: "const a int = 1.1", err: true}, {stmt: "const a int = 1.1", err: true},
{stmt: "const a float = 1.1", err: false}, {stmt: "const a float = 1.1", err: false},
{stmt: "const a vec2 = 1.1", err: true}, {stmt: "const a vec2 = 1.1", err: true},
{stmt: "const a = int(1.1)", err: true},
{stmt: "const a bool = int(1.1)", err: true},
{stmt: "const a int = int(1.1)", err: true},
{stmt: "const a float = int(1.1)", err: true},
{stmt: "const a vec2 = int(1.1)", err: true},
{stmt: "const a = float(1.1)", err: false},
{stmt: "const a bool = float(1.1)", err: true},
{stmt: "const a int = float(1.1)", err: true},
{stmt: "const a float = float(1.1)", err: false},
{stmt: "const a vec2 = float(1.1)", err: true},
{stmt: "const a = vec2(0)", err: true}, {stmt: "const a = vec2(0)", err: true},
{stmt: "const a bool = vec2(0)", err: true}, {stmt: "const a bool = vec2(0)", err: true},
{stmt: "const a int = vec2(0)", err: true}, {stmt: "const a int = vec2(0)", err: true},