25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

331 lines
9.6 KiB

  1. // Copyright 2014 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // DTLS implementation.
  5. //
  6. // NOTE: This is a not even a remotely production-quality DTLS
  7. // implementation. It is the bare minimum necessary to be able to
  8. // achieve coverage on BoringSSL's implementation. Of note is that
  9. // this implementation assumes the underlying net.PacketConn is not
  10. // only reliable but also ordered. BoringSSL will be expected to deal
  11. // with simulated loss, but there is no point in forcing the test
  12. // driver to.
  13. package main
  14. import (
  15. "bytes"
  16. "crypto/cipher"
  17. "errors"
  18. "fmt"
  19. "io"
  20. "net"
  21. )
  22. func versionToWire(vers uint16, isDTLS bool) uint16 {
  23. if isDTLS {
  24. return ^(vers - 0x0201)
  25. }
  26. return vers
  27. }
  28. func wireToVersion(vers uint16, isDTLS bool) uint16 {
  29. if isDTLS {
  30. return ^vers + 0x0201
  31. }
  32. return vers
  33. }
  34. func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) {
  35. recordHeaderLen := dtlsRecordHeaderLen
  36. if c.rawInput == nil {
  37. c.rawInput = c.in.newBlock()
  38. }
  39. b := c.rawInput
  40. // Read a new packet only if the current one is empty.
  41. if len(b.data) == 0 {
  42. // Pick some absurdly large buffer size.
  43. b.resize(maxCiphertext + recordHeaderLen)
  44. n, err := c.conn.Read(c.rawInput.data)
  45. if err != nil {
  46. return 0, nil, err
  47. }
  48. if c.config.Bugs.MaxPacketLength != 0 && n > c.config.Bugs.MaxPacketLength {
  49. return 0, nil, fmt.Errorf("dtls: exceeded maximum packet length")
  50. }
  51. c.rawInput.resize(n)
  52. }
  53. // Read out one record.
  54. //
  55. // A real DTLS implementation should be tolerant of errors,
  56. // but this is test code. We should not be tolerant of our
  57. // peer sending garbage.
  58. if len(b.data) < recordHeaderLen {
  59. return 0, nil, errors.New("dtls: failed to read record header")
  60. }
  61. typ := recordType(b.data[0])
  62. vers := wireToVersion(uint16(b.data[1])<<8|uint16(b.data[2]), c.isDTLS)
  63. if c.haveVers {
  64. if vers != c.vers {
  65. c.sendAlert(alertProtocolVersion)
  66. return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.vers))
  67. }
  68. } else {
  69. if expect := c.config.Bugs.ExpectInitialRecordVersion; expect != 0 && vers != expect {
  70. c.sendAlert(alertProtocolVersion)
  71. return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect))
  72. }
  73. }
  74. seq := b.data[3:11]
  75. // For test purposes, we assume a reliable channel. Require
  76. // that the explicit sequence number matches the incrementing
  77. // one we maintain. A real implementation would maintain a
  78. // replay window and such.
  79. if !bytes.Equal(seq, c.in.seq[:]) {
  80. c.sendAlert(alertIllegalParameter)
  81. return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number"))
  82. }
  83. n := int(b.data[11])<<8 | int(b.data[12])
  84. if n > maxCiphertext || len(b.data) < recordHeaderLen+n {
  85. c.sendAlert(alertRecordOverflow)
  86. return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: oversized record received with length %d", n))
  87. }
  88. // Process message.
  89. b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
  90. ok, off, err := c.in.decrypt(b)
  91. if !ok {
  92. c.in.setErrorLocked(c.sendAlert(err))
  93. }
  94. b.off = off
  95. return typ, b, nil
  96. }
  97. func (c *Conn) dtlsWriteRecord(typ recordType, data []byte) (n int, err error) {
  98. if typ != recordTypeHandshake {
  99. // Only handshake messages are fragmented.
  100. return c.dtlsWriteRawRecord(typ, data)
  101. }
  102. maxLen := c.config.Bugs.MaxHandshakeRecordLength
  103. if maxLen <= 0 {
  104. maxLen = 1024
  105. }
  106. // Handshake messages have to be modified to include fragment
  107. // offset and length and with the header replicated. Save the
  108. // TLS header here.
  109. //
  110. // TODO(davidben): This assumes that data contains exactly one
  111. // handshake message. This is incompatible with
  112. // FragmentAcrossChangeCipherSpec. (Which is unfortunate
  113. // because OpenSSL's DTLS implementation will probably accept
  114. // such fragmentation and could do with a fix + tests.)
  115. if len(data) < 4 {
  116. // This should not happen.
  117. panic(data)
  118. }
  119. header := data[:4]
  120. data = data[4:]
  121. firstRun := true
  122. for firstRun || len(data) > 0 {
  123. firstRun = false
  124. m := len(data)
  125. if m > maxLen {
  126. m = maxLen
  127. }
  128. // Standard TLS handshake header.
  129. fragment := make([]byte, 0, 12+m)
  130. fragment = append(fragment, header...)
  131. // message_seq
  132. fragment = append(fragment, byte(c.sendHandshakeSeq>>8), byte(c.sendHandshakeSeq))
  133. // fragment_offset
  134. fragment = append(fragment, byte(n>>16), byte(n>>8), byte(n))
  135. // fragment_length
  136. fragment = append(fragment, byte(m>>16), byte(m>>8), byte(m))
  137. fragment = append(fragment, data[:m]...)
  138. // TODO(davidben): A real DTLS implementation needs to
  139. // retransmit handshake messages. For testing purposes, we don't
  140. // actually care.
  141. _, err = c.dtlsWriteRawRecord(recordTypeHandshake, fragment)
  142. if err != nil {
  143. break
  144. }
  145. n += m
  146. data = data[m:]
  147. }
  148. // Increment the handshake sequence number for the next
  149. // handshake message.
  150. c.sendHandshakeSeq++
  151. return
  152. }
  153. func (c *Conn) dtlsWriteRawRecord(typ recordType, data []byte) (n int, err error) {
  154. recordHeaderLen := dtlsRecordHeaderLen
  155. maxLen := c.config.Bugs.MaxHandshakeRecordLength
  156. if maxLen <= 0 {
  157. maxLen = 1024
  158. }
  159. b := c.out.newBlock()
  160. explicitIVLen := 0
  161. explicitIVIsSeq := false
  162. if cbc, ok := c.out.cipher.(cbcMode); ok {
  163. // Block cipher modes have an explicit IV.
  164. explicitIVLen = cbc.BlockSize()
  165. } else if _, ok := c.out.cipher.(cipher.AEAD); ok {
  166. explicitIVLen = 8
  167. // The AES-GCM construction in TLS has an explicit nonce so that
  168. // the nonce can be random. However, the nonce is only 8 bytes
  169. // which is too small for a secure, random nonce. Therefore we
  170. // use the sequence number as the nonce.
  171. explicitIVIsSeq = true
  172. } else if c.out.cipher != nil {
  173. panic("Unknown cipher")
  174. }
  175. b.resize(recordHeaderLen + explicitIVLen + len(data))
  176. b.data[0] = byte(typ)
  177. vers := c.vers
  178. if vers == 0 {
  179. // Some TLS servers fail if the record version is greater than
  180. // TLS 1.0 for the initial ClientHello.
  181. vers = VersionTLS10
  182. }
  183. vers = versionToWire(vers, c.isDTLS)
  184. b.data[1] = byte(vers >> 8)
  185. b.data[2] = byte(vers)
  186. // DTLS records include an explicit sequence number.
  187. copy(b.data[3:11], c.out.seq[0:])
  188. b.data[11] = byte(len(data) >> 8)
  189. b.data[12] = byte(len(data))
  190. if explicitIVLen > 0 {
  191. explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
  192. if explicitIVIsSeq {
  193. copy(explicitIV, c.out.seq[:])
  194. } else {
  195. if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil {
  196. return
  197. }
  198. }
  199. }
  200. copy(b.data[recordHeaderLen+explicitIVLen:], data)
  201. c.out.encrypt(b, explicitIVLen)
  202. _, err = c.conn.Write(b.data)
  203. if err != nil {
  204. return
  205. }
  206. n = len(data)
  207. c.out.freeBlock(b)
  208. if typ == recordTypeChangeCipherSpec {
  209. err = c.out.changeCipherSpec(c.config)
  210. if err != nil {
  211. // Cannot call sendAlert directly,
  212. // because we already hold c.out.Mutex.
  213. c.tmp[0] = alertLevelError
  214. c.tmp[1] = byte(err.(alert))
  215. c.writeRecord(recordTypeAlert, c.tmp[0:2])
  216. return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
  217. }
  218. }
  219. return
  220. }
  221. func (c *Conn) dtlsDoReadHandshake() ([]byte, error) {
  222. // Assemble a full handshake message. For test purposes, this
  223. // implementation assumes fragments arrive in order. It may
  224. // need to be cleverer if we ever test BoringSSL's retransmit
  225. // behavior.
  226. for len(c.handMsg) < 4+c.handMsgLen {
  227. // Get a new handshake record if the previous has been
  228. // exhausted.
  229. if c.hand.Len() == 0 {
  230. if err := c.in.err; err != nil {
  231. return nil, err
  232. }
  233. if err := c.readRecord(recordTypeHandshake); err != nil {
  234. return nil, err
  235. }
  236. }
  237. // Read the next fragment. It must fit entirely within
  238. // the record.
  239. if c.hand.Len() < 12 {
  240. return nil, errors.New("dtls: bad handshake record")
  241. }
  242. header := c.hand.Next(12)
  243. fragN := int(header[1])<<16 | int(header[2])<<8 | int(header[3])
  244. fragSeq := uint16(header[4])<<8 | uint16(header[5])
  245. fragOff := int(header[6])<<16 | int(header[7])<<8 | int(header[8])
  246. fragLen := int(header[9])<<16 | int(header[10])<<8 | int(header[11])
  247. if c.hand.Len() < fragLen {
  248. return nil, errors.New("dtls: fragment length too long")
  249. }
  250. fragment := c.hand.Next(fragLen)
  251. // Check it's a fragment for the right message.
  252. if fragSeq != c.recvHandshakeSeq {
  253. return nil, errors.New("dtls: bad handshake sequence number")
  254. }
  255. // Check that the length is consistent.
  256. if c.handMsg == nil {
  257. c.handMsgLen = fragN
  258. if c.handMsgLen > maxHandshake {
  259. return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError))
  260. }
  261. // Start with the TLS handshake header,
  262. // without the DTLS bits.
  263. c.handMsg = append([]byte{}, header[:4]...)
  264. } else if fragN != c.handMsgLen {
  265. return nil, errors.New("dtls: bad handshake length")
  266. }
  267. // Add the fragment to the pending message.
  268. if 4+fragOff != len(c.handMsg) {
  269. return nil, errors.New("dtls: bad fragment offset")
  270. }
  271. if fragOff+fragLen > c.handMsgLen {
  272. return nil, errors.New("dtls: bad fragment length")
  273. }
  274. c.handMsg = append(c.handMsg, fragment...)
  275. }
  276. c.recvHandshakeSeq++
  277. ret := c.handMsg
  278. c.handMsg, c.handMsgLen = nil, 0
  279. return ret, nil
  280. }
  281. // DTLSServer returns a new DTLS server side connection
  282. // using conn as the underlying transport.
  283. // The configuration config must be non-nil and must have
  284. // at least one certificate.
  285. func DTLSServer(conn net.Conn, config *Config) *Conn {
  286. c := &Conn{config: config, isDTLS: true, conn: conn}
  287. c.init()
  288. return c
  289. }
  290. // DTLSClient returns a new DTLS client side connection
  291. // using conn as the underlying transport.
  292. // The config cannot be nil: users must set either ServerHostname or
  293. // InsecureSkipVerify in the config.
  294. func DTLSClient(conn net.Conn, config *Config) *Conn {
  295. c := &Conn{config: config, isClient: true, isDTLS: true, conn: conn}
  296. c.init()
  297. return c
  298. }