crypto/tls: add client OCSP stapling support.
R=r, rsc CC=golang-dev https://golang.org/cl/1750042
This commit is contained in:
parent
08ec2a8e44
commit
a54a4371e7
18
common.go
18
common.go
@ -38,6 +38,7 @@ const (
|
|||||||
typeClientHello uint8 = 1
|
typeClientHello uint8 = 1
|
||||||
typeServerHello uint8 = 2
|
typeServerHello uint8 = 2
|
||||||
typeCertificate uint8 = 11
|
typeCertificate uint8 = 11
|
||||||
|
typeCertificateStatus uint8 = 22
|
||||||
typeServerHelloDone uint8 = 14
|
typeServerHelloDone uint8 = 14
|
||||||
typeClientKeyExchange uint8 = 16
|
typeClientKeyExchange uint8 = 16
|
||||||
typeFinished uint8 = 20
|
typeFinished uint8 = 20
|
||||||
@ -45,25 +46,30 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// TLS cipher suites.
|
// TLS cipher suites.
|
||||||
var (
|
const (
|
||||||
TLS_RSA_WITH_RC4_128_SHA uint16 = 5
|
TLS_RSA_WITH_RC4_128_SHA uint16 = 5
|
||||||
)
|
)
|
||||||
|
|
||||||
// TLS compression types.
|
// TLS compression types.
|
||||||
var (
|
const (
|
||||||
compressionNone uint8 = 0
|
compressionNone uint8 = 0
|
||||||
)
|
)
|
||||||
|
|
||||||
// TLS extension numbers
|
// TLS extension numbers
|
||||||
var (
|
var (
|
||||||
extensionServerName uint16 = 0
|
extensionServerName uint16 = 0
|
||||||
extensionNextProtoNeg uint16 = 13172 // not IANA assigned
|
extensionStatusRequest uint16 = 5
|
||||||
|
extensionNextProtoNeg uint16 = 13172 // not IANA assigned
|
||||||
|
)
|
||||||
|
|
||||||
|
// TLS CertificateStatusType (RFC 3546)
|
||||||
|
const (
|
||||||
|
statusTypeOCSP uint8 = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConnectionState struct {
|
type ConnectionState struct {
|
||||||
HandshakeComplete bool
|
HandshakeComplete bool
|
||||||
CipherSuite string
|
CipherSuite uint16
|
||||||
Error alert
|
|
||||||
NegotiatedProtocol string
|
NegotiatedProtocol string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
28
conn.go
28
conn.go
@ -26,6 +26,7 @@ type Conn struct {
|
|||||||
config *Config // configuration passed to constructor
|
config *Config // configuration passed to constructor
|
||||||
handshakeComplete bool
|
handshakeComplete bool
|
||||||
cipherSuite uint16
|
cipherSuite uint16
|
||||||
|
ocspResponse []byte // stapled OCSP response
|
||||||
|
|
||||||
clientProtocol string
|
clientProtocol string
|
||||||
|
|
||||||
@ -531,6 +532,8 @@ func (c *Conn) readHandshake() (interface{}, os.Error) {
|
|||||||
m = new(serverHelloMsg)
|
m = new(serverHelloMsg)
|
||||||
case typeCertificate:
|
case typeCertificate:
|
||||||
m = new(certificateMsg)
|
m = new(certificateMsg)
|
||||||
|
case typeCertificateStatus:
|
||||||
|
m = new(certificateStatusMsg)
|
||||||
case typeServerHelloDone:
|
case typeServerHelloDone:
|
||||||
m = new(serverHelloDoneMsg)
|
m = new(serverHelloDoneMsg)
|
||||||
case typeClientKeyExchange:
|
case typeClientKeyExchange:
|
||||||
@ -625,11 +628,26 @@ func (c *Conn) Handshake() os.Error {
|
|||||||
return c.serverHandshake()
|
return c.serverHandshake()
|
||||||
}
|
}
|
||||||
|
|
||||||
// If c is a TLS server, ClientConnection returns the protocol
|
// ConnectionState returns basic TLS details about the connection.
|
||||||
// requested by the client during the TLS handshake.
|
func (c *Conn) ConnectionState() ConnectionState {
|
||||||
// Handshake must have been called already.
|
|
||||||
func (c *Conn) ClientConnection() string {
|
|
||||||
c.handshakeMutex.Lock()
|
c.handshakeMutex.Lock()
|
||||||
defer c.handshakeMutex.Unlock()
|
defer c.handshakeMutex.Unlock()
|
||||||
return c.clientProtocol
|
|
||||||
|
var state ConnectionState
|
||||||
|
state.HandshakeComplete = c.handshakeComplete
|
||||||
|
if c.handshakeComplete {
|
||||||
|
state.NegotiatedProtocol = c.clientProtocol
|
||||||
|
state.CipherSuite = c.cipherSuite
|
||||||
|
}
|
||||||
|
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
|
||||||
|
// OCSPResponse returns the stapled OCSP response from the TLS server, if
|
||||||
|
// any. (Only valid for client connections.)
|
||||||
|
func (c *Conn) OCSPResponse() []byte {
|
||||||
|
c.handshakeMutex.Lock()
|
||||||
|
defer c.handshakeMutex.Unlock()
|
||||||
|
|
||||||
|
return c.ocspResponse
|
||||||
}
|
}
|
||||||
|
@ -18,21 +18,24 @@ import (
|
|||||||
func (c *Conn) clientHandshake() os.Error {
|
func (c *Conn) clientHandshake() os.Error {
|
||||||
finishedHash := newFinishedHash()
|
finishedHash := newFinishedHash()
|
||||||
|
|
||||||
config := defaultConfig()
|
if c.config == nil {
|
||||||
|
c.config = defaultConfig()
|
||||||
|
}
|
||||||
|
|
||||||
hello := &clientHelloMsg{
|
hello := &clientHelloMsg{
|
||||||
vers: maxVersion,
|
vers: maxVersion,
|
||||||
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
|
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
|
||||||
compressionMethods: []uint8{compressionNone},
|
compressionMethods: []uint8{compressionNone},
|
||||||
random: make([]byte, 32),
|
random: make([]byte, 32),
|
||||||
|
ocspStapling: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
t := uint32(config.Time())
|
t := uint32(c.config.Time())
|
||||||
hello.random[0] = byte(t >> 24)
|
hello.random[0] = byte(t >> 24)
|
||||||
hello.random[1] = byte(t >> 16)
|
hello.random[1] = byte(t >> 16)
|
||||||
hello.random[2] = byte(t >> 8)
|
hello.random[2] = byte(t >> 8)
|
||||||
hello.random[3] = byte(t)
|
hello.random[3] = byte(t)
|
||||||
_, err := io.ReadFull(config.Rand, hello.random[4:])
|
_, err := io.ReadFull(c.config.Rand, hello.random[4:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.sendAlert(alertInternalError)
|
return c.sendAlert(alertInternalError)
|
||||||
}
|
}
|
||||||
@ -89,8 +92,8 @@ func (c *Conn) clientHandshake() os.Error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO(rsc): Find certificates for OS X 10.6.
|
// TODO(rsc): Find certificates for OS X 10.6.
|
||||||
if false && config.RootCAs != nil {
|
if false && c.config.RootCAs != nil {
|
||||||
root := config.RootCAs.FindParent(certs[len(certs)-1])
|
root := c.config.RootCAs.FindParent(certs[len(certs)-1])
|
||||||
if root == nil {
|
if root == nil {
|
||||||
return c.sendAlert(alertBadCertificate)
|
return c.sendAlert(alertBadCertificate)
|
||||||
}
|
}
|
||||||
@ -104,6 +107,22 @@ func (c *Conn) clientHandshake() os.Error {
|
|||||||
return c.sendAlert(alertUnsupportedCertificate)
|
return c.sendAlert(alertUnsupportedCertificate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if serverHello.certStatus {
|
||||||
|
msg, err = c.readHandshake()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cs, ok := msg.(*certificateStatusMsg)
|
||||||
|
if !ok {
|
||||||
|
return c.sendAlert(alertUnexpectedMessage)
|
||||||
|
}
|
||||||
|
finishedHash.Write(cs.marshal())
|
||||||
|
|
||||||
|
if cs.statusType == statusTypeOCSP {
|
||||||
|
c.ocspResponse = cs.response
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
msg, err = c.readHandshake()
|
msg, err = c.readHandshake()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -118,12 +137,12 @@ func (c *Conn) clientHandshake() os.Error {
|
|||||||
preMasterSecret := make([]byte, 48)
|
preMasterSecret := make([]byte, 48)
|
||||||
preMasterSecret[0] = byte(hello.vers >> 8)
|
preMasterSecret[0] = byte(hello.vers >> 8)
|
||||||
preMasterSecret[1] = byte(hello.vers)
|
preMasterSecret[1] = byte(hello.vers)
|
||||||
_, err = io.ReadFull(config.Rand, preMasterSecret[2:])
|
_, err = io.ReadFull(c.config.Rand, preMasterSecret[2:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.sendAlert(alertInternalError)
|
return c.sendAlert(alertInternalError)
|
||||||
}
|
}
|
||||||
|
|
||||||
ckx.ciphertext, err = rsa.EncryptPKCS1v15(config.Rand, pub, preMasterSecret)
|
ckx.ciphertext, err = rsa.EncryptPKCS1v15(c.config.Rand, pub, preMasterSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.sendAlert(alertInternalError)
|
return c.sendAlert(alertInternalError)
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,7 @@ type clientHelloMsg struct {
|
|||||||
compressionMethods []uint8
|
compressionMethods []uint8
|
||||||
nextProtoNeg bool
|
nextProtoNeg bool
|
||||||
serverName string
|
serverName string
|
||||||
|
ocspStapling bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *clientHelloMsg) marshal() []byte {
|
func (m *clientHelloMsg) marshal() []byte {
|
||||||
@ -26,6 +27,10 @@ func (m *clientHelloMsg) marshal() []byte {
|
|||||||
if m.nextProtoNeg {
|
if m.nextProtoNeg {
|
||||||
numExtensions++
|
numExtensions++
|
||||||
}
|
}
|
||||||
|
if m.ocspStapling {
|
||||||
|
extensionsLength += 1 + 2 + 2
|
||||||
|
numExtensions++
|
||||||
|
}
|
||||||
if len(m.serverName) > 0 {
|
if len(m.serverName) > 0 {
|
||||||
extensionsLength += 5 + len(m.serverName)
|
extensionsLength += 5 + len(m.serverName)
|
||||||
numExtensions++
|
numExtensions++
|
||||||
@ -101,6 +106,16 @@ func (m *clientHelloMsg) marshal() []byte {
|
|||||||
copy(z[5:], []byte(m.serverName))
|
copy(z[5:], []byte(m.serverName))
|
||||||
z = z[l:]
|
z = z[l:]
|
||||||
}
|
}
|
||||||
|
if m.ocspStapling {
|
||||||
|
// RFC 4366, section 3.6
|
||||||
|
z[0] = byte(extensionStatusRequest >> 8)
|
||||||
|
z[1] = byte(extensionStatusRequest)
|
||||||
|
z[2] = 0
|
||||||
|
z[3] = 5
|
||||||
|
z[4] = 1 // OCSP type
|
||||||
|
// Two zero valued uint16s for the two lengths.
|
||||||
|
z = z[9:]
|
||||||
|
}
|
||||||
|
|
||||||
m.raw = x
|
m.raw = x
|
||||||
|
|
||||||
@ -148,6 +163,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
|
|||||||
|
|
||||||
m.nextProtoNeg = false
|
m.nextProtoNeg = false
|
||||||
m.serverName = ""
|
m.serverName = ""
|
||||||
|
m.ocspStapling = false
|
||||||
|
|
||||||
if len(data) == 0 {
|
if len(data) == 0 {
|
||||||
// ClientHello is optionally followed by extension data
|
// ClientHello is optionally followed by extension data
|
||||||
@ -202,6 +218,8 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
m.nextProtoNeg = true
|
m.nextProtoNeg = true
|
||||||
|
case extensionStatusRequest:
|
||||||
|
m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
|
||||||
}
|
}
|
||||||
data = data[length:]
|
data = data[length:]
|
||||||
}
|
}
|
||||||
@ -218,6 +236,7 @@ type serverHelloMsg struct {
|
|||||||
compressionMethod uint8
|
compressionMethod uint8
|
||||||
nextProtoNeg bool
|
nextProtoNeg bool
|
||||||
nextProtos []string
|
nextProtos []string
|
||||||
|
certStatus bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *serverHelloMsg) marshal() []byte {
|
func (m *serverHelloMsg) marshal() []byte {
|
||||||
@ -238,6 +257,9 @@ func (m *serverHelloMsg) marshal() []byte {
|
|||||||
nextProtoLen += len(m.nextProtos)
|
nextProtoLen += len(m.nextProtos)
|
||||||
extensionsLength += nextProtoLen
|
extensionsLength += nextProtoLen
|
||||||
}
|
}
|
||||||
|
if m.certStatus {
|
||||||
|
numExtensions++
|
||||||
|
}
|
||||||
if numExtensions > 0 {
|
if numExtensions > 0 {
|
||||||
extensionsLength += 4 * numExtensions
|
extensionsLength += 4 * numExtensions
|
||||||
length += 2 + extensionsLength
|
length += 2 + extensionsLength
|
||||||
@ -281,6 +303,11 @@ func (m *serverHelloMsg) marshal() []byte {
|
|||||||
z = z[1+l:]
|
z = z[1+l:]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if m.certStatus {
|
||||||
|
z[0] = byte(extensionStatusRequest >> 8)
|
||||||
|
z[1] = byte(extensionStatusRequest)
|
||||||
|
z = z[4:]
|
||||||
|
}
|
||||||
|
|
||||||
m.raw = x
|
m.raw = x
|
||||||
|
|
||||||
@ -322,6 +349,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
|
|||||||
|
|
||||||
m.nextProtoNeg = false
|
m.nextProtoNeg = false
|
||||||
m.nextProtos = nil
|
m.nextProtos = nil
|
||||||
|
m.certStatus = false
|
||||||
|
|
||||||
if len(data) == 0 {
|
if len(data) == 0 {
|
||||||
// ServerHello is optionally followed by extension data
|
// ServerHello is optionally followed by extension data
|
||||||
@ -361,6 +389,11 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
|
|||||||
m.nextProtos = append(m.nextProtos, string(d[0:l]))
|
m.nextProtos = append(m.nextProtos, string(d[0:l]))
|
||||||
d = d[l:]
|
d = d[l:]
|
||||||
}
|
}
|
||||||
|
case extensionStatusRequest:
|
||||||
|
if length > 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
m.certStatus = true
|
||||||
}
|
}
|
||||||
data = data[length:]
|
data = data[length:]
|
||||||
}
|
}
|
||||||
@ -445,6 +478,61 @@ func (m *certificateMsg) unmarshal(data []byte) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type certificateStatusMsg struct {
|
||||||
|
raw []byte
|
||||||
|
statusType uint8
|
||||||
|
response []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *certificateStatusMsg) marshal() []byte {
|
||||||
|
if m.raw != nil {
|
||||||
|
return m.raw
|
||||||
|
}
|
||||||
|
|
||||||
|
var x []byte
|
||||||
|
if m.statusType == statusTypeOCSP {
|
||||||
|
x = make([]byte, 4+4+len(m.response))
|
||||||
|
x[0] = typeCertificateStatus
|
||||||
|
l := len(m.response) + 4
|
||||||
|
x[1] = byte(l >> 16)
|
||||||
|
x[2] = byte(l >> 8)
|
||||||
|
x[3] = byte(l)
|
||||||
|
x[4] = statusTypeOCSP
|
||||||
|
|
||||||
|
l -= 4
|
||||||
|
x[5] = byte(l >> 16)
|
||||||
|
x[6] = byte(l >> 8)
|
||||||
|
x[7] = byte(l)
|
||||||
|
copy(x[8:], m.response)
|
||||||
|
} else {
|
||||||
|
x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.raw = x
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *certificateStatusMsg) unmarshal(data []byte) bool {
|
||||||
|
m.raw = data
|
||||||
|
if len(data) < 5 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
m.statusType = data[4]
|
||||||
|
|
||||||
|
m.response = nil
|
||||||
|
if m.statusType == statusTypeOCSP {
|
||||||
|
if len(data) < 8 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
|
||||||
|
if uint32(len(data)) != 4+4+respLen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
m.response = data[8:]
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
type serverHelloDoneMsg struct{}
|
type serverHelloDoneMsg struct{}
|
||||||
|
|
||||||
func (m *serverHelloDoneMsg) marshal() []byte {
|
func (m *serverHelloDoneMsg) marshal() []byte {
|
||||||
|
@ -16,6 +16,7 @@ var tests = []interface{}{
|
|||||||
&serverHelloMsg{},
|
&serverHelloMsg{},
|
||||||
|
|
||||||
&certificateMsg{},
|
&certificateMsg{},
|
||||||
|
&certificateStatusMsg{},
|
||||||
&clientKeyExchangeMsg{},
|
&clientKeyExchangeMsg{},
|
||||||
&finishedMsg{},
|
&finishedMsg{},
|
||||||
&nextProtoMsg{},
|
&nextProtoMsg{},
|
||||||
@ -111,6 +112,7 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
|||||||
if rand.Intn(10) > 5 {
|
if rand.Intn(10) > 5 {
|
||||||
m.serverName = randomString(rand.Intn(255), rand)
|
m.serverName = randomString(rand.Intn(255), rand)
|
||||||
}
|
}
|
||||||
|
m.ocspStapling = rand.Intn(10) > 5
|
||||||
|
|
||||||
return reflect.NewValue(m)
|
return reflect.NewValue(m)
|
||||||
}
|
}
|
||||||
@ -146,6 +148,17 @@ func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
|||||||
return reflect.NewValue(m)
|
return reflect.NewValue(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||||
|
m := &certificateStatusMsg{}
|
||||||
|
if rand.Intn(10) > 5 {
|
||||||
|
m.statusType = statusTypeOCSP
|
||||||
|
m.response = randomBytes(rand.Intn(10)+1, rand)
|
||||||
|
} else {
|
||||||
|
m.statusType = 42
|
||||||
|
}
|
||||||
|
return reflect.NewValue(m)
|
||||||
|
}
|
||||||
|
|
||||||
func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||||
m := &clientKeyExchangeMsg{}
|
m := &clientKeyExchangeMsg{}
|
||||||
m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
|
m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
|
||||||
|
@ -71,13 +71,13 @@ func TestRejectBadProtocolVersion(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNoSuiteOverlap(t *testing.T) {
|
func TestNoSuiteOverlap(t *testing.T) {
|
||||||
clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{0xff00}, []uint8{0}, false, ""}
|
clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{0xff00}, []uint8{0}, false, "", false}
|
||||||
testClientHelloFailure(t, clientHello, alertHandshakeFailure)
|
testClientHelloFailure(t, clientHello, alertHandshakeFailure)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNoCompressionOverlap(t *testing.T) {
|
func TestNoCompressionOverlap(t *testing.T) {
|
||||||
clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}, false, ""}
|
clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}, false, "", false}
|
||||||
testClientHelloFailure(t, clientHello, alertHandshakeFailure)
|
testClientHelloFailure(t, clientHello, alertHandshakeFailure)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user