You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

397 lines
12 KiB

  1. // Copyright 2014 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // DTLS implementation.
  5. //
  6. // NOTE: This is a not even a remotely production-quality DTLS
  7. // implementation. It is the bare minimum necessary to be able to
  8. // achieve coverage on BoringSSL's implementation. Of note is that
  9. // this implementation assumes the underlying net.PacketConn is not
  10. // only reliable but also ordered. BoringSSL will be expected to deal
  11. // with simulated loss, but there is no point in forcing the test
  12. // driver to.
  13. package main
  14. import (
  15. "bytes"
  16. "errors"
  17. "fmt"
  18. "io"
  19. "math/rand"
  20. "net"
  21. )
  22. func versionToWire(vers uint16, isDTLS bool) uint16 {
  23. if isDTLS {
  24. return ^(vers - 0x0201)
  25. }
  26. return vers
  27. }
  28. func wireToVersion(vers uint16, isDTLS bool) uint16 {
  29. if isDTLS {
  30. return ^vers + 0x0201
  31. }
  32. return vers
  33. }
  34. func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) {
  35. recordHeaderLen := dtlsRecordHeaderLen
  36. if c.rawInput == nil {
  37. c.rawInput = c.in.newBlock()
  38. }
  39. b := c.rawInput
  40. // Read a new packet only if the current one is empty.
  41. if len(b.data) == 0 {
  42. // Pick some absurdly large buffer size.
  43. b.resize(maxCiphertext + recordHeaderLen)
  44. n, err := c.conn.Read(c.rawInput.data)
  45. if err != nil {
  46. return 0, nil, err
  47. }
  48. if c.config.Bugs.MaxPacketLength != 0 && n > c.config.Bugs.MaxPacketLength {
  49. return 0, nil, fmt.Errorf("dtls: exceeded maximum packet length")
  50. }
  51. c.rawInput.resize(n)
  52. }
  53. // Read out one record.
  54. //
  55. // A real DTLS implementation should be tolerant of errors,
  56. // but this is test code. We should not be tolerant of our
  57. // peer sending garbage.
  58. if len(b.data) < recordHeaderLen {
  59. return 0, nil, errors.New("dtls: failed to read record header")
  60. }
  61. typ := recordType(b.data[0])
  62. vers := wireToVersion(uint16(b.data[1])<<8|uint16(b.data[2]), c.isDTLS)
  63. if c.haveVers {
  64. if vers != c.vers {
  65. c.sendAlert(alertProtocolVersion)
  66. return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.vers))
  67. }
  68. } else {
  69. if expect := c.config.Bugs.ExpectInitialRecordVersion; expect != 0 && vers != expect {
  70. c.sendAlert(alertProtocolVersion)
  71. return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect))
  72. }
  73. }
  74. seq := b.data[3:11]
  75. // For test purposes, we assume a reliable channel. Require
  76. // that the explicit sequence number matches the incrementing
  77. // one we maintain. A real implementation would maintain a
  78. // replay window and such.
  79. if !bytes.Equal(seq, c.in.seq[:]) {
  80. c.sendAlert(alertIllegalParameter)
  81. return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number"))
  82. }
  83. n := int(b.data[11])<<8 | int(b.data[12])
  84. if n > maxCiphertext || len(b.data) < recordHeaderLen+n {
  85. c.sendAlert(alertRecordOverflow)
  86. return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: oversized record received with length %d", n))
  87. }
  88. // Process message.
  89. b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
  90. ok, off, err := c.in.decrypt(b)
  91. if !ok {
  92. c.in.setErrorLocked(c.sendAlert(err))
  93. }
  94. b.off = off
  95. return typ, b, nil
  96. }
  97. func (c *Conn) makeFragment(header, data []byte, fragOffset, fragLen int) []byte {
  98. fragment := make([]byte, 0, 12+fragLen)
  99. fragment = append(fragment, header...)
  100. fragment = append(fragment, byte(c.sendHandshakeSeq>>8), byte(c.sendHandshakeSeq))
  101. fragment = append(fragment, byte(fragOffset>>16), byte(fragOffset>>8), byte(fragOffset))
  102. fragment = append(fragment, byte(fragLen>>16), byte(fragLen>>8), byte(fragLen))
  103. fragment = append(fragment, data[fragOffset:fragOffset+fragLen]...)
  104. return fragment
  105. }
  106. func (c *Conn) dtlsWriteRecord(typ recordType, data []byte) (n int, err error) {
  107. if typ != recordTypeHandshake {
  108. // Only handshake messages are fragmented.
  109. return c.dtlsWriteRawRecord(typ, data)
  110. }
  111. maxLen := c.config.Bugs.MaxHandshakeRecordLength
  112. if maxLen <= 0 {
  113. maxLen = 1024
  114. }
  115. // Handshake messages have to be modified to include fragment
  116. // offset and length and with the header replicated. Save the
  117. // TLS header here.
  118. //
  119. // TODO(davidben): This assumes that data contains exactly one
  120. // handshake message. This is incompatible with
  121. // FragmentAcrossChangeCipherSpec. (Which is unfortunate
  122. // because OpenSSL's DTLS implementation will probably accept
  123. // such fragmentation and could do with a fix + tests.)
  124. header := data[:4]
  125. data = data[4:]
  126. isFinished := header[0] == typeFinished
  127. if c.config.Bugs.SendEmptyFragments {
  128. fragment := c.makeFragment(header, data, 0, 0)
  129. c.pendingFragments = append(c.pendingFragments, fragment)
  130. }
  131. firstRun := true
  132. fragOffset := 0
  133. for firstRun || fragOffset < len(data) {
  134. firstRun = false
  135. fragLen := len(data) - fragOffset
  136. if fragLen > maxLen {
  137. fragLen = maxLen
  138. }
  139. fragment := c.makeFragment(header, data, fragOffset, fragLen)
  140. if c.config.Bugs.FragmentMessageTypeMismatch && fragOffset > 0 {
  141. fragment[0]++
  142. }
  143. if c.config.Bugs.FragmentMessageLengthMismatch && fragOffset > 0 {
  144. fragment[3]++
  145. }
  146. // Buffer the fragment for later. They will be sent (and
  147. // reordered) on flush.
  148. c.pendingFragments = append(c.pendingFragments, fragment)
  149. if c.config.Bugs.ReorderHandshakeFragments {
  150. // Don't duplicate Finished to avoid the peer
  151. // interpreting it as a retransmit request.
  152. if !isFinished {
  153. c.pendingFragments = append(c.pendingFragments, fragment)
  154. }
  155. if fragLen > (maxLen+1)/2 {
  156. // Overlap each fragment by half.
  157. fragLen = (maxLen + 1) / 2
  158. }
  159. }
  160. fragOffset += fragLen
  161. n += fragLen
  162. }
  163. if !isFinished && c.config.Bugs.MixCompleteMessageWithFragments {
  164. fragment := c.makeFragment(header, data, 0, len(data))
  165. c.pendingFragments = append(c.pendingFragments, fragment)
  166. }
  167. // Increment the handshake sequence number for the next
  168. // handshake message.
  169. c.sendHandshakeSeq++
  170. return
  171. }
  172. func (c *Conn) dtlsFlushHandshake() error {
  173. if !c.isDTLS {
  174. return nil
  175. }
  176. var fragments [][]byte
  177. fragments, c.pendingFragments = c.pendingFragments, fragments
  178. if c.config.Bugs.ReorderHandshakeFragments {
  179. perm := rand.New(rand.NewSource(0)).Perm(len(fragments))
  180. tmp := make([][]byte, len(fragments))
  181. for i := range tmp {
  182. tmp[i] = fragments[perm[i]]
  183. }
  184. fragments = tmp
  185. }
  186. // Send them all.
  187. for _, fragment := range fragments {
  188. if c.config.Bugs.SplitFragmentHeader {
  189. if _, err := c.dtlsWriteRawRecord(recordTypeHandshake, fragment[:2]); err != nil {
  190. return err
  191. }
  192. fragment = fragment[2:]
  193. } else if c.config.Bugs.SplitFragmentBody && len(fragment) > 12 {
  194. if _, err := c.dtlsWriteRawRecord(recordTypeHandshake, fragment[:13]); err != nil {
  195. return err
  196. }
  197. fragment = fragment[13:]
  198. }
  199. // TODO(davidben): A real DTLS implementation needs to
  200. // retransmit handshake messages. For testing purposes, we don't
  201. // actually care.
  202. if _, err := c.dtlsWriteRawRecord(recordTypeHandshake, fragment); err != nil {
  203. return err
  204. }
  205. }
  206. return nil
  207. }
  208. func (c *Conn) dtlsWriteRawRecord(typ recordType, data []byte) (n int, err error) {
  209. recordHeaderLen := dtlsRecordHeaderLen
  210. maxLen := c.config.Bugs.MaxHandshakeRecordLength
  211. if maxLen <= 0 {
  212. maxLen = 1024
  213. }
  214. b := c.out.newBlock()
  215. explicitIVLen := 0
  216. explicitIVIsSeq := false
  217. if cbc, ok := c.out.cipher.(cbcMode); ok {
  218. // Block cipher modes have an explicit IV.
  219. explicitIVLen = cbc.BlockSize()
  220. } else if aead, ok := c.out.cipher.(*tlsAead); ok {
  221. if aead.explicitNonce {
  222. explicitIVLen = 8
  223. // The AES-GCM construction in TLS has an explicit nonce so that
  224. // the nonce can be random. However, the nonce is only 8 bytes
  225. // which is too small for a secure, random nonce. Therefore we
  226. // use the sequence number as the nonce.
  227. explicitIVIsSeq = true
  228. }
  229. } else if c.out.cipher != nil {
  230. panic("Unknown cipher")
  231. }
  232. b.resize(recordHeaderLen + explicitIVLen + len(data))
  233. b.data[0] = byte(typ)
  234. vers := c.vers
  235. if vers == 0 {
  236. // Some TLS servers fail if the record version is greater than
  237. // TLS 1.0 for the initial ClientHello.
  238. vers = VersionTLS10
  239. }
  240. vers = versionToWire(vers, c.isDTLS)
  241. b.data[1] = byte(vers >> 8)
  242. b.data[2] = byte(vers)
  243. // DTLS records include an explicit sequence number.
  244. copy(b.data[3:11], c.out.seq[0:])
  245. b.data[11] = byte(len(data) >> 8)
  246. b.data[12] = byte(len(data))
  247. if explicitIVLen > 0 {
  248. explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
  249. if explicitIVIsSeq {
  250. copy(explicitIV, c.out.seq[:])
  251. } else {
  252. if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil {
  253. return
  254. }
  255. }
  256. }
  257. copy(b.data[recordHeaderLen+explicitIVLen:], data)
  258. c.out.encrypt(b, explicitIVLen)
  259. _, err = c.conn.Write(b.data)
  260. if err != nil {
  261. return
  262. }
  263. n = len(data)
  264. c.out.freeBlock(b)
  265. if typ == recordTypeChangeCipherSpec {
  266. err = c.out.changeCipherSpec(c.config)
  267. if err != nil {
  268. // Cannot call sendAlert directly,
  269. // because we already hold c.out.Mutex.
  270. c.tmp[0] = alertLevelError
  271. c.tmp[1] = byte(err.(alert))
  272. c.writeRecord(recordTypeAlert, c.tmp[0:2])
  273. return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
  274. }
  275. }
  276. return
  277. }
  278. func (c *Conn) dtlsDoReadHandshake() ([]byte, error) {
  279. // Assemble a full handshake message. For test purposes, this
  280. // implementation assumes fragments arrive in order. It may
  281. // need to be cleverer if we ever test BoringSSL's retransmit
  282. // behavior.
  283. for len(c.handMsg) < 4+c.handMsgLen {
  284. // Get a new handshake record if the previous has been
  285. // exhausted.
  286. if c.hand.Len() == 0 {
  287. if err := c.in.err; err != nil {
  288. return nil, err
  289. }
  290. if err := c.readRecord(recordTypeHandshake); err != nil {
  291. return nil, err
  292. }
  293. }
  294. // Read the next fragment. It must fit entirely within
  295. // the record.
  296. if c.hand.Len() < 12 {
  297. return nil, errors.New("dtls: bad handshake record")
  298. }
  299. header := c.hand.Next(12)
  300. fragN := int(header[1])<<16 | int(header[2])<<8 | int(header[3])
  301. fragSeq := uint16(header[4])<<8 | uint16(header[5])
  302. fragOff := int(header[6])<<16 | int(header[7])<<8 | int(header[8])
  303. fragLen := int(header[9])<<16 | int(header[10])<<8 | int(header[11])
  304. if c.hand.Len() < fragLen {
  305. return nil, errors.New("dtls: fragment length too long")
  306. }
  307. fragment := c.hand.Next(fragLen)
  308. // Check it's a fragment for the right message.
  309. if fragSeq != c.recvHandshakeSeq {
  310. return nil, errors.New("dtls: bad handshake sequence number")
  311. }
  312. // Check that the length is consistent.
  313. if c.handMsg == nil {
  314. c.handMsgLen = fragN
  315. if c.handMsgLen > maxHandshake {
  316. return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError))
  317. }
  318. // Start with the TLS handshake header,
  319. // without the DTLS bits.
  320. c.handMsg = append([]byte{}, header[:4]...)
  321. } else if fragN != c.handMsgLen {
  322. return nil, errors.New("dtls: bad handshake length")
  323. }
  324. // Add the fragment to the pending message.
  325. if 4+fragOff != len(c.handMsg) {
  326. return nil, errors.New("dtls: bad fragment offset")
  327. }
  328. if fragOff+fragLen > c.handMsgLen {
  329. return nil, errors.New("dtls: bad fragment length")
  330. }
  331. c.handMsg = append(c.handMsg, fragment...)
  332. }
  333. c.recvHandshakeSeq++
  334. ret := c.handMsg
  335. c.handMsg, c.handMsgLen = nil, 0
  336. return ret, nil
  337. }
  338. // DTLSServer returns a new DTLS server side connection
  339. // using conn as the underlying transport.
  340. // The configuration config must be non-nil and must have
  341. // at least one certificate.
  342. func DTLSServer(conn net.Conn, config *Config) *Conn {
  343. c := &Conn{config: config, isDTLS: true, conn: conn}
  344. c.init()
  345. return c
  346. }
  347. // DTLSClient returns a new DTLS client side connection
  348. // using conn as the underlying transport.
  349. // The config cannot be nil: users must set either ServerHostname or
  350. // InsecureSkipVerify in the config.
  351. func DTLSClient(conn net.Conn, config *Config) *Conn {
  352. c := &Conn{config: config, isClient: true, isDTLS: true, conn: conn}
  353. c.init()
  354. return c
  355. }