diff --git a/13.go b/13.go index d8490bb..aacb821 100644 --- a/13.go +++ b/13.go @@ -434,11 +434,14 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) { // This enforces the stricter 0-RTT requirements on all ticket uses. // The benefit of using PSK+ECDHE without 0-RTT are small enough that - // we can give them up in the edge case of changed suite or ALPN. + // we can give them up in the edge case of changed suite or ALPN or SNI. if s.suite != hs.suite.id { continue } - if s.alpnProtocol != hs.hello13Enc.alpnProtocol { + if s.alpnProtocol != hs.c.clientProtocol { + continue + } + if s.SNI != hs.c.serverName { continue } @@ -451,11 +454,20 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) { expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey) if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) == 1 { + if i == 0 && hs.clientHello.earlyData { + // This is a ticket intended to be used for 0-RTT + if s.maxEarlyDataLen == 0 { + // But we had not tagged it as such. We could close the connection + // here, but instead we just ignore the ticket and the 0-RTT data. + continue + } + if hs.c.config.Accept0RTTData { + hs.c.ticketMaxEarlyData = int64(s.maxEarlyDataLen) + hs.hello13Enc.earlyData = true + } + } hs.hello13.psk = true hs.hello13.pskIdentity = uint16(i) - if i == 0 && hs.clientHello.earlyData && hs.c.config.Accept0RTTData { - hs.hello13Enc.earlyData = true - } return earlySecret, true } } @@ -496,6 +508,9 @@ func (hs *serverHandshakeState) sendSessionTicket13() error { uint32(ageAddBuf[2])<<8 | uint32(ageAddBuf[3]), createdAt: uint64(time.Now().Unix()), resumptionSecret: resumptionSecret, + alpnProtocol: c.clientProtocol, + SNI: c.serverName, + maxEarlyDataLen: c.config.Max0RTTDataSize, } ticket, err := c.encryptTicket(sessionState.marshal()) diff --git a/conn.go b/conn.go index 6afe10c..9223f1a 100644 --- a/conn.go +++ b/conn.go @@ -83,6 +83,10 @@ type Conn struct { clientProtocol string clientProtocolFallback bool + // ticketMaxEarlyData is the maximum bytes of 0-RTT application data + // that the client is allowed to send on the ticket it used. + ticketMaxEarlyData int64 + // input/output in, out halfConn // in.Mutex < out.Mutex rawInput *block // raw input, right off the wire @@ -106,6 +110,8 @@ type Conn struct { // earlyDataBytes is the number of bytes of early data received so // far. Tracked to enforce max_early_data_size. + // We don't keep track of rejected 0-RTT data since there's no need + // to ever buffer it. in.Mutex. earlyDataBytes int64 tmp [16]byte @@ -778,6 +784,12 @@ func (c *Conn) readRecord(want recordType) error { c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) break } + if c.phase == readingEarlyData { + c.earlyDataBytes += int64(len(b.data) - b.off) + if c.earlyDataBytes > c.ticketMaxEarlyData { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + } c.input = b b = nil diff --git a/handshake_messages_test.go b/handshake_messages_test.go index c0d8ac3..4835960 100644 --- a/handshake_messages_test.go +++ b/handshake_messages_test.go @@ -338,9 +338,11 @@ func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value { s.vers = uint16(rand.Intn(10000)) s.suite = uint16(rand.Intn(10000)) s.ageAdd = uint32(rand.Intn(0xffffffff)) + s.maxEarlyDataLen = uint32(rand.Intn(0xffffffff)) s.createdAt = uint64(rand.Int63n(0xfffffffffffffff)) s.resumptionSecret = randomBytes(rand.Intn(100), rand) s.alpnProtocol = randomString(rand.Intn(100), rand) + s.SNI = randomString(rand.Intn(100), rand) return reflect.ValueOf(s) } diff --git a/ticket.go b/ticket.go index df4c219..c2bd628 100644 --- a/ticket.go +++ b/ticket.go @@ -134,10 +134,10 @@ type sessionState13 struct { suite uint16 ageAdd uint32 createdAt uint64 + maxEarlyDataLen uint32 resumptionSecret []byte alpnProtocol string - // TODO(filippo): add and check SNI - // TODO(filippo): add and check maxEarlyDataLength + SNI string } func (s *sessionState13) equal(i interface{}) bool { @@ -148,13 +148,16 @@ func (s *sessionState13) equal(i interface{}) bool { return s.vers == s1.vers && s.suite == s1.suite && - s.alpnProtocol == s1.alpnProtocol && s.ageAdd == s1.ageAdd && - bytes.Equal(s.resumptionSecret, s1.resumptionSecret) + s.createdAt == s1.createdAt && + s.maxEarlyDataLen == s1.maxEarlyDataLen && + bytes.Equal(s.resumptionSecret, s1.resumptionSecret) && + s.alpnProtocol == s1.alpnProtocol && + s.SNI == s1.SNI } func (s *sessionState13) marshal() []byte { - length := 2 + 2 + 4 + 8 + 2 + len(s.resumptionSecret) + 2 + len(s.alpnProtocol) + length := 2 + 2 + 4 + 8 + 4 + 2 + len(s.resumptionSecret) + 2 + len(s.alpnProtocol) + 2 + len(s.SNI) x := make([]byte, length) x[0] = byte(s.vers >> 8) @@ -173,19 +176,27 @@ func (s *sessionState13) marshal() []byte { x[13] = byte(s.createdAt >> 16) x[14] = byte(s.createdAt >> 8) x[15] = byte(s.createdAt) - x[16] = byte(len(s.resumptionSecret) >> 8) - x[17] = byte(len(s.resumptionSecret)) - copy(x[18:], s.resumptionSecret) - z := x[18+len(s.resumptionSecret):] + x[16] = byte(s.maxEarlyDataLen >> 24) + x[17] = byte(s.maxEarlyDataLen >> 16) + x[18] = byte(s.maxEarlyDataLen >> 8) + x[19] = byte(s.maxEarlyDataLen) + x[20] = byte(len(s.resumptionSecret) >> 8) + x[21] = byte(len(s.resumptionSecret)) + copy(x[22:], s.resumptionSecret) + z := x[22+len(s.resumptionSecret):] z[0] = byte(len(s.alpnProtocol) >> 8) z[1] = byte(len(s.alpnProtocol)) copy(z[2:], s.alpnProtocol) + z = z[2+len(s.alpnProtocol):] + z[0] = byte(len(s.SNI) >> 8) + z[1] = byte(len(s.SNI)) + copy(z[2:], s.SNI) return x } func (s *sessionState13) unmarshal(data []byte) bool { - if len(data) < 18 { + if len(data) < 24 { return false } @@ -194,15 +205,29 @@ func (s *sessionState13) unmarshal(data []byte) bool { s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7]) s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 | uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15]) - l := int(data[16])<<8 | int(data[17]) - if len(data) < 18+l+2 { + s.maxEarlyDataLen = uint32(data[16])<<24 | uint32(data[17])<<16 | uint32(data[18])<<8 | uint32(data[19]) + + l := int(data[20])<<8 | int(data[21]) + if len(data) < 22+l+2 { return false } - s.resumptionSecret = data[18 : 18+l] - z := data[18+l:] + s.resumptionSecret = data[22 : 22+l] + z := data[22+l:] + l = int(z[0])<<8 | int(z[1]) - s.alpnProtocol = string(z[2:]) - return l == len(s.alpnProtocol) + if len(z) < 2+l+2 { + return false + } + s.alpnProtocol = string(z[2 : 2+l]) + z = z[2+l:] + + l = int(z[0])<<8 | int(z[1]) + if len(z) != 2+l { + return false + } + s.SNI = string(z[2 : 2+l]) + + return true } func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) {