audio/internal/convert: bug fix: make Resampling work with non-seekable source

Closes #3193
This commit is contained in:
Hajime Hoshi 2025-02-11 16:27:08 +09:00
parent 4df5d02c36
commit fdddb0e265
2 changed files with 125 additions and 70 deletions

View File

@ -15,6 +15,7 @@
package convert
import (
"fmt"
"io"
"math"
"sync"
@ -78,7 +79,7 @@ func sinc01(x float64) float64 {
}
type Resampling struct {
source io.ReadSeeker
source io.Reader
size int64
from int
to int
@ -88,9 +89,11 @@ type Resampling struct {
srcBufL map[int64][]float64
srcBufR map[int64][]float64
lruSrcBlocks []int64
eof bool
eofBufIndex int64
}
func NewResampling(source io.ReadSeeker, size int64, from, to int, bitDepthInBytes int) *Resampling {
func NewResampling(source io.Reader, size int64, from, to int, bitDepthInBytes int) *Resampling {
r := &Resampling{
source: source,
size: size,
@ -100,6 +103,7 @@ func NewResampling(source io.ReadSeeker, size int64, from, to int, bitDepthInByt
srcBlock: -1,
srcBufL: map[int64][]float64{},
srcBufR: map[int64][]float64{},
eofBufIndex: -1,
}
return r
}
@ -121,23 +125,25 @@ func (r *Resampling) src(i int64) (float64, float64, error) {
return 0, 0, nil
}
sizePerSample := int64(r.bytesPerSample())
if r.size/sizePerSample <= i {
return 0, 0, nil
}
nextPos := int64(i) / resamplingBufferSize
if _, ok := r.srcBufL[nextPos]; !ok {
if r.srcBlock+1 != nextPos {
if _, err := r.source.Seek(nextPos*resamplingBufferSize*sizePerSample, io.SeekStart); err != nil {
seeker, ok := r.source.(io.Seeker)
if !ok {
return 0, 0, fmt.Errorf("convert: source must be io.Seeker")
}
if _, err := seeker.Seek(nextPos*resamplingBufferSize*sizePerSample, io.SeekStart); err != nil {
return 0, 0, err
}
}
buf := make([]byte, resamplingBufferSize*sizePerSample)
c := 0
var c int
for c < len(buf) {
n, err := r.source.Read(buf[c:])
c += n
if err != nil {
if err == io.EOF {
r.eofBufIndex = nextPos
break
}
return 0, 0, err
@ -188,7 +194,16 @@ func (r *Resampling) src(i int64) (float64, float64, error) {
r.lruSrcBlocks = append(r.lruSrcBlocks, r.srcBlock)
}
ii := i % resamplingBufferSize
return r.srcBufL[r.srcBlock][ii], r.srcBufR[r.srcBlock][ii], nil
var err error
if r.eofBufIndex == r.srcBlock && ii >= int64(len(r.srcBufL[r.srcBlock])-1) {
err = io.EOF
}
if _, ok := r.source.(io.Seeker); ok {
if r.size/sizePerSample <= i {
err = io.EOF
}
}
return r.srcBufL[r.srcBlock][ii], r.srcBufR[r.srcBlock][ii], err
}
func (r *Resampling) at(t int64) (float64, float64, error) {
@ -198,21 +213,18 @@ func (r *Resampling) at(t int64) (float64, float64, error) {
if startN < 0 {
startN = 0
}
sizePerSample := int64(r.bytesPerSample())
if r.size/sizePerSample <= startN {
startN = r.size/sizePerSample - 1
}
endN := int64(tInSrc + windowSize)
if r.size/sizePerSample <= endN {
endN = r.size/sizePerSample - 1
}
lv := 0.0
rv := 0.0
var eof bool
for n := startN; n <= endN; n++ {
srcL, srcR, err := r.src(n)
if err != nil {
if err != nil && err != io.EOF {
return 0, 0, err
}
if err == io.EOF {
eof = true
}
d := tInSrc - float64(n)
w := 0.5 + 0.5*fastCos01(d/(windowSize*2+1))
s := sinc01(d/2) * w
@ -231,27 +243,31 @@ func (r *Resampling) at(t int64) (float64, float64, error) {
if rv > 1 {
rv = 1
}
if eof {
return lv, rv, io.EOF
}
return lv, rv, nil
}
func (r *Resampling) Read(b []byte) (int, error) {
if r.pos == r.Length() {
if r.eof {
return 0, io.EOF
}
size := r.bytesPerSample()
n := len(b) / size * size
if r.Length()-r.pos <= int64(n) {
n = int(r.Length() - r.pos)
}
switch r.bitDepthInBytes {
case 2:
for i := 0; i < n/size; i++ {
l, r, err := r.at(r.pos/int64(size) + int64(i))
if err != nil {
ldata, rdata, err := r.at(r.pos/int64(size) + int64(i))
if err != nil && err != io.EOF {
return 0, err
}
l16 := int16(l * (1<<15 - 1))
r16 := int16(r * (1<<15 - 1))
if err == io.EOF {
r.eof = true
}
l16 := int16(ldata * (1<<15 - 1))
r16 := int16(rdata * (1<<15 - 1))
b[4*i] = byte(l16)
b[4*i+1] = byte(l16 >> 8)
b[4*i+2] = byte(r16)
@ -259,12 +275,15 @@ func (r *Resampling) Read(b []byte) (int, error) {
}
case 4:
for i := 0; i < n/size; i++ {
l, r, err := r.at(r.pos/int64(size) + int64(i))
if err != nil {
ldata, rdata, err := r.at(r.pos/int64(size) + int64(i))
if err != nil && err != io.EOF {
return 0, err
}
l32 := float32(l)
r32 := float32(r)
if err == io.EOF {
r.eof = true
}
l32 := float32(ldata)
r32 := float32(rdata)
l32b := math.Float32bits(l32)
r32b := math.Float32bits(r32)
b[8*i] = byte(l32b)
@ -280,10 +299,18 @@ func (r *Resampling) Read(b []byte) (int, error) {
panic("not reached")
}
r.pos += int64(n)
if r.eof {
return n, io.EOF
}
return n, nil
}
func (r *Resampling) Seek(offset int64, whence int) (int64, error) {
if _, ok := r.source.(io.Seeker); !ok {
return 0, fmt.Errorf("convert: source must be io.Seeker")
}
r.eof = false
switch whence {
case io.SeekStart:
r.pos = offset

View File

@ -67,6 +67,14 @@ func newSoundBytes(sampleRate int, bitDepthInBytes int) []byte {
return b
}
type reader struct {
r io.Reader
}
func (r *reader) Read(buf []byte) (int, error) {
return r.r.Read(buf)
}
func TestResampling(t *testing.T) {
cases := []struct {
In int
@ -87,51 +95,71 @@ func TestResampling(t *testing.T) {
for _, bitDepthInBytes := range []int{2, 4} {
bitDepthInBytes := bitDepthInBytes
t.Run(fmt.Sprintf("bitDepthInBytes=%d", bitDepthInBytes), func(t *testing.T) {
inB := newSoundBytes(c.In, bitDepthInBytes)
outS := convert.NewResampling(bytes.NewReader(inB), int64(len(inB)), c.In, c.Out, bitDepthInBytes)
var gotB []byte
for {
var buf [97]byte
n, err := outS.Read(buf[:])
gotB = append(gotB, buf[:n]...)
if err != nil {
if err != io.EOF {
t.Fatal(err)
for _, seek := range []bool{false, true} {
t.Run(fmt.Sprintf("seek=%v", seek), func(t *testing.T) {
inB := newSoundBytes(c.In, bitDepthInBytes)
l := int64(len(inB))
if !seek {
l = 0
}
break
}
cur, err := outS.Seek(0, io.SeekCurrent)
if err != nil {
t.Fatal(err)
}
// Shifting by incomplete bytes should not affect the result.
for i := 0; i < bitDepthInBytes*2; i++ {
pos, err := outS.Seek(int64(i), io.SeekCurrent)
if err != nil {
t.Fatal(err)
var src io.Reader = bytes.NewReader(inB)
if !seek {
src = &reader{r: src}
}
if cur != pos {
t.Errorf("cur: %d, pos: %d", cur, pos)
outS := convert.NewResampling(src, l, c.In, c.Out, bitDepthInBytes)
var gotB []byte
for {
var buf [97]byte
n, err := outS.Read(buf[:])
gotB = append(gotB, buf[:n]...)
if err != nil {
if err != io.EOF {
t.Fatal(err)
}
break
}
if seek {
cur, err := outS.Seek(0, io.SeekCurrent)
if err != nil {
t.Fatal(err)
}
// Shifting by incomplete bytes should not affect the result.
for i := 0; i < bitDepthInBytes*2; i++ {
pos, err := outS.Seek(int64(i), io.SeekCurrent)
if err != nil {
t.Fatal(err)
}
if cur != pos {
t.Errorf("cur: %d, pos: %d", cur, pos)
}
}
}
}
}
}
wantB := newSoundBytes(c.Out, bitDepthInBytes)
if len(gotB) != len(wantB) {
t.Errorf("len(gotB) == %d but len(wantB) == %d", len(gotB), len(wantB))
}
for i := 0; i < len(gotB)/bitDepthInBytes; i++ {
var got, want float64
switch bitDepthInBytes {
case 2:
got = float64(int16(gotB[2*i])|(int16(gotB[2*i+1])<<8)) / (1<<15 - 1)
want = float64(int16(wantB[2*i])|(int16(wantB[2*i+1])<<8)) / (1<<15 - 1)
case 4:
got = float64(math.Float32frombits(uint32(gotB[4*i]) | (uint32(gotB[4*i+1]) << 8) | (uint32(gotB[4*i+2]) << 16) | (uint32(gotB[4*i+3]) << 24)))
want = float64(math.Float32frombits(uint32(wantB[4*i]) | (uint32(wantB[4*i+1]) << 8) | (uint32(wantB[4*i+2]) << 16) | (uint32(wantB[4*i+3]) << 24)))
}
if math.Abs(got-want) > 0.025 {
t.Errorf("sample rate: %d, index: %d: got: %f, want: %f", c.Out, i, got, want)
}
wantB := newSoundBytes(c.Out, bitDepthInBytes)
// 256 is an arbitrary number.
// In most cases, len(gotB) must >= len(wantB), but there are some numerical errors.
if len(gotB) < len(wantB)-256 {
t.Errorf("len(gotB) >= len(wantB) - 256, but len(gotB) == %d, len(wantB) == %d", len(gotB), len(wantB))
}
for i := 0; i < len(gotB)/bitDepthInBytes; i++ {
var got, want float64
switch bitDepthInBytes {
case 2:
got = float64(int16(gotB[2*i])|(int16(gotB[2*i+1])<<8)) / (1<<15 - 1)
if i < len(wantB)/2 {
want = float64(int16(wantB[2*i])|(int16(wantB[2*i+1])<<8)) / (1<<15 - 1)
}
case 4:
got = float64(math.Float32frombits(uint32(gotB[4*i]) | (uint32(gotB[4*i+1]) << 8) | (uint32(gotB[4*i+2]) << 16) | (uint32(gotB[4*i+3]) << 24)))
if i < len(wantB)/4 {
want = float64(math.Float32frombits(uint32(wantB[4*i]) | (uint32(wantB[4*i+1]) << 8) | (uint32(wantB[4*i+2]) << 16) | (uint32(wantB[4*i+3]) << 24)))
}
}
if math.Abs(got-want) > 0.025 {
t.Errorf("sample rate: %d, index: %d: got: %f, want: %f", c.Out, i, got, want)
}
}
})
}
})
}