From ebe78b393dcb8dd8f75fd81983f9e0a372994b96 Mon Sep 17 00:00:00 2001 From: Adam Langley Date: Thu, 5 Nov 2009 15:44:32 -0800 Subject: [PATCH] crypto/tls (part 3) (With hindsight, record_process might have been designed wrong, but it works for now. It'll get redrawn when client support is added.) R=rsc CC=r http://go/go-review/1018032 --- handshake_server.go | 232 +++++++++++++++++++++++++++++++ handshake_server_test.go | 210 ++++++++++++++++++++++++++++ record_process.go | 292 +++++++++++++++++++++++++++++++++++++++ record_process_test.go | 137 ++++++++++++++++++ 4 files changed, 871 insertions(+) create mode 100644 handshake_server.go create mode 100644 handshake_server_test.go create mode 100644 record_process.go create mode 100644 record_process_test.go diff --git a/handshake_server.go b/handshake_server.go new file mode 100644 index 0000000..7303189 --- /dev/null +++ b/handshake_server.go @@ -0,0 +1,232 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +// The handshake goroutine reads handshake messages from the record processor +// and outputs messages to be written on another channel. It updates the record +// processor with the state of the connection via the control channel. In the +// case of handshake messages that need synchronous processing (because they +// affect the handling of the next record) the record processor knows about +// them and either waits for a control message (Finished) or includes a reply +// channel in the message (ChangeCipherSpec). + +import ( + "crypto/hmac"; + "crypto/rc4"; + "crypto/rsa"; + "crypto/sha1"; + "crypto/subtle"; + "io"; +) + +type cipherSuite struct { + id uint16; // The number of this suite on the wire. + hashLength, cipherKeyLength int; + // TODO(agl): need a method to create the cipher and hash interfaces. +} + +var cipherSuites = []cipherSuite{ + cipherSuite{TLS_RSA_WITH_RC4_128_SHA, 20, 16}, +} + +// A serverHandshake performs the server side of the TLS 1.1 handshake protocol. +type serverHandshake struct { + writeChan chan<- interface{}; + controlChan chan<- interface{}; + msgChan <-chan interface{}; + config *Config; +} + +func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config) { + h.writeChan = writeChan; + h.controlChan = controlChan; + h.msgChan = msgChan; + h.config = config; + + defer close(writeChan); + defer close(controlChan); + + clientHello, ok := h.readHandshakeMsg().(*clientHelloMsg); + if !ok { + h.error(alertUnexpectedMessage); + return; + } + major, minor, ok := mutualVersion(clientHello.major, clientHello.minor); + if !ok { + h.error(alertProtocolVersion); + return; + } + + finishedHash := newFinishedHash(); + finishedHash.Write(clientHello.marshal()); + + hello := new(serverHelloMsg); + + // We only support a single ciphersuite so we look for it in the list + // of client supported suites. + // + // TODO(agl): Add additional cipher suites. + var suite *cipherSuite; + + for _, id := range clientHello.cipherSuites { + for _, supported := range cipherSuites { + if supported.id == id { + suite = &supported; + break; + } + } + } + + foundCompression := false; + // We only support null compression, so check that the client offered it. + for _, compression := range clientHello.compressionMethods { + if compression == compressionNone { + foundCompression = true; + break; + } + } + + if suite == nil || !foundCompression { + h.error(alertHandshakeFailure); + return; + } + + hello.major = major; + hello.minor = minor; + hello.cipherSuite = suite.id; + currentTime := uint32(config.Time()); + hello.random = make([]byte, 32); + hello.random[0] = byte(currentTime >> 24); + hello.random[1] = byte(currentTime >> 16); + hello.random[2] = byte(currentTime >> 8); + hello.random[3] = byte(currentTime); + _, err := io.ReadFull(config.Rand, hello.random[4:len(hello.random)]); + if err != nil { + h.error(alertInternalError); + return; + } + hello.compressionMethod = compressionNone; + + finishedHash.Write(hello.marshal()); + writeChan <- writerSetVersion{major, minor}; + writeChan <- hello; + + if len(config.Certificates) == 0 { + h.error(alertInternalError); + return; + } + + certMsg := new(certificateMsg); + certMsg.certificates = config.Certificates[0].Certificate; + finishedHash.Write(certMsg.marshal()); + writeChan <- certMsg; + + helloDone := new(serverHelloDoneMsg); + finishedHash.Write(helloDone.marshal()); + writeChan <- helloDone; + + ckx, ok := h.readHandshakeMsg().(*clientKeyExchangeMsg); + if !ok { + h.error(alertUnexpectedMessage); + return; + } + finishedHash.Write(ckx.marshal()); + + preMasterSecret := make([]byte, 48); + _, err = io.ReadFull(config.Rand, preMasterSecret[2:len(preMasterSecret)]); + if err != nil { + h.error(alertInternalError); + return; + } + + err = rsa.DecryptPKCS1v15SessionKey(config.Rand, config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret); + if err != nil { + h.error(alertHandshakeFailure); + return; + } + // We don't check the version number in the premaster secret. For one, + // by checking it, we would leak information about the validity of the + // encrypted pre-master secret. Secondly, it provides only a small + // benefit against a downgrade attack and some implementations send the + // wrong version anyway. See the discussion at the end of section + // 7.4.7.1 of RFC 4346. + + masterSecret, clientMAC, serverMAC, clientKey, serverKey := + keysFromPreMasterSecret11(preMasterSecret, clientHello.random, hello.random, suite.hashLength, suite.cipherKeyLength); + + _, ok = h.readHandshakeMsg().(changeCipherSpec); + if !ok { + h.error(alertUnexpectedMessage); + return; + } + + cipher, _ := rc4.NewCipher(clientKey); + controlChan <- &newCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)}; + + clientFinished, ok := h.readHandshakeMsg().(*finishedMsg); + if !ok { + h.error(alertUnexpectedMessage); + return; + } + + verify := finishedHash.clientSum(masterSecret); + if len(verify) != len(clientFinished.verifyData) || + subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { + h.error(alertHandshakeFailure); + return; + } + + controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0}; + + finishedHash.Write(clientFinished.marshal()); + + cipher2, _ := rc4.NewCipher(serverKey); + writeChan <- writerChangeCipherSpec{cipher2, hmac.New(sha1.New(), serverMAC)}; + + finished := new(finishedMsg); + finished.verifyData = finishedHash.serverSum(masterSecret); + writeChan <- finished; + + writeChan <- writerEnableApplicationData{}; + + for { + _, ok := h.readHandshakeMsg().(*clientHelloMsg); + if !ok { + h.error(alertUnexpectedMessage); + return; + } + // We reject all renegotication requests. + writeChan <- alert{alertLevelWarning, alertNoRenegotiation}; + } +} + +func (h *serverHandshake) readHandshakeMsg() interface{} { + v := <-h.msgChan; + if closed(h.msgChan) { + // If the channel closed then the processor received an error + // from the peer and we don't want to echo it back to them. + h.msgChan = nil; + return 0; + } + if _, ok := v.(alert); ok { + // We got an alert from the processor. We forward to the writer + // and shutdown. + h.writeChan <- v; + h.msgChan = nil; + return 0; + } + return v; +} + +func (h *serverHandshake) error(e alertType) { + if h.msgChan != nil { + // If we didn't get an error from the processor, then we need + // to tell it about the error. + h.controlChan <- ConnectionState{false, "", e}; + close(h.controlChan); + go func() { for _ = range h.msgChan {} }(); + h.writeChan <- alert{alertLevelError, e}; + } +} diff --git a/handshake_server_test.go b/handshake_server_test.go new file mode 100644 index 0000000..4e7507e --- /dev/null +++ b/handshake_server_test.go @@ -0,0 +1,210 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes"; + "big"; + "crypto/rsa"; + "os"; + "testing"; + "testing/script"; +) + +type zeroSource struct{} + +func (zeroSource) Read(b []byte) (n int, err os.Error) { + for i := range b { + b[i] = 0; + } + + return len(b), nil; +} + +var testConfig *Config + +func init() { + testConfig = new(Config); + testConfig.Time = func() int64 { return 0 }; + testConfig.Rand = zeroSource{}; + testConfig.Certificates = make([]Certificate, 1); + testConfig.Certificates[0].Certificate = [][]byte{testCertificate}; + testConfig.Certificates[0].PrivateKey = testPrivateKey; +} + +func setupServerHandshake() (writeChan chan interface{}, controlChan chan interface{}, msgChan chan interface{}) { + sh := new(serverHandshake); + writeChan = make(chan interface{}); + controlChan = make(chan interface{}); + msgChan = make(chan interface{}); + + go sh.loop(writeChan, controlChan, msgChan, testConfig); + return; +} + +func testClientHelloFailure(t *testing.T, clientHello interface{}, expectedAlert alertType) { + writeChan, controlChan, msgChan := setupServerHandshake(); + defer close(msgChan); + + send := script.NewEvent("send", nil, script.Send{msgChan, clientHello}); + recvAlert := script.NewEvent("recv alert", []*script.Event{send}, script.Recv{writeChan, alert{alertLevelError, expectedAlert}}); + close1 := script.NewEvent("msgChan close", []*script.Event{recvAlert}, script.Closed{writeChan}); + recvState := script.NewEvent("recv state", []*script.Event{send}, script.Recv{controlChan, ConnectionState{false, "", expectedAlert}}); + close2 := script.NewEvent("controlChan close", []*script.Event{recvState}, script.Closed{controlChan}); + + err := script.Perform(0, []*script.Event{send, recvAlert, close1, recvState, close2}); + if err != nil { + t.Errorf("Got error: %s", err); + } +} + +func TestSimpleError(t *testing.T) { + testClientHelloFailure(t, &serverHelloDoneMsg{}, alertUnexpectedMessage); +} + +var badProtocolVersions = []uint8{0, 0, 0, 5, 1, 0, 1, 5, 2, 0, 2, 5, 3, 0} + +func TestRejectBadProtocolVersion(t *testing.T) { + clientHello := new(clientHelloMsg); + + for i := 0; i < len(badProtocolVersions); i += 2 { + clientHello.major = badProtocolVersions[i]; + clientHello.minor = badProtocolVersions[i+1]; + + testClientHelloFailure(t, clientHello, alertProtocolVersion); + } +} + +func TestNoSuiteOverlap(t *testing.T) { + clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{0xff00}, []uint8{0}}; + testClientHelloFailure(t, clientHello, alertHandshakeFailure); + +} + +func TestNoCompressionOverlap(t *testing.T) { + clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}}; + testClientHelloFailure(t, clientHello, alertHandshakeFailure); +} + +func matchServerHello(v interface{}) bool { + serverHello, ok := v.(*serverHelloMsg); + if !ok { + return false; + } + return serverHello.major == 3 && + serverHello.minor == 2 && + serverHello.cipherSuite == TLS_RSA_WITH_RC4_128_SHA && + serverHello.compressionMethod == compressionNone; +} + +func TestAlertForwarding(t *testing.T) { + writeChan, controlChan, msgChan := setupServerHandshake(); + defer close(msgChan); + + a := alert{alertLevelError, alertNoRenegotiation}; + sendAlert := script.NewEvent("send alert", nil, script.Send{msgChan, a}); + recvAlert := script.NewEvent("recv alert", []*script.Event{sendAlert}, script.Recv{writeChan, a}); + closeWriter := script.NewEvent("close writer", []*script.Event{recvAlert}, script.Closed{writeChan}); + closeControl := script.NewEvent("close control", []*script.Event{recvAlert}, script.Closed{controlChan}); + + err := script.Perform(0, []*script.Event{sendAlert, recvAlert, closeWriter, closeControl}); + if err != nil { + t.Errorf("Got error: %s", err); + } +} + +func TestClose(t *testing.T) { + writeChan, controlChan, msgChan := setupServerHandshake(); + + close := script.NewEvent("close", nil, script.Close{msgChan}); + closed1 := script.NewEvent("closed1", []*script.Event{close}, script.Closed{writeChan}); + closed2 := script.NewEvent("closed2", []*script.Event{close}, script.Closed{controlChan}); + + err := script.Perform(0, []*script.Event{close, closed1, closed2}); + if err != nil { + t.Errorf("Got error: %s", err); + } +} + +func matchCertificate(v interface{}) bool { + cert, ok := v.(*certificateMsg); + if !ok { + return false; + } + return len(cert.certificates) == 1 && + bytes.Compare(cert.certificates[0], testCertificate) == 0; +} + +func matchSetCipher(v interface{}) bool { + _, ok := v.(writerChangeCipherSpec); + return ok; +} + +func matchDone(v interface{}) bool { + _, ok := v.(*serverHelloDoneMsg); + return ok; +} + +func matchFinished(v interface{}) bool { + finished, ok := v.(*finishedMsg); + if !ok { + return false; + } + return bytes.Compare(finished.verifyData, fromHex("29122ae11453e631487b02ed")) == 0; +} + +func matchNewCipherSpec(v interface{}) bool { + _, ok := v.(*newCipherSpec); + return ok; +} + +func TestFullHandshake(t *testing.T) { + writeChan, controlChan, msgChan := setupServerHandshake(); + defer close(msgChan); + + // The values for this test were obtained from running `gnutls-cli --insecure --debug 9` + clientHello := &clientHelloMsg{fromHex("0100007603024aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b310000340033004500390088001600320044003800870013006600900091008f008e002f004100350084000a00050004008c008d008b008a01000019000900030200010000000e000c0000093132372e302e302e31"), 3, 2, fromHex("4aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b31"), nil, []uint16{0x33, 0x45, 0x39, 0x88, 0x16, 0x32, 0x44, 0x38, 0x87, 0x13, 0x66, 0x90, 0x91, 0x8f, 0x8e, 0x2f, 0x41, 0x35, 0x84, 0xa, 0x5, 0x4, 0x8c, 0x8d, 0x8b, 0x8a}, []uint8{0x0}}; + + sendHello := script.NewEvent("send hello", nil, script.Send{msgChan, clientHello}); + setVersion := script.NewEvent("set version", []*script.Event{sendHello}, script.Recv{writeChan, writerSetVersion{3, 2}}); + recvHello := script.NewEvent("recv hello", []*script.Event{setVersion}, script.RecvMatch{writeChan, matchServerHello}); + recvCert := script.NewEvent("recv cert", []*script.Event{recvHello}, script.RecvMatch{writeChan, matchCertificate}); + recvDone := script.NewEvent("recv done", []*script.Event{recvCert}, script.RecvMatch{writeChan, matchDone}); + + ckx := &clientKeyExchangeMsg{nil, fromHex("872e1fee5f37dd86f3215938ac8de20b302b90074e9fb93097e6b7d1286d0f45abf2daf179deb618bb3c70ed0afee6ee24476ee4649e5a23358143c0f1d9c251")}; + sendCKX := script.NewEvent("send ckx", []*script.Event{recvDone}, script.Send{msgChan, ckx}); + + sendCCS := script.NewEvent("send ccs", []*script.Event{sendCKX}, script.Send{msgChan, changeCipherSpec{}}); + recvNCS := script.NewEvent("recv done", []*script.Event{sendCCS}, script.RecvMatch{controlChan, matchNewCipherSpec}); + + finished := &finishedMsg{nil, fromHex("c8faca5d242f4423325c5b1a")}; + sendFinished := script.NewEvent("send finished", []*script.Event{recvNCS}, script.Send{msgChan, finished}); + recvFinished := script.NewEvent("recv finished", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchFinished}); + setCipher := script.NewEvent("set cipher", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchSetCipher}); + recvConnectionState := script.NewEvent("recv state", []*script.Event{sendFinished}, script.Recv{controlChan, ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0}}); + + err := script.Perform(0, []*script.Event{sendHello, setVersion, recvHello, recvCert, recvDone, sendCKX, sendCCS, recvNCS, sendFinished, setCipher, recvConnectionState, recvFinished}); + if err != nil { + t.Errorf("Got error: %s", err); + } +} + +var testCertificate = fromHex("3082025930820203a003020102020900c2ec326b95228959300d06092a864886f70d01010505003054310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464310d300b0603550403130474657374301e170d3039313032303232323434355a170d3130313032303232323434355a3054310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464310d300b0603550403130474657374305c300d06092a864886f70d0101010500034b003048024100b2990f49c47dfa8cd400ae6a4d1b8a3b6a13642b23f28b003bfb97790ade9a4cc82b8b2a81747ddec08b6296e53a08c331687ef25c4bf4936ba1c0e6041e9d150203010001a381b73081b4301d0603551d0e0416041478a06086837c9293a8c9b70c0bdabdb9d77eeedf3081840603551d23047d307b801478a06086837c9293a8c9b70c0bdabdb9d77eeedfa158a4563054310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464310d300b0603550403130474657374820900c2ec326b95228959300c0603551d13040530030101ff300d06092a864886f70d0101050500034100ac23761ae1349d85a439caad4d0b932b09ea96de1917c3e0507c446f4838cb3076fb4d431db8c1987e96f1d7a8a2054dea3a64ec99a3f0eda4d47a163bf1f6ac") + +func bigFromString(s string) *big.Int { + ret := new(big.Int); + ret.SetString(s, 10); + return ret; +} + +var testPrivateKey = &rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + N: bigFromString("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"), + E: 65537, + }, + D: bigFromString("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"), + P: bigFromString("98920366548084643601728869055592650835572950932266967461790948584315647051443"), + Q: bigFromString("94560208308847015747498523884063394671606671904944666360068158221458669711639"), +} diff --git a/record_process.go b/record_process.go new file mode 100644 index 0000000..4c69319 --- /dev/null +++ b/record_process.go @@ -0,0 +1,292 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +// A recordProcessor accepts reassembled records, decrypts and verifies them +// and routes them either to the handshake processor, to up to the application. +// It also accepts requests from the application for the current connection +// state, or for a notification when the state changes. + +import ( + "bytes"; + "container/list"; + "crypto/subtle"; + "hash"; +) + +// getConnectionState is a request from the application to get the current +// ConnectionState. +type getConnectionState struct { + reply chan<- ConnectionState; +} + +// waitConnectionState is a request from the application to be notified when +// the connection state changes. +type waitConnectionState struct { + reply chan<- ConnectionState; +} + +// connectionStateChange is a message from the handshake processor that the +// connection state has changed. +type connectionStateChange struct { + connState ConnectionState; +} + +// changeCipherSpec is a message send to the handshake processor to signal that +// the peer is switching ciphers. +type changeCipherSpec struct{} + +// newCipherSpec is a message from the handshake processor that future +// records should be processed with a new cipher and MAC function. +type newCipherSpec struct { + encrypt encryptor; + mac hash.Hash; +} + +type recordProcessor struct { + decrypt encryptor; + mac hash.Hash; + seqNum uint64; + handshakeBuf []byte; + appDataChan chan<- []byte; + requestChan <-chan interface{}; + controlChan <-chan interface{}; + recordChan <-chan *record; + handshakeChan chan<- interface{}; + + // recordRead is nil when we don't wish to read any more. + recordRead <-chan *record; + // appDataSend is nil when len(appData) == 0. + appDataSend chan<- []byte; + // appData contains any application data queued for upstream. + appData []byte; + // A list of channels waiting for connState to change. + waitQueue *list.List; + connState ConnectionState; + shutdown bool; + header [13]byte; +} + +// drainRequestChannel processes messages from the request channel until it's closed. +func drainRequestChannel(requestChan <-chan interface{}, c ConnectionState) { + for v := range requestChan { + if closed(requestChan) { + return; + } + switch r := v.(type) { + case getConnectionState: + r.reply <- c; + case waitConnectionState: + r.reply <- c; + } + } +} + +func (p *recordProcessor) loop(appDataChan chan<- []byte, requestChan <-chan interface{}, controlChan <-chan interface{}, recordChan <-chan *record, handshakeChan chan<- interface{}) { + noop := nop{}; + p.decrypt = noop; + p.mac = noop; + p.waitQueue = list.New(); + + p.appDataChan = appDataChan; + p.requestChan = requestChan; + p.controlChan = controlChan; + p.recordChan = recordChan; + p.handshakeChan = handshakeChan; + p.recordRead = recordChan; + + for !p.shutdown { + select { + case p.appDataSend <- p.appData: + p.appData = nil; + p.appDataSend = nil; + p.recordRead = p.recordChan; + case c := <-controlChan: + p.processControlMsg(c); + case r := <-requestChan: + p.processRequestMsg(r); + case r := <-p.recordRead: + p.processRecord(r); + } + } + + p.wakeWaiters(); + go drainRequestChannel(p.requestChan, p.connState); + go func() { for _ = range controlChan {} }(); + + close(handshakeChan); + if len(p.appData) > 0 { + appDataChan <- p.appData; + } + close(appDataChan); +} + +func (p *recordProcessor) processRequestMsg(requestMsg interface{}) { + if closed(p.requestChan) { + p.shutdown = true; + return; + } + + switch r := requestMsg.(type) { + case getConnectionState: + r.reply <- p.connState; + case waitConnectionState: + if p.connState.HandshakeComplete { + r.reply <- p.connState; + } + p.waitQueue.PushBack(r.reply); + } +} + +func (p *recordProcessor) processControlMsg(msg interface{}) { + connState, ok := msg.(ConnectionState); + if !ok || closed(p.controlChan) { + p.shutdown = true; + return; + } + + p.connState = connState; + p.wakeWaiters(); +} + +func (p *recordProcessor) wakeWaiters() { + for i := p.waitQueue.Front(); i != nil; i = i.Next() { + i.Value.(chan<- ConnectionState) <- p.connState; + } + p.waitQueue.Init(); +} + +func (p *recordProcessor) processRecord(r *record) { + if closed(p.recordChan) { + p.shutdown = true; + return; + } + + p.decrypt.XORKeyStream(r.payload); + if len(r.payload) < p.mac.Size() { + p.error(alertBadRecordMAC); + return; + } + + fillMACHeader(&p.header, p.seqNum, len(r.payload) - p.mac.Size(), r); + p.seqNum++; + + p.mac.Reset(); + p.mac.Write(p.header[0:13]); + p.mac.Write(r.payload[0 : len(r.payload) - p.mac.Size()]); + macBytes := p.mac.Sum(); + + if subtle.ConstantTimeCompare(macBytes, r.payload[len(r.payload) - p.mac.Size() : len(r.payload)]) != 1 { + p.error(alertBadRecordMAC); + return; + } + + switch r.contentType { + case recordTypeHandshake: + p.processHandshakeRecord(r.payload[0 : len(r.payload) - p.mac.Size()]); + case recordTypeChangeCipherSpec: + if len(r.payload) != 1 || r.payload[0] != 1 { + p.error(alertUnexpectedMessage); + return; + } + + p.handshakeChan <- changeCipherSpec{}; + newSpec, ok := (<-p.controlChan).(*newCipherSpec); + if !ok { + p.connState.Error = alertUnexpectedMessage; + p.shutdown = true; + return; + } + p.decrypt = newSpec.encrypt; + p.mac = newSpec.mac; + p.seqNum = 0; + case recordTypeApplicationData: + if p.connState.HandshakeComplete == false { + p.error(alertUnexpectedMessage); + return; + } + p.recordRead = nil; + p.appData = r.payload; + p.appDataSend = p.appDataChan; + default: + p.error(alertUnexpectedMessage); + return; + } +} + +func (p *recordProcessor) processHandshakeRecord(data []byte) { + if p.handshakeBuf == nil { + p.handshakeBuf = data; + } else { + if len(p.handshakeBuf) > maxHandshakeMsg { + p.error(alertInternalError); + return; + } + newBuf := make([]byte, len(p.handshakeBuf)+len(data)); + bytes.Copy(newBuf, p.handshakeBuf); + bytes.Copy(newBuf[len(p.handshakeBuf):len(newBuf)], data); + p.handshakeBuf = newBuf; + } + + for len(p.handshakeBuf) >= 4 { + handshakeLen := int(p.handshakeBuf[1])<<16 | + int(p.handshakeBuf[2])<<8 | + int(p.handshakeBuf[3]); + if handshakeLen + 4 > len(p.handshakeBuf) { + break; + } + + bytes := p.handshakeBuf[0 : handshakeLen + 4]; + p.handshakeBuf = p.handshakeBuf[handshakeLen + 4 : len(p.handshakeBuf)]; + if bytes[0] == typeFinished { + // Special case because Finished is synchronous: the + // handshake handler has to tell us if it's ok to start + // forwarding application data. + m := new(finishedMsg); + if !m.unmarshal(bytes) { + p.error(alertUnexpectedMessage); + } + p.handshakeChan <- m; + var ok bool; + p.connState, ok = (<-p.controlChan).(ConnectionState); + if !ok || p.connState.Error != 0 { + p.shutdown = true; + return; + } + } else { + msg, ok := parseHandshakeMsg(bytes); + if !ok { + p.error(alertUnexpectedMessage); + return; + } + p.handshakeChan <- msg; + } + } +} + +func (p *recordProcessor) error(err alertType) { + close(p.handshakeChan); + p.connState.Error = err; + p.wakeWaiters(); + p.shutdown = true; +} + +func parseHandshakeMsg(data []byte) (interface{}, bool) { + var m interface { + unmarshal([]byte) bool; + } + + switch data[0] { + case typeClientHello: + m = new(clientHelloMsg); + case typeClientKeyExchange: + m = new(clientKeyExchangeMsg); + default: + return nil, false; + } + + ok := m.unmarshal(data); + return m, ok; +} diff --git a/record_process_test.go b/record_process_test.go new file mode 100644 index 0000000..30e2126 --- /dev/null +++ b/record_process_test.go @@ -0,0 +1,137 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "encoding/hex"; + "testing"; + "testing/script"; +) + +func setup() (appDataChan chan []byte, requestChan chan interface{}, controlChan chan interface{}, recordChan chan *record, handshakeChan chan interface{}) { + rp := new(recordProcessor); + appDataChan = make(chan []byte); + requestChan = make(chan interface{}); + controlChan = make(chan interface{}); + recordChan = make(chan *record); + handshakeChan = make(chan interface{}); + + go rp.loop(appDataChan, requestChan, controlChan, recordChan, handshakeChan); + return; +} + +func fromHex(s string) []byte { + b, _ := hex.DecodeString(s); + return b; +} + +func TestNullConnectionState(t *testing.T) { + _, requestChan, controlChan, recordChan, _ := setup(); + defer close(requestChan); + defer close(controlChan); + defer close(recordChan); + + // Test a simple request for the connection state. + replyChan := make(chan ConnectionState); + sendReq := script.NewEvent("send request", nil, script.Send{requestChan, getConnectionState{replyChan}}); + getReply := script.NewEvent("get reply", []*script.Event{sendReq}, script.Recv{replyChan, ConnectionState{false, "", 0}}); + + err := script.Perform(0, []*script.Event{sendReq, getReply}); + if err != nil { + t.Errorf("Got error: %s", err); + } +} + +func TestWaitConnectionState(t *testing.T) { + _, requestChan, controlChan, recordChan, _ := setup(); + defer close(requestChan); + defer close(controlChan); + defer close(recordChan); + + // Test that waitConnectionState doesn't get a reply until the connection state changes. + replyChan := make(chan ConnectionState); + sendReq := script.NewEvent("send request", nil, script.Send{requestChan, waitConnectionState{replyChan}}); + replyChan2 := make(chan ConnectionState); + sendReq2 := script.NewEvent("send request 2", []*script.Event{sendReq}, script.Send{requestChan, getConnectionState{replyChan2}}); + getReply2 := script.NewEvent("get reply 2", []*script.Event{sendReq2}, script.Recv{replyChan2, ConnectionState{false, "", 0}}); + sendState := script.NewEvent("send state", []*script.Event{getReply2}, script.Send{controlChan, ConnectionState{true, "test", 1}}); + getReply := script.NewEvent("get reply", []*script.Event{sendState}, script.Recv{replyChan, ConnectionState{true, "test", 1}}); + + err := script.Perform(0, []*script.Event{sendReq, sendReq2, getReply2, sendState, getReply}); + if err != nil { + t.Errorf("Got error: %s", err); + } +} + +func TestHandshakeAssembly(t *testing.T) { + _, requestChan, controlChan, recordChan, handshakeChan := setup(); + defer close(requestChan); + defer close(controlChan); + defer close(recordChan); + + // Test the reassembly of a fragmented handshake message. + send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("10000003")}}); + send2 := script.NewEvent("send 2", []*script.Event{send1}, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("0001")}}); + send3 := script.NewEvent("send 3", []*script.Event{send2}, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("42")}}); + recvMsg := script.NewEvent("recv", []*script.Event{send3}, script.Recv{handshakeChan, &clientKeyExchangeMsg{fromHex("10000003000142"), fromHex("42")}}); + + err := script.Perform(0, []*script.Event{send1, send2, send3, recvMsg}); + if err != nil { + t.Errorf("Got error: %s", err); + } +} + +func TestEarlyApplicationData(t *testing.T) { + _, requestChan, controlChan, recordChan, handshakeChan := setup(); + defer close(requestChan); + defer close(controlChan); + defer close(recordChan); + + // Test that applicaton data received before the handshake has completed results in an error. + send := script.NewEvent("send", nil, script.Send{recordChan, &record{recordTypeApplicationData, 0, 0, fromHex("")}}); + recv := script.NewEvent("recv", []*script.Event{send}, script.Closed{handshakeChan}); + + err := script.Perform(0, []*script.Event{send, recv}); + if err != nil { + t.Errorf("Got error: %s", err); + } +} + +func TestApplicationData(t *testing.T) { + appDataChan, requestChan, controlChan, recordChan, handshakeChan := setup(); + defer close(requestChan); + defer close(controlChan); + defer close(recordChan); + + // Test that the application data is forwarded after a successful Finished message. + send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("1400000c000000000000000000000000")}}); + recv1 := script.NewEvent("recv finished", []*script.Event{send1}, script.Recv{handshakeChan, &finishedMsg{fromHex("1400000c000000000000000000000000"), fromHex("000000000000000000000000")}}); + send2 := script.NewEvent("send connState", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{true, "", 0}}); + send3 := script.NewEvent("send 2", []*script.Event{send2}, script.Send{recordChan, &record{recordTypeApplicationData, 0, 0, fromHex("0102")}}); + recv2 := script.NewEvent("recv data", []*script.Event{send3}, script.Recv{appDataChan, []byte{0x01, 0x02}}); + + err := script.Perform(0, []*script.Event{send1, recv1, send2, send3, recv2}); + if err != nil { + t.Errorf("Got error: %s", err); + } +} + +func TestInvalidChangeCipherSpec(t *testing.T) { + appDataChan, requestChan, controlChan, recordChan, handshakeChan := setup(); + defer close(requestChan); + defer close(controlChan); + defer close(recordChan); + + send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeChangeCipherSpec, 0, 0, []byte{1}}}); + recv1 := script.NewEvent("recv 1", []*script.Event{send1}, script.Recv{handshakeChan, changeCipherSpec{}}); + send2 := script.NewEvent("send 2", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{false, "", 42}}); + close := script.NewEvent("close 1", []*script.Event{send2}, script.Closed{appDataChan}); + close2 := script.NewEvent("close 2", []*script.Event{send2}, script.Closed{handshakeChan}); + + err := script.Perform(0, []*script.Event{send1, recv1, send2, close, close2}); + if err != nil { + t.Errorf("Got error: %s", err); + } +}