diff --git a/cipher_suites.go b/cipher_suites.go index d6bcc19..7efbe5a 100644 --- a/cipher_suites.go +++ b/cipher_suites.go @@ -131,7 +131,7 @@ func macSHA1(version uint16, key []byte) macFunction { copy(mac.key, key) return mac } - return tls10MAC{hmac.New(sha1.New, key)} + return tls10MAC{hmac.New(newConstantTimeHash(sha1.New), key)} } // macSHA256 returns a SHA-256 based MAC. These are only supported in TLS 1.2 @@ -142,7 +142,7 @@ func macSHA256(version uint16, key []byte) macFunction { type macFunction interface { Size() int - MAC(digestBuf, seq, header, data []byte) []byte + MAC(digestBuf, seq, header, data, extra []byte) []byte } // fixedNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to @@ -200,7 +200,9 @@ var ssl30Pad1 = [48]byte{0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0 var ssl30Pad2 = [48]byte{0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c} -func (s ssl30MAC) MAC(digestBuf, seq, header, data []byte) []byte { +// MAC does not offer constant timing guarantees for SSL v3.0, since it's deemed +// useless considering the similar, protocol-level POODLE vulnerability. +func (s ssl30MAC) MAC(digestBuf, seq, header, data, extra []byte) []byte { padLength := 48 if s.h.Size() == 20 { padLength = 40 @@ -222,6 +224,29 @@ func (s ssl30MAC) MAC(digestBuf, seq, header, data []byte) []byte { return s.h.Sum(digestBuf[:0]) } +type constantTimeHash interface { + hash.Hash + ConstantTimeSum(b []byte) []byte +} + +// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces +// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC. +type cthWrapper struct { + h constantTimeHash +} + +func (c *cthWrapper) Size() int { return c.h.Size() } +func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() } +func (c *cthWrapper) Reset() { c.h.Reset() } +func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) } +func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) } + +func newConstantTimeHash(h func() hash.Hash) func() hash.Hash { + return func() hash.Hash { + return &cthWrapper{h().(constantTimeHash)} + } +} + // tls10MAC implements the TLS 1.0 MAC function. RFC 2246, section 6.2.3. type tls10MAC struct { h hash.Hash @@ -231,12 +256,19 @@ func (s tls10MAC) Size() int { return s.h.Size() } -func (s tls10MAC) MAC(digestBuf, seq, header, data []byte) []byte { +// MAC is guaranteed to take constant time, as long as +// len(seq)+len(header)+len(data)+len(extra) is constant. extra is not fed into +// the MAC, but is only provided to make the timing profile constant. +func (s tls10MAC) MAC(digestBuf, seq, header, data, extra []byte) []byte { s.h.Reset() s.h.Write(seq) s.h.Write(header) s.h.Write(data) - return s.h.Sum(digestBuf[:0]) + res := s.h.Sum(digestBuf[:0]) + if extra != nil { + s.h.Write(extra) + } + return res } func rsaKA(version uint16) keyAgreement { diff --git a/conn.go b/conn.go index 6fd4864..a0c29f0 100644 --- a/conn.go +++ b/conn.go @@ -193,18 +193,18 @@ func (hc *halfConn) incSeq() { panic("TLS: sequence number wraparound") } -// removePadding returns an unpadded slice, in constant time, which is a prefix -// of the input. It also returns a byte which is equal to 255 if the padding -// was valid and 0 otherwise. See RFC 2246, section 6.2.3.2 -func removePadding(payload []byte) ([]byte, byte) { +// extractPadding returns, in constant time, the length of the padding to remove +// from the end of payload. It also returns a byte which is equal to 255 if the +// padding was valid and 0 otherwise. See RFC 2246, section 6.2.3.2 +func extractPadding(payload []byte) (toRemove int, good byte) { if len(payload) < 1 { - return payload, 0 + return 0, 0 } paddingLen := payload[len(payload)-1] t := uint(len(payload)-1) - uint(paddingLen) // if len(payload) >= (paddingLen - 1) then the MSB of t is zero - good := byte(int32(^t) >> 31) + good = byte(int32(^t) >> 31) toCheck := 255 // the maximum possible padding length // The length of the padded data is public, so we can use an if here @@ -227,24 +227,24 @@ func removePadding(payload []byte) ([]byte, byte) { good &= good << 1 good = uint8(int8(good) >> 7) - toRemove := good&paddingLen + 1 - return payload[:len(payload)-int(toRemove)], good + toRemove = int(paddingLen) + 1 + return } -// removePaddingSSL30 is a replacement for removePadding in the case that the +// extractPaddingSSL30 is a replacement for extractPadding in the case that the // protocol version is SSLv3. In this version, the contents of the padding // are random and cannot be checked. -func removePaddingSSL30(payload []byte) ([]byte, byte) { +func extractPaddingSSL30(payload []byte) (toRemove int, good byte) { if len(payload) < 1 { - return payload, 0 + return 0, 0 } paddingLen := int(payload[len(payload)-1]) + 1 if paddingLen > len(payload) { - return payload, 0 + return 0, 0 } - return payload[:len(payload)-paddingLen], 255 + return paddingLen, 255 } func roundUp(a, b int) int { @@ -270,6 +270,7 @@ func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) } paddingGood := byte(255) + paddingLen := 0 explicitIVLen := 0 // decrypt @@ -312,22 +313,17 @@ func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) } c.CryptBlocks(payload, payload) if hc.version == VersionSSL30 { - payload, paddingGood = removePaddingSSL30(payload) + paddingLen, paddingGood = extractPaddingSSL30(payload) } else { - payload, paddingGood = removePadding(payload) + paddingLen, paddingGood = extractPadding(payload) + + // To protect against CBC padding oracles like Lucky13, the data + // past paddingLen (which is secret) is passed to the MAC + // function as extra data, to be fed into the HMAC after + // computing the digest. This makes the MAC constant time as + // long as the digest computation is constant time and does not + // affect the subsequent write. } - b.resize(recordHeaderLen + explicitIVLen + len(payload)) - - // note that we still have a timing side-channel in the - // MAC check, below. An attacker can align the record - // so that a correct padding will cause one less hash - // block to be calculated. Then they can iteratively - // decrypt a record by breaking each byte. See - // "Password Interception in a SSL/TLS Channel", Brice - // Canvel et al. - // - // However, our behavior matches OpenSSL, so we leak - // only as much as they do. default: panic("unknown cipher type") } @@ -340,17 +336,19 @@ func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) } // strip mac off payload, b.data - n := len(payload) - macSize + n := len(payload) - macSize - paddingLen + n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 } b.data[3] = byte(n >> 8) b.data[4] = byte(n) - b.resize(recordHeaderLen + explicitIVLen + n) - remoteMAC := payload[n:] - localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], payload[:n]) + remoteMAC := payload[n : n+macSize] + localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], payload[:n], payload[n+macSize:]) if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 { return false, 0, alertBadRecordMAC } hc.inDigestBuf = localMAC + + b.resize(recordHeaderLen + explicitIVLen + n) } hc.incSeq() @@ -378,7 +376,7 @@ func padToBlockSize(payload []byte, blockSize int) (prefix, finalBlock []byte) { func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) { // mac if hc.mac != nil { - mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:]) + mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:], nil) n := len(b.data) b.resize(n + len(mac)) diff --git a/conn_test.go b/conn_test.go index 15397d6..5e5c7a2 100644 --- a/conn_test.go +++ b/conn_test.go @@ -40,7 +40,7 @@ var paddingTests = []struct { func TestRemovePadding(t *testing.T) { for i, test := range paddingTests { - payload, good := removePadding(test.in) + paddingLen, good := extractPadding(test.in) expectedGood := byte(255) if !test.good { expectedGood = 0 @@ -48,8 +48,8 @@ func TestRemovePadding(t *testing.T) { if good != expectedGood { t.Errorf("#%d: wrong validity, want:%d got:%d", i, expectedGood, good) } - if good == 255 && len(payload) != test.expectedLen { - t.Errorf("#%d: got %d, want %d", i, len(payload), test.expectedLen) + if good == 255 && len(test.in)-paddingLen != test.expectedLen { + t.Errorf("#%d: got %d, want %d", i, len(test.in)-paddingLen, test.expectedLen) } } }