diff --git a/audio/wav/decode.go b/audio/wav/decode.go index 23e0e6a06..331590d3e 100644 --- a/audio/wav/decode.go +++ b/audio/wav/decode.go @@ -59,59 +59,6 @@ func (s *Stream) SampleRate() int { return s.sampleRate } -type stream struct { - src io.Reader - headerSize int64 - dataSize int64 - remaining int64 -} - -// Read is implementation of io.Reader's Read. -func (s *stream) Read(p []byte) (int, error) { - if s.remaining <= 0 { - return 0, io.EOF - } - if s.remaining < int64(len(p)) { - p = p[0:s.remaining] - } - n, err := s.src.Read(p) - s.remaining -= int64(n) - return n, err -} - -// Seek is implementation of io.Seeker's Seek. -// -// If the underlying source is not an io.Seeker, Seek panics. -func (s *stream) Seek(offset int64, whence int) (int64, error) { - seeker, ok := s.src.(io.Seeker) - if !ok { - panic("wav: s.src must be io.Seeker but not") - } - - switch whence { - case io.SeekStart: - offset = offset + s.headerSize - case io.SeekCurrent: - case io.SeekEnd: - offset = s.headerSize + s.dataSize + offset - whence = io.SeekStart - } - n, err := seeker.Seek(offset, whence) - if err != nil { - return 0, err - } - if n-s.headerSize < 0 { - return 0, fmt.Errorf("wav: invalid offset") - } - s.remaining = s.dataSize - (n - s.headerSize) - // There could be a tail in wav file. - if s.remaining < 0 { - s.remaining = 0 - return s.dataSize, nil - } - return n - s.headerSize, nil -} - // DecodeWithoutResampling decodes WAV (RIFF) data to playable stream. // // The format must be 1 or 2 channels, 8bit or 16bit little endian PCM. @@ -189,8 +136,8 @@ func decode(src io.Reader) (*Stream, error) { var sampleRate int chunks: for { - buf := make([]byte, 8) - n, err := io.ReadFull(src, buf) + var buf [8]byte + n, err := io.ReadFull(src, buf[:]) if n != len(buf) { return nil, fmt.Errorf("wav: invalid header") } @@ -247,12 +194,8 @@ chunks: headerSize += size } } - var s io.ReadSeeker = &stream{ - src: src, - headerSize: headerSize, - dataSize: dataSize, - remaining: dataSize, - } + + var s io.ReadSeeker = newSectionReader(src, headerSize, dataSize) if mono || bitsPerSample != 16 { s = convert.NewStereo16(s, mono, bitsPerSample != 16) diff --git a/audio/wav/sectionreader.go b/audio/wav/sectionreader.go new file mode 100644 index 000000000..45cec5d6c --- /dev/null +++ b/audio/wav/sectionreader.go @@ -0,0 +1,79 @@ +// Copyright 2024 The Ebitengine Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wav + +import ( + "fmt" + "io" +) + +// sectionReader is similar to io.SectionReader but takes an io.Reader instead of io.ReaderAt. +type sectionReader struct { + src io.Reader + offset int64 + size int64 + + pos int64 +} + +// newSectionReader creates a new sectionReader. +func newSectionReader(src io.Reader, offset int64, size int64) *sectionReader { + return §ionReader{ + src: src, + offset: offset, + size: size, + } +} + +// Read is implementation of io.Reader's Read. +func (s *sectionReader) Read(p []byte) (int, error) { + if s.pos >= s.size { + return 0, io.EOF + } + if s.pos+int64(len(p)) > s.size { + p = p[:s.size-s.pos] + } + n, err := s.src.Read(p) + s.pos += int64(n) + return n, err +} + +// Seek is implementation of io.Seeker's Seek. +// +// If the underlying source is not an io.Seeker, Seek panics. +func (s *sectionReader) Seek(offset int64, whence int) (int64, error) { + seeker, ok := s.src.(io.Seeker) + if !ok { + panic("wav: s.src must be io.Seeker but not") + } + + switch whence { + case io.SeekStart: + offset += s.offset + case io.SeekCurrent: + case io.SeekEnd: + offset += s.offset + s.size + whence = io.SeekStart + } + n, err := seeker.Seek(offset, whence) + if err != nil { + return 0, err + } + s.pos = n - s.offset + if s.pos < 0 || s.pos > s.size { + return 0, fmt.Errorf("wav: position is out of range") + } + return s.pos, nil +}