Alternative TLS implementation in Go
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1623 lines
45 KiB

  1. // Copyright 2010 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package trs
  5. import (
  6. "bytes"
  7. "crypto/ecdsa"
  8. "crypto/rsa"
  9. "crypto/x509"
  10. "encoding/base64"
  11. "encoding/binary"
  12. "encoding/pem"
  13. "errors"
  14. "fmt"
  15. "io"
  16. "math/big"
  17. "net"
  18. "os"
  19. "os/exec"
  20. "path/filepath"
  21. "strconv"
  22. "strings"
  23. "sync"
  24. "testing"
  25. "time"
  26. )
  27. // Note: see comment in handshake_test.go for details of how the reference
  28. // tests work.
  29. // opensslInputEvent enumerates possible inputs that can be sent to an `openssl
  30. // s_client` process.
  31. type opensslInputEvent int
  32. const (
  33. // opensslRenegotiate causes OpenSSL to request a renegotiation of the
  34. // connection.
  35. opensslRenegotiate opensslInputEvent = iota
  36. // opensslSendBanner causes OpenSSL to send the contents of
  37. // opensslSentinel on the connection.
  38. opensslSendSentinel
  39. )
  40. const opensslSentinel = "SENTINEL\n"
  41. type opensslInput chan opensslInputEvent
  42. func (i opensslInput) Read(buf []byte) (n int, err error) {
  43. for event := range i {
  44. switch event {
  45. case opensslRenegotiate:
  46. return copy(buf, []byte("R\n")), nil
  47. case opensslSendSentinel:
  48. return copy(buf, []byte(opensslSentinel)), nil
  49. default:
  50. panic("unknown event")
  51. }
  52. }
  53. return 0, io.EOF
  54. }
  55. // opensslOutputSink is an io.Writer that receives the stdout and stderr from
  56. // an `openssl` process and sends a value to handshakeConfirmed when it sees a
  57. // log message from a completed server handshake.
  58. type opensslOutputSink struct {
  59. handshakeConfirmed chan struct{}
  60. all []byte
  61. line []byte
  62. }
  63. func newOpensslOutputSink() *opensslOutputSink {
  64. return &opensslOutputSink{make(chan struct{}), nil, nil}
  65. }
  66. // opensslEndOfHandshake is a message that the “openssl s_server” tool will
  67. // print when a handshake completes if run with “-state”.
  68. const opensslEndOfHandshake = "SSL_accept:SSLv3/TLS write finished"
  69. func (o *opensslOutputSink) Write(data []byte) (n int, err error) {
  70. o.line = append(o.line, data...)
  71. o.all = append(o.all, data...)
  72. for {
  73. i := bytes.IndexByte(o.line, '\n')
  74. if i < 0 {
  75. break
  76. }
  77. if bytes.Equal([]byte(opensslEndOfHandshake), o.line[:i]) {
  78. o.handshakeConfirmed <- struct{}{}
  79. }
  80. o.line = o.line[i+1:]
  81. }
  82. return len(data), nil
  83. }
  84. func (o *opensslOutputSink) WriteTo(w io.Writer) (int64, error) {
  85. n, err := w.Write(o.all)
  86. return int64(n), err
  87. }
  88. // clientTest represents a test of the TLS client handshake against a reference
  89. // implementation.
  90. type clientTest struct {
  91. // name is a freeform string identifying the test and the file in which
  92. // the expected results will be stored.
  93. name string
  94. // command, if not empty, contains a series of arguments for the
  95. // command to run for the reference server.
  96. command []string
  97. // config, if not nil, contains a custom Config to use for this test.
  98. config *Config
  99. // cert, if not empty, contains a DER-encoded certificate for the
  100. // reference server.
  101. cert []byte
  102. // key, if not nil, contains either a *rsa.PrivateKey or
  103. // *ecdsa.PrivateKey which is the private key for the reference server.
  104. key interface{}
  105. // extensions, if not nil, contains a list of extension data to be returned
  106. // from the ServerHello. The data should be in standard TLS format with
  107. // a 2-byte uint16 type, 2-byte data length, followed by the extension data.
  108. extensions [][]byte
  109. // validate, if not nil, is a function that will be called with the
  110. // ConnectionState of the resulting connection. It returns a non-nil
  111. // error if the ConnectionState is unacceptable.
  112. validate func(ConnectionState) error
  113. // numRenegotiations is the number of times that the connection will be
  114. // renegotiated.
  115. numRenegotiations int
  116. // renegotiationExpectedToFail, if not zero, is the number of the
  117. // renegotiation attempt that is expected to fail.
  118. renegotiationExpectedToFail int
  119. // checkRenegotiationError, if not nil, is called with any error
  120. // arising from renegotiation. It can map expected errors to nil to
  121. // ignore them.
  122. checkRenegotiationError func(renegotiationNum int, err error) error
  123. }
  124. var defaultServerCommand = []string{"openssl", "s_server"}
  125. // connFromCommand starts the reference server process, connects to it and
  126. // returns a recordingConn for the connection. The stdin return value is an
  127. // opensslInput for the stdin of the child process. It must be closed before
  128. // Waiting for child.
  129. func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) {
  130. cert := testRSACertificate
  131. if len(test.cert) > 0 {
  132. cert = test.cert
  133. }
  134. certPath := tempFile(string(cert))
  135. defer os.Remove(certPath)
  136. var key interface{} = testRSAPrivateKey
  137. if test.key != nil {
  138. key = test.key
  139. }
  140. var pemType string
  141. var derBytes []byte
  142. switch key := key.(type) {
  143. case *rsa.PrivateKey:
  144. pemType = "RSA"
  145. derBytes = x509.MarshalPKCS1PrivateKey(key)
  146. case *ecdsa.PrivateKey:
  147. pemType = "EC"
  148. var err error
  149. derBytes, err = x509.MarshalECPrivateKey(key)
  150. if err != nil {
  151. panic(err)
  152. }
  153. default:
  154. panic("unknown key type")
  155. }
  156. var pemOut bytes.Buffer
  157. pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes})
  158. keyPath := tempFile(string(pemOut.Bytes()))
  159. defer os.Remove(keyPath)
  160. var command []string
  161. if len(test.command) > 0 {
  162. command = append(command, test.command...)
  163. } else {
  164. command = append(command, defaultServerCommand...)
  165. }
  166. command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath)
  167. // serverPort contains the port that OpenSSL will listen on. OpenSSL
  168. // can't take "0" as an argument here so we have to pick a number and
  169. // hope that it's not in use on the machine. Since this only occurs
  170. // when -update is given and thus when there's a human watching the
  171. // test, this isn't too bad.
  172. const serverPort = 24323
  173. command = append(command, "-accept", strconv.Itoa(serverPort))
  174. if len(test.extensions) > 0 {
  175. var serverInfo bytes.Buffer
  176. for _, ext := range test.extensions {
  177. pem.Encode(&serverInfo, &pem.Block{
  178. Type: fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)),
  179. Bytes: ext,
  180. })
  181. }
  182. serverInfoPath := tempFile(serverInfo.String())
  183. defer os.Remove(serverInfoPath)
  184. command = append(command, "-serverinfo", serverInfoPath)
  185. }
  186. if test.numRenegotiations > 0 {
  187. found := false
  188. for _, flag := range command[1:] {
  189. if flag == "-state" {
  190. found = true
  191. break
  192. }
  193. }
  194. if !found {
  195. panic("-state flag missing to OpenSSL. You need this if testing renegotiation")
  196. }
  197. }
  198. cmd := exec.Command(command[0], command[1:]...)
  199. stdin = opensslInput(make(chan opensslInputEvent))
  200. cmd.Stdin = stdin
  201. out := newOpensslOutputSink()
  202. cmd.Stdout = out
  203. cmd.Stderr = out
  204. if err := cmd.Start(); err != nil {
  205. return nil, nil, nil, nil, err
  206. }
  207. // OpenSSL does print an "ACCEPT" banner, but it does so *before*
  208. // opening the listening socket, so we can't use that to wait until it
  209. // has started listening. Thus we are forced to poll until we get a
  210. // connection.
  211. var tcpConn net.Conn
  212. for i := uint(0); i < 5; i++ {
  213. tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
  214. IP: net.IPv4(127, 0, 0, 1),
  215. Port: serverPort,
  216. })
  217. if err == nil {
  218. break
  219. }
  220. time.Sleep((1 << i) * 5 * time.Millisecond)
  221. }
  222. if err != nil {
  223. close(stdin)
  224. out.WriteTo(os.Stdout)
  225. cmd.Process.Kill()
  226. return nil, nil, nil, nil, cmd.Wait()
  227. }
  228. record := &recordingConn{
  229. Conn: tcpConn,
  230. }
  231. return record, cmd, stdin, out, nil
  232. }
  233. func (test *clientTest) dataPath() string {
  234. return filepath.Join("testdata", "Client-"+test.name)
  235. }
  236. func (test *clientTest) loadData() (flows [][]byte, err error) {
  237. in, err := os.Open(test.dataPath())
  238. if err != nil {
  239. return nil, err
  240. }
  241. defer in.Close()
  242. return parseTestData(in)
  243. }
  244. func (test *clientTest) run(t *testing.T, write bool) {
  245. checkOpenSSLVersion(t)
  246. var clientConn, serverConn net.Conn
  247. var recordingConn *recordingConn
  248. var childProcess *exec.Cmd
  249. var stdin opensslInput
  250. var stdout *opensslOutputSink
  251. if write {
  252. var err error
  253. recordingConn, childProcess, stdin, stdout, err = test.connFromCommand()
  254. if err != nil {
  255. t.Fatalf("Failed to start subcommand: %s", err)
  256. }
  257. clientConn = recordingConn
  258. } else {
  259. clientConn, serverConn = net.Pipe()
  260. }
  261. config := test.config
  262. if config == nil {
  263. config = testConfig
  264. }
  265. client := Client(clientConn, config)
  266. doneChan := make(chan bool)
  267. go func() {
  268. defer func() { doneChan <- true }()
  269. defer clientConn.Close()
  270. defer client.Close()
  271. if _, err := client.Write([]byte("hello\n")); err != nil {
  272. t.Errorf("Client.Write failed: %s", err)
  273. return
  274. }
  275. for i := 1; i <= test.numRenegotiations; i++ {
  276. // The initial handshake will generate a
  277. // handshakeConfirmed signal which needs to be quashed.
  278. if i == 1 && write {
  279. <-stdout.handshakeConfirmed
  280. }
  281. // OpenSSL will try to interleave application data and
  282. // a renegotiation if we send both concurrently.
  283. // Therefore: ask OpensSSL to start a renegotiation, run
  284. // a goroutine to call client.Read and thus process the
  285. // renegotiation request, watch for OpenSSL's stdout to
  286. // indicate that the handshake is complete and,
  287. // finally, have OpenSSL write something to cause
  288. // client.Read to complete.
  289. if write {
  290. stdin <- opensslRenegotiate
  291. }
  292. signalChan := make(chan struct{})
  293. go func() {
  294. defer func() { signalChan <- struct{}{} }()
  295. buf := make([]byte, 256)
  296. n, err := client.Read(buf)
  297. if test.checkRenegotiationError != nil {
  298. newErr := test.checkRenegotiationError(i, err)
  299. if err != nil && newErr == nil {
  300. return
  301. }
  302. err = newErr
  303. }
  304. if err != nil {
  305. t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
  306. return
  307. }
  308. buf = buf[:n]
  309. if !bytes.Equal([]byte(opensslSentinel), buf) {
  310. t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
  311. }
  312. if expected := i + 1; client.handshakes != expected {
  313. t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
  314. }
  315. }()
  316. if write && test.renegotiationExpectedToFail != i {
  317. <-stdout.handshakeConfirmed
  318. stdin <- opensslSendSentinel
  319. }
  320. <-signalChan
  321. }
  322. if test.validate != nil {
  323. if err := test.validate(client.ConnectionState()); err != nil {
  324. t.Errorf("validate callback returned error: %s", err)
  325. }
  326. }
  327. }()
  328. if !write {
  329. flows, err := test.loadData()
  330. if err != nil {
  331. t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
  332. }
  333. for i, b := range flows {
  334. if i%2 == 1 {
  335. serverConn.Write(b)
  336. continue
  337. }
  338. bb := make([]byte, len(b))
  339. _, err := io.ReadFull(serverConn, bb)
  340. if err != nil {
  341. t.Fatalf("%s #%d: %s", test.name, i, err)
  342. }
  343. if !bytes.Equal(b, bb) {
  344. t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i, bb, b)
  345. }
  346. }
  347. serverConn.Close()
  348. }
  349. <-doneChan
  350. if write {
  351. path := test.dataPath()
  352. out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
  353. if err != nil {
  354. t.Fatalf("Failed to create output file: %s", err)
  355. }
  356. defer out.Close()
  357. recordingConn.Close()
  358. close(stdin)
  359. childProcess.Process.Kill()
  360. childProcess.Wait()
  361. if len(recordingConn.flows) < 3 {
  362. os.Stdout.Write(childProcess.Stdout.(*opensslOutputSink).all)
  363. t.Fatalf("Client connection didn't work")
  364. }
  365. recordingConn.WriteTo(out)
  366. fmt.Printf("Wrote %s\n", path)
  367. }
  368. }
  369. var (
  370. didParMu sync.Mutex
  371. didPar = map[*testing.T]bool{}
  372. )
  373. // setParallel calls t.Parallel once. If you call it twice, it would
  374. // panic.
  375. func setParallel(t *testing.T) {
  376. didParMu.Lock()
  377. v := didPar[t]
  378. didPar[t] = true
  379. didParMu.Unlock()
  380. if !v {
  381. t.Parallel()
  382. }
  383. }
  384. func runClientTestForVersion(t *testing.T, template *clientTest, prefix, option string) {
  385. setParallel(t)
  386. test := *template
  387. test.name = prefix + test.name
  388. if len(test.command) == 0 {
  389. test.command = defaultClientCommand
  390. }
  391. test.command = append([]string(nil), test.command...)
  392. test.command = append(test.command, option)
  393. test.run(t, *update)
  394. }
  395. func runClientTestTLS10(t *testing.T, template *clientTest) {
  396. runClientTestForVersion(t, template, "TLSv10-", "-tls1")
  397. }
  398. func runClientTestTLS11(t *testing.T, template *clientTest) {
  399. runClientTestForVersion(t, template, "TLSv11-", "-tls1_1")
  400. }
  401. func runClientTestTLS12(t *testing.T, template *clientTest) {
  402. runClientTestForVersion(t, template, "TLSv12-", "-tls1_2")
  403. }
  404. func TestHandshakeClientRSARC4(t *testing.T) {
  405. test := &clientTest{
  406. name: "RSA-RC4",
  407. command: []string{"openssl", "s_server", "-cipher", "RC4-SHA"},
  408. }
  409. runClientTestTLS10(t, test)
  410. runClientTestTLS11(t, test)
  411. runClientTestTLS12(t, test)
  412. }
  413. func TestHandshakeClientRSAAES128GCM(t *testing.T) {
  414. test := &clientTest{
  415. name: "AES128-GCM-SHA256",
  416. command: []string{"openssl", "s_server", "-cipher", "AES128-GCM-SHA256"},
  417. }
  418. runClientTestTLS12(t, test)
  419. }
  420. func TestHandshakeClientRSAAES256GCM(t *testing.T) {
  421. test := &clientTest{
  422. name: "AES256-GCM-SHA384",
  423. command: []string{"openssl", "s_server", "-cipher", "AES256-GCM-SHA384"},
  424. }
  425. runClientTestTLS12(t, test)
  426. }
  427. func TestHandshakeClientECDHERSAAES(t *testing.T) {
  428. test := &clientTest{
  429. name: "ECDHE-RSA-AES",
  430. command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA"},
  431. }
  432. runClientTestTLS10(t, test)
  433. runClientTestTLS11(t, test)
  434. runClientTestTLS12(t, test)
  435. }
  436. func TestHandshakeClientECDHEECDSAAES(t *testing.T) {
  437. test := &clientTest{
  438. name: "ECDHE-ECDSA-AES",
  439. command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA"},
  440. cert: testECDSACertificate,
  441. key: testECDSAPrivateKey,
  442. }
  443. runClientTestTLS10(t, test)
  444. runClientTestTLS11(t, test)
  445. runClientTestTLS12(t, test)
  446. }
  447. func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) {
  448. test := &clientTest{
  449. name: "ECDHE-ECDSA-AES-GCM",
  450. command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"},
  451. cert: testECDSACertificate,
  452. key: testECDSAPrivateKey,
  453. }
  454. runClientTestTLS12(t, test)
  455. }
  456. func TestHandshakeClientAES256GCMSHA384(t *testing.T) {
  457. test := &clientTest{
  458. name: "ECDHE-ECDSA-AES256-GCM-SHA384",
  459. command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"},
  460. cert: testECDSACertificate,
  461. key: testECDSAPrivateKey,
  462. }
  463. runClientTestTLS12(t, test)
  464. }
  465. func TestHandshakeClientAES128CBCSHA256(t *testing.T) {
  466. test := &clientTest{
  467. name: "AES128-SHA256",
  468. command: []string{"openssl", "s_server", "-cipher", "AES128-SHA256"},
  469. }
  470. runClientTestTLS12(t, test)
  471. }
  472. func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) {
  473. test := &clientTest{
  474. name: "ECDHE-RSA-AES128-SHA256",
  475. command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA256"},
  476. }
  477. runClientTestTLS12(t, test)
  478. }
  479. func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) {
  480. test := &clientTest{
  481. name: "ECDHE-ECDSA-AES128-SHA256",
  482. command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA256"},
  483. cert: testECDSACertificate,
  484. key: testECDSAPrivateKey,
  485. }
  486. runClientTestTLS12(t, test)
  487. }
  488. func TestHandshakeClientX25519(t *testing.T) {
  489. config := testConfig.Clone()
  490. config.CurvePreferences = []CurveID{X25519}
  491. test := &clientTest{
  492. name: "X25519-ECDHE-RSA-AES-GCM",
  493. command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-GCM-SHA256"},
  494. config: config,
  495. }
  496. runClientTestTLS12(t, test)
  497. }
  498. func TestHandshakeClientECDHERSAChaCha20(t *testing.T) {
  499. config := testConfig.Clone()
  500. config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}
  501. test := &clientTest{
  502. name: "ECDHE-RSA-CHACHA20-POLY1305",
  503. command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-CHACHA20-POLY1305"},
  504. config: config,
  505. }
  506. runClientTestTLS12(t, test)
  507. }
  508. func TestHandshakeClientECDHEECDSAChaCha20(t *testing.T) {
  509. config := testConfig.Clone()
  510. config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305}
  511. test := &clientTest{
  512. name: "ECDHE-ECDSA-CHACHA20-POLY1305",
  513. command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305"},
  514. config: config,
  515. cert: testECDSACertificate,
  516. key: testECDSAPrivateKey,
  517. }
  518. runClientTestTLS12(t, test)
  519. }
  520. func TestHandshakeClientCertRSA(t *testing.T) {
  521. config := testConfig.Clone()
  522. cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
  523. config.Certificates = []Certificate{cert}
  524. test := &clientTest{
  525. name: "ClientCert-RSA-RSA",
  526. command: []string{"openssl", "s_server", "-cipher", "AES128", "-verify", "1"},
  527. config: config,
  528. }
  529. runClientTestTLS10(t, test)
  530. runClientTestTLS12(t, test)
  531. test = &clientTest{
  532. name: "ClientCert-RSA-ECDSA",
  533. command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"},
  534. config: config,
  535. cert: testECDSACertificate,
  536. key: testECDSAPrivateKey,
  537. }
  538. runClientTestTLS10(t, test)
  539. runClientTestTLS12(t, test)
  540. test = &clientTest{
  541. name: "ClientCert-RSA-AES256-GCM-SHA384",
  542. command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-verify", "1"},
  543. config: config,
  544. cert: testRSACertificate,
  545. key: testRSAPrivateKey,
  546. }
  547. runClientTestTLS12(t, test)
  548. }
  549. func TestHandshakeClientCertECDSA(t *testing.T) {
  550. config := testConfig.Clone()
  551. cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
  552. config.Certificates = []Certificate{cert}
  553. test := &clientTest{
  554. name: "ClientCert-ECDSA-RSA",
  555. command: []string{"openssl", "s_server", "-cipher", "AES128", "-verify", "1"},
  556. config: config,
  557. }
  558. runClientTestTLS10(t, test)
  559. runClientTestTLS12(t, test)
  560. test = &clientTest{
  561. name: "ClientCert-ECDSA-ECDSA",
  562. command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"},
  563. config: config,
  564. cert: testECDSACertificate,
  565. key: testECDSAPrivateKey,
  566. }
  567. runClientTestTLS10(t, test)
  568. runClientTestTLS12(t, test)
  569. }
  570. // This test is specific to TLS versions which support session tickets (TLSv1.2 and below).
  571. // Session tickets are obsolete in TLSv1.3 (see 2.2 of TLS RFC)
  572. func TestClientResumption(t *testing.T) {
  573. serverConfig := &Config{
  574. CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
  575. Certificates: testConfig.Certificates,
  576. }
  577. issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
  578. if err != nil {
  579. panic(err)
  580. }
  581. rootCAs := x509.NewCertPool()
  582. rootCAs.AddCert(issuer)
  583. clientConfig := &Config{
  584. CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
  585. ClientSessionCache: NewLRUClientSessionCache(32),
  586. RootCAs: rootCAs,
  587. ServerName: "example.golang",
  588. MaxVersion: VersionTLS12, // Enforce TLSv1.2
  589. }
  590. testResumeState := func(test string, didResume bool) {
  591. _, hs, err := testHandshake(clientConfig, serverConfig)
  592. if err != nil {
  593. t.Fatalf("%s: handshake failed: %s", test, err)
  594. }
  595. if hs.DidResume != didResume {
  596. t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
  597. }
  598. if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
  599. t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
  600. }
  601. }
  602. getTicket := func() []byte {
  603. return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket
  604. }
  605. randomKey := func() [32]byte {
  606. var k [32]byte
  607. if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil {
  608. t.Fatalf("Failed to read new SessionTicketKey: %s", err)
  609. }
  610. return k
  611. }
  612. testResumeState("Handshake", false)
  613. ticket := getTicket()
  614. testResumeState("Resume", true)
  615. if !bytes.Equal(ticket, getTicket()) {
  616. t.Fatal("first ticket doesn't match ticket after resumption")
  617. }
  618. key1 := randomKey()
  619. serverConfig.SetSessionTicketKeys([][32]byte{key1})
  620. testResumeState("InvalidSessionTicketKey", false)
  621. testResumeState("ResumeAfterInvalidSessionTicketKey", true)
  622. key2 := randomKey()
  623. serverConfig.SetSessionTicketKeys([][32]byte{key2, key1})
  624. ticket = getTicket()
  625. testResumeState("KeyChange", true)
  626. if bytes.Equal(ticket, getTicket()) {
  627. t.Fatal("new ticket wasn't included while resuming")
  628. }
  629. testResumeState("KeyChangeFinish", true)
  630. // Reset serverConfig to ensure that calling SetSessionTicketKeys
  631. // before the serverConfig is used works.
  632. serverConfig = &Config{
  633. CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
  634. Certificates: testConfig.Certificates,
  635. }
  636. serverConfig.SetSessionTicketKeys([][32]byte{key2})
  637. testResumeState("FreshConfig", true)
  638. clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
  639. testResumeState("DifferentCipherSuite", false)
  640. testResumeState("DifferentCipherSuiteRecovers", true)
  641. clientConfig.ClientSessionCache = nil
  642. testResumeState("WithoutSessionCache", false)
  643. }
  644. func TestLRUClientSessionCache(t *testing.T) {
  645. // Initialize cache of capacity 4.
  646. cache := NewLRUClientSessionCache(4)
  647. cs := make([]ClientSessionState, 6)
  648. keys := []string{"0", "1", "2", "3", "4", "5", "6"}
  649. // Add 4 entries to the cache and look them up.
  650. for i := 0; i < 4; i++ {
  651. cache.Put(keys[i], &cs[i])
  652. }
  653. for i := 0; i < 4; i++ {
  654. if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
  655. t.Fatalf("session cache failed lookup for added key: %s", keys[i])
  656. }
  657. }
  658. // Add 2 more entries to the cache. First 2 should be evicted.
  659. for i := 4; i < 6; i++ {
  660. cache.Put(keys[i], &cs[i])
  661. }
  662. for i := 0; i < 2; i++ {
  663. if s, ok := cache.Get(keys[i]); ok || s != nil {
  664. t.Fatalf("session cache should have evicted key: %s", keys[i])
  665. }
  666. }
  667. // Touch entry 2. LRU should evict 3 next.
  668. cache.Get(keys[2])
  669. cache.Put(keys[0], &cs[0])
  670. if s, ok := cache.Get(keys[3]); ok || s != nil {
  671. t.Fatalf("session cache should have evicted key 3")
  672. }
  673. // Update entry 0 in place.
  674. cache.Put(keys[0], &cs[3])
  675. if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
  676. t.Fatalf("session cache failed update for key 0")
  677. }
  678. // Adding a nil entry is valid.
  679. cache.Put(keys[0], nil)
  680. if s, ok := cache.Get(keys[0]); !ok || s != nil {
  681. t.Fatalf("failed to add nil entry to cache")
  682. }
  683. }
  684. func TestKeyLog(t *testing.T) {
  685. var serverBuf, clientBuf bytes.Buffer
  686. clientConfig := testConfig.Clone()
  687. clientConfig.KeyLogWriter = &clientBuf
  688. serverConfig := testConfig.Clone()
  689. serverConfig.KeyLogWriter = &serverBuf
  690. c, s := net.Pipe()
  691. done := make(chan bool)
  692. go func() {
  693. defer close(done)
  694. if err := Server(s, serverConfig).Handshake(); err != nil {
  695. t.Errorf("server: %s", err)
  696. return
  697. }
  698. s.Close()
  699. }()
  700. if err := Client(c, clientConfig).Handshake(); err != nil {
  701. t.Fatalf("client: %s", err)
  702. }
  703. c.Close()
  704. <-done
  705. checkKeylogLine := func(side, loggedLine string) {
  706. if len(loggedLine) == 0 {
  707. t.Fatalf("%s: no keylog line was produced", side)
  708. }
  709. const expectedLen = 13 /* "CLIENT_RANDOM" */ +
  710. 1 /* space */ +
  711. 32*2 /* hex client nonce */ +
  712. 1 /* space */ +
  713. 48*2 /* hex master secret */ +
  714. 1 /* new line */
  715. if len(loggedLine) != expectedLen {
  716. t.Fatalf("%s: keylog line has incorrect length (want %d, got %d): %q", side, expectedLen, len(loggedLine), loggedLine)
  717. }
  718. if !strings.HasPrefix(loggedLine, "CLIENT_RANDOM "+strings.Repeat("0", 64)+" ") {
  719. t.Fatalf("%s: keylog line has incorrect structure or nonce: %q", side, loggedLine)
  720. }
  721. }
  722. checkKeylogLine("client", string(clientBuf.Bytes()))
  723. checkKeylogLine("server", string(serverBuf.Bytes()))
  724. }
  725. func TestHandshakeClientALPNMatch(t *testing.T) {
  726. config := testConfig.Clone()
  727. config.NextProtos = []string{"proto2", "proto1"}
  728. test := &clientTest{
  729. name: "ALPN",
  730. // Note that this needs OpenSSL 1.0.2 because that is the first
  731. // version that supports the -alpn flag.
  732. command: []string{"openssl", "s_server", "-alpn", "proto1,proto2"},
  733. config: config,
  734. validate: func(state ConnectionState) error {
  735. // The server's preferences should override the client.
  736. if state.NegotiatedProtocol != "proto1" {
  737. return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol)
  738. }
  739. return nil
  740. },
  741. }
  742. runClientTestTLS12(t, test)
  743. }
  744. // sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443`
  745. const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0="
  746. func TestHandshakeClientSCTs(t *testing.T) {
  747. config := testConfig.Clone()
  748. scts, err := base64.StdEncoding.DecodeString(sctsBase64)
  749. if err != nil {
  750. t.Fatal(err)
  751. }
  752. test := &clientTest{
  753. name: "SCT",
  754. // Note that this needs OpenSSL 1.0.2 because that is the first
  755. // version that supports the -serverinfo flag.
  756. command: []string{"openssl", "s_server"},
  757. config: config,
  758. extensions: [][]byte{scts},
  759. validate: func(state ConnectionState) error {
  760. expectedSCTs := [][]byte{
  761. scts[8:125],
  762. scts[127:245],
  763. scts[247:],
  764. }
  765. if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) {
  766. return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs))
  767. }
  768. for i, expected := range expectedSCTs {
  769. if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) {
  770. return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected)
  771. }
  772. }
  773. return nil
  774. },
  775. }
  776. runClientTestTLS12(t, test)
  777. }
  778. func TestRenegotiationRejected(t *testing.T) {
  779. config := testConfig.Clone()
  780. test := &clientTest{
  781. name: "RenegotiationRejected",
  782. command: []string{"openssl", "s_server", "-state"},
  783. config: config,
  784. numRenegotiations: 1,
  785. renegotiationExpectedToFail: 1,
  786. checkRenegotiationError: func(renegotiationNum int, err error) error {
  787. if err == nil {
  788. return errors.New("expected error from renegotiation but got nil")
  789. }
  790. if !strings.Contains(err.Error(), "no renegotiation") {
  791. return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
  792. }
  793. return nil
  794. },
  795. }
  796. runClientTestTLS12(t, test)
  797. }
  798. func TestRenegotiateOnce(t *testing.T) {
  799. config := testConfig.Clone()
  800. config.Renegotiation = RenegotiateOnceAsClient
  801. test := &clientTest{
  802. name: "RenegotiateOnce",
  803. command: []string{"openssl", "s_server", "-state"},
  804. config: config,
  805. numRenegotiations: 1,
  806. }
  807. runClientTestTLS12(t, test)
  808. }
  809. func TestRenegotiateTwice(t *testing.T) {
  810. config := testConfig.Clone()
  811. config.Renegotiation = RenegotiateFreelyAsClient
  812. test := &clientTest{
  813. name: "RenegotiateTwice",
  814. command: []string{"openssl", "s_server", "-state"},
  815. config: config,
  816. numRenegotiations: 2,
  817. }
  818. runClientTestTLS12(t, test)
  819. }
  820. func TestRenegotiateTwiceRejected(t *testing.T) {
  821. config := testConfig.Clone()
  822. config.Renegotiation = RenegotiateOnceAsClient
  823. test := &clientTest{
  824. name: "RenegotiateTwiceRejected",
  825. command: []string{"openssl", "s_server", "-state"},
  826. config: config,
  827. numRenegotiations: 2,
  828. renegotiationExpectedToFail: 2,
  829. checkRenegotiationError: func(renegotiationNum int, err error) error {
  830. if renegotiationNum == 1 {
  831. return err
  832. }
  833. if err == nil {
  834. return errors.New("expected error from renegotiation but got nil")
  835. }
  836. if !strings.Contains(err.Error(), "no renegotiation") {
  837. return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
  838. }
  839. return nil
  840. },
  841. }
  842. runClientTestTLS12(t, test)
  843. }
  844. var hostnameInSNITests = []struct {
  845. in, out string
  846. }{
  847. // Opaque string
  848. {"", ""},
  849. {"localhost", "localhost"},
  850. {"foo, bar, baz and qux", "foo, bar, baz and qux"},
  851. // DNS hostname
  852. {"golang.org", "golang.org"},
  853. {"golang.org.", "golang.org"},
  854. // Literal IPv4 address
  855. {"1.2.3.4", ""},
  856. // Literal IPv6 address
  857. {"::1", ""},
  858. {"::1%lo0", ""}, // with zone identifier
  859. {"[::1]", ""}, // as per RFC 5952 we allow the [] style as IPv6 literal
  860. {"[::1%lo0]", ""},
  861. }
  862. func TestHostnameInSNI(t *testing.T) {
  863. for _, tt := range hostnameInSNITests {
  864. c, s := net.Pipe()
  865. go func(host string) {
  866. Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake()
  867. }(tt.in)
  868. var header [5]byte
  869. if _, err := io.ReadFull(s, header[:]); err != nil {
  870. t.Fatal(err)
  871. }
  872. recordLen := int(header[3])<<8 | int(header[4])
  873. record := make([]byte, recordLen)
  874. if _, err := io.ReadFull(s, record[:]); err != nil {
  875. t.Fatal(err)
  876. }
  877. c.Close()
  878. s.Close()
  879. var m clientHelloMsg
  880. if m.unmarshal(record) != alertSuccess {
  881. t.Errorf("unmarshaling ClientHello for %q failed", tt.in)
  882. continue
  883. }
  884. if tt.in != tt.out && m.serverName == tt.in {
  885. t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record)
  886. }
  887. if m.serverName != tt.out {
  888. t.Errorf("expected %q not found in ClientHello: %x", tt.out, record)
  889. }
  890. }
  891. }
  892. func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
  893. // This checks that the server can't select a cipher suite that the
  894. // client didn't offer. See #13174.
  895. c, s := net.Pipe()
  896. errChan := make(chan error, 1)
  897. go func() {
  898. client := Client(c, &Config{
  899. ServerName: "foo",
  900. CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
  901. })
  902. errChan <- client.Handshake()
  903. }()
  904. var header [5]byte
  905. if _, err := io.ReadFull(s, header[:]); err != nil {
  906. t.Fatal(err)
  907. }
  908. recordLen := int(header[3])<<8 | int(header[4])
  909. record := make([]byte, recordLen)
  910. if _, err := io.ReadFull(s, record); err != nil {
  911. t.Fatal(err)
  912. }
  913. // Create a ServerHello that selects a different cipher suite than the
  914. // sole one that the client offered.
  915. serverHello := &serverHelloMsg{
  916. vers: VersionTLS12,
  917. random: make([]byte, 32),
  918. cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384,
  919. }
  920. serverHelloBytes := serverHello.marshal()
  921. s.Write([]byte{
  922. byte(recordTypeHandshake),
  923. byte(VersionTLS12 >> 8),
  924. byte(VersionTLS12 & 0xff),
  925. byte(len(serverHelloBytes) >> 8),
  926. byte(len(serverHelloBytes)),
  927. })
  928. s.Write(serverHelloBytes)
  929. s.Close()
  930. if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") {
  931. t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
  932. }
  933. }
  934. func TestVerifyPeerCertificate(t *testing.T) {
  935. issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
  936. if err != nil {
  937. panic(err)
  938. }
  939. rootCAs := x509.NewCertPool()
  940. rootCAs.AddCert(issuer)
  941. now := func() time.Time { return time.Unix(1476984729, 0) }
  942. sentinelErr := errors.New("TestVerifyPeerCertificate")
  943. verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  944. if l := len(rawCerts); l != 1 {
  945. return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
  946. }
  947. if len(validatedChains) == 0 {
  948. return errors.New("got len(validatedChains) = 0, wanted non-zero")
  949. }
  950. *called = true
  951. return nil
  952. }
  953. tests := []struct {
  954. configureServer func(*Config, *bool)
  955. configureClient func(*Config, *bool)
  956. validate func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error)
  957. }{
  958. {
  959. configureServer: func(config *Config, called *bool) {
  960. config.InsecureSkipVerify = false
  961. config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  962. return verifyCallback(called, rawCerts, validatedChains)
  963. }
  964. },
  965. configureClient: func(config *Config, called *bool) {
  966. config.InsecureSkipVerify = false
  967. config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  968. return verifyCallback(called, rawCerts, validatedChains)
  969. }
  970. },
  971. validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  972. if clientErr != nil {
  973. t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
  974. }
  975. if serverErr != nil {
  976. t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
  977. }
  978. if !clientCalled {
  979. t.Errorf("test[%d]: client did not call callback", testNo)
  980. }
  981. if !serverCalled {
  982. t.Errorf("test[%d]: server did not call callback", testNo)
  983. }
  984. },
  985. },
  986. {
  987. configureServer: func(config *Config, called *bool) {
  988. config.InsecureSkipVerify = false
  989. config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  990. return sentinelErr
  991. }
  992. },
  993. configureClient: func(config *Config, called *bool) {
  994. config.VerifyPeerCertificate = nil
  995. },
  996. validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  997. if serverErr != sentinelErr {
  998. t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
  999. }
  1000. },
  1001. },
  1002. {
  1003. configureServer: func(config *Config, called *bool) {
  1004. config.InsecureSkipVerify = false
  1005. },
  1006. configureClient: func(config *Config, called *bool) {
  1007. config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1008. return sentinelErr
  1009. }
  1010. },
  1011. validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1012. if clientErr != sentinelErr {
  1013. t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
  1014. }
  1015. },
  1016. },
  1017. {
  1018. configureServer: func(config *Config, called *bool) {
  1019. config.InsecureSkipVerify = false
  1020. },
  1021. configureClient: func(config *Config, called *bool) {
  1022. config.InsecureSkipVerify = true
  1023. config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1024. if l := len(rawCerts); l != 1 {
  1025. return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
  1026. }
  1027. // With InsecureSkipVerify set, this
  1028. // callback should still be called but
  1029. // validatedChains must be empty.
  1030. if l := len(validatedChains); l != 0 {
  1031. return fmt.Errorf("got len(validatedChains) = %d, wanted zero", l)
  1032. }
  1033. *called = true
  1034. return nil
  1035. }
  1036. },
  1037. validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1038. if clientErr != nil {
  1039. t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
  1040. }
  1041. if serverErr != nil {
  1042. t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
  1043. }
  1044. if !clientCalled {
  1045. t.Errorf("test[%d]: client did not call callback", testNo)
  1046. }
  1047. },
  1048. },
  1049. }
  1050. for i, test := range tests {
  1051. c, s := net.Pipe()
  1052. done := make(chan error)
  1053. var clientCalled, serverCalled bool
  1054. go func() {
  1055. config := testConfig.Clone()
  1056. config.ServerName = "example.golang"
  1057. config.ClientAuth = RequireAndVerifyClientCert
  1058. config.ClientCAs = rootCAs
  1059. config.Time = now
  1060. test.configureServer(config, &serverCalled)
  1061. err = Server(s, config).Handshake()
  1062. s.Close()
  1063. done <- err
  1064. }()
  1065. config := testConfig.Clone()
  1066. config.ServerName = "example.golang"
  1067. config.RootCAs = rootCAs
  1068. config.Time = now
  1069. test.configureClient(config, &clientCalled)
  1070. clientErr := Client(c, config).Handshake()
  1071. c.Close()
  1072. serverErr := <-done
  1073. test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr)
  1074. }
  1075. }
  1076. // brokenConn wraps a net.Conn and causes all Writes after a certain number to
  1077. // fail with brokenConnErr.
  1078. type brokenConn struct {
  1079. net.Conn
  1080. // breakAfter is the number of successful writes that will be allowed
  1081. // before all subsequent writes fail.
  1082. breakAfter int
  1083. // numWrites is the number of writes that have been done.
  1084. numWrites int
  1085. }
  1086. // brokenConnErr is the error that brokenConn returns once exhausted.
  1087. var brokenConnErr = errors.New("too many writes to brokenConn")
  1088. func (b *brokenConn) Write(data []byte) (int, error) {
  1089. if b.numWrites >= b.breakAfter {
  1090. return 0, brokenConnErr
  1091. }
  1092. b.numWrites++
  1093. return b.Conn.Write(data)
  1094. }
  1095. func TestFailedWrite(t *testing.T) {
  1096. // Test that a write error during the handshake is returned.
  1097. for _, breakAfter := range []int{0, 1} {
  1098. c, s := net.Pipe()
  1099. done := make(chan bool)
  1100. go func() {
  1101. Server(s, testConfig).Handshake()
  1102. s.Close()
  1103. done <- true
  1104. }()
  1105. brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
  1106. err := Client(brokenC, testConfig).Handshake()
  1107. if err != brokenConnErr {
  1108. t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
  1109. }
  1110. brokenC.Close()
  1111. <-done
  1112. }
  1113. }
  1114. // writeCountingConn wraps a net.Conn and counts the number of Write calls.
  1115. type writeCountingConn struct {
  1116. net.Conn
  1117. // numWrites is the number of writes that have been done.
  1118. numWrites int
  1119. }
  1120. func (wcc *writeCountingConn) Write(data []byte) (int, error) {
  1121. wcc.numWrites++
  1122. return wcc.Conn.Write(data)
  1123. }
  1124. func TestBuffering(t *testing.T) {
  1125. c, s := net.Pipe()
  1126. done := make(chan bool)
  1127. clientWCC := &writeCountingConn{Conn: c}
  1128. serverWCC := &writeCountingConn{Conn: s}
  1129. go func() {
  1130. Server(serverWCC, testConfig).Handshake()
  1131. serverWCC.Close()
  1132. done <- true
  1133. }()
  1134. err := Client(clientWCC, testConfig).Handshake()
  1135. if err != nil {
  1136. t.Fatal(err)
  1137. }
  1138. clientWCC.Close()
  1139. <-done
  1140. if n := clientWCC.numWrites; n != 2 {
  1141. t.Errorf("expected client handshake to complete with only two writes, but saw %d", n)
  1142. }
  1143. if n := serverWCC.numWrites; n != 2 {
  1144. t.Errorf("expected server handshake to complete with only two writes, but saw %d", n)
  1145. }
  1146. }
  1147. func TestAlertFlushing(t *testing.T) {
  1148. c, s := net.Pipe()
  1149. done := make(chan bool)
  1150. clientWCC := &writeCountingConn{Conn: c}
  1151. serverWCC := &writeCountingConn{Conn: s}
  1152. serverConfig := testConfig.Clone()
  1153. // Cause a signature-time error
  1154. brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey}
  1155. brokenKey.D = big.NewInt(42)
  1156. serverConfig.Certificates = []Certificate{{
  1157. Certificate: [][]byte{testRSACertificate},
  1158. PrivateKey: &brokenKey,
  1159. }}
  1160. go func() {
  1161. Server(serverWCC, serverConfig).Handshake()
  1162. serverWCC.Close()
  1163. done <- true
  1164. }()
  1165. err := Client(clientWCC, testConfig).Handshake()
  1166. if err == nil {
  1167. t.Fatal("client unexpectedly returned no error")
  1168. }
  1169. const expectedError = "remote error: tls: handshake failure"
  1170. if e := err.Error(); !strings.Contains(e, expectedError) {
  1171. t.Fatalf("expected to find %q in error but error was %q", expectedError, e)
  1172. }
  1173. clientWCC.Close()
  1174. <-done
  1175. if n := clientWCC.numWrites; n != 1 {
  1176. t.Errorf("expected client handshake to complete with one write, but saw %d", n)
  1177. }
  1178. if n := serverWCC.numWrites; n != 1 {
  1179. t.Errorf("expected server handshake to complete with one write, but saw %d", n)
  1180. }
  1181. }
  1182. func TestHandshakeRace(t *testing.T) {
  1183. t.Parallel()
  1184. // This test races a Read and Write to try and complete a handshake in
  1185. // order to provide some evidence that there are no races or deadlocks
  1186. // in the handshake locking.
  1187. for i := 0; i < 32; i++ {
  1188. c, s := net.Pipe()
  1189. go func() {
  1190. server := Server(s, testConfig)
  1191. if err := server.Handshake(); err != nil {
  1192. panic(err)
  1193. }
  1194. var request [1]byte
  1195. if n, err := server.Read(request[:]); err != nil || n != 1 {
  1196. panic(err)
  1197. }
  1198. server.Write(request[:])
  1199. server.Close()
  1200. }()
  1201. startWrite := make(chan struct{})
  1202. startRead := make(chan struct{})
  1203. readDone := make(chan struct{})
  1204. client := Client(c, testConfig)
  1205. go func() {
  1206. <-startWrite
  1207. var request [1]byte
  1208. client.Write(request[:])
  1209. }()
  1210. go func() {
  1211. <-startRead
  1212. var reply [1]byte
  1213. if n, err := client.Read(reply[:]); err != nil || n != 1 {
  1214. panic(err)
  1215. }
  1216. c.Close()
  1217. readDone <- struct{}{}
  1218. }()
  1219. if i&1 == 1 {
  1220. startWrite <- struct{}{}
  1221. startRead <- struct{}{}
  1222. } else {
  1223. startRead <- struct{}{}
  1224. startWrite <- struct{}{}
  1225. }
  1226. <-readDone
  1227. }
  1228. }
  1229. func TestTLS11SignatureSchemes(t *testing.T) {
  1230. expected := tls11SignatureSchemesNumECDSA + tls11SignatureSchemesNumRSA
  1231. if expected != len(tls11SignatureSchemes) {
  1232. t.Errorf("expected to find %d TLS 1.1 signature schemes, but found %d", expected, len(tls11SignatureSchemes))
  1233. }
  1234. }
  1235. var getClientCertificateTests = []struct {
  1236. setup func(*Config, *Config)
  1237. expectedClientError string
  1238. verify func(*testing.T, int, *ConnectionState)
  1239. }{
  1240. {
  1241. func(clientConfig, serverConfig *Config) {
  1242. // Returning a Certificate with no certificate data
  1243. // should result in an empty message being sent to the
  1244. // server.
  1245. serverConfig.ClientCAs = nil
  1246. clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  1247. if len(cri.SignatureSchemes) == 0 {
  1248. panic("empty SignatureSchemes")
  1249. }
  1250. if len(cri.AcceptableCAs) != 0 {
  1251. panic("AcceptableCAs should have been empty")
  1252. }
  1253. return new(Certificate), nil
  1254. }
  1255. },
  1256. "",
  1257. func(t *testing.T, testNum int, cs *ConnectionState) {
  1258. if l := len(cs.PeerCertificates); l != 0 {
  1259. t.Errorf("#%d: expected no certificates but got %d", testNum, l)
  1260. }
  1261. },
  1262. },
  1263. {
  1264. func(clientConfig, serverConfig *Config) {
  1265. // With TLS 1.1, the SignatureSchemes should be
  1266. // synthesised from the supported certificate types.
  1267. clientConfig.MaxVersion = VersionTLS11
  1268. clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  1269. if len(cri.SignatureSchemes) == 0 {
  1270. panic("empty SignatureSchemes")
  1271. }
  1272. return new(Certificate), nil
  1273. }
  1274. },
  1275. "",
  1276. func(t *testing.T, testNum int, cs *ConnectionState) {
  1277. if l := len(cs.PeerCertificates); l != 0 {
  1278. t.Errorf("#%d: expected no certificates but got %d", testNum, l)
  1279. }
  1280. },
  1281. },
  1282. {
  1283. func(clientConfig, serverConfig *Config) {
  1284. // Returning an error should abort the handshake with
  1285. // that error.
  1286. clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  1287. return nil, errors.New("GetClientCertificate")
  1288. }
  1289. },
  1290. "GetClientCertificate",
  1291. func(t *testing.T, testNum int, cs *ConnectionState) {
  1292. },
  1293. },
  1294. {
  1295. func(clientConfig, serverConfig *Config) {
  1296. clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  1297. if len(cri.AcceptableCAs) == 0 {
  1298. panic("empty AcceptableCAs")
  1299. }
  1300. cert := &Certificate{
  1301. Certificate: [][]byte{testRSACertificate},
  1302. PrivateKey: testRSAPrivateKey,
  1303. }
  1304. return cert, nil
  1305. }
  1306. },
  1307. "",
  1308. func(t *testing.T, testNum int, cs *ConnectionState) {
  1309. if len(cs.VerifiedChains) == 0 {
  1310. t.Errorf("#%d: expected some verified chains, but found none", testNum)
  1311. }
  1312. },
  1313. },
  1314. }
  1315. func TestGetClientCertificate(t *testing.T) {
  1316. issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
  1317. if err != nil {
  1318. panic(err)
  1319. }
  1320. for i, test := range getClientCertificateTests {
  1321. serverConfig := testConfig.Clone()
  1322. serverConfig.ClientAuth = VerifyClientCertIfGiven
  1323. serverConfig.RootCAs = x509.NewCertPool()
  1324. serverConfig.RootCAs.AddCert(issuer)
  1325. serverConfig.ClientCAs = serverConfig.RootCAs
  1326. serverConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
  1327. clientConfig := testConfig.Clone()
  1328. test.setup(clientConfig, serverConfig)
  1329. type serverResult struct {
  1330. cs ConnectionState
  1331. err error
  1332. }
  1333. c, s := net.Pipe()
  1334. done := make(chan serverResult)
  1335. go func() {
  1336. defer s.Close()
  1337. server := Server(s, serverConfig)
  1338. err := server.Handshake()
  1339. var cs ConnectionState
  1340. if err == nil {
  1341. cs = server.ConnectionState()
  1342. }
  1343. done <- serverResult{cs, err}
  1344. }()
  1345. clientErr := Client(c, clientConfig).Handshake()
  1346. c.Close()
  1347. result := <-done
  1348. if clientErr != nil {
  1349. if len(test.expectedClientError) == 0 {
  1350. t.Errorf("#%d: client error: %v", i, clientErr)
  1351. } else if got := clientErr.Error(); got != test.expectedClientError {
  1352. t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got)
  1353. } else {
  1354. test.verify(t, i, &result.cs)
  1355. }
  1356. } else if len(test.expectedClientError) > 0 {
  1357. t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError)
  1358. } else if err := result.err; err != nil {
  1359. t.Errorf("#%d: server error: %v", i, err)
  1360. } else {
  1361. test.verify(t, i, &result.cs)
  1362. }
  1363. }
  1364. }
  1365. func TestRSAPSSKeyError(t *testing.T) {
  1366. // crypto/tls does not support the rsa_pss_pss_xxx SignatureSchemes. If support for
  1367. // public keys with OID RSASSA-PSS is added to crypto/x509, they will be misused with
  1368. // the rsa_pss_rsae_xxx SignatureSchemes. Assert that RSASSA-PSS certificates don't
  1369. // parse, or that they don't carry *rsa.PublicKey keys.
  1370. b, _ := pem.Decode([]byte(`
  1371. -----BEGIN CERTIFICATE-----
  1372. MIIDZTCCAhygAwIBAgIUCF2x0FyTgZG0CC9QTDjGWkB5vgEwPgYJKoZIhvcNAQEK
  1373. MDGgDTALBglghkgBZQMEAgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQC
  1374. AgDeMBIxEDAOBgNVBAMMB1JTQS1QU1MwHhcNMTgwNjI3MjI0NDM2WhcNMTgwNzI3
  1375. MjI0NDM2WjASMRAwDgYDVQQDDAdSU0EtUFNTMIIBIDALBgkqhkiG9w0BAQoDggEP
  1376. ADCCAQoCggEBANxDm0f76JdI06YzsjB3AmmjIYkwUEGxePlafmIASFjDZl/elD0Z
  1377. /a7xLX468b0qGxLS5al7XCcEprSdsDR6DF5L520+pCbpfLyPOjuOvGmk9KzVX4x5
  1378. b05YXYuXdsQ0Kjxcx2i3jjCday6scIhMJVgBZxTEyMj1thPQM14SHzKCd/m6HmCL
  1379. QmswpH2yMAAcBRWzRpp/vdH5DeOJEB3aelq7094no731mrLUCHRiZ1htq8BDB3ou
  1380. czwqgwspbqZ4dnMXl2MvfySQ5wJUxQwILbiuAKO2lVVPUbFXHE9pgtznNoPvKwQT
  1381. JNcX8ee8WIZc2SEGzofjk3NpjR+2ADB2u3sCAwEAAaNTMFEwHQYDVR0OBBYEFNEz
  1382. AdyJ2f+fU+vSCS6QzohnOnprMB8GA1UdIwQYMBaAFNEzAdyJ2f+fU+vSCS6Qzohn
  1383. OnprMA8GA1UdEwEB/wQFMAMBAf8wPgYJKoZIhvcNAQEKMDGgDTALBglghkgBZQME
  1384. AgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQCAgDeA4IBAQCjEdrR5aab
  1385. sZmCwrMeKidXgfkmWvfuLDE+TCbaqDZp7BMWcMQXT9O0UoUT5kqgKj2ARm2pEW0Z
  1386. H3Z1vj3bbds72qcDIJXp+l0fekyLGeCrX/CbgnMZXEP7+/+P416p34ChR1Wz4dU1
  1387. KD3gdsUuTKKeMUog3plxlxQDhRQmiL25ygH1LmjLd6dtIt0GVRGr8lj3euVeprqZ
  1388. bZ3Uq5eLfsn8oPgfC57gpO6yiN+UURRTlK3bgYvLh4VWB3XXk9UaQZ7Mq1tpXjoD
  1389. HYFybkWzibkZp4WRo+Fa28rirH+/wHt0vfeN7UCceURZEx4JaxIIfe4ku7uDRhJi
  1390. RwBA9Xk1KBNF
  1391. -----END CERTIFICATE-----`))
  1392. if b == nil {
  1393. t.Fatal("Failed to decode certificate")
  1394. }
  1395. cert, err := x509.ParseCertificate(b.Bytes)
  1396. if err != nil {
  1397. return
  1398. }
  1399. if _, ok := cert.PublicKey.(*rsa.PublicKey); ok {
  1400. t.Error("A RSA-PSS certificate was parsed like a PKCS1 one, and it will be mistakenly used with rsa_pss_rsae_xxx signature algorithms")
  1401. }
  1402. }