blob: 10f7c611f22b9b5fdaccb0586438c20260a74c21 [file] [log] [blame]
// Copyright 2015 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// LimitedWriter is an io.Writer wrapper that limits the total number of bytes
// written to the underlying writer.
//
// All attempted writes count against the limit, regardless of whether they
// succeed.
// Not thread-safe.
package lib
import (
"errors"
"io"
"net/http"
"sync"
)
var ErrWriteLimitExceeded = errors.New("LimitedWriter: write limit exceeded")
// Initialize using NewLimitedWriter.
type LimitedWriter struct {
io.Writer
maxLen int
maxLenExceededCb func()
lenWritten int
}
func NewLimitedWriter(writer io.Writer, maxLen int, maxLenExceededCb func()) *LimitedWriter {
return &LimitedWriter{
Writer: writer,
maxLen: maxLen,
maxLenExceededCb: maxLenExceededCb,
}
}
func (t *LimitedWriter) Write(p []byte) (n int, err error) {
if t.lenWritten+len(p) > t.maxLen {
t.lenWritten = t.maxLen
if t.maxLenExceededCb != nil {
t.maxLenExceededCb()
}
return 0, ErrWriteLimitExceeded
}
if len(p) == 0 {
return 0, nil
}
t.lenWritten += len(p)
return t.Writer.Write(p)
}
var _ http.Flusher = (*LimitedWriter)(nil)
func (t *LimitedWriter) Flush() {
if f, ok := t.Writer.(http.Flusher); ok {
f.Flush()
}
}
// Wraps a function to prevent it from executing more than once.
func DoOnce(f func()) func() {
var once sync.Once
return func() {
once.Do(f)
}
}