blob: 4c563998c7d0cc0487c870c667b651fe66124ec9 [file] [log] [blame]
// Copyright 2016 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"time"
)
type cookieBaker interface {
// set creates the cookie with the given name and payload, and also
// stores the given csrf token in the cookie.
set(w http.ResponseWriter, name, payload, csrfToken string) error
// get retrieves the payload and csrf token from the cookie. If the
// cookie does not exist, get returns ("", "", nil).
get(req *http.Request, name string) (string, string, error)
}
type signedCookieBaker struct {
secure bool
signKey string
validity time.Duration
}
type signedCookie struct {
// Payload is the value the cookie is meant to keep.
Payload string `json:"payload"`
// Expiry ensures cookies cannot be used beyond their intended validity.
Expiry time.Time `json:"expiry"`
// HMAC signs all the above fields and prevents tampering with the
// cookie.
HMAC []byte `json:"hmac"`
// CSRFToken prevents CSRF attacks by ensuring that requests must
// contain a csrf token that matches the cookie.
CSRFToken string `'json:"csrfToken"`
}
func computeHMAC(signKey string, fields ...string) []byte {
mac := hmac.New(sha256.New, []byte(signKey))
for _, f := range fields {
fmt.Fprintf(mac, "%08x%s", len(f), f)
}
return mac.Sum(nil)
}
func (c *signedCookie) computeHMAC(name, signKey string) []byte {
return computeHMAC(signKey, name, c.Payload, c.Expiry.String(), c.CSRFToken)
}
func (c *signedCookie) verifyHMAC(name, signKey string) bool {
return hmac.Equal(c.HMAC, c.computeHMAC(name, signKey))
}
func (b *signedCookieBaker) packCookie(name, payload, csrfToken string) (string, error) {
c := signedCookie{
Payload: payload,
Expiry: time.Now().Add(b.validity),
CSRFToken: csrfToken,
}
c.HMAC = c.computeHMAC(name, b.signKey)
jsonData, err := json.Marshal(c)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(jsonData), nil
}
func (b *signedCookieBaker) unpackCookie(name, value string) (string, string, error) {
jsonData, err := base64.URLEncoding.DecodeString(value)
if err != nil {
return "", "", err
}
var c signedCookie
if err := json.Unmarshal(jsonData, &c); err != nil {
return "", "", err
}
if time.Now().After(c.Expiry) {
return "", "", fmt.Errorf("cookie expired on %v", c.Expiry)
}
if !c.verifyHMAC(name, b.signKey) {
return "", "", fmt.Errorf("HMAC mismatching for cookie %v", c)
}
return c.Payload, c.CSRFToken, nil
}
func (b *signedCookieBaker) set(w http.ResponseWriter, name, payload, csrfToken string) error {
cookieValue, err := b.packCookie(name, payload, csrfToken)
if err != nil {
return err
}
cookie := http.Cookie{
Name: name,
Value: cookieValue,
Expires: time.Now().Add(b.validity),
HttpOnly: true,
Secure: b.secure,
Path: routeRoot,
}
http.SetCookie(w, &cookie)
return nil
}
func (b *signedCookieBaker) get(req *http.Request, name string) (string, string, error) {
cookie, err := req.Cookie(name)
if err == http.ErrNoCookie {
return "", "", nil
}
if err != nil {
return "", "", err
}
return b.unpackCookie(name, cookie.Value)
}