shader: Enable to specify entrypoint names

This commit is contained in:
Hajime Hoshi 2020-06-05 01:11:39 +09:00
parent d0aa18ddb9
commit 3dbf4c0a83
3 changed files with 13 additions and 13 deletions

View File

@ -25,11 +25,6 @@ import (
"github.com/hajimehoshi/ebiten/internal/shaderir"
)
const (
vertexEntry = "Vertex"
fragmentEntry = "Fragment"
)
type variable struct {
name string
typ typ
@ -52,6 +47,9 @@ type function struct {
type compileState struct {
fs *token.FileSet
vertexEntry string
fragmentEntry string
ir shaderir.Program
// uniforms is a collection of uniform variable names.
@ -108,7 +106,7 @@ func (p *ParseError) Error() string {
return strings.Join(p.errs, "\n")
}
func Compile(src []byte) (*shaderir.Program, error) {
func Compile(src []byte, vertexEntry, fragmentEntry string) (*shaderir.Program, error) {
fs := token.NewFileSet()
f, err := parser.ParseFile(fs, "", src, parser.AllErrors)
if err != nil {
@ -116,7 +114,9 @@ func Compile(src []byte) (*shaderir.Program, error) {
}
s := &compileState{
fs: fs,
fs: fs,
vertexEntry: vertexEntry,
fragmentEntry: fragmentEntry,
}
s.parse(f)
@ -205,9 +205,9 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) {
f := cs.parseFunc(b, d)
if b == &cs.global {
switch d.Name.Name {
case vertexEntry:
case cs.vertexEntry:
cs.ir.VertexFunc.Block = f.ir.Block
case fragmentEntry:
case cs.fragmentEntry:
cs.ir.FragmentFunc.Block = f.ir.Block
default:
b.funcs = append(b.funcs, f)
@ -328,7 +328,7 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function {
if block == &cs.global {
switch d.Name.Name {
case vertexEntry:
case cs.vertexEntry:
for _, t := range inT {
cs.ir.Attributes = append(cs.ir.Attributes, t)
}
@ -352,7 +352,7 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function {
}
}
cs.varyingParsed = true
case fragmentEntry:
case cs.fragmentEntry:
if len(inParams) == 0 {
cs.addError(d.Pos(), fmt.Sprintf("fragment entry point must have at least one vec4 parameter for a position"))
return function{}

View File

@ -211,7 +211,7 @@ void main(void) {
}
for _, tc := range tests {
t.Run(tc.Name, func(t *testing.T) {
s, err := Compile([]byte(tc.Src))
s, err := Compile([]byte(tc.Src), "Vertex", "Fragment")
if err != nil {
t.Error(err)
return

View File

@ -24,7 +24,7 @@ type Shader struct {
}
func NewShader(src []byte) (*Shader, error) {
s, err := shader.Compile(src)
s, err := shader.Compile(src, "Vertex", "Fragment")
if err != nil {
return nil, err
}