crypto/tls: finish the session ticket state checks

This commit is contained in:
Filippo Valsorda 2016-11-30 00:11:10 +00:00 committed by Peter Wu
parent 6ca044cede
commit 180bfdbd68
4 changed files with 75 additions and 21 deletions

25
13.go
View File

@ -434,11 +434,14 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
// This enforces the stricter 0-RTT requirements on all ticket uses.
// The benefit of using PSK+ECDHE without 0-RTT are small enough that
// we can give them up in the edge case of changed suite or ALPN.
// we can give them up in the edge case of changed suite or ALPN or SNI.
if s.suite != hs.suite.id {
continue
}
if s.alpnProtocol != hs.hello13Enc.alpnProtocol {
if s.alpnProtocol != hs.c.clientProtocol {
continue
}
if s.SNI != hs.c.serverName {
continue
}
@ -451,11 +454,20 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey)
if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) == 1 {
if i == 0 && hs.clientHello.earlyData {
// This is a ticket intended to be used for 0-RTT
if s.maxEarlyDataLen == 0 {
// But we had not tagged it as such. We could close the connection
// here, but instead we just ignore the ticket and the 0-RTT data.
continue
}
if hs.c.config.Accept0RTTData {
hs.c.ticketMaxEarlyData = int64(s.maxEarlyDataLen)
hs.hello13Enc.earlyData = true
}
}
hs.hello13.psk = true
hs.hello13.pskIdentity = uint16(i)
if i == 0 && hs.clientHello.earlyData && hs.c.config.Accept0RTTData {
hs.hello13Enc.earlyData = true
}
return earlySecret, true
}
}
@ -496,6 +508,9 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
uint32(ageAddBuf[2])<<8 | uint32(ageAddBuf[3]),
createdAt: uint64(time.Now().Unix()),
resumptionSecret: resumptionSecret,
alpnProtocol: c.clientProtocol,
SNI: c.serverName,
maxEarlyDataLen: c.config.Max0RTTDataSize,
}
ticket, err := c.encryptTicket(sessionState.marshal())

12
conn.go
View File

@ -83,6 +83,10 @@ type Conn struct {
clientProtocol string
clientProtocolFallback bool
// ticketMaxEarlyData is the maximum bytes of 0-RTT application data
// that the client is allowed to send on the ticket it used.
ticketMaxEarlyData int64
// input/output
in, out halfConn // in.Mutex < out.Mutex
rawInput *block // raw input, right off the wire
@ -106,6 +110,8 @@ type Conn struct {
// earlyDataBytes is the number of bytes of early data received so
// far. Tracked to enforce max_early_data_size.
// We don't keep track of rejected 0-RTT data since there's no need
// to ever buffer it. in.Mutex.
earlyDataBytes int64
tmp [16]byte
@ -778,6 +784,12 @@ func (c *Conn) readRecord(want recordType) error {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
if c.phase == readingEarlyData {
c.earlyDataBytes += int64(len(b.data) - b.off)
if c.earlyDataBytes > c.ticketMaxEarlyData {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
}
c.input = b
b = nil

View File

@ -338,9 +338,11 @@ func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value {
s.vers = uint16(rand.Intn(10000))
s.suite = uint16(rand.Intn(10000))
s.ageAdd = uint32(rand.Intn(0xffffffff))
s.maxEarlyDataLen = uint32(rand.Intn(0xffffffff))
s.createdAt = uint64(rand.Int63n(0xfffffffffffffff))
s.resumptionSecret = randomBytes(rand.Intn(100), rand)
s.alpnProtocol = randomString(rand.Intn(100), rand)
s.SNI = randomString(rand.Intn(100), rand)
return reflect.ValueOf(s)
}

View File

@ -134,10 +134,10 @@ type sessionState13 struct {
suite uint16
ageAdd uint32
createdAt uint64
maxEarlyDataLen uint32
resumptionSecret []byte
alpnProtocol string
// TODO(filippo): add and check SNI
// TODO(filippo): add and check maxEarlyDataLength
SNI string
}
func (s *sessionState13) equal(i interface{}) bool {
@ -148,13 +148,16 @@ func (s *sessionState13) equal(i interface{}) bool {
return s.vers == s1.vers &&
s.suite == s1.suite &&
s.alpnProtocol == s1.alpnProtocol &&
s.ageAdd == s1.ageAdd &&
bytes.Equal(s.resumptionSecret, s1.resumptionSecret)
s.createdAt == s1.createdAt &&
s.maxEarlyDataLen == s1.maxEarlyDataLen &&
bytes.Equal(s.resumptionSecret, s1.resumptionSecret) &&
s.alpnProtocol == s1.alpnProtocol &&
s.SNI == s1.SNI
}
func (s *sessionState13) marshal() []byte {
length := 2 + 2 + 4 + 8 + 2 + len(s.resumptionSecret) + 2 + len(s.alpnProtocol)
length := 2 + 2 + 4 + 8 + 4 + 2 + len(s.resumptionSecret) + 2 + len(s.alpnProtocol) + 2 + len(s.SNI)
x := make([]byte, length)
x[0] = byte(s.vers >> 8)
@ -173,19 +176,27 @@ func (s *sessionState13) marshal() []byte {
x[13] = byte(s.createdAt >> 16)
x[14] = byte(s.createdAt >> 8)
x[15] = byte(s.createdAt)
x[16] = byte(len(s.resumptionSecret) >> 8)
x[17] = byte(len(s.resumptionSecret))
copy(x[18:], s.resumptionSecret)
z := x[18+len(s.resumptionSecret):]
x[16] = byte(s.maxEarlyDataLen >> 24)
x[17] = byte(s.maxEarlyDataLen >> 16)
x[18] = byte(s.maxEarlyDataLen >> 8)
x[19] = byte(s.maxEarlyDataLen)
x[20] = byte(len(s.resumptionSecret) >> 8)
x[21] = byte(len(s.resumptionSecret))
copy(x[22:], s.resumptionSecret)
z := x[22+len(s.resumptionSecret):]
z[0] = byte(len(s.alpnProtocol) >> 8)
z[1] = byte(len(s.alpnProtocol))
copy(z[2:], s.alpnProtocol)
z = z[2+len(s.alpnProtocol):]
z[0] = byte(len(s.SNI) >> 8)
z[1] = byte(len(s.SNI))
copy(z[2:], s.SNI)
return x
}
func (s *sessionState13) unmarshal(data []byte) bool {
if len(data) < 18 {
if len(data) < 24 {
return false
}
@ -194,15 +205,29 @@ func (s *sessionState13) unmarshal(data []byte) bool {
s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 |
uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15])
l := int(data[16])<<8 | int(data[17])
if len(data) < 18+l+2 {
s.maxEarlyDataLen = uint32(data[16])<<24 | uint32(data[17])<<16 | uint32(data[18])<<8 | uint32(data[19])
l := int(data[20])<<8 | int(data[21])
if len(data) < 22+l+2 {
return false
}
s.resumptionSecret = data[18 : 18+l]
z := data[18+l:]
s.resumptionSecret = data[22 : 22+l]
z := data[22+l:]
l = int(z[0])<<8 | int(z[1])
s.alpnProtocol = string(z[2:])
return l == len(s.alpnProtocol)
if len(z) < 2+l+2 {
return false
}
s.alpnProtocol = string(z[2 : 2+l])
z = z[2+l:]
l = int(z[0])<<8 | int(z[1])
if len(z) != 2+l {
return false
}
s.SNI = string(z[2 : 2+l])
return true
}
func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) {