Alternative TLS implementation in Go
25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

688 lines
16 KiB

  1. // TLS low level connection and record layer
  2. package tls
  3. import (
  4. "bytes"
  5. "crypto/subtle"
  6. "crypto/x509"
  7. "hash"
  8. "io"
  9. "net"
  10. "os"
  11. "sync"
  12. )
  13. // A Conn represents a secured connection.
  14. // It implements the net.Conn interface.
  15. type Conn struct {
  16. // constant
  17. conn net.Conn
  18. isClient bool
  19. // constant after handshake; protected by handshakeMutex
  20. handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
  21. vers uint16 // TLS version
  22. haveVers bool // version has been negotiated
  23. config *Config // configuration passed to constructor
  24. handshakeComplete bool
  25. cipherSuite uint16
  26. ocspResponse []byte // stapled OCSP response
  27. peerCertificates []*x509.Certificate
  28. clientProtocol string
  29. // first permanent error
  30. errMutex sync.Mutex
  31. err os.Error
  32. // input/output
  33. in, out halfConn // in.Mutex < out.Mutex
  34. rawInput *block // raw input, right off the wire
  35. input *block // application data waiting to be read
  36. hand bytes.Buffer // handshake data waiting to be read
  37. tmp [16]byte
  38. }
  39. func (c *Conn) setError(err os.Error) os.Error {
  40. c.errMutex.Lock()
  41. defer c.errMutex.Unlock()
  42. if c.err == nil {
  43. c.err = err
  44. }
  45. return err
  46. }
  47. func (c *Conn) error() os.Error {
  48. c.errMutex.Lock()
  49. defer c.errMutex.Unlock()
  50. return c.err
  51. }
  52. // Access to net.Conn methods.
  53. // Cannot just embed net.Conn because that would
  54. // export the struct field too.
  55. // LocalAddr returns the local network address.
  56. func (c *Conn) LocalAddr() net.Addr {
  57. return c.conn.LocalAddr()
  58. }
  59. // RemoteAddr returns the remote network address.
  60. func (c *Conn) RemoteAddr() net.Addr {
  61. return c.conn.RemoteAddr()
  62. }
  63. // SetTimeout sets the read deadline associated with the connection.
  64. // There is no write deadline.
  65. func (c *Conn) SetTimeout(nsec int64) os.Error {
  66. return c.conn.SetTimeout(nsec)
  67. }
  68. // SetReadTimeout sets the time (in nanoseconds) that
  69. // Read will wait for data before returning os.EAGAIN.
  70. // Setting nsec == 0 (the default) disables the deadline.
  71. func (c *Conn) SetReadTimeout(nsec int64) os.Error {
  72. return c.conn.SetReadTimeout(nsec)
  73. }
  74. // SetWriteTimeout exists to satisfy the net.Conn interface
  75. // but is not implemented by TLS. It always returns an error.
  76. func (c *Conn) SetWriteTimeout(nsec int64) os.Error {
  77. return os.NewError("TLS does not support SetWriteTimeout")
  78. }
  79. // A halfConn represents one direction of the record layer
  80. // connection, either sending or receiving.
  81. type halfConn struct {
  82. sync.Mutex
  83. crypt encryptor // encryption state
  84. mac hash.Hash // MAC algorithm
  85. seq [8]byte // 64-bit sequence number
  86. bfree *block // list of free blocks
  87. nextCrypt encryptor // next encryption state
  88. nextMac hash.Hash // next MAC algorithm
  89. }
  90. // prepareCipherSpec sets the encryption and MAC states
  91. // that a subsequent changeCipherSpec will use.
  92. func (hc *halfConn) prepareCipherSpec(crypt encryptor, mac hash.Hash) {
  93. hc.nextCrypt = crypt
  94. hc.nextMac = mac
  95. }
  96. // changeCipherSpec changes the encryption and MAC states
  97. // to the ones previously passed to prepareCipherSpec.
  98. func (hc *halfConn) changeCipherSpec() os.Error {
  99. if hc.nextCrypt == nil {
  100. return alertInternalError
  101. }
  102. hc.crypt = hc.nextCrypt
  103. hc.mac = hc.nextMac
  104. hc.nextCrypt = nil
  105. hc.nextMac = nil
  106. return nil
  107. }
  108. // incSeq increments the sequence number.
  109. func (hc *halfConn) incSeq() {
  110. for i := 7; i >= 0; i-- {
  111. hc.seq[i]++
  112. if hc.seq[i] != 0 {
  113. return
  114. }
  115. }
  116. // Not allowed to let sequence number wrap.
  117. // Instead, must renegotiate before it does.
  118. // Not likely enough to bother.
  119. panic("TLS: sequence number wraparound")
  120. }
  121. // resetSeq resets the sequence number to zero.
  122. func (hc *halfConn) resetSeq() {
  123. for i := range hc.seq {
  124. hc.seq[i] = 0
  125. }
  126. }
  127. // decrypt checks and strips the mac and decrypts the data in b.
  128. func (hc *halfConn) decrypt(b *block) (bool, alert) {
  129. // pull out payload
  130. payload := b.data[recordHeaderLen:]
  131. // decrypt
  132. if hc.crypt != nil {
  133. hc.crypt.XORKeyStream(payload)
  134. }
  135. // check, strip mac
  136. if hc.mac != nil {
  137. if len(payload) < hc.mac.Size() {
  138. return false, alertBadRecordMAC
  139. }
  140. // strip mac off payload, b.data
  141. n := len(payload) - hc.mac.Size()
  142. b.data[3] = byte(n >> 8)
  143. b.data[4] = byte(n)
  144. b.data = b.data[0 : recordHeaderLen+n]
  145. remoteMAC := payload[n:]
  146. hc.mac.Reset()
  147. hc.mac.Write(hc.seq[0:])
  148. hc.incSeq()
  149. hc.mac.Write(b.data)
  150. if subtle.ConstantTimeCompare(hc.mac.Sum(), remoteMAC) != 1 {
  151. return false, alertBadRecordMAC
  152. }
  153. }
  154. return true, 0
  155. }
  156. // encrypt encrypts and macs the data in b.
  157. func (hc *halfConn) encrypt(b *block) (bool, alert) {
  158. // mac
  159. if hc.mac != nil {
  160. hc.mac.Reset()
  161. hc.mac.Write(hc.seq[0:])
  162. hc.incSeq()
  163. hc.mac.Write(b.data)
  164. mac := hc.mac.Sum()
  165. n := len(b.data)
  166. b.resize(n + len(mac))
  167. copy(b.data[n:], mac)
  168. // update length to include mac
  169. n = len(b.data) - recordHeaderLen
  170. b.data[3] = byte(n >> 8)
  171. b.data[4] = byte(n)
  172. }
  173. // encrypt
  174. if hc.crypt != nil {
  175. hc.crypt.XORKeyStream(b.data[recordHeaderLen:])
  176. }
  177. return true, 0
  178. }
  179. // A block is a simple data buffer.
  180. type block struct {
  181. data []byte
  182. off int // index for Read
  183. link *block
  184. }
  185. // resize resizes block to be n bytes, growing if necessary.
  186. func (b *block) resize(n int) {
  187. if n > cap(b.data) {
  188. b.reserve(n)
  189. }
  190. b.data = b.data[0:n]
  191. }
  192. // reserve makes sure that block contains a capacity of at least n bytes.
  193. func (b *block) reserve(n int) {
  194. if cap(b.data) >= n {
  195. return
  196. }
  197. m := cap(b.data)
  198. if m == 0 {
  199. m = 1024
  200. }
  201. for m < n {
  202. m *= 2
  203. }
  204. data := make([]byte, len(b.data), m)
  205. copy(data, b.data)
  206. b.data = data
  207. }
  208. // readFromUntil reads from r into b until b contains at least n bytes
  209. // or else returns an error.
  210. func (b *block) readFromUntil(r io.Reader, n int) os.Error {
  211. // quick case
  212. if len(b.data) >= n {
  213. return nil
  214. }
  215. // read until have enough.
  216. b.reserve(n)
  217. for {
  218. m, err := r.Read(b.data[len(b.data):cap(b.data)])
  219. b.data = b.data[0 : len(b.data)+m]
  220. if len(b.data) >= n {
  221. break
  222. }
  223. if err != nil {
  224. return err
  225. }
  226. }
  227. return nil
  228. }
  229. func (b *block) Read(p []byte) (n int, err os.Error) {
  230. n = copy(p, b.data[b.off:])
  231. b.off += n
  232. return
  233. }
  234. // newBlock allocates a new block, from hc's free list if possible.
  235. func (hc *halfConn) newBlock() *block {
  236. b := hc.bfree
  237. if b == nil {
  238. return new(block)
  239. }
  240. hc.bfree = b.link
  241. b.link = nil
  242. b.resize(0)
  243. return b
  244. }
  245. // freeBlock returns a block to hc's free list.
  246. // The protocol is such that each side only has a block or two on
  247. // its free list at a time, so there's no need to worry about
  248. // trimming the list, etc.
  249. func (hc *halfConn) freeBlock(b *block) {
  250. b.link = hc.bfree
  251. hc.bfree = b
  252. }
  253. // splitBlock splits a block after the first n bytes,
  254. // returning a block with those n bytes and a
  255. // block with the remaindec. the latter may be nil.
  256. func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
  257. if len(b.data) <= n {
  258. return b, nil
  259. }
  260. bb := hc.newBlock()
  261. bb.resize(len(b.data) - n)
  262. copy(bb.data, b.data[n:])
  263. b.data = b.data[0:n]
  264. return b, bb
  265. }
  266. // readRecord reads the next TLS record from the connection
  267. // and updates the record layer state.
  268. // c.in.Mutex <= L; c.input == nil.
  269. func (c *Conn) readRecord(want recordType) os.Error {
  270. // Caller must be in sync with connection:
  271. // handshake data if handshake not yet completed,
  272. // else application data. (We don't support renegotiation.)
  273. switch want {
  274. default:
  275. return c.sendAlert(alertInternalError)
  276. case recordTypeHandshake, recordTypeChangeCipherSpec:
  277. if c.handshakeComplete {
  278. return c.sendAlert(alertInternalError)
  279. }
  280. case recordTypeApplicationData:
  281. if !c.handshakeComplete {
  282. return c.sendAlert(alertInternalError)
  283. }
  284. }
  285. Again:
  286. if c.rawInput == nil {
  287. c.rawInput = c.in.newBlock()
  288. }
  289. b := c.rawInput
  290. // Read header, payload.
  291. if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
  292. // RFC suggests that EOF without an alertCloseNotify is
  293. // an error, but popular web sites seem to do this,
  294. // so we can't make it an error.
  295. // if err == os.EOF {
  296. // err = io.ErrUnexpectedEOF
  297. // }
  298. if e, ok := err.(net.Error); !ok || !e.Temporary() {
  299. c.setError(err)
  300. }
  301. return err
  302. }
  303. typ := recordType(b.data[0])
  304. vers := uint16(b.data[1])<<8 | uint16(b.data[2])
  305. n := int(b.data[3])<<8 | int(b.data[4])
  306. if c.haveVers && vers != c.vers {
  307. return c.sendAlert(alertProtocolVersion)
  308. }
  309. if n > maxCiphertext {
  310. return c.sendAlert(alertRecordOverflow)
  311. }
  312. if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
  313. if err == os.EOF {
  314. err = io.ErrUnexpectedEOF
  315. }
  316. if e, ok := err.(net.Error); !ok || !e.Temporary() {
  317. c.setError(err)
  318. }
  319. return err
  320. }
  321. // Process message.
  322. b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
  323. b.off = recordHeaderLen
  324. if ok, err := c.in.decrypt(b); !ok {
  325. return c.sendAlert(err)
  326. }
  327. data := b.data[b.off:]
  328. if len(data) > maxPlaintext {
  329. c.sendAlert(alertRecordOverflow)
  330. c.in.freeBlock(b)
  331. return c.error()
  332. }
  333. switch typ {
  334. default:
  335. c.sendAlert(alertUnexpectedMessage)
  336. case recordTypeAlert:
  337. if len(data) != 2 {
  338. c.sendAlert(alertUnexpectedMessage)
  339. break
  340. }
  341. if alert(data[1]) == alertCloseNotify {
  342. c.setError(os.EOF)
  343. break
  344. }
  345. switch data[0] {
  346. case alertLevelWarning:
  347. // drop on the floor
  348. c.in.freeBlock(b)
  349. goto Again
  350. case alertLevelError:
  351. c.setError(&net.OpError{Op: "remote error", Error: alert(data[1])})
  352. default:
  353. c.sendAlert(alertUnexpectedMessage)
  354. }
  355. case recordTypeChangeCipherSpec:
  356. if typ != want || len(data) != 1 || data[0] != 1 {
  357. c.sendAlert(alertUnexpectedMessage)
  358. break
  359. }
  360. err := c.in.changeCipherSpec()
  361. if err != nil {
  362. c.sendAlert(err.(alert))
  363. }
  364. case recordTypeApplicationData:
  365. if typ != want {
  366. c.sendAlert(alertUnexpectedMessage)
  367. break
  368. }
  369. c.input = b
  370. b = nil
  371. case recordTypeHandshake:
  372. // TODO(rsc): Should at least pick off connection close.
  373. if typ != want {
  374. return c.sendAlert(alertNoRenegotiation)
  375. }
  376. c.hand.Write(data)
  377. }
  378. if b != nil {
  379. c.in.freeBlock(b)
  380. }
  381. return c.error()
  382. }
  383. // sendAlert sends a TLS alert message.
  384. // c.out.Mutex <= L.
  385. func (c *Conn) sendAlertLocked(err alert) os.Error {
  386. c.tmp[0] = alertLevelError
  387. if err == alertNoRenegotiation {
  388. c.tmp[0] = alertLevelWarning
  389. }
  390. c.tmp[1] = byte(err)
  391. c.writeRecord(recordTypeAlert, c.tmp[0:2])
  392. // closeNotify is a special case in that it isn't an error:
  393. if err != alertCloseNotify {
  394. return c.setError(&net.OpError{Op: "local error", Error: err})
  395. }
  396. return nil
  397. }
  398. // sendAlert sends a TLS alert message.
  399. // L < c.out.Mutex.
  400. func (c *Conn) sendAlert(err alert) os.Error {
  401. c.out.Lock()
  402. defer c.out.Unlock()
  403. return c.sendAlertLocked(err)
  404. }
  405. // writeRecord writes a TLS record with the given type and payload
  406. // to the connection and updates the record layer state.
  407. // c.out.Mutex <= L.
  408. func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err os.Error) {
  409. b := c.out.newBlock()
  410. for len(data) > 0 {
  411. m := len(data)
  412. if m > maxPlaintext {
  413. m = maxPlaintext
  414. }
  415. b.resize(recordHeaderLen + m)
  416. b.data[0] = byte(typ)
  417. vers := c.vers
  418. if vers == 0 {
  419. vers = maxVersion
  420. }
  421. b.data[1] = byte(vers >> 8)
  422. b.data[2] = byte(vers)
  423. b.data[3] = byte(m >> 8)
  424. b.data[4] = byte(m)
  425. copy(b.data[recordHeaderLen:], data)
  426. c.out.encrypt(b)
  427. _, err = c.conn.Write(b.data)
  428. if err != nil {
  429. break
  430. }
  431. n += m
  432. data = data[m:]
  433. }
  434. c.out.freeBlock(b)
  435. if typ == recordTypeChangeCipherSpec {
  436. err = c.out.changeCipherSpec()
  437. if err != nil {
  438. // Cannot call sendAlert directly,
  439. // because we already hold c.out.Mutex.
  440. c.tmp[0] = alertLevelError
  441. c.tmp[1] = byte(err.(alert))
  442. c.writeRecord(recordTypeAlert, c.tmp[0:2])
  443. c.err = &net.OpError{Op: "local error", Error: err}
  444. return n, c.err
  445. }
  446. }
  447. return
  448. }
  449. // readHandshake reads the next handshake message from
  450. // the record layer.
  451. // c.in.Mutex < L; c.out.Mutex < L.
  452. func (c *Conn) readHandshake() (interface{}, os.Error) {
  453. for c.hand.Len() < 4 {
  454. if c.err != nil {
  455. return nil, c.err
  456. }
  457. c.readRecord(recordTypeHandshake)
  458. }
  459. data := c.hand.Bytes()
  460. n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  461. if n > maxHandshake {
  462. c.sendAlert(alertInternalError)
  463. return nil, c.err
  464. }
  465. for c.hand.Len() < 4+n {
  466. if c.err != nil {
  467. return nil, c.err
  468. }
  469. c.readRecord(recordTypeHandshake)
  470. }
  471. data = c.hand.Next(4 + n)
  472. var m handshakeMessage
  473. switch data[0] {
  474. case typeClientHello:
  475. m = new(clientHelloMsg)
  476. case typeServerHello:
  477. m = new(serverHelloMsg)
  478. case typeCertificate:
  479. m = new(certificateMsg)
  480. case typeCertificateRequest:
  481. m = new(certificateRequestMsg)
  482. case typeCertificateStatus:
  483. m = new(certificateStatusMsg)
  484. case typeServerHelloDone:
  485. m = new(serverHelloDoneMsg)
  486. case typeClientKeyExchange:
  487. m = new(clientKeyExchangeMsg)
  488. case typeCertificateVerify:
  489. m = new(certificateVerifyMsg)
  490. case typeNextProtocol:
  491. m = new(nextProtoMsg)
  492. case typeFinished:
  493. m = new(finishedMsg)
  494. default:
  495. c.sendAlert(alertUnexpectedMessage)
  496. return nil, alertUnexpectedMessage
  497. }
  498. // The handshake message unmarshallers
  499. // expect to be able to keep references to data,
  500. // so pass in a fresh copy that won't be overwritten.
  501. data = bytes.Add(nil, data)
  502. if !m.unmarshal(data) {
  503. c.sendAlert(alertUnexpectedMessage)
  504. return nil, alertUnexpectedMessage
  505. }
  506. return m, nil
  507. }
  508. // Write writes data to the connection.
  509. func (c *Conn) Write(b []byte) (n int, err os.Error) {
  510. if err = c.Handshake(); err != nil {
  511. return
  512. }
  513. c.out.Lock()
  514. defer c.out.Unlock()
  515. if !c.handshakeComplete {
  516. return 0, alertInternalError
  517. }
  518. if c.err != nil {
  519. return 0, c.err
  520. }
  521. return c.writeRecord(recordTypeApplicationData, b)
  522. }
  523. // Read can be made to time out and return err == os.EAGAIN
  524. // after a fixed time limit; see SetTimeout and SetReadTimeout.
  525. func (c *Conn) Read(b []byte) (n int, err os.Error) {
  526. if err = c.Handshake(); err != nil {
  527. return
  528. }
  529. c.in.Lock()
  530. defer c.in.Unlock()
  531. for c.input == nil && c.err == nil {
  532. c.readRecord(recordTypeApplicationData)
  533. }
  534. if c.err != nil {
  535. return 0, c.err
  536. }
  537. n, err = c.input.Read(b)
  538. if c.input.off >= len(c.input.data) {
  539. c.in.freeBlock(c.input)
  540. c.input = nil
  541. }
  542. return n, nil
  543. }
  544. // Close closes the connection.
  545. func (c *Conn) Close() os.Error {
  546. if err := c.Handshake(); err != nil {
  547. return err
  548. }
  549. return c.sendAlert(alertCloseNotify)
  550. }
  551. // Handshake runs the client or server handshake
  552. // protocol if it has not yet been run.
  553. // Most uses of this package need not call Handshake
  554. // explicitly: the first Read or Write will call it automatically.
  555. func (c *Conn) Handshake() os.Error {
  556. c.handshakeMutex.Lock()
  557. defer c.handshakeMutex.Unlock()
  558. if err := c.error(); err != nil {
  559. return err
  560. }
  561. if c.handshakeComplete {
  562. return nil
  563. }
  564. if c.isClient {
  565. return c.clientHandshake()
  566. }
  567. return c.serverHandshake()
  568. }
  569. // ConnectionState returns basic TLS details about the connection.
  570. func (c *Conn) ConnectionState() ConnectionState {
  571. c.handshakeMutex.Lock()
  572. defer c.handshakeMutex.Unlock()
  573. var state ConnectionState
  574. state.HandshakeComplete = c.handshakeComplete
  575. if c.handshakeComplete {
  576. state.NegotiatedProtocol = c.clientProtocol
  577. state.CipherSuite = c.cipherSuite
  578. }
  579. return state
  580. }
  581. // OCSPResponse returns the stapled OCSP response from the TLS server, if
  582. // any. (Only valid for client connections.)
  583. func (c *Conn) OCSPResponse() []byte {
  584. c.handshakeMutex.Lock()
  585. defer c.handshakeMutex.Unlock()
  586. return c.ocspResponse
  587. }
  588. // PeerCertificates returns the certificate chain that was presented by the
  589. // other side.
  590. func (c *Conn) PeerCertificates() []*x509.Certificate {
  591. c.handshakeMutex.Lock()
  592. defer c.handshakeMutex.Unlock()
  593. return c.peerCertificates
  594. }
  595. // VerifyHostname checks that the peer certificate chain is valid for
  596. // connecting to host. If so, it returns nil; if not, it returns an os.Error
  597. // describing the problem.
  598. func (c *Conn) VerifyHostname(host string) os.Error {
  599. c.handshakeMutex.Lock()
  600. defer c.handshakeMutex.Unlock()
  601. if !c.isClient {
  602. return os.ErrorString("VerifyHostname called on TLS server connection")
  603. }
  604. if !c.handshakeComplete {
  605. return os.ErrorString("TLS handshake has not yet been performed")
  606. }
  607. return c.peerCertificates[0].VerifyHostname(host)
  608. }