crypto/tls: add *Config argument to Dial
Document undocumented exported names. Allow nil Rand, Time, RootCAs in Config. Fixes #1248. R=agl1 CC=golang-dev https://golang.org/cl/3481042
This commit is contained in:
parent
24b6f5d63e
commit
f98d01fb7e
@ -16,6 +16,7 @@ type CASet struct {
|
|||||||
byName map[string][]*x509.Certificate
|
byName map[string][]*x509.Certificate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewCASet returns a new, empty CASet.
|
||||||
func NewCASet() *CASet {
|
func NewCASet() *CASet {
|
||||||
return &CASet{
|
return &CASet{
|
||||||
make(map[string][]*x509.Certificate),
|
make(map[string][]*x509.Certificate),
|
||||||
|
76
common.go
76
common.go
@ -78,6 +78,7 @@ const (
|
|||||||
// Rest of these are reserved by the TLS spec
|
// Rest of these are reserved by the TLS spec
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ConnectionState records basic TLS details about the connection.
|
||||||
type ConnectionState struct {
|
type ConnectionState struct {
|
||||||
HandshakeComplete bool
|
HandshakeComplete bool
|
||||||
CipherSuite uint16
|
CipherSuite uint16
|
||||||
@ -88,28 +89,65 @@ type ConnectionState struct {
|
|||||||
// has been passed to a TLS function it must not be modified.
|
// has been passed to a TLS function it must not be modified.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
// Rand provides the source of entropy for nonces and RSA blinding.
|
// Rand provides the source of entropy for nonces and RSA blinding.
|
||||||
|
// If Rand is nil, TLS uses the cryptographic random reader in package
|
||||||
|
// crypto/rand.
|
||||||
Rand io.Reader
|
Rand io.Reader
|
||||||
|
|
||||||
// Time returns the current time as the number of seconds since the epoch.
|
// Time returns the current time as the number of seconds since the epoch.
|
||||||
|
// If Time is nil, TLS uses the system time.Seconds.
|
||||||
Time func() int64
|
Time func() int64
|
||||||
// Certificates contains one or more certificate chains.
|
|
||||||
|
// Certificates contains one or more certificate chains
|
||||||
|
// to present to the other side of the connection.
|
||||||
|
// Server configurations must include at least one certificate.
|
||||||
Certificates []Certificate
|
Certificates []Certificate
|
||||||
|
|
||||||
|
// RootCAs defines the set of root certificate authorities
|
||||||
|
// that clients use when verifying server certificates.
|
||||||
|
// If RootCAs is nil, TLS uses the host's root CA set.
|
||||||
RootCAs *CASet
|
RootCAs *CASet
|
||||||
|
|
||||||
// NextProtos is a list of supported, application level protocols.
|
// NextProtos is a list of supported, application level protocols.
|
||||||
// Currently only server-side handling is supported.
|
// Currently only server-side handling is supported.
|
||||||
NextProtos []string
|
NextProtos []string
|
||||||
|
|
||||||
// ServerName is included in the client's handshake to support virtual
|
// ServerName is included in the client's handshake to support virtual
|
||||||
// hosting.
|
// hosting.
|
||||||
ServerName string
|
ServerName string
|
||||||
// AuthenticateClient determines if a server will request a certificate
|
|
||||||
|
// AuthenticateClient controls whether a server will request a certificate
|
||||||
// from the client. It does not require that the client send a
|
// from the client. It does not require that the client send a
|
||||||
// certificate nor, if it does, that the certificate is anything more
|
// certificate nor does it require that the certificate sent be
|
||||||
// than self-signed.
|
// anything more than self-signed.
|
||||||
AuthenticateClient bool
|
AuthenticateClient bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Config) rand() io.Reader {
|
||||||
|
r := c.Rand
|
||||||
|
if r == nil {
|
||||||
|
return rand.Reader
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) time() int64 {
|
||||||
|
t := c.Time
|
||||||
|
if t == nil {
|
||||||
|
t = time.Seconds
|
||||||
|
}
|
||||||
|
return t()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) rootCAs() *CASet {
|
||||||
|
s := c.RootCAs
|
||||||
|
if s == nil {
|
||||||
|
s = defaultRoots()
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Certificate is a chain of one or more certificates, leaf first.
|
||||||
type Certificate struct {
|
type Certificate struct {
|
||||||
// Certificate contains a chain of one or more certificates. Leaf
|
|
||||||
// certificate first.
|
|
||||||
Certificate [][]byte
|
Certificate [][]byte
|
||||||
PrivateKey *rsa.PrivateKey
|
PrivateKey *rsa.PrivateKey
|
||||||
}
|
}
|
||||||
@ -143,14 +181,10 @@ func mutualVersion(vers uint16) (uint16, bool) {
|
|||||||
return vers, true
|
return vers, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// The defaultConfig is used in place of a nil *Config in the TLS server and client.
|
var emptyConfig Config
|
||||||
var varDefaultConfig *Config
|
|
||||||
|
|
||||||
var once sync.Once
|
|
||||||
|
|
||||||
func defaultConfig() *Config {
|
func defaultConfig() *Config {
|
||||||
once.Do(initDefaultConfig)
|
return &emptyConfig
|
||||||
return varDefaultConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Possible certificate files; stop after finding one.
|
// Possible certificate files; stop after finding one.
|
||||||
@ -162,7 +196,16 @@ var certFiles = []string{
|
|||||||
"/usr/share/curl/curl-ca-bundle.crt", // OS X
|
"/usr/share/curl/curl-ca-bundle.crt", // OS X
|
||||||
}
|
}
|
||||||
|
|
||||||
func initDefaultConfig() {
|
var once sync.Once
|
||||||
|
|
||||||
|
func defaultRoots() *CASet {
|
||||||
|
once.Do(initDefaultRoots)
|
||||||
|
return varDefaultRoots
|
||||||
|
}
|
||||||
|
|
||||||
|
var varDefaultRoots *CASet
|
||||||
|
|
||||||
|
func initDefaultRoots() {
|
||||||
roots := NewCASet()
|
roots := NewCASet()
|
||||||
for _, file := range certFiles {
|
for _, file := range certFiles {
|
||||||
data, err := ioutil.ReadFile(file)
|
data, err := ioutil.ReadFile(file)
|
||||||
@ -171,10 +214,5 @@ func initDefaultConfig() {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
varDefaultRoots = roots
|
||||||
varDefaultConfig = &Config{
|
|
||||||
Rand: rand.Reader,
|
|
||||||
Time: time.Seconds,
|
|
||||||
RootCAs: roots,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -30,12 +30,12 @@ func (c *Conn) clientHandshake() os.Error {
|
|||||||
serverName: c.config.ServerName,
|
serverName: c.config.ServerName,
|
||||||
}
|
}
|
||||||
|
|
||||||
t := uint32(c.config.Time())
|
t := uint32(c.config.time())
|
||||||
hello.random[0] = byte(t >> 24)
|
hello.random[0] = byte(t >> 24)
|
||||||
hello.random[1] = byte(t >> 16)
|
hello.random[1] = byte(t >> 16)
|
||||||
hello.random[2] = byte(t >> 8)
|
hello.random[2] = byte(t >> 8)
|
||||||
hello.random[3] = byte(t)
|
hello.random[3] = byte(t)
|
||||||
_, err := io.ReadFull(c.config.Rand, hello.random[4:])
|
_, err := io.ReadFull(c.config.rand(), hello.random[4:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.sendAlert(alertInternalError)
|
c.sendAlert(alertInternalError)
|
||||||
return os.ErrorString("short read from Rand")
|
return os.ErrorString("short read from Rand")
|
||||||
@ -217,12 +217,12 @@ func (c *Conn) clientHandshake() os.Error {
|
|||||||
preMasterSecret := make([]byte, 48)
|
preMasterSecret := make([]byte, 48)
|
||||||
preMasterSecret[0] = byte(hello.vers >> 8)
|
preMasterSecret[0] = byte(hello.vers >> 8)
|
||||||
preMasterSecret[1] = byte(hello.vers)
|
preMasterSecret[1] = byte(hello.vers)
|
||||||
_, err = io.ReadFull(c.config.Rand, preMasterSecret[2:])
|
_, err = io.ReadFull(c.config.rand(), preMasterSecret[2:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.sendAlert(alertInternalError)
|
return c.sendAlert(alertInternalError)
|
||||||
}
|
}
|
||||||
|
|
||||||
ckx.ciphertext, err = rsa.EncryptPKCS1v15(c.config.Rand, pub, preMasterSecret)
|
ckx.ciphertext, err = rsa.EncryptPKCS1v15(c.config.rand(), pub, preMasterSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.sendAlert(alertInternalError)
|
return c.sendAlert(alertInternalError)
|
||||||
}
|
}
|
||||||
@ -235,7 +235,7 @@ func (c *Conn) clientHandshake() os.Error {
|
|||||||
var digest [36]byte
|
var digest [36]byte
|
||||||
copy(digest[0:16], finishedHash.serverMD5.Sum())
|
copy(digest[0:16], finishedHash.serverMD5.Sum())
|
||||||
copy(digest[16:36], finishedHash.serverSHA1.Sum())
|
copy(digest[16:36], finishedHash.serverSHA1.Sum())
|
||||||
signed, err := rsa.SignPKCS1v15(c.config.Rand, c.config.Certificates[0].PrivateKey, rsa.HashMD5SHA1, digest[0:])
|
signed, err := rsa.SignPKCS1v15(c.config.rand(), c.config.Certificates[0].PrivateKey, rsa.HashMD5SHA1, digest[0:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.sendAlert(alertInternalError)
|
return c.sendAlert(alertInternalError)
|
||||||
}
|
}
|
||||||
|
@ -84,13 +84,13 @@ func (c *Conn) serverHandshake() os.Error {
|
|||||||
|
|
||||||
hello.vers = vers
|
hello.vers = vers
|
||||||
hello.cipherSuite = suite.id
|
hello.cipherSuite = suite.id
|
||||||
t := uint32(config.Time())
|
t := uint32(config.time())
|
||||||
hello.random = make([]byte, 32)
|
hello.random = make([]byte, 32)
|
||||||
hello.random[0] = byte(t >> 24)
|
hello.random[0] = byte(t >> 24)
|
||||||
hello.random[1] = byte(t >> 16)
|
hello.random[1] = byte(t >> 16)
|
||||||
hello.random[2] = byte(t >> 8)
|
hello.random[2] = byte(t >> 8)
|
||||||
hello.random[3] = byte(t)
|
hello.random[3] = byte(t)
|
||||||
_, err = io.ReadFull(config.Rand, hello.random[4:])
|
_, err = io.ReadFull(config.rand(), hello.random[4:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.sendAlert(alertInternalError)
|
return c.sendAlert(alertInternalError)
|
||||||
}
|
}
|
||||||
@ -209,12 +209,12 @@ func (c *Conn) serverHandshake() os.Error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
preMasterSecret := make([]byte, 48)
|
preMasterSecret := make([]byte, 48)
|
||||||
_, err = io.ReadFull(config.Rand, preMasterSecret[2:])
|
_, err = io.ReadFull(config.rand(), preMasterSecret[2:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.sendAlert(alertInternalError)
|
return c.sendAlert(alertInternalError)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = rsa.DecryptPKCS1v15SessionKey(config.Rand, config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret)
|
err = rsa.DecryptPKCS1v15SessionKey(config.rand(), config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.sendAlert(alertHandshakeFailure)
|
return c.sendAlert(alertHandshakeFailure)
|
||||||
}
|
}
|
||||||
|
46
tls.go
46
tls.go
@ -15,19 +15,31 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Server returns a new TLS server side connection
|
||||||
|
// using conn as the underlying transport.
|
||||||
|
// The configuration config must be non-nil and must have
|
||||||
|
// at least one certificate.
|
||||||
func Server(conn net.Conn, config *Config) *Conn {
|
func Server(conn net.Conn, config *Config) *Conn {
|
||||||
return &Conn{conn: conn, config: config}
|
return &Conn{conn: conn, config: config}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Client returns a new TLS client side connection
|
||||||
|
// using conn as the underlying transport.
|
||||||
|
// Client interprets a nil configuration as equivalent to
|
||||||
|
// the zero configuration; see the documentation of Config
|
||||||
|
// for the defaults.
|
||||||
func Client(conn net.Conn, config *Config) *Conn {
|
func Client(conn net.Conn, config *Config) *Conn {
|
||||||
return &Conn{conn: conn, config: config, isClient: true}
|
return &Conn{conn: conn, config: config, isClient: true}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A Listener implements a network listener (net.Listener) for TLS connections.
|
||||||
type Listener struct {
|
type Listener struct {
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
config *Config
|
config *Config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Accept waits for and returns the next incoming TLS connection.
|
||||||
|
// The returned connection c is a *tls.Conn.
|
||||||
func (l *Listener) Accept() (c net.Conn, err os.Error) {
|
func (l *Listener) Accept() (c net.Conn, err os.Error) {
|
||||||
c, err = l.listener.Accept()
|
c, err = l.listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -37,8 +49,10 @@ func (l *Listener) Accept() (c net.Conn, err os.Error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the listener.
|
||||||
func (l *Listener) Close() os.Error { return l.listener.Close() }
|
func (l *Listener) Close() os.Error { return l.listener.Close() }
|
||||||
|
|
||||||
|
// Addr returns the listener's network address.
|
||||||
func (l *Listener) Addr() net.Addr { return l.listener.Addr() }
|
func (l *Listener) Addr() net.Addr { return l.listener.Addr() }
|
||||||
|
|
||||||
// NewListener creates a Listener which accepts connections from an inner
|
// NewListener creates a Listener which accepts connections from an inner
|
||||||
@ -52,7 +66,11 @@ func NewListener(listener net.Listener, config *Config) (l *Listener) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func Listen(network, laddr string, config *Config) (net.Listener, os.Error) {
|
// Listen creates a TLS listener accepting connections on the
|
||||||
|
// given network address using net.Listen.
|
||||||
|
// The configuration config must be non-nil and must have
|
||||||
|
// at least one certificate.
|
||||||
|
func Listen(network, laddr string, config *Config) (*Listener, os.Error) {
|
||||||
if config == nil || len(config.Certificates) == 0 {
|
if config == nil || len(config.Certificates) == 0 {
|
||||||
return nil, os.NewError("tls.Listen: no certificates in configuration")
|
return nil, os.NewError("tls.Listen: no certificates in configuration")
|
||||||
}
|
}
|
||||||
@ -63,7 +81,13 @@ func Listen(network, laddr string, config *Config) (net.Listener, os.Error) {
|
|||||||
return NewListener(l, config), nil
|
return NewListener(l, config), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Dial(network, laddr, raddr string) (net.Conn, os.Error) {
|
// Dial connects to the given network address using net.Dial
|
||||||
|
// and then initiates a TLS handshake, returning the resulting
|
||||||
|
// TLS connection.
|
||||||
|
// Dial interprets a nil configuration as equivalent to
|
||||||
|
// the zero configuration; see the documentation of Config
|
||||||
|
// for the defaults.
|
||||||
|
func Dial(network, laddr, raddr string, config *Config) (*Conn, os.Error) {
|
||||||
c, err := net.Dial(network, laddr, raddr)
|
c, err := net.Dial(network, laddr, raddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -75,16 +99,22 @@ func Dial(network, laddr, raddr string) (net.Conn, os.Error) {
|
|||||||
}
|
}
|
||||||
hostname := raddr[:colonPos]
|
hostname := raddr[:colonPos]
|
||||||
|
|
||||||
config := defaultConfig()
|
if config == nil {
|
||||||
config.ServerName = hostname
|
config = defaultConfig()
|
||||||
conn := Client(c, config)
|
|
||||||
err = conn.Handshake()
|
|
||||||
if err == nil {
|
|
||||||
return conn, nil
|
|
||||||
}
|
}
|
||||||
|
if config.ServerName != "" {
|
||||||
|
// Make a copy to avoid polluting argument or default.
|
||||||
|
c := *config
|
||||||
|
c.ServerName = hostname
|
||||||
|
config = &c
|
||||||
|
}
|
||||||
|
conn := Client(c, config)
|
||||||
|
if err = conn.Handshake(); err != nil {
|
||||||
c.Close()
|
c.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
// LoadX509KeyPair reads and parses a public/private key pair from a pair of
|
// LoadX509KeyPair reads and parses a public/private key pair from a pair of
|
||||||
// files. The files must contain PEM encoded data.
|
// files. The files must contain PEM encoded data.
|
||||||
|
Loading…
Reference in New Issue
Block a user