crypto/tls: add SignedCertificateTimestamps and OCSPStaple to 1.3

This commit is contained in:
Brendan Mc 2017-02-03 11:36:10 -08:00 committed by Peter Wu
parent 9b94b65b7b
commit ed105dc308
3 changed files with 144 additions and 18 deletions

13
13.go
View File

@ -191,9 +191,18 @@ func (hs *serverHandshakeState) readClientFinished13() error {
func (hs *serverHandshakeState) sendCertificate13() error { func (hs *serverHandshakeState) sendCertificate13() error {
c := hs.c c := hs.c
certMsg := &certificateMsg13{ certEntries := []certificateEntry{}
certificates: hs.cert.Certificate, for _, cert := range hs.cert.Certificate {
certEntries = append(certEntries, certificateEntry{data: cert})
} }
if len(certEntries) > 0 && hs.clientHello.ocspStapling {
certEntries[0].ocspStaple = hs.cert.OCSPStaple
}
if len(certEntries) > 0 && hs.clientHello.scts {
certEntries[0].sctList = hs.cert.SignedCertificateTimestamps
}
certMsg := &certificateMsg13{certificates: certEntries}
hs.finishedHash13.Write(certMsg.marshal()) hs.finishedHash13.Write(certMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
return err return err

View File

@ -1377,10 +1377,16 @@ func (m *certificateMsg) unmarshal(data []byte) alert {
return alertSuccess return alertSuccess
} }
type certificateEntry struct {
data []byte
ocspStaple []byte
sctList [][]byte
}
type certificateMsg13 struct { type certificateMsg13 struct {
raw []byte raw []byte
requestContext []byte requestContext []byte
certificates [][]byte certificates []certificateEntry
} }
func (m *certificateMsg13) equal(i interface{}) bool { func (m *certificateMsg13) equal(i interface{}) bool {
@ -1389,9 +1395,20 @@ func (m *certificateMsg13) equal(i interface{}) bool {
return false return false
} }
if len(m.certificates) != len(m1.certificates) {
return false
}
for i, _ := range m.certificates {
ok := bytes.Equal(m.certificates[i].data, m1.certificates[i].data)
ok = ok && bytes.Equal(m.certificates[i].ocspStaple, m1.certificates[i].ocspStaple)
ok = ok && eqByteSlices(m.certificates[i].sctList, m1.certificates[i].sctList)
if !ok {
return false
}
}
return bytes.Equal(m.raw, m1.raw) && return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.requestContext, m1.requestContext) && bytes.Equal(m.requestContext, m1.requestContext)
eqByteSlices(m.certificates, m1.certificates)
} }
func (m *certificateMsg13) marshal() (x []byte) { func (m *certificateMsg13) marshal() (x []byte) {
@ -1400,8 +1417,17 @@ func (m *certificateMsg13) marshal() (x []byte) {
} }
var i int var i int
for _, slice := range m.certificates { for _, cert := range m.certificates {
i += len(slice) i += len(cert.data)
if cert.ocspStaple != nil {
i += 8 + len(cert.ocspStaple)
}
if cert.sctList != nil {
i += 4
for _, sct := range cert.sctList {
i += 2 + len(sct)
}
}
} }
length := 3 + 3*len(m.certificates) + i length := 3 + 3*len(m.certificates) + i
@ -1425,12 +1451,56 @@ func (m *certificateMsg13) marshal() (x []byte) {
z[2] = uint8(certificateOctets) z[2] = uint8(certificateOctets)
z = z[3:] z = z[3:]
for _, slice := range m.certificates { for _, cert := range m.certificates {
z[0] = uint8(len(slice) >> 16) z[0] = uint8(len(cert.data) >> 16)
z[1] = uint8(len(slice) >> 8) z[1] = uint8(len(cert.data) >> 8)
z[2] = uint8(len(slice)) z[2] = uint8(len(cert.data))
copy(z[3:], slice) copy(z[3:], cert.data)
z = z[3+len(slice)+2:] z = z[3+len(cert.data):]
extLenPos := z[:2]
z = z[2:]
extensionLen := 0
if cert.ocspStaple != nil {
stapleLen := 4 + len(cert.ocspStaple)
z[0] = uint8(extensionStatusRequest >> 8)
z[1] = uint8(extensionStatusRequest)
z[2] = uint8(stapleLen >> 8)
z[3] = uint8(stapleLen)
stapleLen -= 4
z[4] = statusTypeOCSP
z[5] = uint8(stapleLen >> 16)
z[6] = uint8(stapleLen >> 8)
z[7] = uint8(stapleLen)
copy(z[8:], cert.ocspStaple)
z = z[8+stapleLen:]
extensionLen += 8 + stapleLen
}
if cert.sctList != nil {
z[0] = uint8(extensionSCT >> 8)
z[1] = uint8(extensionSCT)
sctLenPos := z[2:4]
z = z[4:]
extensionLen += 4
sctLen := 0
for _, sct := range cert.sctList {
z[0] = uint8(len(sct) >> 8)
z[1] = uint8(len(sct))
copy(z[2:], sct)
z = z[2+len(sct):]
extensionLen += 2 + len(sct)
sctLen += 2 + len(sct)
}
sctLenPos[0] = uint8(sctLen >> 8)
sctLenPos[1] = uint8(sctLen)
}
extLenPos[0] = uint8(extensionLen >> 8)
extLenPos[1] = uint8(extensionLen)
} }
m.raw = x m.raw = x
@ -1481,14 +1551,53 @@ func (m *certificateMsg13) unmarshal(data []byte) alert {
numCerts++ numCerts++
} }
m.certificates = make([][]byte, numCerts) m.certificates = make([]certificateEntry, numCerts)
d = data[8+ctxLen:] d = data[8+ctxLen:]
for i := 0; i < numCerts; i++ { for i := 0; i < numCerts; i++ {
certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
m.certificates[i] = d[3 : 3+certLen] m.certificates[i].data = d[3 : 3+certLen]
d = d[3+certLen:] d = d[3+certLen:]
extLen := uint16(d[0])<<8 | uint16(d[1]) extLen := uint16(d[0])<<8 | uint16(d[1])
d = d[2+extLen:] d = d[2:]
for extLen > 0 {
if extLen < 4 {
return alertDecodeError
}
typ := uint16(d[0])<<8 | uint16(d[1])
bodyLen := uint16(d[2])<<8 | uint16(d[3])
if extLen < 4+bodyLen {
return alertDecodeError
}
body := d[4 : 4+bodyLen]
d = d[4+bodyLen:]
extLen -= 4 + bodyLen
switch typ {
case extensionStatusRequest:
if len(body) < 4 || body[0] != 0x01 {
return alertDecodeError
}
ocspLen := int(body[1])<<16 | int(body[2])<<8 | int(body[3])
if len(body) != 4+ocspLen {
return alertDecodeError
}
m.certificates[i].ocspStaple = body[4:]
case extensionSCT:
for len(body) > 0 {
if len(body) < 2 {
return alertDecodeError
}
sctLen := int(body[0]<<8) | int(body[1])
if len(body) < 2+sctLen {
return alertDecodeError
}
m.certificates[i].sctList = append(m.certificates[i].sctList, body[2:2+sctLen])
body = body[2+sctLen:]
}
}
}
} }
return alertSuccess return alertSuccess

View File

@ -252,9 +252,17 @@ func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
func (*certificateMsg13) Generate(rand *rand.Rand, size int) reflect.Value { func (*certificateMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateMsg13{} m := &certificateMsg13{}
numCerts := rand.Intn(20) numCerts := rand.Intn(20)
m.certificates = make([][]byte, numCerts) m.certificates = make([]certificateEntry, numCerts)
for i := 0; i < numCerts; i++ { for i := 0; i < numCerts; i++ {
m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) m.certificates[i].data = randomBytes(rand.Intn(10)+1, rand)
if rand.Intn(2) == 1 {
m.certificates[i].ocspStaple = randomBytes(rand.Intn(10)+1, rand)
}
numScts := rand.Intn(3)
for j := 0; j < numScts; j++ {
m.certificates[i].sctList = append(m.certificates[i].sctList, randomBytes(rand.Intn(10)+1, rand))
}
} }
m.requestContext = randomBytes(rand.Intn(5), rand) m.requestContext = randomBytes(rand.Intn(5), rand)
return reflect.ValueOf(m) return reflect.ValueOf(m)