Alternative TLS implementation in Go
Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

conn.go 14 KiB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. // TLS low level connection and record layer
  2. package tls
  3. import (
  4. "bytes"
  5. "crypto/subtle"
  6. "hash"
  7. "io"
  8. "net"
  9. "os"
  10. "sync"
  11. )
  12. // A Conn represents a secured connection.
  13. // It implements the net.Conn interface.
  14. type Conn struct {
  15. // constant
  16. conn net.Conn
  17. isClient bool
  18. // constant after handshake; protected by handshakeMutex
  19. handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
  20. vers uint16 // TLS version
  21. haveVers bool // version has been negotiated
  22. config *Config // configuration passed to constructor
  23. handshakeComplete bool
  24. cipherSuite uint16
  25. clientProtocol string
  26. // first permanent error
  27. errMutex sync.Mutex
  28. err os.Error
  29. // input/output
  30. in, out halfConn // in.Mutex < out.Mutex
  31. rawInput *block // raw input, right off the wire
  32. input *block // application data waiting to be read
  33. hand bytes.Buffer // handshake data waiting to be read
  34. tmp [16]byte
  35. }
  36. func (c *Conn) setError(err os.Error) os.Error {
  37. c.errMutex.Lock()
  38. defer c.errMutex.Unlock()
  39. if c.err == nil {
  40. c.err = err
  41. }
  42. return err
  43. }
  44. func (c *Conn) error() os.Error {
  45. c.errMutex.Lock()
  46. defer c.errMutex.Unlock()
  47. return c.err
  48. }
  49. // Access to net.Conn methods.
  50. // Cannot just embed net.Conn because that would
  51. // export the struct field too.
  52. // LocalAddr returns the local network address.
  53. func (c *Conn) LocalAddr() net.Addr {
  54. return c.conn.LocalAddr()
  55. }
  56. // RemoteAddr returns the remote network address.
  57. func (c *Conn) RemoteAddr() net.Addr {
  58. return c.conn.RemoteAddr()
  59. }
  60. // SetTimeout sets the read deadline associated with the connection.
  61. // There is no write deadline.
  62. func (c *Conn) SetTimeout(nsec int64) os.Error {
  63. return c.conn.SetTimeout(nsec)
  64. }
  65. // SetReadTimeout sets the time (in nanoseconds) that
  66. // Read will wait for data before returning os.EAGAIN.
  67. // Setting nsec == 0 (the default) disables the deadline.
  68. func (c *Conn) SetReadTimeout(nsec int64) os.Error {
  69. return c.conn.SetReadTimeout(nsec)
  70. }
  71. // SetWriteTimeout exists to satisfy the net.Conn interface
  72. // but is not implemented by TLS. It always returns an error.
  73. func (c *Conn) SetWriteTimeout(nsec int64) os.Error {
  74. return os.NewError("TLS does not support SetWriteTimeout")
  75. }
  76. // A halfConn represents one direction of the record layer
  77. // connection, either sending or receiving.
  78. type halfConn struct {
  79. sync.Mutex
  80. crypt encryptor // encryption state
  81. mac hash.Hash // MAC algorithm
  82. seq [8]byte // 64-bit sequence number
  83. bfree *block // list of free blocks
  84. nextCrypt encryptor // next encryption state
  85. nextMac hash.Hash // next MAC algorithm
  86. }
  87. // prepareCipherSpec sets the encryption and MAC states
  88. // that a subsequent changeCipherSpec will use.
  89. func (hc *halfConn) prepareCipherSpec(crypt encryptor, mac hash.Hash) {
  90. hc.nextCrypt = crypt
  91. hc.nextMac = mac
  92. }
  93. // changeCipherSpec changes the encryption and MAC states
  94. // to the ones previously passed to prepareCipherSpec.
  95. func (hc *halfConn) changeCipherSpec() os.Error {
  96. if hc.nextCrypt == nil {
  97. return alertInternalError
  98. }
  99. hc.crypt = hc.nextCrypt
  100. hc.mac = hc.nextMac
  101. hc.nextCrypt = nil
  102. hc.nextMac = nil
  103. return nil
  104. }
  105. // incSeq increments the sequence number.
  106. func (hc *halfConn) incSeq() {
  107. for i := 7; i >= 0; i-- {
  108. hc.seq[i]++
  109. if hc.seq[i] != 0 {
  110. return
  111. }
  112. }
  113. // Not allowed to let sequence number wrap.
  114. // Instead, must renegotiate before it does.
  115. // Not likely enough to bother.
  116. panic("TLS: sequence number wraparound")
  117. }
  118. // resetSeq resets the sequence number to zero.
  119. func (hc *halfConn) resetSeq() {
  120. for i := range hc.seq {
  121. hc.seq[i] = 0
  122. }
  123. }
  124. // decrypt checks and strips the mac and decrypts the data in b.
  125. func (hc *halfConn) decrypt(b *block) (bool, alert) {
  126. // pull out payload
  127. payload := b.data[recordHeaderLen:]
  128. // decrypt
  129. if hc.crypt != nil {
  130. hc.crypt.XORKeyStream(payload)
  131. }
  132. // check, strip mac
  133. if hc.mac != nil {
  134. if len(payload) < hc.mac.Size() {
  135. return false, alertBadRecordMAC
  136. }
  137. // strip mac off payload, b.data
  138. n := len(payload) - hc.mac.Size()
  139. b.data[3] = byte(n >> 8)
  140. b.data[4] = byte(n)
  141. b.data = b.data[0 : recordHeaderLen+n]
  142. remoteMAC := payload[n:]
  143. hc.mac.Reset()
  144. hc.mac.Write(hc.seq[0:])
  145. hc.incSeq()
  146. hc.mac.Write(b.data)
  147. if subtle.ConstantTimeCompare(hc.mac.Sum(), remoteMAC) != 1 {
  148. return false, alertBadRecordMAC
  149. }
  150. }
  151. return true, 0
  152. }
  153. // encrypt encrypts and macs the data in b.
  154. func (hc *halfConn) encrypt(b *block) (bool, alert) {
  155. // mac
  156. if hc.mac != nil {
  157. hc.mac.Reset()
  158. hc.mac.Write(hc.seq[0:])
  159. hc.incSeq()
  160. hc.mac.Write(b.data)
  161. mac := hc.mac.Sum()
  162. n := len(b.data)
  163. b.resize(n + len(mac))
  164. copy(b.data[n:], mac)
  165. // update length to include mac
  166. n = len(b.data) - recordHeaderLen
  167. b.data[3] = byte(n >> 8)
  168. b.data[4] = byte(n)
  169. }
  170. // encrypt
  171. if hc.crypt != nil {
  172. hc.crypt.XORKeyStream(b.data[recordHeaderLen:])
  173. }
  174. return true, 0
  175. }
  176. // A block is a simple data buffer.
  177. type block struct {
  178. data []byte
  179. off int // index for Read
  180. link *block
  181. }
  182. // resize resizes block to be n bytes, growing if necessary.
  183. func (b *block) resize(n int) {
  184. if n > cap(b.data) {
  185. b.reserve(n)
  186. }
  187. b.data = b.data[0:n]
  188. }
  189. // reserve makes sure that block contains a capacity of at least n bytes.
  190. func (b *block) reserve(n int) {
  191. if cap(b.data) >= n {
  192. return
  193. }
  194. m := cap(b.data)
  195. if m == 0 {
  196. m = 1024
  197. }
  198. for m < n {
  199. m *= 2
  200. }
  201. data := make([]byte, len(b.data), m)
  202. copy(data, b.data)
  203. b.data = data
  204. }
  205. // readFromUntil reads from r into b until b contains at least n bytes
  206. // or else returns an error.
  207. func (b *block) readFromUntil(r io.Reader, n int) os.Error {
  208. // quick case
  209. if len(b.data) >= n {
  210. return nil
  211. }
  212. // read until have enough.
  213. b.reserve(n)
  214. for {
  215. m, err := r.Read(b.data[len(b.data):cap(b.data)])
  216. b.data = b.data[0 : len(b.data)+m]
  217. if len(b.data) >= n {
  218. break
  219. }
  220. if err != nil {
  221. return err
  222. }
  223. }
  224. return nil
  225. }
  226. func (b *block) Read(p []byte) (n int, err os.Error) {
  227. n = copy(p, b.data[b.off:])
  228. b.off += n
  229. return
  230. }
  231. // newBlock allocates a new block, from hc's free list if possible.
  232. func (hc *halfConn) newBlock() *block {
  233. b := hc.bfree
  234. if b == nil {
  235. return new(block)
  236. }
  237. hc.bfree = b.link
  238. b.link = nil
  239. b.resize(0)
  240. return b
  241. }
  242. // freeBlock returns a block to hc's free list.
  243. // The protocol is such that each side only has a block or two on
  244. // its free list at a time, so there's no need to worry about
  245. // trimming the list, etc.
  246. func (hc *halfConn) freeBlock(b *block) {
  247. b.link = hc.bfree
  248. hc.bfree = b
  249. }
  250. // splitBlock splits a block after the first n bytes,
  251. // returning a block with those n bytes and a
  252. // block with the remaindec. the latter may be nil.
  253. func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
  254. if len(b.data) <= n {
  255. return b, nil
  256. }
  257. bb := hc.newBlock()
  258. bb.resize(len(b.data) - n)
  259. copy(bb.data, b.data[n:])
  260. b.data = b.data[0:n]
  261. return b, bb
  262. }
  263. // readRecord reads the next TLS record from the connection
  264. // and updates the record layer state.
  265. // c.in.Mutex <= L; c.input == nil.
  266. func (c *Conn) readRecord(want recordType) os.Error {
  267. // Caller must be in sync with connection:
  268. // handshake data if handshake not yet completed,
  269. // else application data. (We don't support renegotiation.)
  270. switch want {
  271. default:
  272. return c.sendAlert(alertInternalError)
  273. case recordTypeHandshake, recordTypeChangeCipherSpec:
  274. if c.handshakeComplete {
  275. return c.sendAlert(alertInternalError)
  276. }
  277. case recordTypeApplicationData:
  278. if !c.handshakeComplete {
  279. return c.sendAlert(alertInternalError)
  280. }
  281. }
  282. Again:
  283. if c.rawInput == nil {
  284. c.rawInput = c.in.newBlock()
  285. }
  286. b := c.rawInput
  287. // Read header, payload.
  288. if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
  289. // RFC suggests that EOF without an alertCloseNotify is
  290. // an error, but popular web sites seem to do this,
  291. // so we can't make it an error.
  292. // if err == os.EOF {
  293. // err = io.ErrUnexpectedEOF
  294. // }
  295. if e, ok := err.(net.Error); !ok || !e.Temporary() {
  296. c.setError(err)
  297. }
  298. return err
  299. }
  300. typ := recordType(b.data[0])
  301. vers := uint16(b.data[1])<<8 | uint16(b.data[2])
  302. n := int(b.data[3])<<8 | int(b.data[4])
  303. if c.haveVers && vers != c.vers {
  304. return c.sendAlert(alertProtocolVersion)
  305. }
  306. if n > maxCiphertext {
  307. return c.sendAlert(alertRecordOverflow)
  308. }
  309. if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
  310. if err == os.EOF {
  311. err = io.ErrUnexpectedEOF
  312. }
  313. if e, ok := err.(net.Error); !ok || !e.Temporary() {
  314. c.setError(err)
  315. }
  316. return err
  317. }
  318. // Process message.
  319. b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
  320. b.off = recordHeaderLen
  321. if ok, err := c.in.decrypt(b); !ok {
  322. return c.sendAlert(err)
  323. }
  324. data := b.data[b.off:]
  325. if len(data) > maxPlaintext {
  326. c.sendAlert(alertRecordOverflow)
  327. c.in.freeBlock(b)
  328. return c.error()
  329. }
  330. switch typ {
  331. default:
  332. c.sendAlert(alertUnexpectedMessage)
  333. case recordTypeAlert:
  334. if len(data) != 2 {
  335. c.sendAlert(alertUnexpectedMessage)
  336. break
  337. }
  338. if alert(data[1]) == alertCloseNotify {
  339. c.setError(os.EOF)
  340. break
  341. }
  342. switch data[0] {
  343. case alertLevelWarning:
  344. // drop on the floor
  345. c.in.freeBlock(b)
  346. goto Again
  347. case alertLevelError:
  348. c.setError(&net.OpError{Op: "remote error", Error: alert(data[1])})
  349. default:
  350. c.sendAlert(alertUnexpectedMessage)
  351. }
  352. case recordTypeChangeCipherSpec:
  353. if typ != want || len(data) != 1 || data[0] != 1 {
  354. c.sendAlert(alertUnexpectedMessage)
  355. break
  356. }
  357. err := c.in.changeCipherSpec()
  358. if err != nil {
  359. c.sendAlert(err.(alert))
  360. }
  361. case recordTypeApplicationData:
  362. if typ != want {
  363. c.sendAlert(alertUnexpectedMessage)
  364. break
  365. }
  366. c.input = b
  367. b = nil
  368. case recordTypeHandshake:
  369. // TODO(rsc): Should at least pick off connection close.
  370. if typ != want {
  371. return c.sendAlert(alertNoRenegotiation)
  372. }
  373. c.hand.Write(data)
  374. }
  375. if b != nil {
  376. c.in.freeBlock(b)
  377. }
  378. return c.error()
  379. }
  380. // sendAlert sends a TLS alert message.
  381. // c.out.Mutex <= L.
  382. func (c *Conn) sendAlertLocked(err alert) os.Error {
  383. c.tmp[0] = alertLevelError
  384. if err == alertNoRenegotiation {
  385. c.tmp[0] = alertLevelWarning
  386. }
  387. c.tmp[1] = byte(err)
  388. c.writeRecord(recordTypeAlert, c.tmp[0:2])
  389. return c.setError(&net.OpError{Op: "local error", Error: err})
  390. }
  391. // sendAlert sends a TLS alert message.
  392. // L < c.out.Mutex.
  393. func (c *Conn) sendAlert(err alert) os.Error {
  394. c.out.Lock()
  395. defer c.out.Unlock()
  396. return c.sendAlertLocked(err)
  397. }
  398. // writeRecord writes a TLS record with the given type and payload
  399. // to the connection and updates the record layer state.
  400. // c.out.Mutex <= L.
  401. func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err os.Error) {
  402. b := c.out.newBlock()
  403. for len(data) > 0 {
  404. m := len(data)
  405. if m > maxPlaintext {
  406. m = maxPlaintext
  407. }
  408. b.resize(recordHeaderLen + m)
  409. b.data[0] = byte(typ)
  410. vers := c.vers
  411. if vers == 0 {
  412. vers = maxVersion
  413. }
  414. b.data[1] = byte(vers >> 8)
  415. b.data[2] = byte(vers)
  416. b.data[3] = byte(m >> 8)
  417. b.data[4] = byte(m)
  418. copy(b.data[recordHeaderLen:], data)
  419. c.out.encrypt(b)
  420. _, err = c.conn.Write(b.data)
  421. if err != nil {
  422. break
  423. }
  424. n += m
  425. data = data[m:]
  426. }
  427. c.out.freeBlock(b)
  428. if typ == recordTypeChangeCipherSpec {
  429. err = c.out.changeCipherSpec()
  430. if err != nil {
  431. // Cannot call sendAlert directly,
  432. // because we already hold c.out.Mutex.
  433. c.tmp[0] = alertLevelError
  434. c.tmp[1] = byte(err.(alert))
  435. c.writeRecord(recordTypeAlert, c.tmp[0:2])
  436. c.err = &net.OpError{Op: "local error", Error: err}
  437. return n, c.err
  438. }
  439. }
  440. return
  441. }
  442. // readHandshake reads the next handshake message from
  443. // the record layer.
  444. // c.in.Mutex < L; c.out.Mutex < L.
  445. func (c *Conn) readHandshake() (interface{}, os.Error) {
  446. for c.hand.Len() < 4 {
  447. if c.err != nil {
  448. return nil, c.err
  449. }
  450. c.readRecord(recordTypeHandshake)
  451. }
  452. data := c.hand.Bytes()
  453. n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  454. if n > maxHandshake {
  455. c.sendAlert(alertInternalError)
  456. return nil, c.err
  457. }
  458. for c.hand.Len() < 4+n {
  459. if c.err != nil {
  460. return nil, c.err
  461. }
  462. c.readRecord(recordTypeHandshake)
  463. }
  464. data = c.hand.Next(4 + n)
  465. var m handshakeMessage
  466. switch data[0] {
  467. case typeClientHello:
  468. m = new(clientHelloMsg)
  469. case typeServerHello:
  470. m = new(serverHelloMsg)
  471. case typeCertificate:
  472. m = new(certificateMsg)
  473. case typeServerHelloDone:
  474. m = new(serverHelloDoneMsg)
  475. case typeClientKeyExchange:
  476. m = new(clientKeyExchangeMsg)
  477. case typeNextProtocol:
  478. m = new(nextProtoMsg)
  479. case typeFinished:
  480. m = new(finishedMsg)
  481. default:
  482. c.sendAlert(alertUnexpectedMessage)
  483. return nil, alertUnexpectedMessage
  484. }
  485. // The handshake message unmarshallers
  486. // expect to be able to keep references to data,
  487. // so pass in a fresh copy that won't be overwritten.
  488. data = bytes.Add(nil, data)
  489. if !m.unmarshal(data) {
  490. c.sendAlert(alertUnexpectedMessage)
  491. return nil, alertUnexpectedMessage
  492. }
  493. return m, nil
  494. }
  495. // Write writes data to the connection.
  496. func (c *Conn) Write(b []byte) (n int, err os.Error) {
  497. if err = c.Handshake(); err != nil {
  498. return
  499. }
  500. c.out.Lock()
  501. defer c.out.Unlock()
  502. if !c.handshakeComplete {
  503. return 0, alertInternalError
  504. }
  505. if c.err != nil {
  506. return 0, c.err
  507. }
  508. return c.writeRecord(recordTypeApplicationData, b)
  509. }
  510. // Read can be made to time out and return err == os.EAGAIN
  511. // after a fixed time limit; see SetTimeout and SetReadTimeout.
  512. func (c *Conn) Read(b []byte) (n int, err os.Error) {
  513. if err = c.Handshake(); err != nil {
  514. return
  515. }
  516. c.in.Lock()
  517. defer c.in.Unlock()
  518. for c.input == nil && c.err == nil {
  519. c.readRecord(recordTypeApplicationData)
  520. }
  521. if c.err != nil {
  522. return 0, c.err
  523. }
  524. n, err = c.input.Read(b)
  525. if c.input.off >= len(c.input.data) {
  526. c.in.freeBlock(c.input)
  527. c.input = nil
  528. }
  529. return n, nil
  530. }
  531. // Close closes the connection.
  532. func (c *Conn) Close() os.Error {
  533. if err := c.Handshake(); err != nil {
  534. return err
  535. }
  536. return c.sendAlert(alertCloseNotify)
  537. }
  538. // Handshake runs the client or server handshake
  539. // protocol if it has not yet been run.
  540. // Most uses of this packge need not call Handshake
  541. // explicitly: the first Read or Write will call it automatically.
  542. func (c *Conn) Handshake() os.Error {
  543. c.handshakeMutex.Lock()
  544. defer c.handshakeMutex.Unlock()
  545. if err := c.error(); err != nil {
  546. return err
  547. }
  548. if c.handshakeComplete {
  549. return nil
  550. }
  551. if c.isClient {
  552. return c.clientHandshake()
  553. }
  554. return c.serverHandshake()
  555. }
  556. // If c is a TLS server, ClientConnection returns the protocol
  557. // requested by the client during the TLS handshake.
  558. // Handshake must have been called already.
  559. func (c *Conn) ClientConnection() string {
  560. c.handshakeMutex.Lock()
  561. defer c.handshakeMutex.Unlock()
  562. return c.clientProtocol
  563. }