diff --git a/Makefile b/Makefile index dd3df29..0c2d339 100644 --- a/Makefile +++ b/Makefile @@ -8,12 +8,14 @@ TARG=crypto/tls GOFILES=\ alert.go\ common.go\ + handshake_client.go\ handshake_messages.go\ handshake_server.go\ prf.go\ record_process.go\ record_read.go\ record_write.go\ + ca_set.go\ tls.go\ include $(GOROOT)/src/Make.pkg diff --git a/ca_set.go b/ca_set.go new file mode 100644 index 0000000..e8cddd6 --- /dev/null +++ b/ca_set.go @@ -0,0 +1,75 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "crypto/x509"; + "encoding/pem"; +) + +// A CASet is a set of certificates. +type CASet struct { + bySubjectKeyId map[string]*x509.Certificate; + byName map[string]*x509.Certificate; +} + +func NewCASet() *CASet { + return &CASet{ + make(map[string]*x509.Certificate), + make(map[string]*x509.Certificate), + } +} + +func nameToKey(name *x509.Name) string { + return name.Country + "/" + name.OrganizationalUnit + "/" + name.OrganizationalUnit + "/" + name.CommonName +} + +// FindParent attempts to find the certificate in s which signs the given +// certificate. If no such certificate can be found, it returns nil. +func (s *CASet) FindParent(cert *x509.Certificate) (parent *x509.Certificate) { + var ok bool; + + if len(cert.AuthorityKeyId) > 0 { + parent, ok = s.bySubjectKeyId[string(cert.AuthorityKeyId)] + } else { + parent, ok = s.byName[nameToKey(&cert.Issuer)] + } + + if !ok { + return nil + } + return parent; +} + +// SetFromPEM attempts to parse a series of PEM encoded root certificates. It +// appends any certificates found to s and returns true if any certificates +// were successfully parsed. On many Linux systems, /etc/ssl/cert.pem will +// contains the system wide set of root CAs in a format suitable for this +// function. +func (s *CASet) SetFromPEM(pemCerts []byte) (ok bool) { + for len(pemCerts) > 0 { + var block *pem.Block; + block, pemCerts = pem.Decode(pemCerts); + if block == nil { + break + } + if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { + continue + } + + cert, err := x509.ParseCertificate(block.Bytes); + if err != nil { + continue + } + + if len(cert.SubjectKeyId) > 0 { + s.bySubjectKeyId[string(cert.SubjectKeyId)] = cert + } + s.byName[nameToKey(&cert.Subject)] = cert; + ok = true; + } + + return; +} diff --git a/common.go b/common.go index e295cd4..e1318a8 100644 --- a/common.go +++ b/common.go @@ -17,6 +17,9 @@ const ( maxTLSCiphertext = 16384 + 2048; // maxHandshakeMsg is the largest single handshake message that we'll buffer. maxHandshakeMsg = 65536; + // defaultMajor and defaultMinor are the maximum TLS version that we support. + defaultMajor = 3; + defaultMinor = 2; ) @@ -64,6 +67,7 @@ type Config struct { // Time returns the current time as the number of seconds since the epoch. Time func() int64; Certificates []Certificate; + RootCAs *CASet; } type Certificate struct { diff --git a/handshake_client.go b/handshake_client.go new file mode 100644 index 0000000..db9bb8c --- /dev/null +++ b/handshake_client.go @@ -0,0 +1,225 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "crypto/hmac"; + "crypto/rc4"; + "crypto/rsa"; + "crypto/sha1"; + "crypto/subtle"; + "crypto/x509"; + "io"; +) + +// A serverHandshake performs the server side of the TLS 1.1 handshake protocol. +type clientHandshake struct { + writeChan chan<- interface{}; + controlChan chan<- interface{}; + msgChan <-chan interface{}; + config *Config; +} + +func (h *clientHandshake) loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config) { + h.writeChan = writeChan; + h.controlChan = controlChan; + h.msgChan = msgChan; + h.config = config; + + defer close(writeChan); + defer close(controlChan); + + finishedHash := newFinishedHash(); + + hello := &clientHelloMsg{ + major: defaultMajor, + minor: defaultMinor, + cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + compressionMethods: []uint8{compressionNone}, + random: make([]byte, 32), + }; + + currentTime := uint32(config.Time()); + hello.random[0] = byte(currentTime >> 24); + hello.random[1] = byte(currentTime >> 16); + hello.random[2] = byte(currentTime >> 8); + hello.random[3] = byte(currentTime); + _, err := io.ReadFull(config.Rand, hello.random[4:len(hello.random)]); + if err != nil { + h.error(alertInternalError); + return; + } + + finishedHash.Write(hello.marshal()); + writeChan <- writerSetVersion{defaultMajor, defaultMinor}; + writeChan <- hello; + + serverHello, ok := h.readHandshakeMsg().(*serverHelloMsg); + if !ok { + h.error(alertUnexpectedMessage); + return; + } + finishedHash.Write(serverHello.marshal()); + major, minor, ok := mutualVersion(serverHello.major, serverHello.minor); + if !ok { + h.error(alertProtocolVersion); + return; + } + + writeChan <- writerSetVersion{major, minor}; + + if serverHello.cipherSuite != TLS_RSA_WITH_RC4_128_SHA || + serverHello.compressionMethod != compressionNone { + h.error(alertUnexpectedMessage); + return; + } + + certMsg, ok := h.readHandshakeMsg().(*certificateMsg); + if !ok || len(certMsg.certificates) == 0 { + h.error(alertUnexpectedMessage); + return; + } + finishedHash.Write(certMsg.marshal()); + + certs := make([]*x509.Certificate, len(certMsg.certificates)); + for i, asn1Data := range certMsg.certificates { + cert, err := x509.ParseCertificate(asn1Data); + if err != nil { + h.error(alertBadCertificate); + return; + } + certs[i] = cert; + } + + // TODO(agl): do better validation of certs: max path length, name restrictions etc. + for i := 1; i < len(certs); i++ { + if certs[i-1].CheckSignatureFrom(certs[i]) != nil { + h.error(alertBadCertificate); + return; + } + } + + if config.RootCAs != nil { + root := config.RootCAs.FindParent(certs[len(certs)-1]); + if root == nil { + h.error(alertBadCertificate); + return; + } + if certs[len(certs)-1].CheckSignatureFrom(root) != nil { + h.error(alertBadCertificate); + return; + } + } + + pub, ok := certs[0].PublicKey.(*rsa.PublicKey); + if !ok { + h.error(alertUnsupportedCertificate); + return; + } + + shd, ok := h.readHandshakeMsg().(*serverHelloDoneMsg); + if !ok { + h.error(alertUnexpectedMessage); + return; + } + finishedHash.Write(shd.marshal()); + + ckx := new(clientKeyExchangeMsg); + preMasterSecret := make([]byte, 48); + // Note that the version number in the preMasterSecret must be the + // version offered in the ClientHello. + preMasterSecret[0] = defaultMajor; + preMasterSecret[1] = defaultMinor; + _, err = io.ReadFull(config.Rand, preMasterSecret[2:len(preMasterSecret)]); + if err != nil { + h.error(alertInternalError); + return; + } + + ckx.ciphertext, err = rsa.EncryptPKCS1v15(config.Rand, pub, preMasterSecret); + if err != nil { + h.error(alertInternalError); + return; + } + + finishedHash.Write(ckx.marshal()); + writeChan <- ckx; + + suite := cipherSuites[0]; + masterSecret, clientMAC, serverMAC, clientKey, serverKey := + keysFromPreMasterSecret11(preMasterSecret, hello.random, serverHello.random, suite.hashLength, suite.cipherKeyLength); + + cipher, _ := rc4.NewCipher(clientKey); + writeChan <- writerChangeCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)}; + + finished := new(finishedMsg); + finished.verifyData = finishedHash.clientSum(masterSecret); + finishedHash.Write(finished.marshal()); + writeChan <- finished; + + // TODO(agl): this is cut-through mode which should probably be an option. + writeChan <- writerEnableApplicationData{}; + + _, ok = h.readHandshakeMsg().(changeCipherSpec); + if !ok { + h.error(alertUnexpectedMessage); + return; + } + + cipher2, _ := rc4.NewCipher(serverKey); + controlChan <- &newCipherSpec{cipher2, hmac.New(sha1.New(), serverMAC)}; + + serverFinished, ok := h.readHandshakeMsg().(*finishedMsg); + if !ok { + h.error(alertUnexpectedMessage); + return; + } + + verify := finishedHash.serverSum(masterSecret); + if len(verify) != len(serverFinished.verifyData) || + subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { + h.error(alertHandshakeFailure); + return; + } + + controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0}; + + // This should just block forever. + _ = h.readHandshakeMsg(); + h.error(alertUnexpectedMessage); + return; +} + +func (h *clientHandshake) readHandshakeMsg() interface{} { + v := <-h.msgChan; + if closed(h.msgChan) { + // If the channel closed then the processor received an error + // from the peer and we don't want to echo it back to them. + h.msgChan = nil; + return 0; + } + if _, ok := v.(alert); ok { + // We got an alert from the processor. We forward to the writer + // and shutdown. + h.writeChan <- v; + h.msgChan = nil; + return 0; + } + return v; +} + +func (h *clientHandshake) error(e alertType) { + if h.msgChan != nil { + // If we didn't get an error from the processor, then we need + // to tell it about the error. + go func() { + for _ = range h.msgChan { + } + }(); + h.controlChan <- ConnectionState{false, "", e}; + close(h.controlChan); + h.writeChan <- alert{alertLevelError, e}; + } +} diff --git a/handshake_messages.go b/handshake_messages.go index 87e2e77..65dae87 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -45,7 +45,7 @@ func (m *clientHelloMsg) marshal() []byte { } func (m *clientHelloMsg) unmarshal(data []byte) bool { - if len(data) < 39 { + if len(data) < 43 { return false } m.raw = data; @@ -120,6 +120,30 @@ func (m *serverHelloMsg) marshal() []byte { return x; } +func (m *serverHelloMsg) unmarshal(data []byte) bool { + if len(data) < 42 { + return false + } + m.raw = data; + m.major = data[4]; + m.minor = data[5]; + m.random = data[6:38]; + sessionIdLen := int(data[38]); + if sessionIdLen > 32 || len(data) < 39+sessionIdLen { + return false + } + m.sessionId = data[39 : 39+sessionIdLen]; + data = data[39+sessionIdLen : len(data)]; + if len(data) < 3 { + return false + } + m.cipherSuite = uint16(data[0])<<8 | uint16(data[1]); + m.compressionMethod = data[2]; + + // Trailing data is allowed because extensions may be present. + return true; +} + type certificateMsg struct { raw []byte; certificates [][]byte; @@ -160,6 +184,43 @@ func (m *certificateMsg) marshal() (x []byte) { return; } +func (m *certificateMsg) unmarshal(data []byte) bool { + if len(data) < 7 { + return false + } + + m.raw = data; + certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]); + if uint32(len(data)) != certsLen+7 { + return false + } + + numCerts := 0; + d := data[7:len(data)]; + for certsLen > 0 { + if len(d) < 4 { + return false + } + certLen := uint32(d[0])<<24 | uint32(d[1])<<8 | uint32(d[2]); + if uint32(len(d)) < 3+certLen { + return false + } + d = d[3+certLen : len(d)]; + certsLen -= 3 + certLen; + numCerts++; + } + + m.certificates = make([][]byte, numCerts); + d = data[7:len(data)]; + for i := 0; i < numCerts; i++ { + certLen := uint32(d[0])<<24 | uint32(d[1])<<8 | uint32(d[2]); + m.certificates[i] = d[3 : 3+certLen]; + d = d[3+certLen : len(d)]; + } + + return true; +} + type serverHelloDoneMsg struct{} func (m *serverHelloDoneMsg) marshal() []byte { @@ -168,6 +229,10 @@ func (m *serverHelloDoneMsg) marshal() []byte { return x; } +func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + type clientKeyExchangeMsg struct { raw []byte; ciphertext []byte; diff --git a/handshake_messages_test.go b/handshake_messages_test.go index 5dafc38..c580f65 100644 --- a/handshake_messages_test.go +++ b/handshake_messages_test.go @@ -13,6 +13,8 @@ import ( var tests = []interface{}{ &clientHelloMsg{}, + &serverHelloMsg{}, + &certificateMsg{}, &clientKeyExchangeMsg{}, &finishedMsg{}, } @@ -59,6 +61,20 @@ func TestMarshalUnmarshal(t *testing.T) { } } +func TestFuzz(t *testing.T) { + rand := rand.New(rand.NewSource(0)); + for _, iface := range tests { + m := iface.(testMessage); + + for j := 0; j < 1000; j++ { + len := rand.Intn(100); + bytes := randomBytes(len, rand); + // This just looks for crashes due to bounds errors etc. + m.unmarshal(bytes); + } + } +} + func randomBytes(n int, rand *rand.Rand) []byte { r := make([]byte, n); for i := 0; i < n; i++ { @@ -82,9 +98,30 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { return reflect.NewValue(m); } +func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &serverHelloMsg{}; + m.major = uint8(rand.Intn(256)); + m.minor = uint8(rand.Intn(256)); + m.random = randomBytes(32, rand); + m.sessionId = randomBytes(rand.Intn(32), rand); + m.cipherSuite = uint16(rand.Int31()); + m.compressionMethod = uint8(rand.Intn(256)); + return reflect.NewValue(m); +} + +func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &certificateMsg{}; + numCerts := rand.Intn(20); + m.certificates = make([][]byte, numCerts); + for i := 0; i < numCerts; i++ { + m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) + } + return reflect.NewValue(m); +} + func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &clientKeyExchangeMsg{}; - m.ciphertext = randomBytes(rand.Intn(1000), rand); + m.ciphertext = randomBytes(rand.Intn(1000)+1, rand); return reflect.NewValue(m); } diff --git a/handshake_server.go b/handshake_server.go index 0e04c42..2e77603 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -224,12 +224,12 @@ func (h *serverHandshake) error(e alertType) { if h.msgChan != nil { // If we didn't get an error from the processor, then we need // to tell it about the error. - h.controlChan <- ConnectionState{false, "", e}; - close(h.controlChan); go func() { for _ = range h.msgChan { } }(); + h.controlChan <- ConnectionState{false, "", e}; + close(h.controlChan); h.writeChan <- alert{alertLevelError, e}; } } diff --git a/record_process.go b/record_process.go index b7edd9f..e356d67 100644 --- a/record_process.go +++ b/record_process.go @@ -210,7 +210,7 @@ func (p *recordProcessor) processRecord(r *record) { return; } p.recordRead = nil; - p.appData = r.payload; + p.appData = r.payload[0 : len(r.payload)-p.mac.Size()]; p.appDataSend = p.appDataChan; default: p.error(alertUnexpectedMessage); @@ -283,6 +283,12 @@ func parseHandshakeMsg(data []byte) (interface{}, bool) { switch data[0] { case typeClientHello: m = new(clientHelloMsg) + case typeServerHello: + m = new(serverHelloMsg) + case typeCertificate: + m = new(certificateMsg) + case typeServerHelloDone: + m = new(serverHelloDoneMsg) case typeClientKeyExchange: m = new(clientKeyExchangeMsg) default: diff --git a/tls.go b/tls.go index a162487..c5a0f69 100644 --- a/tls.go +++ b/tls.go @@ -112,9 +112,19 @@ func (tls *Conn) GetConnectionState() ConnectionState { return <-replyChan; } +func (tls *Conn) WaitConnectionState() ConnectionState { + replyChan := make(chan ConnectionState); + tls.requestChan <- waitConnectionState{replyChan}; + return <-replyChan; +} + +type handshaker interface { + loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config); +} + // Server establishes a secure connection over the given connection and acts // as a TLS server. -func Server(conn net.Conn, config *Config) *Conn { +func startTLSGoroutines(conn net.Conn, h handshaker, config *Config) *Conn { tls := new(Conn); tls.Conn = conn; @@ -134,11 +144,19 @@ func Server(conn net.Conn, config *Config) *Conn { go new(recordWriter).loop(conn, writeChan, handshakeWriterChan); go recordReader(readerProcessorChan, conn); go new(recordProcessor).loop(readChan, requestChan, handshakeProcessorChan, readerProcessorChan, processorHandshakeChan); - go new(serverHandshake).loop(handshakeWriterChan, handshakeProcessorChan, processorHandshakeChan, config); + go h.loop(handshakeWriterChan, handshakeProcessorChan, processorHandshakeChan, config); return tls; } +func Server(conn net.Conn, config *Config) *Conn { + return startTLSGoroutines(conn, new(serverHandshake), config) +} + +func Client(conn net.Conn, config *Config) *Conn { + return startTLSGoroutines(conn, new(clientHandshake), config) +} + type Listener struct { listener net.Listener; config *Config;