Alternative TLS implementation in Go
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

2645 lignes
58 KiB

  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package tls
  5. import (
  6. "bytes"
  7. "strings"
  8. )
  9. type clientHelloMsg struct {
  10. raw []byte
  11. rawTruncated []byte // for PSK binding
  12. vers uint16
  13. random []byte
  14. sessionId []byte
  15. cipherSuites []uint16
  16. compressionMethods []uint8
  17. nextProtoNeg bool
  18. serverName string
  19. ocspStapling bool
  20. scts bool
  21. supportedCurves []CurveID
  22. supportedPoints []uint8
  23. ticketSupported bool
  24. sessionTicket []uint8
  25. supportedSignatureAlgorithms []SignatureScheme
  26. secureRenegotiation []byte
  27. secureRenegotiationSupported bool
  28. alpnProtocols []string
  29. keyShares []keyShare
  30. supportedVersions []uint16
  31. psks []psk
  32. pskKeyExchangeModes []uint8
  33. earlyData bool
  34. }
  35. // Helpers
  36. // Marshalling of signature_algorithms extension see https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
  37. // for more details. Extension is serialized in data buffer
  38. // Function advances data slice and returns it, so that it can be used for further processing
  39. func marshallExtensionSignatureAlgorithms(data []byte, sigSchemes []SignatureScheme) ([]byte) {
  40. data[0] = byte(extensionSignatureAlgorithms >> 8)
  41. data[1] = byte(extensionSignatureAlgorithms)
  42. l := 2 + 2*len(sigSchemes)
  43. data[2] = byte(l >> 8)
  44. data[3] = byte(l)
  45. data = data[4:]
  46. l -= 2
  47. data[0] = byte(l >> 8)
  48. data[1] = byte(l)
  49. data = data[2:]
  50. for _, sigAlgo := range sigSchemes {
  51. data[0] = byte(sigAlgo >> 8)
  52. data[1] = byte(sigAlgo)
  53. data = data[2:]
  54. }
  55. return data
  56. }
  57. // Unmrshalling of signature_algorithms extension see https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
  58. // for more details.
  59. // In case of error function returns alertDecoderError otherwise filled SignatureScheme slice and alertSuccess
  60. func unmarshallExtensionSignatureAlgorithms(data []byte, length int) ([]SignatureScheme, alert) {
  61. if length < 2 || length&1 != 0 {
  62. return nil, alertDecodeError
  63. }
  64. l := int(data[0])<<8 | int(data[1])
  65. if l != length-2 {
  66. return nil, alertDecodeError
  67. }
  68. n := l / 2
  69. d := data[2:]
  70. sigSchemes := make([]SignatureScheme, n)
  71. for i := range sigSchemes {
  72. sigSchemes[i] = SignatureScheme(d[0])<<8 | SignatureScheme(d[1])
  73. d = d[2:]
  74. }
  75. return sigSchemes, alertSuccess
  76. }
  77. func (m *clientHelloMsg) equal(i interface{}) bool {
  78. m1, ok := i.(*clientHelloMsg)
  79. if !ok {
  80. return false
  81. }
  82. return bytes.Equal(m.raw, m1.raw) &&
  83. m.vers == m1.vers &&
  84. bytes.Equal(m.random, m1.random) &&
  85. bytes.Equal(m.sessionId, m1.sessionId) &&
  86. eqUint16s(m.cipherSuites, m1.cipherSuites) &&
  87. bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
  88. m.nextProtoNeg == m1.nextProtoNeg &&
  89. m.serverName == m1.serverName &&
  90. m.ocspStapling == m1.ocspStapling &&
  91. m.scts == m1.scts &&
  92. eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
  93. bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
  94. m.ticketSupported == m1.ticketSupported &&
  95. bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
  96. eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) &&
  97. m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
  98. bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
  99. eqStrings(m.alpnProtocols, m1.alpnProtocols) &&
  100. eqKeyShares(m.keyShares, m1.keyShares) &&
  101. eqUint16s(m.supportedVersions, m1.supportedVersions) &&
  102. m.earlyData == m1.earlyData
  103. }
  104. func (m *clientHelloMsg) marshal() []byte {
  105. if m.raw != nil {
  106. return m.raw
  107. }
  108. length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
  109. numExtensions := 0
  110. extensionsLength := 0
  111. if m.nextProtoNeg {
  112. numExtensions++
  113. }
  114. if m.ocspStapling {
  115. extensionsLength += 1 + 2 + 2
  116. numExtensions++
  117. }
  118. if len(m.serverName) > 0 {
  119. extensionsLength += 5 + len(m.serverName)
  120. numExtensions++
  121. }
  122. if len(m.supportedCurves) > 0 {
  123. extensionsLength += 2 + 2*len(m.supportedCurves)
  124. numExtensions++
  125. }
  126. if len(m.supportedPoints) > 0 {
  127. extensionsLength += 1 + len(m.supportedPoints)
  128. numExtensions++
  129. }
  130. if m.ticketSupported {
  131. extensionsLength += len(m.sessionTicket)
  132. numExtensions++
  133. }
  134. if len(m.supportedSignatureAlgorithms) > 0 {
  135. extensionsLength += 2 + 2*len(m.supportedSignatureAlgorithms)
  136. numExtensions++
  137. }
  138. if m.secureRenegotiationSupported {
  139. extensionsLength += 1 + len(m.secureRenegotiation)
  140. numExtensions++
  141. }
  142. if len(m.alpnProtocols) > 0 {
  143. extensionsLength += 2
  144. for _, s := range m.alpnProtocols {
  145. if l := len(s); l == 0 || l > 255 {
  146. panic("invalid ALPN protocol")
  147. }
  148. extensionsLength++
  149. extensionsLength += len(s)
  150. }
  151. numExtensions++
  152. }
  153. if m.scts {
  154. numExtensions++
  155. }
  156. if len(m.keyShares) > 0 {
  157. extensionsLength += 2
  158. for _, k := range m.keyShares {
  159. extensionsLength += 4 + len(k.data)
  160. }
  161. numExtensions++
  162. }
  163. if len(m.supportedVersions) > 0 {
  164. extensionsLength += 1 + 2*len(m.supportedVersions)
  165. numExtensions++
  166. }
  167. if m.earlyData {
  168. numExtensions++
  169. }
  170. if numExtensions > 0 {
  171. extensionsLength += 4 * numExtensions
  172. length += 2 + extensionsLength
  173. }
  174. x := make([]byte, 4+length)
  175. x[0] = typeClientHello
  176. x[1] = uint8(length >> 16)
  177. x[2] = uint8(length >> 8)
  178. x[3] = uint8(length)
  179. x[4] = uint8(m.vers >> 8)
  180. x[5] = uint8(m.vers)
  181. copy(x[6:38], m.random)
  182. x[38] = uint8(len(m.sessionId))
  183. copy(x[39:39+len(m.sessionId)], m.sessionId)
  184. y := x[39+len(m.sessionId):]
  185. y[0] = uint8(len(m.cipherSuites) >> 7)
  186. y[1] = uint8(len(m.cipherSuites) << 1)
  187. for i, suite := range m.cipherSuites {
  188. y[2+i*2] = uint8(suite >> 8)
  189. y[3+i*2] = uint8(suite)
  190. }
  191. z := y[2+len(m.cipherSuites)*2:]
  192. z[0] = uint8(len(m.compressionMethods))
  193. copy(z[1:], m.compressionMethods)
  194. z = z[1+len(m.compressionMethods):]
  195. if numExtensions > 0 {
  196. z[0] = byte(extensionsLength >> 8)
  197. z[1] = byte(extensionsLength)
  198. z = z[2:]
  199. }
  200. if m.nextProtoNeg {
  201. z[0] = byte(extensionNextProtoNeg >> 8)
  202. z[1] = byte(extensionNextProtoNeg & 0xff)
  203. // The length is always 0
  204. z = z[4:]
  205. }
  206. if len(m.serverName) > 0 {
  207. z[0] = byte(extensionServerName >> 8)
  208. z[1] = byte(extensionServerName & 0xff)
  209. l := len(m.serverName) + 5
  210. z[2] = byte(l >> 8)
  211. z[3] = byte(l)
  212. z = z[4:]
  213. // RFC 3546, section 3.1
  214. //
  215. // struct {
  216. // NameType name_type;
  217. // select (name_type) {
  218. // case host_name: HostName;
  219. // } name;
  220. // } ServerName;
  221. //
  222. // enum {
  223. // host_name(0), (255)
  224. // } NameType;
  225. //
  226. // opaque HostName<1..2^16-1>;
  227. //
  228. // struct {
  229. // ServerName server_name_list<1..2^16-1>
  230. // } ServerNameList;
  231. z[0] = byte((len(m.serverName) + 3) >> 8)
  232. z[1] = byte(len(m.serverName) + 3)
  233. z[3] = byte(len(m.serverName) >> 8)
  234. z[4] = byte(len(m.serverName))
  235. copy(z[5:], []byte(m.serverName))
  236. z = z[l:]
  237. }
  238. if m.ocspStapling {
  239. // RFC 4366, section 3.6
  240. z[0] = byte(extensionStatusRequest >> 8)
  241. z[1] = byte(extensionStatusRequest)
  242. z[2] = 0
  243. z[3] = 5
  244. z[4] = 1 // OCSP type
  245. // Two zero valued uint16s for the two lengths.
  246. z = z[9:]
  247. }
  248. if len(m.supportedCurves) > 0 {
  249. // http://tools.ietf.org/html/rfc4492#section-5.5.1
  250. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.4
  251. z[0] = byte(extensionSupportedCurves >> 8)
  252. z[1] = byte(extensionSupportedCurves)
  253. l := 2 + 2*len(m.supportedCurves)
  254. z[2] = byte(l >> 8)
  255. z[3] = byte(l)
  256. l -= 2
  257. z[4] = byte(l >> 8)
  258. z[5] = byte(l)
  259. z = z[6:]
  260. for _, curve := range m.supportedCurves {
  261. z[0] = byte(curve >> 8)
  262. z[1] = byte(curve)
  263. z = z[2:]
  264. }
  265. }
  266. if len(m.supportedPoints) > 0 {
  267. // http://tools.ietf.org/html/rfc4492#section-5.5.2
  268. z[0] = byte(extensionSupportedPoints >> 8)
  269. z[1] = byte(extensionSupportedPoints)
  270. l := 1 + len(m.supportedPoints)
  271. z[2] = byte(l >> 8)
  272. z[3] = byte(l)
  273. l--
  274. z[4] = byte(l)
  275. z = z[5:]
  276. for _, pointFormat := range m.supportedPoints {
  277. z[0] = pointFormat
  278. z = z[1:]
  279. }
  280. }
  281. if m.ticketSupported {
  282. // http://tools.ietf.org/html/rfc5077#section-3.2
  283. z[0] = byte(extensionSessionTicket >> 8)
  284. z[1] = byte(extensionSessionTicket)
  285. l := len(m.sessionTicket)
  286. z[2] = byte(l >> 8)
  287. z[3] = byte(l)
  288. z = z[4:]
  289. copy(z, m.sessionTicket)
  290. z = z[len(m.sessionTicket):]
  291. }
  292. if len(m.supportedSignatureAlgorithms) > 0 {
  293. // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
  294. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.3
  295. z[0] = byte(extensionSignatureAlgorithms >> 8)
  296. z[1] = byte(extensionSignatureAlgorithms)
  297. l := 2 + 2*len(m.supportedSignatureAlgorithms)
  298. z[2] = byte(l >> 8)
  299. z[3] = byte(l)
  300. z = z[4:]
  301. l -= 2
  302. z[0] = byte(l >> 8)
  303. z[1] = byte(l)
  304. z = z[2:]
  305. for _, sigAlgo := range m.supportedSignatureAlgorithms {
  306. z[0] = byte(sigAlgo >> 8)
  307. z[1] = byte(sigAlgo)
  308. z = z[2:]
  309. }
  310. }
  311. if m.secureRenegotiationSupported {
  312. z[0] = byte(extensionRenegotiationInfo >> 8)
  313. z[1] = byte(extensionRenegotiationInfo & 0xff)
  314. z[2] = 0
  315. z[3] = byte(len(m.secureRenegotiation) + 1)
  316. z[4] = byte(len(m.secureRenegotiation))
  317. z = z[5:]
  318. copy(z, m.secureRenegotiation)
  319. z = z[len(m.secureRenegotiation):]
  320. }
  321. if len(m.alpnProtocols) > 0 {
  322. z[0] = byte(extensionALPN >> 8)
  323. z[1] = byte(extensionALPN & 0xff)
  324. lengths := z[2:]
  325. z = z[6:]
  326. stringsLength := 0
  327. for _, s := range m.alpnProtocols {
  328. l := len(s)
  329. z[0] = byte(l)
  330. copy(z[1:], s)
  331. z = z[1+l:]
  332. stringsLength += 1 + l
  333. }
  334. lengths[2] = byte(stringsLength >> 8)
  335. lengths[3] = byte(stringsLength)
  336. stringsLength += 2
  337. lengths[0] = byte(stringsLength >> 8)
  338. lengths[1] = byte(stringsLength)
  339. }
  340. if m.scts {
  341. // https://tools.ietf.org/html/rfc6962#section-3.3.1
  342. z[0] = byte(extensionSCT >> 8)
  343. z[1] = byte(extensionSCT)
  344. // zero uint16 for the zero-length extension_data
  345. z = z[4:]
  346. }
  347. if len(m.keyShares) > 0 {
  348. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.5
  349. z[0] = byte(extensionKeyShare >> 8)
  350. z[1] = byte(extensionKeyShare)
  351. lengths := z[2:]
  352. z = z[6:]
  353. totalLength := 0
  354. for _, ks := range m.keyShares {
  355. z[0] = byte(ks.group >> 8)
  356. z[1] = byte(ks.group)
  357. z[2] = byte(len(ks.data) >> 8)
  358. z[3] = byte(len(ks.data))
  359. copy(z[4:], ks.data)
  360. z = z[4+len(ks.data):]
  361. totalLength += 4 + len(ks.data)
  362. }
  363. lengths[2] = byte(totalLength >> 8)
  364. lengths[3] = byte(totalLength)
  365. totalLength += 2
  366. lengths[0] = byte(totalLength >> 8)
  367. lengths[1] = byte(totalLength)
  368. }
  369. if len(m.supportedVersions) > 0 {
  370. z[0] = byte(extensionSupportedVersions >> 8)
  371. z[1] = byte(extensionSupportedVersions)
  372. l := 1 + 2*len(m.supportedVersions)
  373. z[2] = byte(l >> 8)
  374. z[3] = byte(l)
  375. l -= 1
  376. z[4] = byte(l)
  377. z = z[5:]
  378. for _, v := range m.supportedVersions {
  379. z[0] = byte(v >> 8)
  380. z[1] = byte(v)
  381. z = z[2:]
  382. }
  383. }
  384. if m.earlyData {
  385. z[0] = byte(extensionEarlyData >> 8)
  386. z[1] = byte(extensionEarlyData)
  387. z = z[4:]
  388. }
  389. m.raw = x
  390. return x
  391. }
  392. func (m *clientHelloMsg) unmarshal(data []byte) alert {
  393. if len(data) < 42 {
  394. return alertDecodeError
  395. }
  396. m.raw = data
  397. m.vers = uint16(data[4])<<8 | uint16(data[5])
  398. m.random = data[6:38]
  399. sessionIdLen := int(data[38])
  400. if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
  401. return alertDecodeError
  402. }
  403. m.sessionId = data[39 : 39+sessionIdLen]
  404. data = data[39+sessionIdLen:]
  405. bindersOffset := 39 + sessionIdLen
  406. if len(data) < 2 {
  407. return alertDecodeError
  408. }
  409. // cipherSuiteLen is the number of bytes of cipher suite numbers. Since
  410. // they are uint16s, the number must be even.
  411. cipherSuiteLen := int(data[0])<<8 | int(data[1])
  412. if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
  413. return alertDecodeError
  414. }
  415. numCipherSuites := cipherSuiteLen / 2
  416. m.cipherSuites = make([]uint16, numCipherSuites)
  417. for i := 0; i < numCipherSuites; i++ {
  418. m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
  419. if m.cipherSuites[i] == scsvRenegotiation {
  420. m.secureRenegotiationSupported = true
  421. }
  422. }
  423. data = data[2+cipherSuiteLen:]
  424. bindersOffset += 2 + cipherSuiteLen
  425. if len(data) < 1 {
  426. return alertDecodeError
  427. }
  428. compressionMethodsLen := int(data[0])
  429. if len(data) < 1+compressionMethodsLen {
  430. return alertDecodeError
  431. }
  432. m.compressionMethods = data[1 : 1+compressionMethodsLen]
  433. data = data[1+compressionMethodsLen:]
  434. bindersOffset += 1 + compressionMethodsLen
  435. m.nextProtoNeg = false
  436. m.serverName = ""
  437. m.ocspStapling = false
  438. m.ticketSupported = false
  439. m.sessionTicket = nil
  440. m.supportedSignatureAlgorithms = nil
  441. m.alpnProtocols = nil
  442. m.scts = false
  443. m.keyShares = nil
  444. m.supportedVersions = nil
  445. m.psks = nil
  446. m.pskKeyExchangeModes = nil
  447. m.earlyData = false
  448. if len(data) == 0 {
  449. // ClientHello is optionally followed by extension data
  450. return alertSuccess
  451. }
  452. if len(data) < 2 {
  453. return alertDecodeError
  454. }
  455. extensionsLength := int(data[0])<<8 | int(data[1])
  456. data = data[2:]
  457. bindersOffset += 2
  458. if extensionsLength != len(data) {
  459. return alertDecodeError
  460. }
  461. for len(data) != 0 {
  462. if len(data) < 4 {
  463. return alertDecodeError
  464. }
  465. extension := uint16(data[0])<<8 | uint16(data[1])
  466. length := int(data[2])<<8 | int(data[3])
  467. data = data[4:]
  468. bindersOffset += 4
  469. if len(data) < length {
  470. return alertDecodeError
  471. }
  472. switch extension {
  473. case extensionServerName:
  474. d := data[:length]
  475. if len(d) < 2 {
  476. return alertDecodeError
  477. }
  478. namesLen := int(d[0])<<8 | int(d[1])
  479. d = d[2:]
  480. if len(d) != namesLen {
  481. return alertDecodeError
  482. }
  483. for len(d) > 0 {
  484. if len(d) < 3 {
  485. return alertDecodeError
  486. }
  487. nameType := d[0]
  488. nameLen := int(d[1])<<8 | int(d[2])
  489. d = d[3:]
  490. if len(d) < nameLen {
  491. return alertDecodeError
  492. }
  493. if nameType == 0 {
  494. m.serverName = string(d[:nameLen])
  495. // An SNI value may not include a
  496. // trailing dot. See
  497. // https://tools.ietf.org/html/rfc6066#section-3.
  498. if strings.HasSuffix(m.serverName, ".") {
  499. // TODO use alertDecodeError?
  500. return alertUnexpectedMessage
  501. }
  502. break
  503. }
  504. d = d[nameLen:]
  505. }
  506. case extensionNextProtoNeg:
  507. if length > 0 {
  508. return alertDecodeError
  509. }
  510. m.nextProtoNeg = true
  511. case extensionStatusRequest:
  512. m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
  513. case extensionSupportedCurves:
  514. // http://tools.ietf.org/html/rfc4492#section-5.5.1
  515. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.4
  516. if length < 2 {
  517. return alertDecodeError
  518. }
  519. l := int(data[0])<<8 | int(data[1])
  520. if l%2 == 1 || length != l+2 {
  521. return alertDecodeError
  522. }
  523. numCurves := l / 2
  524. m.supportedCurves = make([]CurveID, numCurves)
  525. d := data[2:]
  526. for i := 0; i < numCurves; i++ {
  527. m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
  528. d = d[2:]
  529. }
  530. case extensionSupportedPoints:
  531. // http://tools.ietf.org/html/rfc4492#section-5.5.2
  532. if length < 1 {
  533. return alertDecodeError
  534. }
  535. l := int(data[0])
  536. if length != l+1 {
  537. return alertDecodeError
  538. }
  539. m.supportedPoints = make([]uint8, l)
  540. copy(m.supportedPoints, data[1:])
  541. case extensionSessionTicket:
  542. // http://tools.ietf.org/html/rfc5077#section-3.2
  543. m.ticketSupported = true
  544. m.sessionTicket = data[:length]
  545. case extensionSignatureAlgorithms:
  546. // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
  547. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.3
  548. if length < 2 || length&1 != 0 {
  549. return alertDecodeError
  550. }
  551. l := int(data[0])<<8 | int(data[1])
  552. if l != length-2 {
  553. return alertDecodeError
  554. }
  555. n := l / 2
  556. d := data[2:]
  557. m.supportedSignatureAlgorithms = make([]SignatureScheme, n)
  558. for i := range m.supportedSignatureAlgorithms {
  559. m.supportedSignatureAlgorithms[i] = SignatureScheme(d[0])<<8 | SignatureScheme(d[1])
  560. d = d[2:]
  561. }
  562. case extensionRenegotiationInfo:
  563. if length == 0 {
  564. return alertDecodeError
  565. }
  566. d := data[:length]
  567. l := int(d[0])
  568. d = d[1:]
  569. if l != len(d) {
  570. return alertDecodeError
  571. }
  572. m.secureRenegotiation = d
  573. m.secureRenegotiationSupported = true
  574. case extensionALPN:
  575. if length < 2 {
  576. return alertDecodeError
  577. }
  578. l := int(data[0])<<8 | int(data[1])
  579. if l != length-2 {
  580. return alertDecodeError
  581. }
  582. d := data[2:length]
  583. for len(d) != 0 {
  584. stringLen := int(d[0])
  585. d = d[1:]
  586. if stringLen == 0 || stringLen > len(d) {
  587. return alertDecodeError
  588. }
  589. m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
  590. d = d[stringLen:]
  591. }
  592. case extensionSCT:
  593. m.scts = true
  594. if length != 0 {
  595. return alertDecodeError
  596. }
  597. case extensionKeyShare:
  598. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.5
  599. if length < 2 {
  600. return alertDecodeError
  601. }
  602. l := int(data[0])<<8 | int(data[1])
  603. if l != length-2 {
  604. return alertDecodeError
  605. }
  606. d := data[2:length]
  607. for len(d) != 0 {
  608. if len(d) < 4 {
  609. return alertDecodeError
  610. }
  611. dataLen := int(d[2])<<8 | int(d[3])
  612. if dataLen == 0 || 4+dataLen > len(d) {
  613. return alertDecodeError
  614. }
  615. m.keyShares = append(m.keyShares, keyShare{
  616. group: CurveID(d[0])<<8 | CurveID(d[1]),
  617. data: d[4 : 4+dataLen],
  618. })
  619. d = d[4+dataLen:]
  620. }
  621. case extensionSupportedVersions:
  622. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.1
  623. if length < 1 {
  624. return alertDecodeError
  625. }
  626. l := int(data[0])
  627. if l%2 == 1 || length != l+1 {
  628. return alertDecodeError
  629. }
  630. n := l / 2
  631. d := data[1:]
  632. for i := 0; i < n; i++ {
  633. v := uint16(d[0])<<8 + uint16(d[1])
  634. m.supportedVersions = append(m.supportedVersions, v)
  635. d = d[2:]
  636. }
  637. case extensionPreSharedKey:
  638. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6
  639. if length < 2 {
  640. return alertDecodeError
  641. }
  642. // Ensure this extension is the last one in the Client Hello
  643. if len(data) != length {
  644. return alertIllegalParameter
  645. }
  646. li := int(data[0])<<8 | int(data[1])
  647. if 2+li+2 > length {
  648. return alertDecodeError
  649. }
  650. d := data[2 : 2+li]
  651. bindersOffset += 2 + li
  652. for len(d) > 0 {
  653. if len(d) < 6 {
  654. return alertDecodeError
  655. }
  656. l := int(d[0])<<8 | int(d[1])
  657. if len(d) < 2+l+4 {
  658. return alertDecodeError
  659. }
  660. m.psks = append(m.psks, psk{
  661. identity: d[2 : 2+l],
  662. obfTicketAge: uint32(d[l+2])<<24 | uint32(d[l+3])<<16 |
  663. uint32(d[l+4])<<8 | uint32(d[l+5]),
  664. })
  665. d = d[2+l+4:]
  666. }
  667. lb := int(data[li+2])<<8 | int(data[li+3])
  668. d = data[2+li+2:]
  669. if lb != len(d) || lb == 0 {
  670. return alertDecodeError
  671. }
  672. i := 0
  673. for len(d) > 0 {
  674. if i >= len(m.psks) {
  675. return alertIllegalParameter
  676. }
  677. if len(d) < 1 {
  678. return alertDecodeError
  679. }
  680. l := int(d[0])
  681. if l > len(d)-1 {
  682. return alertDecodeError
  683. }
  684. if i >= len(m.psks) {
  685. return alertIllegalParameter
  686. }
  687. m.psks[i].binder = d[1 : 1+l]
  688. d = d[1+l:]
  689. i++
  690. }
  691. if i != len(m.psks) {
  692. return alertIllegalParameter
  693. }
  694. m.rawTruncated = m.raw[:bindersOffset]
  695. case extensionPSKKeyExchangeModes:
  696. if length < 2 {
  697. return alertDecodeError
  698. }
  699. l := int(data[0])
  700. if length != l+1 {
  701. return alertDecodeError
  702. }
  703. m.pskKeyExchangeModes = data[1:length]
  704. case extensionEarlyData:
  705. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8
  706. m.earlyData = true
  707. }
  708. data = data[length:]
  709. bindersOffset += length
  710. }
  711. return alertSuccess
  712. }
  713. type serverHelloMsg struct {
  714. raw []byte
  715. vers uint16
  716. random []byte
  717. sessionId []byte
  718. cipherSuite uint16
  719. compressionMethod uint8
  720. nextProtoNeg bool
  721. nextProtos []string
  722. ocspStapling bool
  723. scts [][]byte
  724. ticketSupported bool
  725. secureRenegotiation []byte
  726. secureRenegotiationSupported bool
  727. alpnProtocol string
  728. // TLS 1.3
  729. keyShare keyShare
  730. psk bool
  731. pskIdentity uint16
  732. }
  733. func (m *serverHelloMsg) equal(i interface{}) bool {
  734. m1, ok := i.(*serverHelloMsg)
  735. if !ok {
  736. return false
  737. }
  738. if len(m.scts) != len(m1.scts) {
  739. return false
  740. }
  741. for i, sct := range m.scts {
  742. if !bytes.Equal(sct, m1.scts[i]) {
  743. return false
  744. }
  745. }
  746. return bytes.Equal(m.raw, m1.raw) &&
  747. m.vers == m1.vers &&
  748. bytes.Equal(m.random, m1.random) &&
  749. bytes.Equal(m.sessionId, m1.sessionId) &&
  750. m.cipherSuite == m1.cipherSuite &&
  751. m.compressionMethod == m1.compressionMethod &&
  752. m.nextProtoNeg == m1.nextProtoNeg &&
  753. eqStrings(m.nextProtos, m1.nextProtos) &&
  754. m.ocspStapling == m1.ocspStapling &&
  755. m.ticketSupported == m1.ticketSupported &&
  756. m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
  757. bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
  758. m.alpnProtocol == m1.alpnProtocol &&
  759. m.keyShare.group == m1.keyShare.group &&
  760. bytes.Equal(m.keyShare.data, m1.keyShare.data) &&
  761. m.psk == m1.psk &&
  762. m.pskIdentity == m1.pskIdentity
  763. }
  764. func (m *serverHelloMsg) marshal() []byte {
  765. if m.raw != nil {
  766. return m.raw
  767. }
  768. oldTLS13Draft := m.vers >= VersionTLS13Draft18 && m.vers <= VersionTLS13Draft21
  769. length := 38 + len(m.sessionId)
  770. if oldTLS13Draft {
  771. // no compression method, no session ID.
  772. length = 36
  773. }
  774. numExtensions := 0
  775. extensionsLength := 0
  776. nextProtoLen := 0
  777. if m.nextProtoNeg {
  778. numExtensions++
  779. for _, v := range m.nextProtos {
  780. nextProtoLen += len(v)
  781. }
  782. nextProtoLen += len(m.nextProtos)
  783. extensionsLength += nextProtoLen
  784. }
  785. if m.ocspStapling {
  786. numExtensions++
  787. }
  788. if m.ticketSupported {
  789. numExtensions++
  790. }
  791. if m.secureRenegotiationSupported {
  792. extensionsLength += 1 + len(m.secureRenegotiation)
  793. numExtensions++
  794. }
  795. if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
  796. if alpnLen >= 256 {
  797. panic("invalid ALPN protocol")
  798. }
  799. extensionsLength += 2 + 1 + alpnLen
  800. numExtensions++
  801. }
  802. sctLen := 0
  803. if len(m.scts) > 0 {
  804. for _, sct := range m.scts {
  805. sctLen += len(sct) + 2
  806. }
  807. extensionsLength += 2 + sctLen
  808. numExtensions++
  809. }
  810. if m.keyShare.group != 0 {
  811. extensionsLength += 4 + len(m.keyShare.data)
  812. numExtensions++
  813. }
  814. if m.psk {
  815. extensionsLength += 2
  816. numExtensions++
  817. }
  818. // supported_versions extension
  819. if m.vers >= VersionTLS13 {
  820. extensionsLength += 2
  821. numExtensions++
  822. }
  823. if numExtensions > 0 {
  824. extensionsLength += 4 * numExtensions
  825. length += 2 + extensionsLength
  826. }
  827. x := make([]byte, 4+length)
  828. x[0] = typeServerHello
  829. x[1] = uint8(length >> 16)
  830. x[2] = uint8(length >> 8)
  831. x[3] = uint8(length)
  832. if m.vers >= VersionTLS13 {
  833. x[4] = 3
  834. x[5] = 3
  835. } else {
  836. x[4] = uint8(m.vers >> 8)
  837. x[5] = uint8(m.vers)
  838. }
  839. copy(x[6:38], m.random)
  840. z := x[38:]
  841. if !oldTLS13Draft {
  842. x[38] = uint8(len(m.sessionId))
  843. copy(x[39:39+len(m.sessionId)], m.sessionId)
  844. z = x[39+len(m.sessionId):]
  845. }
  846. z[0] = uint8(m.cipherSuite >> 8)
  847. z[1] = uint8(m.cipherSuite)
  848. if oldTLS13Draft {
  849. // no compression method in older TLS 1.3 drafts.
  850. z = z[2:]
  851. } else {
  852. z[2] = m.compressionMethod
  853. z = z[3:]
  854. }
  855. if numExtensions > 0 {
  856. z[0] = byte(extensionsLength >> 8)
  857. z[1] = byte(extensionsLength)
  858. z = z[2:]
  859. }
  860. if m.vers >= VersionTLS13 {
  861. z[0] = byte(extensionSupportedVersions >> 8)
  862. z[1] = byte(extensionSupportedVersions)
  863. z[3] = 2
  864. z[4] = uint8(m.vers >> 8)
  865. z[5] = uint8(m.vers)
  866. z = z[6:]
  867. }
  868. if m.nextProtoNeg {
  869. z[0] = byte(extensionNextProtoNeg >> 8)
  870. z[1] = byte(extensionNextProtoNeg & 0xff)
  871. z[2] = byte(nextProtoLen >> 8)
  872. z[3] = byte(nextProtoLen)
  873. z = z[4:]
  874. for _, v := range m.nextProtos {
  875. l := len(v)
  876. if l > 255 {
  877. l = 255
  878. }
  879. z[0] = byte(l)
  880. copy(z[1:], []byte(v[0:l]))
  881. z = z[1+l:]
  882. }
  883. }
  884. if m.ocspStapling {
  885. z[0] = byte(extensionStatusRequest >> 8)
  886. z[1] = byte(extensionStatusRequest)
  887. z = z[4:]
  888. }
  889. if m.ticketSupported {
  890. z[0] = byte(extensionSessionTicket >> 8)
  891. z[1] = byte(extensionSessionTicket)
  892. z = z[4:]
  893. }
  894. if m.secureRenegotiationSupported {
  895. z[0] = byte(extensionRenegotiationInfo >> 8)
  896. z[1] = byte(extensionRenegotiationInfo & 0xff)
  897. z[2] = 0
  898. z[3] = byte(len(m.secureRenegotiation) + 1)
  899. z[4] = byte(len(m.secureRenegotiation))
  900. z = z[5:]
  901. copy(z, m.secureRenegotiation)
  902. z = z[len(m.secureRenegotiation):]
  903. }
  904. if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
  905. z[0] = byte(extensionALPN >> 8)
  906. z[1] = byte(extensionALPN & 0xff)
  907. l := 2 + 1 + alpnLen
  908. z[2] = byte(l >> 8)
  909. z[3] = byte(l)
  910. l -= 2
  911. z[4] = byte(l >> 8)
  912. z[5] = byte(l)
  913. l -= 1
  914. z[6] = byte(l)
  915. copy(z[7:], []byte(m.alpnProtocol))
  916. z = z[7+alpnLen:]
  917. }
  918. if sctLen > 0 {
  919. z[0] = byte(extensionSCT >> 8)
  920. z[1] = byte(extensionSCT)
  921. l := sctLen + 2
  922. z[2] = byte(l >> 8)
  923. z[3] = byte(l)
  924. z[4] = byte(sctLen >> 8)
  925. z[5] = byte(sctLen)
  926. z = z[6:]
  927. for _, sct := range m.scts {
  928. z[0] = byte(len(sct) >> 8)
  929. z[1] = byte(len(sct))
  930. copy(z[2:], sct)
  931. z = z[len(sct)+2:]
  932. }
  933. }
  934. if m.keyShare.group != 0 {
  935. z[0] = uint8(extensionKeyShare >> 8)
  936. z[1] = uint8(extensionKeyShare)
  937. l := 4 + len(m.keyShare.data)
  938. z[2] = uint8(l >> 8)
  939. z[3] = uint8(l)
  940. z[4] = uint8(m.keyShare.group >> 8)
  941. z[5] = uint8(m.keyShare.group)
  942. l -= 4
  943. z[6] = uint8(l >> 8)
  944. z[7] = uint8(l)
  945. copy(z[8:], m.keyShare.data)
  946. z = z[8+l:]
  947. }
  948. if m.psk {
  949. z[0] = byte(extensionPreSharedKey >> 8)
  950. z[1] = byte(extensionPreSharedKey)
  951. z[3] = 2
  952. z[4] = byte(m.pskIdentity >> 8)
  953. z[5] = byte(m.pskIdentity)
  954. z = z[6:]
  955. }
  956. m.raw = x
  957. return x
  958. }
  959. func (m *serverHelloMsg) unmarshal(data []byte) alert {
  960. if len(data) < 42 {
  961. return alertDecodeError
  962. }
  963. m.raw = data
  964. m.vers = uint16(data[4])<<8 | uint16(data[5])
  965. oldTLS13Draft := m.vers >= VersionTLS13Draft18 && m.vers <= VersionTLS13Draft21
  966. m.random = data[6:38]
  967. if !oldTLS13Draft {
  968. sessionIdLen := int(data[38])
  969. if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
  970. return alertDecodeError
  971. }
  972. m.sessionId = data[39 : 39+sessionIdLen]
  973. data = data[39+sessionIdLen:]
  974. } else {
  975. data = data[38:]
  976. }
  977. if len(data) < 3 {
  978. return alertDecodeError
  979. }
  980. m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
  981. if oldTLS13Draft {
  982. // no compression method in older TLS 1.3 drafts.
  983. data = data[2:]
  984. } else {
  985. m.compressionMethod = data[2]
  986. data = data[3:]
  987. }
  988. m.nextProtoNeg = false
  989. m.nextProtos = nil
  990. m.ocspStapling = false
  991. m.scts = nil
  992. m.ticketSupported = false
  993. m.alpnProtocol = ""
  994. m.keyShare.group = 0
  995. m.keyShare.data = nil
  996. m.psk = false
  997. m.pskIdentity = 0
  998. if len(data) == 0 {
  999. // ServerHello is optionally followed by extension data
  1000. return alertSuccess
  1001. }
  1002. if len(data) < 2 {
  1003. return alertDecodeError
  1004. }
  1005. extensionsLength := int(data[0])<<8 | int(data[1])
  1006. data = data[2:]
  1007. if len(data) != extensionsLength {
  1008. return alertDecodeError
  1009. }
  1010. svData := findExtension(data, extensionSupportedVersions)
  1011. if svData != nil {
  1012. if len(svData) != 2 {
  1013. return alertDecodeError
  1014. }
  1015. if m.vers != VersionTLS12 {
  1016. return alertDecodeError
  1017. }
  1018. m.vers = uint16(svData[0])<<8 | uint16(svData[1])
  1019. }
  1020. for len(data) != 0 {
  1021. if len(data) < 4 {
  1022. return alertDecodeError
  1023. }
  1024. extension := uint16(data[0])<<8 | uint16(data[1])
  1025. length := int(data[2])<<8 | int(data[3])
  1026. data = data[4:]
  1027. if len(data) < length {
  1028. return alertDecodeError
  1029. }
  1030. switch extension {
  1031. case extensionNextProtoNeg:
  1032. m.nextProtoNeg = true
  1033. d := data[:length]
  1034. for len(d) > 0 {
  1035. l := int(d[0])
  1036. d = d[1:]
  1037. if l == 0 || l > len(d) {
  1038. return alertDecodeError
  1039. }
  1040. m.nextProtos = append(m.nextProtos, string(d[:l]))
  1041. d = d[l:]
  1042. }
  1043. case extensionStatusRequest:
  1044. if length > 0 {
  1045. return alertDecodeError
  1046. }
  1047. m.ocspStapling = true
  1048. case extensionSessionTicket:
  1049. if length > 0 {
  1050. return alertDecodeError
  1051. }
  1052. m.ticketSupported = true
  1053. case extensionRenegotiationInfo:
  1054. if length == 0 {
  1055. return alertDecodeError
  1056. }
  1057. d := data[:length]
  1058. l := int(d[0])
  1059. d = d[1:]
  1060. if l != len(d) {
  1061. return alertDecodeError
  1062. }
  1063. m.secureRenegotiation = d
  1064. m.secureRenegotiationSupported = true
  1065. case extensionALPN:
  1066. d := data[:length]
  1067. if len(d) < 3 {
  1068. return alertDecodeError
  1069. }
  1070. l := int(d[0])<<8 | int(d[1])
  1071. if l != len(d)-2 {
  1072. return alertDecodeError
  1073. }
  1074. d = d[2:]
  1075. l = int(d[0])
  1076. if l != len(d)-1 {
  1077. return alertDecodeError
  1078. }
  1079. d = d[1:]
  1080. if len(d) == 0 {
  1081. // ALPN protocols must not be empty.
  1082. return alertDecodeError
  1083. }
  1084. m.alpnProtocol = string(d)
  1085. case extensionSCT:
  1086. d := data[:length]
  1087. if len(d) < 2 {
  1088. return alertDecodeError
  1089. }
  1090. l := int(d[0])<<8 | int(d[1])
  1091. d = d[2:]
  1092. if len(d) != l || l == 0 {
  1093. return alertDecodeError
  1094. }
  1095. m.scts = make([][]byte, 0, 3)
  1096. for len(d) != 0 {
  1097. if len(d) < 2 {
  1098. return alertDecodeError
  1099. }
  1100. sctLen := int(d[0])<<8 | int(d[1])
  1101. d = d[2:]
  1102. if sctLen == 0 || len(d) < sctLen {
  1103. return alertDecodeError
  1104. }
  1105. m.scts = append(m.scts, d[:sctLen])
  1106. d = d[sctLen:]
  1107. }
  1108. case extensionKeyShare:
  1109. d := data[:length]
  1110. if len(d) < 4 {
  1111. return alertDecodeError
  1112. }
  1113. m.keyShare.group = CurveID(d[0])<<8 | CurveID(d[1])
  1114. l := int(d[2])<<8 | int(d[3])
  1115. d = d[4:]
  1116. if len(d) != l {
  1117. return alertDecodeError
  1118. }
  1119. m.keyShare.data = d[:l]
  1120. case extensionPreSharedKey:
  1121. if length != 2 {
  1122. return alertDecodeError
  1123. }
  1124. m.psk = true
  1125. m.pskIdentity = uint16(data[0])<<8 | uint16(data[1])
  1126. }
  1127. data = data[length:]
  1128. }
  1129. return alertSuccess
  1130. }
  1131. type encryptedExtensionsMsg struct {
  1132. raw []byte
  1133. alpnProtocol string
  1134. earlyData bool
  1135. }
  1136. func (m *encryptedExtensionsMsg) equal(i interface{}) bool {
  1137. m1, ok := i.(*encryptedExtensionsMsg)
  1138. if !ok {
  1139. return false
  1140. }
  1141. return bytes.Equal(m.raw, m1.raw) &&
  1142. m.alpnProtocol == m1.alpnProtocol &&
  1143. m.earlyData == m1.earlyData
  1144. }
  1145. func (m *encryptedExtensionsMsg) marshal() []byte {
  1146. if m.raw != nil {
  1147. return m.raw
  1148. }
  1149. length := 2
  1150. if m.earlyData {
  1151. length += 4
  1152. }
  1153. alpnLen := len(m.alpnProtocol)
  1154. if alpnLen > 0 {
  1155. if alpnLen >= 256 {
  1156. panic("invalid ALPN protocol")
  1157. }
  1158. length += 2 + 2 + 2 + 1 + alpnLen
  1159. }
  1160. x := make([]byte, 4+length)
  1161. x[0] = typeEncryptedExtensions
  1162. x[1] = uint8(length >> 16)
  1163. x[2] = uint8(length >> 8)
  1164. x[3] = uint8(length)
  1165. length -= 2
  1166. x[4] = uint8(length >> 8)
  1167. x[5] = uint8(length)
  1168. z := x[6:]
  1169. if alpnLen > 0 {
  1170. z[0] = byte(extensionALPN >> 8)
  1171. z[1] = byte(extensionALPN)
  1172. l := 2 + 1 + alpnLen
  1173. z[2] = byte(l >> 8)
  1174. z[3] = byte(l)
  1175. l -= 2
  1176. z[4] = byte(l >> 8)
  1177. z[5] = byte(l)
  1178. l -= 1
  1179. z[6] = byte(l)
  1180. copy(z[7:], []byte(m.alpnProtocol))
  1181. z = z[7+alpnLen:]
  1182. }
  1183. if m.earlyData {
  1184. z[0] = byte(extensionEarlyData >> 8)
  1185. z[1] = byte(extensionEarlyData)
  1186. z = z[4:]
  1187. }
  1188. m.raw = x
  1189. return x
  1190. }
  1191. func (m *encryptedExtensionsMsg) unmarshal(data []byte) alert {
  1192. if len(data) < 6 {
  1193. return alertDecodeError
  1194. }
  1195. m.raw = data
  1196. m.alpnProtocol = ""
  1197. m.earlyData = false
  1198. extensionsLength := int(data[4])<<8 | int(data[5])
  1199. data = data[6:]
  1200. if len(data) != extensionsLength {
  1201. return alertDecodeError
  1202. }
  1203. for len(data) != 0 {
  1204. if len(data) < 4 {
  1205. return alertDecodeError
  1206. }
  1207. extension := uint16(data[0])<<8 | uint16(data[1])
  1208. length := int(data[2])<<8 | int(data[3])
  1209. data = data[4:]
  1210. if len(data) < length {
  1211. return alertDecodeError
  1212. }
  1213. switch extension {
  1214. case extensionALPN:
  1215. d := data[:length]
  1216. if len(d) < 3 {
  1217. return alertDecodeError
  1218. }
  1219. l := int(d[0])<<8 | int(d[1])
  1220. if l != len(d)-2 {
  1221. return alertDecodeError
  1222. }
  1223. d = d[2:]
  1224. l = int(d[0])
  1225. if l != len(d)-1 {
  1226. return alertDecodeError
  1227. }
  1228. d = d[1:]
  1229. if len(d) == 0 {
  1230. // ALPN protocols must not be empty.
  1231. return alertDecodeError
  1232. }
  1233. m.alpnProtocol = string(d)
  1234. case extensionEarlyData:
  1235. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8
  1236. m.earlyData = true
  1237. }
  1238. data = data[length:]
  1239. }
  1240. return alertSuccess
  1241. }
  1242. type certificateMsg struct {
  1243. raw []byte
  1244. certificates [][]byte
  1245. }
  1246. func (m *certificateMsg) equal(i interface{}) bool {
  1247. m1, ok := i.(*certificateMsg)
  1248. if !ok {
  1249. return false
  1250. }
  1251. return bytes.Equal(m.raw, m1.raw) &&
  1252. eqByteSlices(m.certificates, m1.certificates)
  1253. }
  1254. func (m *certificateMsg) marshal() (x []byte) {
  1255. if m.raw != nil {
  1256. return m.raw
  1257. }
  1258. var i int
  1259. for _, slice := range m.certificates {
  1260. i += len(slice)
  1261. }
  1262. length := 3 + 3*len(m.certificates) + i
  1263. x = make([]byte, 4+length)
  1264. x[0] = typeCertificate
  1265. x[1] = uint8(length >> 16)
  1266. x[2] = uint8(length >> 8)
  1267. x[3] = uint8(length)
  1268. certificateOctets := length - 3
  1269. x[4] = uint8(certificateOctets >> 16)
  1270. x[5] = uint8(certificateOctets >> 8)
  1271. x[6] = uint8(certificateOctets)
  1272. y := x[7:]
  1273. for _, slice := range m.certificates {
  1274. y[0] = uint8(len(slice) >> 16)
  1275. y[1] = uint8(len(slice) >> 8)
  1276. y[2] = uint8(len(slice))
  1277. copy(y[3:], slice)
  1278. y = y[3+len(slice):]
  1279. }
  1280. m.raw = x
  1281. return
  1282. }
  1283. func (m *certificateMsg) unmarshal(data []byte) alert {
  1284. if len(data) < 7 {
  1285. return alertDecodeError
  1286. }
  1287. m.raw = data
  1288. certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
  1289. if uint32(len(data)) != certsLen+7 {
  1290. return alertDecodeError
  1291. }
  1292. numCerts := 0
  1293. d := data[7:]
  1294. for certsLen > 0 {
  1295. if len(d) < 4 {
  1296. return alertDecodeError
  1297. }
  1298. certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1299. if uint32(len(d)) < 3+certLen {
  1300. return alertDecodeError
  1301. }
  1302. d = d[3+certLen:]
  1303. certsLen -= 3 + certLen
  1304. numCerts++
  1305. }
  1306. m.certificates = make([][]byte, numCerts)
  1307. d = data[7:]
  1308. for i := 0; i < numCerts; i++ {
  1309. certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1310. m.certificates[i] = d[3 : 3+certLen]
  1311. d = d[3+certLen:]
  1312. }
  1313. return alertSuccess
  1314. }
  1315. type certificateEntry struct {
  1316. data []byte
  1317. ocspStaple []byte
  1318. sctList [][]byte
  1319. }
  1320. type certificateMsg13 struct {
  1321. raw []byte
  1322. requestContext []byte
  1323. certificates []certificateEntry
  1324. }
  1325. func (m *certificateMsg13) equal(i interface{}) bool {
  1326. m1, ok := i.(*certificateMsg13)
  1327. if !ok {
  1328. return false
  1329. }
  1330. if len(m.certificates) != len(m1.certificates) {
  1331. return false
  1332. }
  1333. for i, _ := range m.certificates {
  1334. ok := bytes.Equal(m.certificates[i].data, m1.certificates[i].data)
  1335. ok = ok && bytes.Equal(m.certificates[i].ocspStaple, m1.certificates[i].ocspStaple)
  1336. ok = ok && eqByteSlices(m.certificates[i].sctList, m1.certificates[i].sctList)
  1337. if !ok {
  1338. return false
  1339. }
  1340. }
  1341. return bytes.Equal(m.raw, m1.raw) &&
  1342. bytes.Equal(m.requestContext, m1.requestContext)
  1343. }
  1344. func (m *certificateMsg13) marshal() (x []byte) {
  1345. if m.raw != nil {
  1346. return m.raw
  1347. }
  1348. var i int
  1349. for _, cert := range m.certificates {
  1350. i += len(cert.data)
  1351. if len(cert.ocspStaple) != 0 {
  1352. i += 8 + len(cert.ocspStaple)
  1353. }
  1354. if len(cert.sctList) != 0 {
  1355. i += 6
  1356. for _, sct := range cert.sctList {
  1357. i += 2 + len(sct)
  1358. }
  1359. }
  1360. }
  1361. length := 3 + 3*len(m.certificates) + i
  1362. length += 2 * len(m.certificates) // extensions
  1363. length += 1 + len(m.requestContext)
  1364. x = make([]byte, 4+length)
  1365. x[0] = typeCertificate
  1366. x[1] = uint8(length >> 16)
  1367. x[2] = uint8(length >> 8)
  1368. x[3] = uint8(length)
  1369. z := x[4:]
  1370. z[0] = byte(len(m.requestContext))
  1371. copy(z[1:], m.requestContext)
  1372. z = z[1+len(m.requestContext):]
  1373. certificateOctets := len(z) - 3
  1374. z[0] = uint8(certificateOctets >> 16)
  1375. z[1] = uint8(certificateOctets >> 8)
  1376. z[2] = uint8(certificateOctets)
  1377. z = z[3:]
  1378. for _, cert := range m.certificates {
  1379. z[0] = uint8(len(cert.data) >> 16)
  1380. z[1] = uint8(len(cert.data) >> 8)
  1381. z[2] = uint8(len(cert.data))
  1382. copy(z[3:], cert.data)
  1383. z = z[3+len(cert.data):]
  1384. extLenPos := z[:2]
  1385. z = z[2:]
  1386. extensionLen := 0
  1387. if len(cert.ocspStaple) != 0 {
  1388. stapleLen := 4 + len(cert.ocspStaple)
  1389. z[0] = uint8(extensionStatusRequest >> 8)
  1390. z[1] = uint8(extensionStatusRequest)
  1391. z[2] = uint8(stapleLen >> 8)
  1392. z[3] = uint8(stapleLen)
  1393. stapleLen -= 4
  1394. z[4] = statusTypeOCSP
  1395. z[5] = uint8(stapleLen >> 16)
  1396. z[6] = uint8(stapleLen >> 8)
  1397. z[7] = uint8(stapleLen)
  1398. copy(z[8:], cert.ocspStaple)
  1399. z = z[8+stapleLen:]
  1400. extensionLen += 8 + stapleLen
  1401. }
  1402. if len(cert.sctList) != 0 {
  1403. z[0] = uint8(extensionSCT >> 8)
  1404. z[1] = uint8(extensionSCT)
  1405. sctLenPos := z[2:6]
  1406. z = z[6:]
  1407. extensionLen += 6
  1408. sctLen := 2
  1409. for _, sct := range cert.sctList {
  1410. z[0] = uint8(len(sct) >> 8)
  1411. z[1] = uint8(len(sct))
  1412. copy(z[2:], sct)
  1413. z = z[2+len(sct):]
  1414. extensionLen += 2 + len(sct)
  1415. sctLen += 2 + len(sct)
  1416. }
  1417. sctLenPos[0] = uint8(sctLen >> 8)
  1418. sctLenPos[1] = uint8(sctLen)
  1419. sctLen -= 2
  1420. sctLenPos[2] = uint8(sctLen >> 8)
  1421. sctLenPos[3] = uint8(sctLen)
  1422. }
  1423. extLenPos[0] = uint8(extensionLen >> 8)
  1424. extLenPos[1] = uint8(extensionLen)
  1425. }
  1426. m.raw = x
  1427. return
  1428. }
  1429. func (m *certificateMsg13) unmarshal(data []byte) alert {
  1430. if len(data) < 5 {
  1431. return alertDecodeError
  1432. }
  1433. m.raw = data
  1434. ctxLen := data[4]
  1435. if len(data) < int(ctxLen)+5+3 {
  1436. return alertDecodeError
  1437. }
  1438. m.requestContext = data[5 : 5+ctxLen]
  1439. d := data[5+ctxLen:]
  1440. certsLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1441. if uint32(len(d)) != certsLen+3 {
  1442. return alertDecodeError
  1443. }
  1444. numCerts := 0
  1445. d = d[3:]
  1446. for certsLen > 0 {
  1447. if len(d) < 4 {
  1448. return alertDecodeError
  1449. }
  1450. certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1451. if uint32(len(d)) < 3+certLen {
  1452. return alertDecodeError
  1453. }
  1454. d = d[3+certLen:]
  1455. if len(d) < 2 {
  1456. return alertDecodeError
  1457. }
  1458. extLen := uint16(d[0])<<8 | uint16(d[1])
  1459. if uint16(len(d)) < 2+extLen {
  1460. return alertDecodeError
  1461. }
  1462. d = d[2+extLen:]
  1463. certsLen -= 3 + certLen + 2 + uint32(extLen)
  1464. numCerts++
  1465. }
  1466. m.certificates = make([]certificateEntry, numCerts)
  1467. d = data[8+ctxLen:]
  1468. for i := 0; i < numCerts; i++ {
  1469. certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1470. m.certificates[i].data = d[3 : 3+certLen]
  1471. d = d[3+certLen:]
  1472. extLen := uint16(d[0])<<8 | uint16(d[1])
  1473. d = d[2:]
  1474. for extLen > 0 {
  1475. if extLen < 4 {
  1476. return alertDecodeError
  1477. }
  1478. typ := uint16(d[0])<<8 | uint16(d[1])
  1479. bodyLen := uint16(d[2])<<8 | uint16(d[3])
  1480. if extLen < 4+bodyLen {
  1481. return alertDecodeError
  1482. }
  1483. body := d[4 : 4+bodyLen]
  1484. d = d[4+bodyLen:]
  1485. extLen -= 4 + bodyLen
  1486. switch typ {
  1487. case extensionStatusRequest:
  1488. if len(body) < 4 || body[0] != 0x01 {
  1489. return alertDecodeError
  1490. }
  1491. ocspLen := int(body[1])<<16 | int(body[2])<<8 | int(body[3])
  1492. if len(body) != 4+ocspLen {
  1493. return alertDecodeError
  1494. }
  1495. m.certificates[i].ocspStaple = body[4:]
  1496. case extensionSCT:
  1497. if len(body) < 2 {
  1498. return alertDecodeError
  1499. }
  1500. listLen := int(body[0])<<8 | int(body[1])
  1501. body = body[2:]
  1502. if len(body) != listLen {
  1503. return alertDecodeError
  1504. }
  1505. for len(body) > 0 {
  1506. if len(body) < 2 {
  1507. return alertDecodeError
  1508. }
  1509. sctLen := int(body[0])<<8 | int(body[1])
  1510. if len(body) < 2+sctLen {
  1511. return alertDecodeError
  1512. }
  1513. m.certificates[i].sctList = append(m.certificates[i].sctList, body[2:2+sctLen])
  1514. body = body[2+sctLen:]
  1515. }
  1516. }
  1517. }
  1518. }
  1519. return alertSuccess
  1520. }
  1521. type serverKeyExchangeMsg struct {
  1522. raw []byte
  1523. key []byte
  1524. }
  1525. func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
  1526. m1, ok := i.(*serverKeyExchangeMsg)
  1527. if !ok {
  1528. return false
  1529. }
  1530. return bytes.Equal(m.raw, m1.raw) &&
  1531. bytes.Equal(m.key, m1.key)
  1532. }
  1533. func (m *serverKeyExchangeMsg) marshal() []byte {
  1534. if m.raw != nil {
  1535. return m.raw
  1536. }
  1537. length := len(m.key)
  1538. x := make([]byte, length+4)
  1539. x[0] = typeServerKeyExchange
  1540. x[1] = uint8(length >> 16)
  1541. x[2] = uint8(length >> 8)
  1542. x[3] = uint8(length)
  1543. copy(x[4:], m.key)
  1544. m.raw = x
  1545. return x
  1546. }
  1547. func (m *serverKeyExchangeMsg) unmarshal(data []byte) alert {
  1548. m.raw = data
  1549. if len(data) < 4 {
  1550. return alertDecodeError
  1551. }
  1552. m.key = data[4:]
  1553. return alertSuccess
  1554. }
  1555. type certificateStatusMsg struct {
  1556. raw []byte
  1557. statusType uint8
  1558. response []byte
  1559. }
  1560. func (m *certificateStatusMsg) equal(i interface{}) bool {
  1561. m1, ok := i.(*certificateStatusMsg)
  1562. if !ok {
  1563. return false
  1564. }
  1565. return bytes.Equal(m.raw, m1.raw) &&
  1566. m.statusType == m1.statusType &&
  1567. bytes.Equal(m.response, m1.response)
  1568. }
  1569. func (m *certificateStatusMsg) marshal() []byte {
  1570. if m.raw != nil {
  1571. return m.raw
  1572. }
  1573. var x []byte
  1574. if m.statusType == statusTypeOCSP {
  1575. x = make([]byte, 4+4+len(m.response))
  1576. x[0] = typeCertificateStatus
  1577. l := len(m.response) + 4
  1578. x[1] = byte(l >> 16)
  1579. x[2] = byte(l >> 8)
  1580. x[3] = byte(l)
  1581. x[4] = statusTypeOCSP
  1582. l -= 4
  1583. x[5] = byte(l >> 16)
  1584. x[6] = byte(l >> 8)
  1585. x[7] = byte(l)
  1586. copy(x[8:], m.response)
  1587. } else {
  1588. x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
  1589. }
  1590. m.raw = x
  1591. return x
  1592. }
  1593. func (m *certificateStatusMsg) unmarshal(data []byte) alert {
  1594. m.raw = data
  1595. if len(data) < 5 {
  1596. return alertDecodeError
  1597. }
  1598. m.statusType = data[4]
  1599. m.response = nil
  1600. if m.statusType == statusTypeOCSP {
  1601. if len(data) < 8 {
  1602. return alertDecodeError
  1603. }
  1604. respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
  1605. if uint32(len(data)) != 4+4+respLen {
  1606. return alertDecodeError
  1607. }
  1608. m.response = data[8:]
  1609. }
  1610. return alertSuccess
  1611. }
  1612. type serverHelloDoneMsg struct{}
  1613. func (m *serverHelloDoneMsg) equal(i interface{}) bool {
  1614. _, ok := i.(*serverHelloDoneMsg)
  1615. return ok
  1616. }
  1617. func (m *serverHelloDoneMsg) marshal() []byte {
  1618. x := make([]byte, 4)
  1619. x[0] = typeServerHelloDone
  1620. return x
  1621. }
  1622. func (m *serverHelloDoneMsg) unmarshal(data []byte) alert {
  1623. if len(data) != 4 {
  1624. return alertDecodeError
  1625. }
  1626. return alertSuccess
  1627. }
  1628. type clientKeyExchangeMsg struct {
  1629. raw []byte
  1630. ciphertext []byte
  1631. }
  1632. func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
  1633. m1, ok := i.(*clientKeyExchangeMsg)
  1634. if !ok {
  1635. return false
  1636. }
  1637. return bytes.Equal(m.raw, m1.raw) &&
  1638. bytes.Equal(m.ciphertext, m1.ciphertext)
  1639. }
  1640. func (m *clientKeyExchangeMsg) marshal() []byte {
  1641. if m.raw != nil {
  1642. return m.raw
  1643. }
  1644. length := len(m.ciphertext)
  1645. x := make([]byte, length+4)
  1646. x[0] = typeClientKeyExchange
  1647. x[1] = uint8(length >> 16)
  1648. x[2] = uint8(length >> 8)
  1649. x[3] = uint8(length)
  1650. copy(x[4:], m.ciphertext)
  1651. m.raw = x
  1652. return x
  1653. }
  1654. func (m *clientKeyExchangeMsg) unmarshal(data []byte) alert {
  1655. m.raw = data
  1656. if len(data) < 4 {
  1657. return alertDecodeError
  1658. }
  1659. l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  1660. if l != len(data)-4 {
  1661. return alertDecodeError
  1662. }
  1663. m.ciphertext = data[4:]
  1664. return alertSuccess
  1665. }
  1666. type finishedMsg struct {
  1667. raw []byte
  1668. verifyData []byte
  1669. }
  1670. func (m *finishedMsg) equal(i interface{}) bool {
  1671. m1, ok := i.(*finishedMsg)
  1672. if !ok {
  1673. return false
  1674. }
  1675. return bytes.Equal(m.raw, m1.raw) &&
  1676. bytes.Equal(m.verifyData, m1.verifyData)
  1677. }
  1678. func (m *finishedMsg) marshal() (x []byte) {
  1679. if m.raw != nil {
  1680. return m.raw
  1681. }
  1682. x = make([]byte, 4+len(m.verifyData))
  1683. x[0] = typeFinished
  1684. x[3] = byte(len(m.verifyData))
  1685. copy(x[4:], m.verifyData)
  1686. m.raw = x
  1687. return
  1688. }
  1689. func (m *finishedMsg) unmarshal(data []byte) alert {
  1690. m.raw = data
  1691. if len(data) < 4 {
  1692. return alertDecodeError
  1693. }
  1694. m.verifyData = data[4:]
  1695. return alertSuccess
  1696. }
  1697. type nextProtoMsg struct {
  1698. raw []byte
  1699. proto string
  1700. }
  1701. func (m *nextProtoMsg) equal(i interface{}) bool {
  1702. m1, ok := i.(*nextProtoMsg)
  1703. if !ok {
  1704. return false
  1705. }
  1706. return bytes.Equal(m.raw, m1.raw) &&
  1707. m.proto == m1.proto
  1708. }
  1709. func (m *nextProtoMsg) marshal() []byte {
  1710. if m.raw != nil {
  1711. return m.raw
  1712. }
  1713. l := len(m.proto)
  1714. if l > 255 {
  1715. l = 255
  1716. }
  1717. padding := 32 - (l+2)%32
  1718. length := l + padding + 2
  1719. x := make([]byte, length+4)
  1720. x[0] = typeNextProtocol
  1721. x[1] = uint8(length >> 16)
  1722. x[2] = uint8(length >> 8)
  1723. x[3] = uint8(length)
  1724. y := x[4:]
  1725. y[0] = byte(l)
  1726. copy(y[1:], []byte(m.proto[0:l]))
  1727. y = y[1+l:]
  1728. y[0] = byte(padding)
  1729. m.raw = x
  1730. return x
  1731. }
  1732. func (m *nextProtoMsg) unmarshal(data []byte) alert {
  1733. m.raw = data
  1734. if len(data) < 5 {
  1735. return alertDecodeError
  1736. }
  1737. data = data[4:]
  1738. protoLen := int(data[0])
  1739. data = data[1:]
  1740. if len(data) < protoLen {
  1741. return alertDecodeError
  1742. }
  1743. m.proto = string(data[0:protoLen])
  1744. data = data[protoLen:]
  1745. if len(data) < 1 {
  1746. return alertDecodeError
  1747. }
  1748. paddingLen := int(data[0])
  1749. data = data[1:]
  1750. if len(data) != paddingLen {
  1751. return alertDecodeError
  1752. }
  1753. return alertSuccess
  1754. }
  1755. type certificateRequestMsg struct {
  1756. raw []byte
  1757. // hasSignatureAndHash indicates whether this message includes a list
  1758. // of signature and hash functions. This change was introduced with TLS
  1759. // 1.2.
  1760. hasSignatureAndHash bool
  1761. certificateTypes []byte
  1762. supportedSignatureAlgorithms []SignatureScheme
  1763. certificateAuthorities [][]byte
  1764. }
  1765. func (m *certificateRequestMsg) equal(i interface{}) bool {
  1766. m1, ok := i.(*certificateRequestMsg)
  1767. if !ok {
  1768. return false
  1769. }
  1770. return bytes.Equal(m.raw, m1.raw) &&
  1771. bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
  1772. eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
  1773. eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms)
  1774. }
  1775. func (m *certificateRequestMsg) marshal() (x []byte) {
  1776. if m.raw != nil {
  1777. return m.raw
  1778. }
  1779. // See http://tools.ietf.org/html/rfc4346#section-7.4.4
  1780. length := 1 + len(m.certificateTypes) + 2
  1781. casLength := 0
  1782. for _, ca := range m.certificateAuthorities {
  1783. casLength += 2 + len(ca)
  1784. }
  1785. length += casLength
  1786. if m.hasSignatureAndHash {
  1787. length += 2 + 2*len(m.supportedSignatureAlgorithms)
  1788. }
  1789. x = make([]byte, 4+length)
  1790. x[0] = typeCertificateRequest
  1791. x[1] = uint8(length >> 16)
  1792. x[2] = uint8(length >> 8)
  1793. x[3] = uint8(length)
  1794. x[4] = uint8(len(m.certificateTypes))
  1795. copy(x[5:], m.certificateTypes)
  1796. y := x[5+len(m.certificateTypes):]
  1797. if m.hasSignatureAndHash {
  1798. n := len(m.supportedSignatureAlgorithms) * 2
  1799. y[0] = uint8(n >> 8)
  1800. y[1] = uint8(n)
  1801. y = y[2:]
  1802. for _, sigAlgo := range m.supportedSignatureAlgorithms {
  1803. y[0] = uint8(sigAlgo >> 8)
  1804. y[1] = uint8(sigAlgo)
  1805. y = y[2:]
  1806. }
  1807. }
  1808. y[0] = uint8(casLength >> 8)
  1809. y[1] = uint8(casLength)
  1810. y = y[2:]
  1811. for _, ca := range m.certificateAuthorities {
  1812. y[0] = uint8(len(ca) >> 8)
  1813. y[1] = uint8(len(ca))
  1814. y = y[2:]
  1815. copy(y, ca)
  1816. y = y[len(ca):]
  1817. }
  1818. m.raw = x
  1819. return
  1820. }
  1821. func (m *certificateRequestMsg) unmarshal(data []byte) alert {
  1822. m.raw = data
  1823. if len(data) < 5 {
  1824. return alertDecodeError
  1825. }
  1826. length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1827. if uint32(len(data))-4 != length {
  1828. return alertDecodeError
  1829. }
  1830. numCertTypes := int(data[4])
  1831. data = data[5:]
  1832. if numCertTypes == 0 || len(data) <= numCertTypes {
  1833. return alertDecodeError
  1834. }
  1835. m.certificateTypes = make([]byte, numCertTypes)
  1836. if copy(m.certificateTypes, data) != numCertTypes {
  1837. return alertDecodeError
  1838. }
  1839. data = data[numCertTypes:]
  1840. if m.hasSignatureAndHash {
  1841. if len(data) < 2 {
  1842. return alertDecodeError
  1843. }
  1844. sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
  1845. data = data[2:]
  1846. if sigAndHashLen&1 != 0 {
  1847. return alertDecodeError
  1848. }
  1849. if len(data) < int(sigAndHashLen) {
  1850. return alertDecodeError
  1851. }
  1852. numSigAlgos := sigAndHashLen / 2
  1853. m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
  1854. for i := range m.supportedSignatureAlgorithms {
  1855. m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
  1856. data = data[2:]
  1857. }
  1858. }
  1859. if len(data) < 2 {
  1860. return alertDecodeError
  1861. }
  1862. casLength := uint16(data[0])<<8 | uint16(data[1])
  1863. data = data[2:]
  1864. if len(data) < int(casLength) {
  1865. return alertDecodeError
  1866. }
  1867. cas := make([]byte, casLength)
  1868. copy(cas, data)
  1869. data = data[casLength:]
  1870. m.certificateAuthorities = nil
  1871. for len(cas) > 0 {
  1872. if len(cas) < 2 {
  1873. return alertDecodeError
  1874. }
  1875. caLen := uint16(cas[0])<<8 | uint16(cas[1])
  1876. cas = cas[2:]
  1877. if len(cas) < int(caLen) {
  1878. return alertDecodeError
  1879. }
  1880. m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  1881. cas = cas[caLen:]
  1882. }
  1883. if len(data) != 0 {
  1884. return alertDecodeError
  1885. }
  1886. return alertSuccess
  1887. }
  1888. type certificateRequestMsg13 struct {
  1889. raw []byte
  1890. requestContext []byte
  1891. supportedSignatureAlgorithms []SignatureScheme
  1892. certificateAuthorities [][]byte
  1893. }
  1894. func (m *certificateRequestMsg13) equal(i interface{}) bool {
  1895. m1, ok := i.(*certificateRequestMsg13)
  1896. return ok &&
  1897. bytes.Equal(m.raw, m1.raw) &&
  1898. bytes.Equal(m.requestContext, m1.requestContext) &&
  1899. eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
  1900. eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms)
  1901. }
  1902. func (m *certificateRequestMsg13) marshal() (x []byte) {
  1903. if m.raw != nil {
  1904. return m.raw
  1905. }
  1906. // See https://tools.ietf.org/html/draft-ietf-tls-tls13-21#section-4.3.2
  1907. length := 1 + len(m.requestContext)
  1908. numExtensions := 1
  1909. extensionsLength := 2 + 2*len(m.supportedSignatureAlgorithms)
  1910. casLength := 0
  1911. if len(m.certificateAuthorities) > 0 {
  1912. for _, ca := range m.certificateAuthorities {
  1913. casLength += 2 + len(ca)
  1914. }
  1915. extensionsLength += 2 + casLength
  1916. numExtensions++
  1917. }
  1918. extensionsLength += 4 * numExtensions
  1919. length += 2 + extensionsLength
  1920. x = make([]byte, 4+length)
  1921. x[0] = typeCertificateRequest
  1922. x[1] = uint8(length >> 16)
  1923. x[2] = uint8(length >> 8)
  1924. x[3] = uint8(length)
  1925. x[4] = uint8(len(m.requestContext))
  1926. copy(x[5:], m.requestContext)
  1927. z := x[5+len(m.requestContext):]
  1928. z[0] = byte(extensionsLength >> 8)
  1929. z[1] = byte(extensionsLength)
  1930. z = z[2:]
  1931. // TODO: this function should be reused by CH
  1932. z = marshallExtensionSignatureAlgorithms(z, m.supportedSignatureAlgorithms)
  1933. // certificate_authorities
  1934. if casLength > 0 {
  1935. z[0] = byte(extensionCAs >> 8)
  1936. z[1] = byte(extensionCAs)
  1937. l := 2 + casLength
  1938. z[2] = byte(l >> 8)
  1939. z[3] = byte(l)
  1940. z = z[4:]
  1941. z[0] = uint8(casLength >> 8)
  1942. z[1] = uint8(casLength)
  1943. z = z[2:]
  1944. for _, ca := range m.certificateAuthorities {
  1945. z[0] = uint8(len(ca) >> 8)
  1946. z[1] = uint8(len(ca))
  1947. z = z[2:]
  1948. copy(z, ca)
  1949. z = z[len(ca):]
  1950. }
  1951. }
  1952. m.raw = x
  1953. return
  1954. }
  1955. func (m *certificateRequestMsg13) unmarshal(data []byte) alert {
  1956. m.raw = data
  1957. m.supportedSignatureAlgorithms = nil
  1958. m.certificateAuthorities = nil
  1959. if len(data) < 5 {
  1960. return alertDecodeError
  1961. }
  1962. length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1963. if uint32(len(data))-4 != length {
  1964. return alertDecodeError
  1965. }
  1966. ctxLen := data[4]
  1967. if len(data) < 5+int(ctxLen)+2 {
  1968. return alertDecodeError
  1969. }
  1970. m.requestContext = data[5 : 5+ctxLen]
  1971. data = data[5+ctxLen:]
  1972. extensionsLength := int(data[0])<<8 | int(data[1])
  1973. data = data[2:]
  1974. if len(data) != extensionsLength {
  1975. return alertDecodeError
  1976. }
  1977. for len(data) != 0 {
  1978. if len(data) < 4 {
  1979. return alertDecodeError
  1980. }
  1981. extension := uint16(data[0])<<8 | uint16(data[1])
  1982. length := int(data[2])<<8 | int(data[3])
  1983. data = data[4:]
  1984. if len(data) < length {
  1985. return alertDecodeError
  1986. }
  1987. switch extension {
  1988. case extensionSignatureAlgorithms:
  1989. // TODO: unmarshallExtensionSignatureAlgorithms should be shared with CH and pre-1.3 CV
  1990. // https://tools.ietf.org/html/draft-ietf-tls-tls13-21#section-4.2.3
  1991. var err alert
  1992. m.supportedSignatureAlgorithms, err = unmarshallExtensionSignatureAlgorithms(data, length)
  1993. if err != alertSuccess {
  1994. return err
  1995. }
  1996. case extensionCAs:
  1997. // TODO DRY: share code with CH
  1998. if length < 2 {
  1999. return alertDecodeError
  2000. }
  2001. l := int(data[0])<<8 | int(data[1])
  2002. if l != length-2 || l < 3 {
  2003. return alertDecodeError
  2004. }
  2005. cas := make([]byte, l)
  2006. copy(cas, data[2:])
  2007. m.certificateAuthorities = nil
  2008. for len(cas) > 0 {
  2009. if len(cas) < 2 {
  2010. return alertDecodeError
  2011. }
  2012. caLen := uint16(cas[0])<<8 | uint16(cas[1])
  2013. cas = cas[2:]
  2014. if len(cas) < int(caLen) {
  2015. return alertDecodeError
  2016. }
  2017. m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  2018. cas = cas[caLen:]
  2019. }
  2020. }
  2021. data = data[length:]
  2022. }
  2023. if len(m.supportedSignatureAlgorithms) == 0 {
  2024. return alertDecodeError
  2025. }
  2026. return alertSuccess
  2027. }
  2028. type certificateVerifyMsg struct {
  2029. raw []byte
  2030. hasSignatureAndHash bool
  2031. signatureAlgorithm SignatureScheme
  2032. signature []byte
  2033. }
  2034. func (m *certificateVerifyMsg) equal(i interface{}) bool {
  2035. m1, ok := i.(*certificateVerifyMsg)
  2036. if !ok {
  2037. return false
  2038. }
  2039. return bytes.Equal(m.raw, m1.raw) &&
  2040. m.hasSignatureAndHash == m1.hasSignatureAndHash &&
  2041. m.signatureAlgorithm == m1.signatureAlgorithm &&
  2042. bytes.Equal(m.signature, m1.signature)
  2043. }
  2044. func (m *certificateVerifyMsg) marshal() (x []byte) {
  2045. if m.raw != nil {
  2046. return m.raw
  2047. }
  2048. // See http://tools.ietf.org/html/rfc4346#section-7.4.8
  2049. siglength := len(m.signature)
  2050. length := 2 + siglength
  2051. if m.hasSignatureAndHash {
  2052. length += 2
  2053. }
  2054. x = make([]byte, 4+length)
  2055. x[0] = typeCertificateVerify
  2056. x[1] = uint8(length >> 16)
  2057. x[2] = uint8(length >> 8)
  2058. x[3] = uint8(length)
  2059. y := x[4:]
  2060. if m.hasSignatureAndHash {
  2061. y[0] = uint8(m.signatureAlgorithm >> 8)
  2062. y[1] = uint8(m.signatureAlgorithm)
  2063. y = y[2:]
  2064. }
  2065. y[0] = uint8(siglength >> 8)
  2066. y[1] = uint8(siglength)
  2067. copy(y[2:], m.signature)
  2068. m.raw = x
  2069. return
  2070. }
  2071. func (m *certificateVerifyMsg) unmarshal(data []byte) alert {
  2072. m.raw = data
  2073. if len(data) < 6 {
  2074. return alertDecodeError
  2075. }
  2076. length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  2077. if uint32(len(data))-4 != length {
  2078. return alertDecodeError
  2079. }
  2080. data = data[4:]
  2081. if m.hasSignatureAndHash {
  2082. m.signatureAlgorithm = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
  2083. data = data[2:]
  2084. }
  2085. if len(data) < 2 {
  2086. return alertDecodeError
  2087. }
  2088. siglength := int(data[0])<<8 + int(data[1])
  2089. data = data[2:]
  2090. if len(data) != siglength {
  2091. return alertDecodeError
  2092. }
  2093. m.signature = data
  2094. return alertSuccess
  2095. }
  2096. type newSessionTicketMsg struct {
  2097. raw []byte
  2098. ticket []byte
  2099. }
  2100. func (m *newSessionTicketMsg) equal(i interface{}) bool {
  2101. m1, ok := i.(*newSessionTicketMsg)
  2102. if !ok {
  2103. return false
  2104. }
  2105. return bytes.Equal(m.raw, m1.raw) &&
  2106. bytes.Equal(m.ticket, m1.ticket)
  2107. }
  2108. func (m *newSessionTicketMsg) marshal() (x []byte) {
  2109. if m.raw != nil {
  2110. return m.raw
  2111. }
  2112. // See http://tools.ietf.org/html/rfc5077#section-3.3
  2113. ticketLen := len(m.ticket)
  2114. length := 2 + 4 + ticketLen
  2115. x = make([]byte, 4+length)
  2116. x[0] = typeNewSessionTicket
  2117. x[1] = uint8(length >> 16)
  2118. x[2] = uint8(length >> 8)
  2119. x[3] = uint8(length)
  2120. x[8] = uint8(ticketLen >> 8)
  2121. x[9] = uint8(ticketLen)
  2122. copy(x[10:], m.ticket)
  2123. m.raw = x
  2124. return
  2125. }
  2126. func (m *newSessionTicketMsg) unmarshal(data []byte) alert {
  2127. m.raw = data
  2128. if len(data) < 10 {
  2129. return alertDecodeError
  2130. }
  2131. length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  2132. if uint32(len(data))-4 != length {
  2133. return alertDecodeError
  2134. }
  2135. ticketLen := int(data[8])<<8 + int(data[9])
  2136. if len(data)-10 != ticketLen {
  2137. return alertDecodeError
  2138. }
  2139. m.ticket = data[10:]
  2140. return alertSuccess
  2141. }
  2142. type newSessionTicketMsg13 struct {
  2143. raw []byte
  2144. lifetime uint32
  2145. ageAdd uint32
  2146. nonce []byte
  2147. ticket []byte
  2148. withEarlyDataInfo bool
  2149. maxEarlyDataLength uint32
  2150. }
  2151. func (m *newSessionTicketMsg13) equal(i interface{}) bool {
  2152. m1, ok := i.(*newSessionTicketMsg13)
  2153. if !ok {
  2154. return false
  2155. }
  2156. return bytes.Equal(m.raw, m1.raw) &&
  2157. m.lifetime == m1.lifetime &&
  2158. m.ageAdd == m1.ageAdd &&
  2159. bytes.Equal(m.nonce, m1.nonce) &&
  2160. bytes.Equal(m.ticket, m1.ticket) &&
  2161. m.withEarlyDataInfo == m1.withEarlyDataInfo &&
  2162. m.maxEarlyDataLength == m1.maxEarlyDataLength
  2163. }
  2164. func (m *newSessionTicketMsg13) marshal() (x []byte) {
  2165. if m.raw != nil {
  2166. return m.raw
  2167. }
  2168. // See https://tools.ietf.org/html/draft-ietf-tls-tls13-21#section-4.6.1
  2169. nonceLen := len(m.nonce)
  2170. ticketLen := len(m.ticket)
  2171. length := 13 + nonceLen + ticketLen
  2172. if m.withEarlyDataInfo {
  2173. length += 8
  2174. }
  2175. x = make([]byte, 4+length)
  2176. x[0] = typeNewSessionTicket
  2177. x[1] = uint8(length >> 16)
  2178. x[2] = uint8(length >> 8)
  2179. x[3] = uint8(length)
  2180. x[4] = uint8(m.lifetime >> 24)
  2181. x[5] = uint8(m.lifetime >> 16)
  2182. x[6] = uint8(m.lifetime >> 8)
  2183. x[7] = uint8(m.lifetime)
  2184. x[8] = uint8(m.ageAdd >> 24)
  2185. x[9] = uint8(m.ageAdd >> 16)
  2186. x[10] = uint8(m.ageAdd >> 8)
  2187. x[11] = uint8(m.ageAdd)
  2188. x[12] = uint8(nonceLen)
  2189. copy(x[13:13+nonceLen], m.nonce)
  2190. y := x[13+nonceLen:]
  2191. y[0] = uint8(ticketLen >> 8)
  2192. y[1] = uint8(ticketLen)
  2193. copy(y[2:2+ticketLen], m.ticket)
  2194. if m.withEarlyDataInfo {
  2195. z := y[2+ticketLen:]
  2196. // z[0] is already 0, this is the extensions vector length.
  2197. z[1] = 8
  2198. z[2] = uint8(extensionEarlyData >> 8)
  2199. z[3] = uint8(extensionEarlyData)
  2200. z[5] = 4
  2201. z[6] = uint8(m.maxEarlyDataLength >> 24)
  2202. z[7] = uint8(m.maxEarlyDataLength >> 16)
  2203. z[8] = uint8(m.maxEarlyDataLength >> 8)
  2204. z[9] = uint8(m.maxEarlyDataLength)
  2205. }
  2206. m.raw = x
  2207. return
  2208. }
  2209. func (m *newSessionTicketMsg13) unmarshal(data []byte) alert {
  2210. m.raw = data
  2211. m.maxEarlyDataLength = 0
  2212. m.withEarlyDataInfo = false
  2213. if len(data) < 17 {
  2214. return alertDecodeError
  2215. }
  2216. length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  2217. if uint32(len(data))-4 != length {
  2218. return alertDecodeError
  2219. }
  2220. m.lifetime = uint32(data[4])<<24 | uint32(data[5])<<16 |
  2221. uint32(data[6])<<8 | uint32(data[7])
  2222. m.ageAdd = uint32(data[8])<<24 | uint32(data[9])<<16 |
  2223. uint32(data[10])<<8 | uint32(data[11])
  2224. nonceLen := int(data[12])
  2225. if nonceLen == 0 || 13+nonceLen+2 > len(data) {
  2226. return alertDecodeError
  2227. }
  2228. m.nonce = data[13 : 13+nonceLen]
  2229. data = data[13+nonceLen:]
  2230. ticketLen := int(data[0])<<8 + int(data[1])
  2231. if ticketLen == 0 || 2+ticketLen+2 > len(data) {
  2232. return alertDecodeError
  2233. }
  2234. m.ticket = data[2 : 2+ticketLen]
  2235. data = data[2+ticketLen:]
  2236. extLen := int(data[0])<<8 + int(data[1])
  2237. if extLen != len(data)-2 {
  2238. return alertDecodeError
  2239. }
  2240. data = data[2:]
  2241. for len(data) > 0 {
  2242. if len(data) < 4 {
  2243. return alertDecodeError
  2244. }
  2245. extType := uint16(data[0])<<8 + uint16(data[1])
  2246. length := int(data[2])<<8 + int(data[3])
  2247. data = data[4:]
  2248. switch extType {
  2249. case extensionEarlyData:
  2250. if length != 4 {
  2251. return alertDecodeError
  2252. }
  2253. m.withEarlyDataInfo = true
  2254. m.maxEarlyDataLength = uint32(data[0])<<24 | uint32(data[1])<<16 |
  2255. uint32(data[2])<<8 | uint32(data[3])
  2256. }
  2257. data = data[length:]
  2258. }
  2259. return alertSuccess
  2260. }
  2261. type endOfEarlyDataMsg struct {
  2262. }
  2263. func (*endOfEarlyDataMsg) marshal() []byte {
  2264. return []byte{typeEndOfEarlyData, 0, 0, 0}
  2265. }
  2266. func (*endOfEarlyDataMsg) unmarshal(data []byte) alert {
  2267. if len(data) != 4 {
  2268. return alertDecodeError
  2269. }
  2270. return alertSuccess
  2271. }
  2272. type helloRequestMsg struct {
  2273. }
  2274. func (*helloRequestMsg) marshal() []byte {
  2275. return []byte{typeHelloRequest, 0, 0, 0}
  2276. }
  2277. func (*helloRequestMsg) unmarshal(data []byte) alert {
  2278. if len(data) != 4 {
  2279. return alertDecodeError
  2280. }
  2281. return alertSuccess
  2282. }
  2283. func eqUint16s(x, y []uint16) bool {
  2284. if len(x) != len(y) {
  2285. return false
  2286. }
  2287. for i, v := range x {
  2288. if y[i] != v {
  2289. return false
  2290. }
  2291. }
  2292. return true
  2293. }
  2294. func eqCurveIDs(x, y []CurveID) bool {
  2295. if len(x) != len(y) {
  2296. return false
  2297. }
  2298. for i, v := range x {
  2299. if y[i] != v {
  2300. return false
  2301. }
  2302. }
  2303. return true
  2304. }
  2305. func eqStrings(x, y []string) bool {
  2306. if len(x) != len(y) {
  2307. return false
  2308. }
  2309. for i, v := range x {
  2310. if y[i] != v {
  2311. return false
  2312. }
  2313. }
  2314. return true
  2315. }
  2316. func eqByteSlices(x, y [][]byte) bool {
  2317. if len(x) != len(y) {
  2318. return false
  2319. }
  2320. for i, v := range x {
  2321. if !bytes.Equal(v, y[i]) {
  2322. return false
  2323. }
  2324. }
  2325. return true
  2326. }
  2327. func eqSignatureAlgorithms(x, y []SignatureScheme) bool {
  2328. if len(x) != len(y) {
  2329. return false
  2330. }
  2331. for i, v := range x {
  2332. if v != y[i] {
  2333. return false
  2334. }
  2335. }
  2336. return true
  2337. }
  2338. func eqKeyShares(x, y []keyShare) bool {
  2339. if len(x) != len(y) {
  2340. return false
  2341. }
  2342. for i := range x {
  2343. if x[i].group != y[i].group {
  2344. return false
  2345. }
  2346. if !bytes.Equal(x[i].data, y[i].data) {
  2347. return false
  2348. }
  2349. }
  2350. return true
  2351. }
  2352. func findExtension(data []byte, extensionType uint16) []byte {
  2353. for len(data) != 0 {
  2354. if len(data) < 4 {
  2355. return nil
  2356. }
  2357. extension := uint16(data[0])<<8 | uint16(data[1])
  2358. length := int(data[2])<<8 | int(data[3])
  2359. data = data[4:]
  2360. if len(data) < length {
  2361. return nil
  2362. }
  2363. if extension == extensionType {
  2364. return data[:length]
  2365. }
  2366. data = data[length:]
  2367. }
  2368. return nil
  2369. }