diff --git a/handshake_messages.go b/handshake_messages.go new file mode 100644 index 0000000..f2d88d8 --- /dev/null +++ b/handshake_messages.go @@ -0,0 +1,236 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes"; +) + +type clientHelloMsg struct { + raw []byte; + major, minor uint8; + random []byte; + sessionId []byte; + cipherSuites []uint16; + compressionMethods []uint8; +} + +func (m *clientHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw; + } + + length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods); + x := make([]byte, 4+length); + x[0] = typeClientHello; + x[1] = uint8(length>>16); + x[2] = uint8(length>>8); + x[3] = uint8(length); + x[4] = m.major; + x[5] = m.minor; + bytes.Copy(x[6:38], m.random); + x[38] = uint8(len(m.sessionId)); + bytes.Copy(x[39 : 39+len(m.sessionId)], m.sessionId); + y := x[39+len(m.sessionId) : len(x)]; + y[0] = uint8(len(m.cipherSuites)>>7); + y[1] = uint8(len(m.cipherSuites)<<1); + for i, suite := range m.cipherSuites { + y[2 + i*2] = uint8(suite>>8); + y[3 + i*2] = uint8(suite); + } + z := y[2 + len(m.cipherSuites)*2 : len(y)]; + z[0] = uint8(len(m.compressionMethods)); + bytes.Copy(z[1:len(z)], m.compressionMethods); + m.raw = x; + + return x; +} + +func (m *clientHelloMsg) unmarshal(data []byte) bool { + if len(data) < 39 { + return false; + } + m.raw = data; + m.major = data[4]; + m.minor = data[5]; + m.random = data[6:38]; + sessionIdLen := int(data[38]); + if sessionIdLen > 32 || len(data) < 39 + sessionIdLen { + return false; + } + m.sessionId = data[39 : 39 + sessionIdLen]; + data = data[39 + sessionIdLen : len(data)]; + if len(data) < 2 { + return false; + } + // cipherSuiteLen is the number of bytes of cipher suite numbers. Since + // they are uint16s, the number must be even. + cipherSuiteLen := int(data[0])<<8 | int(data[1]); + if cipherSuiteLen % 2 == 1 || len(data) < 2 + cipherSuiteLen { + return false; + } + numCipherSuites := cipherSuiteLen / 2; + m.cipherSuites = make([]uint16, numCipherSuites); + for i := 0; i < numCipherSuites; i++ { + m.cipherSuites[i] = uint16(data[2 + 2*i])<<8 | uint16(data[3 + 2*i]); + } + data = data[2 + cipherSuiteLen : len(data)]; + if len(data) < 2 { + return false; + } + compressionMethodsLen := int(data[0]); + if len(data) < 1 + compressionMethodsLen { + return false; + } + m.compressionMethods = data[1 : 1 + compressionMethodsLen]; + + // A ClientHello may be following by trailing data: RFC 4346 section 7.4.1.2 + return true; +} + +type serverHelloMsg struct { + raw []byte; + major, minor uint8; + random []byte; + sessionId []byte; + cipherSuite uint16; + compressionMethod uint8; +} + +func (m *serverHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw; + } + + length := 38+len(m.sessionId); + x := make([]byte, 4+length); + x[0] = typeServerHello; + x[1] = uint8(length>>16); + x[2] = uint8(length>>8); + x[3] = uint8(length); + x[4] = m.major; + x[5] = m.minor; + bytes.Copy(x[6:38], m.random); + x[38] = uint8(len(m.sessionId)); + bytes.Copy(x[39 : 39+len(m.sessionId)], m.sessionId); + z := x[39+len(m.sessionId) : len(x)]; + z[0] = uint8(m.cipherSuite >> 8); + z[1] = uint8(m.cipherSuite); + z[2] = uint8(m.compressionMethod); + m.raw = x; + + return x; +} + +type certificateMsg struct { + raw []byte; + certificates [][]byte; +} + +func (m *certificateMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw; + } + + var i int; + for _, slice := range m.certificates { + i += len(slice); + } + + length := 3 + 3*len(m.certificates) + i; + x = make([]byte, 4+length); + x[0] = typeCertificate; + x[1] = uint8(length>>16); + x[2] = uint8(length>>8); + x[3] = uint8(length); + + certificateOctets := length-3; + x[4] = uint8(certificateOctets >> 16); + x[5] = uint8(certificateOctets >> 8); + x[6] = uint8(certificateOctets); + + y := x[7:len(x)]; + for _, slice := range m.certificates { + y[0] = uint8(len(slice)>>16); + y[1] = uint8(len(slice)>>8); + y[2] = uint8(len(slice)); + bytes.Copy(y[3:len(y)], slice); + y = y[3+len(slice) : len(y)]; + } + + m.raw = x; + return; +} + +type serverHelloDoneMsg struct{} + +func (m *serverHelloDoneMsg) marshal() []byte { + x := make([]byte, 4); + x[0] = typeServerHelloDone; + return x; +} + +type clientKeyExchangeMsg struct { + raw []byte; + ciphertext []byte; +} + +func (m *clientKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw; + } + length := len(m.ciphertext)+2; + x := make([]byte, length+4); + x[0] = typeClientKeyExchange; + x[1] = uint8(length>>16); + x[2] = uint8(length>>8); + x[3] = uint8(length); + x[4] = uint8(len(m.ciphertext)>>8); + x[5] = uint8(len(m.ciphertext)); + bytes.Copy(x[6:len(x)], m.ciphertext); + + m.raw = x; + return x; +} + +func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { + m.raw = data; + if len(data) < 7 { + return false; + } + cipherTextLen := int(data[4])<<8 | int(data[5]); + if len(data) != 6 + cipherTextLen { + return false; + } + m.ciphertext = data[6:len(data)]; + return true; +} + +type finishedMsg struct { + raw []byte; + verifyData []byte; +} + +func (m *finishedMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw; + } + + x = make([]byte, 16); + x[0] = typeFinished; + x[3] = 12; + bytes.Copy(x[4:len(x)], m.verifyData); + m.raw = x; + return; +} + +func (m *finishedMsg) unmarshal(data []byte) bool { + m.raw = data; + if len(data) != 4+12 { + return false; + } + m.verifyData = data[4:len(data)]; + return true; +} diff --git a/handshake_messages_test.go b/handshake_messages_test.go new file mode 100644 index 0000000..ff4d741 --- /dev/null +++ b/handshake_messages_test.go @@ -0,0 +1,95 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "rand"; + "reflect"; + "testing"; + "testing/quick"; +) + +var tests = []interface{}{ + &clientHelloMsg{}, + &clientKeyExchangeMsg{}, + &finishedMsg{}, +} + +type testMessage interface { + marshal() []byte; + unmarshal([]byte) bool; +} + +func TestMarshalUnmarshal(t *testing.T) { + rand := rand.New(rand.NewSource(0)); + for i, iface := range tests { + ty := reflect.NewValue(iface).Type(); + + for j := 0; j < 100; j++ { + v, ok := quick.Value(ty, rand); + if !ok { + t.Errorf("#%d: failed to create value", i); + break; + } + + m1 := v.Interface().(testMessage); + marshaled := m1.marshal(); + m2 := iface.(testMessage); + if !m2.unmarshal(marshaled) { + t.Errorf("#%d failed to unmarshal %#v", i, m1); + break; + } + m2.marshal(); // to fill any marshal cache in the message + + if !reflect.DeepEqual(m1, m2) { + t.Errorf("#%d got:%#v want:%#v", i, m1, m2); + break; + } + + // Now check that all prefixes are invalid. + for j := 0; j < len(marshaled); j++ { + if m2.unmarshal(marshaled[0:j]) { + t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1); + break; + } + } + } + } +} + +func randomBytes(n int, rand *rand.Rand) []byte { + r := make([]byte, n); + for i := 0; i < n; i++ { + r[i] = byte(rand.Int31()); + } + return r; +} + +func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &clientHelloMsg{}; + m.major = uint8(rand.Intn(256)); + m.minor = uint8(rand.Intn(256)); + m.random = randomBytes(32, rand); + m.sessionId = randomBytes(rand.Intn(32), rand); + m.cipherSuites = make([]uint16, rand.Intn(63) + 1); + for i := 0; i < len(m.cipherSuites); i++ { + m.cipherSuites[i] = uint16(rand.Int31()); + } + m.compressionMethods = randomBytes(rand.Intn(63) + 1, rand); + + return reflect.NewValue(m); +} + +func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &clientKeyExchangeMsg{}; + m.ciphertext = randomBytes(rand.Intn(1000), rand); + return reflect.NewValue(m); +} + +func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &finishedMsg{}; + m.verifyData = randomBytes(12, rand); + return reflect.NewValue(m); +} diff --git a/record_read.go b/record_read.go new file mode 100644 index 0000000..27b0455 --- /dev/null +++ b/record_read.go @@ -0,0 +1,42 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +// The record reader handles reading from the connection and reassembling TLS +// record structures. It loops forever doing this and writes the TLS records to +// it's outbound channel. On error, it closes its outbound channel. + +import ( + "io"; + "bufio"; +) + +// recordReader loops, reading TLS records from source and writing them to the +// given channel. The channel is closed on EOF or on error. +func recordReader(c chan<- *record, source io.Reader) { + defer close(c); + buf := bufio.NewReader(source); + + for { + var header [5]byte; + n, _ := buf.Read(header[0:len(header)]); + if n != 5 { + return; + } + + recordLength := int(header[3])<<8 | int(header[4]); + if recordLength > maxTLSCiphertext { + return; + } + + payload := make([]byte, recordLength); + n, _ = buf.Read(payload); + if n != recordLength { + return; + } + + c <- &record{recordType(header[0]), header[1], header[2], payload}; + } +} diff --git a/record_read_test.go b/record_read_test.go new file mode 100644 index 0000000..7bd943c --- /dev/null +++ b/record_read_test.go @@ -0,0 +1,73 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes"; + "testing"; + "testing/iotest"; +) + +func matchRecord(r1, r2 *record) bool { + if (r1 == nil) != (r2 == nil) { + return false; + } + if r1 == nil { + return true; + } + return r1.contentType == r2.contentType && + r1.major == r2.major && + r1.minor == r2.minor && + bytes.Compare(r1.payload, r2.payload) == 0; +} + +type recordReaderTest struct { + in []byte; + out []*record; +} + +var recordReaderTests = []recordReaderTest{ + recordReaderTest{nil, nil}, + recordReaderTest{fromHex("01"), nil}, + recordReaderTest{fromHex("0102"), nil}, + recordReaderTest{fromHex("010203"), nil}, + recordReaderTest{fromHex("01020300"), nil}, + recordReaderTest{fromHex("0102030000"), []*record{&record{1, 2, 3, nil}}}, + recordReaderTest{fromHex("01020300000102030000"), []*record{&record{1, 2, 3, nil}, &record{1, 2, 3, nil}}}, + recordReaderTest{fromHex("0102030001fe0102030002feff"), []*record{&record{1, 2, 3, []byte{0xfe}}, &record{1, 2, 3, []byte{0xfe, 0xff}}}}, + recordReaderTest{fromHex("010203000001020300"), []*record{&record{1, 2, 3, nil}}}, +} + +func TestRecordReader(t *testing.T) { + for i, test := range recordReaderTests { + buf := bytes.NewBuffer(test.in); + c := make(chan *record); + go recordReader(c, buf); + matchRecordReaderOutput(t, i, test, c); + + buf = bytes.NewBuffer(test.in); + buf2 := iotest.OneByteReader(buf); + c = make(chan *record); + go recordReader(c, buf2); + matchRecordReaderOutput(t, i*2, test, c); + } +} + +func matchRecordReaderOutput(t *testing.T, i int, test recordReaderTest, c <-chan *record) { + for j, r1 := range test.out { + r2 := <-c; + if r2 == nil { + t.Errorf("#%d truncated after %d values", i, j); + break; + } + if !matchRecord(r1, r2) { + t.Errorf("#%d (%d) got:%#v want:%#v", i, j, r2, r1); + } + } + <-c; + if !closed(c) { + t.Errorf("#%d: channel didn't close", i); + } +} diff --git a/record_write.go b/record_write.go new file mode 100644 index 0000000..241dbec --- /dev/null +++ b/record_write.go @@ -0,0 +1,164 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "fmt"; + "hash"; + "io"; +) + +// writerEnableApplicationData is a message which instructs recordWriter to +// start reading and transmitting data from the application data channel. +type writerEnableApplicationData struct{} + +// writerChangeCipherSpec updates the encryption and MAC functions and resets +// the sequence count. +type writerChangeCipherSpec struct { + encryptor encryptor; + mac hash.Hash; +} + +// writerSetVersion sets the version number bytes that we included in the +// record header for future records. +type writerSetVersion struct { + major, minor uint8; +} + +// A recordWriter accepts messages from the handshake processor and +// application data. It writes them to the outgoing connection and blocks on +// writing. It doesn't read from the application data channel until the +// handshake processor has signaled that the handshake is complete. +type recordWriter struct { + writer io.Writer; + encryptor encryptor; + mac hash.Hash; + seqNum uint64; + major, minor uint8; + shutdown bool; + appChan <-chan []byte; + controlChan <-chan interface{}; + header [13]byte; +} + +func (w *recordWriter) loop(writer io.Writer, appChan <-chan []byte, controlChan <-chan interface{}) { + w.writer = writer; + w.encryptor = nop{}; + w.mac = nop{}; + w.appChan = appChan; + w.controlChan = controlChan; + + for !w.shutdown { + msg := <-controlChan; + if _, ok := msg.(writerEnableApplicationData); ok { + break; + } + w.processControlMessage(msg); + } + + for !w.shutdown { + // Always process control messages first. + if controlMsg, ok := <-controlChan; ok { + w.processControlMessage(controlMsg); + continue; + } + + select { + case controlMsg := <-controlChan: + w.processControlMessage(controlMsg); + case appMsg := <-appChan: + w.processAppMessage(appMsg); + } + } + + if !closed(appChan) { + go func() { for _ = range appChan {} }(); + } + if !closed(controlChan) { + go func() { for _ = range controlChan {} }(); + } +} + +// fillMACHeader generates a MAC header. See RFC 4346, section 6.2.3.1. +func fillMACHeader(header *[13]byte, seqNum uint64, length int, r *record) { + header[0] = uint8(seqNum>>56); + header[1] = uint8(seqNum>>48); + header[2] = uint8(seqNum>>40); + header[3] = uint8(seqNum>>32); + header[4] = uint8(seqNum>>24); + header[5] = uint8(seqNum>>16); + header[6] = uint8(seqNum>>8); + header[7] = uint8(seqNum); + header[8] = uint8(r.contentType); + header[9] = r.major; + header[10] = r.minor; + header[11] = uint8(length>>8); + header[12] = uint8(length); +} + +func (w *recordWriter) writeRecord(r *record) { + w.mac.Reset(); + + fillMACHeader(&w.header, w.seqNum, len(r.payload), r); + + w.mac.Write(w.header[0:13]); + w.mac.Write(r.payload); + macBytes := w.mac.Sum(); + + w.encryptor.XORKeyStream(r.payload); + w.encryptor.XORKeyStream(macBytes); + + length := len(r.payload)+len(macBytes); + w.header[11] = uint8(length>>8); + w.header[12] = uint8(length); + w.writer.Write(w.header[8:13]); + w.writer.Write(r.payload); + w.writer.Write(macBytes); + + w.seqNum++; +} + +func (w *recordWriter) processControlMessage(controlMsg interface{}) { + if controlMsg == nil { + w.shutdown = true; + return; + } + + switch msg := controlMsg.(type) { + case writerChangeCipherSpec: + w.writeRecord(&record{recordTypeChangeCipherSpec, w.major, w.minor, []byte{0x01}}); + w.encryptor = msg.encryptor; + w.mac = msg.mac; + w.seqNum = 0; + case writerSetVersion: + w.major = msg.major; + w.minor = msg.minor; + case alert: + w.writeRecord(&record{recordTypeAlert, w.major, w.minor, []byte{byte(msg.level), byte(msg.error)}}); + case handshakeMessage: + // TODO(agl): marshal may return a slice too large for a single record. + w.writeRecord(&record{recordTypeHandshake, w.major, w.minor, msg.marshal()}); + default: + fmt.Printf("processControlMessage: unknown %#v\n", msg); + } +} + +func (w *recordWriter) processAppMessage(appMsg []byte) { + if closed(w.appChan) { + w.writeRecord(&record{recordTypeApplicationData, w.major, w.minor, []byte{byte(alertCloseNotify)}}); + w.shutdown = true; + return; + } + + var done int; + for done < len(appMsg) { + todo := len(appMsg); + if todo > maxTLSPlaintext { + todo = maxTLSPlaintext; + } + w.writeRecord(&record{recordTypeApplicationData, w.major, w.minor, appMsg[done : done+todo]}); + done += todo; + } +}