internal/shader: bug fix: needed to resolve const and non-const types

Closes #2922
This commit is contained in:
Hajime Hoshi 2024-03-10 19:14:05 +09:00
parent c9a973c6c1
commit 63e97c7064
2 changed files with 83 additions and 37 deletions

View File

@ -276,26 +276,26 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
}
// Process the expression as a regular function call.
var t shaderir.Type
var finalType shaderir.Type
switch callee.BuiltinFunc {
case shaderir.BoolF:
if err := checkArgsForBoolBuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false
}
t = shaderir.Type{Main: shaderir.Bool}
finalType = shaderir.Type{Main: shaderir.Bool}
case shaderir.IntF:
if err := checkArgsForIntBuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false
}
t = shaderir.Type{Main: shaderir.Int}
finalType = shaderir.Type{Main: shaderir.Int}
case shaderir.FloatF:
if err := checkArgsForFloatBuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false
}
t = shaderir.Type{Main: shaderir.Float}
finalType = shaderir.Type{Main: shaderir.Float}
case shaderir.Vec2F:
if err := checkArgsForVec2BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
@ -308,7 +308,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
args[i].Const = gconstant.ToFloat(args[i].Const)
argts[i] = shaderir.Type{Main: shaderir.Float}
}
t = shaderir.Type{Main: shaderir.Vec2}
finalType = shaderir.Type{Main: shaderir.Vec2}
case shaderir.Vec3F:
if err := checkArgsForVec3BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
@ -321,7 +321,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
args[i].Const = gconstant.ToFloat(args[i].Const)
argts[i] = shaderir.Type{Main: shaderir.Float}
}
t = shaderir.Type{Main: shaderir.Vec3}
finalType = shaderir.Type{Main: shaderir.Vec3}
case shaderir.Vec4F:
if err := checkArgsForVec4BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
@ -334,25 +334,25 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
args[i].Const = gconstant.ToFloat(args[i].Const)
argts[i] = shaderir.Type{Main: shaderir.Float}
}
t = shaderir.Type{Main: shaderir.Vec4}
finalType = shaderir.Type{Main: shaderir.Vec4}
case shaderir.IVec2F:
if err := checkArgsForIVec2BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false
}
t = shaderir.Type{Main: shaderir.IVec2}
finalType = shaderir.Type{Main: shaderir.IVec2}
case shaderir.IVec3F:
if err := checkArgsForIVec3BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false
}
t = shaderir.Type{Main: shaderir.IVec3}
finalType = shaderir.Type{Main: shaderir.IVec3}
case shaderir.IVec4F:
if err := checkArgsForIVec4BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false
}
t = shaderir.Type{Main: shaderir.IVec4}
finalType = shaderir.Type{Main: shaderir.IVec4}
case shaderir.Mat2F:
if err := checkArgsForMat2BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
@ -365,7 +365,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
args[i].Const = gconstant.ToFloat(args[i].Const)
argts[i] = shaderir.Type{Main: shaderir.Float}
}
t = shaderir.Type{Main: shaderir.Mat2}
finalType = shaderir.Type{Main: shaderir.Mat2}
case shaderir.Mat3F:
if err := checkArgsForMat3BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
@ -378,7 +378,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
args[i].Const = gconstant.ToFloat(args[i].Const)
argts[i] = shaderir.Type{Main: shaderir.Float}
}
t = shaderir.Type{Main: shaderir.Mat3}
finalType = shaderir.Type{Main: shaderir.Mat3}
case shaderir.Mat4F:
if err := checkArgsForMat4BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
@ -391,7 +391,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
args[i].Const = gconstant.ToFloat(args[i].Const)
argts[i] = shaderir.Type{Main: shaderir.Float}
}
t = shaderir.Type{Main: shaderir.Mat4}
finalType = shaderir.Type{Main: shaderir.Mat4}
case shaderir.TexelAt:
if len(args) != 2 {
cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 2 but %d", callee.BuiltinFunc, len(args)))
@ -405,7 +405,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
cs.addError(e.Pos(), fmt.Sprintf("cannot use %s as vec2 value in argument to %s", argts[1].String(), callee.BuiltinFunc))
return nil, nil, nil, false
}
t = shaderir.Type{Main: shaderir.Vec4}
finalType = shaderir.Type{Main: shaderir.Vec4}
case shaderir.DiscardF:
if len(args) != 0 {
cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 0 but %d", callee.BuiltinFunc, len(args)))
@ -428,13 +428,17 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
}
switch callee.BuiltinFunc {
case shaderir.Clamp:
if kind, allConsts := resolveConstKind(args, argts); allConsts {
if kind, _ := resolveConstKind(args, argts); kind != gconstant.Unknown {
switch kind {
case gconstant.Unknown:
cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s, %s, and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String(), argts[2].String()))
return nil, nil, nil, false
case gconstant.Int:
for i, arg := range args {
if arg.Const == nil {
if argts[i].Main != shaderir.Int {
cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s, %s, and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String(), argts[2].String()))
return nil, nil, nil, false
}
continue
}
v := gconstant.ToInt(arg.Const)
if v.Kind() == gconstant.Unknown {
cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type int", arg.Const.String()))
@ -445,6 +449,13 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
}
case gconstant.Float:
for i, arg := range args {
if arg.Const == nil {
if argts[i].Main != shaderir.Float {
cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s, %s, and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String(), argts[2].String()))
return nil, nil, nil, false
}
continue
}
v := gconstant.ToFloat(arg.Const)
if v.Kind() == gconstant.Unknown {
cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type float", arg.Const.String()))
@ -544,9 +555,9 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
switch callee.BuiltinFunc {
case shaderir.Smoothstep:
t = argts[2]
finalType = argts[2]
default:
t = argts[0]
finalType = argts[0]
}
case shaderir.Atan2, shaderir.Pow, shaderir.Mod, shaderir.Min, shaderir.Max, shaderir.Step, shaderir.Distance, shaderir.Dot, shaderir.Cross, shaderir.Reflect:
@ -558,13 +569,17 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
switch callee.BuiltinFunc {
case shaderir.Min, shaderir.Max:
if kind, allConsts := resolveConstKind(args, argts); allConsts {
if kind, _ := resolveConstKind(args, argts); kind != gconstant.Unknown {
switch kind {
case gconstant.Unknown:
cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String()))
return nil, nil, nil, false
case gconstant.Int:
for i, arg := range args {
if arg.Const == nil {
if argts[i].Main != shaderir.Int {
cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String()))
return nil, nil, nil, false
}
continue
}
v := gconstant.ToInt(arg.Const)
if v.Kind() == gconstant.Unknown {
cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type int", arg.Const.String()))
@ -575,6 +590,13 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
}
case gconstant.Float:
for i, arg := range args {
if arg.Const == nil {
if argts[i].Main != shaderir.Float {
cs.addError(e.Pos(), fmt.Sprintf("%s's arguments don't match: %s and %s", callee.BuiltinFunc, argts[0].String(), argts[1].String()))
return nil, nil, nil, false
}
continue
}
v := gconstant.ToFloat(arg.Const)
if v.Kind() == gconstant.Unknown {
cs.addError(e.Pos(), fmt.Sprintf("cannot convert %s to type float", arg.Const.String()))
@ -662,11 +684,11 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
}
switch callee.BuiltinFunc {
case shaderir.Distance, shaderir.Dot:
t = shaderir.Type{Main: shaderir.Float}
finalType = shaderir.Type{Main: shaderir.Float}
case shaderir.Step:
t = argts[1]
finalType = argts[1]
default:
t = argts[0]
finalType = argts[0]
}
default:
@ -711,9 +733,9 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
}
}
if callee.BuiltinFunc == shaderir.Length {
t = shaderir.Type{Main: shaderir.Float}
finalType = shaderir.Type{Main: shaderir.Float}
} else {
t = argts[0]
finalType = argts[0]
}
}
return []shaderir.Expr{
@ -721,7 +743,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
Type: shaderir.Call,
Exprs: append([]shaderir.Expr{callee}, args...),
},
}, []shaderir.Type{t}, stmts, true
}, []shaderir.Type{finalType}, stmts, true
}
if callee.Type != shaderir.FunctionExpr {
@ -1162,8 +1184,24 @@ func resolveConstKind(exprs []shaderir.Expr, ts []shaderir.Type) (kind gconstant
panic("not reached")
}
allConsts = true
for _, expr := range exprs {
if expr.Const == nil {
allConsts = false
}
}
if !allConsts {
for _, t := range ts {
if t.Main == shaderir.None {
continue
}
if t.Main == shaderir.Float {
return gconstant.Float, false
}
if t.Main == shaderir.Int {
return gconstant.Int, false
}
return gconstant.Unknown, false
}
}
@ -1192,19 +1230,17 @@ func resolveConstKind(exprs []shaderir.Expr, ts []shaderir.Type) (kind gconstant
}
}
if kind == gconstant.Float {
return gconstant.Float, true
if kind != gconstant.Unknown {
return kind, true
}
// Prefer floats over integers for non-typed constant values.
// For example, max(1.0, 1) should return a float value.
if kind == gconstant.Unknown {
for _, expr := range exprs {
if expr.Const.Kind() == gconstant.Float {
return gconstant.Float, true
}
}
}
return gconstant.Int, true
}

View File

@ -2347,6 +2347,10 @@ func TestSyntaxBuiltinFuncArgsMinMax(t *testing.T) {
{stmt: "a := {{.Func}}(int(1), int(1)); var _ int = a", err: false},
{stmt: "a := {{.Func}}(int(1), float(1)); _ = a", err: true},
{stmt: "a := {{.Func}}(float(1), int(1)); _ = a", err: true},
{stmt: "x := 1.1; a := {{.Func}}(int(x), 1); _ = a", err: false},
{stmt: "x := 1; a := {{.Func}}(float(x), 1.1); _ = a", err: false},
{stmt: "x := 1.1; a := {{.Func}}(1, int(x)); _ = a", err: false},
{stmt: "x := 1; a := {{.Func}}(1.1, float(x)); _ = a", err: false},
{stmt: "a := {{.Func}}(1, vec2(1)); _ = a", err: true},
{stmt: "a := {{.Func}}(1, vec3(1)); _ = a", err: true},
{stmt: "a := {{.Func}}(1, vec4(1)); _ = a", err: true},
@ -2547,6 +2551,12 @@ func TestSyntaxBuiltinFuncClampType(t *testing.T) {
{stmt: "a := clamp(float(1), 1, 1.0); var _ float = a", err: false},
{stmt: "a := clamp(float(1), 1.1, 1); _ = a", err: false},
{stmt: "a := clamp(float(1), 1, 1.1); _ = a", err: false},
{stmt: "x := 1.1; a := clamp(int(x), 1, 1); _ = a", err: false},
{stmt: "x := 1; a := clamp(float(x), 1.1, 1.1); _ = a", err: false},
{stmt: "x := 1.1; a := clamp(1, int(x), 1); _ = a", err: false},
{stmt: "x := 1; a := clamp(1.1, float(x), 1.1); _ = a", err: false},
{stmt: "x := 1.1; a := clamp(1, 1, int(x)); _ = a", err: false},
{stmt: "x := 1; a := clamp(1.1, 1.1, float(x)); _ = a", err: false},
{stmt: "a := clamp(1.0, 1, 1); var _ float = a", err: false},
{stmt: "a := clamp(1, 1.0, 1); var _ float = a", err: false},
{stmt: "a := clamp(1, 1, 1.0); var _ float = a", err: false},