You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

1521 lines
31 KiB

  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package tls
  5. import "bytes"
  6. type clientHelloMsg struct {
  7. raw []byte
  8. vers uint16
  9. random []byte
  10. sessionId []byte
  11. cipherSuites []uint16
  12. compressionMethods []uint8
  13. nextProtoNeg bool
  14. serverName string
  15. ocspStapling bool
  16. scts bool
  17. supportedCurves []CurveID
  18. supportedPoints []uint8
  19. ticketSupported bool
  20. sessionTicket []uint8
  21. signatureAndHashes []signatureAndHash
  22. secureRenegotiation bool
  23. alpnProtocols []string
  24. }
  25. func (m *clientHelloMsg) equal(i interface{}) bool {
  26. m1, ok := i.(*clientHelloMsg)
  27. if !ok {
  28. return false
  29. }
  30. return bytes.Equal(m.raw, m1.raw) &&
  31. m.vers == m1.vers &&
  32. bytes.Equal(m.random, m1.random) &&
  33. bytes.Equal(m.sessionId, m1.sessionId) &&
  34. eqUint16s(m.cipherSuites, m1.cipherSuites) &&
  35. bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
  36. m.nextProtoNeg == m1.nextProtoNeg &&
  37. m.serverName == m1.serverName &&
  38. m.ocspStapling == m1.ocspStapling &&
  39. m.scts == m1.scts &&
  40. eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
  41. bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
  42. m.ticketSupported == m1.ticketSupported &&
  43. bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
  44. eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) &&
  45. m.secureRenegotiation == m1.secureRenegotiation &&
  46. eqStrings(m.alpnProtocols, m1.alpnProtocols)
  47. }
  48. func (m *clientHelloMsg) marshal() []byte {
  49. if m.raw != nil {
  50. return m.raw
  51. }
  52. length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
  53. numExtensions := 0
  54. extensionsLength := 0
  55. if m.nextProtoNeg {
  56. numExtensions++
  57. }
  58. if m.ocspStapling {
  59. extensionsLength += 1 + 2 + 2
  60. numExtensions++
  61. }
  62. if len(m.serverName) > 0 {
  63. extensionsLength += 5 + len(m.serverName)
  64. numExtensions++
  65. }
  66. if len(m.supportedCurves) > 0 {
  67. extensionsLength += 2 + 2*len(m.supportedCurves)
  68. numExtensions++
  69. }
  70. if len(m.supportedPoints) > 0 {
  71. extensionsLength += 1 + len(m.supportedPoints)
  72. numExtensions++
  73. }
  74. if m.ticketSupported {
  75. extensionsLength += len(m.sessionTicket)
  76. numExtensions++
  77. }
  78. if len(m.signatureAndHashes) > 0 {
  79. extensionsLength += 2 + 2*len(m.signatureAndHashes)
  80. numExtensions++
  81. }
  82. if m.secureRenegotiation {
  83. extensionsLength += 1
  84. numExtensions++
  85. }
  86. if len(m.alpnProtocols) > 0 {
  87. extensionsLength += 2
  88. for _, s := range m.alpnProtocols {
  89. if l := len(s); l == 0 || l > 255 {
  90. panic("invalid ALPN protocol")
  91. }
  92. extensionsLength++
  93. extensionsLength += len(s)
  94. }
  95. numExtensions++
  96. }
  97. if m.scts {
  98. numExtensions++
  99. }
  100. if numExtensions > 0 {
  101. extensionsLength += 4 * numExtensions
  102. length += 2 + extensionsLength
  103. }
  104. x := make([]byte, 4+length)
  105. x[0] = typeClientHello
  106. x[1] = uint8(length >> 16)
  107. x[2] = uint8(length >> 8)
  108. x[3] = uint8(length)
  109. x[4] = uint8(m.vers >> 8)
  110. x[5] = uint8(m.vers)
  111. copy(x[6:38], m.random)
  112. x[38] = uint8(len(m.sessionId))
  113. copy(x[39:39+len(m.sessionId)], m.sessionId)
  114. y := x[39+len(m.sessionId):]
  115. y[0] = uint8(len(m.cipherSuites) >> 7)
  116. y[1] = uint8(len(m.cipherSuites) << 1)
  117. for i, suite := range m.cipherSuites {
  118. y[2+i*2] = uint8(suite >> 8)
  119. y[3+i*2] = uint8(suite)
  120. }
  121. z := y[2+len(m.cipherSuites)*2:]
  122. z[0] = uint8(len(m.compressionMethods))
  123. copy(z[1:], m.compressionMethods)
  124. z = z[1+len(m.compressionMethods):]
  125. if numExtensions > 0 {
  126. z[0] = byte(extensionsLength >> 8)
  127. z[1] = byte(extensionsLength)
  128. z = z[2:]
  129. }
  130. if m.nextProtoNeg {
  131. z[0] = byte(extensionNextProtoNeg >> 8)
  132. z[1] = byte(extensionNextProtoNeg & 0xff)
  133. // The length is always 0
  134. z = z[4:]
  135. }
  136. if len(m.serverName) > 0 {
  137. z[0] = byte(extensionServerName >> 8)
  138. z[1] = byte(extensionServerName & 0xff)
  139. l := len(m.serverName) + 5
  140. z[2] = byte(l >> 8)
  141. z[3] = byte(l)
  142. z = z[4:]
  143. // RFC 3546, section 3.1
  144. //
  145. // struct {
  146. // NameType name_type;
  147. // select (name_type) {
  148. // case host_name: HostName;
  149. // } name;
  150. // } ServerName;
  151. //
  152. // enum {
  153. // host_name(0), (255)
  154. // } NameType;
  155. //
  156. // opaque HostName<1..2^16-1>;
  157. //
  158. // struct {
  159. // ServerName server_name_list<1..2^16-1>
  160. // } ServerNameList;
  161. z[0] = byte((len(m.serverName) + 3) >> 8)
  162. z[1] = byte(len(m.serverName) + 3)
  163. z[3] = byte(len(m.serverName) >> 8)
  164. z[4] = byte(len(m.serverName))
  165. copy(z[5:], []byte(m.serverName))
  166. z = z[l:]
  167. }
  168. if m.ocspStapling {
  169. // RFC 4366, section 3.6
  170. z[0] = byte(extensionStatusRequest >> 8)
  171. z[1] = byte(extensionStatusRequest)
  172. z[2] = 0
  173. z[3] = 5
  174. z[4] = 1 // OCSP type
  175. // Two zero valued uint16s for the two lengths.
  176. z = z[9:]
  177. }
  178. if len(m.supportedCurves) > 0 {
  179. // http://tools.ietf.org/html/rfc4492#section-5.5.1
  180. z[0] = byte(extensionSupportedCurves >> 8)
  181. z[1] = byte(extensionSupportedCurves)
  182. l := 2 + 2*len(m.supportedCurves)
  183. z[2] = byte(l >> 8)
  184. z[3] = byte(l)
  185. l -= 2
  186. z[4] = byte(l >> 8)
  187. z[5] = byte(l)
  188. z = z[6:]
  189. for _, curve := range m.supportedCurves {
  190. z[0] = byte(curve >> 8)
  191. z[1] = byte(curve)
  192. z = z[2:]
  193. }
  194. }
  195. if len(m.supportedPoints) > 0 {
  196. // http://tools.ietf.org/html/rfc4492#section-5.5.2
  197. z[0] = byte(extensionSupportedPoints >> 8)
  198. z[1] = byte(extensionSupportedPoints)
  199. l := 1 + len(m.supportedPoints)
  200. z[2] = byte(l >> 8)
  201. z[3] = byte(l)
  202. l--
  203. z[4] = byte(l)
  204. z = z[5:]
  205. for _, pointFormat := range m.supportedPoints {
  206. z[0] = byte(pointFormat)
  207. z = z[1:]
  208. }
  209. }
  210. if m.ticketSupported {
  211. // http://tools.ietf.org/html/rfc5077#section-3.2
  212. z[0] = byte(extensionSessionTicket >> 8)
  213. z[1] = byte(extensionSessionTicket)
  214. l := len(m.sessionTicket)
  215. z[2] = byte(l >> 8)
  216. z[3] = byte(l)
  217. z = z[4:]
  218. copy(z, m.sessionTicket)
  219. z = z[len(m.sessionTicket):]
  220. }
  221. if len(m.signatureAndHashes) > 0 {
  222. // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
  223. z[0] = byte(extensionSignatureAlgorithms >> 8)
  224. z[1] = byte(extensionSignatureAlgorithms)
  225. l := 2 + 2*len(m.signatureAndHashes)
  226. z[2] = byte(l >> 8)
  227. z[3] = byte(l)
  228. z = z[4:]
  229. l -= 2
  230. z[0] = byte(l >> 8)
  231. z[1] = byte(l)
  232. z = z[2:]
  233. for _, sigAndHash := range m.signatureAndHashes {
  234. z[0] = sigAndHash.hash
  235. z[1] = sigAndHash.signature
  236. z = z[2:]
  237. }
  238. }
  239. if m.secureRenegotiation {
  240. z[0] = byte(extensionRenegotiationInfo >> 8)
  241. z[1] = byte(extensionRenegotiationInfo & 0xff)
  242. z[2] = 0
  243. z[3] = 1
  244. z = z[5:]
  245. }
  246. if len(m.alpnProtocols) > 0 {
  247. z[0] = byte(extensionALPN >> 8)
  248. z[1] = byte(extensionALPN & 0xff)
  249. lengths := z[2:]
  250. z = z[6:]
  251. stringsLength := 0
  252. for _, s := range m.alpnProtocols {
  253. l := len(s)
  254. z[0] = byte(l)
  255. copy(z[1:], s)
  256. z = z[1+l:]
  257. stringsLength += 1 + l
  258. }
  259. lengths[2] = byte(stringsLength >> 8)
  260. lengths[3] = byte(stringsLength)
  261. stringsLength += 2
  262. lengths[0] = byte(stringsLength >> 8)
  263. lengths[1] = byte(stringsLength)
  264. }
  265. if m.scts {
  266. // https://tools.ietf.org/html/rfc6962#section-3.3.1
  267. z[0] = byte(extensionSCT >> 8)
  268. z[1] = byte(extensionSCT)
  269. // zero uint16 for the zero-length extension_data
  270. z = z[4:]
  271. }
  272. m.raw = x
  273. return x
  274. }
  275. func (m *clientHelloMsg) unmarshal(data []byte) bool {
  276. if len(data) < 42 {
  277. return false
  278. }
  279. m.raw = data
  280. m.vers = uint16(data[4])<<8 | uint16(data[5])
  281. m.random = data[6:38]
  282. sessionIdLen := int(data[38])
  283. if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
  284. return false
  285. }
  286. m.sessionId = data[39 : 39+sessionIdLen]
  287. data = data[39+sessionIdLen:]
  288. if len(data) < 2 {
  289. return false
  290. }
  291. // cipherSuiteLen is the number of bytes of cipher suite numbers. Since
  292. // they are uint16s, the number must be even.
  293. cipherSuiteLen := int(data[0])<<8 | int(data[1])
  294. if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
  295. return false
  296. }
  297. numCipherSuites := cipherSuiteLen / 2
  298. m.cipherSuites = make([]uint16, numCipherSuites)
  299. for i := 0; i < numCipherSuites; i++ {
  300. m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
  301. if m.cipherSuites[i] == scsvRenegotiation {
  302. m.secureRenegotiation = true
  303. }
  304. }
  305. data = data[2+cipherSuiteLen:]
  306. if len(data) < 1 {
  307. return false
  308. }
  309. compressionMethodsLen := int(data[0])
  310. if len(data) < 1+compressionMethodsLen {
  311. return false
  312. }
  313. m.compressionMethods = data[1 : 1+compressionMethodsLen]
  314. data = data[1+compressionMethodsLen:]
  315. m.nextProtoNeg = false
  316. m.serverName = ""
  317. m.ocspStapling = false
  318. m.ticketSupported = false
  319. m.sessionTicket = nil
  320. m.signatureAndHashes = nil
  321. m.alpnProtocols = nil
  322. m.scts = false
  323. if len(data) == 0 {
  324. // ClientHello is optionally followed by extension data
  325. return true
  326. }
  327. if len(data) < 2 {
  328. return false
  329. }
  330. extensionsLength := int(data[0])<<8 | int(data[1])
  331. data = data[2:]
  332. if extensionsLength != len(data) {
  333. return false
  334. }
  335. for len(data) != 0 {
  336. if len(data) < 4 {
  337. return false
  338. }
  339. extension := uint16(data[0])<<8 | uint16(data[1])
  340. length := int(data[2])<<8 | int(data[3])
  341. data = data[4:]
  342. if len(data) < length {
  343. return false
  344. }
  345. switch extension {
  346. case extensionServerName:
  347. if length < 2 {
  348. return false
  349. }
  350. numNames := int(data[0])<<8 | int(data[1])
  351. d := data[2:]
  352. for i := 0; i < numNames; i++ {
  353. if len(d) < 3 {
  354. return false
  355. }
  356. nameType := d[0]
  357. nameLen := int(d[1])<<8 | int(d[2])
  358. d = d[3:]
  359. if len(d) < nameLen {
  360. return false
  361. }
  362. if nameType == 0 {
  363. m.serverName = string(d[0:nameLen])
  364. break
  365. }
  366. d = d[nameLen:]
  367. }
  368. case extensionNextProtoNeg:
  369. if length > 0 {
  370. return false
  371. }
  372. m.nextProtoNeg = true
  373. case extensionStatusRequest:
  374. m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
  375. case extensionSupportedCurves:
  376. // http://tools.ietf.org/html/rfc4492#section-5.5.1
  377. if length < 2 {
  378. return false
  379. }
  380. l := int(data[0])<<8 | int(data[1])
  381. if l%2 == 1 || length != l+2 {
  382. return false
  383. }
  384. numCurves := l / 2
  385. m.supportedCurves = make([]CurveID, numCurves)
  386. d := data[2:]
  387. for i := 0; i < numCurves; i++ {
  388. m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
  389. d = d[2:]
  390. }
  391. case extensionSupportedPoints:
  392. // http://tools.ietf.org/html/rfc4492#section-5.5.2
  393. if length < 1 {
  394. return false
  395. }
  396. l := int(data[0])
  397. if length != l+1 {
  398. return false
  399. }
  400. m.supportedPoints = make([]uint8, l)
  401. copy(m.supportedPoints, data[1:])
  402. case extensionSessionTicket:
  403. // http://tools.ietf.org/html/rfc5077#section-3.2
  404. m.ticketSupported = true
  405. m.sessionTicket = data[:length]
  406. case extensionSignatureAlgorithms:
  407. // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
  408. if length < 2 || length&1 != 0 {
  409. return false
  410. }
  411. l := int(data[0])<<8 | int(data[1])
  412. if l != length-2 {
  413. return false
  414. }
  415. n := l / 2
  416. d := data[2:]
  417. m.signatureAndHashes = make([]signatureAndHash, n)
  418. for i := range m.signatureAndHashes {
  419. m.signatureAndHashes[i].hash = d[0]
  420. m.signatureAndHashes[i].signature = d[1]
  421. d = d[2:]
  422. }
  423. case extensionRenegotiationInfo:
  424. if length != 1 || data[0] != 0 {
  425. return false
  426. }
  427. m.secureRenegotiation = true
  428. case extensionALPN:
  429. if length < 2 {
  430. return false
  431. }
  432. l := int(data[0])<<8 | int(data[1])
  433. if l != length-2 {
  434. return false
  435. }
  436. d := data[2:length]
  437. for len(d) != 0 {
  438. stringLen := int(d[0])
  439. d = d[1:]
  440. if stringLen == 0 || stringLen > len(d) {
  441. return false
  442. }
  443. m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
  444. d = d[stringLen:]
  445. }
  446. case extensionSCT:
  447. m.scts = true
  448. if length != 0 {
  449. return false
  450. }
  451. }
  452. data = data[length:]
  453. }
  454. return true
  455. }
  456. type serverHelloMsg struct {
  457. raw []byte
  458. vers uint16
  459. random []byte
  460. sessionId []byte
  461. cipherSuite uint16
  462. compressionMethod uint8
  463. nextProtoNeg bool
  464. nextProtos []string
  465. ocspStapling bool
  466. scts [][]byte
  467. ticketSupported bool
  468. secureRenegotiation bool
  469. alpnProtocol string
  470. }
  471. func (m *serverHelloMsg) equal(i interface{}) bool {
  472. m1, ok := i.(*serverHelloMsg)
  473. if !ok {
  474. return false
  475. }
  476. if len(m.scts) != len(m1.scts) {
  477. return false
  478. }
  479. for i, sct := range m.scts {
  480. if !bytes.Equal(sct, m1.scts[i]) {
  481. return false
  482. }
  483. }
  484. return bytes.Equal(m.raw, m1.raw) &&
  485. m.vers == m1.vers &&
  486. bytes.Equal(m.random, m1.random) &&
  487. bytes.Equal(m.sessionId, m1.sessionId) &&
  488. m.cipherSuite == m1.cipherSuite &&
  489. m.compressionMethod == m1.compressionMethod &&
  490. m.nextProtoNeg == m1.nextProtoNeg &&
  491. eqStrings(m.nextProtos, m1.nextProtos) &&
  492. m.ocspStapling == m1.ocspStapling &&
  493. m.ticketSupported == m1.ticketSupported &&
  494. m.secureRenegotiation == m1.secureRenegotiation &&
  495. m.alpnProtocol == m1.alpnProtocol
  496. }
  497. func (m *serverHelloMsg) marshal() []byte {
  498. if m.raw != nil {
  499. return m.raw
  500. }
  501. length := 38 + len(m.sessionId)
  502. numExtensions := 0
  503. extensionsLength := 0
  504. nextProtoLen := 0
  505. if m.nextProtoNeg {
  506. numExtensions++
  507. for _, v := range m.nextProtos {
  508. nextProtoLen += len(v)
  509. }
  510. nextProtoLen += len(m.nextProtos)
  511. extensionsLength += nextProtoLen
  512. }
  513. if m.ocspStapling {
  514. numExtensions++
  515. }
  516. if m.ticketSupported {
  517. numExtensions++
  518. }
  519. if m.secureRenegotiation {
  520. extensionsLength += 1
  521. numExtensions++
  522. }
  523. if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
  524. if alpnLen >= 256 {
  525. panic("invalid ALPN protocol")
  526. }
  527. extensionsLength += 2 + 1 + alpnLen
  528. numExtensions++
  529. }
  530. sctLen := 0
  531. if len(m.scts) > 0 {
  532. for _, sct := range m.scts {
  533. sctLen += len(sct) + 2
  534. }
  535. extensionsLength += 2 + sctLen
  536. numExtensions++
  537. }
  538. if numExtensions > 0 {
  539. extensionsLength += 4 * numExtensions
  540. length += 2 + extensionsLength
  541. }
  542. x := make([]byte, 4+length)
  543. x[0] = typeServerHello
  544. x[1] = uint8(length >> 16)
  545. x[2] = uint8(length >> 8)
  546. x[3] = uint8(length)
  547. x[4] = uint8(m.vers >> 8)
  548. x[5] = uint8(m.vers)
  549. copy(x[6:38], m.random)
  550. x[38] = uint8(len(m.sessionId))
  551. copy(x[39:39+len(m.sessionId)], m.sessionId)
  552. z := x[39+len(m.sessionId):]
  553. z[0] = uint8(m.cipherSuite >> 8)
  554. z[1] = uint8(m.cipherSuite)
  555. z[2] = uint8(m.compressionMethod)
  556. z = z[3:]
  557. if numExtensions > 0 {
  558. z[0] = byte(extensionsLength >> 8)
  559. z[1] = byte(extensionsLength)
  560. z = z[2:]
  561. }
  562. if m.nextProtoNeg {
  563. z[0] = byte(extensionNextProtoNeg >> 8)
  564. z[1] = byte(extensionNextProtoNeg & 0xff)
  565. z[2] = byte(nextProtoLen >> 8)
  566. z[3] = byte(nextProtoLen)
  567. z = z[4:]
  568. for _, v := range m.nextProtos {
  569. l := len(v)
  570. if l > 255 {
  571. l = 255
  572. }
  573. z[0] = byte(l)
  574. copy(z[1:], []byte(v[0:l]))
  575. z = z[1+l:]
  576. }
  577. }
  578. if m.ocspStapling {
  579. z[0] = byte(extensionStatusRequest >> 8)
  580. z[1] = byte(extensionStatusRequest)
  581. z = z[4:]
  582. }
  583. if m.ticketSupported {
  584. z[0] = byte(extensionSessionTicket >> 8)
  585. z[1] = byte(extensionSessionTicket)
  586. z = z[4:]
  587. }
  588. if m.secureRenegotiation {
  589. z[0] = byte(extensionRenegotiationInfo >> 8)
  590. z[1] = byte(extensionRenegotiationInfo & 0xff)
  591. z[2] = 0
  592. z[3] = 1
  593. z = z[5:]
  594. }
  595. if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
  596. z[0] = byte(extensionALPN >> 8)
  597. z[1] = byte(extensionALPN & 0xff)
  598. l := 2 + 1 + alpnLen
  599. z[2] = byte(l >> 8)
  600. z[3] = byte(l)
  601. l -= 2
  602. z[4] = byte(l >> 8)
  603. z[5] = byte(l)
  604. l -= 1
  605. z[6] = byte(l)
  606. copy(z[7:], []byte(m.alpnProtocol))
  607. z = z[7+alpnLen:]
  608. }
  609. if sctLen > 0 {
  610. z[0] = byte(extensionSCT >> 8)
  611. z[1] = byte(extensionSCT)
  612. l := sctLen + 2
  613. z[2] = byte(l >> 8)
  614. z[3] = byte(l)
  615. z[4] = byte(sctLen >> 8)
  616. z[5] = byte(sctLen)
  617. z = z[6:]
  618. for _, sct := range m.scts {
  619. z[0] = byte(len(sct) >> 8)
  620. z[1] = byte(len(sct))
  621. copy(z[2:], sct)
  622. z = z[len(sct)+2:]
  623. }
  624. }
  625. m.raw = x
  626. return x
  627. }
  628. func (m *serverHelloMsg) unmarshal(data []byte) bool {
  629. if len(data) < 42 {
  630. return false
  631. }
  632. m.raw = data
  633. m.vers = uint16(data[4])<<8 | uint16(data[5])
  634. m.random = data[6:38]
  635. sessionIdLen := int(data[38])
  636. if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
  637. return false
  638. }
  639. m.sessionId = data[39 : 39+sessionIdLen]
  640. data = data[39+sessionIdLen:]
  641. if len(data) < 3 {
  642. return false
  643. }
  644. m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
  645. m.compressionMethod = data[2]
  646. data = data[3:]
  647. m.nextProtoNeg = false
  648. m.nextProtos = nil
  649. m.ocspStapling = false
  650. m.scts = nil
  651. m.ticketSupported = false
  652. m.alpnProtocol = ""
  653. if len(data) == 0 {
  654. // ServerHello is optionally followed by extension data
  655. return true
  656. }
  657. if len(data) < 2 {
  658. return false
  659. }
  660. extensionsLength := int(data[0])<<8 | int(data[1])
  661. data = data[2:]
  662. if len(data) != extensionsLength {
  663. return false
  664. }
  665. for len(data) != 0 {
  666. if len(data) < 4 {
  667. return false
  668. }
  669. extension := uint16(data[0])<<8 | uint16(data[1])
  670. length := int(data[2])<<8 | int(data[3])
  671. data = data[4:]
  672. if len(data) < length {
  673. return false
  674. }
  675. switch extension {
  676. case extensionNextProtoNeg:
  677. m.nextProtoNeg = true
  678. d := data[:length]
  679. for len(d) > 0 {
  680. l := int(d[0])
  681. d = d[1:]
  682. if l == 0 || l > len(d) {
  683. return false
  684. }
  685. m.nextProtos = append(m.nextProtos, string(d[:l]))
  686. d = d[l:]
  687. }
  688. case extensionStatusRequest:
  689. if length > 0 {
  690. return false
  691. }
  692. m.ocspStapling = true
  693. case extensionSessionTicket:
  694. if length > 0 {
  695. return false
  696. }
  697. m.ticketSupported = true
  698. case extensionRenegotiationInfo:
  699. if length != 1 || data[0] != 0 {
  700. return false
  701. }
  702. m.secureRenegotiation = true
  703. case extensionALPN:
  704. d := data[:length]
  705. if len(d) < 3 {
  706. return false
  707. }
  708. l := int(d[0])<<8 | int(d[1])
  709. if l != len(d)-2 {
  710. return false
  711. }
  712. d = d[2:]
  713. l = int(d[0])
  714. if l != len(d)-1 {
  715. return false
  716. }
  717. d = d[1:]
  718. m.alpnProtocol = string(d)
  719. case extensionSCT:
  720. d := data[:length]
  721. if len(d) < 2 {
  722. return false
  723. }
  724. l := int(d[0])<<8 | int(d[1])
  725. d = d[2:]
  726. if len(d) != l {
  727. return false
  728. }
  729. if l == 0 {
  730. continue
  731. }
  732. m.scts = make([][]byte, 0, 3)
  733. for len(d) != 0 {
  734. if len(d) < 2 {
  735. return false
  736. }
  737. sctLen := int(d[0])<<8 | int(d[1])
  738. d = d[2:]
  739. if len(d) < sctLen {
  740. return false
  741. }
  742. m.scts = append(m.scts, d[:sctLen])
  743. d = d[sctLen:]
  744. }
  745. }
  746. data = data[length:]
  747. }
  748. return true
  749. }
  750. type certificateMsg struct {
  751. raw []byte
  752. certificates [][]byte
  753. }
  754. func (m *certificateMsg) equal(i interface{}) bool {
  755. m1, ok := i.(*certificateMsg)
  756. if !ok {
  757. return false
  758. }
  759. return bytes.Equal(m.raw, m1.raw) &&
  760. eqByteSlices(m.certificates, m1.certificates)
  761. }
  762. func (m *certificateMsg) marshal() (x []byte) {
  763. if m.raw != nil {
  764. return m.raw
  765. }
  766. var i int
  767. for _, slice := range m.certificates {
  768. i += len(slice)
  769. }
  770. length := 3 + 3*len(m.certificates) + i
  771. x = make([]byte, 4+length)
  772. x[0] = typeCertificate
  773. x[1] = uint8(length >> 16)
  774. x[2] = uint8(length >> 8)
  775. x[3] = uint8(length)
  776. certificateOctets := length - 3
  777. x[4] = uint8(certificateOctets >> 16)
  778. x[5] = uint8(certificateOctets >> 8)
  779. x[6] = uint8(certificateOctets)
  780. y := x[7:]
  781. for _, slice := range m.certificates {
  782. y[0] = uint8(len(slice) >> 16)
  783. y[1] = uint8(len(slice) >> 8)
  784. y[2] = uint8(len(slice))
  785. copy(y[3:], slice)
  786. y = y[3+len(slice):]
  787. }
  788. m.raw = x
  789. return
  790. }
  791. func (m *certificateMsg) unmarshal(data []byte) bool {
  792. if len(data) < 7 {
  793. return false
  794. }
  795. m.raw = data
  796. certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
  797. if uint32(len(data)) != certsLen+7 {
  798. return false
  799. }
  800. numCerts := 0
  801. d := data[7:]
  802. for certsLen > 0 {
  803. if len(d) < 4 {
  804. return false
  805. }
  806. certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  807. if uint32(len(d)) < 3+certLen {
  808. return false
  809. }
  810. d = d[3+certLen:]
  811. certsLen -= 3 + certLen
  812. numCerts++
  813. }
  814. m.certificates = make([][]byte, numCerts)
  815. d = data[7:]
  816. for i := 0; i < numCerts; i++ {
  817. certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  818. m.certificates[i] = d[3 : 3+certLen]
  819. d = d[3+certLen:]
  820. }
  821. return true
  822. }
  823. type serverKeyExchangeMsg struct {
  824. raw []byte
  825. key []byte
  826. }
  827. func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
  828. m1, ok := i.(*serverKeyExchangeMsg)
  829. if !ok {
  830. return false
  831. }
  832. return bytes.Equal(m.raw, m1.raw) &&
  833. bytes.Equal(m.key, m1.key)
  834. }
  835. func (m *serverKeyExchangeMsg) marshal() []byte {
  836. if m.raw != nil {
  837. return m.raw
  838. }
  839. length := len(m.key)
  840. x := make([]byte, length+4)
  841. x[0] = typeServerKeyExchange
  842. x[1] = uint8(length >> 16)
  843. x[2] = uint8(length >> 8)
  844. x[3] = uint8(length)
  845. copy(x[4:], m.key)
  846. m.raw = x
  847. return x
  848. }
  849. func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
  850. m.raw = data
  851. if len(data) < 4 {
  852. return false
  853. }
  854. m.key = data[4:]
  855. return true
  856. }
  857. type certificateStatusMsg struct {
  858. raw []byte
  859. statusType uint8
  860. response []byte
  861. }
  862. func (m *certificateStatusMsg) equal(i interface{}) bool {
  863. m1, ok := i.(*certificateStatusMsg)
  864. if !ok {
  865. return false
  866. }
  867. return bytes.Equal(m.raw, m1.raw) &&
  868. m.statusType == m1.statusType &&
  869. bytes.Equal(m.response, m1.response)
  870. }
  871. func (m *certificateStatusMsg) marshal() []byte {
  872. if m.raw != nil {
  873. return m.raw
  874. }
  875. var x []byte
  876. if m.statusType == statusTypeOCSP {
  877. x = make([]byte, 4+4+len(m.response))
  878. x[0] = typeCertificateStatus
  879. l := len(m.response) + 4
  880. x[1] = byte(l >> 16)
  881. x[2] = byte(l >> 8)
  882. x[3] = byte(l)
  883. x[4] = statusTypeOCSP
  884. l -= 4
  885. x[5] = byte(l >> 16)
  886. x[6] = byte(l >> 8)
  887. x[7] = byte(l)
  888. copy(x[8:], m.response)
  889. } else {
  890. x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
  891. }
  892. m.raw = x
  893. return x
  894. }
  895. func (m *certificateStatusMsg) unmarshal(data []byte) bool {
  896. m.raw = data
  897. if len(data) < 5 {
  898. return false
  899. }
  900. m.statusType = data[4]
  901. m.response = nil
  902. if m.statusType == statusTypeOCSP {
  903. if len(data) < 8 {
  904. return false
  905. }
  906. respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
  907. if uint32(len(data)) != 4+4+respLen {
  908. return false
  909. }
  910. m.response = data[8:]
  911. }
  912. return true
  913. }
  914. type serverHelloDoneMsg struct{}
  915. func (m *serverHelloDoneMsg) equal(i interface{}) bool {
  916. _, ok := i.(*serverHelloDoneMsg)
  917. return ok
  918. }
  919. func (m *serverHelloDoneMsg) marshal() []byte {
  920. x := make([]byte, 4)
  921. x[0] = typeServerHelloDone
  922. return x
  923. }
  924. func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
  925. return len(data) == 4
  926. }
  927. type clientKeyExchangeMsg struct {
  928. raw []byte
  929. ciphertext []byte
  930. }
  931. func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
  932. m1, ok := i.(*clientKeyExchangeMsg)
  933. if !ok {
  934. return false
  935. }
  936. return bytes.Equal(m.raw, m1.raw) &&
  937. bytes.Equal(m.ciphertext, m1.ciphertext)
  938. }
  939. func (m *clientKeyExchangeMsg) marshal() []byte {
  940. if m.raw != nil {
  941. return m.raw
  942. }
  943. length := len(m.ciphertext)
  944. x := make([]byte, length+4)
  945. x[0] = typeClientKeyExchange
  946. x[1] = uint8(length >> 16)
  947. x[2] = uint8(length >> 8)
  948. x[3] = uint8(length)
  949. copy(x[4:], m.ciphertext)
  950. m.raw = x
  951. return x
  952. }
  953. func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
  954. m.raw = data
  955. if len(data) < 4 {
  956. return false
  957. }
  958. l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  959. if l != len(data)-4 {
  960. return false
  961. }
  962. m.ciphertext = data[4:]
  963. return true
  964. }
  965. type finishedMsg struct {
  966. raw []byte
  967. verifyData []byte
  968. }
  969. func (m *finishedMsg) equal(i interface{}) bool {
  970. m1, ok := i.(*finishedMsg)
  971. if !ok {
  972. return false
  973. }
  974. return bytes.Equal(m.raw, m1.raw) &&
  975. bytes.Equal(m.verifyData, m1.verifyData)
  976. }
  977. func (m *finishedMsg) marshal() (x []byte) {
  978. if m.raw != nil {
  979. return m.raw
  980. }
  981. x = make([]byte, 4+len(m.verifyData))
  982. x[0] = typeFinished
  983. x[3] = byte(len(m.verifyData))
  984. copy(x[4:], m.verifyData)
  985. m.raw = x
  986. return
  987. }
  988. func (m *finishedMsg) unmarshal(data []byte) bool {
  989. m.raw = data
  990. if len(data) < 4 {
  991. return false
  992. }
  993. m.verifyData = data[4:]
  994. return true
  995. }
  996. type nextProtoMsg struct {
  997. raw []byte
  998. proto string
  999. }
  1000. func (m *nextProtoMsg) equal(i interface{}) bool {
  1001. m1, ok := i.(*nextProtoMsg)
  1002. if !ok {
  1003. return false
  1004. }
  1005. return bytes.Equal(m.raw, m1.raw) &&
  1006. m.proto == m1.proto
  1007. }
  1008. func (m *nextProtoMsg) marshal() []byte {
  1009. if m.raw != nil {
  1010. return m.raw
  1011. }
  1012. l := len(m.proto)
  1013. if l > 255 {
  1014. l = 255
  1015. }
  1016. padding := 32 - (l+2)%32
  1017. length := l + padding + 2
  1018. x := make([]byte, length+4)
  1019. x[0] = typeNextProtocol
  1020. x[1] = uint8(length >> 16)
  1021. x[2] = uint8(length >> 8)
  1022. x[3] = uint8(length)
  1023. y := x[4:]
  1024. y[0] = byte(l)
  1025. copy(y[1:], []byte(m.proto[0:l]))
  1026. y = y[1+l:]
  1027. y[0] = byte(padding)
  1028. m.raw = x
  1029. return x
  1030. }
  1031. func (m *nextProtoMsg) unmarshal(data []byte) bool {
  1032. m.raw = data
  1033. if len(data) < 5 {
  1034. return false
  1035. }
  1036. data = data[4:]
  1037. protoLen := int(data[0])
  1038. data = data[1:]
  1039. if len(data) < protoLen {
  1040. return false
  1041. }
  1042. m.proto = string(data[0:protoLen])
  1043. data = data[protoLen:]
  1044. if len(data) < 1 {
  1045. return false
  1046. }
  1047. paddingLen := int(data[0])
  1048. data = data[1:]
  1049. if len(data) != paddingLen {
  1050. return false
  1051. }
  1052. return true
  1053. }
  1054. type certificateRequestMsg struct {
  1055. raw []byte
  1056. // hasSignatureAndHash indicates whether this message includes a list
  1057. // of signature and hash functions. This change was introduced with TLS
  1058. // 1.2.
  1059. hasSignatureAndHash bool
  1060. certificateTypes []byte
  1061. signatureAndHashes []signatureAndHash
  1062. certificateAuthorities [][]byte
  1063. }
  1064. func (m *certificateRequestMsg) equal(i interface{}) bool {
  1065. m1, ok := i.(*certificateRequestMsg)
  1066. if !ok {
  1067. return false
  1068. }
  1069. return bytes.Equal(m.raw, m1.raw) &&
  1070. bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
  1071. eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
  1072. eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes)
  1073. }
  1074. func (m *certificateRequestMsg) marshal() (x []byte) {
  1075. if m.raw != nil {
  1076. return m.raw
  1077. }
  1078. // See http://tools.ietf.org/html/rfc4346#section-7.4.4
  1079. length := 1 + len(m.certificateTypes) + 2
  1080. casLength := 0
  1081. for _, ca := range m.certificateAuthorities {
  1082. casLength += 2 + len(ca)
  1083. }
  1084. length += casLength
  1085. if m.hasSignatureAndHash {
  1086. length += 2 + 2*len(m.signatureAndHashes)
  1087. }
  1088. x = make([]byte, 4+length)
  1089. x[0] = typeCertificateRequest
  1090. x[1] = uint8(length >> 16)
  1091. x[2] = uint8(length >> 8)
  1092. x[3] = uint8(length)
  1093. x[4] = uint8(len(m.certificateTypes))
  1094. copy(x[5:], m.certificateTypes)
  1095. y := x[5+len(m.certificateTypes):]
  1096. if m.hasSignatureAndHash {
  1097. n := len(m.signatureAndHashes) * 2
  1098. y[0] = uint8(n >> 8)
  1099. y[1] = uint8(n)
  1100. y = y[2:]
  1101. for _, sigAndHash := range m.signatureAndHashes {
  1102. y[0] = sigAndHash.hash
  1103. y[1] = sigAndHash.signature
  1104. y = y[2:]
  1105. }
  1106. }
  1107. y[0] = uint8(casLength >> 8)
  1108. y[1] = uint8(casLength)
  1109. y = y[2:]
  1110. for _, ca := range m.certificateAuthorities {
  1111. y[0] = uint8(len(ca) >> 8)
  1112. y[1] = uint8(len(ca))
  1113. y = y[2:]
  1114. copy(y, ca)
  1115. y = y[len(ca):]
  1116. }
  1117. m.raw = x
  1118. return
  1119. }
  1120. func (m *certificateRequestMsg) unmarshal(data []byte) bool {
  1121. m.raw = data
  1122. if len(data) < 5 {
  1123. return false
  1124. }
  1125. length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1126. if uint32(len(data))-4 != length {
  1127. return false
  1128. }
  1129. numCertTypes := int(data[4])
  1130. data = data[5:]
  1131. if numCertTypes == 0 || len(data) <= numCertTypes {
  1132. return false
  1133. }
  1134. m.certificateTypes = make([]byte, numCertTypes)
  1135. if copy(m.certificateTypes, data) != numCertTypes {
  1136. return false
  1137. }
  1138. data = data[numCertTypes:]
  1139. if m.hasSignatureAndHash {
  1140. if len(data) < 2 {
  1141. return false
  1142. }
  1143. sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
  1144. data = data[2:]
  1145. if sigAndHashLen&1 != 0 {
  1146. return false
  1147. }
  1148. if len(data) < int(sigAndHashLen) {
  1149. return false
  1150. }
  1151. numSigAndHash := sigAndHashLen / 2
  1152. m.signatureAndHashes = make([]signatureAndHash, numSigAndHash)
  1153. for i := range m.signatureAndHashes {
  1154. m.signatureAndHashes[i].hash = data[0]
  1155. m.signatureAndHashes[i].signature = data[1]
  1156. data = data[2:]
  1157. }
  1158. }
  1159. if len(data) < 2 {
  1160. return false
  1161. }
  1162. casLength := uint16(data[0])<<8 | uint16(data[1])
  1163. data = data[2:]
  1164. if len(data) < int(casLength) {
  1165. return false
  1166. }
  1167. cas := make([]byte, casLength)
  1168. copy(cas, data)
  1169. data = data[casLength:]
  1170. m.certificateAuthorities = nil
  1171. for len(cas) > 0 {
  1172. if len(cas) < 2 {
  1173. return false
  1174. }
  1175. caLen := uint16(cas[0])<<8 | uint16(cas[1])
  1176. cas = cas[2:]
  1177. if len(cas) < int(caLen) {
  1178. return false
  1179. }
  1180. m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  1181. cas = cas[caLen:]
  1182. }
  1183. if len(data) > 0 {
  1184. return false
  1185. }
  1186. return true
  1187. }
  1188. type certificateVerifyMsg struct {
  1189. raw []byte
  1190. hasSignatureAndHash bool
  1191. signatureAndHash signatureAndHash
  1192. signature []byte
  1193. }
  1194. func (m *certificateVerifyMsg) equal(i interface{}) bool {
  1195. m1, ok := i.(*certificateVerifyMsg)
  1196. if !ok {
  1197. return false
  1198. }
  1199. return bytes.Equal(m.raw, m1.raw) &&
  1200. m.hasSignatureAndHash == m1.hasSignatureAndHash &&
  1201. m.signatureAndHash.hash == m1.signatureAndHash.hash &&
  1202. m.signatureAndHash.signature == m1.signatureAndHash.signature &&
  1203. bytes.Equal(m.signature, m1.signature)
  1204. }
  1205. func (m *certificateVerifyMsg) marshal() (x []byte) {
  1206. if m.raw != nil {
  1207. return m.raw
  1208. }
  1209. // See http://tools.ietf.org/html/rfc4346#section-7.4.8
  1210. siglength := len(m.signature)
  1211. length := 2 + siglength
  1212. if m.hasSignatureAndHash {
  1213. length += 2
  1214. }
  1215. x = make([]byte, 4+length)
  1216. x[0] = typeCertificateVerify
  1217. x[1] = uint8(length >> 16)
  1218. x[2] = uint8(length >> 8)
  1219. x[3] = uint8(length)
  1220. y := x[4:]
  1221. if m.hasSignatureAndHash {
  1222. y[0] = m.signatureAndHash.hash
  1223. y[1] = m.signatureAndHash.signature
  1224. y = y[2:]
  1225. }
  1226. y[0] = uint8(siglength >> 8)
  1227. y[1] = uint8(siglength)
  1228. copy(y[2:], m.signature)
  1229. m.raw = x
  1230. return
  1231. }
  1232. func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
  1233. m.raw = data
  1234. if len(data) < 6 {
  1235. return false
  1236. }
  1237. length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1238. if uint32(len(data))-4 != length {
  1239. return false
  1240. }
  1241. data = data[4:]
  1242. if m.hasSignatureAndHash {
  1243. m.signatureAndHash.hash = data[0]
  1244. m.signatureAndHash.signature = data[1]
  1245. data = data[2:]
  1246. }
  1247. if len(data) < 2 {
  1248. return false
  1249. }
  1250. siglength := int(data[0])<<8 + int(data[1])
  1251. data = data[2:]
  1252. if len(data) != siglength {
  1253. return false
  1254. }
  1255. m.signature = data
  1256. return true
  1257. }
  1258. type newSessionTicketMsg struct {
  1259. raw []byte
  1260. ticket []byte
  1261. }
  1262. func (m *newSessionTicketMsg) equal(i interface{}) bool {
  1263. m1, ok := i.(*newSessionTicketMsg)
  1264. if !ok {
  1265. return false
  1266. }
  1267. return bytes.Equal(m.raw, m1.raw) &&
  1268. bytes.Equal(m.ticket, m1.ticket)
  1269. }
  1270. func (m *newSessionTicketMsg) marshal() (x []byte) {
  1271. if m.raw != nil {
  1272. return m.raw
  1273. }
  1274. // See http://tools.ietf.org/html/rfc5077#section-3.3
  1275. ticketLen := len(m.ticket)
  1276. length := 2 + 4 + ticketLen
  1277. x = make([]byte, 4+length)
  1278. x[0] = typeNewSessionTicket
  1279. x[1] = uint8(length >> 16)
  1280. x[2] = uint8(length >> 8)
  1281. x[3] = uint8(length)
  1282. x[8] = uint8(ticketLen >> 8)
  1283. x[9] = uint8(ticketLen)
  1284. copy(x[10:], m.ticket)
  1285. m.raw = x
  1286. return
  1287. }
  1288. func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
  1289. m.raw = data
  1290. if len(data) < 10 {
  1291. return false
  1292. }
  1293. length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1294. if uint32(len(data))-4 != length {
  1295. return false
  1296. }
  1297. ticketLen := int(data[8])<<8 + int(data[9])
  1298. if len(data)-10 != ticketLen {
  1299. return false
  1300. }
  1301. m.ticket = data[10:]
  1302. return true
  1303. }
  1304. func eqUint16s(x, y []uint16) bool {
  1305. if len(x) != len(y) {
  1306. return false
  1307. }
  1308. for i, v := range x {
  1309. if y[i] != v {
  1310. return false
  1311. }
  1312. }
  1313. return true
  1314. }
  1315. func eqCurveIDs(x, y []CurveID) bool {
  1316. if len(x) != len(y) {
  1317. return false
  1318. }
  1319. for i, v := range x {
  1320. if y[i] != v {
  1321. return false
  1322. }
  1323. }
  1324. return true
  1325. }
  1326. func eqStrings(x, y []string) bool {
  1327. if len(x) != len(y) {
  1328. return false
  1329. }
  1330. for i, v := range x {
  1331. if y[i] != v {
  1332. return false
  1333. }
  1334. }
  1335. return true
  1336. }
  1337. func eqByteSlices(x, y [][]byte) bool {
  1338. if len(x) != len(y) {
  1339. return false
  1340. }
  1341. for i, v := range x {
  1342. if !bytes.Equal(v, y[i]) {
  1343. return false
  1344. }
  1345. }
  1346. return true
  1347. }
  1348. func eqSignatureAndHashes(x, y []signatureAndHash) bool {
  1349. if len(x) != len(y) {
  1350. return false
  1351. }
  1352. for i, v := range x {
  1353. v2 := y[i]
  1354. if v.hash != v2.hash || v.signature != v2.signature {
  1355. return false
  1356. }
  1357. }
  1358. return true
  1359. }