diff --git a/common.go b/common.go index 80db411..457ac44 100644 --- a/common.go +++ b/common.go @@ -84,7 +84,10 @@ const ( extensionSCT uint16 = 18 // https://tools.ietf.org/html/rfc6962#section-6 extensionSessionTicket uint16 = 35 extensionKeyShare uint16 = 40 + extensionPreSharedKey uint16 = 41 extensionSupportedVersions uint16 = 43 + extensionPSKKeyExchangeModes uint16 = 45 + extensionTicketEarlyDataInfo uint16 = 46 extensionNextProtoNeg uint16 = 13172 // not IANA assigned extensionRenegotiationInfo uint16 = 0xff01 ) @@ -94,6 +97,12 @@ const ( scsvRenegotiation uint16 = 0x00ff ) +// PSK Key Exchange Modes +// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.7 +const ( + pskDHEKeyExchange uint8 = 1 +) + // CurveID is the type of a TLS identifier for an elliptic curve. See // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8 // @@ -115,6 +124,15 @@ type keyShare struct { data []byte } +// TLS 1.3 PSK Identity and Binder, as sent by the client +// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6 + +type psk struct { + identity []byte + obfTicketAge uint32 + binder []byte +} + // TLS Elliptic Curve Point Formats // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 const ( diff --git a/conn.go b/conn.go index 8e916d9..d1e358c 100644 --- a/conn.go +++ b/conn.go @@ -1032,7 +1032,11 @@ func (c *Conn) readHandshake() (interface{}, error) { case typeEncryptedExtensions: m = new(encryptedExtensionsMsg) case typeNewSessionTicket: - m = new(newSessionTicketMsg) + if c.vers >= VersionTLS13 { + m = new(newSessionTicketMsg13) + } else { + m = new(newSessionTicketMsg) + } case typeCertificate: if c.vers >= VersionTLS13 { m = new(certificateMsg13) diff --git a/handshake_messages.go b/handshake_messages.go index 19b5129..31d1513 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -11,6 +11,7 @@ import ( type clientHelloMsg struct { raw []byte + rawTruncated []byte // for PSK binding vers uint16 random []byte sessionId []byte @@ -30,6 +31,8 @@ type clientHelloMsg struct { alpnProtocols []string keyShares []keyShare supportedVersions []uint16 + psks []psk + pskKeyExchangeModes []uint8 } func (m *clientHelloMsg) equal(i interface{}) bool { @@ -366,6 +369,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } m.sessionId = data[39 : 39+sessionIdLen] data = data[39+sessionIdLen:] + bindersOffset := 39 + sessionIdLen if len(data) < 2 { return false } @@ -384,6 +388,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } } data = data[2+cipherSuiteLen:] + bindersOffset += 2 + cipherSuiteLen if len(data) < 1 { return false } @@ -394,6 +399,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.compressionMethods = data[1 : 1+compressionMethodsLen] data = data[1+compressionMethodsLen:] + bindersOffset += 1 + compressionMethodsLen m.nextProtoNeg = false m.serverName = "" @@ -405,6 +411,8 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.scts = false m.keyShares = nil m.supportedVersions = nil + m.psks = nil + m.pskKeyExchangeModes = nil if len(data) == 0 { // ClientHello is optionally followed by extension data @@ -416,6 +424,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { extensionsLength := int(data[0])<<8 | int(data[1]) data = data[2:] + bindersOffset += 2 if extensionsLength != len(data) { return false } @@ -427,6 +436,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { extension := uint16(data[0])<<8 | uint16(data[1]) length := int(data[2])<<8 | int(data[3]) data = data[4:] + bindersOffset += 4 if len(data) < length { return false } @@ -597,8 +607,70 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.supportedVersions = append(m.supportedVersions, v) d = d[2:] } + case extensionPreSharedKey: + // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6 + if length < 43 { + return false + } + li := int(data[0])<<8 | int(data[1]) + if li > length-37 { + return false + } + d := data[2 : 2+li] + bindersOffset += 2 + li + for len(d) > 0 { + if len(d) < 6 { + return false + } + l := int(d[0])<<8 | int(d[1]) + if len(d) < 2+l+4 { + return false + } + m.psks = append(m.psks, psk{ + identity: d[2 : 2+l], + obfTicketAge: uint32(d[l+2])<<24 | uint32(d[l+3])<<16 | + uint32(d[l+4])<<8 | uint32(d[l+5]), + }) + d = d[2+l+4:] + } + lb := int(data[li+2])<<8 | int(data[li+3]) + d = data[2+li+2:] + if lb != len(d) { + return false + } + i := 0 + for len(d) > 0 { + if len(d) < 1 { + return false + } + l := int(d[0]) + if l > len(d)-1 { + return false + } + m.psks[i].binder = d[1 : 1+l] + d = d[1+l:] + i++ + } + if i != len(m.psks) { + return false + } + // Ensure this extension is the last one in the Client Hello + if len(data) != length { + return false + } + m.rawTruncated = m.raw[:bindersOffset] + case extensionPSKKeyExchangeModes: + if length < 2 { + return false + } + l := int(data[0]) + if length != l+1 { + return false + } + m.pskKeyExchangeModes = data[1:length] } data = data[length:] + bindersOffset += length } return true @@ -941,6 +1013,8 @@ type serverHelloMsg13 struct { random []byte cipherSuite uint16 keyShare keyShare + psk bool + pskIdentity uint16 } func (m *serverHelloMsg13) equal(i interface{}) bool { @@ -954,7 +1028,9 @@ func (m *serverHelloMsg13) equal(i interface{}) bool { bytes.Equal(m.random, m1.random) && m.cipherSuite == m1.cipherSuite && m.keyShare.group == m1.keyShare.group && - bytes.Equal(m.keyShare.data, m1.keyShare.data) + bytes.Equal(m.keyShare.data, m1.keyShare.data) && + m.psk == m1.psk && + m.pskIdentity == m1.pskIdentity } func (m *serverHelloMsg13) marshal() []byte { @@ -966,6 +1042,9 @@ func (m *serverHelloMsg13) marshal() []byte { if m.keyShare.group != 0 { length += 8 + len(m.keyShare.data) } + if m.psk { + length += 6 + } x := make([]byte, 4+length) x[0] = typeServerHello @@ -982,6 +1061,15 @@ func (m *serverHelloMsg13) marshal() []byte { x[40] = uint8(len(z) >> 8) x[41] = uint8(len(z)) + if m.psk { + z[0] = byte(extensionPreSharedKey >> 8) + z[1] = byte(extensionPreSharedKey) + z[3] = 2 + z[4] = byte(m.pskIdentity >> 8) + z[5] = byte(m.pskIdentity) + z = z[6:] + } + if m.keyShare.group != 0 { z[0] = uint8(extensionKeyShare >> 8) z[1] = uint8(extensionKeyShare) @@ -1008,6 +1096,8 @@ func (m *serverHelloMsg13) unmarshal(data []byte) bool { m.vers = uint16(data[4])<<8 | uint16(data[5]) m.random = data[6:38] m.cipherSuite = uint16(data[38])<<8 | uint16(data[39]) + m.psk = false + m.pskIdentity = 0 extensionsLength := int(data[40])<<8 | int(data[41]) data = data[42:] @@ -1029,6 +1119,12 @@ func (m *serverHelloMsg13) unmarshal(data []byte) bool { switch extension { default: return false + case extensionPreSharedKey: + if length != 2 { + return false + } + m.psk = true + m.pskIdentity = uint16(data[0])<<8 | uint16(data[1]) case extensionKeyShare: if length < 2 { return false @@ -1925,6 +2021,131 @@ func (m *newSessionTicketMsg) unmarshal(data []byte) bool { return true } +type newSessionTicketMsg13 struct { + raw []byte + lifetime uint32 + ageAdd uint32 + ticket []byte + withEarlyDataInfo bool + maxEarlyDataLength uint32 +} + +func (m *newSessionTicketMsg13) equal(i interface{}) bool { + m1, ok := i.(*newSessionTicketMsg13) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.lifetime == m1.lifetime && + m.ageAdd == m1.ageAdd && + bytes.Equal(m.ticket, m1.ticket) && + m.withEarlyDataInfo == m1.withEarlyDataInfo && + m.maxEarlyDataLength == m1.maxEarlyDataLength +} + +func (m *newSessionTicketMsg13) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6 + ticketLen := len(m.ticket) + length := 12 + ticketLen + if m.withEarlyDataInfo { + length += 8 + } + x = make([]byte, 4+length) + x[0] = typeNewSessionTicket + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + x[4] = uint8(m.lifetime >> 24) + x[5] = uint8(m.lifetime >> 16) + x[6] = uint8(m.lifetime >> 8) + x[7] = uint8(m.lifetime) + x[8] = uint8(m.ageAdd >> 24) + x[9] = uint8(m.ageAdd >> 16) + x[10] = uint8(m.ageAdd >> 8) + x[11] = uint8(m.ageAdd) + + x[12] = uint8(ticketLen >> 8) + x[13] = uint8(ticketLen) + copy(x[14:], m.ticket) + + if m.withEarlyDataInfo { + z := x[14+ticketLen:] + z[1] = 8 + z[2] = uint8(extensionTicketEarlyDataInfo >> 8) + z[3] = uint8(extensionTicketEarlyDataInfo) + z[5] = 4 + z[6] = uint8(m.maxEarlyDataLength >> 24) + z[7] = uint8(m.maxEarlyDataLength >> 16) + z[8] = uint8(m.maxEarlyDataLength >> 8) + z[9] = uint8(m.maxEarlyDataLength) + } + + m.raw = x + + return +} + +func (m *newSessionTicketMsg13) unmarshal(data []byte) bool { + m.raw = data + m.maxEarlyDataLength = 0 + m.withEarlyDataInfo = false + + if len(data) < 16 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + m.lifetime = uint32(data[4])<<24 | uint32(data[5])<<16 | + uint32(data[6])<<8 | uint32(data[7]) + m.ageAdd = uint32(data[8])<<24 | uint32(data[9])<<16 | + uint32(data[10])<<8 | uint32(data[11]) + + ticketLen := int(data[12])<<8 + int(data[13]) + if 14+ticketLen > len(data) { + return false + } + m.ticket = data[14 : 14+ticketLen] + + data = data[14+ticketLen:] + extLen := int(data[0])<<8 + int(data[1]) + if extLen != len(data)-2 { + return false + } + + data = data[2:] + for len(data) > 0 { + if len(data) < 4 { + return false + } + extType := uint16(data[0])<<8 + uint16(data[1]) + length := int(data[2])<<8 + int(data[3]) + data = data[4:] + + switch extType { + case extensionTicketEarlyDataInfo: + if length != 4 { + return false + } + m.withEarlyDataInfo = true + m.maxEarlyDataLength = uint32(data[0])<<24 | uint32(data[1])<<16 | + uint32(data[2])<<8 | uint32(data[3]) + } + data = data[length:] + } + + return true +} + type helloRequestMsg struct { } diff --git a/handshake_messages_test.go b/handshake_messages_test.go index db89801..1dd0e3b 100644 --- a/handshake_messages_test.go +++ b/handshake_messages_test.go @@ -29,6 +29,8 @@ var tests = []interface{}{ &serverHelloMsg13{}, &encryptedExtensionsMsg{}, &certificateMsg13{}, + &newSessionTicketMsg13{}, + &sessionState13{}, } type testMessage interface { @@ -210,6 +212,10 @@ func (*serverHelloMsg13) Generate(rand *rand.Rand, size int) reflect.Value { m.cipherSuite = uint16(rand.Int31()) m.keyShare.group = CurveID(rand.Intn(30000)) m.keyShare.data = randomBytes(rand.Intn(300), rand) + if rand.Intn(10) > 5 { + m.psk = true + m.pskIdentity = uint16(rand.Int31()) + } return reflect.ValueOf(m) } @@ -294,6 +300,18 @@ func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value { return reflect.ValueOf(m) } +func (*newSessionTicketMsg13) Generate(rand *rand.Rand, size int) reflect.Value { + m := &newSessionTicketMsg13{} + m.ageAdd = uint32(rand.Intn(0xffffffff)) + m.lifetime = uint32(rand.Intn(0xffffffff)) + m.ticket = randomBytes(rand.Intn(40), rand) + if rand.Intn(10) > 5 { + m.withEarlyDataInfo = true + m.maxEarlyDataLength = uint32(rand.Intn(0xffffffff)) + } + return reflect.ValueOf(m) +} + func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { s := &sessionState{} s.vers = uint16(rand.Intn(10000)) @@ -307,6 +325,16 @@ func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { return reflect.ValueOf(s) } +func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value { + s := &sessionState13{} + s.vers = uint16(rand.Intn(10000)) + s.hash = uint16(rand.Intn(10000)) + s.ageAdd = uint32(rand.Intn(0xffffffff)) + s.createdAt = uint64(rand.Int63n(0xfffffffffffffff)) + s.resumptionSecret = randomBytes(rand.Intn(100), rand) + return reflect.ValueOf(s) +} + func TestRejectEmptySCTList(t *testing.T) { // https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that // empty SCT lists are invalid. diff --git a/ticket.go b/ticket.go index 3e7aa93..c0d79d7 100644 --- a/ticket.go +++ b/ticket.go @@ -129,6 +129,70 @@ func (s *sessionState) unmarshal(data []byte) bool { return len(data) == 0 } +type sessionState13 struct { + vers uint16 + hash uint16 // crypto.Hash value + ageAdd uint32 + createdAt uint64 + resumptionSecret []byte +} + +func (s *sessionState13) equal(i interface{}) bool { + s1, ok := i.(*sessionState13) + if !ok { + return false + } + + return s.vers == s1.vers && + s.hash == s1.hash && + s.ageAdd == s1.ageAdd && + bytes.Equal(s.resumptionSecret, s1.resumptionSecret) +} + +func (s *sessionState13) marshal() []byte { + length := 2 + 2 + 4 + 8 + 2 + len(s.resumptionSecret) + + x := make([]byte, length) + x[0] = byte(s.vers >> 8) + x[1] = byte(s.vers) + x[2] = byte(s.hash >> 8) + x[3] = byte(s.hash) + x[4] = byte(s.ageAdd >> 24) + x[5] = byte(s.ageAdd >> 16) + x[6] = byte(s.ageAdd >> 8) + x[7] = byte(s.ageAdd) + x[8] = byte(s.createdAt >> 56) + x[9] = byte(s.createdAt >> 48) + x[10] = byte(s.createdAt >> 40) + x[11] = byte(s.createdAt >> 32) + x[12] = byte(s.createdAt >> 24) + x[13] = byte(s.createdAt >> 16) + x[14] = byte(s.createdAt >> 8) + x[15] = byte(s.createdAt) + x[16] = byte(len(s.resumptionSecret) >> 8) + x[17] = byte(len(s.resumptionSecret)) + + copy(x[18:], s.resumptionSecret) + + return x +} + +func (s *sessionState13) unmarshal(data []byte) bool { + if len(data) < 18 { + return false + } + + s.vers = uint16(data[0])<<8 | uint16(data[1]) + s.hash = uint16(data[2])<<8 | uint16(data[3]) + s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7]) + s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 | + uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15]) + l := uint16(data[16])<<8 | uint16(data[17]) + s.resumptionSecret = data[18:] + + return int(l) == len(s.resumptionSecret) +} + func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) { serialized := state.marshal() encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(serialized)+sha256.Size)