blob: 516cf846919938c3bea47a696db26724da93b964 [file] [log] [blame]
package util
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"net/http"
"time"
"veyron2/vlog"
)
const cookieLen = 16
type CSRFCop struct{}
func NewCSRFCop() *CSRFCop { return new(CSRFCop) }
func (*CSRFCop) maybeSetCookie(w http.ResponseWriter, req *http.Request, cookieName string) ([]byte, error) {
cookie, err := req.Cookie(cookieName)
switch err {
case nil:
if v, err := decodeCookieValue(cookie.Value); err == nil {
return v, nil
}
vlog.Infof("Invalid cookie: %#v, err: %v. Regenerating one.", cookie, err)
case http.ErrNoCookie:
// Intentionally blank: Cookie will be generated below.
default:
vlog.Infof("Error decoding cookie %q in request: %v. Regenerating one.", cookieName, err)
}
cookie, v := newCookie(cookieName)
if cookie == nil || v == nil {
return nil, fmt.Errorf("failed to create cookie")
}
http.SetCookie(w, cookie)
return v, nil
}
// NewToken creates an anti-cross-site-request-forgery, aka CSRF aka XSRF token.
// It returns an error if the token could not be created.
func (c *CSRFCop) NewToken(w http.ResponseWriter, r *http.Request, cookieName string) (string, error) {
cookie, err := c.maybeSetCookie(w, r, cookieName)
if err != nil {
return "", fmt.Errorf("bad cookie: %v", err)
}
return c.newToken(cookie), nil
}
func (c *CSRFCop) newToken(cookie []byte) string {
return b64encode(hmac.New(sha256.New, cookie).Sum(nil))
}
// ValidateToken checks the validity of the provided CSRF token for the
// provided request, returning nil if the token is valid and an error
// otherwise.
// The returned error should not be shown to end-users, it is meant for
// consumption of the server process only.
func (c *CSRFCop) ValidateToken(token string, req *http.Request, cookieName string) error {
cookie, err := req.Cookie(cookieName)
if err != nil {
return err
}
cookieValue, err := decodeCookieValue(cookie.Value)
if err != nil {
return fmt.Errorf("invalid cookie")
}
if token != c.newToken(cookieValue) {
return fmt.Errorf("invalid CSRF token")
}
return nil
}
func requestString(r *http.Request) string {
var buf bytes.Buffer
r.Write(&buf)
return buf.String()
}
type badRequestData struct {
Request string
Error error
}
func newCookie(cookieName string) (*http.Cookie, []byte) {
b := make([]byte, cookieLen)
if n, err := rand.Read(b); n != cookieLen || err != nil {
vlog.Errorf("newCookie failed: Read %d random bytes, wanted %d. err: %v", n, cookieLen, err)
return nil, nil
}
return &http.Cookie{
Name: cookieName,
Value: b64encode(b),
Expires: time.Now().Add(time.Hour * 24),
HttpOnly: true,
}, b
}
func decodeCookieValue(v string) ([]byte, error) {
b, err := b64decode(v)
if err != nil {
return nil, err
}
if len(b) != cookieLen {
return nil, fmt.Errorf("invalid cookie length[%d]", len(b))
}
return b, nil
}
// Shorthands.
func b64encode(b []byte) string { return base64.URLEncoding.EncodeToString(b) }
func b64decode(s string) ([]byte, error) { return base64.URLEncoding.DecodeString(s) }