diff --git a/audio/wav/decode.go b/audio/wav/decode.go index bb97aa616..1dc0489b3 100644 --- a/audio/wav/decode.go +++ b/audio/wav/decode.go @@ -114,7 +114,11 @@ func (s *stream) Seek(offset int64, whence int) (int64, error) { // A Stream doesn't close src even if src implements io.Closer. // Closing the source is src owner's responsibility. func DecodeWithoutResampling(src io.Reader) (*Stream, error) { - return decode(src, nil) + s, _, err := decode(src) + if err != nil { + return nil, err + } + return s, nil } // DecodeWithSampleRate decodes WAV (RIFF) data to playable stream. @@ -131,41 +135,53 @@ func DecodeWithoutResampling(src io.Reader) (*Stream, error) { // A Stream doesn't close src even if src implements io.Closer. // Closing the source is src owner's responsibility. func DecodeWithSampleRate(sampleRate int, src io.Reader) (*Stream, error) { - return decode(src, &sampleRate) -} - -func decode(src io.Reader, sampleRate *int) (*Stream, error) { - buf := make([]byte, 12) - n, err := io.ReadFull(src, buf) - if n != len(buf) { - return nil, fmt.Errorf("wav: invalid header") - } + s, origSampleRate, err := decode(src) if err != nil { return nil, err } + + if sampleRate == origSampleRate { + return s, nil + } + + r := convert.NewResampling(s.inner, s.size, origSampleRate, sampleRate) + return &Stream{ + inner: r, + size: r.Length(), + }, nil +} + +func decode(src io.Reader) (*Stream, int, error) { + buf := make([]byte, 12) + n, err := io.ReadFull(src, buf) + if n != len(buf) { + return nil, 0, fmt.Errorf("wav: invalid header") + } + if err != nil { + return nil, 0, err + } if !bytes.Equal(buf[0:4], []byte("RIFF")) { - return nil, fmt.Errorf("wav: invalid header: 'RIFF' not found") + return nil, 0, fmt.Errorf("wav: invalid header: 'RIFF' not found") } if !bytes.Equal(buf[8:12], []byte("WAVE")) { - return nil, fmt.Errorf("wav: invalid header: 'WAVE' not found") + return nil, 0, fmt.Errorf("wav: invalid header: 'WAVE' not found") } // Read chunks - dataSize := int64(0) + var dataSize int64 headerSize := int64(len(buf)) - sampleRateFrom := 0 - sampleRateTo := 0 - mono := false - bitsPerSample := 0 + var mono bool + var bitsPerSample int + var sampleRate int chunks: for { buf := make([]byte, 8) n, err := io.ReadFull(src, buf) if n != len(buf) { - return nil, fmt.Errorf("wav: invalid header") + return nil, 0, fmt.Errorf("wav: invalid header") } if err != nil { - return nil, err + return nil, 0, err } headerSize += 8 size := int64(buf[4]) | int64(buf[5])<<8 | int64(buf[6])<<16 | int64(buf[7])<<24 @@ -173,19 +189,19 @@ chunks: case bytes.Equal(buf[0:4], []byte("fmt ")): // Size of 'fmt' header is usually 16, but can be more than 16. if size < 16 { - return nil, fmt.Errorf("wav: invalid header: maybe non-PCM file?") + return nil, 0, fmt.Errorf("wav: invalid header: maybe non-PCM file?") } buf := make([]byte, size) n, err := io.ReadFull(src, buf) if n != len(buf) { - return nil, fmt.Errorf("wav: invalid header") + return nil, 0, fmt.Errorf("wav: invalid header") } if err != nil { - return nil, err + return nil, 0, err } format := int(buf[0]) | int(buf[1])<<8 if format != 1 { - return nil, fmt.Errorf("wav: format must be linear PCM") + return nil, 0, fmt.Errorf("wav: format must be linear PCM") } channelCount := int(buf[2]) | int(buf[3])<<8 switch channelCount { @@ -194,17 +210,13 @@ chunks: case 2: mono = false default: - return nil, fmt.Errorf("wav: number of channels must be 1 or 2 but was %d", channelCount) + return nil, 0, fmt.Errorf("wav: number of channels must be 1 or 2 but was %d", channelCount) } bitsPerSample = int(buf[14]) | int(buf[15])<<8 if bitsPerSample != 8 && bitsPerSample != 16 { - return nil, fmt.Errorf("wav: bits per sample must be 8 or 16 but was %d", bitsPerSample) - } - origSampleRate := int64(buf[4]) | int64(buf[5])<<8 | int64(buf[6])<<16 | int64(buf[7])<<24 - if sampleRate != nil && int64(*sampleRate) != origSampleRate { - sampleRateFrom = int(origSampleRate) - sampleRateTo = *sampleRate + return nil, 0, fmt.Errorf("wav: bits per sample must be 8 or 16 but was %d", bitsPerSample) } + sampleRate = int(buf[4]) | int(buf[5])<<8 | int(buf[6])<<16 | int(buf[7])<<24 headerSize += size case bytes.Equal(buf[0:4], []byte("data")): dataSize = size @@ -213,10 +225,10 @@ chunks: buf := make([]byte, size) n, err := io.ReadFull(src, buf) if n != len(buf) { - return nil, fmt.Errorf("wav: invalid header") + return nil, 0, fmt.Errorf("wav: invalid header") } if err != nil { - return nil, err + return nil, 0, err } headerSize += size } @@ -237,13 +249,7 @@ chunks: dataSize *= 2 } } - if sampleRateFrom != sampleRateTo { - r := convert.NewResampling(s, dataSize, sampleRateFrom, sampleRateTo) - s = r - dataSize = r.Length() - } - ss := &Stream{inner: s, size: dataSize} - return ss, nil + return &Stream{inner: s, size: dataSize}, sampleRate, nil } // Decode decodes WAV (RIFF) data to playable stream.