blob: b211f0ac5ac2437a3fa6142ef689207efb442050 [file] [log] [blame]
// Copyright 2015 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package util
import (
"bytes"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"net/http"
"time"
"v.io/v23"
"v.io/v23/context"
"v.io/v23/vom"
)
const (
cookieLen = 16
)
// CSRFCop implements utilities for generating and validating tokens for
// cross-site-request-forgery prevention (also called XSRF).
type CSRFCop struct {
ctx *context.T
}
func NewCSRFCop(ctx *context.T) *CSRFCop {
return &CSRFCop{ctx: ctx}
}
// NewToken creates an anti-cross-site-request-forgery, aka CSRF aka XSRF token
// with some data bound to it that can be obtained by ValidateToken.
// It returns an error if the token could not be created.
func (c *CSRFCop) NewToken(w http.ResponseWriter, r *http.Request, cookieName string, data interface{}) (string, error) {
cookieValue, err := c.MaybeSetCookie(w, r, cookieName)
if err != nil {
return "", fmt.Errorf("bad cookie: %v", err)
}
var encData []byte
if data != nil {
if encData, err = vom.Encode(data); err != nil {
return "", err
}
}
hash := sha256.Sum256(cookieValue)
mac, err := NewMacaroon(v23.GetPrincipal(c.ctx), append(hash[:], encData...))
return string(mac), err
}
// ValidateToken checks the validity of the provided CSRF token for the
// provided request, and extracts the data encoded in the token into 'decoded'.
// If the token is invalid, return an error. This error should not be shown to end users,
// it is meant for the consumption by the server process only.
func (c *CSRFCop) ValidateToken(token string, req *http.Request, cookieName string, decoded interface{}) error {
cookie, err := req.Cookie(cookieName)
if err != nil {
return err
}
cookieValue, err := decodeCookieValue(cookie.Value)
if err != nil {
return fmt.Errorf("invalid cookie")
}
encodedInput, err := Macaroon(token).Decode(v23.GetPrincipal(c.ctx))
if err != nil {
return err
}
if len(encodedInput) < sha256.Size {
return fmt.Errorf("invalid token data: too short")
}
hash := sha256.Sum256(cookieValue)
if !bytes.Equal(hash[:], encodedInput[:sha256.Size]) {
return fmt.Errorf("invalid token data")
}
if decoded != nil {
if err := vom.Decode(encodedInput[sha256.Size:], decoded); err != nil {
return fmt.Errorf("invalid token data: %v", err)
}
}
return nil
}
func (c *CSRFCop) MaybeSetCookie(w http.ResponseWriter, req *http.Request, cookieName string) ([]byte, error) {
cookie, err := req.Cookie(cookieName)
switch err {
case nil:
if v, err := decodeCookieValue(cookie.Value); err == nil {
return v, nil
}
c.ctx.Infof("Invalid cookie: %#v, err: %v. Regenerating one.", cookie, err)
case http.ErrNoCookie:
// Intentionally blank: Cookie will be generated below.
default:
c.ctx.Infof("Error decoding cookie %q in request: %v. Regenerating one.", cookieName, err)
}
cookie, v := newCookie(c.ctx, cookieName)
if cookie == nil || v == nil {
return nil, fmt.Errorf("failed to create cookie")
}
http.SetCookie(w, cookie)
// We need to add the cookie to the request also to prevent repeatedly resetting cookies on multiple
// calls from the same request.
req.AddCookie(cookie)
return v, nil
}
func newCookie(ctx *context.T, cookieName string) (*http.Cookie, []byte) {
b := make([]byte, cookieLen)
if _, err := rand.Read(b); err != nil {
ctx.Errorf("newCookie failed: %v", err)
return nil, nil
}
return &http.Cookie{
Name: cookieName,
Value: b64encode(b),
Expires: time.Now().Add(time.Hour * 24),
HttpOnly: true,
Secure: true,
Path: "/",
}, b
}
func decodeCookieValue(v string) ([]byte, error) {
b, err := b64decode(v)
if err != nil {
return nil, err
}
if len(b) != cookieLen {
return nil, fmt.Errorf("invalid cookie length[%d]", len(b))
}
return b, nil
}
// Shorthands.
func b64encode(b []byte) string { return base64.URLEncoding.EncodeToString(b) }
func b64decode(s string) ([]byte, error) { return base64.URLEncoding.DecodeString(s) }