mirror of
https://github.com/hajimehoshi/ebiten.git
synced 2025-02-23 00:10:11 +01:00
add basic checks
This commit is contained in:
parent
66a4b20bda
commit
f44640778d
124
internal/shader/delayed.go
Normal file
124
internal/shader/delayed.go
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
// Copyright 2024 The Ebiten Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package shader
|
||||||
|
|
||||||
|
import (
|
||||||
|
"go/ast"
|
||||||
|
gconstant "go/constant"
|
||||||
|
"go/token"
|
||||||
|
|
||||||
|
"github.com/hajimehoshi/ebiten/v2/internal/shaderir"
|
||||||
|
)
|
||||||
|
|
||||||
|
type resolveTypeStatus int
|
||||||
|
|
||||||
|
const (
|
||||||
|
resolveUnsure resolveTypeStatus = iota
|
||||||
|
resolveOk
|
||||||
|
resolveFail
|
||||||
|
)
|
||||||
|
|
||||||
|
type delayedValidator interface {
|
||||||
|
Validate(expr ast.Expr) resolveTypeStatus
|
||||||
|
Pos() token.Pos
|
||||||
|
Error() string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *compileState) tryValidateDelayed(cexpr ast.Expr) (ok bool) {
|
||||||
|
valExprs := make([]ast.Expr, 0, len(cs.delayedTypeCheks))
|
||||||
|
for k := range cs.delayedTypeCheks {
|
||||||
|
valExprs = append(valExprs, k)
|
||||||
|
}
|
||||||
|
for _, expr := range valExprs {
|
||||||
|
if cexpr == expr {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Check if delayed validation can be done by adding current context
|
||||||
|
cres := cs.delayedTypeCheks[expr].Validate(cexpr)
|
||||||
|
switch cres {
|
||||||
|
case resolveFail:
|
||||||
|
cs.addError(cs.delayedTypeCheks[expr].Pos(), cs.delayedTypeCheks[expr].Error())
|
||||||
|
return false
|
||||||
|
case resolveOk:
|
||||||
|
delete(cs.delayedTypeCheks, expr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type delayedShiftValidator struct {
|
||||||
|
value gconstant.Value
|
||||||
|
pos token.Pos
|
||||||
|
last ast.Expr
|
||||||
|
}
|
||||||
|
|
||||||
|
func isArgDefaultTypeInt(f shaderir.BuiltinFunc) bool {
|
||||||
|
return f == shaderir.IntF || f == shaderir.IVec2F || f == shaderir.IVec3F || f == shaderir.IVec4F
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *delayedShiftValidator) Validate(cexpr ast.Expr) (rs resolveTypeStatus) {
|
||||||
|
switch cexpr.(type) {
|
||||||
|
case *ast.Ident:
|
||||||
|
ident := cexpr.(*ast.Ident)
|
||||||
|
// For BuiltinFunc, only int* are allowed
|
||||||
|
if fname, ok := shaderir.ParseBuiltinFunc(ident.Name); ok {
|
||||||
|
if isArgDefaultTypeInt(fname) {
|
||||||
|
return resolveOk
|
||||||
|
}
|
||||||
|
return resolveFail
|
||||||
|
}
|
||||||
|
// Untyped constant must represent int
|
||||||
|
if ident.Name == "_" {
|
||||||
|
if d.value != nil && d.value.Kind() == gconstant.Int {
|
||||||
|
return resolveOk
|
||||||
|
}
|
||||||
|
return resolveFail
|
||||||
|
}
|
||||||
|
if ident.Obj != nil {
|
||||||
|
if t, ok := ident.Obj.Type.(*ast.Ident); ok {
|
||||||
|
return d.Validate(t)
|
||||||
|
}
|
||||||
|
if decl, ok := ident.Obj.Decl.(*ast.ValueSpec); ok {
|
||||||
|
return d.Validate(decl.Type)
|
||||||
|
}
|
||||||
|
if _, ok := ident.Obj.Decl.(*ast.AssignStmt); ok {
|
||||||
|
if d.value != nil && d.value.Kind() == gconstant.Int {
|
||||||
|
return resolveOk
|
||||||
|
}
|
||||||
|
return resolveFail
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case *ast.BinaryExpr:
|
||||||
|
bs := cexpr.(*ast.BinaryExpr)
|
||||||
|
left, right := bs.X, bs.Y
|
||||||
|
if bs.Y == d.last {
|
||||||
|
left, right = right, left
|
||||||
|
}
|
||||||
|
|
||||||
|
rightCheck := d.Validate(right)
|
||||||
|
d.last = cexpr
|
||||||
|
return rightCheck
|
||||||
|
}
|
||||||
|
return resolveUnsure
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d delayedShiftValidator) Pos() token.Pos {
|
||||||
|
return d.pos
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d delayedShiftValidator) Error() string {
|
||||||
|
return "left shift operand should be int"
|
||||||
|
}
|
@ -36,7 +36,12 @@ func canTruncateToFloat(v gconstant.Value) bool {
|
|||||||
|
|
||||||
var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`)
|
var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`)
|
||||||
|
|
||||||
func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) ([]shaderir.Expr, []shaderir.Type, []shaderir.Stmt, bool) {
|
func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) (rexpr []shaderir.Expr, rtype []shaderir.Type, rstmt []shaderir.Stmt, ok bool) {
|
||||||
|
defer func() {
|
||||||
|
// Due to use of early return in the parsing, delayed checks are conducted in defer
|
||||||
|
ok = ok && cs.tryValidateDelayed(expr)
|
||||||
|
}()
|
||||||
|
|
||||||
switch e := expr.(type) {
|
switch e := expr.(type) {
|
||||||
case *ast.BasicLit:
|
case *ast.BasicLit:
|
||||||
switch e.Kind {
|
switch e.Kind {
|
||||||
@ -103,6 +108,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
|
|||||||
// Resolve untyped constants.
|
// Resolve untyped constants.
|
||||||
var l gconstant.Value
|
var l gconstant.Value
|
||||||
var r gconstant.Value
|
var r gconstant.Value
|
||||||
|
origLvalue := lhs[0].Const
|
||||||
if op2 == shaderir.LeftShift || op2 == shaderir.RightShift {
|
if op2 == shaderir.LeftShift || op2 == shaderir.RightShift {
|
||||||
l, r, ok = shaderir.ResolveUntypedConstsForBitShiftOp(lhs[0].Const, rhs[0].Const, lhst, rhst)
|
l, r, ok = shaderir.ResolveUntypedConstsForBitShiftOp(lhs[0].Const, rhs[0].Const, lhst, rhst)
|
||||||
} else {
|
} else {
|
||||||
@ -126,6 +132,9 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
|
|||||||
if lhst.Main == shaderir.None && lhs[0].Const != nil {
|
if lhst.Main == shaderir.None && lhs[0].Const != nil {
|
||||||
lhst = shaderir.Type{Main: shaderir.Int}
|
lhst = shaderir.Type{Main: shaderir.Int}
|
||||||
// Left should be implicitly converted to the type it would assume if the shift expression were replaced by its left operand alone.
|
// Left should be implicitly converted to the type it would assume if the shift expression were replaced by its left operand alone.
|
||||||
|
if rhs[0].Const == nil {
|
||||||
|
cs.addDelayedTypeCheck(expr, &delayedShiftValidator{value: origLvalue, pos: e.Pos(), last: expr})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -61,6 +61,8 @@ type compileState struct {
|
|||||||
|
|
||||||
varyingParsed bool
|
varyingParsed bool
|
||||||
|
|
||||||
|
delayedTypeCheks map[ast.Expr]delayedValidator
|
||||||
|
|
||||||
errs []string
|
errs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,6 +84,13 @@ func (cs *compileState) findUniformVariable(name string) (int, bool) {
|
|||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cs *compileState) addDelayedTypeCheck(at ast.Expr, check delayedValidator) {
|
||||||
|
if cs.delayedTypeCheks == nil {
|
||||||
|
cs.delayedTypeCheks = make(map[ast.Expr]delayedValidator, 1)
|
||||||
|
}
|
||||||
|
cs.delayedTypeCheks[at] = check
|
||||||
|
}
|
||||||
|
|
||||||
type typ struct {
|
type typ struct {
|
||||||
name string
|
name string
|
||||||
ir shaderir.Type
|
ir shaderir.Type
|
||||||
@ -350,6 +359,12 @@ func (cs *compileState) parse(f *ast.File) {
|
|||||||
for _, f := range cs.funcs {
|
for _, f := range cs.funcs {
|
||||||
cs.ir.Funcs = append(cs.ir.Funcs, f.ir)
|
cs.ir.Funcs = append(cs.ir.Funcs, f.ir)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if len(cs.delayedTypeCheks) != 0 {
|
||||||
|
// for _, check := range cs.delayedTypeCheks {
|
||||||
|
// cs.addError(check.Pos(), check.Error())
|
||||||
|
// }
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *compileState) parseDecl(b *block, fname string, d ast.Decl) ([]shaderir.Stmt, bool) {
|
func (cs *compileState) parseDecl(b *block, fname string, d ast.Decl) ([]shaderir.Stmt, bool) {
|
||||||
|
@ -1320,23 +1320,27 @@ func TestSyntaxOperatorShift(t *testing.T) {
|
|||||||
stmt string
|
stmt string
|
||||||
err bool
|
err bool
|
||||||
}{
|
}{
|
||||||
// {stmt: "s := 1; var a float = float(1 << s); _ = a", err: true},
|
{stmt: "s := 1; t := 2; _ = t + 1.0 << s", err: false},
|
||||||
// {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true},
|
{stmt: "s := 1; _ = 1 << s", err: false},
|
||||||
// {stmt: "s := 1; var a int = int(1 << s); _ = a", err: false},
|
{stmt: "s := 1; _ = 1.0 << s", err: true},
|
||||||
// {stmt: "s := 1; var a int = int(1.0 << s); _ = a", err: false},
|
{stmt: "var a = 1; b := a << 2.0; _ = b", err: false},
|
||||||
// {stmt: "s := 1; a := 1 << s; _ = a", err: false},
|
{stmt: "s := 1; var a float; a = 1 << s; _ = a", err: true},
|
||||||
// {stmt: "s := 1; a := 1.0 << s; _ = a", err: true},
|
{stmt: "s := 1; var a float = 1 << s; _ = a", err: true},
|
||||||
// {stmt: "s := 1; a := int(1.0 << s); _ = a", err: false},
|
{stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true},
|
||||||
// {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true},
|
{stmt: "s := 1; var a int = int(1 << s); _ = a", err: false},
|
||||||
// {stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true},
|
{stmt: "s := 1; var a int = int(1.0 << s); _ = a", err: false},
|
||||||
// {stmt: "s := 1; var a float = 1 << s; _ = a", err: true},
|
{stmt: "s := 1; a := 1 << s; _ = a", err: false},
|
||||||
// {stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false},
|
{stmt: "s := 1; a := 1.0 << s; _ = a", err: true},
|
||||||
// {stmt: "s := 1; var a int = 1 << s; _ = a", err: false},
|
{stmt: "s := 1; a := int(1.0 << s); _ = a", err: false},
|
||||||
|
{stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true},
|
||||||
|
{stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true},
|
||||||
|
{stmt: "s := 1; var a float = 1 << s; _ = a", err: true},
|
||||||
|
{stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false},
|
||||||
|
{stmt: "s := 1; var a int = 1 << s; _ = a", err: false},
|
||||||
|
|
||||||
{stmt: "var a float = 1.0 << 2.0; _ = a", err: false},
|
{stmt: "var a float = 1.0 << 2.0; _ = a", err: false},
|
||||||
{stmt: "var a int = 1.0 << 2; _ = a", err: false},
|
{stmt: "var a int = 1.0 << 2; _ = a", err: false},
|
||||||
{stmt: "var a float = 1.0 << 2; _ = a", err: false},
|
{stmt: "var a float = 1.0 << 2; _ = a", err: false},
|
||||||
{stmt: "var a = 1.0 << 2; _ = a", err: false},
|
|
||||||
{stmt: "a := 1 << 2.0; _ = a", err: false},
|
{stmt: "a := 1 << 2.0; _ = a", err: false},
|
||||||
{stmt: "a := 1.0 << 2; _ = a", err: false},
|
{stmt: "a := 1.0 << 2; _ = a", err: false},
|
||||||
{stmt: "a := 1.0 << 2.0; _ = a", err: false},
|
{stmt: "a := 1.0 << 2.0; _ = a", err: false},
|
||||||
@ -1362,36 +1366,6 @@ func TestSyntaxOperatorShift(t *testing.T) {
|
|||||||
{stmt: "a := ivec2(1) << vec2(2); _ = a", err: true},
|
{stmt: "a := ivec2(1) << vec2(2); _ = a", err: true},
|
||||||
{stmt: "a := vec3(1) << ivec2(2); _ = a", err: true},
|
{stmt: "a := vec3(1) << ivec2(2); _ = a", err: true},
|
||||||
{stmt: "a := ivec2(1) << vec3(2); _ = a", err: true},
|
{stmt: "a := ivec2(1) << vec3(2); _ = a", err: true},
|
||||||
|
|
||||||
{stmt: "var a float = 1.0 >> 2.0; _ = a", err: false},
|
|
||||||
{stmt: "var a int = 1.0 >> 2; _ = a", err: false},
|
|
||||||
{stmt: "var a float = 1.0 >> 2; _ = a", err: false},
|
|
||||||
{stmt: "var a = 1.0 >> 2; _ = a", err: false},
|
|
||||||
{stmt: "a := 1 >> 2.0; _ = a", err: false},
|
|
||||||
{stmt: "a := 1.0 >> 2; _ = a", err: false},
|
|
||||||
{stmt: "a := 1.0 >> 2.0; _ = a", err: false},
|
|
||||||
{stmt: "a := 1 >> 2; _ = a", err: false},
|
|
||||||
{stmt: "a := float(1.0) >> 2; _ = a", err: true},
|
|
||||||
{stmt: "a := 1 >> float(2.0); _ = a", err: false},
|
|
||||||
{stmt: "a := 1.0 >> float(2.0); _ = a", err: false},
|
|
||||||
{stmt: "a := ivec2(1) >> 2; _ = a", err: false},
|
|
||||||
{stmt: "a := 1 >> ivec2(2); _ = a", err: true},
|
|
||||||
{stmt: "a := ivec2(1) >> float(2.0); _ = a", err: true},
|
|
||||||
{stmt: "a := float(1.0) >> ivec2(2); _ = a", err: true},
|
|
||||||
{stmt: "a := ivec2(1) >> ivec2(2); _ = a", err: false},
|
|
||||||
{stmt: "a := ivec3(1) >> ivec2(2); _ = a", err: true},
|
|
||||||
{stmt: "a := ivec2(1) >> ivec3(2); _ = a", err: true},
|
|
||||||
{stmt: "a := 1 >> vec2(2); _ = a", err: true},
|
|
||||||
{stmt: "a := vec2(1) >> 2; _ = a", err: true},
|
|
||||||
{stmt: "a := float(1.0) >> vec2(2); _ = a", err: true},
|
|
||||||
{stmt: "a := vec2(1) >> float(2.0); _ = a", err: true},
|
|
||||||
{stmt: "a := vec2(1) >> vec2(2); _ = a", err: true},
|
|
||||||
{stmt: "a := vec2(1) >> vec3(2); _ = a", err: true},
|
|
||||||
{stmt: "a := vec3(1) >> vec2(2); _ = a", err: true},
|
|
||||||
{stmt: "a := vec2(1) >> ivec2(2); _ = a", err: true},
|
|
||||||
{stmt: "a := ivec2(1) >> vec2(2); _ = a", err: true},
|
|
||||||
{stmt: "a := vec3(1) >> ivec2(2); _ = a", err: true},
|
|
||||||
{stmt: "a := ivec2(1) >> vec3(2); _ = a", err: true},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
@ -1407,6 +1381,80 @@ func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
|
|||||||
t.Errorf("%s must not return nil but returned %v", c.stmt, err)
|
t.Errorf("%s must not return nil but returned %v", c.stmt, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
casesFunc := []struct {
|
||||||
|
prog string
|
||||||
|
err bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
prog: `package main
|
||||||
|
func Foo(x int) {
|
||||||
|
_ = x
|
||||||
|
}
|
||||||
|
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
|
||||||
|
s := 1
|
||||||
|
Foo(1 << s)
|
||||||
|
return dstPos
|
||||||
|
}`,
|
||||||
|
err: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
prog: `package main
|
||||||
|
func Foo(x int) {
|
||||||
|
_ = x
|
||||||
|
}
|
||||||
|
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
|
||||||
|
s := 1
|
||||||
|
Foo(1.0 << s)
|
||||||
|
return dstPos
|
||||||
|
}`,
|
||||||
|
err: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
prog: `package main
|
||||||
|
func Foo(x float) {
|
||||||
|
_ = x
|
||||||
|
}
|
||||||
|
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
|
||||||
|
s := 1
|
||||||
|
Foo(1 << s)
|
||||||
|
return dstPos
|
||||||
|
}`,
|
||||||
|
err: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
prog: `package main
|
||||||
|
func Foo(x float) {
|
||||||
|
_ = x
|
||||||
|
}
|
||||||
|
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
|
||||||
|
s := 1
|
||||||
|
Foo(1 << s)
|
||||||
|
return dstPos
|
||||||
|
}`,
|
||||||
|
err: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
prog: `package main
|
||||||
|
func Foo(x float) {
|
||||||
|
_ = x
|
||||||
|
}
|
||||||
|
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
|
||||||
|
Foo(1.0 << 2.0)
|
||||||
|
return dstPos
|
||||||
|
}`,
|
||||||
|
err: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range casesFunc {
|
||||||
|
_, err := compileToIR([]byte(c.prog))
|
||||||
|
if err == nil && c.err {
|
||||||
|
t.Errorf("%s must return an error but does not", c.prog)
|
||||||
|
} else if err != nil && !c.err {
|
||||||
|
t.Errorf("%s must not return nil but returned %v", c.prog, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSyntaxOperatorShiftAssign(t *testing.T) {
|
func TestSyntaxOperatorShiftAssign(t *testing.T) {
|
||||||
|
Loading…
Reference in New Issue
Block a user