internal/shader: bug fix: needed to resolve const and non-const types

Closes #2922
This commit is contained in:
Hajime Hoshi 2024-03-10 19:14:05 +09:00
parent c9a973c6c1
commit 63e97c7064
2 changed files with 83 additions and 37 deletions

View File

@ -276,26 +276,26 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
} }
// Process the expression as a regular function call. // Process the expression as a regular function call.
var t shaderir.Type var finalType shaderir.Type
switch callee.BuiltinFunc { switch callee.BuiltinFunc {
case shaderir.BoolF: case shaderir.BoolF:
if err := checkArgsForBoolBuiltinFunc(args, argts); err != nil { if err := checkArgsForBoolBuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error()) cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false return nil, nil, nil, false
} }
t = shaderir.Type{Main: shaderir.Bool} finalType = shaderir.Type{Main: shaderir.Bool}
case shaderir.IntF: case shaderir.IntF:
if err := checkArgsForIntBuiltinFunc(args, argts); err != nil { if err := checkArgsForIntBuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error()) cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false return nil, nil, nil, false
} }
t = shaderir.Type{Main: shaderir.Int} finalType = shaderir.Type{Main: shaderir.Int}
case shaderir.FloatF: case shaderir.FloatF:
if err := checkArgsForFloatBuiltinFunc(args, argts); err != nil { if err := checkArgsForFloatBuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error()) cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false return nil, nil, nil, false
} }
t = shaderir.Type{Main: shaderir.Float} finalType = shaderir.Type{Main: shaderir.Float}
case shaderir.Vec2F: case shaderir.Vec2F:
if err := checkArgsForVec2BuiltinFunc(args, argts); err != nil { if err := checkArgsForVec2BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error()) cs.addError(e.Pos(), err.Error())
@ -308,7 +308,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
args[i].Const = gconstant.ToFloat(args[i].Const) args[i].Const = gconstant.ToFloat(args[i].Const)
argts[i] = shaderir.Type{Main: shaderir.Float} argts[i] = shaderir.Type{Main: shaderir.Float}
} }
t = shaderir.Type{Main: shaderir.Vec2} finalType = shaderir.Type{Main: shaderir.Vec2}
case shaderir.Vec3F: case shaderir.Vec3F:
if err := checkArgsForVec3BuiltinFunc(args, argts); err != nil { if err := checkArgsForVec3BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error()) cs.addError(e.Pos(), err.Error())
@ -321,7 +321,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
args[i].Const = gconstant.ToFloat(args[i].Const) args[i].Const = gconstant.ToFloat(args[i].Const)
argts[i] = shaderir.Type{Main: shaderir.Float} argts[i] = shaderir.Type{Main: shaderir.Float}
} }
t = shaderir.Type{Main: shaderir.Vec3} finalType = shaderir.Type{Main: shaderir.Vec3}
case shaderir.Vec4F: case shaderir.Vec4F:
if err := checkArgsForVec4BuiltinFunc(args, argts); err != nil { if err := checkArgsForVec4BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error()) cs.addError(e.Pos(), err.Error())
@ -334,25 +334,25 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
args[i].Const = gconstant.ToFloat(args[i].Const) args[i].Const = gconstant.ToFloat(args[i].Const)
argts[i] = shaderir.Type{Main: shaderir.Float} argts[i] = shaderir.Type{Main: shaderir.Float}
} }
t = shaderir.Type{Main: shaderir.Vec4} finalType = shaderir.Type{Main: shaderir.Vec4}
case shaderir.IVec2F: case shaderir.IVec2F:
if err := checkArgsForIVec2BuiltinFunc(args, argts); err != nil { if err := checkArgsForIVec2BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error()) cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false return nil, nil, nil, false
} }
t = shaderir.Type{Main: shaderir.IVec2} finalType = shaderir.Type{Main: shaderir.IVec2}
case shaderir.IVec3F: case shaderir.IVec3F:
if err := checkArgsForIVec3BuiltinFunc(args, argts); err != nil { if err := checkArgsForIVec3BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error()) cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false return nil, nil, nil, false
} }
t = shaderir.Type{Main: shaderir.IVec3} finalType = shaderir.Type{Main: shaderir.IVec3}
case shaderir.IVec4F: case shaderir.IVec4F:
if err := checkArgsForIVec4BuiltinFunc(args, argts); err != nil { if err := checkArgsForIVec4BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error()) cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false return nil, nil, nil, false
} }
t = shaderir.Type{Main: shaderir.IVec4} finalType = shaderir.Type{Main: shaderir.IVec4}
case shaderir.Mat2F: case shaderir.Mat2F:
if err := checkArgsForMat2BuiltinFunc(args, argts); err != nil { if err := checkArgsForMat2BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error()) cs.addError(e.Pos(), err.Error())
@ -365,7 +365,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
args[i].Const = gconstant.ToFloat(args[i].Const) args[i].Const = gconstant.ToFloat(args[i].Const)
argts[i] = shaderir.Type{Main: shaderir.Float} argts[i] = shaderir.Type{Main: shaderir.Float}
} }
t = shaderir.Type{Main: shaderir.Mat2} finalType = shaderir.Type{Main: shaderir.Mat2}
case shaderir.Mat3F: case shaderir.Mat3F:
if err := checkArgsForMat3BuiltinFunc(args, argts); err != nil { if err := checkArgsForMat3BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error()) cs.addError(e.Pos(), err.Error())
@ -378,7 +378,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
args[i].Const = gconstant.ToFloat(args[i].Const) args[i].Const = gconstant.ToFloat(args[i].Const)
argts[i] = shaderir.Type{Main: shaderir.Float} argts[i] = shaderir.Type{Main: shaderir.Float}
} }
t = shaderir.Type{Main: shaderir.Mat3} finalType = shaderir.Type{Main: shaderir.Mat3}
case shaderir.Mat4F: case shaderir.Mat4F:
if err := checkArgsForMat4BuiltinFunc(args, argts); err != nil { if err := checkArgsForMat4BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error()) cs.addError(e.Pos(), err.Error())
@ -391,7 +391,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
args[i].Const = gconstant.ToFloat(args[i].Const) args[i].Const = gconstant.ToFloat(args[i].Const)
argts[i] = shaderir.Type{Main: shaderir.Float} argts[i] = shaderir.Type{Main: shaderir.Float}
} }
t = shaderir.Type{Main: shaderir.Mat4} finalType = shaderir.Type{Main: shaderir.Mat4}
case shaderir.TexelAt: case shaderir.TexelAt:
if len(args) != 2 { if len(args) != 2 {
cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 2 but %d", callee.BuiltinFunc, len(args))) cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 2 but %d", callee.BuiltinFunc, len(args)))
@ -405,7 +405,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
cs.addError(e.Pos(), fmt.Sprintf("cannot use %s as vec2 value in argument to %s", argts[1].String(), callee.BuiltinFunc)) cs.addError(e.Pos(), fmt.Sprintf("cannot use %s as vec2 value in argument to %s", argts[1].String(), callee.BuiltinFunc))
return nil, nil, nil, false return nil, nil, nil, false
} }
t = shaderir.Type{Main: shaderir.Vec4} finalType = shaderir.Type{Main: shaderir.Vec4}
case shaderir.DiscardF: case shaderir.DiscardF:
if len(args) != 0 { if len(args) != 0 {
cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 0 but %d", callee.BuiltinFunc, len(args))) cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 0 but %d", callee.BuiltinFunc, len(args)))
@ -428,13 +428,17 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
} }
switch callee.BuiltinFunc { switch callee.BuiltinFunc {
case shaderir.Clamp: case shaderir.Clamp:
if kind, allConsts := resolveConstKind(args, argts); allConsts { if kind, _ := resolveConstKind(args, argts); kind != gconstant.Unknown {
switch kind { switch kind {
case gconstant.Unknown:
cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s, %s, and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String(), argts[2].String()))
return nil, nil, nil, false
case gconstant.Int: case gconstant.Int:
for i, arg := range args { for i, arg := range args {
if arg.Const == nil {
if argts[i].Main != shaderir.Int {
cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s, %s, and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String(), argts[2].String()))
return nil, nil, nil, false
}
continue
}
v := gconstant.ToInt(arg.Const) v := gconstant.ToInt(arg.Const)
if v.Kind() == gconstant.Unknown { if v.Kind() == gconstant.Unknown {
cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type int", arg.Const.String())) cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type int", arg.Const.String()))
@ -445,6 +449,13 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
} }
case gconstant.Float: case gconstant.Float:
for i, arg := range args { for i, arg := range args {
if arg.Const == nil {
if argts[i].Main != shaderir.Float {
cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s, %s, and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String(), argts[2].String()))
return nil, nil, nil, false
}
continue
}
v := gconstant.ToFloat(arg.Const) v := gconstant.ToFloat(arg.Const)
if v.Kind() == gconstant.Unknown { if v.Kind() == gconstant.Unknown {
cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type float", arg.Const.String())) cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type float", arg.Const.String()))
@ -544,9 +555,9 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
switch callee.BuiltinFunc { switch callee.BuiltinFunc {
case shaderir.Smoothstep: case shaderir.Smoothstep:
t = argts[2] finalType = argts[2]
default: default:
t = argts[0] finalType = argts[0]
} }
case shaderir.Atan2, shaderir.Pow, shaderir.Mod, shaderir.Min, shaderir.Max, shaderir.Step, shaderir.Distance, shaderir.Dot, shaderir.Cross, shaderir.Reflect: case shaderir.Atan2, shaderir.Pow, shaderir.Mod, shaderir.Min, shaderir.Max, shaderir.Step, shaderir.Distance, shaderir.Dot, shaderir.Cross, shaderir.Reflect:
@ -558,13 +569,17 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
switch callee.BuiltinFunc { switch callee.BuiltinFunc {
case shaderir.Min, shaderir.Max: case shaderir.Min, shaderir.Max:
if kind, allConsts := resolveConstKind(args, argts); allConsts { if kind, _ := resolveConstKind(args, argts); kind != gconstant.Unknown {
switch kind { switch kind {
case gconstant.Unknown:
cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String()))
return nil, nil, nil, false
case gconstant.Int: case gconstant.Int:
for i, arg := range args { for i, arg := range args {
if arg.Const == nil {
if argts[i].Main != shaderir.Int {
cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String()))
return nil, nil, nil, false
}
continue
}
v := gconstant.ToInt(arg.Const) v := gconstant.ToInt(arg.Const)
if v.Kind() == gconstant.Unknown { if v.Kind() == gconstant.Unknown {
cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type int", arg.Const.String())) cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type int", arg.Const.String()))
@ -575,6 +590,13 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
} }
case gconstant.Float: case gconstant.Float:
for i, arg := range args { for i, arg := range args {
if arg.Const == nil {
if argts[i].Main != shaderir.Float {
cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String()))
return nil, nil, nil, false
}
continue
}
v := gconstant.ToFloat(arg.Const) v := gconstant.ToFloat(arg.Const)
if v.Kind() == gconstant.Unknown { if v.Kind() == gconstant.Unknown {
cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type float", arg.Const.String())) cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type float", arg.Const.String()))
@ -662,11 +684,11 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
} }
switch callee.BuiltinFunc { switch callee.BuiltinFunc {
case shaderir.Distance, shaderir.Dot: case shaderir.Distance, shaderir.Dot:
t = shaderir.Type{Main: shaderir.Float} finalType = shaderir.Type{Main: shaderir.Float}
case shaderir.Step: case shaderir.Step:
t = argts[1] finalType = argts[1]
default: default:
t = argts[0] finalType = argts[0]
} }
default: default:
@ -711,9 +733,9 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
} }
} }
if callee.BuiltinFunc == shaderir.Length { if callee.BuiltinFunc == shaderir.Length {
t = shaderir.Type{Main: shaderir.Float} finalType = shaderir.Type{Main: shaderir.Float}
} else { } else {
t = argts[0] finalType = argts[0]
} }
} }
return []shaderir.Expr{ return []shaderir.Expr{
@ -721,7 +743,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
Type: shaderir.Call, Type: shaderir.Call,
Exprs: append([]shaderir.Expr{callee}, args...), Exprs: append([]shaderir.Expr{callee}, args...),
}, },
}, []shaderir.Type{t}, stmts, true }, []shaderir.Type{finalType}, stmts, true
} }
if callee.Type != shaderir.FunctionExpr { if callee.Type != shaderir.FunctionExpr {
@ -1162,8 +1184,24 @@ func resolveConstKind(exprs []shaderir.Expr, ts []shaderir.Type) (kind gconstant
panic("not reached") panic("not reached")
} }
allConsts = true
for _, expr := range exprs { for _, expr := range exprs {
if expr.Const == nil { if expr.Const == nil {
allConsts = false
}
}
if !allConsts {
for _, t := range ts {
if t.Main == shaderir.None {
continue
}
if t.Main == shaderir.Float {
return gconstant.Float, false
}
if t.Main == shaderir.Int {
return gconstant.Int, false
}
return gconstant.Unknown, false return gconstant.Unknown, false
} }
} }
@ -1192,17 +1230,15 @@ func resolveConstKind(exprs []shaderir.Expr, ts []shaderir.Type) (kind gconstant
} }
} }
if kind == gconstant.Float { if kind != gconstant.Unknown {
return gconstant.Float, true return kind, true
} }
// Prefer floats over integers for non-typed constant values. // Prefer floats over integers for non-typed constant values.
// For example, max(1.0, 1) should return a float value. // For example, max(1.0, 1) should return a float value.
if kind == gconstant.Unknown { for _, expr := range exprs {
for _, expr := range exprs { if expr.Const.Kind() == gconstant.Float {
if expr.Const.Kind() == gconstant.Float { return gconstant.Float, true
return gconstant.Float, true
}
} }
} }

View File

@ -2347,6 +2347,10 @@ func TestSyntaxBuiltinFuncArgsMinMax(t *testing.T) {
{stmt: "a := {{.Func}}(int(1), int(1)); var _ int = a", err: false}, {stmt: "a := {{.Func}}(int(1), int(1)); var _ int = a", err: false},
{stmt: "a := {{.Func}}(int(1), float(1)); _ = a", err: true}, {stmt: "a := {{.Func}}(int(1), float(1)); _ = a", err: true},
{stmt: "a := {{.Func}}(float(1), int(1)); _ = a", err: true}, {stmt: "a := {{.Func}}(float(1), int(1)); _ = a", err: true},
{stmt: "x := 1.1; a := {{.Func}}(int(x), 1); _ = a", err: false},
{stmt: "x := 1; a := {{.Func}}(float(x), 1.1); _ = a", err: false},
{stmt: "x := 1.1; a := {{.Func}}(1, int(x)); _ = a", err: false},
{stmt: "x := 1; a := {{.Func}}(1.1, float(x)); _ = a", err: false},
{stmt: "a := {{.Func}}(1, vec2(1)); _ = a", err: true}, {stmt: "a := {{.Func}}(1, vec2(1)); _ = a", err: true},
{stmt: "a := {{.Func}}(1, vec3(1)); _ = a", err: true}, {stmt: "a := {{.Func}}(1, vec3(1)); _ = a", err: true},
{stmt: "a := {{.Func}}(1, vec4(1)); _ = a", err: true}, {stmt: "a := {{.Func}}(1, vec4(1)); _ = a", err: true},
@ -2547,6 +2551,12 @@ func TestSyntaxBuiltinFuncClampType(t *testing.T) {
{stmt: "a := clamp(float(1), 1, 1.0); var _ float = a", err: false}, {stmt: "a := clamp(float(1), 1, 1.0); var _ float = a", err: false},
{stmt: "a := clamp(float(1), 1.1, 1); _ = a", err: false}, {stmt: "a := clamp(float(1), 1.1, 1); _ = a", err: false},
{stmt: "a := clamp(float(1), 1, 1.1); _ = a", err: false}, {stmt: "a := clamp(float(1), 1, 1.1); _ = a", err: false},
{stmt: "x := 1.1; a := clamp(int(x), 1, 1); _ = a", err: false},
{stmt: "x := 1; a := clamp(float(x), 1.1, 1.1); _ = a", err: false},
{stmt: "x := 1.1; a := clamp(1, int(x), 1); _ = a", err: false},
{stmt: "x := 1; a := clamp(1.1, float(x), 1.1); _ = a", err: false},
{stmt: "x := 1.1; a := clamp(1, 1, int(x)); _ = a", err: false},
{stmt: "x := 1; a := clamp(1.1, 1.1, float(x)); _ = a", err: false},
{stmt: "a := clamp(1.0, 1, 1); var _ float = a", err: false}, {stmt: "a := clamp(1.0, 1, 1); var _ float = a", err: false},
{stmt: "a := clamp(1, 1.0, 1); var _ float = a", err: false}, {stmt: "a := clamp(1, 1.0, 1); var _ float = a", err: false},
{stmt: "a := clamp(1, 1, 1.0); var _ float = a", err: false}, {stmt: "a := clamp(1, 1, 1.0); var _ float = a", err: false},