crypto/tls: add client OCSP stapling support.

R=r, rsc
CC=golang-dev
https://golang.org/cl/1750042
This commit is contained in:
Adam Langley 2010-07-14 10:40:15 -04:00
parent 08ec2a8e44
commit a54a4371e7
6 changed files with 164 additions and 20 deletions

View File

@ -38,6 +38,7 @@ const (
typeClientHello uint8 = 1 typeClientHello uint8 = 1
typeServerHello uint8 = 2 typeServerHello uint8 = 2
typeCertificate uint8 = 11 typeCertificate uint8 = 11
typeCertificateStatus uint8 = 22
typeServerHelloDone uint8 = 14 typeServerHelloDone uint8 = 14
typeClientKeyExchange uint8 = 16 typeClientKeyExchange uint8 = 16
typeFinished uint8 = 20 typeFinished uint8 = 20
@ -45,25 +46,30 @@ const (
) )
// TLS cipher suites. // TLS cipher suites.
var ( const (
TLS_RSA_WITH_RC4_128_SHA uint16 = 5 TLS_RSA_WITH_RC4_128_SHA uint16 = 5
) )
// TLS compression types. // TLS compression types.
var ( const (
compressionNone uint8 = 0 compressionNone uint8 = 0
) )
// TLS extension numbers // TLS extension numbers
var ( var (
extensionServerName uint16 = 0 extensionServerName uint16 = 0
extensionNextProtoNeg uint16 = 13172 // not IANA assigned extensionStatusRequest uint16 = 5
extensionNextProtoNeg uint16 = 13172 // not IANA assigned
)
// TLS CertificateStatusType (RFC 3546)
const (
statusTypeOCSP uint8 = 1
) )
type ConnectionState struct { type ConnectionState struct {
HandshakeComplete bool HandshakeComplete bool
CipherSuite string CipherSuite uint16
Error alert
NegotiatedProtocol string NegotiatedProtocol string
} }

28
conn.go
View File

@ -26,6 +26,7 @@ type Conn struct {
config *Config // configuration passed to constructor config *Config // configuration passed to constructor
handshakeComplete bool handshakeComplete bool
cipherSuite uint16 cipherSuite uint16
ocspResponse []byte // stapled OCSP response
clientProtocol string clientProtocol string
@ -531,6 +532,8 @@ func (c *Conn) readHandshake() (interface{}, os.Error) {
m = new(serverHelloMsg) m = new(serverHelloMsg)
case typeCertificate: case typeCertificate:
m = new(certificateMsg) m = new(certificateMsg)
case typeCertificateStatus:
m = new(certificateStatusMsg)
case typeServerHelloDone: case typeServerHelloDone:
m = new(serverHelloDoneMsg) m = new(serverHelloDoneMsg)
case typeClientKeyExchange: case typeClientKeyExchange:
@ -625,11 +628,26 @@ func (c *Conn) Handshake() os.Error {
return c.serverHandshake() return c.serverHandshake()
} }
// If c is a TLS server, ClientConnection returns the protocol // ConnectionState returns basic TLS details about the connection.
// requested by the client during the TLS handshake. func (c *Conn) ConnectionState() ConnectionState {
// Handshake must have been called already.
func (c *Conn) ClientConnection() string {
c.handshakeMutex.Lock() c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock() defer c.handshakeMutex.Unlock()
return c.clientProtocol
var state ConnectionState
state.HandshakeComplete = c.handshakeComplete
if c.handshakeComplete {
state.NegotiatedProtocol = c.clientProtocol
state.CipherSuite = c.cipherSuite
}
return state
}
// OCSPResponse returns the stapled OCSP response from the TLS server, if
// any. (Only valid for client connections.)
func (c *Conn) OCSPResponse() []byte {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
return c.ocspResponse
} }

View File

@ -18,21 +18,24 @@ import (
func (c *Conn) clientHandshake() os.Error { func (c *Conn) clientHandshake() os.Error {
finishedHash := newFinishedHash() finishedHash := newFinishedHash()
config := defaultConfig() if c.config == nil {
c.config = defaultConfig()
}
hello := &clientHelloMsg{ hello := &clientHelloMsg{
vers: maxVersion, vers: maxVersion,
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
compressionMethods: []uint8{compressionNone}, compressionMethods: []uint8{compressionNone},
random: make([]byte, 32), random: make([]byte, 32),
ocspStapling: true,
} }
t := uint32(config.Time()) t := uint32(c.config.Time())
hello.random[0] = byte(t >> 24) hello.random[0] = byte(t >> 24)
hello.random[1] = byte(t >> 16) hello.random[1] = byte(t >> 16)
hello.random[2] = byte(t >> 8) hello.random[2] = byte(t >> 8)
hello.random[3] = byte(t) hello.random[3] = byte(t)
_, err := io.ReadFull(config.Rand, hello.random[4:]) _, err := io.ReadFull(c.config.Rand, hello.random[4:])
if err != nil { if err != nil {
return c.sendAlert(alertInternalError) return c.sendAlert(alertInternalError)
} }
@ -89,8 +92,8 @@ func (c *Conn) clientHandshake() os.Error {
} }
// TODO(rsc): Find certificates for OS X 10.6. // TODO(rsc): Find certificates for OS X 10.6.
if false && config.RootCAs != nil { if false && c.config.RootCAs != nil {
root := config.RootCAs.FindParent(certs[len(certs)-1]) root := c.config.RootCAs.FindParent(certs[len(certs)-1])
if root == nil { if root == nil {
return c.sendAlert(alertBadCertificate) return c.sendAlert(alertBadCertificate)
} }
@ -104,6 +107,22 @@ func (c *Conn) clientHandshake() os.Error {
return c.sendAlert(alertUnsupportedCertificate) return c.sendAlert(alertUnsupportedCertificate)
} }
if serverHello.certStatus {
msg, err = c.readHandshake()
if err != nil {
return err
}
cs, ok := msg.(*certificateStatusMsg)
if !ok {
return c.sendAlert(alertUnexpectedMessage)
}
finishedHash.Write(cs.marshal())
if cs.statusType == statusTypeOCSP {
c.ocspResponse = cs.response
}
}
msg, err = c.readHandshake() msg, err = c.readHandshake()
if err != nil { if err != nil {
return err return err
@ -118,12 +137,12 @@ func (c *Conn) clientHandshake() os.Error {
preMasterSecret := make([]byte, 48) preMasterSecret := make([]byte, 48)
preMasterSecret[0] = byte(hello.vers >> 8) preMasterSecret[0] = byte(hello.vers >> 8)
preMasterSecret[1] = byte(hello.vers) preMasterSecret[1] = byte(hello.vers)
_, err = io.ReadFull(config.Rand, preMasterSecret[2:]) _, err = io.ReadFull(c.config.Rand, preMasterSecret[2:])
if err != nil { if err != nil {
return c.sendAlert(alertInternalError) return c.sendAlert(alertInternalError)
} }
ckx.ciphertext, err = rsa.EncryptPKCS1v15(config.Rand, pub, preMasterSecret) ckx.ciphertext, err = rsa.EncryptPKCS1v15(c.config.Rand, pub, preMasterSecret)
if err != nil { if err != nil {
return c.sendAlert(alertInternalError) return c.sendAlert(alertInternalError)
} }

View File

@ -13,6 +13,7 @@ type clientHelloMsg struct {
compressionMethods []uint8 compressionMethods []uint8
nextProtoNeg bool nextProtoNeg bool
serverName string serverName string
ocspStapling bool
} }
func (m *clientHelloMsg) marshal() []byte { func (m *clientHelloMsg) marshal() []byte {
@ -26,6 +27,10 @@ func (m *clientHelloMsg) marshal() []byte {
if m.nextProtoNeg { if m.nextProtoNeg {
numExtensions++ numExtensions++
} }
if m.ocspStapling {
extensionsLength += 1 + 2 + 2
numExtensions++
}
if len(m.serverName) > 0 { if len(m.serverName) > 0 {
extensionsLength += 5 + len(m.serverName) extensionsLength += 5 + len(m.serverName)
numExtensions++ numExtensions++
@ -101,6 +106,16 @@ func (m *clientHelloMsg) marshal() []byte {
copy(z[5:], []byte(m.serverName)) copy(z[5:], []byte(m.serverName))
z = z[l:] z = z[l:]
} }
if m.ocspStapling {
// RFC 4366, section 3.6
z[0] = byte(extensionStatusRequest >> 8)
z[1] = byte(extensionStatusRequest)
z[2] = 0
z[3] = 5
z[4] = 1 // OCSP type
// Two zero valued uint16s for the two lengths.
z = z[9:]
}
m.raw = x m.raw = x
@ -148,6 +163,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.nextProtoNeg = false m.nextProtoNeg = false
m.serverName = "" m.serverName = ""
m.ocspStapling = false
if len(data) == 0 { if len(data) == 0 {
// ClientHello is optionally followed by extension data // ClientHello is optionally followed by extension data
@ -202,6 +218,8 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
return false return false
} }
m.nextProtoNeg = true m.nextProtoNeg = true
case extensionStatusRequest:
m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
} }
data = data[length:] data = data[length:]
} }
@ -218,6 +236,7 @@ type serverHelloMsg struct {
compressionMethod uint8 compressionMethod uint8
nextProtoNeg bool nextProtoNeg bool
nextProtos []string nextProtos []string
certStatus bool
} }
func (m *serverHelloMsg) marshal() []byte { func (m *serverHelloMsg) marshal() []byte {
@ -238,6 +257,9 @@ func (m *serverHelloMsg) marshal() []byte {
nextProtoLen += len(m.nextProtos) nextProtoLen += len(m.nextProtos)
extensionsLength += nextProtoLen extensionsLength += nextProtoLen
} }
if m.certStatus {
numExtensions++
}
if numExtensions > 0 { if numExtensions > 0 {
extensionsLength += 4 * numExtensions extensionsLength += 4 * numExtensions
length += 2 + extensionsLength length += 2 + extensionsLength
@ -281,6 +303,11 @@ func (m *serverHelloMsg) marshal() []byte {
z = z[1+l:] z = z[1+l:]
} }
} }
if m.certStatus {
z[0] = byte(extensionStatusRequest >> 8)
z[1] = byte(extensionStatusRequest)
z = z[4:]
}
m.raw = x m.raw = x
@ -322,6 +349,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
m.nextProtoNeg = false m.nextProtoNeg = false
m.nextProtos = nil m.nextProtos = nil
m.certStatus = false
if len(data) == 0 { if len(data) == 0 {
// ServerHello is optionally followed by extension data // ServerHello is optionally followed by extension data
@ -361,6 +389,11 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
m.nextProtos = append(m.nextProtos, string(d[0:l])) m.nextProtos = append(m.nextProtos, string(d[0:l]))
d = d[l:] d = d[l:]
} }
case extensionStatusRequest:
if length > 0 {
return false
}
m.certStatus = true
} }
data = data[length:] data = data[length:]
} }
@ -445,6 +478,61 @@ func (m *certificateMsg) unmarshal(data []byte) bool {
return true return true
} }
type certificateStatusMsg struct {
raw []byte
statusType uint8
response []byte
}
func (m *certificateStatusMsg) marshal() []byte {
if m.raw != nil {
return m.raw
}
var x []byte
if m.statusType == statusTypeOCSP {
x = make([]byte, 4+4+len(m.response))
x[0] = typeCertificateStatus
l := len(m.response) + 4
x[1] = byte(l >> 16)
x[2] = byte(l >> 8)
x[3] = byte(l)
x[4] = statusTypeOCSP
l -= 4
x[5] = byte(l >> 16)
x[6] = byte(l >> 8)
x[7] = byte(l)
copy(x[8:], m.response)
} else {
x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
}
m.raw = x
return x
}
func (m *certificateStatusMsg) unmarshal(data []byte) bool {
m.raw = data
if len(data) < 5 {
return false
}
m.statusType = data[4]
m.response = nil
if m.statusType == statusTypeOCSP {
if len(data) < 8 {
return false
}
respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
if uint32(len(data)) != 4+4+respLen {
return false
}
m.response = data[8:]
}
return true
}
type serverHelloDoneMsg struct{} type serverHelloDoneMsg struct{}
func (m *serverHelloDoneMsg) marshal() []byte { func (m *serverHelloDoneMsg) marshal() []byte {

View File

@ -16,6 +16,7 @@ var tests = []interface{}{
&serverHelloMsg{}, &serverHelloMsg{},
&certificateMsg{}, &certificateMsg{},
&certificateStatusMsg{},
&clientKeyExchangeMsg{}, &clientKeyExchangeMsg{},
&finishedMsg{}, &finishedMsg{},
&nextProtoMsg{}, &nextProtoMsg{},
@ -111,6 +112,7 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
if rand.Intn(10) > 5 { if rand.Intn(10) > 5 {
m.serverName = randomString(rand.Intn(255), rand) m.serverName = randomString(rand.Intn(255), rand)
} }
m.ocspStapling = rand.Intn(10) > 5
return reflect.NewValue(m) return reflect.NewValue(m)
} }
@ -146,6 +148,17 @@ func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
return reflect.NewValue(m) return reflect.NewValue(m)
} }
func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateStatusMsg{}
if rand.Intn(10) > 5 {
m.statusType = statusTypeOCSP
m.response = randomBytes(rand.Intn(10)+1, rand)
} else {
m.statusType = 42
}
return reflect.NewValue(m)
}
func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &clientKeyExchangeMsg{} m := &clientKeyExchangeMsg{}
m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)

View File

@ -71,13 +71,13 @@ func TestRejectBadProtocolVersion(t *testing.T) {
} }
func TestNoSuiteOverlap(t *testing.T) { func TestNoSuiteOverlap(t *testing.T) {
clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{0xff00}, []uint8{0}, false, ""} clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{0xff00}, []uint8{0}, false, "", false}
testClientHelloFailure(t, clientHello, alertHandshakeFailure) testClientHelloFailure(t, clientHello, alertHandshakeFailure)
} }
func TestNoCompressionOverlap(t *testing.T) { func TestNoCompressionOverlap(t *testing.T) {
clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}, false, ""} clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}, false, "", false}
testClientHelloFailure(t, clientHello, alertHandshakeFailure) testClientHelloFailure(t, clientHello, alertHandshakeFailure)
} }