crypto/tls: add SignedCertificateTimestamps and OCSPStaple to 1.3
This commit is contained in:
parent
9b94b65b7b
commit
ed105dc308
13
13.go
13
13.go
@ -191,9 +191,18 @@ func (hs *serverHandshakeState) readClientFinished13() error {
|
||||
func (hs *serverHandshakeState) sendCertificate13() error {
|
||||
c := hs.c
|
||||
|
||||
certMsg := &certificateMsg13{
|
||||
certificates: hs.cert.Certificate,
|
||||
certEntries := []certificateEntry{}
|
||||
for _, cert := range hs.cert.Certificate {
|
||||
certEntries = append(certEntries, certificateEntry{data: cert})
|
||||
}
|
||||
if len(certEntries) > 0 && hs.clientHello.ocspStapling {
|
||||
certEntries[0].ocspStaple = hs.cert.OCSPStaple
|
||||
}
|
||||
if len(certEntries) > 0 && hs.clientHello.scts {
|
||||
certEntries[0].sctList = hs.cert.SignedCertificateTimestamps
|
||||
}
|
||||
certMsg := &certificateMsg13{certificates: certEntries}
|
||||
|
||||
hs.finishedHash13.Write(certMsg.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
|
||||
return err
|
||||
|
@ -1377,10 +1377,16 @@ func (m *certificateMsg) unmarshal(data []byte) alert {
|
||||
return alertSuccess
|
||||
}
|
||||
|
||||
type certificateEntry struct {
|
||||
data []byte
|
||||
ocspStaple []byte
|
||||
sctList [][]byte
|
||||
}
|
||||
|
||||
type certificateMsg13 struct {
|
||||
raw []byte
|
||||
requestContext []byte
|
||||
certificates [][]byte
|
||||
certificates []certificateEntry
|
||||
}
|
||||
|
||||
func (m *certificateMsg13) equal(i interface{}) bool {
|
||||
@ -1389,9 +1395,20 @@ func (m *certificateMsg13) equal(i interface{}) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(m.certificates) != len(m1.certificates) {
|
||||
return false
|
||||
}
|
||||
for i, _ := range m.certificates {
|
||||
ok := bytes.Equal(m.certificates[i].data, m1.certificates[i].data)
|
||||
ok = ok && bytes.Equal(m.certificates[i].ocspStaple, m1.certificates[i].ocspStaple)
|
||||
ok = ok && eqByteSlices(m.certificates[i].sctList, m1.certificates[i].sctList)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return bytes.Equal(m.raw, m1.raw) &&
|
||||
bytes.Equal(m.requestContext, m1.requestContext) &&
|
||||
eqByteSlices(m.certificates, m1.certificates)
|
||||
bytes.Equal(m.requestContext, m1.requestContext)
|
||||
}
|
||||
|
||||
func (m *certificateMsg13) marshal() (x []byte) {
|
||||
@ -1400,8 +1417,17 @@ func (m *certificateMsg13) marshal() (x []byte) {
|
||||
}
|
||||
|
||||
var i int
|
||||
for _, slice := range m.certificates {
|
||||
i += len(slice)
|
||||
for _, cert := range m.certificates {
|
||||
i += len(cert.data)
|
||||
if cert.ocspStaple != nil {
|
||||
i += 8 + len(cert.ocspStaple)
|
||||
}
|
||||
if cert.sctList != nil {
|
||||
i += 4
|
||||
for _, sct := range cert.sctList {
|
||||
i += 2 + len(sct)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
length := 3 + 3*len(m.certificates) + i
|
||||
@ -1425,12 +1451,56 @@ func (m *certificateMsg13) marshal() (x []byte) {
|
||||
z[2] = uint8(certificateOctets)
|
||||
|
||||
z = z[3:]
|
||||
for _, slice := range m.certificates {
|
||||
z[0] = uint8(len(slice) >> 16)
|
||||
z[1] = uint8(len(slice) >> 8)
|
||||
z[2] = uint8(len(slice))
|
||||
copy(z[3:], slice)
|
||||
z = z[3+len(slice)+2:]
|
||||
for _, cert := range m.certificates {
|
||||
z[0] = uint8(len(cert.data) >> 16)
|
||||
z[1] = uint8(len(cert.data) >> 8)
|
||||
z[2] = uint8(len(cert.data))
|
||||
copy(z[3:], cert.data)
|
||||
z = z[3+len(cert.data):]
|
||||
|
||||
extLenPos := z[:2]
|
||||
z = z[2:]
|
||||
|
||||
extensionLen := 0
|
||||
if cert.ocspStaple != nil {
|
||||
stapleLen := 4 + len(cert.ocspStaple)
|
||||
z[0] = uint8(extensionStatusRequest >> 8)
|
||||
z[1] = uint8(extensionStatusRequest)
|
||||
z[2] = uint8(stapleLen >> 8)
|
||||
z[3] = uint8(stapleLen)
|
||||
|
||||
stapleLen -= 4
|
||||
z[4] = statusTypeOCSP
|
||||
z[5] = uint8(stapleLen >> 16)
|
||||
z[6] = uint8(stapleLen >> 8)
|
||||
z[7] = uint8(stapleLen)
|
||||
copy(z[8:], cert.ocspStaple)
|
||||
z = z[8+stapleLen:]
|
||||
|
||||
extensionLen += 8 + stapleLen
|
||||
}
|
||||
if cert.sctList != nil {
|
||||
z[0] = uint8(extensionSCT >> 8)
|
||||
z[1] = uint8(extensionSCT)
|
||||
sctLenPos := z[2:4]
|
||||
z = z[4:]
|
||||
extensionLen += 4
|
||||
|
||||
sctLen := 0
|
||||
for _, sct := range cert.sctList {
|
||||
z[0] = uint8(len(sct) >> 8)
|
||||
z[1] = uint8(len(sct))
|
||||
copy(z[2:], sct)
|
||||
z = z[2+len(sct):]
|
||||
|
||||
extensionLen += 2 + len(sct)
|
||||
sctLen += 2 + len(sct)
|
||||
}
|
||||
sctLenPos[0] = uint8(sctLen >> 8)
|
||||
sctLenPos[1] = uint8(sctLen)
|
||||
}
|
||||
extLenPos[0] = uint8(extensionLen >> 8)
|
||||
extLenPos[1] = uint8(extensionLen)
|
||||
}
|
||||
|
||||
m.raw = x
|
||||
@ -1481,14 +1551,53 @@ func (m *certificateMsg13) unmarshal(data []byte) alert {
|
||||
numCerts++
|
||||
}
|
||||
|
||||
m.certificates = make([][]byte, numCerts)
|
||||
m.certificates = make([]certificateEntry, numCerts)
|
||||
d = data[8+ctxLen:]
|
||||
for i := 0; i < numCerts; i++ {
|
||||
certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
|
||||
m.certificates[i] = d[3 : 3+certLen]
|
||||
m.certificates[i].data = d[3 : 3+certLen]
|
||||
d = d[3+certLen:]
|
||||
|
||||
extLen := uint16(d[0])<<8 | uint16(d[1])
|
||||
d = d[2+extLen:]
|
||||
d = d[2:]
|
||||
for extLen > 0 {
|
||||
if extLen < 4 {
|
||||
return alertDecodeError
|
||||
}
|
||||
typ := uint16(d[0])<<8 | uint16(d[1])
|
||||
bodyLen := uint16(d[2])<<8 | uint16(d[3])
|
||||
if extLen < 4+bodyLen {
|
||||
return alertDecodeError
|
||||
}
|
||||
body := d[4 : 4+bodyLen]
|
||||
d = d[4+bodyLen:]
|
||||
extLen -= 4 + bodyLen
|
||||
|
||||
switch typ {
|
||||
case extensionStatusRequest:
|
||||
if len(body) < 4 || body[0] != 0x01 {
|
||||
return alertDecodeError
|
||||
}
|
||||
ocspLen := int(body[1])<<16 | int(body[2])<<8 | int(body[3])
|
||||
if len(body) != 4+ocspLen {
|
||||
return alertDecodeError
|
||||
}
|
||||
m.certificates[i].ocspStaple = body[4:]
|
||||
|
||||
case extensionSCT:
|
||||
for len(body) > 0 {
|
||||
if len(body) < 2 {
|
||||
return alertDecodeError
|
||||
}
|
||||
sctLen := int(body[0]<<8) | int(body[1])
|
||||
if len(body) < 2+sctLen {
|
||||
return alertDecodeError
|
||||
}
|
||||
m.certificates[i].sctList = append(m.certificates[i].sctList, body[2:2+sctLen])
|
||||
body = body[2+sctLen:]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return alertSuccess
|
||||
|
@ -252,9 +252,17 @@ func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
func (*certificateMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
m := &certificateMsg13{}
|
||||
numCerts := rand.Intn(20)
|
||||
m.certificates = make([][]byte, numCerts)
|
||||
m.certificates = make([]certificateEntry, numCerts)
|
||||
for i := 0; i < numCerts; i++ {
|
||||
m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
|
||||
m.certificates[i].data = randomBytes(rand.Intn(10)+1, rand)
|
||||
if rand.Intn(2) == 1 {
|
||||
m.certificates[i].ocspStaple = randomBytes(rand.Intn(10)+1, rand)
|
||||
}
|
||||
|
||||
numScts := rand.Intn(3)
|
||||
for j := 0; j < numScts; j++ {
|
||||
m.certificates[i].sctList = append(m.certificates[i].sctList, randomBytes(rand.Intn(10)+1, rand))
|
||||
}
|
||||
}
|
||||
m.requestContext = randomBytes(rand.Intn(5), rand)
|
||||
return reflect.ValueOf(m)
|
||||
|
Loading…
Reference in New Issue
Block a user