diff --git a/common.go b/common.go index 8ef8b09..ef54a1d 100644 --- a/common.go +++ b/common.go @@ -5,9 +5,13 @@ package tls import ( + "crypto/rand" "crypto/rsa" "io" + "io/ioutil" + "once" "os" + "time" ) const ( @@ -130,3 +134,38 @@ func (nop) Sum() []byte { return nil } func (nop) Reset() {} func (nop) Size() int { return 0 } + + +// The defaultConfig is used in place of a nil *Config in the TLS server and client. +var varDefaultConfig *Config + +func defaultConfig() *Config { + once.Do(initDefaultConfig) + return varDefaultConfig +} + +// Possible certificate files; stop after finding one. +// On OS X we should really be using the Directory Services keychain +// but that requires a lot of Mach goo to get at. Instead we use +// the same root set that curl uses. +var certFiles = []string{ + "/etc/ssl/certs/ca-certificates.crt", // Linux etc + "/usr/share/curl/curl-ca-bundle.crt", // OS X +} + +func initDefaultConfig() { + roots := NewCASet() + for _, file := range certFiles { + data, err := ioutil.ReadFile(file) + if err == nil { + roots.SetFromPEM(data) + break + } + } + + varDefaultConfig = &Config{ + Rand: rand.Reader, + Time: time.Seconds, + RootCAs: roots, + } +} diff --git a/tls.go b/tls.go index 7c76dde..5fbf850 100644 --- a/tls.go +++ b/tls.go @@ -125,6 +125,9 @@ type handshaker interface { // Server establishes a secure connection over the given connection and acts // as a TLS server. func startTLSGoroutines(conn net.Conn, h handshaker, config *Config) *Conn { + if config == nil { + config = defaultConfig() + } tls := new(Conn) tls.Conn = conn @@ -167,7 +170,6 @@ func (l *Listener) Accept() (c net.Conn, err os.Error) { if err != nil { return } - c = Server(c, l.config) return } @@ -179,8 +181,27 @@ func (l *Listener) Addr() net.Addr { return l.listener.Addr() } // NewListener creates a Listener which accepts connections from an inner // Listener and wraps each connection with Server. func NewListener(listener net.Listener, config *Config) (l *Listener) { + if config == nil { + config = defaultConfig() + } l = new(Listener) l.listener = listener l.config = config return } + +func Listen(network, laddr string) (net.Listener, os.Error) { + l, err := net.Listen(network, laddr) + if err != nil { + return nil, err + } + return NewListener(l, nil), nil +} + +func Dial(network, laddr, raddr string) (net.Conn, os.Error) { + c, err := net.Dial(network, laddr, raddr) + if err != nil { + return nil, err + } + return Client(c, nil), nil +}