crypto/tls: finish the session ticket state checks
This commit is contained in:
parent
6ca044cede
commit
180bfdbd68
25
13.go
25
13.go
@ -434,11 +434,14 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
|
|||||||
|
|
||||||
// This enforces the stricter 0-RTT requirements on all ticket uses.
|
// This enforces the stricter 0-RTT requirements on all ticket uses.
|
||||||
// The benefit of using PSK+ECDHE without 0-RTT are small enough that
|
// The benefit of using PSK+ECDHE without 0-RTT are small enough that
|
||||||
// we can give them up in the edge case of changed suite or ALPN.
|
// we can give them up in the edge case of changed suite or ALPN or SNI.
|
||||||
if s.suite != hs.suite.id {
|
if s.suite != hs.suite.id {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if s.alpnProtocol != hs.hello13Enc.alpnProtocol {
|
if s.alpnProtocol != hs.c.clientProtocol {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if s.SNI != hs.c.serverName {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -451,11 +454,20 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
|
|||||||
expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey)
|
expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey)
|
||||||
|
|
||||||
if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) == 1 {
|
if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) == 1 {
|
||||||
hs.hello13.psk = true
|
if i == 0 && hs.clientHello.earlyData {
|
||||||
hs.hello13.pskIdentity = uint16(i)
|
// This is a ticket intended to be used for 0-RTT
|
||||||
if i == 0 && hs.clientHello.earlyData && hs.c.config.Accept0RTTData {
|
if s.maxEarlyDataLen == 0 {
|
||||||
|
// But we had not tagged it as such. We could close the connection
|
||||||
|
// here, but instead we just ignore the ticket and the 0-RTT data.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if hs.c.config.Accept0RTTData {
|
||||||
|
hs.c.ticketMaxEarlyData = int64(s.maxEarlyDataLen)
|
||||||
hs.hello13Enc.earlyData = true
|
hs.hello13Enc.earlyData = true
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
hs.hello13.psk = true
|
||||||
|
hs.hello13.pskIdentity = uint16(i)
|
||||||
return earlySecret, true
|
return earlySecret, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -496,6 +508,9 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
|
|||||||
uint32(ageAddBuf[2])<<8 | uint32(ageAddBuf[3]),
|
uint32(ageAddBuf[2])<<8 | uint32(ageAddBuf[3]),
|
||||||
createdAt: uint64(time.Now().Unix()),
|
createdAt: uint64(time.Now().Unix()),
|
||||||
resumptionSecret: resumptionSecret,
|
resumptionSecret: resumptionSecret,
|
||||||
|
alpnProtocol: c.clientProtocol,
|
||||||
|
SNI: c.serverName,
|
||||||
|
maxEarlyDataLen: c.config.Max0RTTDataSize,
|
||||||
}
|
}
|
||||||
|
|
||||||
ticket, err := c.encryptTicket(sessionState.marshal())
|
ticket, err := c.encryptTicket(sessionState.marshal())
|
||||||
|
12
conn.go
12
conn.go
@ -83,6 +83,10 @@ type Conn struct {
|
|||||||
clientProtocol string
|
clientProtocol string
|
||||||
clientProtocolFallback bool
|
clientProtocolFallback bool
|
||||||
|
|
||||||
|
// ticketMaxEarlyData is the maximum bytes of 0-RTT application data
|
||||||
|
// that the client is allowed to send on the ticket it used.
|
||||||
|
ticketMaxEarlyData int64
|
||||||
|
|
||||||
// input/output
|
// input/output
|
||||||
in, out halfConn // in.Mutex < out.Mutex
|
in, out halfConn // in.Mutex < out.Mutex
|
||||||
rawInput *block // raw input, right off the wire
|
rawInput *block // raw input, right off the wire
|
||||||
@ -106,6 +110,8 @@ type Conn struct {
|
|||||||
|
|
||||||
// earlyDataBytes is the number of bytes of early data received so
|
// earlyDataBytes is the number of bytes of early data received so
|
||||||
// far. Tracked to enforce max_early_data_size.
|
// far. Tracked to enforce max_early_data_size.
|
||||||
|
// We don't keep track of rejected 0-RTT data since there's no need
|
||||||
|
// to ever buffer it. in.Mutex.
|
||||||
earlyDataBytes int64
|
earlyDataBytes int64
|
||||||
|
|
||||||
tmp [16]byte
|
tmp [16]byte
|
||||||
@ -778,6 +784,12 @@ func (c *Conn) readRecord(want recordType) error {
|
|||||||
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
if c.phase == readingEarlyData {
|
||||||
|
c.earlyDataBytes += int64(len(b.data) - b.off)
|
||||||
|
if c.earlyDataBytes > c.ticketMaxEarlyData {
|
||||||
|
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||||
|
}
|
||||||
|
}
|
||||||
c.input = b
|
c.input = b
|
||||||
b = nil
|
b = nil
|
||||||
|
|
||||||
|
@ -338,9 +338,11 @@ func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value {
|
|||||||
s.vers = uint16(rand.Intn(10000))
|
s.vers = uint16(rand.Intn(10000))
|
||||||
s.suite = uint16(rand.Intn(10000))
|
s.suite = uint16(rand.Intn(10000))
|
||||||
s.ageAdd = uint32(rand.Intn(0xffffffff))
|
s.ageAdd = uint32(rand.Intn(0xffffffff))
|
||||||
|
s.maxEarlyDataLen = uint32(rand.Intn(0xffffffff))
|
||||||
s.createdAt = uint64(rand.Int63n(0xfffffffffffffff))
|
s.createdAt = uint64(rand.Int63n(0xfffffffffffffff))
|
||||||
s.resumptionSecret = randomBytes(rand.Intn(100), rand)
|
s.resumptionSecret = randomBytes(rand.Intn(100), rand)
|
||||||
s.alpnProtocol = randomString(rand.Intn(100), rand)
|
s.alpnProtocol = randomString(rand.Intn(100), rand)
|
||||||
|
s.SNI = randomString(rand.Intn(100), rand)
|
||||||
return reflect.ValueOf(s)
|
return reflect.ValueOf(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
57
ticket.go
57
ticket.go
@ -134,10 +134,10 @@ type sessionState13 struct {
|
|||||||
suite uint16
|
suite uint16
|
||||||
ageAdd uint32
|
ageAdd uint32
|
||||||
createdAt uint64
|
createdAt uint64
|
||||||
|
maxEarlyDataLen uint32
|
||||||
resumptionSecret []byte
|
resumptionSecret []byte
|
||||||
alpnProtocol string
|
alpnProtocol string
|
||||||
// TODO(filippo): add and check SNI
|
SNI string
|
||||||
// TODO(filippo): add and check maxEarlyDataLength
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sessionState13) equal(i interface{}) bool {
|
func (s *sessionState13) equal(i interface{}) bool {
|
||||||
@ -148,13 +148,16 @@ func (s *sessionState13) equal(i interface{}) bool {
|
|||||||
|
|
||||||
return s.vers == s1.vers &&
|
return s.vers == s1.vers &&
|
||||||
s.suite == s1.suite &&
|
s.suite == s1.suite &&
|
||||||
s.alpnProtocol == s1.alpnProtocol &&
|
|
||||||
s.ageAdd == s1.ageAdd &&
|
s.ageAdd == s1.ageAdd &&
|
||||||
bytes.Equal(s.resumptionSecret, s1.resumptionSecret)
|
s.createdAt == s1.createdAt &&
|
||||||
|
s.maxEarlyDataLen == s1.maxEarlyDataLen &&
|
||||||
|
bytes.Equal(s.resumptionSecret, s1.resumptionSecret) &&
|
||||||
|
s.alpnProtocol == s1.alpnProtocol &&
|
||||||
|
s.SNI == s1.SNI
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sessionState13) marshal() []byte {
|
func (s *sessionState13) marshal() []byte {
|
||||||
length := 2 + 2 + 4 + 8 + 2 + len(s.resumptionSecret) + 2 + len(s.alpnProtocol)
|
length := 2 + 2 + 4 + 8 + 4 + 2 + len(s.resumptionSecret) + 2 + len(s.alpnProtocol) + 2 + len(s.SNI)
|
||||||
|
|
||||||
x := make([]byte, length)
|
x := make([]byte, length)
|
||||||
x[0] = byte(s.vers >> 8)
|
x[0] = byte(s.vers >> 8)
|
||||||
@ -173,19 +176,27 @@ func (s *sessionState13) marshal() []byte {
|
|||||||
x[13] = byte(s.createdAt >> 16)
|
x[13] = byte(s.createdAt >> 16)
|
||||||
x[14] = byte(s.createdAt >> 8)
|
x[14] = byte(s.createdAt >> 8)
|
||||||
x[15] = byte(s.createdAt)
|
x[15] = byte(s.createdAt)
|
||||||
x[16] = byte(len(s.resumptionSecret) >> 8)
|
x[16] = byte(s.maxEarlyDataLen >> 24)
|
||||||
x[17] = byte(len(s.resumptionSecret))
|
x[17] = byte(s.maxEarlyDataLen >> 16)
|
||||||
copy(x[18:], s.resumptionSecret)
|
x[18] = byte(s.maxEarlyDataLen >> 8)
|
||||||
z := x[18+len(s.resumptionSecret):]
|
x[19] = byte(s.maxEarlyDataLen)
|
||||||
|
x[20] = byte(len(s.resumptionSecret) >> 8)
|
||||||
|
x[21] = byte(len(s.resumptionSecret))
|
||||||
|
copy(x[22:], s.resumptionSecret)
|
||||||
|
z := x[22+len(s.resumptionSecret):]
|
||||||
z[0] = byte(len(s.alpnProtocol) >> 8)
|
z[0] = byte(len(s.alpnProtocol) >> 8)
|
||||||
z[1] = byte(len(s.alpnProtocol))
|
z[1] = byte(len(s.alpnProtocol))
|
||||||
copy(z[2:], s.alpnProtocol)
|
copy(z[2:], s.alpnProtocol)
|
||||||
|
z = z[2+len(s.alpnProtocol):]
|
||||||
|
z[0] = byte(len(s.SNI) >> 8)
|
||||||
|
z[1] = byte(len(s.SNI))
|
||||||
|
copy(z[2:], s.SNI)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sessionState13) unmarshal(data []byte) bool {
|
func (s *sessionState13) unmarshal(data []byte) bool {
|
||||||
if len(data) < 18 {
|
if len(data) < 24 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -194,15 +205,29 @@ func (s *sessionState13) unmarshal(data []byte) bool {
|
|||||||
s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
|
s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
|
||||||
s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 |
|
s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 |
|
||||||
uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15])
|
uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15])
|
||||||
l := int(data[16])<<8 | int(data[17])
|
s.maxEarlyDataLen = uint32(data[16])<<24 | uint32(data[17])<<16 | uint32(data[18])<<8 | uint32(data[19])
|
||||||
if len(data) < 18+l+2 {
|
|
||||||
|
l := int(data[20])<<8 | int(data[21])
|
||||||
|
if len(data) < 22+l+2 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
s.resumptionSecret = data[18 : 18+l]
|
s.resumptionSecret = data[22 : 22+l]
|
||||||
z := data[18+l:]
|
z := data[22+l:]
|
||||||
|
|
||||||
l = int(z[0])<<8 | int(z[1])
|
l = int(z[0])<<8 | int(z[1])
|
||||||
s.alpnProtocol = string(z[2:])
|
if len(z) < 2+l+2 {
|
||||||
return l == len(s.alpnProtocol)
|
return false
|
||||||
|
}
|
||||||
|
s.alpnProtocol = string(z[2 : 2+l])
|
||||||
|
z = z[2+l:]
|
||||||
|
|
||||||
|
l = int(z[0])<<8 | int(z[1])
|
||||||
|
if len(z) != 2+l {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s.SNI = string(z[2 : 2+l])
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) {
|
func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) {
|
||||||
|
Loading…
Reference in New Issue
Block a user