// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This package partially implements the TLS 1.1 protocol, as specified in RFC 4346. package tls import ( "io" "os" "net" "time" ) // A Conn represents a secure connection. type Conn struct { net.Conn writeChan chan<- []byte readChan <-chan []byte requestChan chan<- interface{} readBuf []byte eof bool readTimeout, writeTimeout int64 } func timeout(c chan<- bool, nsecs int64) { time.Sleep(nsecs) c <- true } func (tls *Conn) Read(p []byte) (int, os.Error) { if len(tls.readBuf) == 0 { if tls.eof { return 0, os.EOF } var timeoutChan chan bool if tls.readTimeout > 0 { timeoutChan = make(chan bool) go timeout(timeoutChan, tls.readTimeout) } select { case b := <-tls.readChan: tls.readBuf = b case <-timeoutChan: return 0, os.EAGAIN } // TLS distinguishes between orderly closes and truncations. An // orderly close is represented by a zero length slice. if closed(tls.readChan) { return 0, io.ErrUnexpectedEOF } if len(tls.readBuf) == 0 { tls.eof = true return 0, os.EOF } } n := copy(p, tls.readBuf) tls.readBuf = tls.readBuf[n:] return n, nil } func (tls *Conn) Write(p []byte) (int, os.Error) { if tls.eof || closed(tls.readChan) { return 0, os.EOF } var timeoutChan chan bool if tls.writeTimeout > 0 { timeoutChan = make(chan bool) go timeout(timeoutChan, tls.writeTimeout) } select { case tls.writeChan <- p: case <-timeoutChan: return 0, os.EAGAIN } return len(p), nil } func (tls *Conn) Close() os.Error { close(tls.writeChan) close(tls.requestChan) tls.eof = true return nil } func (tls *Conn) SetTimeout(nsec int64) os.Error { tls.readTimeout = nsec tls.writeTimeout = nsec return nil } func (tls *Conn) SetReadTimeout(nsec int64) os.Error { tls.readTimeout = nsec return nil } func (tls *Conn) SetWriteTimeout(nsec int64) os.Error { tls.writeTimeout = nsec return nil } func (tls *Conn) GetConnectionState() ConnectionState { replyChan := make(chan ConnectionState) tls.requestChan <- getConnectionState{replyChan} return <-replyChan } func (tls *Conn) WaitConnectionState() ConnectionState { replyChan := make(chan ConnectionState) tls.requestChan <- waitConnectionState{replyChan} return <-replyChan } type handshaker interface { loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config) } // Server establishes a secure connection over the given connection and acts // as a TLS server. func startTLSGoroutines(conn net.Conn, h handshaker, config *Config) *Conn { if config == nil { config = defaultConfig() } tls := new(Conn) tls.Conn = conn writeChan := make(chan []byte) readChan := make(chan []byte) requestChan := make(chan interface{}) tls.writeChan = writeChan tls.readChan = readChan tls.requestChan = requestChan handshakeWriterChan := make(chan interface{}) processorHandshakeChan := make(chan interface{}) handshakeProcessorChan := make(chan interface{}) readerProcessorChan := make(chan *record) go new(recordWriter).loop(conn, writeChan, handshakeWriterChan) go recordReader(readerProcessorChan, conn) go new(recordProcessor).loop(readChan, requestChan, handshakeProcessorChan, readerProcessorChan, processorHandshakeChan) go h.loop(handshakeWriterChan, handshakeProcessorChan, processorHandshakeChan, config) return tls } func Server(conn net.Conn, config *Config) *Conn { return startTLSGoroutines(conn, new(serverHandshake), config) } func Client(conn net.Conn, config *Config) *Conn { return startTLSGoroutines(conn, new(clientHandshake), config) } type Listener struct { listener net.Listener config *Config } func (l *Listener) Accept() (c net.Conn, err os.Error) { c, err = l.listener.Accept() if err != nil { return } c = Server(c, l.config) return } func (l *Listener) Close() os.Error { return l.listener.Close() } func (l *Listener) Addr() net.Addr { return l.listener.Addr() } // NewListener creates a Listener which accepts connections from an inner // Listener and wraps each connection with Server. func NewListener(listener net.Listener, config *Config) (l *Listener) { if config == nil { config = defaultConfig() } l = new(Listener) l.listener = listener l.config = config return } func Listen(network, laddr string) (net.Listener, os.Error) { l, err := net.Listen(network, laddr) if err != nil { return nil, err } return NewListener(l, nil), nil } func Dial(network, laddr, raddr string) (net.Conn, os.Error) { c, err := net.Dial(network, laddr, raddr) if err != nil { return nil, err } return Client(c, nil), nil }