Document undocumented exported names. Allow nil Rand, Time, RootCAs in Config. Fixes #1248. R=agl1 CC=golang-dev https://golang.org/cl/3481042v1.2.3
@@ -16,6 +16,7 @@ type CASet struct { | |||
byName map[string][]*x509.Certificate | |||
} | |||
// NewCASet returns a new, empty CASet. | |||
func NewCASet() *CASet { | |||
return &CASet{ | |||
make(map[string][]*x509.Certificate), | |||
@@ -78,6 +78,7 @@ const ( | |||
// Rest of these are reserved by the TLS spec | |||
) | |||
// ConnectionState records basic TLS details about the connection. | |||
type ConnectionState struct { | |||
HandshakeComplete bool | |||
CipherSuite uint16 | |||
@@ -88,28 +89,65 @@ type ConnectionState struct { | |||
// has been passed to a TLS function it must not be modified. | |||
type Config struct { | |||
// Rand provides the source of entropy for nonces and RSA blinding. | |||
// If Rand is nil, TLS uses the cryptographic random reader in package | |||
// crypto/rand. | |||
Rand io.Reader | |||
// Time returns the current time as the number of seconds since the epoch. | |||
// If Time is nil, TLS uses the system time.Seconds. | |||
Time func() int64 | |||
// Certificates contains one or more certificate chains. | |||
// Certificates contains one or more certificate chains | |||
// to present to the other side of the connection. | |||
// Server configurations must include at least one certificate. | |||
Certificates []Certificate | |||
RootCAs *CASet | |||
// RootCAs defines the set of root certificate authorities | |||
// that clients use when verifying server certificates. | |||
// If RootCAs is nil, TLS uses the host's root CA set. | |||
RootCAs *CASet | |||
// NextProtos is a list of supported, application level protocols. | |||
// Currently only server-side handling is supported. | |||
NextProtos []string | |||
// ServerName is included in the client's handshake to support virtual | |||
// hosting. | |||
ServerName string | |||
// AuthenticateClient determines if a server will request a certificate | |||
// AuthenticateClient controls whether a server will request a certificate | |||
// from the client. It does not require that the client send a | |||
// certificate nor, if it does, that the certificate is anything more | |||
// than self-signed. | |||
// certificate nor does it require that the certificate sent be | |||
// anything more than self-signed. | |||
AuthenticateClient bool | |||
} | |||
func (c *Config) rand() io.Reader { | |||
r := c.Rand | |||
if r == nil { | |||
return rand.Reader | |||
} | |||
return r | |||
} | |||
func (c *Config) time() int64 { | |||
t := c.Time | |||
if t == nil { | |||
t = time.Seconds | |||
} | |||
return t() | |||
} | |||
func (c *Config) rootCAs() *CASet { | |||
s := c.RootCAs | |||
if s == nil { | |||
s = defaultRoots() | |||
} | |||
return s | |||
} | |||
// A Certificate is a chain of one or more certificates, leaf first. | |||
type Certificate struct { | |||
// Certificate contains a chain of one or more certificates. Leaf | |||
// certificate first. | |||
Certificate [][]byte | |||
PrivateKey *rsa.PrivateKey | |||
} | |||
@@ -143,14 +181,10 @@ func mutualVersion(vers uint16) (uint16, bool) { | |||
return vers, true | |||
} | |||
// The defaultConfig is used in place of a nil *Config in the TLS server and client. | |||
var varDefaultConfig *Config | |||
var once sync.Once | |||
var emptyConfig Config | |||
func defaultConfig() *Config { | |||
once.Do(initDefaultConfig) | |||
return varDefaultConfig | |||
return &emptyConfig | |||
} | |||
// Possible certificate files; stop after finding one. | |||
@@ -162,7 +196,16 @@ var certFiles = []string{ | |||
"/usr/share/curl/curl-ca-bundle.crt", // OS X | |||
} | |||
func initDefaultConfig() { | |||
var once sync.Once | |||
func defaultRoots() *CASet { | |||
once.Do(initDefaultRoots) | |||
return varDefaultRoots | |||
} | |||
var varDefaultRoots *CASet | |||
func initDefaultRoots() { | |||
roots := NewCASet() | |||
for _, file := range certFiles { | |||
data, err := ioutil.ReadFile(file) | |||
@@ -171,10 +214,5 @@ func initDefaultConfig() { | |||
break | |||
} | |||
} | |||
varDefaultConfig = &Config{ | |||
Rand: rand.Reader, | |||
Time: time.Seconds, | |||
RootCAs: roots, | |||
} | |||
varDefaultRoots = roots | |||
} |
@@ -30,12 +30,12 @@ func (c *Conn) clientHandshake() os.Error { | |||
serverName: c.config.ServerName, | |||
} | |||
t := uint32(c.config.Time()) | |||
t := uint32(c.config.time()) | |||
hello.random[0] = byte(t >> 24) | |||
hello.random[1] = byte(t >> 16) | |||
hello.random[2] = byte(t >> 8) | |||
hello.random[3] = byte(t) | |||
_, err := io.ReadFull(c.config.Rand, hello.random[4:]) | |||
_, err := io.ReadFull(c.config.rand(), hello.random[4:]) | |||
if err != nil { | |||
c.sendAlert(alertInternalError) | |||
return os.ErrorString("short read from Rand") | |||
@@ -217,12 +217,12 @@ func (c *Conn) clientHandshake() os.Error { | |||
preMasterSecret := make([]byte, 48) | |||
preMasterSecret[0] = byte(hello.vers >> 8) | |||
preMasterSecret[1] = byte(hello.vers) | |||
_, err = io.ReadFull(c.config.Rand, preMasterSecret[2:]) | |||
_, err = io.ReadFull(c.config.rand(), preMasterSecret[2:]) | |||
if err != nil { | |||
return c.sendAlert(alertInternalError) | |||
} | |||
ckx.ciphertext, err = rsa.EncryptPKCS1v15(c.config.Rand, pub, preMasterSecret) | |||
ckx.ciphertext, err = rsa.EncryptPKCS1v15(c.config.rand(), pub, preMasterSecret) | |||
if err != nil { | |||
return c.sendAlert(alertInternalError) | |||
} | |||
@@ -235,7 +235,7 @@ func (c *Conn) clientHandshake() os.Error { | |||
var digest [36]byte | |||
copy(digest[0:16], finishedHash.serverMD5.Sum()) | |||
copy(digest[16:36], finishedHash.serverSHA1.Sum()) | |||
signed, err := rsa.SignPKCS1v15(c.config.Rand, c.config.Certificates[0].PrivateKey, rsa.HashMD5SHA1, digest[0:]) | |||
signed, err := rsa.SignPKCS1v15(c.config.rand(), c.config.Certificates[0].PrivateKey, rsa.HashMD5SHA1, digest[0:]) | |||
if err != nil { | |||
return c.sendAlert(alertInternalError) | |||
} | |||
@@ -84,13 +84,13 @@ func (c *Conn) serverHandshake() os.Error { | |||
hello.vers = vers | |||
hello.cipherSuite = suite.id | |||
t := uint32(config.Time()) | |||
t := uint32(config.time()) | |||
hello.random = make([]byte, 32) | |||
hello.random[0] = byte(t >> 24) | |||
hello.random[1] = byte(t >> 16) | |||
hello.random[2] = byte(t >> 8) | |||
hello.random[3] = byte(t) | |||
_, err = io.ReadFull(config.Rand, hello.random[4:]) | |||
_, err = io.ReadFull(config.rand(), hello.random[4:]) | |||
if err != nil { | |||
return c.sendAlert(alertInternalError) | |||
} | |||
@@ -209,12 +209,12 @@ func (c *Conn) serverHandshake() os.Error { | |||
} | |||
preMasterSecret := make([]byte, 48) | |||
_, err = io.ReadFull(config.Rand, preMasterSecret[2:]) | |||
_, err = io.ReadFull(config.rand(), preMasterSecret[2:]) | |||
if err != nil { | |||
return c.sendAlert(alertInternalError) | |||
} | |||
err = rsa.DecryptPKCS1v15SessionKey(config.Rand, config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret) | |||
err = rsa.DecryptPKCS1v15SessionKey(config.rand(), config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret) | |||
if err != nil { | |||
return c.sendAlert(alertHandshakeFailure) | |||
} | |||
@@ -15,19 +15,31 @@ import ( | |||
"strings" | |||
) | |||
// Server returns a new TLS server side connection | |||
// using conn as the underlying transport. | |||
// The configuration config must be non-nil and must have | |||
// at least one certificate. | |||
func Server(conn net.Conn, config *Config) *Conn { | |||
return &Conn{conn: conn, config: config} | |||
} | |||
// Client returns a new TLS client side connection | |||
// using conn as the underlying transport. | |||
// Client interprets a nil configuration as equivalent to | |||
// the zero configuration; see the documentation of Config | |||
// for the defaults. | |||
func Client(conn net.Conn, config *Config) *Conn { | |||
return &Conn{conn: conn, config: config, isClient: true} | |||
} | |||
// A Listener implements a network listener (net.Listener) for TLS connections. | |||
type Listener struct { | |||
listener net.Listener | |||
config *Config | |||
} | |||
// Accept waits for and returns the next incoming TLS connection. | |||
// The returned connection c is a *tls.Conn. | |||
func (l *Listener) Accept() (c net.Conn, err os.Error) { | |||
c, err = l.listener.Accept() | |||
if err != nil { | |||
@@ -37,8 +49,10 @@ func (l *Listener) Accept() (c net.Conn, err os.Error) { | |||
return | |||
} | |||
// Close closes the listener. | |||
func (l *Listener) Close() os.Error { return l.listener.Close() } | |||
// Addr returns the listener's network address. | |||
func (l *Listener) Addr() net.Addr { return l.listener.Addr() } | |||
// NewListener creates a Listener which accepts connections from an inner | |||
@@ -52,7 +66,11 @@ func NewListener(listener net.Listener, config *Config) (l *Listener) { | |||
return | |||
} | |||
func Listen(network, laddr string, config *Config) (net.Listener, os.Error) { | |||
// Listen creates a TLS listener accepting connections on the | |||
// given network address using net.Listen. | |||
// The configuration config must be non-nil and must have | |||
// at least one certificate. | |||
func Listen(network, laddr string, config *Config) (*Listener, os.Error) { | |||
if config == nil || len(config.Certificates) == 0 { | |||
return nil, os.NewError("tls.Listen: no certificates in configuration") | |||
} | |||
@@ -63,7 +81,13 @@ func Listen(network, laddr string, config *Config) (net.Listener, os.Error) { | |||
return NewListener(l, config), nil | |||
} | |||
func Dial(network, laddr, raddr string) (net.Conn, os.Error) { | |||
// Dial connects to the given network address using net.Dial | |||
// and then initiates a TLS handshake, returning the resulting | |||
// TLS connection. | |||
// Dial interprets a nil configuration as equivalent to | |||
// the zero configuration; see the documentation of Config | |||
// for the defaults. | |||
func Dial(network, laddr, raddr string, config *Config) (*Conn, os.Error) { | |||
c, err := net.Dial(network, laddr, raddr) | |||
if err != nil { | |||
return nil, err | |||
@@ -75,15 +99,21 @@ func Dial(network, laddr, raddr string) (net.Conn, os.Error) { | |||
} | |||
hostname := raddr[:colonPos] | |||
config := defaultConfig() | |||
config.ServerName = hostname | |||
if config == nil { | |||
config = defaultConfig() | |||
} | |||
if config.ServerName != "" { | |||
// Make a copy to avoid polluting argument or default. | |||
c := *config | |||
c.ServerName = hostname | |||
config = &c | |||
} | |||
conn := Client(c, config) | |||
err = conn.Handshake() | |||
if err == nil { | |||
return conn, nil | |||
if err = conn.Handshake(); err != nil { | |||
c.Close() | |||
return nil, err | |||
} | |||
c.Close() | |||
return nil, err | |||
return conn, nil | |||
} | |||
// LoadX509KeyPair reads and parses a public/private key pair from a pair of | |||