shader: Forbid unexported global variables

This commit is contained in:
Hajime Hoshi 2020-05-10 21:57:12 +09:00
parent a39376ad47
commit 02eafb2929
2 changed files with 29 additions and 41 deletions

View File

@ -106,18 +106,7 @@ func (s *Shader) addError(pos token.Pos, str string) {
func (sh *Shader) parse(f *ast.File) {
for _, d := range f.Decls {
sh.parseDecl(&sh.global, d)
}
vars := make([]variable, len(sh.global.vars))
copy(vars, sh.global.vars)
sh.global.vars = nil
for _, v := range vars {
if 'A' <= v.name[0] && v.name[0] <= 'Z' {
sh.uniforms = append(sh.uniforms, v)
} else {
sh.global.vars = append(sh.global.vars, v)
}
sh.parseDecl(&sh.global, d, true)
}
// TODO: This is duplicated with parseBlock.
@ -135,7 +124,7 @@ func (sh *Shader) parse(f *ast.File) {
})
}
func (sh *Shader) parseDecl(b *block, d ast.Decl) {
func (sh *Shader) parseDecl(b *block, d ast.Decl, global bool) {
switch d := d.(type) {
case *ast.GenDecl:
switch d.Tok {
@ -157,7 +146,17 @@ func (sh *Shader) parseDecl(b *block, d ast.Decl) {
for _, s := range d.Specs {
s := s.(*ast.ValueSpec)
vs := sh.parseVariable(b, s)
b.vars = append(b.vars, vs...)
if !global {
b.vars = append(b.vars, vs...)
continue
}
for i, v := range vs {
if 'A' <= v.name[0] && v.name[0] <= 'Z' {
sh.uniforms = append(sh.uniforms, v)
} else {
sh.addError(s.Names[i].Pos(), fmt.Sprintf("global variables must be exposed: %s", v.name))
}
}
}
case token.IMPORT:
sh.addError(d.Pos(), "import is forbidden")
@ -374,7 +373,7 @@ func (sh *Shader) parseBlock(outer *block, b *ast.BlockStmt) *block {
block: sh.parseBlock(block, l),
})
case *ast.DeclStmt:
sh.parseDecl(block, l.Decl)
sh.parseDecl(block, l.Decl, false)
case *ast.ReturnStmt:
var exprs []ast.Expr
for _, r := range l.Results {
@ -418,6 +417,13 @@ func (s *Shader) detectType(b *block, expr ast.Expr) typ {
return v.typ
}
}
if b == &s.global {
for _, v := range s.uniforms {
if v.name == n {
return v.typ
}
}
}
if b.outer != nil {
return s.detectType(b.outer, e)
}

View File

@ -40,7 +40,6 @@ var Foo float
var (
Bar vec2
Baz, Quux vec3
qux vec4
)
const C1 float = 1
@ -61,7 +60,6 @@ var Bar uniform vec2
var Baz uniform vec3
var Foo uniform float
var Quux uniform vec3
var qux vec4
const C1 float = 1
const C2 float = 2
const C3 float = 3
@ -80,31 +78,15 @@ func F1(a vec2, b vec2) (_ vec4) {
Name: "AutoType",
Src: `package main
var v0 = 0.0
var V0 = 0.0
func F() {
v1 := v0
v1 := V0
}
`,
Dump: `var v0 float = 0.0
Dump: `var V0 uniform float
func F() {
var v1 float
v1 = v0
}
`,
},
{
Name: "AutoType",
Src: `package main
var v0 = 0.0
func F() {
v1 := v0
}
`,
Dump: `var v0 float = 0.0
func F() {
var v1 float
v1 = v0
v1 = V0
}
`,
},
@ -112,18 +94,18 @@ func F() {
Name: "AutoType2",
Src: `package main
var v0 = 0.0
var V0 = 0.0
func F() {
v1 := v0
v1 := V0
{
v2 := v1
}
}
`,
Dump: `var v0 float = 0.0
Dump: `var V0 uniform float
func F() {
var v1 float
v1 = v0
v1 = V0
{
var v2 float
v2 = v1