diff --git a/common.go b/common.go index 56c22cf..7c6940a 100644 --- a/common.go +++ b/common.go @@ -38,6 +38,7 @@ const ( typeClientHello uint8 = 1 typeServerHello uint8 = 2 typeCertificate uint8 = 11 + typeCertificateStatus uint8 = 22 typeServerHelloDone uint8 = 14 typeClientKeyExchange uint8 = 16 typeFinished uint8 = 20 @@ -45,25 +46,30 @@ const ( ) // TLS cipher suites. -var ( +const ( TLS_RSA_WITH_RC4_128_SHA uint16 = 5 ) // TLS compression types. -var ( +const ( compressionNone uint8 = 0 ) // TLS extension numbers var ( - extensionServerName uint16 = 0 - extensionNextProtoNeg uint16 = 13172 // not IANA assigned + extensionServerName uint16 = 0 + extensionStatusRequest uint16 = 5 + extensionNextProtoNeg uint16 = 13172 // not IANA assigned +) + +// TLS CertificateStatusType (RFC 3546) +const ( + statusTypeOCSP uint8 = 1 ) type ConnectionState struct { HandshakeComplete bool - CipherSuite string - Error alert + CipherSuite uint16 NegotiatedProtocol string } diff --git a/conn.go b/conn.go index 0798e26..aa224e4 100644 --- a/conn.go +++ b/conn.go @@ -26,6 +26,7 @@ type Conn struct { config *Config // configuration passed to constructor handshakeComplete bool cipherSuite uint16 + ocspResponse []byte // stapled OCSP response clientProtocol string @@ -531,6 +532,8 @@ func (c *Conn) readHandshake() (interface{}, os.Error) { m = new(serverHelloMsg) case typeCertificate: m = new(certificateMsg) + case typeCertificateStatus: + m = new(certificateStatusMsg) case typeServerHelloDone: m = new(serverHelloDoneMsg) case typeClientKeyExchange: @@ -625,11 +628,26 @@ func (c *Conn) Handshake() os.Error { return c.serverHandshake() } -// If c is a TLS server, ClientConnection returns the protocol -// requested by the client during the TLS handshake. -// Handshake must have been called already. -func (c *Conn) ClientConnection() string { +// ConnectionState returns basic TLS details about the connection. +func (c *Conn) ConnectionState() ConnectionState { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() - return c.clientProtocol + + var state ConnectionState + state.HandshakeComplete = c.handshakeComplete + if c.handshakeComplete { + state.NegotiatedProtocol = c.clientProtocol + state.CipherSuite = c.cipherSuite + } + + return state +} + +// OCSPResponse returns the stapled OCSP response from the TLS server, if +// any. (Only valid for client connections.) +func (c *Conn) OCSPResponse() []byte { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + return c.ocspResponse } diff --git a/handshake_client.go b/handshake_client.go index dd30098..b3b5973 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -18,21 +18,24 @@ import ( func (c *Conn) clientHandshake() os.Error { finishedHash := newFinishedHash() - config := defaultConfig() + if c.config == nil { + c.config = defaultConfig() + } hello := &clientHelloMsg{ vers: maxVersion, cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, compressionMethods: []uint8{compressionNone}, random: make([]byte, 32), + ocspStapling: true, } - t := uint32(config.Time()) + t := uint32(c.config.Time()) hello.random[0] = byte(t >> 24) hello.random[1] = byte(t >> 16) hello.random[2] = byte(t >> 8) hello.random[3] = byte(t) - _, err := io.ReadFull(config.Rand, hello.random[4:]) + _, err := io.ReadFull(c.config.Rand, hello.random[4:]) if err != nil { return c.sendAlert(alertInternalError) } @@ -89,8 +92,8 @@ func (c *Conn) clientHandshake() os.Error { } // TODO(rsc): Find certificates for OS X 10.6. - if false && config.RootCAs != nil { - root := config.RootCAs.FindParent(certs[len(certs)-1]) + if false && c.config.RootCAs != nil { + root := c.config.RootCAs.FindParent(certs[len(certs)-1]) if root == nil { return c.sendAlert(alertBadCertificate) } @@ -104,6 +107,22 @@ func (c *Conn) clientHandshake() os.Error { return c.sendAlert(alertUnsupportedCertificate) } + if serverHello.certStatus { + msg, err = c.readHandshake() + if err != nil { + return err + } + cs, ok := msg.(*certificateStatusMsg) + if !ok { + return c.sendAlert(alertUnexpectedMessage) + } + finishedHash.Write(cs.marshal()) + + if cs.statusType == statusTypeOCSP { + c.ocspResponse = cs.response + } + } + msg, err = c.readHandshake() if err != nil { return err @@ -118,12 +137,12 @@ func (c *Conn) clientHandshake() os.Error { preMasterSecret := make([]byte, 48) preMasterSecret[0] = byte(hello.vers >> 8) preMasterSecret[1] = byte(hello.vers) - _, err = io.ReadFull(config.Rand, preMasterSecret[2:]) + _, err = io.ReadFull(c.config.Rand, preMasterSecret[2:]) if err != nil { return c.sendAlert(alertInternalError) } - ckx.ciphertext, err = rsa.EncryptPKCS1v15(config.Rand, pub, preMasterSecret) + ckx.ciphertext, err = rsa.EncryptPKCS1v15(c.config.Rand, pub, preMasterSecret) if err != nil { return c.sendAlert(alertInternalError) } diff --git a/handshake_messages.go b/handshake_messages.go index f0a48c8..13c05fe 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -13,6 +13,7 @@ type clientHelloMsg struct { compressionMethods []uint8 nextProtoNeg bool serverName string + ocspStapling bool } func (m *clientHelloMsg) marshal() []byte { @@ -26,6 +27,10 @@ func (m *clientHelloMsg) marshal() []byte { if m.nextProtoNeg { numExtensions++ } + if m.ocspStapling { + extensionsLength += 1 + 2 + 2 + numExtensions++ + } if len(m.serverName) > 0 { extensionsLength += 5 + len(m.serverName) numExtensions++ @@ -101,6 +106,16 @@ func (m *clientHelloMsg) marshal() []byte { copy(z[5:], []byte(m.serverName)) z = z[l:] } + if m.ocspStapling { + // RFC 4366, section 3.6 + z[0] = byte(extensionStatusRequest >> 8) + z[1] = byte(extensionStatusRequest) + z[2] = 0 + z[3] = 5 + z[4] = 1 // OCSP type + // Two zero valued uint16s for the two lengths. + z = z[9:] + } m.raw = x @@ -148,6 +163,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.nextProtoNeg = false m.serverName = "" + m.ocspStapling = false if len(data) == 0 { // ClientHello is optionally followed by extension data @@ -202,6 +218,8 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { return false } m.nextProtoNeg = true + case extensionStatusRequest: + m.ocspStapling = length > 0 && data[0] == statusTypeOCSP } data = data[length:] } @@ -218,6 +236,7 @@ type serverHelloMsg struct { compressionMethod uint8 nextProtoNeg bool nextProtos []string + certStatus bool } func (m *serverHelloMsg) marshal() []byte { @@ -238,6 +257,9 @@ func (m *serverHelloMsg) marshal() []byte { nextProtoLen += len(m.nextProtos) extensionsLength += nextProtoLen } + if m.certStatus { + numExtensions++ + } if numExtensions > 0 { extensionsLength += 4 * numExtensions length += 2 + extensionsLength @@ -281,6 +303,11 @@ func (m *serverHelloMsg) marshal() []byte { z = z[1+l:] } } + if m.certStatus { + z[0] = byte(extensionStatusRequest >> 8) + z[1] = byte(extensionStatusRequest) + z = z[4:] + } m.raw = x @@ -322,6 +349,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { m.nextProtoNeg = false m.nextProtos = nil + m.certStatus = false if len(data) == 0 { // ServerHello is optionally followed by extension data @@ -361,6 +389,11 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { m.nextProtos = append(m.nextProtos, string(d[0:l])) d = d[l:] } + case extensionStatusRequest: + if length > 0 { + return false + } + m.certStatus = true } data = data[length:] } @@ -445,6 +478,61 @@ func (m *certificateMsg) unmarshal(data []byte) bool { return true } +type certificateStatusMsg struct { + raw []byte + statusType uint8 + response []byte +} + +func (m *certificateStatusMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var x []byte + if m.statusType == statusTypeOCSP { + x = make([]byte, 4+4+len(m.response)) + x[0] = typeCertificateStatus + l := len(m.response) + 4 + x[1] = byte(l >> 16) + x[2] = byte(l >> 8) + x[3] = byte(l) + x[4] = statusTypeOCSP + + l -= 4 + x[5] = byte(l >> 16) + x[6] = byte(l >> 8) + x[7] = byte(l) + copy(x[8:], m.response) + } else { + x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType} + } + + m.raw = x + return x +} + +func (m *certificateStatusMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 5 { + return false + } + m.statusType = data[4] + + m.response = nil + if m.statusType == statusTypeOCSP { + if len(data) < 8 { + return false + } + respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7]) + if uint32(len(data)) != 4+4+respLen { + return false + } + m.response = data[8:] + } + return true +} + type serverHelloDoneMsg struct{} func (m *serverHelloDoneMsg) marshal() []byte { diff --git a/handshake_messages_test.go b/handshake_messages_test.go index 2e422cc..274e16f 100644 --- a/handshake_messages_test.go +++ b/handshake_messages_test.go @@ -16,6 +16,7 @@ var tests = []interface{}{ &serverHelloMsg{}, &certificateMsg{}, + &certificateStatusMsg{}, &clientKeyExchangeMsg{}, &finishedMsg{}, &nextProtoMsg{}, @@ -111,6 +112,7 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { if rand.Intn(10) > 5 { m.serverName = randomString(rand.Intn(255), rand) } + m.ocspStapling = rand.Intn(10) > 5 return reflect.NewValue(m) } @@ -146,6 +148,17 @@ func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { return reflect.NewValue(m) } +func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &certificateStatusMsg{} + if rand.Intn(10) > 5 { + m.statusType = statusTypeOCSP + m.response = randomBytes(rand.Intn(10)+1, rand) + } else { + m.statusType = 42 + } + return reflect.NewValue(m) +} + func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &clientKeyExchangeMsg{} m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) diff --git a/handshake_server_test.go b/handshake_server_test.go index d31dc49..c1a72fc 100644 --- a/handshake_server_test.go +++ b/handshake_server_test.go @@ -71,13 +71,13 @@ func TestRejectBadProtocolVersion(t *testing.T) { } func TestNoSuiteOverlap(t *testing.T) { - clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{0xff00}, []uint8{0}, false, ""} + clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{0xff00}, []uint8{0}, false, "", false} testClientHelloFailure(t, clientHello, alertHandshakeFailure) } func TestNoCompressionOverlap(t *testing.T) { - clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}, false, ""} + clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}, false, "", false} testClientHelloFailure(t, clientHello, alertHandshakeFailure) }