add return type for type resolving

This commit is contained in:
aoyako 2024-02-27 19:29:39 +09:00
parent c90d02f8d4
commit 7f9d997175
6 changed files with 124 additions and 52 deletions

View File

@ -101,7 +101,13 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
}
// Resolve untyped constants.
l, r, ok := shaderir.ResolveUntypedConstsForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst)
var l gconstant.Value
var r gconstant.Value
if op2 == shaderir.LeftShift || op2 == shaderir.RightShift {
l, r, ok = shaderir.ResolveUntypedConstsForBitShiftOp(lhs[0].Const, rhs[0].Const, lhst, rhst)
} else {
l, r, ok = shaderir.ResolveUntypedConstsForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst)
}
if !ok {
// TODO: Show a better type name for untyped constants.
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String()))
@ -109,27 +115,45 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
}
lhs[0].Const, rhs[0].Const = l, r
// If either is typed, resolve the other type.
// If both are untyped, keep them untyped.
if lhst.Main != shaderir.None || rhst.Main != shaderir.None {
if lhs[0].Const != nil {
switch lhs[0].Const.Kind() {
case gconstant.Float:
lhst = shaderir.Type{Main: shaderir.Float}
case gconstant.Int:
lhst = shaderir.Type{Main: shaderir.Int}
case gconstant.Bool:
lhst = shaderir.Type{Main: shaderir.Bool}
if op2 == shaderir.LeftShift || op2 == shaderir.RightShift {
if !(lhst.Main == shaderir.None && rhst.Main == shaderir.None) {
// If both are const
if rhs[0].Const != nil && (rhst.Main == shaderir.None || lhs[0].Const != nil) {
rhst = shaderir.Type{Main: shaderir.Int}
}
// If left is untyped const
if lhst.Main == shaderir.None && lhs[0].Const != nil {
if rhs[0].Const != nil {
lhst = shaderir.Type{Main: shaderir.Int}
} else {
lhst = shaderir.Type{Main: shaderir.DeducedInt}
}
}
}
if rhs[0].Const != nil {
switch rhs[0].Const.Kind() {
case gconstant.Float:
rhst = shaderir.Type{Main: shaderir.Float}
case gconstant.Int:
rhst = shaderir.Type{Main: shaderir.Int}
case gconstant.Bool:
rhst = shaderir.Type{Main: shaderir.Bool}
} else {
// If either is typed, resolve the other type.
// If both are untyped, keep them untyped.
if lhst.Main != shaderir.None || rhst.Main != shaderir.None {
if lhs[0].Const != nil {
switch lhs[0].Const.Kind() {
case gconstant.Float:
lhst = shaderir.Type{Main: shaderir.Float}
case gconstant.Int:
lhst = shaderir.Type{Main: shaderir.Int}
case gconstant.Bool:
lhst = shaderir.Type{Main: shaderir.Bool}
}
}
if rhs[0].Const != nil {
switch rhs[0].Const.Kind() {
case gconstant.Float:
rhst = shaderir.Type{Main: shaderir.Float}
case gconstant.Int:
rhst = shaderir.Type{Main: shaderir.Int}
case gconstant.Bool:
rhst = shaderir.Type{Main: shaderir.Bool}
}
}
}
}

View File

@ -514,6 +514,9 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
return nil, false
}
t := ts[0]
if t.Main == shaderir.DeducedInt {
cs.addError(pos, "invalid operation: shifted operand 1 (type float) must be integer")
}
if t.Main == shaderir.None {
t = toDefaultType(r[0].Const)
}
@ -705,6 +708,9 @@ func canAssign(lt *shaderir.Type, rt *shaderir.Type, rc gconstant.Value) bool {
if lt.Equal(rt) {
return true
}
if lt.Main == shaderir.Int && rt.Main == shaderir.DeducedInt {
return true
}
if rc == nil {
return false

View File

@ -1320,12 +1320,30 @@ func TestSyntaxOperatorShift(t *testing.T) {
stmt string
err bool
}{
{stmt: "a := 1 << 2.0; _ = a", err: true},
{stmt: "a := 1.0 << 2; _ = a", err: true},
{stmt: "a := 1.0 << 2.0; _ = a", err: true},
// {stmt: "s := 1; var a float = float(1 << s); _ = a", err: true},
// {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true},
// {stmt: "s := 1; var a int = int(1 << s); _ = a", err: false},
// {stmt: "s := 1; var a int = int(1.0 << s); _ = a", err: false},
{stmt: "s := 1; a := 1 << s; _ = a", err: false},
{stmt: "s := 1; a := 1.0 << s; _ = a", err: true},
{stmt: "s := 1; a := int(1.0 << s); _ = a", err: false},
{stmt: "var a float = 1.0 << 2.0; _ = a", err: false},
{stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true},
{stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 << s; _ = a", err: true},
{stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false},
{stmt: "s := 1; var a int = 1 << s; _ = a", err: false},
{stmt: "var a int = 1.0 << 2; _ = a", err: false},
{stmt: "var a float = 1.0 << 2; _ = a", err: false},
{stmt: "var a = 1.0 << 2; _ = a", err: false},
{stmt: "a := 1 << 2.0; _ = a", err: false},
{stmt: "a := 1.0 << 2; _ = a", err: false},
{stmt: "a := 1.0 << 2.0; _ = a", err: false},
{stmt: "a := 1 << 2; _ = a", err: false},
{stmt: "a := float(1.0) << 2; _ = a", err: true},
{stmt: "a := 1 << float(2.0); _ = a", err: true},
{stmt: "a := 1 << float(2.0); _ = a", err: false},
{stmt: "a := 1.0 << float(2.0); _ = a", err: false},
{stmt: "a := ivec2(1) << 2; _ = a", err: false},
{stmt: "a := 1 << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << float(2.0); _ = a", err: true},
@ -1345,9 +1363,22 @@ func TestSyntaxOperatorShift(t *testing.T) {
{stmt: "a := vec3(1) << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << vec3(2); _ = a", err: true},
{stmt: "s := 1; var a float = float(1.0 >> s); _ = a", err: true},
{stmt: "s := 1; var a int = int(1.0 >> s); _ = a", err: false},
{stmt: "s := 1; var a float = 1.0 >> s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 >> s; _ = a", err: true},
{stmt: "s := 1; var a int = 1.0 >> s; _ = a", err: false},
{stmt: "s := 1; var a int = 1 >> s; _ = a", err: false},
{stmt: "var a int = 1.0 >> 2; _ = a", err: false},
{stmt: "var a float = 1.0 >> 2; _ = a", err: false},
{stmt: "var a = 1.0 >> 2; _ = a", err: false},
{stmt: "a := 1 >> 2.0; _ = a", err: false},
{stmt: "a := 1.0 >> 2; _ = a", err: false},
{stmt: "a := 1.0 >> 2.0; _ = a", err: false},
{stmt: "a := 1 >> 2; _ = a", err: false},
{stmt: "a := float(1.0) >> 2; _ = a", err: true},
{stmt: "a := 1 >> float(2.0); _ = a", err: true},
{stmt: "a := 1 >> float(2.0); _ = a", err: false},
{stmt: "a := 1.0 >> float(2.0); _ = a", err: false},
{stmt: "a := ivec2(1) >> 2; _ = a", err: false},
{stmt: "a := 1 >> ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) >> float(2.0); _ = a", err: true},

View File

@ -165,7 +165,7 @@ func checkArgsForIntBuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) err
if len(args) != 1 {
return fmt.Errorf("number of int's arguments must be 1 but %d", len(args))
}
if argts[0].Main == shaderir.Int || argts[0].Main == shaderir.Float {
if argts[0].Main == shaderir.Int || argts[0].Main == shaderir.Float || argts[0].Main == shaderir.DeducedInt {
return nil
}
if args[0].Const != nil && gconstant.ToInt(args[0].Const).Kind() != gconstant.Unknown {

View File

@ -18,6 +18,29 @@ import (
"go/constant"
)
func ResolveUntypedConstsForBitShiftOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) {
cLhs := lhs
cRhs := rhs
// Right is const -> int
if rhs != nil {
cRhs = constant.ToInt(rhs)
if cRhs.Kind() == constant.Unknown {
return nil, nil, false
}
}
// Left if untyped const -> int
if lhs != nil && lhst.Main == None {
cLhs = constant.ToInt(lhs)
if cLhs.Kind() == constant.Unknown {
return nil, nil, false
}
}
return cLhs, cRhs, true
}
func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) {
if lhst.Main == None && rhst.Main == None {
if lhs.Kind() == rhs.Kind() {
@ -98,13 +121,6 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value)
return Type{}, false
}
if op == LeftShift || op == RightShift {
if lhsConst.Kind() == constant.Int && rhsConst.Kind() == constant.Int {
return Type{Main: Int}, true
}
return Type{}, false
}
if op == EqualOp || op == NotEqualOp || op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp {
return Type{Main: Bool}, true
}
@ -128,6 +144,19 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value)
panic("shaderir: cannot resolve untyped values")
}
if op == LeftShift || op == RightShift {
if (lhst.Main == Int || lhst.Main == DeducedInt) && rhst.Main == Int {
return Type{Main: lhst.Main}, true
}
if lhst.IsIntVector() && rhst.Main == Int {
return Type{Main: lhst.Main}, true
}
if lhst.IsIntVector() && rhst.IsIntVector() && lhst.VectorElementCount() == rhst.VectorElementCount() {
return Type{Main: lhst.Main}, true
}
return Type{}, false
}
if op == AndAnd || op == OrOr {
if lhst.Main == Bool && rhst.Main == Bool {
return Type{Main: Bool}, true
@ -202,25 +231,6 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value)
return Type{}, false
}
if op == LeftShift || op == RightShift {
if lhst.Main == Int && rhst.Main == Int {
return Type{Main: Int}, true
}
if lhst.Main == IVec2 && rhst.Main == IVec2 {
return Type{Main: IVec2}, true
}
if lhst.Main == IVec3 && rhst.Main == IVec3 {
return Type{Main: IVec3}, true
}
if lhst.Main == IVec4 && rhst.Main == IVec4 {
return Type{Main: IVec4}, true
}
if lhst.IsIntVector() && rhst.Main == Int {
return lhst, true
}
return Type{}, false
}
if lhst.Equal(&rhst) {
if lhst.Main == None {
return rhst, true

View File

@ -180,6 +180,7 @@ const (
Texture
Array
Struct
DeducedInt
)
func descendantLocalVars(block, target *Block) ([]Type, bool) {