diff --git a/colorm.go b/colorm.go index ca6e6c1b6..03af81136 100644 --- a/colorm.go +++ b/colorm.go @@ -104,11 +104,7 @@ func (c *ColorM) ChangeHSV(hueTheta float64, saturationScale float64, valueScale // Element returns a value of a matrix at (i, j). func (c *ColorM) Element(i, j int) float64 { - b, t := c.impl.UnsafeElements() - if j < ColorMDim-1 { - return float64(b[i+j*(ColorMDim-1)]) - } - return float64(t[i]) + return float64(c.impl.Element(i, j)) } // SetElement sets an element at (i, j). diff --git a/internal/affine/colorm.go b/internal/affine/colorm.go index b1e2405c4..b112883a4 100644 --- a/internal/affine/colorm.go +++ b/internal/affine/colorm.go @@ -319,6 +319,15 @@ func (c *ColorM) Invert() *ColorM { return m } +// Element returns a value of a matrix at (i, j). +func (c *ColorM) Element(i, j int) float32 { + b, t := c.UnsafeElements() + if j < ColorMDim-1 { + return b[i+j*(ColorMDim-1)] + } + return t[i] +} + // SetElement sets an element at (i, j). func (c *ColorM) SetElement(i, j int, element float32) *ColorM { newC := &ColorM{ diff --git a/internal/affine/colorm_test.go b/internal/affine/colorm_test.go index 3ea8b3afc..4e254acc8 100644 --- a/internal/affine/colorm_test.go +++ b/internal/affine/colorm_test.go @@ -136,6 +136,26 @@ func arrayToColorM(es [4][5]float32) *ColorM { return a } +func abs(x float32) float32 { + if x < 0 { + return -x + } + return x +} + +func equalWithDelta(a, b *ColorM, delta float32) bool { + for j := 0; j < 5; j++ { + for i := 0; i < 4; i++ { + ea := a.Element(i, j) + eb := b.Element(i, j) + if abs(ea-eb) > delta { + return false + } + } + } + return true +} + func TestColorMInvert(t *testing.T) { cases := []struct { In *ColorM @@ -159,10 +179,24 @@ func TestColorMInvert(t *testing.T) { {7, -4, -2, 1, 0}, }), }, + { + In: arrayToColorM([4][5]float32{ + {1, 2, 3, 4, 5}, + {5, 1, 2, 3, 4}, + {4, 5, 1, 2, 3}, + {3, 4, 5, 1, 2}, + }), + Out: arrayToColorM([4][5]float32{ + {-6 / 35.0, 3 / 14.0, 1 / 70.0, 1 / 70.0, -1 / 14.0}, + {1 / 35.0, -13 / 70.0, 3 / 14.0, 1 / 70.0, -1 / 14.0}, + {1 / 35.0, 1 / 70.0, -13 / 70.0, 3 / 14.0, -1 / 14.0}, + {9 / 35.0, 1 / 35.0, 1 / 35.0, -6 / 35.0, -8 / 7.0}, + }), + }, } for _, c := range cases { - if got, want := c.In.Invert(), c.Out; !got.Equals(want) { + if got, want := c.In.Invert(), c.Out; !equalWithDelta(got, want, 1e-6) { t.Errorf("got: %v, want: %v", got, want) } }