crypto/tls: implement TLS 1.3 PSK messages

This commit is contained in:
Filippo Valsorda 2016-11-21 17:24:45 -05:00 committed by Peter Wu
parent 6c3765bb15
commit 453bd6af77
5 changed files with 337 additions and 2 deletions

View File

@ -84,7 +84,10 @@ const (
extensionSCT uint16 = 18 // https://tools.ietf.org/html/rfc6962#section-6 extensionSCT uint16 = 18 // https://tools.ietf.org/html/rfc6962#section-6
extensionSessionTicket uint16 = 35 extensionSessionTicket uint16 = 35
extensionKeyShare uint16 = 40 extensionKeyShare uint16 = 40
extensionPreSharedKey uint16 = 41
extensionSupportedVersions uint16 = 43 extensionSupportedVersions uint16 = 43
extensionPSKKeyExchangeModes uint16 = 45
extensionTicketEarlyDataInfo uint16 = 46
extensionNextProtoNeg uint16 = 13172 // not IANA assigned extensionNextProtoNeg uint16 = 13172 // not IANA assigned
extensionRenegotiationInfo uint16 = 0xff01 extensionRenegotiationInfo uint16 = 0xff01
) )
@ -94,6 +97,12 @@ const (
scsvRenegotiation uint16 = 0x00ff scsvRenegotiation uint16 = 0x00ff
) )
// PSK Key Exchange Modes
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.7
const (
pskDHEKeyExchange uint8 = 1
)
// CurveID is the type of a TLS identifier for an elliptic curve. See // CurveID is the type of a TLS identifier for an elliptic curve. See
// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8 // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8
// //
@ -115,6 +124,15 @@ type keyShare struct {
data []byte data []byte
} }
// TLS 1.3 PSK Identity and Binder, as sent by the client
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6
type psk struct {
identity []byte
obfTicketAge uint32
binder []byte
}
// TLS Elliptic Curve Point Formats // TLS Elliptic Curve Point Formats
// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9
const ( const (

View File

@ -1032,7 +1032,11 @@ func (c *Conn) readHandshake() (interface{}, error) {
case typeEncryptedExtensions: case typeEncryptedExtensions:
m = new(encryptedExtensionsMsg) m = new(encryptedExtensionsMsg)
case typeNewSessionTicket: case typeNewSessionTicket:
m = new(newSessionTicketMsg) if c.vers >= VersionTLS13 {
m = new(newSessionTicketMsg13)
} else {
m = new(newSessionTicketMsg)
}
case typeCertificate: case typeCertificate:
if c.vers >= VersionTLS13 { if c.vers >= VersionTLS13 {
m = new(certificateMsg13) m = new(certificateMsg13)

View File

@ -11,6 +11,7 @@ import (
type clientHelloMsg struct { type clientHelloMsg struct {
raw []byte raw []byte
rawTruncated []byte // for PSK binding
vers uint16 vers uint16
random []byte random []byte
sessionId []byte sessionId []byte
@ -30,6 +31,8 @@ type clientHelloMsg struct {
alpnProtocols []string alpnProtocols []string
keyShares []keyShare keyShares []keyShare
supportedVersions []uint16 supportedVersions []uint16
psks []psk
pskKeyExchangeModes []uint8
} }
func (m *clientHelloMsg) equal(i interface{}) bool { func (m *clientHelloMsg) equal(i interface{}) bool {
@ -366,6 +369,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
} }
m.sessionId = data[39 : 39+sessionIdLen] m.sessionId = data[39 : 39+sessionIdLen]
data = data[39+sessionIdLen:] data = data[39+sessionIdLen:]
bindersOffset := 39 + sessionIdLen
if len(data) < 2 { if len(data) < 2 {
return false return false
} }
@ -384,6 +388,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
} }
} }
data = data[2+cipherSuiteLen:] data = data[2+cipherSuiteLen:]
bindersOffset += 2 + cipherSuiteLen
if len(data) < 1 { if len(data) < 1 {
return false return false
} }
@ -394,6 +399,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.compressionMethods = data[1 : 1+compressionMethodsLen] m.compressionMethods = data[1 : 1+compressionMethodsLen]
data = data[1+compressionMethodsLen:] data = data[1+compressionMethodsLen:]
bindersOffset += 1 + compressionMethodsLen
m.nextProtoNeg = false m.nextProtoNeg = false
m.serverName = "" m.serverName = ""
@ -405,6 +411,8 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.scts = false m.scts = false
m.keyShares = nil m.keyShares = nil
m.supportedVersions = nil m.supportedVersions = nil
m.psks = nil
m.pskKeyExchangeModes = nil
if len(data) == 0 { if len(data) == 0 {
// ClientHello is optionally followed by extension data // ClientHello is optionally followed by extension data
@ -416,6 +424,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
extensionsLength := int(data[0])<<8 | int(data[1]) extensionsLength := int(data[0])<<8 | int(data[1])
data = data[2:] data = data[2:]
bindersOffset += 2
if extensionsLength != len(data) { if extensionsLength != len(data) {
return false return false
} }
@ -427,6 +436,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
extension := uint16(data[0])<<8 | uint16(data[1]) extension := uint16(data[0])<<8 | uint16(data[1])
length := int(data[2])<<8 | int(data[3]) length := int(data[2])<<8 | int(data[3])
data = data[4:] data = data[4:]
bindersOffset += 4
if len(data) < length { if len(data) < length {
return false return false
} }
@ -597,8 +607,70 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.supportedVersions = append(m.supportedVersions, v) m.supportedVersions = append(m.supportedVersions, v)
d = d[2:] d = d[2:]
} }
case extensionPreSharedKey:
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6
if length < 43 {
return false
}
li := int(data[0])<<8 | int(data[1])
if li > length-37 {
return false
}
d := data[2 : 2+li]
bindersOffset += 2 + li
for len(d) > 0 {
if len(d) < 6 {
return false
}
l := int(d[0])<<8 | int(d[1])
if len(d) < 2+l+4 {
return false
}
m.psks = append(m.psks, psk{
identity: d[2 : 2+l],
obfTicketAge: uint32(d[l+2])<<24 | uint32(d[l+3])<<16 |
uint32(d[l+4])<<8 | uint32(d[l+5]),
})
d = d[2+l+4:]
}
lb := int(data[li+2])<<8 | int(data[li+3])
d = data[2+li+2:]
if lb != len(d) {
return false
}
i := 0
for len(d) > 0 {
if len(d) < 1 {
return false
}
l := int(d[0])
if l > len(d)-1 {
return false
}
m.psks[i].binder = d[1 : 1+l]
d = d[1+l:]
i++
}
if i != len(m.psks) {
return false
}
// Ensure this extension is the last one in the Client Hello
if len(data) != length {
return false
}
m.rawTruncated = m.raw[:bindersOffset]
case extensionPSKKeyExchangeModes:
if length < 2 {
return false
}
l := int(data[0])
if length != l+1 {
return false
}
m.pskKeyExchangeModes = data[1:length]
} }
data = data[length:] data = data[length:]
bindersOffset += length
} }
return true return true
@ -941,6 +1013,8 @@ type serverHelloMsg13 struct {
random []byte random []byte
cipherSuite uint16 cipherSuite uint16
keyShare keyShare keyShare keyShare
psk bool
pskIdentity uint16
} }
func (m *serverHelloMsg13) equal(i interface{}) bool { func (m *serverHelloMsg13) equal(i interface{}) bool {
@ -954,7 +1028,9 @@ func (m *serverHelloMsg13) equal(i interface{}) bool {
bytes.Equal(m.random, m1.random) && bytes.Equal(m.random, m1.random) &&
m.cipherSuite == m1.cipherSuite && m.cipherSuite == m1.cipherSuite &&
m.keyShare.group == m1.keyShare.group && m.keyShare.group == m1.keyShare.group &&
bytes.Equal(m.keyShare.data, m1.keyShare.data) bytes.Equal(m.keyShare.data, m1.keyShare.data) &&
m.psk == m1.psk &&
m.pskIdentity == m1.pskIdentity
} }
func (m *serverHelloMsg13) marshal() []byte { func (m *serverHelloMsg13) marshal() []byte {
@ -966,6 +1042,9 @@ func (m *serverHelloMsg13) marshal() []byte {
if m.keyShare.group != 0 { if m.keyShare.group != 0 {
length += 8 + len(m.keyShare.data) length += 8 + len(m.keyShare.data)
} }
if m.psk {
length += 6
}
x := make([]byte, 4+length) x := make([]byte, 4+length)
x[0] = typeServerHello x[0] = typeServerHello
@ -982,6 +1061,15 @@ func (m *serverHelloMsg13) marshal() []byte {
x[40] = uint8(len(z) >> 8) x[40] = uint8(len(z) >> 8)
x[41] = uint8(len(z)) x[41] = uint8(len(z))
if m.psk {
z[0] = byte(extensionPreSharedKey >> 8)
z[1] = byte(extensionPreSharedKey)
z[3] = 2
z[4] = byte(m.pskIdentity >> 8)
z[5] = byte(m.pskIdentity)
z = z[6:]
}
if m.keyShare.group != 0 { if m.keyShare.group != 0 {
z[0] = uint8(extensionKeyShare >> 8) z[0] = uint8(extensionKeyShare >> 8)
z[1] = uint8(extensionKeyShare) z[1] = uint8(extensionKeyShare)
@ -1008,6 +1096,8 @@ func (m *serverHelloMsg13) unmarshal(data []byte) bool {
m.vers = uint16(data[4])<<8 | uint16(data[5]) m.vers = uint16(data[4])<<8 | uint16(data[5])
m.random = data[6:38] m.random = data[6:38]
m.cipherSuite = uint16(data[38])<<8 | uint16(data[39]) m.cipherSuite = uint16(data[38])<<8 | uint16(data[39])
m.psk = false
m.pskIdentity = 0
extensionsLength := int(data[40])<<8 | int(data[41]) extensionsLength := int(data[40])<<8 | int(data[41])
data = data[42:] data = data[42:]
@ -1029,6 +1119,12 @@ func (m *serverHelloMsg13) unmarshal(data []byte) bool {
switch extension { switch extension {
default: default:
return false return false
case extensionPreSharedKey:
if length != 2 {
return false
}
m.psk = true
m.pskIdentity = uint16(data[0])<<8 | uint16(data[1])
case extensionKeyShare: case extensionKeyShare:
if length < 2 { if length < 2 {
return false return false
@ -1925,6 +2021,131 @@ func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
return true return true
} }
type newSessionTicketMsg13 struct {
raw []byte
lifetime uint32
ageAdd uint32
ticket []byte
withEarlyDataInfo bool
maxEarlyDataLength uint32
}
func (m *newSessionTicketMsg13) equal(i interface{}) bool {
m1, ok := i.(*newSessionTicketMsg13)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.lifetime == m1.lifetime &&
m.ageAdd == m1.ageAdd &&
bytes.Equal(m.ticket, m1.ticket) &&
m.withEarlyDataInfo == m1.withEarlyDataInfo &&
m.maxEarlyDataLength == m1.maxEarlyDataLength
}
func (m *newSessionTicketMsg13) marshal() (x []byte) {
if m.raw != nil {
return m.raw
}
// See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6
ticketLen := len(m.ticket)
length := 12 + ticketLen
if m.withEarlyDataInfo {
length += 8
}
x = make([]byte, 4+length)
x[0] = typeNewSessionTicket
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
x[4] = uint8(m.lifetime >> 24)
x[5] = uint8(m.lifetime >> 16)
x[6] = uint8(m.lifetime >> 8)
x[7] = uint8(m.lifetime)
x[8] = uint8(m.ageAdd >> 24)
x[9] = uint8(m.ageAdd >> 16)
x[10] = uint8(m.ageAdd >> 8)
x[11] = uint8(m.ageAdd)
x[12] = uint8(ticketLen >> 8)
x[13] = uint8(ticketLen)
copy(x[14:], m.ticket)
if m.withEarlyDataInfo {
z := x[14+ticketLen:]
z[1] = 8
z[2] = uint8(extensionTicketEarlyDataInfo >> 8)
z[3] = uint8(extensionTicketEarlyDataInfo)
z[5] = 4
z[6] = uint8(m.maxEarlyDataLength >> 24)
z[7] = uint8(m.maxEarlyDataLength >> 16)
z[8] = uint8(m.maxEarlyDataLength >> 8)
z[9] = uint8(m.maxEarlyDataLength)
}
m.raw = x
return
}
func (m *newSessionTicketMsg13) unmarshal(data []byte) bool {
m.raw = data
m.maxEarlyDataLength = 0
m.withEarlyDataInfo = false
if len(data) < 16 {
return false
}
length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
if uint32(len(data))-4 != length {
return false
}
m.lifetime = uint32(data[4])<<24 | uint32(data[5])<<16 |
uint32(data[6])<<8 | uint32(data[7])
m.ageAdd = uint32(data[8])<<24 | uint32(data[9])<<16 |
uint32(data[10])<<8 | uint32(data[11])
ticketLen := int(data[12])<<8 + int(data[13])
if 14+ticketLen > len(data) {
return false
}
m.ticket = data[14 : 14+ticketLen]
data = data[14+ticketLen:]
extLen := int(data[0])<<8 + int(data[1])
if extLen != len(data)-2 {
return false
}
data = data[2:]
for len(data) > 0 {
if len(data) < 4 {
return false
}
extType := uint16(data[0])<<8 + uint16(data[1])
length := int(data[2])<<8 + int(data[3])
data = data[4:]
switch extType {
case extensionTicketEarlyDataInfo:
if length != 4 {
return false
}
m.withEarlyDataInfo = true
m.maxEarlyDataLength = uint32(data[0])<<24 | uint32(data[1])<<16 |
uint32(data[2])<<8 | uint32(data[3])
}
data = data[length:]
}
return true
}
type helloRequestMsg struct { type helloRequestMsg struct {
} }

View File

@ -29,6 +29,8 @@ var tests = []interface{}{
&serverHelloMsg13{}, &serverHelloMsg13{},
&encryptedExtensionsMsg{}, &encryptedExtensionsMsg{},
&certificateMsg13{}, &certificateMsg13{},
&newSessionTicketMsg13{},
&sessionState13{},
} }
type testMessage interface { type testMessage interface {
@ -210,6 +212,10 @@ func (*serverHelloMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
m.cipherSuite = uint16(rand.Int31()) m.cipherSuite = uint16(rand.Int31())
m.keyShare.group = CurveID(rand.Intn(30000)) m.keyShare.group = CurveID(rand.Intn(30000))
m.keyShare.data = randomBytes(rand.Intn(300), rand) m.keyShare.data = randomBytes(rand.Intn(300), rand)
if rand.Intn(10) > 5 {
m.psk = true
m.pskIdentity = uint16(rand.Int31())
}
return reflect.ValueOf(m) return reflect.ValueOf(m)
} }
@ -294,6 +300,18 @@ func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
return reflect.ValueOf(m) return reflect.ValueOf(m)
} }
func (*newSessionTicketMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &newSessionTicketMsg13{}
m.ageAdd = uint32(rand.Intn(0xffffffff))
m.lifetime = uint32(rand.Intn(0xffffffff))
m.ticket = randomBytes(rand.Intn(40), rand)
if rand.Intn(10) > 5 {
m.withEarlyDataInfo = true
m.maxEarlyDataLength = uint32(rand.Intn(0xffffffff))
}
return reflect.ValueOf(m)
}
func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
s := &sessionState{} s := &sessionState{}
s.vers = uint16(rand.Intn(10000)) s.vers = uint16(rand.Intn(10000))
@ -307,6 +325,16 @@ func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
return reflect.ValueOf(s) return reflect.ValueOf(s)
} }
func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value {
s := &sessionState13{}
s.vers = uint16(rand.Intn(10000))
s.hash = uint16(rand.Intn(10000))
s.ageAdd = uint32(rand.Intn(0xffffffff))
s.createdAt = uint64(rand.Int63n(0xfffffffffffffff))
s.resumptionSecret = randomBytes(rand.Intn(100), rand)
return reflect.ValueOf(s)
}
func TestRejectEmptySCTList(t *testing.T) { func TestRejectEmptySCTList(t *testing.T) {
// https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that // https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that
// empty SCT lists are invalid. // empty SCT lists are invalid.

View File

@ -129,6 +129,70 @@ func (s *sessionState) unmarshal(data []byte) bool {
return len(data) == 0 return len(data) == 0
} }
type sessionState13 struct {
vers uint16
hash uint16 // crypto.Hash value
ageAdd uint32
createdAt uint64
resumptionSecret []byte
}
func (s *sessionState13) equal(i interface{}) bool {
s1, ok := i.(*sessionState13)
if !ok {
return false
}
return s.vers == s1.vers &&
s.hash == s1.hash &&
s.ageAdd == s1.ageAdd &&
bytes.Equal(s.resumptionSecret, s1.resumptionSecret)
}
func (s *sessionState13) marshal() []byte {
length := 2 + 2 + 4 + 8 + 2 + len(s.resumptionSecret)
x := make([]byte, length)
x[0] = byte(s.vers >> 8)
x[1] = byte(s.vers)
x[2] = byte(s.hash >> 8)
x[3] = byte(s.hash)
x[4] = byte(s.ageAdd >> 24)
x[5] = byte(s.ageAdd >> 16)
x[6] = byte(s.ageAdd >> 8)
x[7] = byte(s.ageAdd)
x[8] = byte(s.createdAt >> 56)
x[9] = byte(s.createdAt >> 48)
x[10] = byte(s.createdAt >> 40)
x[11] = byte(s.createdAt >> 32)
x[12] = byte(s.createdAt >> 24)
x[13] = byte(s.createdAt >> 16)
x[14] = byte(s.createdAt >> 8)
x[15] = byte(s.createdAt)
x[16] = byte(len(s.resumptionSecret) >> 8)
x[17] = byte(len(s.resumptionSecret))
copy(x[18:], s.resumptionSecret)
return x
}
func (s *sessionState13) unmarshal(data []byte) bool {
if len(data) < 18 {
return false
}
s.vers = uint16(data[0])<<8 | uint16(data[1])
s.hash = uint16(data[2])<<8 | uint16(data[3])
s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 |
uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15])
l := uint16(data[16])<<8 | uint16(data[17])
s.resumptionSecret = data[18:]
return int(l) == len(s.resumptionSecret)
}
func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) { func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) {
serialized := state.marshal() serialized := state.marshal()
encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(serialized)+sha256.Size) encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(serialized)+sha256.Size)