crypto/tls: implement TLS 1.3 PSK messages
This commit is contained in:
부모
6c3765bb15
커밋
453bd6af77
18
common.go
18
common.go
@ -84,7 +84,10 @@ const (
|
||||
extensionSCT uint16 = 18 // https://tools.ietf.org/html/rfc6962#section-6
|
||||
extensionSessionTicket uint16 = 35
|
||||
extensionKeyShare uint16 = 40
|
||||
extensionPreSharedKey uint16 = 41
|
||||
extensionSupportedVersions uint16 = 43
|
||||
extensionPSKKeyExchangeModes uint16 = 45
|
||||
extensionTicketEarlyDataInfo uint16 = 46
|
||||
extensionNextProtoNeg uint16 = 13172 // not IANA assigned
|
||||
extensionRenegotiationInfo uint16 = 0xff01
|
||||
)
|
||||
@ -94,6 +97,12 @@ const (
|
||||
scsvRenegotiation uint16 = 0x00ff
|
||||
)
|
||||
|
||||
// PSK Key Exchange Modes
|
||||
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.7
|
||||
const (
|
||||
pskDHEKeyExchange uint8 = 1
|
||||
)
|
||||
|
||||
// CurveID is the type of a TLS identifier for an elliptic curve. See
|
||||
// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8
|
||||
//
|
||||
@ -115,6 +124,15 @@ type keyShare struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
// TLS 1.3 PSK Identity and Binder, as sent by the client
|
||||
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6
|
||||
|
||||
type psk struct {
|
||||
identity []byte
|
||||
obfTicketAge uint32
|
||||
binder []byte
|
||||
}
|
||||
|
||||
// TLS Elliptic Curve Point Formats
|
||||
// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9
|
||||
const (
|
||||
|
4
conn.go
4
conn.go
@ -1032,7 +1032,11 @@ func (c *Conn) readHandshake() (interface{}, error) {
|
||||
case typeEncryptedExtensions:
|
||||
m = new(encryptedExtensionsMsg)
|
||||
case typeNewSessionTicket:
|
||||
if c.vers >= VersionTLS13 {
|
||||
m = new(newSessionTicketMsg13)
|
||||
} else {
|
||||
m = new(newSessionTicketMsg)
|
||||
}
|
||||
case typeCertificate:
|
||||
if c.vers >= VersionTLS13 {
|
||||
m = new(certificateMsg13)
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
|
||||
type clientHelloMsg struct {
|
||||
raw []byte
|
||||
rawTruncated []byte // for PSK binding
|
||||
vers uint16
|
||||
random []byte
|
||||
sessionId []byte
|
||||
@ -30,6 +31,8 @@ type clientHelloMsg struct {
|
||||
alpnProtocols []string
|
||||
keyShares []keyShare
|
||||
supportedVersions []uint16
|
||||
psks []psk
|
||||
pskKeyExchangeModes []uint8
|
||||
}
|
||||
|
||||
func (m *clientHelloMsg) equal(i interface{}) bool {
|
||||
@ -366,6 +369,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
|
||||
}
|
||||
m.sessionId = data[39 : 39+sessionIdLen]
|
||||
data = data[39+sessionIdLen:]
|
||||
bindersOffset := 39 + sessionIdLen
|
||||
if len(data) < 2 {
|
||||
return false
|
||||
}
|
||||
@ -384,6 +388,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
|
||||
}
|
||||
}
|
||||
data = data[2+cipherSuiteLen:]
|
||||
bindersOffset += 2 + cipherSuiteLen
|
||||
if len(data) < 1 {
|
||||
return false
|
||||
}
|
||||
@ -394,6 +399,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
|
||||
m.compressionMethods = data[1 : 1+compressionMethodsLen]
|
||||
|
||||
data = data[1+compressionMethodsLen:]
|
||||
bindersOffset += 1 + compressionMethodsLen
|
||||
|
||||
m.nextProtoNeg = false
|
||||
m.serverName = ""
|
||||
@ -405,6 +411,8 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
|
||||
m.scts = false
|
||||
m.keyShares = nil
|
||||
m.supportedVersions = nil
|
||||
m.psks = nil
|
||||
m.pskKeyExchangeModes = nil
|
||||
|
||||
if len(data) == 0 {
|
||||
// ClientHello is optionally followed by extension data
|
||||
@ -416,6 +424,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
|
||||
|
||||
extensionsLength := int(data[0])<<8 | int(data[1])
|
||||
data = data[2:]
|
||||
bindersOffset += 2
|
||||
if extensionsLength != len(data) {
|
||||
return false
|
||||
}
|
||||
@ -427,6 +436,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
|
||||
extension := uint16(data[0])<<8 | uint16(data[1])
|
||||
length := int(data[2])<<8 | int(data[3])
|
||||
data = data[4:]
|
||||
bindersOffset += 4
|
||||
if len(data) < length {
|
||||
return false
|
||||
}
|
||||
@ -597,8 +607,70 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
|
||||
m.supportedVersions = append(m.supportedVersions, v)
|
||||
d = d[2:]
|
||||
}
|
||||
case extensionPreSharedKey:
|
||||
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6
|
||||
if length < 43 {
|
||||
return false
|
||||
}
|
||||
li := int(data[0])<<8 | int(data[1])
|
||||
if li > length-37 {
|
||||
return false
|
||||
}
|
||||
d := data[2 : 2+li]
|
||||
bindersOffset += 2 + li
|
||||
for len(d) > 0 {
|
||||
if len(d) < 6 {
|
||||
return false
|
||||
}
|
||||
l := int(d[0])<<8 | int(d[1])
|
||||
if len(d) < 2+l+4 {
|
||||
return false
|
||||
}
|
||||
m.psks = append(m.psks, psk{
|
||||
identity: d[2 : 2+l],
|
||||
obfTicketAge: uint32(d[l+2])<<24 | uint32(d[l+3])<<16 |
|
||||
uint32(d[l+4])<<8 | uint32(d[l+5]),
|
||||
})
|
||||
d = d[2+l+4:]
|
||||
}
|
||||
lb := int(data[li+2])<<8 | int(data[li+3])
|
||||
d = data[2+li+2:]
|
||||
if lb != len(d) {
|
||||
return false
|
||||
}
|
||||
i := 0
|
||||
for len(d) > 0 {
|
||||
if len(d) < 1 {
|
||||
return false
|
||||
}
|
||||
l := int(d[0])
|
||||
if l > len(d)-1 {
|
||||
return false
|
||||
}
|
||||
m.psks[i].binder = d[1 : 1+l]
|
||||
d = d[1+l:]
|
||||
i++
|
||||
}
|
||||
if i != len(m.psks) {
|
||||
return false
|
||||
}
|
||||
// Ensure this extension is the last one in the Client Hello
|
||||
if len(data) != length {
|
||||
return false
|
||||
}
|
||||
m.rawTruncated = m.raw[:bindersOffset]
|
||||
case extensionPSKKeyExchangeModes:
|
||||
if length < 2 {
|
||||
return false
|
||||
}
|
||||
l := int(data[0])
|
||||
if length != l+1 {
|
||||
return false
|
||||
}
|
||||
m.pskKeyExchangeModes = data[1:length]
|
||||
}
|
||||
data = data[length:]
|
||||
bindersOffset += length
|
||||
}
|
||||
|
||||
return true
|
||||
@ -941,6 +1013,8 @@ type serverHelloMsg13 struct {
|
||||
random []byte
|
||||
cipherSuite uint16
|
||||
keyShare keyShare
|
||||
psk bool
|
||||
pskIdentity uint16
|
||||
}
|
||||
|
||||
func (m *serverHelloMsg13) equal(i interface{}) bool {
|
||||
@ -954,7 +1028,9 @@ func (m *serverHelloMsg13) equal(i interface{}) bool {
|
||||
bytes.Equal(m.random, m1.random) &&
|
||||
m.cipherSuite == m1.cipherSuite &&
|
||||
m.keyShare.group == m1.keyShare.group &&
|
||||
bytes.Equal(m.keyShare.data, m1.keyShare.data)
|
||||
bytes.Equal(m.keyShare.data, m1.keyShare.data) &&
|
||||
m.psk == m1.psk &&
|
||||
m.pskIdentity == m1.pskIdentity
|
||||
}
|
||||
|
||||
func (m *serverHelloMsg13) marshal() []byte {
|
||||
@ -966,6 +1042,9 @@ func (m *serverHelloMsg13) marshal() []byte {
|
||||
if m.keyShare.group != 0 {
|
||||
length += 8 + len(m.keyShare.data)
|
||||
}
|
||||
if m.psk {
|
||||
length += 6
|
||||
}
|
||||
|
||||
x := make([]byte, 4+length)
|
||||
x[0] = typeServerHello
|
||||
@ -982,6 +1061,15 @@ func (m *serverHelloMsg13) marshal() []byte {
|
||||
x[40] = uint8(len(z) >> 8)
|
||||
x[41] = uint8(len(z))
|
||||
|
||||
if m.psk {
|
||||
z[0] = byte(extensionPreSharedKey >> 8)
|
||||
z[1] = byte(extensionPreSharedKey)
|
||||
z[3] = 2
|
||||
z[4] = byte(m.pskIdentity >> 8)
|
||||
z[5] = byte(m.pskIdentity)
|
||||
z = z[6:]
|
||||
}
|
||||
|
||||
if m.keyShare.group != 0 {
|
||||
z[0] = uint8(extensionKeyShare >> 8)
|
||||
z[1] = uint8(extensionKeyShare)
|
||||
@ -1008,6 +1096,8 @@ func (m *serverHelloMsg13) unmarshal(data []byte) bool {
|
||||
m.vers = uint16(data[4])<<8 | uint16(data[5])
|
||||
m.random = data[6:38]
|
||||
m.cipherSuite = uint16(data[38])<<8 | uint16(data[39])
|
||||
m.psk = false
|
||||
m.pskIdentity = 0
|
||||
|
||||
extensionsLength := int(data[40])<<8 | int(data[41])
|
||||
data = data[42:]
|
||||
@ -1029,6 +1119,12 @@ func (m *serverHelloMsg13) unmarshal(data []byte) bool {
|
||||
switch extension {
|
||||
default:
|
||||
return false
|
||||
case extensionPreSharedKey:
|
||||
if length != 2 {
|
||||
return false
|
||||
}
|
||||
m.psk = true
|
||||
m.pskIdentity = uint16(data[0])<<8 | uint16(data[1])
|
||||
case extensionKeyShare:
|
||||
if length < 2 {
|
||||
return false
|
||||
@ -1925,6 +2021,131 @@ func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type newSessionTicketMsg13 struct {
|
||||
raw []byte
|
||||
lifetime uint32
|
||||
ageAdd uint32
|
||||
ticket []byte
|
||||
withEarlyDataInfo bool
|
||||
maxEarlyDataLength uint32
|
||||
}
|
||||
|
||||
func (m *newSessionTicketMsg13) equal(i interface{}) bool {
|
||||
m1, ok := i.(*newSessionTicketMsg13)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return bytes.Equal(m.raw, m1.raw) &&
|
||||
m.lifetime == m1.lifetime &&
|
||||
m.ageAdd == m1.ageAdd &&
|
||||
bytes.Equal(m.ticket, m1.ticket) &&
|
||||
m.withEarlyDataInfo == m1.withEarlyDataInfo &&
|
||||
m.maxEarlyDataLength == m1.maxEarlyDataLength
|
||||
}
|
||||
|
||||
func (m *newSessionTicketMsg13) marshal() (x []byte) {
|
||||
if m.raw != nil {
|
||||
return m.raw
|
||||
}
|
||||
|
||||
// See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6
|
||||
ticketLen := len(m.ticket)
|
||||
length := 12 + ticketLen
|
||||
if m.withEarlyDataInfo {
|
||||
length += 8
|
||||
}
|
||||
x = make([]byte, 4+length)
|
||||
x[0] = typeNewSessionTicket
|
||||
x[1] = uint8(length >> 16)
|
||||
x[2] = uint8(length >> 8)
|
||||
x[3] = uint8(length)
|
||||
|
||||
x[4] = uint8(m.lifetime >> 24)
|
||||
x[5] = uint8(m.lifetime >> 16)
|
||||
x[6] = uint8(m.lifetime >> 8)
|
||||
x[7] = uint8(m.lifetime)
|
||||
x[8] = uint8(m.ageAdd >> 24)
|
||||
x[9] = uint8(m.ageAdd >> 16)
|
||||
x[10] = uint8(m.ageAdd >> 8)
|
||||
x[11] = uint8(m.ageAdd)
|
||||
|
||||
x[12] = uint8(ticketLen >> 8)
|
||||
x[13] = uint8(ticketLen)
|
||||
copy(x[14:], m.ticket)
|
||||
|
||||
if m.withEarlyDataInfo {
|
||||
z := x[14+ticketLen:]
|
||||
z[1] = 8
|
||||
z[2] = uint8(extensionTicketEarlyDataInfo >> 8)
|
||||
z[3] = uint8(extensionTicketEarlyDataInfo)
|
||||
z[5] = 4
|
||||
z[6] = uint8(m.maxEarlyDataLength >> 24)
|
||||
z[7] = uint8(m.maxEarlyDataLength >> 16)
|
||||
z[8] = uint8(m.maxEarlyDataLength >> 8)
|
||||
z[9] = uint8(m.maxEarlyDataLength)
|
||||
}
|
||||
|
||||
m.raw = x
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (m *newSessionTicketMsg13) unmarshal(data []byte) bool {
|
||||
m.raw = data
|
||||
m.maxEarlyDataLength = 0
|
||||
m.withEarlyDataInfo = false
|
||||
|
||||
if len(data) < 16 {
|
||||
return false
|
||||
}
|
||||
|
||||
length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
|
||||
if uint32(len(data))-4 != length {
|
||||
return false
|
||||
}
|
||||
|
||||
m.lifetime = uint32(data[4])<<24 | uint32(data[5])<<16 |
|
||||
uint32(data[6])<<8 | uint32(data[7])
|
||||
m.ageAdd = uint32(data[8])<<24 | uint32(data[9])<<16 |
|
||||
uint32(data[10])<<8 | uint32(data[11])
|
||||
|
||||
ticketLen := int(data[12])<<8 + int(data[13])
|
||||
if 14+ticketLen > len(data) {
|
||||
return false
|
||||
}
|
||||
m.ticket = data[14 : 14+ticketLen]
|
||||
|
||||
data = data[14+ticketLen:]
|
||||
extLen := int(data[0])<<8 + int(data[1])
|
||||
if extLen != len(data)-2 {
|
||||
return false
|
||||
}
|
||||
|
||||
data = data[2:]
|
||||
for len(data) > 0 {
|
||||
if len(data) < 4 {
|
||||
return false
|
||||
}
|
||||
extType := uint16(data[0])<<8 + uint16(data[1])
|
||||
length := int(data[2])<<8 + int(data[3])
|
||||
data = data[4:]
|
||||
|
||||
switch extType {
|
||||
case extensionTicketEarlyDataInfo:
|
||||
if length != 4 {
|
||||
return false
|
||||
}
|
||||
m.withEarlyDataInfo = true
|
||||
m.maxEarlyDataLength = uint32(data[0])<<24 | uint32(data[1])<<16 |
|
||||
uint32(data[2])<<8 | uint32(data[3])
|
||||
}
|
||||
data = data[length:]
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
type helloRequestMsg struct {
|
||||
}
|
||||
|
||||
|
@ -29,6 +29,8 @@ var tests = []interface{}{
|
||||
&serverHelloMsg13{},
|
||||
&encryptedExtensionsMsg{},
|
||||
&certificateMsg13{},
|
||||
&newSessionTicketMsg13{},
|
||||
&sessionState13{},
|
||||
}
|
||||
|
||||
type testMessage interface {
|
||||
@ -210,6 +212,10 @@ func (*serverHelloMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
m.cipherSuite = uint16(rand.Int31())
|
||||
m.keyShare.group = CurveID(rand.Intn(30000))
|
||||
m.keyShare.data = randomBytes(rand.Intn(300), rand)
|
||||
if rand.Intn(10) > 5 {
|
||||
m.psk = true
|
||||
m.pskIdentity = uint16(rand.Int31())
|
||||
}
|
||||
|
||||
return reflect.ValueOf(m)
|
||||
}
|
||||
@ -294,6 +300,18 @@ func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
return reflect.ValueOf(m)
|
||||
}
|
||||
|
||||
func (*newSessionTicketMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
m := &newSessionTicketMsg13{}
|
||||
m.ageAdd = uint32(rand.Intn(0xffffffff))
|
||||
m.lifetime = uint32(rand.Intn(0xffffffff))
|
||||
m.ticket = randomBytes(rand.Intn(40), rand)
|
||||
if rand.Intn(10) > 5 {
|
||||
m.withEarlyDataInfo = true
|
||||
m.maxEarlyDataLength = uint32(rand.Intn(0xffffffff))
|
||||
}
|
||||
return reflect.ValueOf(m)
|
||||
}
|
||||
|
||||
func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
s := &sessionState{}
|
||||
s.vers = uint16(rand.Intn(10000))
|
||||
@ -307,6 +325,16 @@ func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
return reflect.ValueOf(s)
|
||||
}
|
||||
|
||||
func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
s := &sessionState13{}
|
||||
s.vers = uint16(rand.Intn(10000))
|
||||
s.hash = uint16(rand.Intn(10000))
|
||||
s.ageAdd = uint32(rand.Intn(0xffffffff))
|
||||
s.createdAt = uint64(rand.Int63n(0xfffffffffffffff))
|
||||
s.resumptionSecret = randomBytes(rand.Intn(100), rand)
|
||||
return reflect.ValueOf(s)
|
||||
}
|
||||
|
||||
func TestRejectEmptySCTList(t *testing.T) {
|
||||
// https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that
|
||||
// empty SCT lists are invalid.
|
||||
|
64
ticket.go
64
ticket.go
@ -129,6 +129,70 @@ func (s *sessionState) unmarshal(data []byte) bool {
|
||||
return len(data) == 0
|
||||
}
|
||||
|
||||
type sessionState13 struct {
|
||||
vers uint16
|
||||
hash uint16 // crypto.Hash value
|
||||
ageAdd uint32
|
||||
createdAt uint64
|
||||
resumptionSecret []byte
|
||||
}
|
||||
|
||||
func (s *sessionState13) equal(i interface{}) bool {
|
||||
s1, ok := i.(*sessionState13)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return s.vers == s1.vers &&
|
||||
s.hash == s1.hash &&
|
||||
s.ageAdd == s1.ageAdd &&
|
||||
bytes.Equal(s.resumptionSecret, s1.resumptionSecret)
|
||||
}
|
||||
|
||||
func (s *sessionState13) marshal() []byte {
|
||||
length := 2 + 2 + 4 + 8 + 2 + len(s.resumptionSecret)
|
||||
|
||||
x := make([]byte, length)
|
||||
x[0] = byte(s.vers >> 8)
|
||||
x[1] = byte(s.vers)
|
||||
x[2] = byte(s.hash >> 8)
|
||||
x[3] = byte(s.hash)
|
||||
x[4] = byte(s.ageAdd >> 24)
|
||||
x[5] = byte(s.ageAdd >> 16)
|
||||
x[6] = byte(s.ageAdd >> 8)
|
||||
x[7] = byte(s.ageAdd)
|
||||
x[8] = byte(s.createdAt >> 56)
|
||||
x[9] = byte(s.createdAt >> 48)
|
||||
x[10] = byte(s.createdAt >> 40)
|
||||
x[11] = byte(s.createdAt >> 32)
|
||||
x[12] = byte(s.createdAt >> 24)
|
||||
x[13] = byte(s.createdAt >> 16)
|
||||
x[14] = byte(s.createdAt >> 8)
|
||||
x[15] = byte(s.createdAt)
|
||||
x[16] = byte(len(s.resumptionSecret) >> 8)
|
||||
x[17] = byte(len(s.resumptionSecret))
|
||||
|
||||
copy(x[18:], s.resumptionSecret)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
func (s *sessionState13) unmarshal(data []byte) bool {
|
||||
if len(data) < 18 {
|
||||
return false
|
||||
}
|
||||
|
||||
s.vers = uint16(data[0])<<8 | uint16(data[1])
|
||||
s.hash = uint16(data[2])<<8 | uint16(data[3])
|
||||
s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
|
||||
s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 |
|
||||
uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15])
|
||||
l := uint16(data[16])<<8 | uint16(data[17])
|
||||
s.resumptionSecret = data[18:]
|
||||
|
||||
return int(l) == len(s.resumptionSecret)
|
||||
}
|
||||
|
||||
func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) {
|
||||
serialized := state.marshal()
|
||||
encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(serialized)+sha256.Size)
|
||||
|
불러오는 중...
Reference in New Issue
Block a user