diff --git a/common.go b/common.go index 81b5a07..c779234 100644 --- a/common.go +++ b/common.go @@ -93,9 +93,10 @@ const ( // ConnectionState records basic TLS details about the connection. type ConnectionState struct { - HandshakeComplete bool - CipherSuite uint16 - NegotiatedProtocol string + HandshakeComplete bool + CipherSuite uint16 + NegotiatedProtocol string + NegotiatedProtocolIsMutual bool // the certificate chain that was presented by the other side PeerCertificates []*x509.Certificate @@ -124,7 +125,6 @@ type Config struct { RootCAs *CASet // NextProtos is a list of supported, application level protocols. - // Currently only server-side handling is supported. NextProtos []string // ServerName is included in the client's handshake to support virtual diff --git a/conn.go b/conn.go index 1e6fe60..b94e235 100644 --- a/conn.go +++ b/conn.go @@ -35,7 +35,8 @@ type Conn struct { ocspResponse []byte // stapled OCSP response peerCertificates []*x509.Certificate - clientProtocol string + clientProtocol string + clientProtocolFallback bool // first permanent error errMutex sync.Mutex @@ -761,6 +762,7 @@ func (c *Conn) ConnectionState() ConnectionState { state.HandshakeComplete = c.handshakeComplete if c.handshakeComplete { state.NegotiatedProtocol = c.clientProtocol + state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback state.CipherSuite = c.cipherSuite state.PeerCertificates = c.peerCertificates } diff --git a/handshake_client.go b/handshake_client.go index a325a9b..540b25c 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -29,6 +29,7 @@ func (c *Conn) clientHandshake() os.Error { serverName: c.config.ServerName, supportedCurves: []uint16{curveP256, curveP384, curveP521}, supportedPoints: []uint8{pointFormatUncompressed}, + nextProtoNeg: len(c.config.NextProtos) > 0, } t := uint32(c.config.time()) @@ -66,6 +67,11 @@ func (c *Conn) clientHandshake() os.Error { return c.sendAlert(alertUnexpectedMessage) } + if !hello.nextProtoNeg && serverHello.nextProtoNeg { + c.sendAlert(alertHandshakeFailure) + return os.ErrorString("server advertised unrequested NPN") + } + suite, suiteId := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite) if suite == nil { return c.sendAlert(alertHandshakeFailure) @@ -267,6 +273,17 @@ func (c *Conn) clientHandshake() os.Error { c.out.prepareCipherSpec(clientCipher, clientHash) c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + if serverHello.nextProtoNeg { + nextProto := new(nextProtoMsg) + proto, fallback := mutualProtocol(c.config.NextProtos, serverHello.nextProtos) + nextProto.proto = proto + c.clientProtocol = proto + c.clientProtocolFallback = fallback + + finishedHash.Write(nextProto.marshal()) + c.writeRecord(recordTypeHandshake, nextProto.marshal()) + } + finished := new(finishedMsg) finished.verifyData = finishedHash.clientSum(masterSecret) finishedHash.Write(finished.marshal()) @@ -299,3 +316,19 @@ func (c *Conn) clientHandshake() os.Error { c.cipherSuite = suiteId return nil } + +// mutualProtocol finds the mutual Next Protocol Negotiation protocol given the +// set of client and server supported protocols. The set of client supported +// protocols must not be empty. It returns the resulting protocol and flag +// indicating if the fallback case was reached. +func mutualProtocol(clientProtos, serverProtos []string) (string, bool) { + for _, s := range serverProtos { + for _, c := range clientProtos { + if s == c { + return s, false + } + } + } + + return clientProtos[0], true +}