From 0415773b940719905b0d3242d38799a2773d4b6e Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Fri, 21 Jan 2022 01:59:54 +0900 Subject: [PATCH] internal/shader: bug fix: allow the *= operator for a vector and a matrix Updates #1971 --- internal/shader/stmt.go | 7 ++++++ shader_test.go | 54 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index 76b18a0be..2199a32c0 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -111,6 +111,13 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: if rts[0].Main == shaderir.Float { // OK + } else if op == shaderir.Mul && ((lts[0].Main == shaderir.Vec2 && rts[0].Main == shaderir.Mat2) || + (lts[0].Main == shaderir.Vec3 && rts[0].Main == shaderir.Mat3) || + (lts[0].Main == shaderir.Vec4 && rts[0].Main == shaderir.Mat4) || + (lts[0].Main == shaderir.Mat2 && rts[0].Main == shaderir.Vec2) || + (lts[0].Main == shaderir.Mat3 && rts[0].Main == shaderir.Vec3) || + (lts[0].Main == shaderir.Mat4 && rts[0].Main == shaderir.Vec4)) { + // OK } else if rhs[0].Const != nil && rhs[0].Const.Kind() == gconstant.Int { rhs[0].Const = gconstant.ToFloat(rhs[0].Const) rhs[0].ConstType = shaderir.ConstTypeFloat diff --git a/shader_test.go b/shader_test.go index 41a889d0b..0bea61f78 100644 --- a/shader_test.go +++ b/shader_test.go @@ -1600,3 +1600,57 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { } } } + +// Issue #1971 +func TestShaderOperatorMultiplyAssign(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "a := 1.0; a *= 2", err: false}, + {stmt: "a := 1.0; a *= 2.0", err: false}, + {stmt: "a := 1.0; a *= int(2)", err: true}, + {stmt: "a := 1.0; a *= vec2(2)", err: true}, + {stmt: "a := 1.0; a *= vec3(2)", err: true}, + {stmt: "a := 1.0; a *= vec4(2)", err: true}, + {stmt: "a := 1.0; a *= mat2(2)", err: true}, + {stmt: "a := 1.0; a *= mat3(2)", err: true}, + {stmt: "a := 1.0; a *= mat4(2)", err: true}, + {stmt: "a := vec2(1); a *= 2", err: false}, + {stmt: "a := vec2(1); a *= 2.0", err: false}, + {stmt: "a := vec2(1); a *= int(2)", err: true}, + {stmt: "a := vec2(1); a *= vec2(2)", err: false}, + {stmt: "a := vec2(1); a += vec2(2)", err: false}, + {stmt: "a := vec2(1); a *= vec3(2)", err: true}, + {stmt: "a := vec2(1); a *= vec4(2)", err: true}, + {stmt: "a := vec2(1); a *= mat2(2)", err: false}, + {stmt: "a := vec2(1); a += mat2(2)", err: true}, + {stmt: "a := vec2(1); a *= mat3(2)", err: true}, + {stmt: "a := vec2(1); a *= mat4(2)", err: true}, + {stmt: "a := mat2(1); a *= 2", err: false}, + {stmt: "a := mat2(1); a *= 2.0", err: false}, + {stmt: "a := mat2(1); a *= int(2)", err: true}, + {stmt: "a := mat2(1); a *= vec2(2)", err: false}, + {stmt: "a := mat2(1); a += vec2(2)", err: true}, + {stmt: "a := mat2(1); a *= vec3(2)", err: true}, + {stmt: "a := mat2(1); a *= vec4(2)", err: true}, + {stmt: "a := mat2(1); a *= mat2(2)", err: false}, + {stmt: "a := mat2(1); a += mat2(2)", err: false}, + {stmt: "a := mat2(1); a *= mat3(2)", err: true}, + {stmt: "a := mat2(1); a *= mat4(2)", err: true}, + } + + for _, c := range cases { + _, err := ebiten.NewShader([]byte(fmt.Sprintf(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + %s + return position +}`, c.stmt))) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", c.stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", c.stmt, err) + } + } +}