diff --git a/audio/wav/decode.go b/audio/wav/decode.go index abd8593b7..097795d3f 100644 --- a/audio/wav/decode.go +++ b/audio/wav/decode.go @@ -26,8 +26,9 @@ import ( // Stream is a decoded audio stream. type Stream struct { - inner io.ReadSeeker - size int64 + inner io.ReadSeeker + size int64 + sampleRate int } // Read is implementation of io.Reader's Read. @@ -49,6 +50,11 @@ func (s *Stream) Length() int64 { return s.size } +// SampleRate returns the sample rate of the decoded stream. +func (s *Stream) SampleRate() int { + return s.sampleRate +} + type stream struct { src io.Reader headerSize int64 @@ -114,7 +120,7 @@ func (s *stream) Seek(offset int64, whence int) (int64, error) { // A Stream doesn't close src even if src implements io.Closer. // Closing the source is src owner's responsibility. func DecodeWithoutResampling(src io.Reader) (*Stream, error) { - s, _, err := decode(src) + s, err := decode(src) if err != nil { return nil, err } @@ -138,36 +144,37 @@ func DecodeWithoutResampling(src io.Reader) (*Stream, error) { // Resampling can be a very heavy task. Stream has a cache for resampling, but the size is limited. // Do not expect that Stream has a resampling cache even after whole data is played. func DecodeWithSampleRate(sampleRate int, src io.Reader) (*Stream, error) { - s, origSampleRate, err := decode(src) + s, err := decode(src) if err != nil { return nil, err } - if sampleRate == origSampleRate { + if sampleRate == s.sampleRate { return s, nil } - r := convert.NewResampling(s.inner, s.size, origSampleRate, sampleRate) + r := convert.NewResampling(s.inner, s.size, s.sampleRate, sampleRate) return &Stream{ - inner: r, - size: r.Length(), + inner: r, + size: r.Length(), + sampleRate: sampleRate, }, nil } -func decode(src io.Reader) (*Stream, int, error) { +func decode(src io.Reader) (*Stream, error) { buf := make([]byte, 12) n, err := io.ReadFull(src, buf) if n != len(buf) { - return nil, 0, fmt.Errorf("wav: invalid header") + return nil, fmt.Errorf("wav: invalid header") } if err != nil { - return nil, 0, err + return nil, err } if !bytes.Equal(buf[0:4], []byte("RIFF")) { - return nil, 0, fmt.Errorf("wav: invalid header: 'RIFF' not found") + return nil, fmt.Errorf("wav: invalid header: 'RIFF' not found") } if !bytes.Equal(buf[8:12], []byte("WAVE")) { - return nil, 0, fmt.Errorf("wav: invalid header: 'WAVE' not found") + return nil, fmt.Errorf("wav: invalid header: 'WAVE' not found") } // Read chunks @@ -181,10 +188,10 @@ chunks: buf := make([]byte, 8) n, err := io.ReadFull(src, buf) if n != len(buf) { - return nil, 0, fmt.Errorf("wav: invalid header") + return nil, fmt.Errorf("wav: invalid header") } if err != nil { - return nil, 0, err + return nil, err } headerSize += 8 size := int64(buf[4]) | int64(buf[5])<<8 | int64(buf[6])<<16 | int64(buf[7])<<24 @@ -192,19 +199,19 @@ chunks: case bytes.Equal(buf[0:4], []byte("fmt ")): // Size of 'fmt' header is usually 16, but can be more than 16. if size < 16 { - return nil, 0, fmt.Errorf("wav: invalid header: maybe non-PCM file?") + return nil, fmt.Errorf("wav: invalid header: maybe non-PCM file?") } buf := make([]byte, size) n, err := io.ReadFull(src, buf) if n != len(buf) { - return nil, 0, fmt.Errorf("wav: invalid header") + return nil, fmt.Errorf("wav: invalid header") } if err != nil { - return nil, 0, err + return nil, err } format := int(buf[0]) | int(buf[1])<<8 if format != 1 { - return nil, 0, fmt.Errorf("wav: format must be linear PCM") + return nil, fmt.Errorf("wav: format must be linear PCM") } channelCount := int(buf[2]) | int(buf[3])<<8 switch channelCount { @@ -213,11 +220,11 @@ chunks: case 2: mono = false default: - return nil, 0, fmt.Errorf("wav: number of channels must be 1 or 2 but was %d", channelCount) + return nil, fmt.Errorf("wav: number of channels must be 1 or 2 but was %d", channelCount) } bitsPerSample = int(buf[14]) | int(buf[15])<<8 if bitsPerSample != 8 && bitsPerSample != 16 { - return nil, 0, fmt.Errorf("wav: bits per sample must be 8 or 16 but was %d", bitsPerSample) + return nil, fmt.Errorf("wav: bits per sample must be 8 or 16 but was %d", bitsPerSample) } sampleRate = int(buf[4]) | int(buf[5])<<8 | int(buf[6])<<16 | int(buf[7])<<24 headerSize += size @@ -228,10 +235,10 @@ chunks: buf := make([]byte, size) n, err := io.ReadFull(src, buf) if n != len(buf) { - return nil, 0, fmt.Errorf("wav: invalid header") + return nil, fmt.Errorf("wav: invalid header") } if err != nil { - return nil, 0, err + return nil, err } headerSize += size } @@ -252,7 +259,11 @@ chunks: dataSize *= 2 } } - return &Stream{inner: s, size: dataSize}, sampleRate, nil + return &Stream{ + inner: s, + size: dataSize, + sampleRate: sampleRate, + }, nil } // Decode decodes WAV (RIFF) data to playable stream.