crypto/tls: finish the session ticket state checks
This commit is contained in:
vanhempi
6ca044cede
commit
180bfdbd68
25
13.go
25
13.go
@ -434,11 +434,14 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
|
||||
|
||||
// This enforces the stricter 0-RTT requirements on all ticket uses.
|
||||
// The benefit of using PSK+ECDHE without 0-RTT are small enough that
|
||||
// we can give them up in the edge case of changed suite or ALPN.
|
||||
// we can give them up in the edge case of changed suite or ALPN or SNI.
|
||||
if s.suite != hs.suite.id {
|
||||
continue
|
||||
}
|
||||
if s.alpnProtocol != hs.hello13Enc.alpnProtocol {
|
||||
if s.alpnProtocol != hs.c.clientProtocol {
|
||||
continue
|
||||
}
|
||||
if s.SNI != hs.c.serverName {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -451,11 +454,20 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
|
||||
expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey)
|
||||
|
||||
if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) == 1 {
|
||||
hs.hello13.psk = true
|
||||
hs.hello13.pskIdentity = uint16(i)
|
||||
if i == 0 && hs.clientHello.earlyData && hs.c.config.Accept0RTTData {
|
||||
if i == 0 && hs.clientHello.earlyData {
|
||||
// This is a ticket intended to be used for 0-RTT
|
||||
if s.maxEarlyDataLen == 0 {
|
||||
// But we had not tagged it as such. We could close the connection
|
||||
// here, but instead we just ignore the ticket and the 0-RTT data.
|
||||
continue
|
||||
}
|
||||
if hs.c.config.Accept0RTTData {
|
||||
hs.c.ticketMaxEarlyData = int64(s.maxEarlyDataLen)
|
||||
hs.hello13Enc.earlyData = true
|
||||
}
|
||||
}
|
||||
hs.hello13.psk = true
|
||||
hs.hello13.pskIdentity = uint16(i)
|
||||
return earlySecret, true
|
||||
}
|
||||
}
|
||||
@ -496,6 +508,9 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
|
||||
uint32(ageAddBuf[2])<<8 | uint32(ageAddBuf[3]),
|
||||
createdAt: uint64(time.Now().Unix()),
|
||||
resumptionSecret: resumptionSecret,
|
||||
alpnProtocol: c.clientProtocol,
|
||||
SNI: c.serverName,
|
||||
maxEarlyDataLen: c.config.Max0RTTDataSize,
|
||||
}
|
||||
|
||||
ticket, err := c.encryptTicket(sessionState.marshal())
|
||||
|
12
conn.go
12
conn.go
@ -83,6 +83,10 @@ type Conn struct {
|
||||
clientProtocol string
|
||||
clientProtocolFallback bool
|
||||
|
||||
// ticketMaxEarlyData is the maximum bytes of 0-RTT application data
|
||||
// that the client is allowed to send on the ticket it used.
|
||||
ticketMaxEarlyData int64
|
||||
|
||||
// input/output
|
||||
in, out halfConn // in.Mutex < out.Mutex
|
||||
rawInput *block // raw input, right off the wire
|
||||
@ -106,6 +110,8 @@ type Conn struct {
|
||||
|
||||
// earlyDataBytes is the number of bytes of early data received so
|
||||
// far. Tracked to enforce max_early_data_size.
|
||||
// We don't keep track of rejected 0-RTT data since there's no need
|
||||
// to ever buffer it. in.Mutex.
|
||||
earlyDataBytes int64
|
||||
|
||||
tmp [16]byte
|
||||
@ -778,6 +784,12 @@ func (c *Conn) readRecord(want recordType) error {
|
||||
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
break
|
||||
}
|
||||
if c.phase == readingEarlyData {
|
||||
c.earlyDataBytes += int64(len(b.data) - b.off)
|
||||
if c.earlyDataBytes > c.ticketMaxEarlyData {
|
||||
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
}
|
||||
}
|
||||
c.input = b
|
||||
b = nil
|
||||
|
||||
|
@ -338,9 +338,11 @@ func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
s.vers = uint16(rand.Intn(10000))
|
||||
s.suite = uint16(rand.Intn(10000))
|
||||
s.ageAdd = uint32(rand.Intn(0xffffffff))
|
||||
s.maxEarlyDataLen = uint32(rand.Intn(0xffffffff))
|
||||
s.createdAt = uint64(rand.Int63n(0xfffffffffffffff))
|
||||
s.resumptionSecret = randomBytes(rand.Intn(100), rand)
|
||||
s.alpnProtocol = randomString(rand.Intn(100), rand)
|
||||
s.SNI = randomString(rand.Intn(100), rand)
|
||||
return reflect.ValueOf(s)
|
||||
}
|
||||
|
||||
|
57
ticket.go
57
ticket.go
@ -134,10 +134,10 @@ type sessionState13 struct {
|
||||
suite uint16
|
||||
ageAdd uint32
|
||||
createdAt uint64
|
||||
maxEarlyDataLen uint32
|
||||
resumptionSecret []byte
|
||||
alpnProtocol string
|
||||
// TODO(filippo): add and check SNI
|
||||
// TODO(filippo): add and check maxEarlyDataLength
|
||||
SNI string
|
||||
}
|
||||
|
||||
func (s *sessionState13) equal(i interface{}) bool {
|
||||
@ -148,13 +148,16 @@ func (s *sessionState13) equal(i interface{}) bool {
|
||||
|
||||
return s.vers == s1.vers &&
|
||||
s.suite == s1.suite &&
|
||||
s.alpnProtocol == s1.alpnProtocol &&
|
||||
s.ageAdd == s1.ageAdd &&
|
||||
bytes.Equal(s.resumptionSecret, s1.resumptionSecret)
|
||||
s.createdAt == s1.createdAt &&
|
||||
s.maxEarlyDataLen == s1.maxEarlyDataLen &&
|
||||
bytes.Equal(s.resumptionSecret, s1.resumptionSecret) &&
|
||||
s.alpnProtocol == s1.alpnProtocol &&
|
||||
s.SNI == s1.SNI
|
||||
}
|
||||
|
||||
func (s *sessionState13) marshal() []byte {
|
||||
length := 2 + 2 + 4 + 8 + 2 + len(s.resumptionSecret) + 2 + len(s.alpnProtocol)
|
||||
length := 2 + 2 + 4 + 8 + 4 + 2 + len(s.resumptionSecret) + 2 + len(s.alpnProtocol) + 2 + len(s.SNI)
|
||||
|
||||
x := make([]byte, length)
|
||||
x[0] = byte(s.vers >> 8)
|
||||
@ -173,19 +176,27 @@ func (s *sessionState13) marshal() []byte {
|
||||
x[13] = byte(s.createdAt >> 16)
|
||||
x[14] = byte(s.createdAt >> 8)
|
||||
x[15] = byte(s.createdAt)
|
||||
x[16] = byte(len(s.resumptionSecret) >> 8)
|
||||
x[17] = byte(len(s.resumptionSecret))
|
||||
copy(x[18:], s.resumptionSecret)
|
||||
z := x[18+len(s.resumptionSecret):]
|
||||
x[16] = byte(s.maxEarlyDataLen >> 24)
|
||||
x[17] = byte(s.maxEarlyDataLen >> 16)
|
||||
x[18] = byte(s.maxEarlyDataLen >> 8)
|
||||
x[19] = byte(s.maxEarlyDataLen)
|
||||
x[20] = byte(len(s.resumptionSecret) >> 8)
|
||||
x[21] = byte(len(s.resumptionSecret))
|
||||
copy(x[22:], s.resumptionSecret)
|
||||
z := x[22+len(s.resumptionSecret):]
|
||||
z[0] = byte(len(s.alpnProtocol) >> 8)
|
||||
z[1] = byte(len(s.alpnProtocol))
|
||||
copy(z[2:], s.alpnProtocol)
|
||||
z = z[2+len(s.alpnProtocol):]
|
||||
z[0] = byte(len(s.SNI) >> 8)
|
||||
z[1] = byte(len(s.SNI))
|
||||
copy(z[2:], s.SNI)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
func (s *sessionState13) unmarshal(data []byte) bool {
|
||||
if len(data) < 18 {
|
||||
if len(data) < 24 {
|
||||
return false
|
||||
}
|
||||
|
||||
@ -194,15 +205,29 @@ func (s *sessionState13) unmarshal(data []byte) bool {
|
||||
s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
|
||||
s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 |
|
||||
uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15])
|
||||
l := int(data[16])<<8 | int(data[17])
|
||||
if len(data) < 18+l+2 {
|
||||
s.maxEarlyDataLen = uint32(data[16])<<24 | uint32(data[17])<<16 | uint32(data[18])<<8 | uint32(data[19])
|
||||
|
||||
l := int(data[20])<<8 | int(data[21])
|
||||
if len(data) < 22+l+2 {
|
||||
return false
|
||||
}
|
||||
s.resumptionSecret = data[18 : 18+l]
|
||||
z := data[18+l:]
|
||||
s.resumptionSecret = data[22 : 22+l]
|
||||
z := data[22+l:]
|
||||
|
||||
l = int(z[0])<<8 | int(z[1])
|
||||
s.alpnProtocol = string(z[2:])
|
||||
return l == len(s.alpnProtocol)
|
||||
if len(z) < 2+l+2 {
|
||||
return false
|
||||
}
|
||||
s.alpnProtocol = string(z[2 : 2+l])
|
||||
z = z[2+l:]
|
||||
|
||||
l = int(z[0])<<8 | int(z[1])
|
||||
if len(z) != 2+l {
|
||||
return false
|
||||
}
|
||||
s.SNI = string(z[2 : 2+l])
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) {
|
||||
|
Ladataan…
Viittaa uudesa ongelmassa
Block a user