Alternative TLS implementation in Go
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

870 lignes
22 KiB

  1. // Copyright 2010 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // TLS low level connection and record layer
  5. package tls
  6. import (
  7. "bytes"
  8. "crypto/cipher"
  9. "crypto/subtle"
  10. "crypto/x509"
  11. "errors"
  12. "io"
  13. "net"
  14. "sync"
  15. "time"
  16. )
  17. // A Conn represents a secured connection.
  18. // It implements the net.Conn interface.
  19. type Conn struct {
  20. // constant
  21. conn net.Conn
  22. isClient bool
  23. // constant after handshake; protected by handshakeMutex
  24. handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
  25. vers uint16 // TLS version
  26. haveVers bool // version has been negotiated
  27. config *Config // configuration passed to constructor
  28. handshakeComplete bool
  29. didResume bool // whether this connection was a session resumption
  30. cipherSuite uint16
  31. ocspResponse []byte // stapled OCSP response
  32. peerCertificates []*x509.Certificate
  33. // verifiedChains contains the certificate chains that we built, as
  34. // opposed to the ones presented by the server.
  35. verifiedChains [][]*x509.Certificate
  36. // serverName contains the server name indicated by the client, if any.
  37. serverName string
  38. clientProtocol string
  39. clientProtocolFallback bool
  40. // first permanent error
  41. connErr
  42. // input/output
  43. in, out halfConn // in.Mutex < out.Mutex
  44. rawInput *block // raw input, right off the wire
  45. input *block // application data waiting to be read
  46. hand bytes.Buffer // handshake data waiting to be read
  47. tmp [16]byte
  48. }
  49. type connErr struct {
  50. mu sync.Mutex
  51. value error
  52. }
  53. func (e *connErr) setError(err error) error {
  54. e.mu.Lock()
  55. defer e.mu.Unlock()
  56. if e.value == nil {
  57. e.value = err
  58. }
  59. return err
  60. }
  61. func (e *connErr) error() error {
  62. e.mu.Lock()
  63. defer e.mu.Unlock()
  64. return e.value
  65. }
  66. // Access to net.Conn methods.
  67. // Cannot just embed net.Conn because that would
  68. // export the struct field too.
  69. // LocalAddr returns the local network address.
  70. func (c *Conn) LocalAddr() net.Addr {
  71. return c.conn.LocalAddr()
  72. }
  73. // RemoteAddr returns the remote network address.
  74. func (c *Conn) RemoteAddr() net.Addr {
  75. return c.conn.RemoteAddr()
  76. }
  77. // SetDeadline sets the read and write deadlines associated with the connection.
  78. // A zero value for t means Read and Write will not time out.
  79. // After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
  80. func (c *Conn) SetDeadline(t time.Time) error {
  81. return c.conn.SetDeadline(t)
  82. }
  83. // SetReadDeadline sets the read deadline on the underlying connection.
  84. // A zero value for t means Read will not time out.
  85. func (c *Conn) SetReadDeadline(t time.Time) error {
  86. return c.conn.SetReadDeadline(t)
  87. }
  88. // SetWriteDeadline sets the write deadline on the underlying conneciton.
  89. // A zero value for t means Write will not time out.
  90. // After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
  91. func (c *Conn) SetWriteDeadline(t time.Time) error {
  92. return c.conn.SetWriteDeadline(t)
  93. }
  94. // A halfConn represents one direction of the record layer
  95. // connection, either sending or receiving.
  96. type halfConn struct {
  97. sync.Mutex
  98. version uint16 // protocol version
  99. cipher interface{} // cipher algorithm
  100. mac macFunction
  101. seq [8]byte // 64-bit sequence number
  102. bfree *block // list of free blocks
  103. nextCipher interface{} // next encryption state
  104. nextMac macFunction // next MAC algorithm
  105. // used to save allocating a new buffer for each MAC.
  106. inDigestBuf, outDigestBuf []byte
  107. }
  108. // prepareCipherSpec sets the encryption and MAC states
  109. // that a subsequent changeCipherSpec will use.
  110. func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) {
  111. hc.version = version
  112. hc.nextCipher = cipher
  113. hc.nextMac = mac
  114. }
  115. // changeCipherSpec changes the encryption and MAC states
  116. // to the ones previously passed to prepareCipherSpec.
  117. func (hc *halfConn) changeCipherSpec() error {
  118. if hc.nextCipher == nil {
  119. return alertInternalError
  120. }
  121. hc.cipher = hc.nextCipher
  122. hc.mac = hc.nextMac
  123. hc.nextCipher = nil
  124. hc.nextMac = nil
  125. return nil
  126. }
  127. // incSeq increments the sequence number.
  128. func (hc *halfConn) incSeq() {
  129. for i := 7; i >= 0; i-- {
  130. hc.seq[i]++
  131. if hc.seq[i] != 0 {
  132. return
  133. }
  134. }
  135. // Not allowed to let sequence number wrap.
  136. // Instead, must renegotiate before it does.
  137. // Not likely enough to bother.
  138. panic("TLS: sequence number wraparound")
  139. }
  140. // resetSeq resets the sequence number to zero.
  141. func (hc *halfConn) resetSeq() {
  142. for i := range hc.seq {
  143. hc.seq[i] = 0
  144. }
  145. }
  146. // removePadding returns an unpadded slice, in constant time, which is a prefix
  147. // of the input. It also returns a byte which is equal to 255 if the padding
  148. // was valid and 0 otherwise. See RFC 2246, section 6.2.3.2
  149. func removePadding(payload []byte) ([]byte, byte) {
  150. if len(payload) < 1 {
  151. return payload, 0
  152. }
  153. paddingLen := payload[len(payload)-1]
  154. t := uint(len(payload)-1) - uint(paddingLen)
  155. // if len(payload) >= (paddingLen - 1) then the MSB of t is zero
  156. good := byte(int32(^t) >> 31)
  157. toCheck := 255 // the maximum possible padding length
  158. // The length of the padded data is public, so we can use an if here
  159. if toCheck+1 > len(payload) {
  160. toCheck = len(payload) - 1
  161. }
  162. for i := 0; i < toCheck; i++ {
  163. t := uint(paddingLen) - uint(i)
  164. // if i <= paddingLen then the MSB of t is zero
  165. mask := byte(int32(^t) >> 31)
  166. b := payload[len(payload)-1-i]
  167. good &^= mask&paddingLen ^ mask&b
  168. }
  169. // We AND together the bits of good and replicate the result across
  170. // all the bits.
  171. good &= good << 4
  172. good &= good << 2
  173. good &= good << 1
  174. good = uint8(int8(good) >> 7)
  175. toRemove := good&paddingLen + 1
  176. return payload[:len(payload)-int(toRemove)], good
  177. }
  178. // removePaddingSSL30 is a replacement for removePadding in the case that the
  179. // protocol version is SSLv3. In this version, the contents of the padding
  180. // are random and cannot be checked.
  181. func removePaddingSSL30(payload []byte) ([]byte, byte) {
  182. if len(payload) < 1 {
  183. return payload, 0
  184. }
  185. paddingLen := int(payload[len(payload)-1]) + 1
  186. if paddingLen > len(payload) {
  187. return payload, 0
  188. }
  189. return payload[:len(payload)-paddingLen], 255
  190. }
  191. func roundUp(a, b int) int {
  192. return a + (b-a%b)%b
  193. }
  194. // decrypt checks and strips the mac and decrypts the data in b.
  195. func (hc *halfConn) decrypt(b *block) (bool, alert) {
  196. // pull out payload
  197. payload := b.data[recordHeaderLen:]
  198. macSize := 0
  199. if hc.mac != nil {
  200. macSize = hc.mac.Size()
  201. }
  202. paddingGood := byte(255)
  203. // decrypt
  204. if hc.cipher != nil {
  205. switch c := hc.cipher.(type) {
  206. case cipher.Stream:
  207. c.XORKeyStream(payload, payload)
  208. case cipher.BlockMode:
  209. blockSize := c.BlockSize()
  210. if len(payload)%blockSize != 0 || len(payload) < roundUp(macSize+1, blockSize) {
  211. return false, alertBadRecordMAC
  212. }
  213. c.CryptBlocks(payload, payload)
  214. if hc.version == versionSSL30 {
  215. payload, paddingGood = removePaddingSSL30(payload)
  216. } else {
  217. payload, paddingGood = removePadding(payload)
  218. }
  219. b.resize(recordHeaderLen + len(payload))
  220. // note that we still have a timing side-channel in the
  221. // MAC check, below. An attacker can align the record
  222. // so that a correct padding will cause one less hash
  223. // block to be calculated. Then they can iteratively
  224. // decrypt a record by breaking each byte. See
  225. // "Password Interception in a SSL/TLS Channel", Brice
  226. // Canvel et al.
  227. //
  228. // However, our behavior matches OpenSSL, so we leak
  229. // only as much as they do.
  230. default:
  231. panic("unknown cipher type")
  232. }
  233. }
  234. // check, strip mac
  235. if hc.mac != nil {
  236. if len(payload) < macSize {
  237. return false, alertBadRecordMAC
  238. }
  239. // strip mac off payload, b.data
  240. n := len(payload) - macSize
  241. b.data[3] = byte(n >> 8)
  242. b.data[4] = byte(n)
  243. b.resize(recordHeaderLen + n)
  244. remoteMAC := payload[n:]
  245. localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data)
  246. hc.incSeq()
  247. if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
  248. return false, alertBadRecordMAC
  249. }
  250. hc.inDigestBuf = localMAC
  251. }
  252. return true, 0
  253. }
  254. // padToBlockSize calculates the needed padding block, if any, for a payload.
  255. // On exit, prefix aliases payload and extends to the end of the last full
  256. // block of payload. finalBlock is a fresh slice which contains the contents of
  257. // any suffix of payload as well as the needed padding to make finalBlock a
  258. // full block.
  259. func padToBlockSize(payload []byte, blockSize int) (prefix, finalBlock []byte) {
  260. overrun := len(payload) % blockSize
  261. paddingLen := blockSize - overrun
  262. prefix = payload[:len(payload)-overrun]
  263. finalBlock = make([]byte, blockSize)
  264. copy(finalBlock, payload[len(payload)-overrun:])
  265. for i := overrun; i < blockSize; i++ {
  266. finalBlock[i] = byte(paddingLen - 1)
  267. }
  268. return
  269. }
  270. // encrypt encrypts and macs the data in b.
  271. func (hc *halfConn) encrypt(b *block) (bool, alert) {
  272. // mac
  273. if hc.mac != nil {
  274. mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data)
  275. hc.incSeq()
  276. n := len(b.data)
  277. b.resize(n + len(mac))
  278. copy(b.data[n:], mac)
  279. hc.outDigestBuf = mac
  280. }
  281. payload := b.data[recordHeaderLen:]
  282. // encrypt
  283. if hc.cipher != nil {
  284. switch c := hc.cipher.(type) {
  285. case cipher.Stream:
  286. c.XORKeyStream(payload, payload)
  287. case cipher.BlockMode:
  288. prefix, finalBlock := padToBlockSize(payload, c.BlockSize())
  289. b.resize(recordHeaderLen + len(prefix) + len(finalBlock))
  290. c.CryptBlocks(b.data[recordHeaderLen:], prefix)
  291. c.CryptBlocks(b.data[recordHeaderLen+len(prefix):], finalBlock)
  292. default:
  293. panic("unknown cipher type")
  294. }
  295. }
  296. // update length to include MAC and any block padding needed.
  297. n := len(b.data) - recordHeaderLen
  298. b.data[3] = byte(n >> 8)
  299. b.data[4] = byte(n)
  300. return true, 0
  301. }
  302. // A block is a simple data buffer.
  303. type block struct {
  304. data []byte
  305. off int // index for Read
  306. link *block
  307. }
  308. // resize resizes block to be n bytes, growing if necessary.
  309. func (b *block) resize(n int) {
  310. if n > cap(b.data) {
  311. b.reserve(n)
  312. }
  313. b.data = b.data[0:n]
  314. }
  315. // reserve makes sure that block contains a capacity of at least n bytes.
  316. func (b *block) reserve(n int) {
  317. if cap(b.data) >= n {
  318. return
  319. }
  320. m := cap(b.data)
  321. if m == 0 {
  322. m = 1024
  323. }
  324. for m < n {
  325. m *= 2
  326. }
  327. data := make([]byte, len(b.data), m)
  328. copy(data, b.data)
  329. b.data = data
  330. }
  331. // readFromUntil reads from r into b until b contains at least n bytes
  332. // or else returns an error.
  333. func (b *block) readFromUntil(r io.Reader, n int) error {
  334. // quick case
  335. if len(b.data) >= n {
  336. return nil
  337. }
  338. // read until have enough.
  339. b.reserve(n)
  340. for {
  341. m, err := r.Read(b.data[len(b.data):cap(b.data)])
  342. b.data = b.data[0 : len(b.data)+m]
  343. if len(b.data) >= n {
  344. break
  345. }
  346. if err != nil {
  347. return err
  348. }
  349. }
  350. return nil
  351. }
  352. func (b *block) Read(p []byte) (n int, err error) {
  353. n = copy(p, b.data[b.off:])
  354. b.off += n
  355. return
  356. }
  357. // newBlock allocates a new block, from hc's free list if possible.
  358. func (hc *halfConn) newBlock() *block {
  359. b := hc.bfree
  360. if b == nil {
  361. return new(block)
  362. }
  363. hc.bfree = b.link
  364. b.link = nil
  365. b.resize(0)
  366. return b
  367. }
  368. // freeBlock returns a block to hc's free list.
  369. // The protocol is such that each side only has a block or two on
  370. // its free list at a time, so there's no need to worry about
  371. // trimming the list, etc.
  372. func (hc *halfConn) freeBlock(b *block) {
  373. b.link = hc.bfree
  374. hc.bfree = b
  375. }
  376. // splitBlock splits a block after the first n bytes,
  377. // returning a block with those n bytes and a
  378. // block with the remainder. the latter may be nil.
  379. func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
  380. if len(b.data) <= n {
  381. return b, nil
  382. }
  383. bb := hc.newBlock()
  384. bb.resize(len(b.data) - n)
  385. copy(bb.data, b.data[n:])
  386. b.data = b.data[0:n]
  387. return b, bb
  388. }
  389. // readRecord reads the next TLS record from the connection
  390. // and updates the record layer state.
  391. // c.in.Mutex <= L; c.input == nil.
  392. func (c *Conn) readRecord(want recordType) error {
  393. // Caller must be in sync with connection:
  394. // handshake data if handshake not yet completed,
  395. // else application data. (We don't support renegotiation.)
  396. switch want {
  397. default:
  398. return c.sendAlert(alertInternalError)
  399. case recordTypeHandshake, recordTypeChangeCipherSpec:
  400. if c.handshakeComplete {
  401. return c.sendAlert(alertInternalError)
  402. }
  403. case recordTypeApplicationData:
  404. if !c.handshakeComplete {
  405. return c.sendAlert(alertInternalError)
  406. }
  407. }
  408. Again:
  409. if c.rawInput == nil {
  410. c.rawInput = c.in.newBlock()
  411. }
  412. b := c.rawInput
  413. // Read header, payload.
  414. if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
  415. // RFC suggests that EOF without an alertCloseNotify is
  416. // an error, but popular web sites seem to do this,
  417. // so we can't make it an error.
  418. // if err == io.EOF {
  419. // err = io.ErrUnexpectedEOF
  420. // }
  421. if e, ok := err.(net.Error); !ok || !e.Temporary() {
  422. c.setError(err)
  423. }
  424. return err
  425. }
  426. typ := recordType(b.data[0])
  427. // No valid TLS record has a type of 0x80, however SSLv2 handshakes
  428. // start with a uint16 length where the MSB is set and the first record
  429. // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests
  430. // an SSLv2 client.
  431. if want == recordTypeHandshake && typ == 0x80 {
  432. c.sendAlert(alertProtocolVersion)
  433. return errors.New("tls: unsupported SSLv2 handshake received")
  434. }
  435. vers := uint16(b.data[1])<<8 | uint16(b.data[2])
  436. n := int(b.data[3])<<8 | int(b.data[4])
  437. if c.haveVers && vers != c.vers {
  438. return c.sendAlert(alertProtocolVersion)
  439. }
  440. if n > maxCiphertext {
  441. return c.sendAlert(alertRecordOverflow)
  442. }
  443. if !c.haveVers {
  444. // First message, be extra suspicious:
  445. // this might not be a TLS client.
  446. // Bail out before reading a full 'body', if possible.
  447. // The current max version is 3.1.
  448. // If the version is >= 16.0, it's probably not real.
  449. // Similarly, a clientHello message encodes in
  450. // well under a kilobyte. If the length is >= 12 kB,
  451. // it's probably not real.
  452. if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
  453. return c.sendAlert(alertUnexpectedMessage)
  454. }
  455. }
  456. if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
  457. if err == io.EOF {
  458. err = io.ErrUnexpectedEOF
  459. }
  460. if e, ok := err.(net.Error); !ok || !e.Temporary() {
  461. c.setError(err)
  462. }
  463. return err
  464. }
  465. // Process message.
  466. b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
  467. b.off = recordHeaderLen
  468. if ok, err := c.in.decrypt(b); !ok {
  469. return c.sendAlert(err)
  470. }
  471. data := b.data[b.off:]
  472. if len(data) > maxPlaintext {
  473. c.sendAlert(alertRecordOverflow)
  474. c.in.freeBlock(b)
  475. return c.error()
  476. }
  477. switch typ {
  478. default:
  479. c.sendAlert(alertUnexpectedMessage)
  480. case recordTypeAlert:
  481. if len(data) != 2 {
  482. c.sendAlert(alertUnexpectedMessage)
  483. break
  484. }
  485. if alert(data[1]) == alertCloseNotify {
  486. c.setError(io.EOF)
  487. break
  488. }
  489. switch data[0] {
  490. case alertLevelWarning:
  491. // drop on the floor
  492. c.in.freeBlock(b)
  493. goto Again
  494. case alertLevelError:
  495. c.setError(&net.OpError{Op: "remote error", Err: alert(data[1])})
  496. default:
  497. c.sendAlert(alertUnexpectedMessage)
  498. }
  499. case recordTypeChangeCipherSpec:
  500. if typ != want || len(data) != 1 || data[0] != 1 {
  501. c.sendAlert(alertUnexpectedMessage)
  502. break
  503. }
  504. err := c.in.changeCipherSpec()
  505. if err != nil {
  506. c.sendAlert(err.(alert))
  507. }
  508. case recordTypeApplicationData:
  509. if typ != want {
  510. c.sendAlert(alertUnexpectedMessage)
  511. break
  512. }
  513. c.input = b
  514. b = nil
  515. case recordTypeHandshake:
  516. // TODO(rsc): Should at least pick off connection close.
  517. if typ != want {
  518. return c.sendAlert(alertNoRenegotiation)
  519. }
  520. c.hand.Write(data)
  521. }
  522. if b != nil {
  523. c.in.freeBlock(b)
  524. }
  525. return c.error()
  526. }
  527. // sendAlert sends a TLS alert message.
  528. // c.out.Mutex <= L.
  529. func (c *Conn) sendAlertLocked(err alert) error {
  530. switch err {
  531. case alertNoRenegotiation, alertCloseNotify:
  532. c.tmp[0] = alertLevelWarning
  533. default:
  534. c.tmp[0] = alertLevelError
  535. }
  536. c.tmp[1] = byte(err)
  537. c.writeRecord(recordTypeAlert, c.tmp[0:2])
  538. // closeNotify is a special case in that it isn't an error:
  539. if err != alertCloseNotify {
  540. return c.setError(&net.OpError{Op: "local error", Err: err})
  541. }
  542. return nil
  543. }
  544. // sendAlert sends a TLS alert message.
  545. // L < c.out.Mutex.
  546. func (c *Conn) sendAlert(err alert) error {
  547. c.out.Lock()
  548. defer c.out.Unlock()
  549. return c.sendAlertLocked(err)
  550. }
  551. // writeRecord writes a TLS record with the given type and payload
  552. // to the connection and updates the record layer state.
  553. // c.out.Mutex <= L.
  554. func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
  555. b := c.out.newBlock()
  556. for len(data) > 0 {
  557. m := len(data)
  558. if m > maxPlaintext {
  559. m = maxPlaintext
  560. }
  561. b.resize(recordHeaderLen + m)
  562. b.data[0] = byte(typ)
  563. vers := c.vers
  564. if vers == 0 {
  565. vers = maxVersion
  566. }
  567. b.data[1] = byte(vers >> 8)
  568. b.data[2] = byte(vers)
  569. b.data[3] = byte(m >> 8)
  570. b.data[4] = byte(m)
  571. copy(b.data[recordHeaderLen:], data)
  572. c.out.encrypt(b)
  573. _, err = c.conn.Write(b.data)
  574. if err != nil {
  575. break
  576. }
  577. n += m
  578. data = data[m:]
  579. }
  580. c.out.freeBlock(b)
  581. if typ == recordTypeChangeCipherSpec {
  582. err = c.out.changeCipherSpec()
  583. if err != nil {
  584. // Cannot call sendAlert directly,
  585. // because we already hold c.out.Mutex.
  586. c.tmp[0] = alertLevelError
  587. c.tmp[1] = byte(err.(alert))
  588. c.writeRecord(recordTypeAlert, c.tmp[0:2])
  589. return n, c.setError(&net.OpError{Op: "local error", Err: err})
  590. }
  591. }
  592. return
  593. }
  594. // readHandshake reads the next handshake message from
  595. // the record layer.
  596. // c.in.Mutex < L; c.out.Mutex < L.
  597. func (c *Conn) readHandshake() (interface{}, error) {
  598. for c.hand.Len() < 4 {
  599. if err := c.error(); err != nil {
  600. return nil, err
  601. }
  602. if err := c.readRecord(recordTypeHandshake); err != nil {
  603. return nil, err
  604. }
  605. }
  606. data := c.hand.Bytes()
  607. n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  608. if n > maxHandshake {
  609. c.sendAlert(alertInternalError)
  610. return nil, c.error()
  611. }
  612. for c.hand.Len() < 4+n {
  613. if err := c.error(); err != nil {
  614. return nil, err
  615. }
  616. if err := c.readRecord(recordTypeHandshake); err != nil {
  617. return nil, err
  618. }
  619. }
  620. data = c.hand.Next(4 + n)
  621. var m handshakeMessage
  622. switch data[0] {
  623. case typeClientHello:
  624. m = new(clientHelloMsg)
  625. case typeServerHello:
  626. m = new(serverHelloMsg)
  627. case typeCertificate:
  628. m = new(certificateMsg)
  629. case typeCertificateRequest:
  630. m = new(certificateRequestMsg)
  631. case typeCertificateStatus:
  632. m = new(certificateStatusMsg)
  633. case typeServerKeyExchange:
  634. m = new(serverKeyExchangeMsg)
  635. case typeServerHelloDone:
  636. m = new(serverHelloDoneMsg)
  637. case typeClientKeyExchange:
  638. m = new(clientKeyExchangeMsg)
  639. case typeCertificateVerify:
  640. m = new(certificateVerifyMsg)
  641. case typeNextProtocol:
  642. m = new(nextProtoMsg)
  643. case typeFinished:
  644. m = new(finishedMsg)
  645. default:
  646. c.sendAlert(alertUnexpectedMessage)
  647. return nil, alertUnexpectedMessage
  648. }
  649. // The handshake message unmarshallers
  650. // expect to be able to keep references to data,
  651. // so pass in a fresh copy that won't be overwritten.
  652. data = append([]byte(nil), data...)
  653. if !m.unmarshal(data) {
  654. c.sendAlert(alertUnexpectedMessage)
  655. return nil, alertUnexpectedMessage
  656. }
  657. return m, nil
  658. }
  659. // Write writes data to the connection.
  660. func (c *Conn) Write(b []byte) (int, error) {
  661. if err := c.error(); err != nil {
  662. return 0, err
  663. }
  664. if err := c.Handshake(); err != nil {
  665. return 0, c.setError(err)
  666. }
  667. c.out.Lock()
  668. defer c.out.Unlock()
  669. if !c.handshakeComplete {
  670. return 0, alertInternalError
  671. }
  672. n, err := c.writeRecord(recordTypeApplicationData, b)
  673. return n, c.setError(err)
  674. }
  675. // Read can be made to time out and return a net.Error with Timeout() == true
  676. // after a fixed time limit; see SetDeadline and SetReadDeadline.
  677. func (c *Conn) Read(b []byte) (n int, err error) {
  678. if err = c.Handshake(); err != nil {
  679. return
  680. }
  681. c.in.Lock()
  682. defer c.in.Unlock()
  683. for c.input == nil && c.error() == nil {
  684. if err := c.readRecord(recordTypeApplicationData); err != nil {
  685. // Soft error, like EAGAIN
  686. return 0, err
  687. }
  688. }
  689. if err := c.error(); err != nil {
  690. return 0, err
  691. }
  692. n, err = c.input.Read(b)
  693. if c.input.off >= len(c.input.data) {
  694. c.in.freeBlock(c.input)
  695. c.input = nil
  696. }
  697. return n, nil
  698. }
  699. // Close closes the connection.
  700. func (c *Conn) Close() error {
  701. var alertErr error
  702. c.handshakeMutex.Lock()
  703. defer c.handshakeMutex.Unlock()
  704. if c.handshakeComplete {
  705. alertErr = c.sendAlert(alertCloseNotify)
  706. }
  707. if err := c.conn.Close(); err != nil {
  708. return err
  709. }
  710. return alertErr
  711. }
  712. // Handshake runs the client or server handshake
  713. // protocol if it has not yet been run.
  714. // Most uses of this package need not call Handshake
  715. // explicitly: the first Read or Write will call it automatically.
  716. func (c *Conn) Handshake() error {
  717. c.handshakeMutex.Lock()
  718. defer c.handshakeMutex.Unlock()
  719. if err := c.error(); err != nil {
  720. return err
  721. }
  722. if c.handshakeComplete {
  723. return nil
  724. }
  725. if c.isClient {
  726. return c.clientHandshake()
  727. }
  728. return c.serverHandshake()
  729. }
  730. // ConnectionState returns basic TLS details about the connection.
  731. func (c *Conn) ConnectionState() ConnectionState {
  732. c.handshakeMutex.Lock()
  733. defer c.handshakeMutex.Unlock()
  734. var state ConnectionState
  735. state.HandshakeComplete = c.handshakeComplete
  736. if c.handshakeComplete {
  737. state.NegotiatedProtocol = c.clientProtocol
  738. state.DidResume = c.didResume
  739. state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
  740. state.CipherSuite = c.cipherSuite
  741. state.PeerCertificates = c.peerCertificates
  742. state.VerifiedChains = c.verifiedChains
  743. state.ServerName = c.serverName
  744. }
  745. return state
  746. }
  747. // OCSPResponse returns the stapled OCSP response from the TLS server, if
  748. // any. (Only valid for client connections.)
  749. func (c *Conn) OCSPResponse() []byte {
  750. c.handshakeMutex.Lock()
  751. defer c.handshakeMutex.Unlock()
  752. return c.ocspResponse
  753. }
  754. // VerifyHostname checks that the peer certificate chain is valid for
  755. // connecting to host. If so, it returns nil; if not, it returns an error
  756. // describing the problem.
  757. func (c *Conn) VerifyHostname(host string) error {
  758. c.handshakeMutex.Lock()
  759. defer c.handshakeMutex.Unlock()
  760. if !c.isClient {
  761. return errors.New("VerifyHostname called on TLS server connection")
  762. }
  763. if !c.handshakeComplete {
  764. return errors.New("TLS handshake has not yet been performed")
  765. }
  766. return c.peerCertificates[0].VerifyHostname(host)
  767. }