diff --git a/common.go b/common.go index d47dc61..c68ebfe 100644 --- a/common.go +++ b/common.go @@ -286,7 +286,8 @@ type Config struct { // ServerName is used to verify the hostname on the returned // certificates unless InsecureSkipVerify is given. It is also included - // in the client's handshake to support virtual hosting. + // in the client's handshake to support virtual hosting unless it is + // an IP address. ServerName string // ClientAuth determines the server's policy for diff --git a/handshake_client.go b/handshake_client.go index 0b591d7..462acfd 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -49,13 +49,20 @@ func (c *Conn) clientHandshake() error { return errors.New("tls: NextProtos values too large") } + sni := c.config.ServerName + // IP address literals are not permitted as SNI values. See + // https://tools.ietf.org/html/rfc6066#section-3. + if net.ParseIP(sni) != nil { + sni = "" + } + hello := &clientHelloMsg{ vers: c.config.maxVersion(), compressionMethods: []uint8{compressionNone}, random: make([]byte, 32), ocspStapling: true, scts: true, - serverName: c.config.ServerName, + serverName: sni, supportedCurves: c.config.curvePreferences(), supportedPoints: []uint8{pointFormatUncompressed}, nextProtoNeg: len(c.config.NextProtos) > 0, diff --git a/handshake_client_test.go b/handshake_client_test.go index 664fe8d..b275da1 100644 --- a/handshake_client_test.go +++ b/handshake_client_test.go @@ -600,3 +600,30 @@ func TestHandshakClientSCTs(t *testing.T) { } runClientTestTLS12(t, test) } + +func TestNoIPAddressesInSNI(t *testing.T) { + for _, ipLiteral := range []string{"1.2.3.4", "::1"} { + c, s := net.Pipe() + + go func() { + client := Client(c, &Config{ServerName: ipLiteral}) + client.Handshake() + }() + + var header [5]byte + if _, err := io.ReadFull(s, header[:]); err != nil { + t.Fatal(err) + } + recordLen := int(header[3])<<8 | int(header[4]) + + record := make([]byte, recordLen) + if _, err := io.ReadFull(s, record[:]); err != nil { + t.Fatal(err) + } + s.Close() + + if bytes.Index(record, []byte(ipLiteral)) != -1 { + t.Errorf("IP literal %q found in ClientHello: %x", ipLiteral, record) + } + } +}