diff --git a/alert.go b/alert.go index 4cf62e7..2f740b3 100644 --- a/alert.go +++ b/alert.go @@ -8,36 +8,36 @@ type alertLevel int type alertType int const ( - alertLevelWarning alertLevel = 1; - alertLevelError alertLevel = 2; + alertLevelWarning alertLevel = 1 + alertLevelError alertLevel = 2 ) const ( - alertCloseNotify alertType = 0; - alertUnexpectedMessage alertType = 10; - alertBadRecordMAC alertType = 20; - alertDecryptionFailed alertType = 21; - alertRecordOverflow alertType = 22; - alertDecompressionFailure alertType = 30; - alertHandshakeFailure alertType = 40; - alertBadCertificate alertType = 42; - alertUnsupportedCertificate alertType = 43; - alertCertificateRevoked alertType = 44; - alertCertificateExpired alertType = 45; - alertCertificateUnknown alertType = 46; - alertIllegalParameter alertType = 47; - alertUnknownCA alertType = 48; - alertAccessDenied alertType = 49; - alertDecodeError alertType = 50; - alertDecryptError alertType = 51; - alertProtocolVersion alertType = 70; - alertInsufficientSecurity alertType = 71; - alertInternalError alertType = 80; - alertUserCanceled alertType = 90; - alertNoRenegotiation alertType = 100; + alertCloseNotify alertType = 0 + alertUnexpectedMessage alertType = 10 + alertBadRecordMAC alertType = 20 + alertDecryptionFailed alertType = 21 + alertRecordOverflow alertType = 22 + alertDecompressionFailure alertType = 30 + alertHandshakeFailure alertType = 40 + alertBadCertificate alertType = 42 + alertUnsupportedCertificate alertType = 43 + alertCertificateRevoked alertType = 44 + alertCertificateExpired alertType = 45 + alertCertificateUnknown alertType = 46 + alertIllegalParameter alertType = 47 + alertUnknownCA alertType = 48 + alertAccessDenied alertType = 49 + alertDecodeError alertType = 50 + alertDecryptError alertType = 51 + alertProtocolVersion alertType = 70 + alertInsufficientSecurity alertType = 71 + alertInternalError alertType = 80 + alertUserCanceled alertType = 90 + alertNoRenegotiation alertType = 100 ) type alert struct { - level alertLevel; - error alertType; + level alertLevel + error alertType } diff --git a/ca_set.go b/ca_set.go index e8cddd6..00f6a87 100644 --- a/ca_set.go +++ b/ca_set.go @@ -5,14 +5,14 @@ package tls import ( - "crypto/x509"; - "encoding/pem"; + "crypto/x509" + "encoding/pem" ) // A CASet is a set of certificates. type CASet struct { - bySubjectKeyId map[string]*x509.Certificate; - byName map[string]*x509.Certificate; + bySubjectKeyId map[string]*x509.Certificate + byName map[string]*x509.Certificate } func NewCASet() *CASet { @@ -29,7 +29,7 @@ func nameToKey(name *x509.Name) string { // FindParent attempts to find the certificate in s which signs the given // certificate. If no such certificate can be found, it returns nil. func (s *CASet) FindParent(cert *x509.Certificate) (parent *x509.Certificate) { - var ok bool; + var ok bool if len(cert.AuthorityKeyId) > 0 { parent, ok = s.bySubjectKeyId[string(cert.AuthorityKeyId)] @@ -40,7 +40,7 @@ func (s *CASet) FindParent(cert *x509.Certificate) (parent *x509.Certificate) { if !ok { return nil } - return parent; + return parent } // SetFromPEM attempts to parse a series of PEM encoded root certificates. It @@ -50,8 +50,8 @@ func (s *CASet) FindParent(cert *x509.Certificate) (parent *x509.Certificate) { // function. func (s *CASet) SetFromPEM(pemCerts []byte) (ok bool) { for len(pemCerts) > 0 { - var block *pem.Block; - block, pemCerts = pem.Decode(pemCerts); + var block *pem.Block + block, pemCerts = pem.Decode(pemCerts) if block == nil { break } @@ -59,7 +59,7 @@ func (s *CASet) SetFromPEM(pemCerts []byte) (ok bool) { continue } - cert, err := x509.ParseCertificate(block.Bytes); + cert, err := x509.ParseCertificate(block.Bytes) if err != nil { continue } @@ -67,9 +67,9 @@ func (s *CASet) SetFromPEM(pemCerts []byte) (ok bool) { if len(cert.SubjectKeyId) > 0 { s.bySubjectKeyId[string(cert.SubjectKeyId)] = cert } - s.byName[nameToKey(&cert.Subject)] = cert; - ok = true; + s.byName[nameToKey(&cert.Subject)] = cert + ok = true } - return; + return } diff --git a/common.go b/common.go index e1318a8..51de533 100644 --- a/common.go +++ b/common.go @@ -5,21 +5,21 @@ package tls import ( - "crypto/rsa"; - "io"; - "os"; + "crypto/rsa" + "io" + "os" ) const ( // maxTLSCiphertext is the maximum length of a plaintext payload. - maxTLSPlaintext = 16384; + maxTLSPlaintext = 16384 // maxTLSCiphertext is the maximum length payload after compression and encryption. - maxTLSCiphertext = 16384 + 2048; + maxTLSCiphertext = 16384 + 2048 // maxHandshakeMsg is the largest single handshake message that we'll buffer. - maxHandshakeMsg = 65536; + maxHandshakeMsg = 65536 // defaultMajor and defaultMinor are the maximum TLS version that we support. - defaultMajor = 3; - defaultMinor = 2; + defaultMajor = 3 + defaultMinor = 2 ) @@ -27,68 +27,68 @@ const ( type recordType uint8 const ( - recordTypeChangeCipherSpec recordType = 20; - recordTypeAlert recordType = 21; - recordTypeHandshake recordType = 22; - recordTypeApplicationData recordType = 23; + recordTypeChangeCipherSpec recordType = 20 + recordTypeAlert recordType = 21 + recordTypeHandshake recordType = 22 + recordTypeApplicationData recordType = 23 ) // TLS handshake message types. const ( - typeClientHello uint8 = 1; - typeServerHello uint8 = 2; - typeCertificate uint8 = 11; - typeServerHelloDone uint8 = 14; - typeClientKeyExchange uint8 = 16; - typeFinished uint8 = 20; + typeClientHello uint8 = 1 + typeServerHello uint8 = 2 + typeCertificate uint8 = 11 + typeServerHelloDone uint8 = 14 + typeClientKeyExchange uint8 = 16 + typeFinished uint8 = 20 ) // TLS cipher suites. var ( - TLS_RSA_WITH_RC4_128_SHA uint16 = 5; + TLS_RSA_WITH_RC4_128_SHA uint16 = 5 ) // TLS compression types. var ( - compressionNone uint8 = 0; + compressionNone uint8 = 0 ) type ConnectionState struct { - HandshakeComplete bool; - CipherSuite string; - Error alertType; + HandshakeComplete bool + CipherSuite string + Error alertType } // A Config structure is used to configure a TLS client or server. After one // has been passed to a TLS function it must not be modified. type Config struct { // Rand provides the source of entropy for nonces and RSA blinding. - Rand io.Reader; + Rand io.Reader // Time returns the current time as the number of seconds since the epoch. - Time func() int64; - Certificates []Certificate; - RootCAs *CASet; + Time func() int64 + Certificates []Certificate + RootCAs *CASet } type Certificate struct { - Certificate [][]byte; - PrivateKey *rsa.PrivateKey; + Certificate [][]byte + PrivateKey *rsa.PrivateKey } // A TLS record. type record struct { - contentType recordType; - major, minor uint8; - payload []byte; + contentType recordType + major, minor uint8 + payload []byte } type handshakeMessage interface { - marshal() []byte; + marshal() []byte } type encryptor interface { // XORKeyStream xors the contents of the slice with bytes from the key stream. - XORKeyStream(buf []byte); + XORKeyStream(buf []byte) } // mutualVersion returns the protocol version to use given the advertised @@ -98,24 +98,24 @@ func mutualVersion(theirMajor, theirMinor uint8) (major, minor uint8, ok bool) { if theirMajor < 3 || theirMajor == 3 && theirMinor < 1 { return 0, 0, false } - major = 3; - minor = 2; + major = 3 + minor = 2 if theirMinor < minor { minor = theirMinor } - ok = true; - return; + ok = true + return } // A nop implements the NULL encryption and MAC algorithms. type nop struct{} -func (nop) XORKeyStream(buf []byte) {} +func (nop) XORKeyStream(buf []byte) {} -func (nop) Write(buf []byte) (int, os.Error) { return len(buf), nil } +func (nop) Write(buf []byte) (int, os.Error) { return len(buf), nil } -func (nop) Sum() []byte { return nil } +func (nop) Sum() []byte { return nil } -func (nop) Reset() {} +func (nop) Reset() {} -func (nop) Size() int { return 0 } +func (nop) Size() int { return 0 } diff --git a/handshake_client.go b/handshake_client.go index 1c6bd4b..4e31e70 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -5,33 +5,33 @@ package tls import ( - "crypto/hmac"; - "crypto/rc4"; - "crypto/rsa"; - "crypto/sha1"; - "crypto/subtle"; - "crypto/x509"; - "io"; + "crypto/hmac" + "crypto/rc4" + "crypto/rsa" + "crypto/sha1" + "crypto/subtle" + "crypto/x509" + "io" ) // A serverHandshake performs the server side of the TLS 1.1 handshake protocol. type clientHandshake struct { - writeChan chan<- interface{}; - controlChan chan<- interface{}; - msgChan <-chan interface{}; - config *Config; + writeChan chan<- interface{} + controlChan chan<- interface{} + msgChan <-chan interface{} + config *Config } func (h *clientHandshake) loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config) { - h.writeChan = writeChan; - h.controlChan = controlChan; - h.msgChan = msgChan; - h.config = config; + h.writeChan = writeChan + h.controlChan = controlChan + h.msgChan = msgChan + h.config = config - defer close(writeChan); - defer close(controlChan); + defer close(writeChan) + defer close(controlChan) - finishedHash := newFinishedHash(); + finishedHash := newFinishedHash() hello := &clientHelloMsg{ major: defaultMajor, @@ -39,175 +39,175 @@ func (h *clientHandshake) loop(writeChan chan<- interface{}, controlChan chan<- cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, compressionMethods: []uint8{compressionNone}, random: make([]byte, 32), - }; - - currentTime := uint32(config.Time()); - hello.random[0] = byte(currentTime >> 24); - hello.random[1] = byte(currentTime >> 16); - hello.random[2] = byte(currentTime >> 8); - hello.random[3] = byte(currentTime); - _, err := io.ReadFull(config.Rand, hello.random[4:]); + } + + currentTime := uint32(config.Time()) + hello.random[0] = byte(currentTime >> 24) + hello.random[1] = byte(currentTime >> 16) + hello.random[2] = byte(currentTime >> 8) + hello.random[3] = byte(currentTime) + _, err := io.ReadFull(config.Rand, hello.random[4:]) if err != nil { - h.error(alertInternalError); - return; + h.error(alertInternalError) + return } - finishedHash.Write(hello.marshal()); - writeChan <- writerSetVersion{defaultMajor, defaultMinor}; - writeChan <- hello; + finishedHash.Write(hello.marshal()) + writeChan <- writerSetVersion{defaultMajor, defaultMinor} + writeChan <- hello - serverHello, ok := h.readHandshakeMsg().(*serverHelloMsg); + serverHello, ok := h.readHandshakeMsg().(*serverHelloMsg) if !ok { - h.error(alertUnexpectedMessage); - return; + h.error(alertUnexpectedMessage) + return } - finishedHash.Write(serverHello.marshal()); - major, minor, ok := mutualVersion(serverHello.major, serverHello.minor); + finishedHash.Write(serverHello.marshal()) + major, minor, ok := mutualVersion(serverHello.major, serverHello.minor) if !ok { - h.error(alertProtocolVersion); - return; + h.error(alertProtocolVersion) + return } - writeChan <- writerSetVersion{major, minor}; + writeChan <- writerSetVersion{major, minor} if serverHello.cipherSuite != TLS_RSA_WITH_RC4_128_SHA || serverHello.compressionMethod != compressionNone { - h.error(alertUnexpectedMessage); - return; + h.error(alertUnexpectedMessage) + return } - certMsg, ok := h.readHandshakeMsg().(*certificateMsg); + certMsg, ok := h.readHandshakeMsg().(*certificateMsg) if !ok || len(certMsg.certificates) == 0 { - h.error(alertUnexpectedMessage); - return; + h.error(alertUnexpectedMessage) + return } - finishedHash.Write(certMsg.marshal()); + finishedHash.Write(certMsg.marshal()) - certs := make([]*x509.Certificate, len(certMsg.certificates)); + certs := make([]*x509.Certificate, len(certMsg.certificates)) for i, asn1Data := range certMsg.certificates { - cert, err := x509.ParseCertificate(asn1Data); + cert, err := x509.ParseCertificate(asn1Data) if err != nil { - h.error(alertBadCertificate); - return; + h.error(alertBadCertificate) + return } - certs[i] = cert; + certs[i] = cert } // TODO(agl): do better validation of certs: max path length, name restrictions etc. for i := 1; i < len(certs); i++ { if certs[i-1].CheckSignatureFrom(certs[i]) != nil { - h.error(alertBadCertificate); - return; + h.error(alertBadCertificate) + return } } if config.RootCAs != nil { - root := config.RootCAs.FindParent(certs[len(certs)-1]); + root := config.RootCAs.FindParent(certs[len(certs)-1]) if root == nil { - h.error(alertBadCertificate); - return; + h.error(alertBadCertificate) + return } if certs[len(certs)-1].CheckSignatureFrom(root) != nil { - h.error(alertBadCertificate); - return; + h.error(alertBadCertificate) + return } } - pub, ok := certs[0].PublicKey.(*rsa.PublicKey); + pub, ok := certs[0].PublicKey.(*rsa.PublicKey) if !ok { - h.error(alertUnsupportedCertificate); - return; + h.error(alertUnsupportedCertificate) + return } - shd, ok := h.readHandshakeMsg().(*serverHelloDoneMsg); + shd, ok := h.readHandshakeMsg().(*serverHelloDoneMsg) if !ok { - h.error(alertUnexpectedMessage); - return; + h.error(alertUnexpectedMessage) + return } - finishedHash.Write(shd.marshal()); + finishedHash.Write(shd.marshal()) - ckx := new(clientKeyExchangeMsg); - preMasterSecret := make([]byte, 48); + ckx := new(clientKeyExchangeMsg) + preMasterSecret := make([]byte, 48) // Note that the version number in the preMasterSecret must be the // version offered in the ClientHello. - preMasterSecret[0] = defaultMajor; - preMasterSecret[1] = defaultMinor; - _, err = io.ReadFull(config.Rand, preMasterSecret[2:]); + preMasterSecret[0] = defaultMajor + preMasterSecret[1] = defaultMinor + _, err = io.ReadFull(config.Rand, preMasterSecret[2:]) if err != nil { - h.error(alertInternalError); - return; + h.error(alertInternalError) + return } - ckx.ciphertext, err = rsa.EncryptPKCS1v15(config.Rand, pub, preMasterSecret); + ckx.ciphertext, err = rsa.EncryptPKCS1v15(config.Rand, pub, preMasterSecret) if err != nil { - h.error(alertInternalError); - return; + h.error(alertInternalError) + return } - finishedHash.Write(ckx.marshal()); - writeChan <- ckx; + finishedHash.Write(ckx.marshal()) + writeChan <- ckx - suite := cipherSuites[0]; + suite := cipherSuites[0] masterSecret, clientMAC, serverMAC, clientKey, serverKey := - keysFromPreMasterSecret11(preMasterSecret, hello.random, serverHello.random, suite.hashLength, suite.cipherKeyLength); + keysFromPreMasterSecret11(preMasterSecret, hello.random, serverHello.random, suite.hashLength, suite.cipherKeyLength) - cipher, _ := rc4.NewCipher(clientKey); - writeChan <- writerChangeCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)}; + cipher, _ := rc4.NewCipher(clientKey) + writeChan <- writerChangeCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)} - finished := new(finishedMsg); - finished.verifyData = finishedHash.clientSum(masterSecret); - finishedHash.Write(finished.marshal()); - writeChan <- finished; + finished := new(finishedMsg) + finished.verifyData = finishedHash.clientSum(masterSecret) + finishedHash.Write(finished.marshal()) + writeChan <- finished // TODO(agl): this is cut-through mode which should probably be an option. - writeChan <- writerEnableApplicationData{}; + writeChan <- writerEnableApplicationData{} - _, ok = h.readHandshakeMsg().(changeCipherSpec); + _, ok = h.readHandshakeMsg().(changeCipherSpec) if !ok { - h.error(alertUnexpectedMessage); - return; + h.error(alertUnexpectedMessage) + return } - cipher2, _ := rc4.NewCipher(serverKey); - controlChan <- &newCipherSpec{cipher2, hmac.New(sha1.New(), serverMAC)}; + cipher2, _ := rc4.NewCipher(serverKey) + controlChan <- &newCipherSpec{cipher2, hmac.New(sha1.New(), serverMAC)} - serverFinished, ok := h.readHandshakeMsg().(*finishedMsg); + serverFinished, ok := h.readHandshakeMsg().(*finishedMsg) if !ok { - h.error(alertUnexpectedMessage); - return; + h.error(alertUnexpectedMessage) + return } - verify := finishedHash.serverSum(masterSecret); + verify := finishedHash.serverSum(masterSecret) if len(verify) != len(serverFinished.verifyData) || subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { - h.error(alertHandshakeFailure); - return; + h.error(alertHandshakeFailure) + return } - controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0}; + controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0} // This should just block forever. - _ = h.readHandshakeMsg(); - h.error(alertUnexpectedMessage); - return; + _ = h.readHandshakeMsg() + h.error(alertUnexpectedMessage) + return } func (h *clientHandshake) readHandshakeMsg() interface{} { - v := <-h.msgChan; + v := <-h.msgChan if closed(h.msgChan) { // If the channel closed then the processor received an error // from the peer and we don't want to echo it back to them. - h.msgChan = nil; - return 0; + h.msgChan = nil + return 0 } if _, ok := v.(alert); ok { // We got an alert from the processor. We forward to the writer // and shutdown. - h.writeChan <- v; - h.msgChan = nil; - return 0; + h.writeChan <- v + h.msgChan = nil + return 0 } - return v; + return v } func (h *clientHandshake) error(e alertType) { @@ -217,9 +217,9 @@ func (h *clientHandshake) error(e alertType) { go func() { for _ = range h.msgChan { } - }(); - h.controlChan <- ConnectionState{false, "", e}; - close(h.controlChan); - h.writeChan <- alert{alertLevelError, e}; + }() + h.controlChan <- ConnectionState{false, "", e} + close(h.controlChan) + h.writeChan <- alert{alertLevelError, e} } } diff --git a/handshake_messages.go b/handshake_messages.go index b5f2aa7..2870969 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -5,12 +5,12 @@ package tls type clientHelloMsg struct { - raw []byte; - major, minor uint8; - random []byte; - sessionId []byte; - cipherSuites []uint16; - compressionMethods []uint8; + raw []byte + major, minor uint8 + random []byte + sessionId []byte + cipherSuites []uint16 + compressionMethods []uint8 } func (m *clientHelloMsg) marshal() []byte { @@ -18,81 +18,81 @@ func (m *clientHelloMsg) marshal() []byte { return m.raw } - length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods); - x := make([]byte, 4+length); - x[0] = typeClientHello; - x[1] = uint8(length >> 16); - x[2] = uint8(length >> 8); - x[3] = uint8(length); - x[4] = m.major; - x[5] = m.minor; - copy(x[6:38], m.random); - x[38] = uint8(len(m.sessionId)); - copy(x[39:39+len(m.sessionId)], m.sessionId); - y := x[39+len(m.sessionId):]; - y[0] = uint8(len(m.cipherSuites) >> 7); - y[1] = uint8(len(m.cipherSuites) << 1); + length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods) + x := make([]byte, 4+length) + x[0] = typeClientHello + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[4] = m.major + x[5] = m.minor + copy(x[6:38], m.random) + x[38] = uint8(len(m.sessionId)) + copy(x[39:39+len(m.sessionId)], m.sessionId) + y := x[39+len(m.sessionId):] + y[0] = uint8(len(m.cipherSuites) >> 7) + y[1] = uint8(len(m.cipherSuites) << 1) for i, suite := range m.cipherSuites { - y[2+i*2] = uint8(suite >> 8); - y[3+i*2] = uint8(suite); + y[2+i*2] = uint8(suite >> 8) + y[3+i*2] = uint8(suite) } - z := y[2+len(m.cipherSuites)*2:]; - z[0] = uint8(len(m.compressionMethods)); - copy(z[1:], m.compressionMethods); - m.raw = x; + z := y[2+len(m.cipherSuites)*2:] + z[0] = uint8(len(m.compressionMethods)) + copy(z[1:], m.compressionMethods) + m.raw = x - return x; + return x } func (m *clientHelloMsg) unmarshal(data []byte) bool { if len(data) < 43 { return false } - m.raw = data; - m.major = data[4]; - m.minor = data[5]; - m.random = data[6:38]; - sessionIdLen := int(data[38]); + m.raw = data + m.major = data[4] + m.minor = data[5] + m.random = data[6:38] + sessionIdLen := int(data[38]) if sessionIdLen > 32 || len(data) < 39+sessionIdLen { return false } - m.sessionId = data[39 : 39+sessionIdLen]; - data = data[39+sessionIdLen:]; + m.sessionId = data[39 : 39+sessionIdLen] + data = data[39+sessionIdLen:] if len(data) < 2 { return false } // cipherSuiteLen is the number of bytes of cipher suite numbers. Since // they are uint16s, the number must be even. - cipherSuiteLen := int(data[0])<<8 | int(data[1]); + cipherSuiteLen := int(data[0])<<8 | int(data[1]) if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen { return false } - numCipherSuites := cipherSuiteLen / 2; - m.cipherSuites = make([]uint16, numCipherSuites); + numCipherSuites := cipherSuiteLen / 2 + m.cipherSuites = make([]uint16, numCipherSuites) for i := 0; i < numCipherSuites; i++ { m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i]) } - data = data[2+cipherSuiteLen:]; + data = data[2+cipherSuiteLen:] if len(data) < 2 { return false } - compressionMethodsLen := int(data[0]); + compressionMethodsLen := int(data[0]) if len(data) < 1+compressionMethodsLen { return false } - m.compressionMethods = data[1 : 1+compressionMethodsLen]; + m.compressionMethods = data[1 : 1+compressionMethodsLen] // A ClientHello may be following by trailing data: RFC 4346 section 7.4.1.2 - return true; + return true } type serverHelloMsg struct { - raw []byte; - major, minor uint8; - random []byte; - sessionId []byte; - cipherSuite uint16; - compressionMethod uint8; + raw []byte + major, minor uint8 + random []byte + sessionId []byte + cipherSuite uint16 + compressionMethod uint8 } func (m *serverHelloMsg) marshal() []byte { @@ -100,53 +100,53 @@ func (m *serverHelloMsg) marshal() []byte { return m.raw } - length := 38 + len(m.sessionId); - x := make([]byte, 4+length); - x[0] = typeServerHello; - x[1] = uint8(length >> 16); - x[2] = uint8(length >> 8); - x[3] = uint8(length); - x[4] = m.major; - x[5] = m.minor; - copy(x[6:38], m.random); - x[38] = uint8(len(m.sessionId)); - copy(x[39:39+len(m.sessionId)], m.sessionId); - z := x[39+len(m.sessionId):]; - z[0] = uint8(m.cipherSuite >> 8); - z[1] = uint8(m.cipherSuite); - z[2] = uint8(m.compressionMethod); - m.raw = x; - - return x; + length := 38 + len(m.sessionId) + x := make([]byte, 4+length) + x[0] = typeServerHello + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[4] = m.major + x[5] = m.minor + copy(x[6:38], m.random) + x[38] = uint8(len(m.sessionId)) + copy(x[39:39+len(m.sessionId)], m.sessionId) + z := x[39+len(m.sessionId):] + z[0] = uint8(m.cipherSuite >> 8) + z[1] = uint8(m.cipherSuite) + z[2] = uint8(m.compressionMethod) + m.raw = x + + return x } func (m *serverHelloMsg) unmarshal(data []byte) bool { if len(data) < 42 { return false } - m.raw = data; - m.major = data[4]; - m.minor = data[5]; - m.random = data[6:38]; - sessionIdLen := int(data[38]); + m.raw = data + m.major = data[4] + m.minor = data[5] + m.random = data[6:38] + sessionIdLen := int(data[38]) if sessionIdLen > 32 || len(data) < 39+sessionIdLen { return false } - m.sessionId = data[39 : 39+sessionIdLen]; - data = data[39+sessionIdLen:]; + m.sessionId = data[39 : 39+sessionIdLen] + data = data[39+sessionIdLen:] if len(data) < 3 { return false } - m.cipherSuite = uint16(data[0])<<8 | uint16(data[1]); - m.compressionMethod = data[2]; + m.cipherSuite = uint16(data[0])<<8 | uint16(data[1]) + m.compressionMethod = data[2] // Trailing data is allowed because extensions may be present. - return true; + return true } type certificateMsg struct { - raw []byte; - certificates [][]byte; + raw []byte + certificates [][]byte } func (m *certificateMsg) marshal() (x []byte) { @@ -154,34 +154,34 @@ func (m *certificateMsg) marshal() (x []byte) { return m.raw } - var i int; + var i int for _, slice := range m.certificates { i += len(slice) } - length := 3 + 3*len(m.certificates) + i; - x = make([]byte, 4+length); - x[0] = typeCertificate; - x[1] = uint8(length >> 16); - x[2] = uint8(length >> 8); - x[3] = uint8(length); + length := 3 + 3*len(m.certificates) + i + x = make([]byte, 4+length) + x[0] = typeCertificate + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) - certificateOctets := length - 3; - x[4] = uint8(certificateOctets >> 16); - x[5] = uint8(certificateOctets >> 8); - x[6] = uint8(certificateOctets); + certificateOctets := length - 3 + x[4] = uint8(certificateOctets >> 16) + x[5] = uint8(certificateOctets >> 8) + x[6] = uint8(certificateOctets) - y := x[7:]; + y := x[7:] for _, slice := range m.certificates { - y[0] = uint8(len(slice) >> 16); - y[1] = uint8(len(slice) >> 8); - y[2] = uint8(len(slice)); - copy(y[3:], slice); - y = y[3+len(slice):]; + y[0] = uint8(len(slice) >> 16) + y[1] = uint8(len(slice) >> 8) + y[2] = uint8(len(slice)) + copy(y[3:], slice) + y = y[3+len(slice):] } - m.raw = x; - return; + m.raw = x + return } func (m *certificateMsg) unmarshal(data []byte) bool { @@ -189,44 +189,44 @@ func (m *certificateMsg) unmarshal(data []byte) bool { return false } - m.raw = data; - certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]); + m.raw = data + certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) if uint32(len(data)) != certsLen+7 { return false } - numCerts := 0; - d := data[7:]; + numCerts := 0 + d := data[7:] for certsLen > 0 { if len(d) < 4 { return false } - certLen := uint32(d[0])<<24 | uint32(d[1])<<8 | uint32(d[2]); + certLen := uint32(d[0])<<24 | uint32(d[1])<<8 | uint32(d[2]) if uint32(len(d)) < 3+certLen { return false } - d = d[3+certLen:]; - certsLen -= 3 + certLen; - numCerts++; + d = d[3+certLen:] + certsLen -= 3 + certLen + numCerts++ } - m.certificates = make([][]byte, numCerts); - d = data[7:]; + m.certificates = make([][]byte, numCerts) + d = data[7:] for i := 0; i < numCerts; i++ { - certLen := uint32(d[0])<<24 | uint32(d[1])<<8 | uint32(d[2]); - m.certificates[i] = d[3 : 3+certLen]; - d = d[3+certLen:]; + certLen := uint32(d[0])<<24 | uint32(d[1])<<8 | uint32(d[2]) + m.certificates[i] = d[3 : 3+certLen] + d = d[3+certLen:] } - return true; + return true } type serverHelloDoneMsg struct{} func (m *serverHelloDoneMsg) marshal() []byte { - x := make([]byte, 4); - x[0] = typeServerHelloDone; - return x; + x := make([]byte, 4) + x[0] = typeServerHelloDone + return x } func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { @@ -234,44 +234,44 @@ func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { } type clientKeyExchangeMsg struct { - raw []byte; - ciphertext []byte; + raw []byte + ciphertext []byte } func (m *clientKeyExchangeMsg) marshal() []byte { if m.raw != nil { return m.raw } - length := len(m.ciphertext) + 2; - x := make([]byte, length+4); - x[0] = typeClientKeyExchange; - x[1] = uint8(length >> 16); - x[2] = uint8(length >> 8); - x[3] = uint8(length); - x[4] = uint8(len(m.ciphertext) >> 8); - x[5] = uint8(len(m.ciphertext)); - copy(x[6:], m.ciphertext); - - m.raw = x; - return x; + length := len(m.ciphertext) + 2 + x := make([]byte, length+4) + x[0] = typeClientKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[4] = uint8(len(m.ciphertext) >> 8) + x[5] = uint8(len(m.ciphertext)) + copy(x[6:], m.ciphertext) + + m.raw = x + return x } func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { - m.raw = data; + m.raw = data if len(data) < 7 { return false } - cipherTextLen := int(data[4])<<8 | int(data[5]); + cipherTextLen := int(data[4])<<8 | int(data[5]) if len(data) != 6+cipherTextLen { return false } - m.ciphertext = data[6:]; - return true; + m.ciphertext = data[6:] + return true } type finishedMsg struct { - raw []byte; - verifyData []byte; + raw []byte + verifyData []byte } func (m *finishedMsg) marshal() (x []byte) { @@ -279,19 +279,19 @@ func (m *finishedMsg) marshal() (x []byte) { return m.raw } - x = make([]byte, 16); - x[0] = typeFinished; - x[3] = 12; - copy(x[4:], m.verifyData); - m.raw = x; - return; + x = make([]byte, 16) + x[0] = typeFinished + x[3] = 12 + copy(x[4:], m.verifyData) + m.raw = x + return } func (m *finishedMsg) unmarshal(data []byte) bool { - m.raw = data; + m.raw = data if len(data) != 4+12 { return false } - m.verifyData = data[4:]; - return true; + m.verifyData = data[4:] + return true } diff --git a/handshake_messages_test.go b/handshake_messages_test.go index c580f65..4bfdd6c 100644 --- a/handshake_messages_test.go +++ b/handshake_messages_test.go @@ -5,10 +5,10 @@ package tls import ( - "rand"; - "reflect"; - "testing"; - "testing/quick"; + "rand" + "reflect" + "testing" + "testing/quick" ) var tests = []interface{}{ @@ -20,41 +20,41 @@ var tests = []interface{}{ } type testMessage interface { - marshal() []byte; - unmarshal([]byte) bool; + marshal() []byte + unmarshal([]byte) bool } func TestMarshalUnmarshal(t *testing.T) { - rand := rand.New(rand.NewSource(0)); + rand := rand.New(rand.NewSource(0)) for i, iface := range tests { - ty := reflect.NewValue(iface).Type(); + ty := reflect.NewValue(iface).Type() for j := 0; j < 100; j++ { - v, ok := quick.Value(ty, rand); + v, ok := quick.Value(ty, rand) if !ok { - t.Errorf("#%d: failed to create value", i); - break; + t.Errorf("#%d: failed to create value", i) + break } - m1 := v.Interface().(testMessage); - marshaled := m1.marshal(); - m2 := iface.(testMessage); + m1 := v.Interface().(testMessage) + marshaled := m1.marshal() + m2 := iface.(testMessage) if !m2.unmarshal(marshaled) { - t.Errorf("#%d failed to unmarshal %#v", i, m1); - break; + t.Errorf("#%d failed to unmarshal %#v", i, m1) + break } - m2.marshal(); // to fill any marshal cache in the message + m2.marshal() // to fill any marshal cache in the message if !reflect.DeepEqual(m1, m2) { - t.Errorf("#%d got:%#v want:%#v", i, m1, m2); - break; + t.Errorf("#%d got:%#v want:%#v", i, m1, m2) + break } // Now check that all prefixes are invalid. for j := 0; j < len(marshaled); j++ { if m2.unmarshal(marshaled[0:j]) { - t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1); - break; + t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) + break } } } @@ -62,71 +62,71 @@ func TestMarshalUnmarshal(t *testing.T) { } func TestFuzz(t *testing.T) { - rand := rand.New(rand.NewSource(0)); + rand := rand.New(rand.NewSource(0)) for _, iface := range tests { - m := iface.(testMessage); + m := iface.(testMessage) for j := 0; j < 1000; j++ { - len := rand.Intn(100); - bytes := randomBytes(len, rand); + len := rand.Intn(100) + bytes := randomBytes(len, rand) // This just looks for crashes due to bounds errors etc. - m.unmarshal(bytes); + m.unmarshal(bytes) } } } func randomBytes(n int, rand *rand.Rand) []byte { - r := make([]byte, n); + r := make([]byte, n) for i := 0; i < n; i++ { r[i] = byte(rand.Int31()) } - return r; + return r } func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { - m := &clientHelloMsg{}; - m.major = uint8(rand.Intn(256)); - m.minor = uint8(rand.Intn(256)); - m.random = randomBytes(32, rand); - m.sessionId = randomBytes(rand.Intn(32), rand); - m.cipherSuites = make([]uint16, rand.Intn(63)+1); + m := &clientHelloMsg{} + m.major = uint8(rand.Intn(256)) + m.minor = uint8(rand.Intn(256)) + m.random = randomBytes(32, rand) + m.sessionId = randomBytes(rand.Intn(32), rand) + m.cipherSuites = make([]uint16, rand.Intn(63)+1) for i := 0; i < len(m.cipherSuites); i++ { m.cipherSuites[i] = uint16(rand.Int31()) } - m.compressionMethods = randomBytes(rand.Intn(63)+1, rand); + m.compressionMethods = randomBytes(rand.Intn(63)+1, rand) - return reflect.NewValue(m); + return reflect.NewValue(m) } func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { - m := &serverHelloMsg{}; - m.major = uint8(rand.Intn(256)); - m.minor = uint8(rand.Intn(256)); - m.random = randomBytes(32, rand); - m.sessionId = randomBytes(rand.Intn(32), rand); - m.cipherSuite = uint16(rand.Int31()); - m.compressionMethod = uint8(rand.Intn(256)); - return reflect.NewValue(m); + m := &serverHelloMsg{} + m.major = uint8(rand.Intn(256)) + m.minor = uint8(rand.Intn(256)) + m.random = randomBytes(32, rand) + m.sessionId = randomBytes(rand.Intn(32), rand) + m.cipherSuite = uint16(rand.Int31()) + m.compressionMethod = uint8(rand.Intn(256)) + return reflect.NewValue(m) } func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { - m := &certificateMsg{}; - numCerts := rand.Intn(20); - m.certificates = make([][]byte, numCerts); + m := &certificateMsg{} + numCerts := rand.Intn(20) + m.certificates = make([][]byte, numCerts) for i := 0; i < numCerts; i++ { m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) } - return reflect.NewValue(m); + return reflect.NewValue(m) } func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { - m := &clientKeyExchangeMsg{}; - m.ciphertext = randomBytes(rand.Intn(1000)+1, rand); - return reflect.NewValue(m); + m := &clientKeyExchangeMsg{} + m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) + return reflect.NewValue(m) } func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { - m := &finishedMsg{}; - m.verifyData = randomBytes(12, rand); - return reflect.NewValue(m); + m := &finishedMsg{} + m.verifyData = randomBytes(12, rand) + return reflect.NewValue(m) } diff --git a/handshake_server.go b/handshake_server.go index 2e77603..5314e5c 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -13,17 +13,17 @@ package tls // channel in the message (ChangeCipherSpec). import ( - "crypto/hmac"; - "crypto/rc4"; - "crypto/rsa"; - "crypto/sha1"; - "crypto/subtle"; - "io"; + "crypto/hmac" + "crypto/rc4" + "crypto/rsa" + "crypto/sha1" + "crypto/subtle" + "io" ) type cipherSuite struct { - id uint16; // The number of this suite on the wire. - hashLength, cipherKeyLength int; + id uint16 // The number of this suite on the wire. + hashLength, cipherKeyLength int // TODO(agl): need a method to create the cipher and hash interfaces. } @@ -33,118 +33,118 @@ var cipherSuites = []cipherSuite{ // A serverHandshake performs the server side of the TLS 1.1 handshake protocol. type serverHandshake struct { - writeChan chan<- interface{}; - controlChan chan<- interface{}; - msgChan <-chan interface{}; - config *Config; + writeChan chan<- interface{} + controlChan chan<- interface{} + msgChan <-chan interface{} + config *Config } func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config) { - h.writeChan = writeChan; - h.controlChan = controlChan; - h.msgChan = msgChan; - h.config = config; + h.writeChan = writeChan + h.controlChan = controlChan + h.msgChan = msgChan + h.config = config - defer close(writeChan); - defer close(controlChan); + defer close(writeChan) + defer close(controlChan) - clientHello, ok := h.readHandshakeMsg().(*clientHelloMsg); + clientHello, ok := h.readHandshakeMsg().(*clientHelloMsg) if !ok { - h.error(alertUnexpectedMessage); - return; + h.error(alertUnexpectedMessage) + return } - major, minor, ok := mutualVersion(clientHello.major, clientHello.minor); + major, minor, ok := mutualVersion(clientHello.major, clientHello.minor) if !ok { - h.error(alertProtocolVersion); - return; + h.error(alertProtocolVersion) + return } - finishedHash := newFinishedHash(); - finishedHash.Write(clientHello.marshal()); + finishedHash := newFinishedHash() + finishedHash.Write(clientHello.marshal()) - hello := new(serverHelloMsg); + hello := new(serverHelloMsg) // We only support a single ciphersuite so we look for it in the list // of client supported suites. // // TODO(agl): Add additional cipher suites. - var suite *cipherSuite; + var suite *cipherSuite for _, id := range clientHello.cipherSuites { for _, supported := range cipherSuites { if supported.id == id { - suite = &supported; - break; + suite = &supported + break } } } - foundCompression := false; + foundCompression := false // We only support null compression, so check that the client offered it. for _, compression := range clientHello.compressionMethods { if compression == compressionNone { - foundCompression = true; - break; + foundCompression = true + break } } if suite == nil || !foundCompression { - h.error(alertHandshakeFailure); - return; - } - - hello.major = major; - hello.minor = minor; - hello.cipherSuite = suite.id; - currentTime := uint32(config.Time()); - hello.random = make([]byte, 32); - hello.random[0] = byte(currentTime >> 24); - hello.random[1] = byte(currentTime >> 16); - hello.random[2] = byte(currentTime >> 8); - hello.random[3] = byte(currentTime); - _, err := io.ReadFull(config.Rand, hello.random[4:]); + h.error(alertHandshakeFailure) + return + } + + hello.major = major + hello.minor = minor + hello.cipherSuite = suite.id + currentTime := uint32(config.Time()) + hello.random = make([]byte, 32) + hello.random[0] = byte(currentTime >> 24) + hello.random[1] = byte(currentTime >> 16) + hello.random[2] = byte(currentTime >> 8) + hello.random[3] = byte(currentTime) + _, err := io.ReadFull(config.Rand, hello.random[4:]) if err != nil { - h.error(alertInternalError); - return; + h.error(alertInternalError) + return } - hello.compressionMethod = compressionNone; + hello.compressionMethod = compressionNone - finishedHash.Write(hello.marshal()); - writeChan <- writerSetVersion{major, minor}; - writeChan <- hello; + finishedHash.Write(hello.marshal()) + writeChan <- writerSetVersion{major, minor} + writeChan <- hello if len(config.Certificates) == 0 { - h.error(alertInternalError); - return; + h.error(alertInternalError) + return } - certMsg := new(certificateMsg); - certMsg.certificates = config.Certificates[0].Certificate; - finishedHash.Write(certMsg.marshal()); - writeChan <- certMsg; + certMsg := new(certificateMsg) + certMsg.certificates = config.Certificates[0].Certificate + finishedHash.Write(certMsg.marshal()) + writeChan <- certMsg - helloDone := new(serverHelloDoneMsg); - finishedHash.Write(helloDone.marshal()); - writeChan <- helloDone; + helloDone := new(serverHelloDoneMsg) + finishedHash.Write(helloDone.marshal()) + writeChan <- helloDone - ckx, ok := h.readHandshakeMsg().(*clientKeyExchangeMsg); + ckx, ok := h.readHandshakeMsg().(*clientKeyExchangeMsg) if !ok { - h.error(alertUnexpectedMessage); - return; + h.error(alertUnexpectedMessage) + return } - finishedHash.Write(ckx.marshal()); + finishedHash.Write(ckx.marshal()) - preMasterSecret := make([]byte, 48); - _, err = io.ReadFull(config.Rand, preMasterSecret[2:]); + preMasterSecret := make([]byte, 48) + _, err = io.ReadFull(config.Rand, preMasterSecret[2:]) if err != nil { - h.error(alertInternalError); - return; + h.error(alertInternalError) + return } - err = rsa.DecryptPKCS1v15SessionKey(config.Rand, config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret); + err = rsa.DecryptPKCS1v15SessionKey(config.Rand, config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret) if err != nil { - h.error(alertHandshakeFailure); - return; + h.error(alertHandshakeFailure) + return } // We don't check the version number in the premaster secret. For one, // by checking it, we would leak information about the validity of the @@ -154,70 +154,70 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<- // 7.4.7.1 of RFC 4346. masterSecret, clientMAC, serverMAC, clientKey, serverKey := - keysFromPreMasterSecret11(preMasterSecret, clientHello.random, hello.random, suite.hashLength, suite.cipherKeyLength); + keysFromPreMasterSecret11(preMasterSecret, clientHello.random, hello.random, suite.hashLength, suite.cipherKeyLength) - _, ok = h.readHandshakeMsg().(changeCipherSpec); + _, ok = h.readHandshakeMsg().(changeCipherSpec) if !ok { - h.error(alertUnexpectedMessage); - return; + h.error(alertUnexpectedMessage) + return } - cipher, _ := rc4.NewCipher(clientKey); - controlChan <- &newCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)}; + cipher, _ := rc4.NewCipher(clientKey) + controlChan <- &newCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)} - clientFinished, ok := h.readHandshakeMsg().(*finishedMsg); + clientFinished, ok := h.readHandshakeMsg().(*finishedMsg) if !ok { - h.error(alertUnexpectedMessage); - return; + h.error(alertUnexpectedMessage) + return } - verify := finishedHash.clientSum(masterSecret); + verify := finishedHash.clientSum(masterSecret) if len(verify) != len(clientFinished.verifyData) || subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { - h.error(alertHandshakeFailure); - return; + h.error(alertHandshakeFailure) + return } - controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0}; + controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0} - finishedHash.Write(clientFinished.marshal()); + finishedHash.Write(clientFinished.marshal()) - cipher2, _ := rc4.NewCipher(serverKey); - writeChan <- writerChangeCipherSpec{cipher2, hmac.New(sha1.New(), serverMAC)}; + cipher2, _ := rc4.NewCipher(serverKey) + writeChan <- writerChangeCipherSpec{cipher2, hmac.New(sha1.New(), serverMAC)} - finished := new(finishedMsg); - finished.verifyData = finishedHash.serverSum(masterSecret); - writeChan <- finished; + finished := new(finishedMsg) + finished.verifyData = finishedHash.serverSum(masterSecret) + writeChan <- finished - writeChan <- writerEnableApplicationData{}; + writeChan <- writerEnableApplicationData{} for { - _, ok := h.readHandshakeMsg().(*clientHelloMsg); + _, ok := h.readHandshakeMsg().(*clientHelloMsg) if !ok { - h.error(alertUnexpectedMessage); - return; + h.error(alertUnexpectedMessage) + return } // We reject all renegotication requests. - writeChan <- alert{alertLevelWarning, alertNoRenegotiation}; + writeChan <- alert{alertLevelWarning, alertNoRenegotiation} } } func (h *serverHandshake) readHandshakeMsg() interface{} { - v := <-h.msgChan; + v := <-h.msgChan if closed(h.msgChan) { // If the channel closed then the processor received an error // from the peer and we don't want to echo it back to them. - h.msgChan = nil; - return 0; + h.msgChan = nil + return 0 } if _, ok := v.(alert); ok { // We got an alert from the processor. We forward to the writer // and shutdown. - h.writeChan <- v; - h.msgChan = nil; - return 0; + h.writeChan <- v + h.msgChan = nil + return 0 } - return v; + return v } func (h *serverHandshake) error(e alertType) { @@ -227,9 +227,9 @@ func (h *serverHandshake) error(e alertType) { go func() { for _ = range h.msgChan { } - }(); - h.controlChan <- ConnectionState{false, "", e}; - close(h.controlChan); - h.writeChan <- alert{alertLevelError, e}; + }() + h.controlChan <- ConnectionState{false, "", e} + close(h.controlChan) + h.writeChan <- alert{alertLevelError, e} } } diff --git a/handshake_server_test.go b/handshake_server_test.go index 91583d2..7160985 100644 --- a/handshake_server_test.go +++ b/handshake_server_test.go @@ -5,12 +5,12 @@ package tls import ( - "bytes"; - "big"; - "crypto/rsa"; - "os"; - "testing"; - "testing/script"; + "bytes" + "big" + "crypto/rsa" + "os" + "testing" + "testing/script" ) type zeroSource struct{} @@ -20,41 +20,41 @@ func (zeroSource) Read(b []byte) (n int, err os.Error) { b[i] = 0 } - return len(b), nil; + return len(b), nil } var testConfig *Config func init() { - testConfig = new(Config); - testConfig.Time = func() int64 { return 0 }; - testConfig.Rand = zeroSource{}; - testConfig.Certificates = make([]Certificate, 1); - testConfig.Certificates[0].Certificate = [][]byte{testCertificate}; - testConfig.Certificates[0].PrivateKey = testPrivateKey; + testConfig = new(Config) + testConfig.Time = func() int64 { return 0 } + testConfig.Rand = zeroSource{} + testConfig.Certificates = make([]Certificate, 1) + testConfig.Certificates[0].Certificate = [][]byte{testCertificate} + testConfig.Certificates[0].PrivateKey = testPrivateKey } func setupServerHandshake() (writeChan chan interface{}, controlChan chan interface{}, msgChan chan interface{}) { - sh := new(serverHandshake); - writeChan = make(chan interface{}); - controlChan = make(chan interface{}); - msgChan = make(chan interface{}); + sh := new(serverHandshake) + writeChan = make(chan interface{}) + controlChan = make(chan interface{}) + msgChan = make(chan interface{}) - go sh.loop(writeChan, controlChan, msgChan, testConfig); - return; + go sh.loop(writeChan, controlChan, msgChan, testConfig) + return } func testClientHelloFailure(t *testing.T, clientHello interface{}, expectedAlert alertType) { - writeChan, controlChan, msgChan := setupServerHandshake(); - defer close(msgChan); + writeChan, controlChan, msgChan := setupServerHandshake() + defer close(msgChan) - send := script.NewEvent("send", nil, script.Send{msgChan, clientHello}); - recvAlert := script.NewEvent("recv alert", []*script.Event{send}, script.Recv{writeChan, alert{alertLevelError, expectedAlert}}); - close1 := script.NewEvent("msgChan close", []*script.Event{recvAlert}, script.Closed{writeChan}); - recvState := script.NewEvent("recv state", []*script.Event{send}, script.Recv{controlChan, ConnectionState{false, "", expectedAlert}}); - close2 := script.NewEvent("controlChan close", []*script.Event{recvState}, script.Closed{controlChan}); + send := script.NewEvent("send", nil, script.Send{msgChan, clientHello}) + recvAlert := script.NewEvent("recv alert", []*script.Event{send}, script.Recv{writeChan, alert{alertLevelError, expectedAlert}}) + close1 := script.NewEvent("msgChan close", []*script.Event{recvAlert}, script.Closed{writeChan}) + recvState := script.NewEvent("recv state", []*script.Event{send}, script.Recv{controlChan, ConnectionState{false, "", expectedAlert}}) + close2 := script.NewEvent("controlChan close", []*script.Event{recvState}, script.Closed{controlChan}) - err := script.Perform(0, []*script.Event{send, recvAlert, close1, recvState, close2}); + err := script.Perform(0, []*script.Event{send, recvAlert, close1, recvState, close2}) if err != nil { t.Errorf("Got error: %s", err) } @@ -67,125 +67,125 @@ func TestSimpleError(t *testing.T) { var badProtocolVersions = []uint8{0, 0, 0, 5, 1, 0, 1, 5, 2, 0, 2, 5, 3, 0} func TestRejectBadProtocolVersion(t *testing.T) { - clientHello := new(clientHelloMsg); + clientHello := new(clientHelloMsg) for i := 0; i < len(badProtocolVersions); i += 2 { - clientHello.major = badProtocolVersions[i]; - clientHello.minor = badProtocolVersions[i+1]; + clientHello.major = badProtocolVersions[i] + clientHello.minor = badProtocolVersions[i+1] - testClientHelloFailure(t, clientHello, alertProtocolVersion); + testClientHelloFailure(t, clientHello, alertProtocolVersion) } } func TestNoSuiteOverlap(t *testing.T) { - clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{0xff00}, []uint8{0}}; - testClientHelloFailure(t, clientHello, alertHandshakeFailure); + clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{0xff00}, []uint8{0}} + testClientHelloFailure(t, clientHello, alertHandshakeFailure) } func TestNoCompressionOverlap(t *testing.T) { - clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}}; - testClientHelloFailure(t, clientHello, alertHandshakeFailure); + clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}} + testClientHelloFailure(t, clientHello, alertHandshakeFailure) } func matchServerHello(v interface{}) bool { - serverHello, ok := v.(*serverHelloMsg); + serverHello, ok := v.(*serverHelloMsg) if !ok { return false } return serverHello.major == 3 && serverHello.minor == 2 && serverHello.cipherSuite == TLS_RSA_WITH_RC4_128_SHA && - serverHello.compressionMethod == compressionNone; + serverHello.compressionMethod == compressionNone } func TestAlertForwarding(t *testing.T) { - writeChan, controlChan, msgChan := setupServerHandshake(); - defer close(msgChan); + writeChan, controlChan, msgChan := setupServerHandshake() + defer close(msgChan) - a := alert{alertLevelError, alertNoRenegotiation}; - sendAlert := script.NewEvent("send alert", nil, script.Send{msgChan, a}); - recvAlert := script.NewEvent("recv alert", []*script.Event{sendAlert}, script.Recv{writeChan, a}); - closeWriter := script.NewEvent("close writer", []*script.Event{recvAlert}, script.Closed{writeChan}); - closeControl := script.NewEvent("close control", []*script.Event{recvAlert}, script.Closed{controlChan}); + a := alert{alertLevelError, alertNoRenegotiation} + sendAlert := script.NewEvent("send alert", nil, script.Send{msgChan, a}) + recvAlert := script.NewEvent("recv alert", []*script.Event{sendAlert}, script.Recv{writeChan, a}) + closeWriter := script.NewEvent("close writer", []*script.Event{recvAlert}, script.Closed{writeChan}) + closeControl := script.NewEvent("close control", []*script.Event{recvAlert}, script.Closed{controlChan}) - err := script.Perform(0, []*script.Event{sendAlert, recvAlert, closeWriter, closeControl}); + err := script.Perform(0, []*script.Event{sendAlert, recvAlert, closeWriter, closeControl}) if err != nil { t.Errorf("Got error: %s", err) } } func TestClose(t *testing.T) { - writeChan, controlChan, msgChan := setupServerHandshake(); + writeChan, controlChan, msgChan := setupServerHandshake() - close := script.NewEvent("close", nil, script.Close{msgChan}); - closed1 := script.NewEvent("closed1", []*script.Event{close}, script.Closed{writeChan}); - closed2 := script.NewEvent("closed2", []*script.Event{close}, script.Closed{controlChan}); + close := script.NewEvent("close", nil, script.Close{msgChan}) + closed1 := script.NewEvent("closed1", []*script.Event{close}, script.Closed{writeChan}) + closed2 := script.NewEvent("closed2", []*script.Event{close}, script.Closed{controlChan}) - err := script.Perform(0, []*script.Event{close, closed1, closed2}); + err := script.Perform(0, []*script.Event{close, closed1, closed2}) if err != nil { t.Errorf("Got error: %s", err) } } func matchCertificate(v interface{}) bool { - cert, ok := v.(*certificateMsg); + cert, ok := v.(*certificateMsg) if !ok { return false } return len(cert.certificates) == 1 && - bytes.Compare(cert.certificates[0], testCertificate) == 0; + bytes.Compare(cert.certificates[0], testCertificate) == 0 } func matchSetCipher(v interface{}) bool { - _, ok := v.(writerChangeCipherSpec); - return ok; + _, ok := v.(writerChangeCipherSpec) + return ok } func matchDone(v interface{}) bool { - _, ok := v.(*serverHelloDoneMsg); - return ok; + _, ok := v.(*serverHelloDoneMsg) + return ok } func matchFinished(v interface{}) bool { - finished, ok := v.(*finishedMsg); + finished, ok := v.(*finishedMsg) if !ok { return false } - return bytes.Compare(finished.verifyData, fromHex("29122ae11453e631487b02ed")) == 0; + return bytes.Compare(finished.verifyData, fromHex("29122ae11453e631487b02ed")) == 0 } func matchNewCipherSpec(v interface{}) bool { - _, ok := v.(*newCipherSpec); - return ok; + _, ok := v.(*newCipherSpec) + return ok } func TestFullHandshake(t *testing.T) { - writeChan, controlChan, msgChan := setupServerHandshake(); - defer close(msgChan); + writeChan, controlChan, msgChan := setupServerHandshake() + defer close(msgChan) // The values for this test were obtained from running `gnutls-cli --insecure --debug 9` - clientHello := &clientHelloMsg{fromHex("0100007603024aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b310000340033004500390088001600320044003800870013006600900091008f008e002f004100350084000a00050004008c008d008b008a01000019000900030200010000000e000c0000093132372e302e302e31"), 3, 2, fromHex("4aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b31"), nil, []uint16{0x33, 0x45, 0x39, 0x88, 0x16, 0x32, 0x44, 0x38, 0x87, 0x13, 0x66, 0x90, 0x91, 0x8f, 0x8e, 0x2f, 0x41, 0x35, 0x84, 0xa, 0x5, 0x4, 0x8c, 0x8d, 0x8b, 0x8a}, []uint8{0x0}}; + clientHello := &clientHelloMsg{fromHex("0100007603024aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b310000340033004500390088001600320044003800870013006600900091008f008e002f004100350084000a00050004008c008d008b008a01000019000900030200010000000e000c0000093132372e302e302e31"), 3, 2, fromHex("4aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b31"), nil, []uint16{0x33, 0x45, 0x39, 0x88, 0x16, 0x32, 0x44, 0x38, 0x87, 0x13, 0x66, 0x90, 0x91, 0x8f, 0x8e, 0x2f, 0x41, 0x35, 0x84, 0xa, 0x5, 0x4, 0x8c, 0x8d, 0x8b, 0x8a}, []uint8{0x0}} - sendHello := script.NewEvent("send hello", nil, script.Send{msgChan, clientHello}); - setVersion := script.NewEvent("set version", []*script.Event{sendHello}, script.Recv{writeChan, writerSetVersion{3, 2}}); - recvHello := script.NewEvent("recv hello", []*script.Event{setVersion}, script.RecvMatch{writeChan, matchServerHello}); - recvCert := script.NewEvent("recv cert", []*script.Event{recvHello}, script.RecvMatch{writeChan, matchCertificate}); - recvDone := script.NewEvent("recv done", []*script.Event{recvCert}, script.RecvMatch{writeChan, matchDone}); + sendHello := script.NewEvent("send hello", nil, script.Send{msgChan, clientHello}) + setVersion := script.NewEvent("set version", []*script.Event{sendHello}, script.Recv{writeChan, writerSetVersion{3, 2}}) + recvHello := script.NewEvent("recv hello", []*script.Event{setVersion}, script.RecvMatch{writeChan, matchServerHello}) + recvCert := script.NewEvent("recv cert", []*script.Event{recvHello}, script.RecvMatch{writeChan, matchCertificate}) + recvDone := script.NewEvent("recv done", []*script.Event{recvCert}, script.RecvMatch{writeChan, matchDone}) - ckx := &clientKeyExchangeMsg{nil, fromHex("872e1fee5f37dd86f3215938ac8de20b302b90074e9fb93097e6b7d1286d0f45abf2daf179deb618bb3c70ed0afee6ee24476ee4649e5a23358143c0f1d9c251")}; - sendCKX := script.NewEvent("send ckx", []*script.Event{recvDone}, script.Send{msgChan, ckx}); + ckx := &clientKeyExchangeMsg{nil, fromHex("872e1fee5f37dd86f3215938ac8de20b302b90074e9fb93097e6b7d1286d0f45abf2daf179deb618bb3c70ed0afee6ee24476ee4649e5a23358143c0f1d9c251")} + sendCKX := script.NewEvent("send ckx", []*script.Event{recvDone}, script.Send{msgChan, ckx}) - sendCCS := script.NewEvent("send ccs", []*script.Event{sendCKX}, script.Send{msgChan, changeCipherSpec{}}); - recvNCS := script.NewEvent("recv done", []*script.Event{sendCCS}, script.RecvMatch{controlChan, matchNewCipherSpec}); + sendCCS := script.NewEvent("send ccs", []*script.Event{sendCKX}, script.Send{msgChan, changeCipherSpec{}}) + recvNCS := script.NewEvent("recv done", []*script.Event{sendCCS}, script.RecvMatch{controlChan, matchNewCipherSpec}) - finished := &finishedMsg{nil, fromHex("c8faca5d242f4423325c5b1a")}; - sendFinished := script.NewEvent("send finished", []*script.Event{recvNCS}, script.Send{msgChan, finished}); - recvFinished := script.NewEvent("recv finished", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchFinished}); - setCipher := script.NewEvent("set cipher", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchSetCipher}); - recvConnectionState := script.NewEvent("recv state", []*script.Event{sendFinished}, script.Recv{controlChan, ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0}}); + finished := &finishedMsg{nil, fromHex("c8faca5d242f4423325c5b1a")} + sendFinished := script.NewEvent("send finished", []*script.Event{recvNCS}, script.Send{msgChan, finished}) + recvFinished := script.NewEvent("recv finished", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchFinished}) + setCipher := script.NewEvent("set cipher", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchSetCipher}) + recvConnectionState := script.NewEvent("recv state", []*script.Event{sendFinished}, script.Recv{controlChan, ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0}}) - err := script.Perform(0, []*script.Event{sendHello, setVersion, recvHello, recvCert, recvDone, sendCKX, sendCCS, recvNCS, sendFinished, setCipher, recvConnectionState, recvFinished}); + err := script.Perform(0, []*script.Event{sendHello, setVersion, recvHello, recvCert, recvDone, sendCKX, sendCCS, recvNCS, sendFinished, setCipher, recvConnectionState, recvFinished}) if err != nil { t.Errorf("Got error: %s", err) } @@ -194,9 +194,9 @@ func TestFullHandshake(t *testing.T) { var testCertificate = fromHex("3082025930820203a003020102020900c2ec326b95228959300d06092a864886f70d01010505003054310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464310d300b0603550403130474657374301e170d3039313032303232323434355a170d3130313032303232323434355a3054310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464310d300b0603550403130474657374305c300d06092a864886f70d0101010500034b003048024100b2990f49c47dfa8cd400ae6a4d1b8a3b6a13642b23f28b003bfb97790ade9a4cc82b8b2a81747ddec08b6296e53a08c331687ef25c4bf4936ba1c0e6041e9d150203010001a381b73081b4301d0603551d0e0416041478a06086837c9293a8c9b70c0bdabdb9d77eeedf3081840603551d23047d307b801478a06086837c9293a8c9b70c0bdabdb9d77eeedfa158a4563054310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464310d300b0603550403130474657374820900c2ec326b95228959300c0603551d13040530030101ff300d06092a864886f70d0101050500034100ac23761ae1349d85a439caad4d0b932b09ea96de1917c3e0507c446f4838cb3076fb4d431db8c1987e96f1d7a8a2054dea3a64ec99a3f0eda4d47a163bf1f6ac") func bigFromString(s string) *big.Int { - ret := new(big.Int); - ret.SetString(s, 10); - return ret; + ret := new(big.Int) + ret.SetString(s, 10) + return ret } var testPrivateKey = &rsa.PrivateKey{ diff --git a/prf.go b/prf.go index b89b59c..6b9c44c 100644 --- a/prf.go +++ b/prf.go @@ -5,59 +5,59 @@ package tls import ( - "crypto/hmac"; - "crypto/md5"; - "crypto/sha1"; - "hash"; - "os"; - "strings"; + "crypto/hmac" + "crypto/md5" + "crypto/sha1" + "hash" + "os" + "strings" ) // Split a premaster secret in two as specified in RFC 4346, section 5. func splitPreMasterSecret(secret []byte) (s1, s2 []byte) { - s1 = secret[0 : (len(secret)+1)/2]; - s2 = secret[len(secret)/2:]; - return; + s1 = secret[0 : (len(secret)+1)/2] + s2 = secret[len(secret)/2:] + return } // pHash implements the P_hash function, as defined in RFC 4346, section 5. func pHash(result, secret, seed []byte, hash hash.Hash) { - h := hmac.New(hash, secret); - h.Write(seed); - a := h.Sum(); + h := hmac.New(hash, secret) + h.Write(seed) + a := h.Sum() - j := 0; + j := 0 for j < len(result) { - h.Reset(); - h.Write(a); - h.Write(seed); - b := h.Sum(); - todo := len(b); + h.Reset() + h.Write(a) + h.Write(seed) + b := h.Sum() + todo := len(b) if j+todo > len(result) { todo = len(result) - j } - copy(result[j:j+todo], b); - j += todo; + copy(result[j:j+todo], b) + j += todo - h.Reset(); - h.Write(a); - a = h.Sum(); + h.Reset() + h.Write(a) + a = h.Sum() } } // pRF11 implements the TLS 1.1 pseudo-random function, as defined in RFC 4346, section 5. func pRF11(result, secret, label, seed []byte) { - hashSHA1 := sha1.New(); - hashMD5 := md5.New(); + hashSHA1 := sha1.New() + hashMD5 := md5.New() - labelAndSeed := make([]byte, len(label)+len(seed)); - copy(labelAndSeed, label); - copy(labelAndSeed[len(label):], seed); + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) - s1, s2 := splitPreMasterSecret(secret); - pHash(result, s1, labelAndSeed, hashMD5); - result2 := make([]byte, len(result)); - pHash(result2, s2, labelAndSeed, hashSHA1); + s1, s2 := splitPreMasterSecret(secret) + pHash(result, s1, labelAndSeed, hashMD5) + result2 := make([]byte, len(result)) + pHash(result2, s2, labelAndSeed, hashSHA1) for i, b := range result2 { result[i] ^= b @@ -65,9 +65,9 @@ func pRF11(result, secret, label, seed []byte) { } const ( - tlsRandomLength = 32; // Length of a random nonce in TLS 1.1. - masterSecretLength = 48; // Length of a master secret in TLS 1.1. - finishedVerifyLength = 12; // Length of verify_data in a Finished message. + tlsRandomLength = 32 // Length of a random nonce in TLS 1.1. + masterSecretLength = 48 // Length of a master secret in TLS 1.1. + finishedVerifyLength = 12 // Length of verify_data in a Finished message. ) var masterSecretLabel = strings.Bytes("master secret") @@ -79,32 +79,32 @@ var serverFinishedLabel = strings.Bytes("server finished") // secret, given the lengths of the MAC and cipher keys, as defined in RFC // 4346, section 6.3. func keysFromPreMasterSecret11(preMasterSecret, clientRandom, serverRandom []byte, macLen, keyLen int) (masterSecret, clientMAC, serverMAC, clientKey, serverKey []byte) { - var seed [tlsRandomLength * 2]byte; - copy(seed[0:len(clientRandom)], clientRandom); - copy(seed[len(clientRandom):], serverRandom); - masterSecret = make([]byte, masterSecretLength); - pRF11(masterSecret, preMasterSecret, masterSecretLabel, seed[0:]); - - copy(seed[0:len(clientRandom)], serverRandom); - copy(seed[len(serverRandom):], clientRandom); - - n := 2*macLen + 2*keyLen; - keyMaterial := make([]byte, n); - pRF11(keyMaterial, masterSecret, keyExpansionLabel, seed[0:]); - clientMAC = keyMaterial[0:macLen]; - serverMAC = keyMaterial[macLen : macLen*2]; - clientKey = keyMaterial[macLen*2 : macLen*2+keyLen]; - serverKey = keyMaterial[macLen*2+keyLen:]; - return; + var seed [tlsRandomLength * 2]byte + copy(seed[0:len(clientRandom)], clientRandom) + copy(seed[len(clientRandom):], serverRandom) + masterSecret = make([]byte, masterSecretLength) + pRF11(masterSecret, preMasterSecret, masterSecretLabel, seed[0:]) + + copy(seed[0:len(clientRandom)], serverRandom) + copy(seed[len(serverRandom):], clientRandom) + + n := 2*macLen + 2*keyLen + keyMaterial := make([]byte, n) + pRF11(keyMaterial, masterSecret, keyExpansionLabel, seed[0:]) + clientMAC = keyMaterial[0:macLen] + serverMAC = keyMaterial[macLen : macLen*2] + clientKey = keyMaterial[macLen*2 : macLen*2+keyLen] + serverKey = keyMaterial[macLen*2+keyLen:] + return } // A finishedHash calculates the hash of a set of handshake messages suitable // for including in a Finished message. type finishedHash struct { - clientMD5 hash.Hash; - clientSHA1 hash.Hash; - serverMD5 hash.Hash; - serverSHA1 hash.Hash; + clientMD5 hash.Hash + clientSHA1 hash.Hash + serverMD5 hash.Hash + serverSHA1 hash.Hash } func newFinishedHash() finishedHash { @@ -112,36 +112,36 @@ func newFinishedHash() finishedHash { } func (h finishedHash) Write(msg []byte) (n int, err os.Error) { - h.clientMD5.Write(msg); - h.clientSHA1.Write(msg); - h.serverMD5.Write(msg); - h.serverSHA1.Write(msg); - return len(msg), nil; + h.clientMD5.Write(msg) + h.clientSHA1.Write(msg) + h.serverMD5.Write(msg) + h.serverSHA1.Write(msg) + return len(msg), nil } // finishedSum calculates the contents of the verify_data member of a Finished // message given the MD5 and SHA1 hashes of a set of handshake messages. func finishedSum(md5, sha1, label, masterSecret []byte) []byte { - seed := make([]byte, len(md5)+len(sha1)); - copy(seed, md5); - copy(seed[len(md5):], sha1); - out := make([]byte, finishedVerifyLength); - pRF11(out, masterSecret, label, seed); - return out; + seed := make([]byte, len(md5)+len(sha1)) + copy(seed, md5) + copy(seed[len(md5):], sha1) + out := make([]byte, finishedVerifyLength) + pRF11(out, masterSecret, label, seed) + return out } // clientSum returns the contents of the verify_data member of a client's // Finished message. func (h finishedHash) clientSum(masterSecret []byte) []byte { - md5 := h.clientMD5.Sum(); - sha1 := h.clientSHA1.Sum(); - return finishedSum(md5, sha1, clientFinishedLabel, masterSecret); + md5 := h.clientMD5.Sum() + sha1 := h.clientSHA1.Sum() + return finishedSum(md5, sha1, clientFinishedLabel, masterSecret) } // serverSum returns the contents of the verify_data member of a server's // Finished message. func (h finishedHash) serverSum(masterSecret []byte) []byte { - md5 := h.serverMD5.Sum(); - sha1 := h.serverSHA1.Sum(); - return finishedSum(md5, sha1, serverFinishedLabel, masterSecret); + md5 := h.serverMD5.Sum() + sha1 := h.serverSHA1.Sum() + return finishedSum(md5, sha1, serverFinishedLabel, masterSecret) } diff --git a/prf_test.go b/prf_test.go index 0d4a4db..5c23f36 100644 --- a/prf_test.go +++ b/prf_test.go @@ -5,12 +5,12 @@ package tls import ( - "encoding/hex"; - "testing"; + "encoding/hex" + "testing" ) type testSplitPreMasterSecretTest struct { - in, out1, out2 string; + in, out1, out2 string } var testSplitPreMasterSecretTests = []testSplitPreMasterSecretTest{ @@ -23,10 +23,10 @@ var testSplitPreMasterSecretTests = []testSplitPreMasterSecretTest{ func TestSplitPreMasterSecret(t *testing.T) { for i, test := range testSplitPreMasterSecretTests { - in, _ := hex.DecodeString(test.in); - out1, out2 := splitPreMasterSecret(in); - s1 := hex.EncodeToString(out1); - s2 := hex.EncodeToString(out2); + in, _ := hex.DecodeString(test.in) + out1, out2 := splitPreMasterSecret(in) + s1 := hex.EncodeToString(out1) + s2 := hex.EncodeToString(out2) if s1 != test.out1 || s2 != test.out2 { t.Errorf("#%d: got: (%s, %s) want: (%s, %s)", i, s1, s2, test.out1, test.out2) } @@ -34,25 +34,25 @@ func TestSplitPreMasterSecret(t *testing.T) { } type testKeysFromTest struct { - preMasterSecret string; - clientRandom, serverRandom string; - masterSecret string; - clientMAC, serverMAC string; - clientKey, serverKey string; - macLen, keyLen int; + preMasterSecret string + clientRandom, serverRandom string + masterSecret string + clientMAC, serverMAC string + clientKey, serverKey string + macLen, keyLen int } func TestKeysFromPreMasterSecret(t *testing.T) { for i, test := range testKeysFromTests { - in, _ := hex.DecodeString(test.preMasterSecret); - clientRandom, _ := hex.DecodeString(test.clientRandom); - serverRandom, _ := hex.DecodeString(test.serverRandom); - master, clientMAC, serverMAC, clientKey, serverKey := keysFromPreMasterSecret11(in, clientRandom, serverRandom, test.macLen, test.keyLen); - masterString := hex.EncodeToString(master); - clientMACString := hex.EncodeToString(clientMAC); - serverMACString := hex.EncodeToString(serverMAC); - clientKeyString := hex.EncodeToString(clientKey); - serverKeyString := hex.EncodeToString(serverKey); + in, _ := hex.DecodeString(test.preMasterSecret) + clientRandom, _ := hex.DecodeString(test.clientRandom) + serverRandom, _ := hex.DecodeString(test.serverRandom) + master, clientMAC, serverMAC, clientKey, serverKey := keysFromPreMasterSecret11(in, clientRandom, serverRandom, test.macLen, test.keyLen) + masterString := hex.EncodeToString(master) + clientMACString := hex.EncodeToString(clientMAC) + serverMACString := hex.EncodeToString(serverMAC) + clientKeyString := hex.EncodeToString(clientKey) + serverKeyString := hex.EncodeToString(serverKey) if masterString != test.masterSecret || clientMACString != test.clientMAC || serverMACString != test.serverMAC || diff --git a/record_process.go b/record_process.go index e356d67..ddeca0e 100644 --- a/record_process.go +++ b/record_process.go @@ -10,27 +10,27 @@ package tls // state, or for a notification when the state changes. import ( - "container/list"; - "crypto/subtle"; - "hash"; + "container/list" + "crypto/subtle" + "hash" ) // getConnectionState is a request from the application to get the current // ConnectionState. type getConnectionState struct { - reply chan<- ConnectionState; + reply chan<- ConnectionState } // waitConnectionState is a request from the application to be notified when // the connection state changes. type waitConnectionState struct { - reply chan<- ConnectionState; + reply chan<- ConnectionState } // connectionStateChange is a message from the handshake processor that the // connection state has changed. type connectionStateChange struct { - connState ConnectionState; + connState ConnectionState } // changeCipherSpec is a message send to the handshake processor to signal that @@ -40,32 +40,32 @@ type changeCipherSpec struct{} // newCipherSpec is a message from the handshake processor that future // records should be processed with a new cipher and MAC function. type newCipherSpec struct { - encrypt encryptor; - mac hash.Hash; + encrypt encryptor + mac hash.Hash } type recordProcessor struct { - decrypt encryptor; - mac hash.Hash; - seqNum uint64; - handshakeBuf []byte; - appDataChan chan<- []byte; - requestChan <-chan interface{}; - controlChan <-chan interface{}; - recordChan <-chan *record; - handshakeChan chan<- interface{}; + decrypt encryptor + mac hash.Hash + seqNum uint64 + handshakeBuf []byte + appDataChan chan<- []byte + requestChan <-chan interface{} + controlChan <-chan interface{} + recordChan <-chan *record + handshakeChan chan<- interface{} // recordRead is nil when we don't wish to read any more. - recordRead <-chan *record; + recordRead <-chan *record // appDataSend is nil when len(appData) == 0. - appDataSend chan<- []byte; + appDataSend chan<- []byte // appData contains any application data queued for upstream. - appData []byte; + appData []byte // A list of channels waiting for connState to change. - waitQueue *list.List; - connState ConnectionState; - shutdown bool; - header [13]byte; + waitQueue *list.List + connState ConnectionState + shutdown bool + header [13]byte } // drainRequestChannel processes messages from the request channel until it's closed. @@ -84,24 +84,24 @@ func drainRequestChannel(requestChan <-chan interface{}, c ConnectionState) { } func (p *recordProcessor) loop(appDataChan chan<- []byte, requestChan <-chan interface{}, controlChan <-chan interface{}, recordChan <-chan *record, handshakeChan chan<- interface{}) { - noop := nop{}; - p.decrypt = noop; - p.mac = noop; - p.waitQueue = list.New(); - - p.appDataChan = appDataChan; - p.requestChan = requestChan; - p.controlChan = controlChan; - p.recordChan = recordChan; - p.handshakeChan = handshakeChan; - p.recordRead = recordChan; + noop := nop{} + p.decrypt = noop + p.mac = noop + p.waitQueue = list.New() + + p.appDataChan = appDataChan + p.requestChan = requestChan + p.controlChan = controlChan + p.recordChan = recordChan + p.handshakeChan = handshakeChan + p.recordRead = recordChan for !p.shutdown { select { case p.appDataSend <- p.appData: - p.appData = nil; - p.appDataSend = nil; - p.recordRead = p.recordChan; + p.appData = nil + p.appDataSend = nil + p.recordRead = p.recordChan case c := <-controlChan: p.processControlMsg(c) case r := <-requestChan: @@ -111,24 +111,24 @@ func (p *recordProcessor) loop(appDataChan chan<- []byte, requestChan <-chan int } } - p.wakeWaiters(); - go drainRequestChannel(p.requestChan, p.connState); + p.wakeWaiters() + go drainRequestChannel(p.requestChan, p.connState) go func() { for _ = range controlChan { } - }(); + }() - close(handshakeChan); + close(handshakeChan) if len(p.appData) > 0 { appDataChan <- p.appData } - close(appDataChan); + close(appDataChan) } func (p *recordProcessor) processRequestMsg(requestMsg interface{}) { if closed(p.requestChan) { - p.shutdown = true; - return; + p.shutdown = true + return } switch r := requestMsg.(type) { @@ -138,51 +138,51 @@ func (p *recordProcessor) processRequestMsg(requestMsg interface{}) { if p.connState.HandshakeComplete { r.reply <- p.connState } - p.waitQueue.PushBack(r.reply); + p.waitQueue.PushBack(r.reply) } } func (p *recordProcessor) processControlMsg(msg interface{}) { - connState, ok := msg.(ConnectionState); + connState, ok := msg.(ConnectionState) if !ok || closed(p.controlChan) { - p.shutdown = true; - return; + p.shutdown = true + return } - p.connState = connState; - p.wakeWaiters(); + p.connState = connState + p.wakeWaiters() } func (p *recordProcessor) wakeWaiters() { for i := p.waitQueue.Front(); i != nil; i = i.Next() { i.Value.(chan<- ConnectionState) <- p.connState } - p.waitQueue.Init(); + p.waitQueue.Init() } func (p *recordProcessor) processRecord(r *record) { if closed(p.recordChan) { - p.shutdown = true; - return; + p.shutdown = true + return } - p.decrypt.XORKeyStream(r.payload); + p.decrypt.XORKeyStream(r.payload) if len(r.payload) < p.mac.Size() { - p.error(alertBadRecordMAC); - return; + p.error(alertBadRecordMAC) + return } - fillMACHeader(&p.header, p.seqNum, len(r.payload)-p.mac.Size(), r); - p.seqNum++; + fillMACHeader(&p.header, p.seqNum, len(r.payload)-p.mac.Size(), r) + p.seqNum++ - p.mac.Reset(); - p.mac.Write(p.header[0:13]); - p.mac.Write(r.payload[0 : len(r.payload)-p.mac.Size()]); - macBytes := p.mac.Sum(); + p.mac.Reset() + p.mac.Write(p.header[0:13]) + p.mac.Write(r.payload[0 : len(r.payload)-p.mac.Size()]) + macBytes := p.mac.Sum() if subtle.ConstantTimeCompare(macBytes, r.payload[len(r.payload)-p.mac.Size():]) != 1 { - p.error(alertBadRecordMAC); - return; + p.error(alertBadRecordMAC) + return } switch r.contentType { @@ -190,31 +190,31 @@ func (p *recordProcessor) processRecord(r *record) { p.processHandshakeRecord(r.payload[0 : len(r.payload)-p.mac.Size()]) case recordTypeChangeCipherSpec: if len(r.payload) != 1 || r.payload[0] != 1 { - p.error(alertUnexpectedMessage); - return; + p.error(alertUnexpectedMessage) + return } - p.handshakeChan <- changeCipherSpec{}; - newSpec, ok := (<-p.controlChan).(*newCipherSpec); + p.handshakeChan <- changeCipherSpec{} + newSpec, ok := (<-p.controlChan).(*newCipherSpec) if !ok { - p.connState.Error = alertUnexpectedMessage; - p.shutdown = true; - return; + p.connState.Error = alertUnexpectedMessage + p.shutdown = true + return } - p.decrypt = newSpec.encrypt; - p.mac = newSpec.mac; - p.seqNum = 0; + p.decrypt = newSpec.encrypt + p.mac = newSpec.mac + p.seqNum = 0 case recordTypeApplicationData: if p.connState.HandshakeComplete == false { - p.error(alertUnexpectedMessage); - return; + p.error(alertUnexpectedMessage) + return } - p.recordRead = nil; - p.appData = r.payload[0 : len(r.payload)-p.mac.Size()]; - p.appDataSend = p.appDataChan; + p.recordRead = nil + p.appData = r.payload[0 : len(r.payload)-p.mac.Size()] + p.appDataSend = p.appDataChan default: - p.error(alertUnexpectedMessage); - return; + p.error(alertUnexpectedMessage) + return } } @@ -223,61 +223,61 @@ func (p *recordProcessor) processHandshakeRecord(data []byte) { p.handshakeBuf = data } else { if len(p.handshakeBuf) > maxHandshakeMsg { - p.error(alertInternalError); - return; + p.error(alertInternalError) + return } - newBuf := make([]byte, len(p.handshakeBuf)+len(data)); - copy(newBuf, p.handshakeBuf); - copy(newBuf[len(p.handshakeBuf):], data); - p.handshakeBuf = newBuf; + newBuf := make([]byte, len(p.handshakeBuf)+len(data)) + copy(newBuf, p.handshakeBuf) + copy(newBuf[len(p.handshakeBuf):], data) + p.handshakeBuf = newBuf } for len(p.handshakeBuf) >= 4 { handshakeLen := int(p.handshakeBuf[1])<<16 | int(p.handshakeBuf[2])<<8 | - int(p.handshakeBuf[3]); + int(p.handshakeBuf[3]) if handshakeLen+4 > len(p.handshakeBuf) { break } - bytes := p.handshakeBuf[0 : handshakeLen+4]; - p.handshakeBuf = p.handshakeBuf[handshakeLen+4:]; + bytes := p.handshakeBuf[0 : handshakeLen+4] + p.handshakeBuf = p.handshakeBuf[handshakeLen+4:] if bytes[0] == typeFinished { // Special case because Finished is synchronous: the // handshake handler has to tell us if it's ok to start // forwarding application data. - m := new(finishedMsg); + m := new(finishedMsg) if !m.unmarshal(bytes) { p.error(alertUnexpectedMessage) } - p.handshakeChan <- m; - var ok bool; - p.connState, ok = (<-p.controlChan).(ConnectionState); + p.handshakeChan <- m + var ok bool + p.connState, ok = (<-p.controlChan).(ConnectionState) if !ok || p.connState.Error != 0 { - p.shutdown = true; - return; + p.shutdown = true + return } } else { - msg, ok := parseHandshakeMsg(bytes); + msg, ok := parseHandshakeMsg(bytes) if !ok { - p.error(alertUnexpectedMessage); - return; + p.error(alertUnexpectedMessage) + return } - p.handshakeChan <- msg; + p.handshakeChan <- msg } } } func (p *recordProcessor) error(err alertType) { - close(p.handshakeChan); - p.connState.Error = err; - p.wakeWaiters(); - p.shutdown = true; + close(p.handshakeChan) + p.connState.Error = err + p.wakeWaiters() + p.shutdown = true } func parseHandshakeMsg(data []byte) (interface{}, bool) { var m interface { - unmarshal([]byte) bool; + unmarshal([]byte) bool } switch data[0] { @@ -295,6 +295,6 @@ func parseHandshakeMsg(data []byte) (interface{}, bool) { return nil, false } - ok := m.unmarshal(data); - return m, ok; + ok := m.unmarshal(data) + return m, ok } diff --git a/record_process_test.go b/record_process_test.go index 1d019e3..65ce3eb 100644 --- a/record_process_test.go +++ b/record_process_test.go @@ -5,132 +5,132 @@ package tls import ( - "encoding/hex"; - "testing"; - "testing/script"; + "encoding/hex" + "testing" + "testing/script" ) func setup() (appDataChan chan []byte, requestChan chan interface{}, controlChan chan interface{}, recordChan chan *record, handshakeChan chan interface{}) { - rp := new(recordProcessor); - appDataChan = make(chan []byte); - requestChan = make(chan interface{}); - controlChan = make(chan interface{}); - recordChan = make(chan *record); - handshakeChan = make(chan interface{}); - - go rp.loop(appDataChan, requestChan, controlChan, recordChan, handshakeChan); - return; + rp := new(recordProcessor) + appDataChan = make(chan []byte) + requestChan = make(chan interface{}) + controlChan = make(chan interface{}) + recordChan = make(chan *record) + handshakeChan = make(chan interface{}) + + go rp.loop(appDataChan, requestChan, controlChan, recordChan, handshakeChan) + return } func fromHex(s string) []byte { - b, _ := hex.DecodeString(s); - return b; + b, _ := hex.DecodeString(s) + return b } func TestNullConnectionState(t *testing.T) { - _, requestChan, controlChan, recordChan, _ := setup(); - defer close(requestChan); - defer close(controlChan); - defer close(recordChan); + _, requestChan, controlChan, recordChan, _ := setup() + defer close(requestChan) + defer close(controlChan) + defer close(recordChan) // Test a simple request for the connection state. - replyChan := make(chan ConnectionState); - sendReq := script.NewEvent("send request", nil, script.Send{requestChan, getConnectionState{replyChan}}); - getReply := script.NewEvent("get reply", []*script.Event{sendReq}, script.Recv{replyChan, ConnectionState{false, "", 0}}); + replyChan := make(chan ConnectionState) + sendReq := script.NewEvent("send request", nil, script.Send{requestChan, getConnectionState{replyChan}}) + getReply := script.NewEvent("get reply", []*script.Event{sendReq}, script.Recv{replyChan, ConnectionState{false, "", 0}}) - err := script.Perform(0, []*script.Event{sendReq, getReply}); + err := script.Perform(0, []*script.Event{sendReq, getReply}) if err != nil { t.Errorf("Got error: %s", err) } } func TestWaitConnectionState(t *testing.T) { - _, requestChan, controlChan, recordChan, _ := setup(); - defer close(requestChan); - defer close(controlChan); - defer close(recordChan); + _, requestChan, controlChan, recordChan, _ := setup() + defer close(requestChan) + defer close(controlChan) + defer close(recordChan) // Test that waitConnectionState doesn't get a reply until the connection state changes. - replyChan := make(chan ConnectionState); - sendReq := script.NewEvent("send request", nil, script.Send{requestChan, waitConnectionState{replyChan}}); - replyChan2 := make(chan ConnectionState); - sendReq2 := script.NewEvent("send request 2", []*script.Event{sendReq}, script.Send{requestChan, getConnectionState{replyChan2}}); - getReply2 := script.NewEvent("get reply 2", []*script.Event{sendReq2}, script.Recv{replyChan2, ConnectionState{false, "", 0}}); - sendState := script.NewEvent("send state", []*script.Event{getReply2}, script.Send{controlChan, ConnectionState{true, "test", 1}}); - getReply := script.NewEvent("get reply", []*script.Event{sendState}, script.Recv{replyChan, ConnectionState{true, "test", 1}}); - - err := script.Perform(0, []*script.Event{sendReq, sendReq2, getReply2, sendState, getReply}); + replyChan := make(chan ConnectionState) + sendReq := script.NewEvent("send request", nil, script.Send{requestChan, waitConnectionState{replyChan}}) + replyChan2 := make(chan ConnectionState) + sendReq2 := script.NewEvent("send request 2", []*script.Event{sendReq}, script.Send{requestChan, getConnectionState{replyChan2}}) + getReply2 := script.NewEvent("get reply 2", []*script.Event{sendReq2}, script.Recv{replyChan2, ConnectionState{false, "", 0}}) + sendState := script.NewEvent("send state", []*script.Event{getReply2}, script.Send{controlChan, ConnectionState{true, "test", 1}}) + getReply := script.NewEvent("get reply", []*script.Event{sendState}, script.Recv{replyChan, ConnectionState{true, "test", 1}}) + + err := script.Perform(0, []*script.Event{sendReq, sendReq2, getReply2, sendState, getReply}) if err != nil { t.Errorf("Got error: %s", err) } } func TestHandshakeAssembly(t *testing.T) { - _, requestChan, controlChan, recordChan, handshakeChan := setup(); - defer close(requestChan); - defer close(controlChan); - defer close(recordChan); + _, requestChan, controlChan, recordChan, handshakeChan := setup() + defer close(requestChan) + defer close(controlChan) + defer close(recordChan) // Test the reassembly of a fragmented handshake message. - send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("10000003")}}); - send2 := script.NewEvent("send 2", []*script.Event{send1}, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("0001")}}); - send3 := script.NewEvent("send 3", []*script.Event{send2}, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("42")}}); - recvMsg := script.NewEvent("recv", []*script.Event{send3}, script.Recv{handshakeChan, &clientKeyExchangeMsg{fromHex("10000003000142"), fromHex("42")}}); + send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("10000003")}}) + send2 := script.NewEvent("send 2", []*script.Event{send1}, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("0001")}}) + send3 := script.NewEvent("send 3", []*script.Event{send2}, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("42")}}) + recvMsg := script.NewEvent("recv", []*script.Event{send3}, script.Recv{handshakeChan, &clientKeyExchangeMsg{fromHex("10000003000142"), fromHex("42")}}) - err := script.Perform(0, []*script.Event{send1, send2, send3, recvMsg}); + err := script.Perform(0, []*script.Event{send1, send2, send3, recvMsg}) if err != nil { t.Errorf("Got error: %s", err) } } func TestEarlyApplicationData(t *testing.T) { - _, requestChan, controlChan, recordChan, handshakeChan := setup(); - defer close(requestChan); - defer close(controlChan); - defer close(recordChan); + _, requestChan, controlChan, recordChan, handshakeChan := setup() + defer close(requestChan) + defer close(controlChan) + defer close(recordChan) // Test that applicaton data received before the handshake has completed results in an error. - send := script.NewEvent("send", nil, script.Send{recordChan, &record{recordTypeApplicationData, 0, 0, fromHex("")}}); - recv := script.NewEvent("recv", []*script.Event{send}, script.Closed{handshakeChan}); + send := script.NewEvent("send", nil, script.Send{recordChan, &record{recordTypeApplicationData, 0, 0, fromHex("")}}) + recv := script.NewEvent("recv", []*script.Event{send}, script.Closed{handshakeChan}) - err := script.Perform(0, []*script.Event{send, recv}); + err := script.Perform(0, []*script.Event{send, recv}) if err != nil { t.Errorf("Got error: %s", err) } } func TestApplicationData(t *testing.T) { - appDataChan, requestChan, controlChan, recordChan, handshakeChan := setup(); - defer close(requestChan); - defer close(controlChan); - defer close(recordChan); + appDataChan, requestChan, controlChan, recordChan, handshakeChan := setup() + defer close(requestChan) + defer close(controlChan) + defer close(recordChan) // Test that the application data is forwarded after a successful Finished message. - send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("1400000c000000000000000000000000")}}); - recv1 := script.NewEvent("recv finished", []*script.Event{send1}, script.Recv{handshakeChan, &finishedMsg{fromHex("1400000c000000000000000000000000"), fromHex("000000000000000000000000")}}); - send2 := script.NewEvent("send connState", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{true, "", 0}}); - send3 := script.NewEvent("send 2", []*script.Event{send2}, script.Send{recordChan, &record{recordTypeApplicationData, 0, 0, fromHex("0102")}}); - recv2 := script.NewEvent("recv data", []*script.Event{send3}, script.Recv{appDataChan, []byte{0x01, 0x02}}); + send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("1400000c000000000000000000000000")}}) + recv1 := script.NewEvent("recv finished", []*script.Event{send1}, script.Recv{handshakeChan, &finishedMsg{fromHex("1400000c000000000000000000000000"), fromHex("000000000000000000000000")}}) + send2 := script.NewEvent("send connState", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{true, "", 0}}) + send3 := script.NewEvent("send 2", []*script.Event{send2}, script.Send{recordChan, &record{recordTypeApplicationData, 0, 0, fromHex("0102")}}) + recv2 := script.NewEvent("recv data", []*script.Event{send3}, script.Recv{appDataChan, []byte{0x01, 0x02}}) - err := script.Perform(0, []*script.Event{send1, recv1, send2, send3, recv2}); + err := script.Perform(0, []*script.Event{send1, recv1, send2, send3, recv2}) if err != nil { t.Errorf("Got error: %s", err) } } func TestInvalidChangeCipherSpec(t *testing.T) { - appDataChan, requestChan, controlChan, recordChan, handshakeChan := setup(); - defer close(requestChan); - defer close(controlChan); - defer close(recordChan); - - send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeChangeCipherSpec, 0, 0, []byte{1}}}); - recv1 := script.NewEvent("recv 1", []*script.Event{send1}, script.Recv{handshakeChan, changeCipherSpec{}}); - send2 := script.NewEvent("send 2", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{false, "", 42}}); - close := script.NewEvent("close 1", []*script.Event{send2}, script.Closed{appDataChan}); - close2 := script.NewEvent("close 2", []*script.Event{send2}, script.Closed{handshakeChan}); - - err := script.Perform(0, []*script.Event{send1, recv1, send2, close, close2}); + appDataChan, requestChan, controlChan, recordChan, handshakeChan := setup() + defer close(requestChan) + defer close(controlChan) + defer close(recordChan) + + send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeChangeCipherSpec, 0, 0, []byte{1}}}) + recv1 := script.NewEvent("recv 1", []*script.Event{send1}, script.Recv{handshakeChan, changeCipherSpec{}}) + send2 := script.NewEvent("send 2", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{false, "", 42}}) + close := script.NewEvent("close 1", []*script.Event{send2}, script.Closed{appDataChan}) + close2 := script.NewEvent("close 2", []*script.Event{send2}, script.Closed{handshakeChan}) + + err := script.Perform(0, []*script.Event{send1, recv1, send2, close, close2}) if err != nil { t.Errorf("Got error: %s", err) } diff --git a/record_read.go b/record_read.go index 518a136..0ddd884 100644 --- a/record_read.go +++ b/record_read.go @@ -9,34 +9,34 @@ package tls // it's outbound channel. On error, it closes its outbound channel. import ( - "io"; - "bufio"; + "io" + "bufio" ) // recordReader loops, reading TLS records from source and writing them to the // given channel. The channel is closed on EOF or on error. func recordReader(c chan<- *record, source io.Reader) { - defer close(c); - buf := bufio.NewReader(source); + defer close(c) + buf := bufio.NewReader(source) for { - var header [5]byte; - n, _ := buf.Read(header[0:]); + var header [5]byte + n, _ := buf.Read(header[0:]) if n != 5 { return } - recordLength := int(header[3])<<8 | int(header[4]); + recordLength := int(header[3])<<8 | int(header[4]) if recordLength > maxTLSCiphertext { return } - payload := make([]byte, recordLength); - n, _ = buf.Read(payload); + payload := make([]byte, recordLength) + n, _ = buf.Read(payload) if n != recordLength { return } - c <- &record{recordType(header[0]), header[1], header[2], payload}; + c <- &record{recordType(header[0]), header[1], header[2], payload} } } diff --git a/record_read_test.go b/record_read_test.go index ce1a8e6..f897599 100644 --- a/record_read_test.go +++ b/record_read_test.go @@ -5,9 +5,9 @@ package tls import ( - "bytes"; - "testing"; - "testing/iotest"; + "bytes" + "testing" + "testing/iotest" ) func matchRecord(r1, r2 *record) bool { @@ -20,12 +20,12 @@ func matchRecord(r1, r2 *record) bool { return r1.contentType == r2.contentType && r1.major == r2.major && r1.minor == r2.minor && - bytes.Compare(r1.payload, r2.payload) == 0; + bytes.Compare(r1.payload, r2.payload) == 0 } type recordReaderTest struct { - in []byte; - out []*record; + in []byte + out []*record } var recordReaderTests = []recordReaderTest{ @@ -42,31 +42,31 @@ var recordReaderTests = []recordReaderTest{ func TestRecordReader(t *testing.T) { for i, test := range recordReaderTests { - buf := bytes.NewBuffer(test.in); - c := make(chan *record); - go recordReader(c, buf); - matchRecordReaderOutput(t, i, test, c); + buf := bytes.NewBuffer(test.in) + c := make(chan *record) + go recordReader(c, buf) + matchRecordReaderOutput(t, i, test, c) - buf = bytes.NewBuffer(test.in); - buf2 := iotest.OneByteReader(buf); - c = make(chan *record); - go recordReader(c, buf2); - matchRecordReaderOutput(t, i*2, test, c); + buf = bytes.NewBuffer(test.in) + buf2 := iotest.OneByteReader(buf) + c = make(chan *record) + go recordReader(c, buf2) + matchRecordReaderOutput(t, i*2, test, c) } } func matchRecordReaderOutput(t *testing.T, i int, test recordReaderTest, c <-chan *record) { for j, r1 := range test.out { - r2 := <-c; + r2 := <-c if r2 == nil { - t.Errorf("#%d truncated after %d values", i, j); - break; + t.Errorf("#%d truncated after %d values", i, j) + break } if !matchRecord(r1, r2) { t.Errorf("#%d (%d) got:%#v want:%#v", i, j, r2, r1) } } - <-c; + <-c if !closed(c) { t.Errorf("#%d: channel didn't close", i) } diff --git a/record_write.go b/record_write.go index f55a214..5f3fb5b 100644 --- a/record_write.go +++ b/record_write.go @@ -5,9 +5,9 @@ package tls import ( - "fmt"; - "hash"; - "io"; + "fmt" + "hash" + "io" ) // writerEnableApplicationData is a message which instructs recordWriter to @@ -17,14 +17,14 @@ type writerEnableApplicationData struct{} // writerChangeCipherSpec updates the encryption and MAC functions and resets // the sequence count. type writerChangeCipherSpec struct { - encryptor encryptor; - mac hash.Hash; + encryptor encryptor + mac hash.Hash } // writerSetVersion sets the version number bytes that we included in the // record header for future records. type writerSetVersion struct { - major, minor uint8; + major, minor uint8 } // A recordWriter accepts messages from the handshake processor and @@ -32,37 +32,37 @@ type writerSetVersion struct { // writing. It doesn't read from the application data channel until the // handshake processor has signaled that the handshake is complete. type recordWriter struct { - writer io.Writer; - encryptor encryptor; - mac hash.Hash; - seqNum uint64; - major, minor uint8; - shutdown bool; - appChan <-chan []byte; - controlChan <-chan interface{}; - header [13]byte; + writer io.Writer + encryptor encryptor + mac hash.Hash + seqNum uint64 + major, minor uint8 + shutdown bool + appChan <-chan []byte + controlChan <-chan interface{} + header [13]byte } func (w *recordWriter) loop(writer io.Writer, appChan <-chan []byte, controlChan <-chan interface{}) { - w.writer = writer; - w.encryptor = nop{}; - w.mac = nop{}; - w.appChan = appChan; - w.controlChan = controlChan; + w.writer = writer + w.encryptor = nop{} + w.mac = nop{} + w.appChan = appChan + w.controlChan = controlChan for !w.shutdown { - msg := <-controlChan; + msg := <-controlChan if _, ok := msg.(writerEnableApplicationData); ok { break } - w.processControlMessage(msg); + w.processControlMessage(msg) } for !w.shutdown { // Always process control messages first. if controlMsg, ok := <-controlChan; ok { - w.processControlMessage(controlMsg); - continue; + w.processControlMessage(controlMsg) + continue } select { @@ -89,58 +89,58 @@ func (w *recordWriter) loop(writer io.Writer, appChan <-chan []byte, controlChan // fillMACHeader generates a MAC header. See RFC 4346, section 6.2.3.1. func fillMACHeader(header *[13]byte, seqNum uint64, length int, r *record) { - header[0] = uint8(seqNum >> 56); - header[1] = uint8(seqNum >> 48); - header[2] = uint8(seqNum >> 40); - header[3] = uint8(seqNum >> 32); - header[4] = uint8(seqNum >> 24); - header[5] = uint8(seqNum >> 16); - header[6] = uint8(seqNum >> 8); - header[7] = uint8(seqNum); - header[8] = uint8(r.contentType); - header[9] = r.major; - header[10] = r.minor; - header[11] = uint8(length >> 8); - header[12] = uint8(length); + header[0] = uint8(seqNum >> 56) + header[1] = uint8(seqNum >> 48) + header[2] = uint8(seqNum >> 40) + header[3] = uint8(seqNum >> 32) + header[4] = uint8(seqNum >> 24) + header[5] = uint8(seqNum >> 16) + header[6] = uint8(seqNum >> 8) + header[7] = uint8(seqNum) + header[8] = uint8(r.contentType) + header[9] = r.major + header[10] = r.minor + header[11] = uint8(length >> 8) + header[12] = uint8(length) } func (w *recordWriter) writeRecord(r *record) { - w.mac.Reset(); + w.mac.Reset() - fillMACHeader(&w.header, w.seqNum, len(r.payload), r); + fillMACHeader(&w.header, w.seqNum, len(r.payload), r) - w.mac.Write(w.header[0:13]); - w.mac.Write(r.payload); - macBytes := w.mac.Sum(); + w.mac.Write(w.header[0:13]) + w.mac.Write(r.payload) + macBytes := w.mac.Sum() - w.encryptor.XORKeyStream(r.payload); - w.encryptor.XORKeyStream(macBytes); + w.encryptor.XORKeyStream(r.payload) + w.encryptor.XORKeyStream(macBytes) - length := len(r.payload) + len(macBytes); - w.header[11] = uint8(length >> 8); - w.header[12] = uint8(length); - w.writer.Write(w.header[8:13]); - w.writer.Write(r.payload); - w.writer.Write(macBytes); + length := len(r.payload) + len(macBytes) + w.header[11] = uint8(length >> 8) + w.header[12] = uint8(length) + w.writer.Write(w.header[8:13]) + w.writer.Write(r.payload) + w.writer.Write(macBytes) - w.seqNum++; + w.seqNum++ } func (w *recordWriter) processControlMessage(controlMsg interface{}) { if controlMsg == nil { - w.shutdown = true; - return; + w.shutdown = true + return } switch msg := controlMsg.(type) { case writerChangeCipherSpec: - w.writeRecord(&record{recordTypeChangeCipherSpec, w.major, w.minor, []byte{0x01}}); - w.encryptor = msg.encryptor; - w.mac = msg.mac; - w.seqNum = 0; + w.writeRecord(&record{recordTypeChangeCipherSpec, w.major, w.minor, []byte{0x01}}) + w.encryptor = msg.encryptor + w.mac = msg.mac + w.seqNum = 0 case writerSetVersion: - w.major = msg.major; - w.minor = msg.minor; + w.major = msg.major + w.minor = msg.minor case alert: w.writeRecord(&record{recordTypeAlert, w.major, w.minor, []byte{byte(msg.level), byte(msg.error)}}) case handshakeMessage: @@ -153,18 +153,18 @@ func (w *recordWriter) processControlMessage(controlMsg interface{}) { func (w *recordWriter) processAppMessage(appMsg []byte) { if closed(w.appChan) { - w.writeRecord(&record{recordTypeApplicationData, w.major, w.minor, []byte{byte(alertCloseNotify)}}); - w.shutdown = true; - return; + w.writeRecord(&record{recordTypeApplicationData, w.major, w.minor, []byte{byte(alertCloseNotify)}}) + w.shutdown = true + return } - var done int; + var done int for done < len(appMsg) { - todo := len(appMsg); + todo := len(appMsg) if todo > maxTLSPlaintext { todo = maxTLSPlaintext } - w.writeRecord(&record{recordTypeApplicationData, w.major, w.minor, appMsg[done : done+todo]}); - done += todo; + w.writeRecord(&record{recordTypeApplicationData, w.major, w.minor, appMsg[done : done+todo]}) + done += todo } } diff --git a/tls.go b/tls.go index c5a0f69..29d918e 100644 --- a/tls.go +++ b/tls.go @@ -6,26 +6,26 @@ package tls import ( - "io"; - "os"; - "net"; - "time"; + "io" + "os" + "net" + "time" ) // A Conn represents a secure connection. type Conn struct { - net.Conn; - writeChan chan<- []byte; - readChan <-chan []byte; - requestChan chan<- interface{}; - readBuf []byte; - eof bool; - readTimeout, writeTimeout int64; + net.Conn + writeChan chan<- []byte + readChan <-chan []byte + requestChan chan<- interface{} + readBuf []byte + eof bool + readTimeout, writeTimeout int64 } func timeout(c chan<- bool, nsecs int64) { - time.Sleep(nsecs); - c <- true; + time.Sleep(nsecs) + c <- true } func (tls *Conn) Read(p []byte) (int, os.Error) { @@ -34,10 +34,10 @@ func (tls *Conn) Read(p []byte) (int, os.Error) { return 0, os.EOF } - var timeoutChan chan bool; + var timeoutChan chan bool if tls.readTimeout > 0 { - timeoutChan = make(chan bool); - go timeout(timeoutChan, tls.readTimeout); + timeoutChan = make(chan bool) + go timeout(timeoutChan, tls.readTimeout) } select { @@ -53,14 +53,14 @@ func (tls *Conn) Read(p []byte) (int, os.Error) { return 0, io.ErrUnexpectedEOF } if len(tls.readBuf) == 0 { - tls.eof = true; - return 0, os.EOF; + tls.eof = true + return 0, os.EOF } } - n := copy(p, tls.readBuf); - tls.readBuf = tls.readBuf[n:]; - return n, nil; + n := copy(p, tls.readBuf) + tls.readBuf = tls.readBuf[n:] + return n, nil } func (tls *Conn) Write(p []byte) (int, os.Error) { @@ -68,10 +68,10 @@ func (tls *Conn) Write(p []byte) (int, os.Error) { return 0, os.EOF } - var timeoutChan chan bool; + var timeoutChan chan bool if tls.writeTimeout > 0 { - timeoutChan = make(chan bool); - go timeout(timeoutChan, tls.writeTimeout); + timeoutChan = make(chan bool) + go timeout(timeoutChan, tls.writeTimeout) } select { @@ -80,73 +80,73 @@ func (tls *Conn) Write(p []byte) (int, os.Error) { return 0, os.EAGAIN } - return len(p), nil; + return len(p), nil } func (tls *Conn) Close() os.Error { - close(tls.writeChan); - close(tls.requestChan); - tls.eof = true; - return nil; + close(tls.writeChan) + close(tls.requestChan) + tls.eof = true + return nil } func (tls *Conn) SetTimeout(nsec int64) os.Error { - tls.readTimeout = nsec; - tls.writeTimeout = nsec; - return nil; + tls.readTimeout = nsec + tls.writeTimeout = nsec + return nil } func (tls *Conn) SetReadTimeout(nsec int64) os.Error { - tls.readTimeout = nsec; - return nil; + tls.readTimeout = nsec + return nil } func (tls *Conn) SetWriteTimeout(nsec int64) os.Error { - tls.writeTimeout = nsec; - return nil; + tls.writeTimeout = nsec + return nil } func (tls *Conn) GetConnectionState() ConnectionState { - replyChan := make(chan ConnectionState); - tls.requestChan <- getConnectionState{replyChan}; - return <-replyChan; + replyChan := make(chan ConnectionState) + tls.requestChan <- getConnectionState{replyChan} + return <-replyChan } func (tls *Conn) WaitConnectionState() ConnectionState { - replyChan := make(chan ConnectionState); - tls.requestChan <- waitConnectionState{replyChan}; - return <-replyChan; + replyChan := make(chan ConnectionState) + tls.requestChan <- waitConnectionState{replyChan} + return <-replyChan } type handshaker interface { - loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config); + loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config) } // Server establishes a secure connection over the given connection and acts // as a TLS server. func startTLSGoroutines(conn net.Conn, h handshaker, config *Config) *Conn { - tls := new(Conn); - tls.Conn = conn; + tls := new(Conn) + tls.Conn = conn - writeChan := make(chan []byte); - readChan := make(chan []byte); - requestChan := make(chan interface{}); + writeChan := make(chan []byte) + readChan := make(chan []byte) + requestChan := make(chan interface{}) - tls.writeChan = writeChan; - tls.readChan = readChan; - tls.requestChan = requestChan; + tls.writeChan = writeChan + tls.readChan = readChan + tls.requestChan = requestChan - handshakeWriterChan := make(chan interface{}); - processorHandshakeChan := make(chan interface{}); - handshakeProcessorChan := make(chan interface{}); - readerProcessorChan := make(chan *record); + handshakeWriterChan := make(chan interface{}) + processorHandshakeChan := make(chan interface{}) + handshakeProcessorChan := make(chan interface{}) + readerProcessorChan := make(chan *record) - go new(recordWriter).loop(conn, writeChan, handshakeWriterChan); - go recordReader(readerProcessorChan, conn); - go new(recordProcessor).loop(readChan, requestChan, handshakeProcessorChan, readerProcessorChan, processorHandshakeChan); - go h.loop(handshakeWriterChan, handshakeProcessorChan, processorHandshakeChan, config); + go new(recordWriter).loop(conn, writeChan, handshakeWriterChan) + go recordReader(readerProcessorChan, conn) + go new(recordProcessor).loop(readChan, requestChan, handshakeProcessorChan, readerProcessorChan, processorHandshakeChan) + go h.loop(handshakeWriterChan, handshakeProcessorChan, processorHandshakeChan, config) - return tls; + return tls } func Server(conn net.Conn, config *Config) *Conn { @@ -158,28 +158,28 @@ func Client(conn net.Conn, config *Config) *Conn { } type Listener struct { - listener net.Listener; - config *Config; + listener net.Listener + config *Config } func (l Listener) Accept() (c net.Conn, err os.Error) { - c, err = l.listener.Accept(); + c, err = l.listener.Accept() if err != nil { return } - c = Server(c, l.config); - return; + c = Server(c, l.config) + return } -func (l Listener) Close() os.Error { return l.listener.Close() } +func (l Listener) Close() os.Error { return l.listener.Close() } -func (l Listener) Addr() net.Addr { return l.listener.Addr() } +func (l Listener) Addr() net.Addr { return l.listener.Addr() } // NewListener creates a Listener which accepts connections from an inner // Listener and wraps each connection with Server. func NewListener(listener net.Listener, config *Config) (l Listener) { - l.listener = listener; - l.config = config; - return; + l.listener = listener + l.config = config + return }