From fdddb0e2658dde0278f83835247d9b4e9e8afa11 Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Tue, 11 Feb 2025 16:27:08 +0900 Subject: [PATCH] audio/internal/convert: bug fix: make Resampling work with non-seekable source Closes #3193 --- audio/internal/convert/resampling.go | 83 ++++++++++------ audio/internal/convert/resampling_test.go | 112 ++++++++++++++-------- 2 files changed, 125 insertions(+), 70 deletions(-) diff --git a/audio/internal/convert/resampling.go b/audio/internal/convert/resampling.go index 967efd69a..c52a6ac03 100644 --- a/audio/internal/convert/resampling.go +++ b/audio/internal/convert/resampling.go @@ -15,6 +15,7 @@ package convert import ( + "fmt" "io" "math" "sync" @@ -78,7 +79,7 @@ func sinc01(x float64) float64 { } type Resampling struct { - source io.ReadSeeker + source io.Reader size int64 from int to int @@ -88,9 +89,11 @@ type Resampling struct { srcBufL map[int64][]float64 srcBufR map[int64][]float64 lruSrcBlocks []int64 + eof bool + eofBufIndex int64 } -func NewResampling(source io.ReadSeeker, size int64, from, to int, bitDepthInBytes int) *Resampling { +func NewResampling(source io.Reader, size int64, from, to int, bitDepthInBytes int) *Resampling { r := &Resampling{ source: source, size: size, @@ -100,6 +103,7 @@ func NewResampling(source io.ReadSeeker, size int64, from, to int, bitDepthInByt srcBlock: -1, srcBufL: map[int64][]float64{}, srcBufR: map[int64][]float64{}, + eofBufIndex: -1, } return r } @@ -121,23 +125,25 @@ func (r *Resampling) src(i int64) (float64, float64, error) { return 0, 0, nil } sizePerSample := int64(r.bytesPerSample()) - if r.size/sizePerSample <= i { - return 0, 0, nil - } nextPos := int64(i) / resamplingBufferSize if _, ok := r.srcBufL[nextPos]; !ok { if r.srcBlock+1 != nextPos { - if _, err := r.source.Seek(nextPos*resamplingBufferSize*sizePerSample, io.SeekStart); err != nil { + seeker, ok := r.source.(io.Seeker) + if !ok { + return 0, 0, fmt.Errorf("convert: source must be io.Seeker") + } + if _, err := seeker.Seek(nextPos*resamplingBufferSize*sizePerSample, io.SeekStart); err != nil { return 0, 0, err } } buf := make([]byte, resamplingBufferSize*sizePerSample) - c := 0 + var c int for c < len(buf) { n, err := r.source.Read(buf[c:]) c += n if err != nil { if err == io.EOF { + r.eofBufIndex = nextPos break } return 0, 0, err @@ -188,7 +194,16 @@ func (r *Resampling) src(i int64) (float64, float64, error) { r.lruSrcBlocks = append(r.lruSrcBlocks, r.srcBlock) } ii := i % resamplingBufferSize - return r.srcBufL[r.srcBlock][ii], r.srcBufR[r.srcBlock][ii], nil + var err error + if r.eofBufIndex == r.srcBlock && ii >= int64(len(r.srcBufL[r.srcBlock])-1) { + err = io.EOF + } + if _, ok := r.source.(io.Seeker); ok { + if r.size/sizePerSample <= i { + err = io.EOF + } + } + return r.srcBufL[r.srcBlock][ii], r.srcBufR[r.srcBlock][ii], err } func (r *Resampling) at(t int64) (float64, float64, error) { @@ -198,21 +213,18 @@ func (r *Resampling) at(t int64) (float64, float64, error) { if startN < 0 { startN = 0 } - sizePerSample := int64(r.bytesPerSample()) - if r.size/sizePerSample <= startN { - startN = r.size/sizePerSample - 1 - } endN := int64(tInSrc + windowSize) - if r.size/sizePerSample <= endN { - endN = r.size/sizePerSample - 1 - } lv := 0.0 rv := 0.0 + var eof bool for n := startN; n <= endN; n++ { srcL, srcR, err := r.src(n) - if err != nil { + if err != nil && err != io.EOF { return 0, 0, err } + if err == io.EOF { + eof = true + } d := tInSrc - float64(n) w := 0.5 + 0.5*fastCos01(d/(windowSize*2+1)) s := sinc01(d/2) * w @@ -231,27 +243,31 @@ func (r *Resampling) at(t int64) (float64, float64, error) { if rv > 1 { rv = 1 } + if eof { + return lv, rv, io.EOF + } return lv, rv, nil } func (r *Resampling) Read(b []byte) (int, error) { - if r.pos == r.Length() { + if r.eof { return 0, io.EOF } + size := r.bytesPerSample() n := len(b) / size * size - if r.Length()-r.pos <= int64(n) { - n = int(r.Length() - r.pos) - } switch r.bitDepthInBytes { case 2: for i := 0; i < n/size; i++ { - l, r, err := r.at(r.pos/int64(size) + int64(i)) - if err != nil { + ldata, rdata, err := r.at(r.pos/int64(size) + int64(i)) + if err != nil && err != io.EOF { return 0, err } - l16 := int16(l * (1<<15 - 1)) - r16 := int16(r * (1<<15 - 1)) + if err == io.EOF { + r.eof = true + } + l16 := int16(ldata * (1<<15 - 1)) + r16 := int16(rdata * (1<<15 - 1)) b[4*i] = byte(l16) b[4*i+1] = byte(l16 >> 8) b[4*i+2] = byte(r16) @@ -259,12 +275,15 @@ func (r *Resampling) Read(b []byte) (int, error) { } case 4: for i := 0; i < n/size; i++ { - l, r, err := r.at(r.pos/int64(size) + int64(i)) - if err != nil { + ldata, rdata, err := r.at(r.pos/int64(size) + int64(i)) + if err != nil && err != io.EOF { return 0, err } - l32 := float32(l) - r32 := float32(r) + if err == io.EOF { + r.eof = true + } + l32 := float32(ldata) + r32 := float32(rdata) l32b := math.Float32bits(l32) r32b := math.Float32bits(r32) b[8*i] = byte(l32b) @@ -280,10 +299,18 @@ func (r *Resampling) Read(b []byte) (int, error) { panic("not reached") } r.pos += int64(n) + if r.eof { + return n, io.EOF + } return n, nil } func (r *Resampling) Seek(offset int64, whence int) (int64, error) { + if _, ok := r.source.(io.Seeker); !ok { + return 0, fmt.Errorf("convert: source must be io.Seeker") + } + + r.eof = false switch whence { case io.SeekStart: r.pos = offset diff --git a/audio/internal/convert/resampling_test.go b/audio/internal/convert/resampling_test.go index e8bc17c4f..4771a4f97 100644 --- a/audio/internal/convert/resampling_test.go +++ b/audio/internal/convert/resampling_test.go @@ -67,6 +67,14 @@ func newSoundBytes(sampleRate int, bitDepthInBytes int) []byte { return b } +type reader struct { + r io.Reader +} + +func (r *reader) Read(buf []byte) (int, error) { + return r.r.Read(buf) +} + func TestResampling(t *testing.T) { cases := []struct { In int @@ -87,51 +95,71 @@ func TestResampling(t *testing.T) { for _, bitDepthInBytes := range []int{2, 4} { bitDepthInBytes := bitDepthInBytes t.Run(fmt.Sprintf("bitDepthInBytes=%d", bitDepthInBytes), func(t *testing.T) { - inB := newSoundBytes(c.In, bitDepthInBytes) - outS := convert.NewResampling(bytes.NewReader(inB), int64(len(inB)), c.In, c.Out, bitDepthInBytes) - var gotB []byte - for { - var buf [97]byte - n, err := outS.Read(buf[:]) - gotB = append(gotB, buf[:n]...) - if err != nil { - if err != io.EOF { - t.Fatal(err) + for _, seek := range []bool{false, true} { + t.Run(fmt.Sprintf("seek=%v", seek), func(t *testing.T) { + inB := newSoundBytes(c.In, bitDepthInBytes) + l := int64(len(inB)) + if !seek { + l = 0 } - break - } - cur, err := outS.Seek(0, io.SeekCurrent) - if err != nil { - t.Fatal(err) - } - // Shifting by incomplete bytes should not affect the result. - for i := 0; i < bitDepthInBytes*2; i++ { - pos, err := outS.Seek(int64(i), io.SeekCurrent) - if err != nil { - t.Fatal(err) + var src io.Reader = bytes.NewReader(inB) + if !seek { + src = &reader{r: src} } - if cur != pos { - t.Errorf("cur: %d, pos: %d", cur, pos) + outS := convert.NewResampling(src, l, c.In, c.Out, bitDepthInBytes) + var gotB []byte + for { + var buf [97]byte + n, err := outS.Read(buf[:]) + gotB = append(gotB, buf[:n]...) + if err != nil { + if err != io.EOF { + t.Fatal(err) + } + break + } + if seek { + cur, err := outS.Seek(0, io.SeekCurrent) + if err != nil { + t.Fatal(err) + } + // Shifting by incomplete bytes should not affect the result. + for i := 0; i < bitDepthInBytes*2; i++ { + pos, err := outS.Seek(int64(i), io.SeekCurrent) + if err != nil { + t.Fatal(err) + } + if cur != pos { + t.Errorf("cur: %d, pos: %d", cur, pos) + } + } + } } - } - } - wantB := newSoundBytes(c.Out, bitDepthInBytes) - if len(gotB) != len(wantB) { - t.Errorf("len(gotB) == %d but len(wantB) == %d", len(gotB), len(wantB)) - } - for i := 0; i < len(gotB)/bitDepthInBytes; i++ { - var got, want float64 - switch bitDepthInBytes { - case 2: - got = float64(int16(gotB[2*i])|(int16(gotB[2*i+1])<<8)) / (1<<15 - 1) - want = float64(int16(wantB[2*i])|(int16(wantB[2*i+1])<<8)) / (1<<15 - 1) - case 4: - got = float64(math.Float32frombits(uint32(gotB[4*i]) | (uint32(gotB[4*i+1]) << 8) | (uint32(gotB[4*i+2]) << 16) | (uint32(gotB[4*i+3]) << 24))) - want = float64(math.Float32frombits(uint32(wantB[4*i]) | (uint32(wantB[4*i+1]) << 8) | (uint32(wantB[4*i+2]) << 16) | (uint32(wantB[4*i+3]) << 24))) - } - if math.Abs(got-want) > 0.025 { - t.Errorf("sample rate: %d, index: %d: got: %f, want: %f", c.Out, i, got, want) - } + wantB := newSoundBytes(c.Out, bitDepthInBytes) + // 256 is an arbitrary number. + // In most cases, len(gotB) must >= len(wantB), but there are some numerical errors. + if len(gotB) < len(wantB)-256 { + t.Errorf("len(gotB) >= len(wantB) - 256, but len(gotB) == %d, len(wantB) == %d", len(gotB), len(wantB)) + } + for i := 0; i < len(gotB)/bitDepthInBytes; i++ { + var got, want float64 + switch bitDepthInBytes { + case 2: + got = float64(int16(gotB[2*i])|(int16(gotB[2*i+1])<<8)) / (1<<15 - 1) + if i < len(wantB)/2 { + want = float64(int16(wantB[2*i])|(int16(wantB[2*i+1])<<8)) / (1<<15 - 1) + } + case 4: + got = float64(math.Float32frombits(uint32(gotB[4*i]) | (uint32(gotB[4*i+1]) << 8) | (uint32(gotB[4*i+2]) << 16) | (uint32(gotB[4*i+3]) << 24))) + if i < len(wantB)/4 { + want = float64(math.Float32frombits(uint32(wantB[4*i]) | (uint32(wantB[4*i+1]) << 8) | (uint32(wantB[4*i+2]) << 16) | (uint32(wantB[4*i+3]) << 24))) + } + } + if math.Abs(got-want) > 0.025 { + t.Errorf("sample rate: %d, index: %d: got: %f, want: %f", c.Out, i, got, want) + } + } + }) } }) }