Alternative TLS implementation in Go
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

1584 lignes
43 KiB

  1. // Copyright 2010 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package tls
  5. import (
  6. "bytes"
  7. "crypto/ecdsa"
  8. "crypto/rsa"
  9. "crypto/x509"
  10. "encoding/base64"
  11. "encoding/binary"
  12. "encoding/pem"
  13. "errors"
  14. "fmt"
  15. "io"
  16. "math/big"
  17. "net"
  18. "os"
  19. "os/exec"
  20. "path/filepath"
  21. "strconv"
  22. "strings"
  23. "sync"
  24. "testing"
  25. "time"
  26. )
  27. // Note: see comment in handshake_test.go for details of how the reference
  28. // tests work.
  29. // opensslInputEvent enumerates possible inputs that can be sent to an `openssl
  30. // s_client` process.
  31. type opensslInputEvent int
  32. const (
  33. // opensslRenegotiate causes OpenSSL to request a renegotiation of the
  34. // connection.
  35. opensslRenegotiate opensslInputEvent = iota
  36. // opensslSendBanner causes OpenSSL to send the contents of
  37. // opensslSentinel on the connection.
  38. opensslSendSentinel
  39. )
  40. const opensslSentinel = "SENTINEL\n"
  41. type opensslInput chan opensslInputEvent
  42. func (i opensslInput) Read(buf []byte) (n int, err error) {
  43. for event := range i {
  44. switch event {
  45. case opensslRenegotiate:
  46. return copy(buf, []byte("R\n")), nil
  47. case opensslSendSentinel:
  48. return copy(buf, []byte(opensslSentinel)), nil
  49. default:
  50. panic("unknown event")
  51. }
  52. }
  53. return 0, io.EOF
  54. }
  55. // opensslOutputSink is an io.Writer that receives the stdout and stderr from
  56. // an `openssl` process and sends a value to handshakeConfirmed when it sees a
  57. // log message from a completed server handshake.
  58. type opensslOutputSink struct {
  59. handshakeConfirmed chan struct{}
  60. all []byte
  61. line []byte
  62. }
  63. func newOpensslOutputSink() *opensslOutputSink {
  64. return &opensslOutputSink{make(chan struct{}), nil, nil}
  65. }
  66. // opensslEndOfHandshake is a message that the “openssl s_server” tool will
  67. // print when a handshake completes if run with “-state”.
  68. const opensslEndOfHandshake = "SSL_accept:SSLv3/TLS write finished"
  69. func (o *opensslOutputSink) Write(data []byte) (n int, err error) {
  70. o.line = append(o.line, data...)
  71. o.all = append(o.all, data...)
  72. for {
  73. i := bytes.IndexByte(o.line, '\n')
  74. if i < 0 {
  75. break
  76. }
  77. if bytes.Equal([]byte(opensslEndOfHandshake), o.line[:i]) {
  78. o.handshakeConfirmed <- struct{}{}
  79. }
  80. o.line = o.line[i+1:]
  81. }
  82. return len(data), nil
  83. }
  84. func (o *opensslOutputSink) WriteTo(w io.Writer) (int64, error) {
  85. n, err := w.Write(o.all)
  86. return int64(n), err
  87. }
  88. // clientTest represents a test of the TLS client handshake against a reference
  89. // implementation.
  90. type clientTest struct {
  91. // name is a freeform string identifying the test and the file in which
  92. // the expected results will be stored.
  93. name string
  94. // command, if not empty, contains a series of arguments for the
  95. // command to run for the reference server.
  96. command []string
  97. // config, if not nil, contains a custom Config to use for this test.
  98. config *Config
  99. // cert, if not empty, contains a DER-encoded certificate for the
  100. // reference server.
  101. cert []byte
  102. // key, if not nil, contains either a *rsa.PrivateKey or
  103. // *ecdsa.PrivateKey which is the private key for the reference server.
  104. key interface{}
  105. // extensions, if not nil, contains a list of extension data to be returned
  106. // from the ServerHello. The data should be in standard TLS format with
  107. // a 2-byte uint16 type, 2-byte data length, followed by the extension data.
  108. extensions [][]byte
  109. // validate, if not nil, is a function that will be called with the
  110. // ConnectionState of the resulting connection. It returns a non-nil
  111. // error if the ConnectionState is unacceptable.
  112. validate func(ConnectionState) error
  113. // numRenegotiations is the number of times that the connection will be
  114. // renegotiated.
  115. numRenegotiations int
  116. // renegotiationExpectedToFail, if not zero, is the number of the
  117. // renegotiation attempt that is expected to fail.
  118. renegotiationExpectedToFail int
  119. // checkRenegotiationError, if not nil, is called with any error
  120. // arising from renegotiation. It can map expected errors to nil to
  121. // ignore them.
  122. checkRenegotiationError func(renegotiationNum int, err error) error
  123. }
  124. var defaultServerCommand = []string{"openssl", "s_server"}
  125. // connFromCommand starts the reference server process, connects to it and
  126. // returns a recordingConn for the connection. The stdin return value is an
  127. // opensslInput for the stdin of the child process. It must be closed before
  128. // Waiting for child.
  129. func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) {
  130. cert := testRSACertificate
  131. if len(test.cert) > 0 {
  132. cert = test.cert
  133. }
  134. certPath := tempFile(string(cert))
  135. defer os.Remove(certPath)
  136. var key interface{} = testRSAPrivateKey
  137. if test.key != nil {
  138. key = test.key
  139. }
  140. var pemType string
  141. var derBytes []byte
  142. switch key := key.(type) {
  143. case *rsa.PrivateKey:
  144. pemType = "RSA"
  145. derBytes = x509.MarshalPKCS1PrivateKey(key)
  146. case *ecdsa.PrivateKey:
  147. pemType = "EC"
  148. var err error
  149. derBytes, err = x509.MarshalECPrivateKey(key)
  150. if err != nil {
  151. panic(err)
  152. }
  153. default:
  154. panic("unknown key type")
  155. }
  156. var pemOut bytes.Buffer
  157. pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes})
  158. keyPath := tempFile(string(pemOut.Bytes()))
  159. defer os.Remove(keyPath)
  160. var command []string
  161. if len(test.command) > 0 {
  162. command = append(command, test.command...)
  163. } else {
  164. command = append(command, defaultServerCommand...)
  165. }
  166. command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath)
  167. // serverPort contains the port that OpenSSL will listen on. OpenSSL
  168. // can't take "0" as an argument here so we have to pick a number and
  169. // hope that it's not in use on the machine. Since this only occurs
  170. // when -update is given and thus when there's a human watching the
  171. // test, this isn't too bad.
  172. const serverPort = 24323
  173. command = append(command, "-accept", strconv.Itoa(serverPort))
  174. if len(test.extensions) > 0 {
  175. var serverInfo bytes.Buffer
  176. for _, ext := range test.extensions {
  177. pem.Encode(&serverInfo, &pem.Block{
  178. Type: fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)),
  179. Bytes: ext,
  180. })
  181. }
  182. serverInfoPath := tempFile(serverInfo.String())
  183. defer os.Remove(serverInfoPath)
  184. command = append(command, "-serverinfo", serverInfoPath)
  185. }
  186. if test.numRenegotiations > 0 {
  187. found := false
  188. for _, flag := range command[1:] {
  189. if flag == "-state" {
  190. found = true
  191. break
  192. }
  193. }
  194. if !found {
  195. panic("-state flag missing to OpenSSL. You need this if testing renegotiation")
  196. }
  197. }
  198. cmd := exec.Command(command[0], command[1:]...)
  199. stdin = opensslInput(make(chan opensslInputEvent))
  200. cmd.Stdin = stdin
  201. out := newOpensslOutputSink()
  202. cmd.Stdout = out
  203. cmd.Stderr = out
  204. if err := cmd.Start(); err != nil {
  205. return nil, nil, nil, nil, err
  206. }
  207. // OpenSSL does print an "ACCEPT" banner, but it does so *before*
  208. // opening the listening socket, so we can't use that to wait until it
  209. // has started listening. Thus we are forced to poll until we get a
  210. // connection.
  211. var tcpConn net.Conn
  212. for i := uint(0); i < 5; i++ {
  213. tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
  214. IP: net.IPv4(127, 0, 0, 1),
  215. Port: serverPort,
  216. })
  217. if err == nil {
  218. break
  219. }
  220. time.Sleep((1 << i) * 5 * time.Millisecond)
  221. }
  222. if err != nil {
  223. close(stdin)
  224. out.WriteTo(os.Stdout)
  225. cmd.Process.Kill()
  226. return nil, nil, nil, nil, cmd.Wait()
  227. }
  228. record := &recordingConn{
  229. Conn: tcpConn,
  230. }
  231. return record, cmd, stdin, out, nil
  232. }
  233. func (test *clientTest) dataPath() string {
  234. return filepath.Join("testdata", "Client-"+test.name)
  235. }
  236. func (test *clientTest) loadData() (flows [][]byte, err error) {
  237. in, err := os.Open(test.dataPath())
  238. if err != nil {
  239. return nil, err
  240. }
  241. defer in.Close()
  242. return parseTestData(in)
  243. }
  244. func (test *clientTest) run(t *testing.T, write bool) {
  245. checkOpenSSLVersion(t)
  246. var clientConn, serverConn net.Conn
  247. var recordingConn *recordingConn
  248. var childProcess *exec.Cmd
  249. var stdin opensslInput
  250. var stdout *opensslOutputSink
  251. if write {
  252. var err error
  253. recordingConn, childProcess, stdin, stdout, err = test.connFromCommand()
  254. if err != nil {
  255. t.Fatalf("Failed to start subcommand: %s", err)
  256. }
  257. clientConn = recordingConn
  258. } else {
  259. clientConn, serverConn = net.Pipe()
  260. }
  261. config := test.config
  262. if config == nil {
  263. config = testConfig
  264. }
  265. client := Client(clientConn, config)
  266. doneChan := make(chan bool)
  267. go func() {
  268. defer func() { doneChan <- true }()
  269. defer clientConn.Close()
  270. defer client.Close()
  271. if _, err := client.Write([]byte("hello\n")); err != nil {
  272. t.Errorf("Client.Write failed: %s", err)
  273. return
  274. }
  275. for i := 1; i <= test.numRenegotiations; i++ {
  276. // The initial handshake will generate a
  277. // handshakeConfirmed signal which needs to be quashed.
  278. if i == 1 && write {
  279. <-stdout.handshakeConfirmed
  280. }
  281. // OpenSSL will try to interleave application data and
  282. // a renegotiation if we send both concurrently.
  283. // Therefore: ask OpensSSL to start a renegotiation, run
  284. // a goroutine to call client.Read and thus process the
  285. // renegotiation request, watch for OpenSSL's stdout to
  286. // indicate that the handshake is complete and,
  287. // finally, have OpenSSL write something to cause
  288. // client.Read to complete.
  289. if write {
  290. stdin <- opensslRenegotiate
  291. }
  292. signalChan := make(chan struct{})
  293. go func() {
  294. defer func() { signalChan <- struct{}{} }()
  295. buf := make([]byte, 256)
  296. n, err := client.Read(buf)
  297. if test.checkRenegotiationError != nil {
  298. newErr := test.checkRenegotiationError(i, err)
  299. if err != nil && newErr == nil {
  300. return
  301. }
  302. err = newErr
  303. }
  304. if err != nil {
  305. t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
  306. return
  307. }
  308. buf = buf[:n]
  309. if !bytes.Equal([]byte(opensslSentinel), buf) {
  310. t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
  311. }
  312. if expected := i + 1; client.handshakes != expected {
  313. t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
  314. }
  315. }()
  316. if write && test.renegotiationExpectedToFail != i {
  317. <-stdout.handshakeConfirmed
  318. stdin <- opensslSendSentinel
  319. }
  320. <-signalChan
  321. }
  322. if test.validate != nil {
  323. if err := test.validate(client.ConnectionState()); err != nil {
  324. t.Errorf("validate callback returned error: %s", err)
  325. }
  326. }
  327. }()
  328. if !write {
  329. flows, err := test.loadData()
  330. if err != nil {
  331. t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
  332. }
  333. for i, b := range flows {
  334. if i%2 == 1 {
  335. serverConn.Write(b)
  336. continue
  337. }
  338. bb := make([]byte, len(b))
  339. _, err := io.ReadFull(serverConn, bb)
  340. if err != nil {
  341. t.Fatalf("%s #%d: %s", test.name, i, err)
  342. }
  343. if !bytes.Equal(b, bb) {
  344. t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i, bb, b)
  345. }
  346. }
  347. serverConn.Close()
  348. }
  349. <-doneChan
  350. if write {
  351. path := test.dataPath()
  352. out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
  353. if err != nil {
  354. t.Fatalf("Failed to create output file: %s", err)
  355. }
  356. defer out.Close()
  357. recordingConn.Close()
  358. close(stdin)
  359. childProcess.Process.Kill()
  360. childProcess.Wait()
  361. if len(recordingConn.flows) < 3 {
  362. os.Stdout.Write(childProcess.Stdout.(*opensslOutputSink).all)
  363. t.Fatalf("Client connection didn't work")
  364. }
  365. recordingConn.WriteTo(out)
  366. fmt.Printf("Wrote %s\n", path)
  367. }
  368. }
  369. var (
  370. didParMu sync.Mutex
  371. didPar = map[*testing.T]bool{}
  372. )
  373. // setParallel calls t.Parallel once. If you call it twice, it would
  374. // panic.
  375. func setParallel(t *testing.T) {
  376. didParMu.Lock()
  377. v := didPar[t]
  378. didPar[t] = true
  379. didParMu.Unlock()
  380. if !v {
  381. t.Parallel()
  382. }
  383. }
  384. func runClientTestForVersion(t *testing.T, template *clientTest, prefix, option string) {
  385. setParallel(t)
  386. test := *template
  387. test.name = prefix + test.name
  388. if len(test.command) == 0 {
  389. test.command = defaultClientCommand
  390. }
  391. test.command = append([]string(nil), test.command...)
  392. test.command = append(test.command, option)
  393. test.run(t, *update)
  394. }
  395. func runClientTestTLS10(t *testing.T, template *clientTest) {
  396. runClientTestForVersion(t, template, "TLSv10-", "-tls1")
  397. }
  398. func runClientTestTLS11(t *testing.T, template *clientTest) {
  399. runClientTestForVersion(t, template, "TLSv11-", "-tls1_1")
  400. }
  401. func runClientTestTLS12(t *testing.T, template *clientTest) {
  402. runClientTestForVersion(t, template, "TLSv12-", "-tls1_2")
  403. }
  404. func TestHandshakeClientRSARC4(t *testing.T) {
  405. test := &clientTest{
  406. name: "RSA-RC4",
  407. command: []string{"openssl", "s_server", "-cipher", "RC4-SHA"},
  408. }
  409. runClientTestTLS10(t, test)
  410. runClientTestTLS11(t, test)
  411. runClientTestTLS12(t, test)
  412. }
  413. func TestHandshakeClientRSAAES128GCM(t *testing.T) {
  414. test := &clientTest{
  415. name: "AES128-GCM-SHA256",
  416. command: []string{"openssl", "s_server", "-cipher", "AES128-GCM-SHA256"},
  417. }
  418. runClientTestTLS12(t, test)
  419. }
  420. func TestHandshakeClientRSAAES256GCM(t *testing.T) {
  421. test := &clientTest{
  422. name: "AES256-GCM-SHA384",
  423. command: []string{"openssl", "s_server", "-cipher", "AES256-GCM-SHA384"},
  424. }
  425. runClientTestTLS12(t, test)
  426. }
  427. func TestHandshakeClientECDHERSAAES(t *testing.T) {
  428. test := &clientTest{
  429. name: "ECDHE-RSA-AES",
  430. command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA"},
  431. }
  432. runClientTestTLS10(t, test)
  433. runClientTestTLS11(t, test)
  434. runClientTestTLS12(t, test)
  435. }
  436. func TestHandshakeClientECDHEECDSAAES(t *testing.T) {
  437. test := &clientTest{
  438. name: "ECDHE-ECDSA-AES",
  439. command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA"},
  440. cert: testECDSACertificate,
  441. key: testECDSAPrivateKey,
  442. }
  443. runClientTestTLS10(t, test)
  444. runClientTestTLS11(t, test)
  445. runClientTestTLS12(t, test)
  446. }
  447. func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) {
  448. test := &clientTest{
  449. name: "ECDHE-ECDSA-AES-GCM",
  450. command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"},
  451. cert: testECDSACertificate,
  452. key: testECDSAPrivateKey,
  453. }
  454. runClientTestTLS12(t, test)
  455. }
  456. func TestHandshakeClientAES256GCMSHA384(t *testing.T) {
  457. test := &clientTest{
  458. name: "ECDHE-ECDSA-AES256-GCM-SHA384",
  459. command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"},
  460. cert: testECDSACertificate,
  461. key: testECDSAPrivateKey,
  462. }
  463. runClientTestTLS12(t, test)
  464. }
  465. func TestHandshakeClientAES128CBCSHA256(t *testing.T) {
  466. test := &clientTest{
  467. name: "AES128-SHA256",
  468. command: []string{"openssl", "s_server", "-cipher", "AES128-SHA256"},
  469. }
  470. runClientTestTLS12(t, test)
  471. }
  472. func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) {
  473. test := &clientTest{
  474. name: "ECDHE-RSA-AES128-SHA256",
  475. command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA256"},
  476. }
  477. runClientTestTLS12(t, test)
  478. }
  479. func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) {
  480. test := &clientTest{
  481. name: "ECDHE-ECDSA-AES128-SHA256",
  482. command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA256"},
  483. cert: testECDSACertificate,
  484. key: testECDSAPrivateKey,
  485. }
  486. runClientTestTLS12(t, test)
  487. }
  488. func TestHandshakeClientX25519(t *testing.T) {
  489. config := testConfig.Clone()
  490. config.CurvePreferences = []CurveID{X25519}
  491. test := &clientTest{
  492. name: "X25519-ECDHE-RSA-AES-GCM",
  493. command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-GCM-SHA256"},
  494. config: config,
  495. }
  496. runClientTestTLS12(t, test)
  497. }
  498. func TestHandshakeClientECDHERSAChaCha20(t *testing.T) {
  499. config := testConfig.Clone()
  500. config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}
  501. test := &clientTest{
  502. name: "ECDHE-RSA-CHACHA20-POLY1305",
  503. command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-CHACHA20-POLY1305"},
  504. config: config,
  505. }
  506. runClientTestTLS12(t, test)
  507. }
  508. func TestHandshakeClientECDHEECDSAChaCha20(t *testing.T) {
  509. config := testConfig.Clone()
  510. config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305}
  511. test := &clientTest{
  512. name: "ECDHE-ECDSA-CHACHA20-POLY1305",
  513. command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305"},
  514. config: config,
  515. cert: testECDSACertificate,
  516. key: testECDSAPrivateKey,
  517. }
  518. runClientTestTLS12(t, test)
  519. }
  520. func TestHandshakeClientCertRSA(t *testing.T) {
  521. config := testConfig.Clone()
  522. cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
  523. config.Certificates = []Certificate{cert}
  524. test := &clientTest{
  525. name: "ClientCert-RSA-RSA",
  526. command: []string{"openssl", "s_server", "-cipher", "AES128", "-verify", "1"},
  527. config: config,
  528. }
  529. runClientTestTLS10(t, test)
  530. runClientTestTLS12(t, test)
  531. test = &clientTest{
  532. name: "ClientCert-RSA-ECDSA",
  533. command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"},
  534. config: config,
  535. cert: testECDSACertificate,
  536. key: testECDSAPrivateKey,
  537. }
  538. runClientTestTLS10(t, test)
  539. runClientTestTLS12(t, test)
  540. test = &clientTest{
  541. name: "ClientCert-RSA-AES256-GCM-SHA384",
  542. command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-verify", "1"},
  543. config: config,
  544. cert: testRSACertificate,
  545. key: testRSAPrivateKey,
  546. }
  547. runClientTestTLS12(t, test)
  548. }
  549. func TestHandshakeClientCertECDSA(t *testing.T) {
  550. config := testConfig.Clone()
  551. cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
  552. config.Certificates = []Certificate{cert}
  553. test := &clientTest{
  554. name: "ClientCert-ECDSA-RSA",
  555. command: []string{"openssl", "s_server", "-cipher", "AES128", "-verify", "1"},
  556. config: config,
  557. }
  558. runClientTestTLS10(t, test)
  559. runClientTestTLS12(t, test)
  560. test = &clientTest{
  561. name: "ClientCert-ECDSA-ECDSA",
  562. command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"},
  563. config: config,
  564. cert: testECDSACertificate,
  565. key: testECDSAPrivateKey,
  566. }
  567. runClientTestTLS10(t, test)
  568. runClientTestTLS12(t, test)
  569. }
  570. // This test is specific to TLS versions which support session tickets (TLSv1.2 and below).
  571. // Session tickets are obsolete in TLSv1.3 (see 2.2 of TLS RFC)
  572. func TestClientResumption(t *testing.T) {
  573. serverConfig := &Config{
  574. CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
  575. Certificates: testConfig.Certificates,
  576. }
  577. issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
  578. if err != nil {
  579. panic(err)
  580. }
  581. rootCAs := x509.NewCertPool()
  582. rootCAs.AddCert(issuer)
  583. clientConfig := &Config{
  584. CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
  585. ClientSessionCache: NewLRUClientSessionCache(32),
  586. RootCAs: rootCAs,
  587. ServerName: "example.golang",
  588. MaxVersion: VersionTLS12, // Enforce TLSv1.2
  589. }
  590. testResumeState := func(test string, didResume bool) {
  591. _, hs, err := testHandshake(clientConfig, serverConfig)
  592. if err != nil {
  593. t.Fatalf("%s: handshake failed: %s", test, err)
  594. }
  595. if hs.DidResume != didResume {
  596. t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
  597. }
  598. if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
  599. t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
  600. }
  601. }
  602. getTicket := func() []byte {
  603. return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket
  604. }
  605. randomKey := func() [32]byte {
  606. var k [32]byte
  607. if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil {
  608. t.Fatalf("Failed to read new SessionTicketKey: %s", err)
  609. }
  610. return k
  611. }
  612. testResumeState("Handshake", false)
  613. ticket := getTicket()
  614. testResumeState("Resume", true)
  615. if !bytes.Equal(ticket, getTicket()) {
  616. t.Fatal("first ticket doesn't match ticket after resumption")
  617. }
  618. key1 := randomKey()
  619. serverConfig.SetSessionTicketKeys([][32]byte{key1})
  620. testResumeState("InvalidSessionTicketKey", false)
  621. testResumeState("ResumeAfterInvalidSessionTicketKey", true)
  622. key2 := randomKey()
  623. serverConfig.SetSessionTicketKeys([][32]byte{key2, key1})
  624. ticket = getTicket()
  625. testResumeState("KeyChange", true)
  626. if bytes.Equal(ticket, getTicket()) {
  627. t.Fatal("new ticket wasn't included while resuming")
  628. }
  629. testResumeState("KeyChangeFinish", true)
  630. // Reset serverConfig to ensure that calling SetSessionTicketKeys
  631. // before the serverConfig is used works.
  632. serverConfig = &Config{
  633. CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
  634. Certificates: testConfig.Certificates,
  635. }
  636. serverConfig.SetSessionTicketKeys([][32]byte{key2})
  637. testResumeState("FreshConfig", true)
  638. clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
  639. testResumeState("DifferentCipherSuite", false)
  640. testResumeState("DifferentCipherSuiteRecovers", true)
  641. clientConfig.ClientSessionCache = nil
  642. testResumeState("WithoutSessionCache", false)
  643. }
  644. func TestLRUClientSessionCache(t *testing.T) {
  645. // Initialize cache of capacity 4.
  646. cache := NewLRUClientSessionCache(4)
  647. cs := make([]ClientSessionState, 6)
  648. keys := []string{"0", "1", "2", "3", "4", "5", "6"}
  649. // Add 4 entries to the cache and look them up.
  650. for i := 0; i < 4; i++ {
  651. cache.Put(keys[i], &cs[i])
  652. }
  653. for i := 0; i < 4; i++ {
  654. if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
  655. t.Fatalf("session cache failed lookup for added key: %s", keys[i])
  656. }
  657. }
  658. // Add 2 more entries to the cache. First 2 should be evicted.
  659. for i := 4; i < 6; i++ {
  660. cache.Put(keys[i], &cs[i])
  661. }
  662. for i := 0; i < 2; i++ {
  663. if s, ok := cache.Get(keys[i]); ok || s != nil {
  664. t.Fatalf("session cache should have evicted key: %s", keys[i])
  665. }
  666. }
  667. // Touch entry 2. LRU should evict 3 next.
  668. cache.Get(keys[2])
  669. cache.Put(keys[0], &cs[0])
  670. if s, ok := cache.Get(keys[3]); ok || s != nil {
  671. t.Fatalf("session cache should have evicted key 3")
  672. }
  673. // Update entry 0 in place.
  674. cache.Put(keys[0], &cs[3])
  675. if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
  676. t.Fatalf("session cache failed update for key 0")
  677. }
  678. // Adding a nil entry is valid.
  679. cache.Put(keys[0], nil)
  680. if s, ok := cache.Get(keys[0]); !ok || s != nil {
  681. t.Fatalf("failed to add nil entry to cache")
  682. }
  683. }
  684. func TestKeyLog(t *testing.T) {
  685. var serverBuf, clientBuf bytes.Buffer
  686. clientConfig := testConfig.Clone()
  687. clientConfig.KeyLogWriter = &clientBuf
  688. serverConfig := testConfig.Clone()
  689. serverConfig.KeyLogWriter = &serverBuf
  690. c, s := net.Pipe()
  691. done := make(chan bool)
  692. go func() {
  693. defer close(done)
  694. if err := Server(s, serverConfig).Handshake(); err != nil {
  695. t.Errorf("server: %s", err)
  696. return
  697. }
  698. s.Close()
  699. }()
  700. if err := Client(c, clientConfig).Handshake(); err != nil {
  701. t.Fatalf("client: %s", err)
  702. }
  703. c.Close()
  704. <-done
  705. checkKeylogLine := func(side, loggedLine string) {
  706. if len(loggedLine) == 0 {
  707. t.Fatalf("%s: no keylog line was produced", side)
  708. }
  709. const expectedLen = 13 /* "CLIENT_RANDOM" */ +
  710. 1 /* space */ +
  711. 32*2 /* hex client nonce */ +
  712. 1 /* space */ +
  713. 48*2 /* hex master secret */ +
  714. 1 /* new line */
  715. if len(loggedLine) != expectedLen {
  716. t.Fatalf("%s: keylog line has incorrect length (want %d, got %d): %q", side, expectedLen, len(loggedLine), loggedLine)
  717. }
  718. if !strings.HasPrefix(loggedLine, "CLIENT_RANDOM "+strings.Repeat("0", 64)+" ") {
  719. t.Fatalf("%s: keylog line has incorrect structure or nonce: %q", side, loggedLine)
  720. }
  721. }
  722. checkKeylogLine("client", string(clientBuf.Bytes()))
  723. checkKeylogLine("server", string(serverBuf.Bytes()))
  724. }
  725. func TestHandshakeClientALPNMatch(t *testing.T) {
  726. config := testConfig.Clone()
  727. config.NextProtos = []string{"proto2", "proto1"}
  728. test := &clientTest{
  729. name: "ALPN",
  730. // Note that this needs OpenSSL 1.0.2 because that is the first
  731. // version that supports the -alpn flag.
  732. command: []string{"openssl", "s_server", "-alpn", "proto1,proto2"},
  733. config: config,
  734. validate: func(state ConnectionState) error {
  735. // The server's preferences should override the client.
  736. if state.NegotiatedProtocol != "proto1" {
  737. return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol)
  738. }
  739. return nil
  740. },
  741. }
  742. runClientTestTLS12(t, test)
  743. }
  744. // sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443`
  745. const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0="
  746. func TestHandshakeClientSCTs(t *testing.T) {
  747. config := testConfig.Clone()
  748. scts, err := base64.StdEncoding.DecodeString(sctsBase64)
  749. if err != nil {
  750. t.Fatal(err)
  751. }
  752. test := &clientTest{
  753. name: "SCT",
  754. // Note that this needs OpenSSL 1.0.2 because that is the first
  755. // version that supports the -serverinfo flag.
  756. command: []string{"openssl", "s_server"},
  757. config: config,
  758. extensions: [][]byte{scts},
  759. validate: func(state ConnectionState) error {
  760. expectedSCTs := [][]byte{
  761. scts[8:125],
  762. scts[127:245],
  763. scts[247:],
  764. }
  765. if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) {
  766. return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs))
  767. }
  768. for i, expected := range expectedSCTs {
  769. if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) {
  770. return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected)
  771. }
  772. }
  773. return nil
  774. },
  775. }
  776. runClientTestTLS12(t, test)
  777. }
  778. func TestRenegotiationRejected(t *testing.T) {
  779. config := testConfig.Clone()
  780. test := &clientTest{
  781. name: "RenegotiationRejected",
  782. command: []string{"openssl", "s_server", "-state"},
  783. config: config,
  784. numRenegotiations: 1,
  785. renegotiationExpectedToFail: 1,
  786. checkRenegotiationError: func(renegotiationNum int, err error) error {
  787. if err == nil {
  788. return errors.New("expected error from renegotiation but got nil")
  789. }
  790. if !strings.Contains(err.Error(), "no renegotiation") {
  791. return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
  792. }
  793. return nil
  794. },
  795. }
  796. runClientTestTLS12(t, test)
  797. }
  798. func TestRenegotiateOnce(t *testing.T) {
  799. config := testConfig.Clone()
  800. config.Renegotiation = RenegotiateOnceAsClient
  801. test := &clientTest{
  802. name: "RenegotiateOnce",
  803. command: []string{"openssl", "s_server", "-state"},
  804. config: config,
  805. numRenegotiations: 1,
  806. }
  807. runClientTestTLS12(t, test)
  808. }
  809. func TestRenegotiateTwice(t *testing.T) {
  810. config := testConfig.Clone()
  811. config.Renegotiation = RenegotiateFreelyAsClient
  812. test := &clientTest{
  813. name: "RenegotiateTwice",
  814. command: []string{"openssl", "s_server", "-state"},
  815. config: config,
  816. numRenegotiations: 2,
  817. }
  818. runClientTestTLS12(t, test)
  819. }
  820. func TestRenegotiateTwiceRejected(t *testing.T) {
  821. config := testConfig.Clone()
  822. config.Renegotiation = RenegotiateOnceAsClient
  823. test := &clientTest{
  824. name: "RenegotiateTwiceRejected",
  825. command: []string{"openssl", "s_server", "-state"},
  826. config: config,
  827. numRenegotiations: 2,
  828. renegotiationExpectedToFail: 2,
  829. checkRenegotiationError: func(renegotiationNum int, err error) error {
  830. if renegotiationNum == 1 {
  831. return err
  832. }
  833. if err == nil {
  834. return errors.New("expected error from renegotiation but got nil")
  835. }
  836. if !strings.Contains(err.Error(), "no renegotiation") {
  837. return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
  838. }
  839. return nil
  840. },
  841. }
  842. runClientTestTLS12(t, test)
  843. }
  844. var hostnameInSNITests = []struct {
  845. in, out string
  846. }{
  847. // Opaque string
  848. {"", ""},
  849. {"localhost", "localhost"},
  850. {"foo, bar, baz and qux", "foo, bar, baz and qux"},
  851. // DNS hostname
  852. {"golang.org", "golang.org"},
  853. {"golang.org.", "golang.org"},
  854. // Literal IPv4 address
  855. {"1.2.3.4", ""},
  856. // Literal IPv6 address
  857. {"::1", ""},
  858. {"::1%lo0", ""}, // with zone identifier
  859. {"[::1]", ""}, // as per RFC 5952 we allow the [] style as IPv6 literal
  860. {"[::1%lo0]", ""},
  861. }
  862. func TestHostnameInSNI(t *testing.T) {
  863. for _, tt := range hostnameInSNITests {
  864. c, s := net.Pipe()
  865. go func(host string) {
  866. Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake()
  867. }(tt.in)
  868. var header [5]byte
  869. if _, err := io.ReadFull(s, header[:]); err != nil {
  870. t.Fatal(err)
  871. }
  872. recordLen := int(header[3])<<8 | int(header[4])
  873. record := make([]byte, recordLen)
  874. if _, err := io.ReadFull(s, record[:]); err != nil {
  875. t.Fatal(err)
  876. }
  877. c.Close()
  878. s.Close()
  879. var m clientHelloMsg
  880. if m.unmarshal(record) != alertSuccess {
  881. t.Errorf("unmarshaling ClientHello for %q failed", tt.in)
  882. continue
  883. }
  884. if tt.in != tt.out && m.serverName == tt.in {
  885. t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record)
  886. }
  887. if m.serverName != tt.out {
  888. t.Errorf("expected %q not found in ClientHello: %x", tt.out, record)
  889. }
  890. }
  891. }
  892. func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
  893. // This checks that the server can't select a cipher suite that the
  894. // client didn't offer. See #13174.
  895. c, s := net.Pipe()
  896. errChan := make(chan error, 1)
  897. go func() {
  898. client := Client(c, &Config{
  899. ServerName: "foo",
  900. CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
  901. })
  902. errChan <- client.Handshake()
  903. }()
  904. var header [5]byte
  905. if _, err := io.ReadFull(s, header[:]); err != nil {
  906. t.Fatal(err)
  907. }
  908. recordLen := int(header[3])<<8 | int(header[4])
  909. record := make([]byte, recordLen)
  910. if _, err := io.ReadFull(s, record); err != nil {
  911. t.Fatal(err)
  912. }
  913. // Create a ServerHello that selects a different cipher suite than the
  914. // sole one that the client offered.
  915. serverHello := &serverHelloMsg{
  916. vers: VersionTLS12,
  917. random: make([]byte, 32),
  918. cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384,
  919. }
  920. serverHelloBytes := serverHello.marshal()
  921. s.Write([]byte{
  922. byte(recordTypeHandshake),
  923. byte(VersionTLS12 >> 8),
  924. byte(VersionTLS12 & 0xff),
  925. byte(len(serverHelloBytes) >> 8),
  926. byte(len(serverHelloBytes)),
  927. })
  928. s.Write(serverHelloBytes)
  929. s.Close()
  930. if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") {
  931. t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
  932. }
  933. }
  934. func TestVerifyPeerCertificate(t *testing.T) {
  935. issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
  936. if err != nil {
  937. panic(err)
  938. }
  939. rootCAs := x509.NewCertPool()
  940. rootCAs.AddCert(issuer)
  941. now := func() time.Time { return time.Unix(1476984729, 0) }
  942. sentinelErr := errors.New("TestVerifyPeerCertificate")
  943. verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  944. if l := len(rawCerts); l != 1 {
  945. return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
  946. }
  947. if len(validatedChains) == 0 {
  948. return errors.New("got len(validatedChains) = 0, wanted non-zero")
  949. }
  950. *called = true
  951. return nil
  952. }
  953. tests := []struct {
  954. configureServer func(*Config, *bool)
  955. configureClient func(*Config, *bool)
  956. validate func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error)
  957. }{
  958. {
  959. configureServer: func(config *Config, called *bool) {
  960. config.InsecureSkipVerify = false
  961. config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  962. return verifyCallback(called, rawCerts, validatedChains)
  963. }
  964. },
  965. configureClient: func(config *Config, called *bool) {
  966. config.InsecureSkipVerify = false
  967. config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  968. return verifyCallback(called, rawCerts, validatedChains)
  969. }
  970. },
  971. validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  972. if clientErr != nil {
  973. t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
  974. }
  975. if serverErr != nil {
  976. t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
  977. }
  978. if !clientCalled {
  979. t.Errorf("test[%d]: client did not call callback", testNo)
  980. }
  981. if !serverCalled {
  982. t.Errorf("test[%d]: server did not call callback", testNo)
  983. }
  984. },
  985. },
  986. {
  987. configureServer: func(config *Config, called *bool) {
  988. config.InsecureSkipVerify = false
  989. config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  990. return sentinelErr
  991. }
  992. },
  993. configureClient: func(config *Config, called *bool) {
  994. config.VerifyPeerCertificate = nil
  995. },
  996. validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  997. if serverErr != sentinelErr {
  998. t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
  999. }
  1000. },
  1001. },
  1002. {
  1003. configureServer: func(config *Config, called *bool) {
  1004. config.InsecureSkipVerify = false
  1005. },
  1006. configureClient: func(config *Config, called *bool) {
  1007. config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1008. return sentinelErr
  1009. }
  1010. },
  1011. validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1012. if clientErr != sentinelErr {
  1013. t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
  1014. }
  1015. },
  1016. },
  1017. {
  1018. configureServer: func(config *Config, called *bool) {
  1019. config.InsecureSkipVerify = false
  1020. },
  1021. configureClient: func(config *Config, called *bool) {
  1022. config.InsecureSkipVerify = true
  1023. config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1024. if l := len(rawCerts); l != 1 {
  1025. return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
  1026. }
  1027. // With InsecureSkipVerify set, this
  1028. // callback should still be called but
  1029. // validatedChains must be empty.
  1030. if l := len(validatedChains); l != 0 {
  1031. return fmt.Errorf("got len(validatedChains) = %d, wanted zero", l)
  1032. }
  1033. *called = true
  1034. return nil
  1035. }
  1036. },
  1037. validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1038. if clientErr != nil {
  1039. t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
  1040. }
  1041. if serverErr != nil {
  1042. t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
  1043. }
  1044. if !clientCalled {
  1045. t.Errorf("test[%d]: client did not call callback", testNo)
  1046. }
  1047. },
  1048. },
  1049. }
  1050. for i, test := range tests {
  1051. c, s := net.Pipe()
  1052. done := make(chan error)
  1053. var clientCalled, serverCalled bool
  1054. go func() {
  1055. config := testConfig.Clone()
  1056. config.ServerName = "example.golang"
  1057. config.ClientAuth = RequireAndVerifyClientCert
  1058. config.ClientCAs = rootCAs
  1059. config.Time = now
  1060. test.configureServer(config, &serverCalled)
  1061. err = Server(s, config).Handshake()
  1062. s.Close()
  1063. done <- err
  1064. }()
  1065. config := testConfig.Clone()
  1066. config.ServerName = "example.golang"
  1067. config.RootCAs = rootCAs
  1068. config.Time = now
  1069. test.configureClient(config, &clientCalled)
  1070. clientErr := Client(c, config).Handshake()
  1071. c.Close()
  1072. serverErr := <-done
  1073. test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr)
  1074. }
  1075. }
  1076. // brokenConn wraps a net.Conn and causes all Writes after a certain number to
  1077. // fail with brokenConnErr.
  1078. type brokenConn struct {
  1079. net.Conn
  1080. // breakAfter is the number of successful writes that will be allowed
  1081. // before all subsequent writes fail.
  1082. breakAfter int
  1083. // numWrites is the number of writes that have been done.
  1084. numWrites int
  1085. }
  1086. // brokenConnErr is the error that brokenConn returns once exhausted.
  1087. var brokenConnErr = errors.New("too many writes to brokenConn")
  1088. func (b *brokenConn) Write(data []byte) (int, error) {
  1089. if b.numWrites >= b.breakAfter {
  1090. return 0, brokenConnErr
  1091. }
  1092. b.numWrites++
  1093. return b.Conn.Write(data)
  1094. }
  1095. func TestFailedWrite(t *testing.T) {
  1096. // Test that a write error during the handshake is returned.
  1097. for _, breakAfter := range []int{0, 1} {
  1098. c, s := net.Pipe()
  1099. done := make(chan bool)
  1100. go func() {
  1101. Server(s, testConfig).Handshake()
  1102. s.Close()
  1103. done <- true
  1104. }()
  1105. brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
  1106. err := Client(brokenC, testConfig).Handshake()
  1107. if err != brokenConnErr {
  1108. t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
  1109. }
  1110. brokenC.Close()
  1111. <-done
  1112. }
  1113. }
  1114. // writeCountingConn wraps a net.Conn and counts the number of Write calls.
  1115. type writeCountingConn struct {
  1116. net.Conn
  1117. // numWrites is the number of writes that have been done.
  1118. numWrites int
  1119. }
  1120. func (wcc *writeCountingConn) Write(data []byte) (int, error) {
  1121. wcc.numWrites++
  1122. return wcc.Conn.Write(data)
  1123. }
  1124. func TestBuffering(t *testing.T) {
  1125. c, s := net.Pipe()
  1126. done := make(chan bool)
  1127. clientWCC := &writeCountingConn{Conn: c}
  1128. serverWCC := &writeCountingConn{Conn: s}
  1129. go func() {
  1130. Server(serverWCC, testConfig).Handshake()
  1131. serverWCC.Close()
  1132. done <- true
  1133. }()
  1134. err := Client(clientWCC, testConfig).Handshake()
  1135. if err != nil {
  1136. t.Fatal(err)
  1137. }
  1138. clientWCC.Close()
  1139. <-done
  1140. if n := clientWCC.numWrites; n != 2 {
  1141. t.Errorf("expected client handshake to complete with only two writes, but saw %d", n)
  1142. }
  1143. if n := serverWCC.numWrites; n != 2 {
  1144. t.Errorf("expected server handshake to complete with only two writes, but saw %d", n)
  1145. }
  1146. }
  1147. func TestAlertFlushing(t *testing.T) {
  1148. c, s := net.Pipe()
  1149. done := make(chan bool)
  1150. clientWCC := &writeCountingConn{Conn: c}
  1151. serverWCC := &writeCountingConn{Conn: s}
  1152. serverConfig := testConfig.Clone()
  1153. // Cause a signature-time error
  1154. brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey}
  1155. brokenKey.D = big.NewInt(42)
  1156. serverConfig.Certificates = []Certificate{{
  1157. Certificate: [][]byte{testRSACertificate},
  1158. PrivateKey: &brokenKey,
  1159. }}
  1160. go func() {
  1161. Server(serverWCC, serverConfig).Handshake()
  1162. serverWCC.Close()
  1163. done <- true
  1164. }()
  1165. err := Client(clientWCC, testConfig).Handshake()
  1166. if err == nil {
  1167. t.Fatal("client unexpectedly returned no error")
  1168. }
  1169. const expectedError = "remote error: tls: handshake failure"
  1170. if e := err.Error(); !strings.Contains(e, expectedError) {
  1171. t.Fatalf("expected to find %q in error but error was %q", expectedError, e)
  1172. }
  1173. clientWCC.Close()
  1174. <-done
  1175. if n := clientWCC.numWrites; n != 1 {
  1176. t.Errorf("expected client handshake to complete with one write, but saw %d", n)
  1177. }
  1178. if n := serverWCC.numWrites; n != 1 {
  1179. t.Errorf("expected server handshake to complete with one write, but saw %d", n)
  1180. }
  1181. }
  1182. func TestHandshakeRace(t *testing.T) {
  1183. t.Parallel()
  1184. // This test races a Read and Write to try and complete a handshake in
  1185. // order to provide some evidence that there are no races or deadlocks
  1186. // in the handshake locking.
  1187. for i := 0; i < 32; i++ {
  1188. c, s := net.Pipe()
  1189. go func() {
  1190. server := Server(s, testConfig)
  1191. if err := server.Handshake(); err != nil {
  1192. panic(err)
  1193. }
  1194. var request [1]byte
  1195. if n, err := server.Read(request[:]); err != nil || n != 1 {
  1196. panic(err)
  1197. }
  1198. server.Write(request[:])
  1199. server.Close()
  1200. }()
  1201. startWrite := make(chan struct{})
  1202. startRead := make(chan struct{})
  1203. readDone := make(chan struct{})
  1204. client := Client(c, testConfig)
  1205. go func() {
  1206. <-startWrite
  1207. var request [1]byte
  1208. client.Write(request[:])
  1209. }()
  1210. go func() {
  1211. <-startRead
  1212. var reply [1]byte
  1213. if n, err := client.Read(reply[:]); err != nil || n != 1 {
  1214. panic(err)
  1215. }
  1216. c.Close()
  1217. readDone <- struct{}{}
  1218. }()
  1219. if i&1 == 1 {
  1220. startWrite <- struct{}{}
  1221. startRead <- struct{}{}
  1222. } else {
  1223. startRead <- struct{}{}
  1224. startWrite <- struct{}{}
  1225. }
  1226. <-readDone
  1227. }
  1228. }
  1229. func TestTLS11SignatureSchemes(t *testing.T) {
  1230. expected := tls11SignatureSchemesNumECDSA + tls11SignatureSchemesNumRSA
  1231. if expected != len(tls11SignatureSchemes) {
  1232. t.Errorf("expected to find %d TLS 1.1 signature schemes, but found %d", expected, len(tls11SignatureSchemes))
  1233. }
  1234. }
  1235. var getClientCertificateTests = []struct {
  1236. setup func(*Config, *Config)
  1237. expectedClientError string
  1238. verify func(*testing.T, int, *ConnectionState)
  1239. }{
  1240. {
  1241. func(clientConfig, serverConfig *Config) {
  1242. // Returning a Certificate with no certificate data
  1243. // should result in an empty message being sent to the
  1244. // server.
  1245. serverConfig.ClientCAs = nil
  1246. clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  1247. if len(cri.SignatureSchemes) == 0 {
  1248. panic("empty SignatureSchemes")
  1249. }
  1250. if len(cri.AcceptableCAs) != 0 {
  1251. panic("AcceptableCAs should have been empty")
  1252. }
  1253. return new(Certificate), nil
  1254. }
  1255. },
  1256. "",
  1257. func(t *testing.T, testNum int, cs *ConnectionState) {
  1258. if l := len(cs.PeerCertificates); l != 0 {
  1259. t.Errorf("#%d: expected no certificates but got %d", testNum, l)
  1260. }
  1261. },
  1262. },
  1263. {
  1264. func(clientConfig, serverConfig *Config) {
  1265. // With TLS 1.1, the SignatureSchemes should be
  1266. // synthesised from the supported certificate types.
  1267. clientConfig.MaxVersion = VersionTLS11
  1268. clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  1269. if len(cri.SignatureSchemes) == 0 {
  1270. panic("empty SignatureSchemes")
  1271. }
  1272. return new(Certificate), nil
  1273. }
  1274. },
  1275. "",
  1276. func(t *testing.T, testNum int, cs *ConnectionState) {
  1277. if l := len(cs.PeerCertificates); l != 0 {
  1278. t.Errorf("#%d: expected no certificates but got %d", testNum, l)
  1279. }
  1280. },
  1281. },
  1282. {
  1283. func(clientConfig, serverConfig *Config) {
  1284. // Returning an error should abort the handshake with
  1285. // that error.
  1286. clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  1287. return nil, errors.New("GetClientCertificate")
  1288. }
  1289. },
  1290. "GetClientCertificate",
  1291. func(t *testing.T, testNum int, cs *ConnectionState) {
  1292. },
  1293. },
  1294. {
  1295. func(clientConfig, serverConfig *Config) {
  1296. clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  1297. if len(cri.AcceptableCAs) == 0 {
  1298. panic("empty AcceptableCAs")
  1299. }
  1300. cert := &Certificate{
  1301. Certificate: [][]byte{testRSACertificate},
  1302. PrivateKey: testRSAPrivateKey,
  1303. }
  1304. return cert, nil
  1305. }
  1306. },
  1307. "",
  1308. func(t *testing.T, testNum int, cs *ConnectionState) {
  1309. if len(cs.VerifiedChains) == 0 {
  1310. t.Errorf("#%d: expected some verified chains, but found none", testNum)
  1311. }
  1312. },
  1313. },
  1314. }
  1315. func TestGetClientCertificate(t *testing.T) {
  1316. issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
  1317. if err != nil {
  1318. panic(err)
  1319. }
  1320. for i, test := range getClientCertificateTests {
  1321. serverConfig := testConfig.Clone()
  1322. serverConfig.ClientAuth = VerifyClientCertIfGiven
  1323. serverConfig.RootCAs = x509.NewCertPool()
  1324. serverConfig.RootCAs.AddCert(issuer)
  1325. serverConfig.ClientCAs = serverConfig.RootCAs
  1326. serverConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
  1327. clientConfig := testConfig.Clone()
  1328. test.setup(clientConfig, serverConfig)
  1329. type serverResult struct {
  1330. cs ConnectionState
  1331. err error
  1332. }
  1333. c, s := net.Pipe()
  1334. done := make(chan serverResult)
  1335. go func() {
  1336. defer s.Close()
  1337. server := Server(s, serverConfig)
  1338. err := server.Handshake()
  1339. var cs ConnectionState
  1340. if err == nil {
  1341. cs = server.ConnectionState()
  1342. }
  1343. done <- serverResult{cs, err}
  1344. }()
  1345. clientErr := Client(c, clientConfig).Handshake()
  1346. c.Close()
  1347. result := <-done
  1348. if clientErr != nil {
  1349. if len(test.expectedClientError) == 0 {
  1350. t.Errorf("#%d: client error: %v", i, clientErr)
  1351. } else if got := clientErr.Error(); got != test.expectedClientError {
  1352. t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got)
  1353. } else {
  1354. test.verify(t, i, &result.cs)
  1355. }
  1356. } else if len(test.expectedClientError) > 0 {
  1357. t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError)
  1358. } else if err := result.err; err != nil {
  1359. t.Errorf("#%d: server error: %v", i, err)
  1360. } else {
  1361. test.verify(t, i, &result.cs)
  1362. }
  1363. }
  1364. }