crypto/tls: add *Config argument to Dial

Document undocumented exported names.
Allow nil Rand, Time, RootCAs in Config.

Fixes #1248.

R=agl1
CC=golang-dev
https://golang.org/cl/3481042
This commit is contained in:
Russ Cox 2010-12-07 16:15:15 -05:00
parent 24b6f5d63e
commit f98d01fb7e
5 changed files with 108 additions and 39 deletions

View File

@ -16,6 +16,7 @@ type CASet struct {
byName map[string][]*x509.Certificate byName map[string][]*x509.Certificate
} }
// NewCASet returns a new, empty CASet.
func NewCASet() *CASet { func NewCASet() *CASet {
return &CASet{ return &CASet{
make(map[string][]*x509.Certificate), make(map[string][]*x509.Certificate),

View File

@ -78,6 +78,7 @@ const (
// Rest of these are reserved by the TLS spec // Rest of these are reserved by the TLS spec
) )
// ConnectionState records basic TLS details about the connection.
type ConnectionState struct { type ConnectionState struct {
HandshakeComplete bool HandshakeComplete bool
CipherSuite uint16 CipherSuite uint16
@ -88,28 +89,65 @@ type ConnectionState struct {
// has been passed to a TLS function it must not be modified. // has been passed to a TLS function it must not be modified.
type Config struct { type Config struct {
// Rand provides the source of entropy for nonces and RSA blinding. // Rand provides the source of entropy for nonces and RSA blinding.
// If Rand is nil, TLS uses the cryptographic random reader in package
// crypto/rand.
Rand io.Reader Rand io.Reader
// Time returns the current time as the number of seconds since the epoch. // Time returns the current time as the number of seconds since the epoch.
// If Time is nil, TLS uses the system time.Seconds.
Time func() int64 Time func() int64
// Certificates contains one or more certificate chains.
// Certificates contains one or more certificate chains
// to present to the other side of the connection.
// Server configurations must include at least one certificate.
Certificates []Certificate Certificates []Certificate
// RootCAs defines the set of root certificate authorities
// that clients use when verifying server certificates.
// If RootCAs is nil, TLS uses the host's root CA set.
RootCAs *CASet RootCAs *CASet
// NextProtos is a list of supported, application level protocols. // NextProtos is a list of supported, application level protocols.
// Currently only server-side handling is supported. // Currently only server-side handling is supported.
NextProtos []string NextProtos []string
// ServerName is included in the client's handshake to support virtual // ServerName is included in the client's handshake to support virtual
// hosting. // hosting.
ServerName string ServerName string
// AuthenticateClient determines if a server will request a certificate
// AuthenticateClient controls whether a server will request a certificate
// from the client. It does not require that the client send a // from the client. It does not require that the client send a
// certificate nor, if it does, that the certificate is anything more // certificate nor does it require that the certificate sent be
// than self-signed. // anything more than self-signed.
AuthenticateClient bool AuthenticateClient bool
} }
func (c *Config) rand() io.Reader {
r := c.Rand
if r == nil {
return rand.Reader
}
return r
}
func (c *Config) time() int64 {
t := c.Time
if t == nil {
t = time.Seconds
}
return t()
}
func (c *Config) rootCAs() *CASet {
s := c.RootCAs
if s == nil {
s = defaultRoots()
}
return s
}
// A Certificate is a chain of one or more certificates, leaf first.
type Certificate struct { type Certificate struct {
// Certificate contains a chain of one or more certificates. Leaf
// certificate first.
Certificate [][]byte Certificate [][]byte
PrivateKey *rsa.PrivateKey PrivateKey *rsa.PrivateKey
} }
@ -143,14 +181,10 @@ func mutualVersion(vers uint16) (uint16, bool) {
return vers, true return vers, true
} }
// The defaultConfig is used in place of a nil *Config in the TLS server and client. var emptyConfig Config
var varDefaultConfig *Config
var once sync.Once
func defaultConfig() *Config { func defaultConfig() *Config {
once.Do(initDefaultConfig) return &emptyConfig
return varDefaultConfig
} }
// Possible certificate files; stop after finding one. // Possible certificate files; stop after finding one.
@ -162,7 +196,16 @@ var certFiles = []string{
"/usr/share/curl/curl-ca-bundle.crt", // OS X "/usr/share/curl/curl-ca-bundle.crt", // OS X
} }
func initDefaultConfig() { var once sync.Once
func defaultRoots() *CASet {
once.Do(initDefaultRoots)
return varDefaultRoots
}
var varDefaultRoots *CASet
func initDefaultRoots() {
roots := NewCASet() roots := NewCASet()
for _, file := range certFiles { for _, file := range certFiles {
data, err := ioutil.ReadFile(file) data, err := ioutil.ReadFile(file)
@ -171,10 +214,5 @@ func initDefaultConfig() {
break break
} }
} }
varDefaultRoots = roots
varDefaultConfig = &Config{
Rand: rand.Reader,
Time: time.Seconds,
RootCAs: roots,
}
} }

View File

@ -30,12 +30,12 @@ func (c *Conn) clientHandshake() os.Error {
serverName: c.config.ServerName, serverName: c.config.ServerName,
} }
t := uint32(c.config.Time()) t := uint32(c.config.time())
hello.random[0] = byte(t >> 24) hello.random[0] = byte(t >> 24)
hello.random[1] = byte(t >> 16) hello.random[1] = byte(t >> 16)
hello.random[2] = byte(t >> 8) hello.random[2] = byte(t >> 8)
hello.random[3] = byte(t) hello.random[3] = byte(t)
_, err := io.ReadFull(c.config.Rand, hello.random[4:]) _, err := io.ReadFull(c.config.rand(), hello.random[4:])
if err != nil { if err != nil {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return os.ErrorString("short read from Rand") return os.ErrorString("short read from Rand")
@ -217,12 +217,12 @@ func (c *Conn) clientHandshake() os.Error {
preMasterSecret := make([]byte, 48) preMasterSecret := make([]byte, 48)
preMasterSecret[0] = byte(hello.vers >> 8) preMasterSecret[0] = byte(hello.vers >> 8)
preMasterSecret[1] = byte(hello.vers) preMasterSecret[1] = byte(hello.vers)
_, err = io.ReadFull(c.config.Rand, preMasterSecret[2:]) _, err = io.ReadFull(c.config.rand(), preMasterSecret[2:])
if err != nil { if err != nil {
return c.sendAlert(alertInternalError) return c.sendAlert(alertInternalError)
} }
ckx.ciphertext, err = rsa.EncryptPKCS1v15(c.config.Rand, pub, preMasterSecret) ckx.ciphertext, err = rsa.EncryptPKCS1v15(c.config.rand(), pub, preMasterSecret)
if err != nil { if err != nil {
return c.sendAlert(alertInternalError) return c.sendAlert(alertInternalError)
} }
@ -235,7 +235,7 @@ func (c *Conn) clientHandshake() os.Error {
var digest [36]byte var digest [36]byte
copy(digest[0:16], finishedHash.serverMD5.Sum()) copy(digest[0:16], finishedHash.serverMD5.Sum())
copy(digest[16:36], finishedHash.serverSHA1.Sum()) copy(digest[16:36], finishedHash.serverSHA1.Sum())
signed, err := rsa.SignPKCS1v15(c.config.Rand, c.config.Certificates[0].PrivateKey, rsa.HashMD5SHA1, digest[0:]) signed, err := rsa.SignPKCS1v15(c.config.rand(), c.config.Certificates[0].PrivateKey, rsa.HashMD5SHA1, digest[0:])
if err != nil { if err != nil {
return c.sendAlert(alertInternalError) return c.sendAlert(alertInternalError)
} }

View File

@ -84,13 +84,13 @@ func (c *Conn) serverHandshake() os.Error {
hello.vers = vers hello.vers = vers
hello.cipherSuite = suite.id hello.cipherSuite = suite.id
t := uint32(config.Time()) t := uint32(config.time())
hello.random = make([]byte, 32) hello.random = make([]byte, 32)
hello.random[0] = byte(t >> 24) hello.random[0] = byte(t >> 24)
hello.random[1] = byte(t >> 16) hello.random[1] = byte(t >> 16)
hello.random[2] = byte(t >> 8) hello.random[2] = byte(t >> 8)
hello.random[3] = byte(t) hello.random[3] = byte(t)
_, err = io.ReadFull(config.Rand, hello.random[4:]) _, err = io.ReadFull(config.rand(), hello.random[4:])
if err != nil { if err != nil {
return c.sendAlert(alertInternalError) return c.sendAlert(alertInternalError)
} }
@ -209,12 +209,12 @@ func (c *Conn) serverHandshake() os.Error {
} }
preMasterSecret := make([]byte, 48) preMasterSecret := make([]byte, 48)
_, err = io.ReadFull(config.Rand, preMasterSecret[2:]) _, err = io.ReadFull(config.rand(), preMasterSecret[2:])
if err != nil { if err != nil {
return c.sendAlert(alertInternalError) return c.sendAlert(alertInternalError)
} }
err = rsa.DecryptPKCS1v15SessionKey(config.Rand, config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret) err = rsa.DecryptPKCS1v15SessionKey(config.rand(), config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret)
if err != nil { if err != nil {
return c.sendAlert(alertHandshakeFailure) return c.sendAlert(alertHandshakeFailure)
} }

46
tls.go
View File

@ -15,19 +15,31 @@ import (
"strings" "strings"
) )
// Server returns a new TLS server side connection
// using conn as the underlying transport.
// The configuration config must be non-nil and must have
// at least one certificate.
func Server(conn net.Conn, config *Config) *Conn { func Server(conn net.Conn, config *Config) *Conn {
return &Conn{conn: conn, config: config} return &Conn{conn: conn, config: config}
} }
// Client returns a new TLS client side connection
// using conn as the underlying transport.
// Client interprets a nil configuration as equivalent to
// the zero configuration; see the documentation of Config
// for the defaults.
func Client(conn net.Conn, config *Config) *Conn { func Client(conn net.Conn, config *Config) *Conn {
return &Conn{conn: conn, config: config, isClient: true} return &Conn{conn: conn, config: config, isClient: true}
} }
// A Listener implements a network listener (net.Listener) for TLS connections.
type Listener struct { type Listener struct {
listener net.Listener listener net.Listener
config *Config config *Config
} }
// Accept waits for and returns the next incoming TLS connection.
// The returned connection c is a *tls.Conn.
func (l *Listener) Accept() (c net.Conn, err os.Error) { func (l *Listener) Accept() (c net.Conn, err os.Error) {
c, err = l.listener.Accept() c, err = l.listener.Accept()
if err != nil { if err != nil {
@ -37,8 +49,10 @@ func (l *Listener) Accept() (c net.Conn, err os.Error) {
return return
} }
// Close closes the listener.
func (l *Listener) Close() os.Error { return l.listener.Close() } func (l *Listener) Close() os.Error { return l.listener.Close() }
// Addr returns the listener's network address.
func (l *Listener) Addr() net.Addr { return l.listener.Addr() } func (l *Listener) Addr() net.Addr { return l.listener.Addr() }
// NewListener creates a Listener which accepts connections from an inner // NewListener creates a Listener which accepts connections from an inner
@ -52,7 +66,11 @@ func NewListener(listener net.Listener, config *Config) (l *Listener) {
return return
} }
func Listen(network, laddr string, config *Config) (net.Listener, os.Error) { // Listen creates a TLS listener accepting connections on the
// given network address using net.Listen.
// The configuration config must be non-nil and must have
// at least one certificate.
func Listen(network, laddr string, config *Config) (*Listener, os.Error) {
if config == nil || len(config.Certificates) == 0 { if config == nil || len(config.Certificates) == 0 {
return nil, os.NewError("tls.Listen: no certificates in configuration") return nil, os.NewError("tls.Listen: no certificates in configuration")
} }
@ -63,7 +81,13 @@ func Listen(network, laddr string, config *Config) (net.Listener, os.Error) {
return NewListener(l, config), nil return NewListener(l, config), nil
} }
func Dial(network, laddr, raddr string) (net.Conn, os.Error) { // Dial connects to the given network address using net.Dial
// and then initiates a TLS handshake, returning the resulting
// TLS connection.
// Dial interprets a nil configuration as equivalent to
// the zero configuration; see the documentation of Config
// for the defaults.
func Dial(network, laddr, raddr string, config *Config) (*Conn, os.Error) {
c, err := net.Dial(network, laddr, raddr) c, err := net.Dial(network, laddr, raddr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -75,16 +99,22 @@ func Dial(network, laddr, raddr string) (net.Conn, os.Error) {
} }
hostname := raddr[:colonPos] hostname := raddr[:colonPos]
config := defaultConfig() if config == nil {
config.ServerName = hostname config = defaultConfig()
conn := Client(c, config)
err = conn.Handshake()
if err == nil {
return conn, nil
} }
if config.ServerName != "" {
// Make a copy to avoid polluting argument or default.
c := *config
c.ServerName = hostname
config = &c
}
conn := Client(c, config)
if err = conn.Handshake(); err != nil {
c.Close() c.Close()
return nil, err return nil, err
} }
return conn, nil
}
// LoadX509KeyPair reads and parses a public/private key pair from a pair of // LoadX509KeyPair reads and parses a public/private key pair from a pair of
// files. The files must contain PEM encoded data. // files. The files must contain PEM encoded data.