@@ -0,0 +1,236 @@ | |||
// Copyright 2009 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package tls | |||
import ( | |||
"bytes"; | |||
) | |||
type clientHelloMsg struct { | |||
raw []byte; | |||
major, minor uint8; | |||
random []byte; | |||
sessionId []byte; | |||
cipherSuites []uint16; | |||
compressionMethods []uint8; | |||
} | |||
func (m *clientHelloMsg) marshal() []byte { | |||
if m.raw != nil { | |||
return m.raw; | |||
} | |||
length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods); | |||
x := make([]byte, 4+length); | |||
x[0] = typeClientHello; | |||
x[1] = uint8(length>>16); | |||
x[2] = uint8(length>>8); | |||
x[3] = uint8(length); | |||
x[4] = m.major; | |||
x[5] = m.minor; | |||
bytes.Copy(x[6:38], m.random); | |||
x[38] = uint8(len(m.sessionId)); | |||
bytes.Copy(x[39 : 39+len(m.sessionId)], m.sessionId); | |||
y := x[39+len(m.sessionId) : len(x)]; | |||
y[0] = uint8(len(m.cipherSuites)>>7); | |||
y[1] = uint8(len(m.cipherSuites)<<1); | |||
for i, suite := range m.cipherSuites { | |||
y[2 + i*2] = uint8(suite>>8); | |||
y[3 + i*2] = uint8(suite); | |||
} | |||
z := y[2 + len(m.cipherSuites)*2 : len(y)]; | |||
z[0] = uint8(len(m.compressionMethods)); | |||
bytes.Copy(z[1:len(z)], m.compressionMethods); | |||
m.raw = x; | |||
return x; | |||
} | |||
func (m *clientHelloMsg) unmarshal(data []byte) bool { | |||
if len(data) < 39 { | |||
return false; | |||
} | |||
m.raw = data; | |||
m.major = data[4]; | |||
m.minor = data[5]; | |||
m.random = data[6:38]; | |||
sessionIdLen := int(data[38]); | |||
if sessionIdLen > 32 || len(data) < 39 + sessionIdLen { | |||
return false; | |||
} | |||
m.sessionId = data[39 : 39 + sessionIdLen]; | |||
data = data[39 + sessionIdLen : len(data)]; | |||
if len(data) < 2 { | |||
return false; | |||
} | |||
// cipherSuiteLen is the number of bytes of cipher suite numbers. Since | |||
// they are uint16s, the number must be even. | |||
cipherSuiteLen := int(data[0])<<8 | int(data[1]); | |||
if cipherSuiteLen % 2 == 1 || len(data) < 2 + cipherSuiteLen { | |||
return false; | |||
} | |||
numCipherSuites := cipherSuiteLen / 2; | |||
m.cipherSuites = make([]uint16, numCipherSuites); | |||
for i := 0; i < numCipherSuites; i++ { | |||
m.cipherSuites[i] = uint16(data[2 + 2*i])<<8 | uint16(data[3 + 2*i]); | |||
} | |||
data = data[2 + cipherSuiteLen : len(data)]; | |||
if len(data) < 2 { | |||
return false; | |||
} | |||
compressionMethodsLen := int(data[0]); | |||
if len(data) < 1 + compressionMethodsLen { | |||
return false; | |||
} | |||
m.compressionMethods = data[1 : 1 + compressionMethodsLen]; | |||
// A ClientHello may be following by trailing data: RFC 4346 section 7.4.1.2 | |||
return true; | |||
} | |||
type serverHelloMsg struct { | |||
raw []byte; | |||
major, minor uint8; | |||
random []byte; | |||
sessionId []byte; | |||
cipherSuite uint16; | |||
compressionMethod uint8; | |||
} | |||
func (m *serverHelloMsg) marshal() []byte { | |||
if m.raw != nil { | |||
return m.raw; | |||
} | |||
length := 38+len(m.sessionId); | |||
x := make([]byte, 4+length); | |||
x[0] = typeServerHello; | |||
x[1] = uint8(length>>16); | |||
x[2] = uint8(length>>8); | |||
x[3] = uint8(length); | |||
x[4] = m.major; | |||
x[5] = m.minor; | |||
bytes.Copy(x[6:38], m.random); | |||
x[38] = uint8(len(m.sessionId)); | |||
bytes.Copy(x[39 : 39+len(m.sessionId)], m.sessionId); | |||
z := x[39+len(m.sessionId) : len(x)]; | |||
z[0] = uint8(m.cipherSuite >> 8); | |||
z[1] = uint8(m.cipherSuite); | |||
z[2] = uint8(m.compressionMethod); | |||
m.raw = x; | |||
return x; | |||
} | |||
type certificateMsg struct { | |||
raw []byte; | |||
certificates [][]byte; | |||
} | |||
func (m *certificateMsg) marshal() (x []byte) { | |||
if m.raw != nil { | |||
return m.raw; | |||
} | |||
var i int; | |||
for _, slice := range m.certificates { | |||
i += len(slice); | |||
} | |||
length := 3 + 3*len(m.certificates) + i; | |||
x = make([]byte, 4+length); | |||
x[0] = typeCertificate; | |||
x[1] = uint8(length>>16); | |||
x[2] = uint8(length>>8); | |||
x[3] = uint8(length); | |||
certificateOctets := length-3; | |||
x[4] = uint8(certificateOctets >> 16); | |||
x[5] = uint8(certificateOctets >> 8); | |||
x[6] = uint8(certificateOctets); | |||
y := x[7:len(x)]; | |||
for _, slice := range m.certificates { | |||
y[0] = uint8(len(slice)>>16); | |||
y[1] = uint8(len(slice)>>8); | |||
y[2] = uint8(len(slice)); | |||
bytes.Copy(y[3:len(y)], slice); | |||
y = y[3+len(slice) : len(y)]; | |||
} | |||
m.raw = x; | |||
return; | |||
} | |||
type serverHelloDoneMsg struct{} | |||
func (m *serverHelloDoneMsg) marshal() []byte { | |||
x := make([]byte, 4); | |||
x[0] = typeServerHelloDone; | |||
return x; | |||
} | |||
type clientKeyExchangeMsg struct { | |||
raw []byte; | |||
ciphertext []byte; | |||
} | |||
func (m *clientKeyExchangeMsg) marshal() []byte { | |||
if m.raw != nil { | |||
return m.raw; | |||
} | |||
length := len(m.ciphertext)+2; | |||
x := make([]byte, length+4); | |||
x[0] = typeClientKeyExchange; | |||
x[1] = uint8(length>>16); | |||
x[2] = uint8(length>>8); | |||
x[3] = uint8(length); | |||
x[4] = uint8(len(m.ciphertext)>>8); | |||
x[5] = uint8(len(m.ciphertext)); | |||
bytes.Copy(x[6:len(x)], m.ciphertext); | |||
m.raw = x; | |||
return x; | |||
} | |||
func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { | |||
m.raw = data; | |||
if len(data) < 7 { | |||
return false; | |||
} | |||
cipherTextLen := int(data[4])<<8 | int(data[5]); | |||
if len(data) != 6 + cipherTextLen { | |||
return false; | |||
} | |||
m.ciphertext = data[6:len(data)]; | |||
return true; | |||
} | |||
type finishedMsg struct { | |||
raw []byte; | |||
verifyData []byte; | |||
} | |||
func (m *finishedMsg) marshal() (x []byte) { | |||
if m.raw != nil { | |||
return m.raw; | |||
} | |||
x = make([]byte, 16); | |||
x[0] = typeFinished; | |||
x[3] = 12; | |||
bytes.Copy(x[4:len(x)], m.verifyData); | |||
m.raw = x; | |||
return; | |||
} | |||
func (m *finishedMsg) unmarshal(data []byte) bool { | |||
m.raw = data; | |||
if len(data) != 4+12 { | |||
return false; | |||
} | |||
m.verifyData = data[4:len(data)]; | |||
return true; | |||
} |
@@ -0,0 +1,95 @@ | |||
// Copyright 2009 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package tls | |||
import ( | |||
"rand"; | |||
"reflect"; | |||
"testing"; | |||
"testing/quick"; | |||
) | |||
var tests = []interface{}{ | |||
&clientHelloMsg{}, | |||
&clientKeyExchangeMsg{}, | |||
&finishedMsg{}, | |||
} | |||
type testMessage interface { | |||
marshal() []byte; | |||
unmarshal([]byte) bool; | |||
} | |||
func TestMarshalUnmarshal(t *testing.T) { | |||
rand := rand.New(rand.NewSource(0)); | |||
for i, iface := range tests { | |||
ty := reflect.NewValue(iface).Type(); | |||
for j := 0; j < 100; j++ { | |||
v, ok := quick.Value(ty, rand); | |||
if !ok { | |||
t.Errorf("#%d: failed to create value", i); | |||
break; | |||
} | |||
m1 := v.Interface().(testMessage); | |||
marshaled := m1.marshal(); | |||
m2 := iface.(testMessage); | |||
if !m2.unmarshal(marshaled) { | |||
t.Errorf("#%d failed to unmarshal %#v", i, m1); | |||
break; | |||
} | |||
m2.marshal(); // to fill any marshal cache in the message | |||
if !reflect.DeepEqual(m1, m2) { | |||
t.Errorf("#%d got:%#v want:%#v", i, m1, m2); | |||
break; | |||
} | |||
// Now check that all prefixes are invalid. | |||
for j := 0; j < len(marshaled); j++ { | |||
if m2.unmarshal(marshaled[0:j]) { | |||
t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
func randomBytes(n int, rand *rand.Rand) []byte { | |||
r := make([]byte, n); | |||
for i := 0; i < n; i++ { | |||
r[i] = byte(rand.Int31()); | |||
} | |||
return r; | |||
} | |||
func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { | |||
m := &clientHelloMsg{}; | |||
m.major = uint8(rand.Intn(256)); | |||
m.minor = uint8(rand.Intn(256)); | |||
m.random = randomBytes(32, rand); | |||
m.sessionId = randomBytes(rand.Intn(32), rand); | |||
m.cipherSuites = make([]uint16, rand.Intn(63) + 1); | |||
for i := 0; i < len(m.cipherSuites); i++ { | |||
m.cipherSuites[i] = uint16(rand.Int31()); | |||
} | |||
m.compressionMethods = randomBytes(rand.Intn(63) + 1, rand); | |||
return reflect.NewValue(m); | |||
} | |||
func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { | |||
m := &clientKeyExchangeMsg{}; | |||
m.ciphertext = randomBytes(rand.Intn(1000), rand); | |||
return reflect.NewValue(m); | |||
} | |||
func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { | |||
m := &finishedMsg{}; | |||
m.verifyData = randomBytes(12, rand); | |||
return reflect.NewValue(m); | |||
} |
@@ -0,0 +1,42 @@ | |||
// Copyright 2009 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package tls | |||
// The record reader handles reading from the connection and reassembling TLS | |||
// record structures. It loops forever doing this and writes the TLS records to | |||
// it's outbound channel. On error, it closes its outbound channel. | |||
import ( | |||
"io"; | |||
"bufio"; | |||
) | |||
// recordReader loops, reading TLS records from source and writing them to the | |||
// given channel. The channel is closed on EOF or on error. | |||
func recordReader(c chan<- *record, source io.Reader) { | |||
defer close(c); | |||
buf := bufio.NewReader(source); | |||
for { | |||
var header [5]byte; | |||
n, _ := buf.Read(header[0:len(header)]); | |||
if n != 5 { | |||
return; | |||
} | |||
recordLength := int(header[3])<<8 | int(header[4]); | |||
if recordLength > maxTLSCiphertext { | |||
return; | |||
} | |||
payload := make([]byte, recordLength); | |||
n, _ = buf.Read(payload); | |||
if n != recordLength { | |||
return; | |||
} | |||
c <- &record{recordType(header[0]), header[1], header[2], payload}; | |||
} | |||
} |
@@ -0,0 +1,73 @@ | |||
// Copyright 2009 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package tls | |||
import ( | |||
"bytes"; | |||
"testing"; | |||
"testing/iotest"; | |||
) | |||
func matchRecord(r1, r2 *record) bool { | |||
if (r1 == nil) != (r2 == nil) { | |||
return false; | |||
} | |||
if r1 == nil { | |||
return true; | |||
} | |||
return r1.contentType == r2.contentType && | |||
r1.major == r2.major && | |||
r1.minor == r2.minor && | |||
bytes.Compare(r1.payload, r2.payload) == 0; | |||
} | |||
type recordReaderTest struct { | |||
in []byte; | |||
out []*record; | |||
} | |||
var recordReaderTests = []recordReaderTest{ | |||
recordReaderTest{nil, nil}, | |||
recordReaderTest{fromHex("01"), nil}, | |||
recordReaderTest{fromHex("0102"), nil}, | |||
recordReaderTest{fromHex("010203"), nil}, | |||
recordReaderTest{fromHex("01020300"), nil}, | |||
recordReaderTest{fromHex("0102030000"), []*record{&record{1, 2, 3, nil}}}, | |||
recordReaderTest{fromHex("01020300000102030000"), []*record{&record{1, 2, 3, nil}, &record{1, 2, 3, nil}}}, | |||
recordReaderTest{fromHex("0102030001fe0102030002feff"), []*record{&record{1, 2, 3, []byte{0xfe}}, &record{1, 2, 3, []byte{0xfe, 0xff}}}}, | |||
recordReaderTest{fromHex("010203000001020300"), []*record{&record{1, 2, 3, nil}}}, | |||
} | |||
func TestRecordReader(t *testing.T) { | |||
for i, test := range recordReaderTests { | |||
buf := bytes.NewBuffer(test.in); | |||
c := make(chan *record); | |||
go recordReader(c, buf); | |||
matchRecordReaderOutput(t, i, test, c); | |||
buf = bytes.NewBuffer(test.in); | |||
buf2 := iotest.OneByteReader(buf); | |||
c = make(chan *record); | |||
go recordReader(c, buf2); | |||
matchRecordReaderOutput(t, i*2, test, c); | |||
} | |||
} | |||
func matchRecordReaderOutput(t *testing.T, i int, test recordReaderTest, c <-chan *record) { | |||
for j, r1 := range test.out { | |||
r2 := <-c; | |||
if r2 == nil { | |||
t.Errorf("#%d truncated after %d values", i, j); | |||
break; | |||
} | |||
if !matchRecord(r1, r2) { | |||
t.Errorf("#%d (%d) got:%#v want:%#v", i, j, r2, r1); | |||
} | |||
} | |||
<-c; | |||
if !closed(c) { | |||
t.Errorf("#%d: channel didn't close", i); | |||
} | |||
} |
@@ -0,0 +1,164 @@ | |||
// Copyright 2009 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package tls | |||
import ( | |||
"fmt"; | |||
"hash"; | |||
"io"; | |||
) | |||
// writerEnableApplicationData is a message which instructs recordWriter to | |||
// start reading and transmitting data from the application data channel. | |||
type writerEnableApplicationData struct{} | |||
// writerChangeCipherSpec updates the encryption and MAC functions and resets | |||
// the sequence count. | |||
type writerChangeCipherSpec struct { | |||
encryptor encryptor; | |||
mac hash.Hash; | |||
} | |||
// writerSetVersion sets the version number bytes that we included in the | |||
// record header for future records. | |||
type writerSetVersion struct { | |||
major, minor uint8; | |||
} | |||
// A recordWriter accepts messages from the handshake processor and | |||
// application data. It writes them to the outgoing connection and blocks on | |||
// writing. It doesn't read from the application data channel until the | |||
// handshake processor has signaled that the handshake is complete. | |||
type recordWriter struct { | |||
writer io.Writer; | |||
encryptor encryptor; | |||
mac hash.Hash; | |||
seqNum uint64; | |||
major, minor uint8; | |||
shutdown bool; | |||
appChan <-chan []byte; | |||
controlChan <-chan interface{}; | |||
header [13]byte; | |||
} | |||
func (w *recordWriter) loop(writer io.Writer, appChan <-chan []byte, controlChan <-chan interface{}) { | |||
w.writer = writer; | |||
w.encryptor = nop{}; | |||
w.mac = nop{}; | |||
w.appChan = appChan; | |||
w.controlChan = controlChan; | |||
for !w.shutdown { | |||
msg := <-controlChan; | |||
if _, ok := msg.(writerEnableApplicationData); ok { | |||
break; | |||
} | |||
w.processControlMessage(msg); | |||
} | |||
for !w.shutdown { | |||
// Always process control messages first. | |||
if controlMsg, ok := <-controlChan; ok { | |||
w.processControlMessage(controlMsg); | |||
continue; | |||
} | |||
select { | |||
case controlMsg := <-controlChan: | |||
w.processControlMessage(controlMsg); | |||
case appMsg := <-appChan: | |||
w.processAppMessage(appMsg); | |||
} | |||
} | |||
if !closed(appChan) { | |||
go func() { for _ = range appChan {} }(); | |||
} | |||
if !closed(controlChan) { | |||
go func() { for _ = range controlChan {} }(); | |||
} | |||
} | |||
// fillMACHeader generates a MAC header. See RFC 4346, section 6.2.3.1. | |||
func fillMACHeader(header *[13]byte, seqNum uint64, length int, r *record) { | |||
header[0] = uint8(seqNum>>56); | |||
header[1] = uint8(seqNum>>48); | |||
header[2] = uint8(seqNum>>40); | |||
header[3] = uint8(seqNum>>32); | |||
header[4] = uint8(seqNum>>24); | |||
header[5] = uint8(seqNum>>16); | |||
header[6] = uint8(seqNum>>8); | |||
header[7] = uint8(seqNum); | |||
header[8] = uint8(r.contentType); | |||
header[9] = r.major; | |||
header[10] = r.minor; | |||
header[11] = uint8(length>>8); | |||
header[12] = uint8(length); | |||
} | |||
func (w *recordWriter) writeRecord(r *record) { | |||
w.mac.Reset(); | |||
fillMACHeader(&w.header, w.seqNum, len(r.payload), r); | |||
w.mac.Write(w.header[0:13]); | |||
w.mac.Write(r.payload); | |||
macBytes := w.mac.Sum(); | |||
w.encryptor.XORKeyStream(r.payload); | |||
w.encryptor.XORKeyStream(macBytes); | |||
length := len(r.payload)+len(macBytes); | |||
w.header[11] = uint8(length>>8); | |||
w.header[12] = uint8(length); | |||
w.writer.Write(w.header[8:13]); | |||
w.writer.Write(r.payload); | |||
w.writer.Write(macBytes); | |||
w.seqNum++; | |||
} | |||
func (w *recordWriter) processControlMessage(controlMsg interface{}) { | |||
if controlMsg == nil { | |||
w.shutdown = true; | |||
return; | |||
} | |||
switch msg := controlMsg.(type) { | |||
case writerChangeCipherSpec: | |||
w.writeRecord(&record{recordTypeChangeCipherSpec, w.major, w.minor, []byte{0x01}}); | |||
w.encryptor = msg.encryptor; | |||
w.mac = msg.mac; | |||
w.seqNum = 0; | |||
case writerSetVersion: | |||
w.major = msg.major; | |||
w.minor = msg.minor; | |||
case alert: | |||
w.writeRecord(&record{recordTypeAlert, w.major, w.minor, []byte{byte(msg.level), byte(msg.error)}}); | |||
case handshakeMessage: | |||
// TODO(agl): marshal may return a slice too large for a single record. | |||
w.writeRecord(&record{recordTypeHandshake, w.major, w.minor, msg.marshal()}); | |||
default: | |||
fmt.Printf("processControlMessage: unknown %#v\n", msg); | |||
} | |||
} | |||
func (w *recordWriter) processAppMessage(appMsg []byte) { | |||
if closed(w.appChan) { | |||
w.writeRecord(&record{recordTypeApplicationData, w.major, w.minor, []byte{byte(alertCloseNotify)}}); | |||
w.shutdown = true; | |||
return; | |||
} | |||
var done int; | |||
for done < len(appMsg) { | |||
todo := len(appMsg); | |||
if todo > maxTLSPlaintext { | |||
todo = maxTLSPlaintext; | |||
} | |||
w.writeRecord(&record{recordTypeApplicationData, w.major, w.minor, appMsg[done : done+todo]}); | |||
done += todo; | |||
} | |||
} |