You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

2435 line
53 KiB

  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package tls
  5. import (
  6. "bytes"
  7. "strings"
  8. )
  9. type clientHelloMsg struct {
  10. raw []byte
  11. rawTruncated []byte // for PSK binding
  12. vers uint16
  13. random []byte
  14. sessionId []byte
  15. cipherSuites []uint16
  16. compressionMethods []uint8
  17. nextProtoNeg bool
  18. serverName string
  19. ocspStapling bool
  20. scts bool
  21. supportedCurves []CurveID
  22. supportedPoints []uint8
  23. ticketSupported bool
  24. sessionTicket []uint8
  25. supportedSignatureAlgorithms []SignatureScheme
  26. secureRenegotiation []byte
  27. secureRenegotiationSupported bool
  28. alpnProtocols []string
  29. keyShares []keyShare
  30. supportedVersions []uint16
  31. psks []psk
  32. pskKeyExchangeModes []uint8
  33. earlyData bool
  34. }
  35. func (m *clientHelloMsg) equal(i interface{}) bool {
  36. m1, ok := i.(*clientHelloMsg)
  37. if !ok {
  38. return false
  39. }
  40. return bytes.Equal(m.raw, m1.raw) &&
  41. m.vers == m1.vers &&
  42. bytes.Equal(m.random, m1.random) &&
  43. bytes.Equal(m.sessionId, m1.sessionId) &&
  44. eqUint16s(m.cipherSuites, m1.cipherSuites) &&
  45. bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
  46. m.nextProtoNeg == m1.nextProtoNeg &&
  47. m.serverName == m1.serverName &&
  48. m.ocspStapling == m1.ocspStapling &&
  49. m.scts == m1.scts &&
  50. eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
  51. bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
  52. m.ticketSupported == m1.ticketSupported &&
  53. bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
  54. eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) &&
  55. m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
  56. bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
  57. eqStrings(m.alpnProtocols, m1.alpnProtocols) &&
  58. eqKeyShares(m.keyShares, m1.keyShares) &&
  59. eqUint16s(m.supportedVersions, m1.supportedVersions) &&
  60. m.earlyData == m1.earlyData
  61. }
  62. func (m *clientHelloMsg) marshal() []byte {
  63. if m.raw != nil {
  64. return m.raw
  65. }
  66. length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
  67. numExtensions := 0
  68. extensionsLength := 0
  69. if m.nextProtoNeg {
  70. numExtensions++
  71. }
  72. if m.ocspStapling {
  73. extensionsLength += 1 + 2 + 2
  74. numExtensions++
  75. }
  76. if len(m.serverName) > 0 {
  77. extensionsLength += 5 + len(m.serverName)
  78. numExtensions++
  79. }
  80. if len(m.supportedCurves) > 0 {
  81. extensionsLength += 2 + 2*len(m.supportedCurves)
  82. numExtensions++
  83. }
  84. if len(m.supportedPoints) > 0 {
  85. extensionsLength += 1 + len(m.supportedPoints)
  86. numExtensions++
  87. }
  88. if m.ticketSupported {
  89. extensionsLength += len(m.sessionTicket)
  90. numExtensions++
  91. }
  92. if len(m.supportedSignatureAlgorithms) > 0 {
  93. extensionsLength += 2 + 2*len(m.supportedSignatureAlgorithms)
  94. numExtensions++
  95. }
  96. if m.secureRenegotiationSupported {
  97. extensionsLength += 1 + len(m.secureRenegotiation)
  98. numExtensions++
  99. }
  100. if len(m.alpnProtocols) > 0 {
  101. extensionsLength += 2
  102. for _, s := range m.alpnProtocols {
  103. if l := len(s); l == 0 || l > 255 {
  104. panic("invalid ALPN protocol")
  105. }
  106. extensionsLength++
  107. extensionsLength += len(s)
  108. }
  109. numExtensions++
  110. }
  111. if m.scts {
  112. numExtensions++
  113. }
  114. if len(m.keyShares) > 0 {
  115. extensionsLength += 2
  116. for _, k := range m.keyShares {
  117. extensionsLength += 4 + len(k.data)
  118. }
  119. numExtensions++
  120. }
  121. if len(m.supportedVersions) > 0 {
  122. extensionsLength += 1 + 2*len(m.supportedVersions)
  123. numExtensions++
  124. }
  125. if m.earlyData {
  126. numExtensions++
  127. }
  128. if numExtensions > 0 {
  129. extensionsLength += 4 * numExtensions
  130. length += 2 + extensionsLength
  131. }
  132. x := make([]byte, 4+length)
  133. x[0] = typeClientHello
  134. x[1] = uint8(length >> 16)
  135. x[2] = uint8(length >> 8)
  136. x[3] = uint8(length)
  137. x[4] = uint8(m.vers >> 8)
  138. x[5] = uint8(m.vers)
  139. copy(x[6:38], m.random)
  140. x[38] = uint8(len(m.sessionId))
  141. copy(x[39:39+len(m.sessionId)], m.sessionId)
  142. y := x[39+len(m.sessionId):]
  143. y[0] = uint8(len(m.cipherSuites) >> 7)
  144. y[1] = uint8(len(m.cipherSuites) << 1)
  145. for i, suite := range m.cipherSuites {
  146. y[2+i*2] = uint8(suite >> 8)
  147. y[3+i*2] = uint8(suite)
  148. }
  149. z := y[2+len(m.cipherSuites)*2:]
  150. z[0] = uint8(len(m.compressionMethods))
  151. copy(z[1:], m.compressionMethods)
  152. z = z[1+len(m.compressionMethods):]
  153. if numExtensions > 0 {
  154. z[0] = byte(extensionsLength >> 8)
  155. z[1] = byte(extensionsLength)
  156. z = z[2:]
  157. }
  158. if m.nextProtoNeg {
  159. z[0] = byte(extensionNextProtoNeg >> 8)
  160. z[1] = byte(extensionNextProtoNeg & 0xff)
  161. // The length is always 0
  162. z = z[4:]
  163. }
  164. if len(m.serverName) > 0 {
  165. z[0] = byte(extensionServerName >> 8)
  166. z[1] = byte(extensionServerName & 0xff)
  167. l := len(m.serverName) + 5
  168. z[2] = byte(l >> 8)
  169. z[3] = byte(l)
  170. z = z[4:]
  171. // RFC 3546, section 3.1
  172. //
  173. // struct {
  174. // NameType name_type;
  175. // select (name_type) {
  176. // case host_name: HostName;
  177. // } name;
  178. // } ServerName;
  179. //
  180. // enum {
  181. // host_name(0), (255)
  182. // } NameType;
  183. //
  184. // opaque HostName<1..2^16-1>;
  185. //
  186. // struct {
  187. // ServerName server_name_list<1..2^16-1>
  188. // } ServerNameList;
  189. z[0] = byte((len(m.serverName) + 3) >> 8)
  190. z[1] = byte(len(m.serverName) + 3)
  191. z[3] = byte(len(m.serverName) >> 8)
  192. z[4] = byte(len(m.serverName))
  193. copy(z[5:], []byte(m.serverName))
  194. z = z[l:]
  195. }
  196. if m.ocspStapling {
  197. // RFC 4366, section 3.6
  198. z[0] = byte(extensionStatusRequest >> 8)
  199. z[1] = byte(extensionStatusRequest)
  200. z[2] = 0
  201. z[3] = 5
  202. z[4] = 1 // OCSP type
  203. // Two zero valued uint16s for the two lengths.
  204. z = z[9:]
  205. }
  206. if len(m.supportedCurves) > 0 {
  207. // http://tools.ietf.org/html/rfc4492#section-5.5.1
  208. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.4
  209. z[0] = byte(extensionSupportedCurves >> 8)
  210. z[1] = byte(extensionSupportedCurves)
  211. l := 2 + 2*len(m.supportedCurves)
  212. z[2] = byte(l >> 8)
  213. z[3] = byte(l)
  214. l -= 2
  215. z[4] = byte(l >> 8)
  216. z[5] = byte(l)
  217. z = z[6:]
  218. for _, curve := range m.supportedCurves {
  219. z[0] = byte(curve >> 8)
  220. z[1] = byte(curve)
  221. z = z[2:]
  222. }
  223. }
  224. if len(m.supportedPoints) > 0 {
  225. // http://tools.ietf.org/html/rfc4492#section-5.5.2
  226. z[0] = byte(extensionSupportedPoints >> 8)
  227. z[1] = byte(extensionSupportedPoints)
  228. l := 1 + len(m.supportedPoints)
  229. z[2] = byte(l >> 8)
  230. z[3] = byte(l)
  231. l--
  232. z[4] = byte(l)
  233. z = z[5:]
  234. for _, pointFormat := range m.supportedPoints {
  235. z[0] = pointFormat
  236. z = z[1:]
  237. }
  238. }
  239. if m.ticketSupported {
  240. // http://tools.ietf.org/html/rfc5077#section-3.2
  241. z[0] = byte(extensionSessionTicket >> 8)
  242. z[1] = byte(extensionSessionTicket)
  243. l := len(m.sessionTicket)
  244. z[2] = byte(l >> 8)
  245. z[3] = byte(l)
  246. z = z[4:]
  247. copy(z, m.sessionTicket)
  248. z = z[len(m.sessionTicket):]
  249. }
  250. if len(m.supportedSignatureAlgorithms) > 0 {
  251. // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
  252. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.3
  253. z[0] = byte(extensionSignatureAlgorithms >> 8)
  254. z[1] = byte(extensionSignatureAlgorithms)
  255. l := 2 + 2*len(m.supportedSignatureAlgorithms)
  256. z[2] = byte(l >> 8)
  257. z[3] = byte(l)
  258. z = z[4:]
  259. l -= 2
  260. z[0] = byte(l >> 8)
  261. z[1] = byte(l)
  262. z = z[2:]
  263. for _, sigAlgo := range m.supportedSignatureAlgorithms {
  264. z[0] = byte(sigAlgo >> 8)
  265. z[1] = byte(sigAlgo)
  266. z = z[2:]
  267. }
  268. }
  269. if m.secureRenegotiationSupported {
  270. z[0] = byte(extensionRenegotiationInfo >> 8)
  271. z[1] = byte(extensionRenegotiationInfo & 0xff)
  272. z[2] = 0
  273. z[3] = byte(len(m.secureRenegotiation) + 1)
  274. z[4] = byte(len(m.secureRenegotiation))
  275. z = z[5:]
  276. copy(z, m.secureRenegotiation)
  277. z = z[len(m.secureRenegotiation):]
  278. }
  279. if len(m.alpnProtocols) > 0 {
  280. z[0] = byte(extensionALPN >> 8)
  281. z[1] = byte(extensionALPN & 0xff)
  282. lengths := z[2:]
  283. z = z[6:]
  284. stringsLength := 0
  285. for _, s := range m.alpnProtocols {
  286. l := len(s)
  287. z[0] = byte(l)
  288. copy(z[1:], s)
  289. z = z[1+l:]
  290. stringsLength += 1 + l
  291. }
  292. lengths[2] = byte(stringsLength >> 8)
  293. lengths[3] = byte(stringsLength)
  294. stringsLength += 2
  295. lengths[0] = byte(stringsLength >> 8)
  296. lengths[1] = byte(stringsLength)
  297. }
  298. if m.scts {
  299. // https://tools.ietf.org/html/rfc6962#section-3.3.1
  300. z[0] = byte(extensionSCT >> 8)
  301. z[1] = byte(extensionSCT)
  302. // zero uint16 for the zero-length extension_data
  303. z = z[4:]
  304. }
  305. if len(m.keyShares) > 0 {
  306. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.5
  307. z[0] = byte(extensionKeyShare >> 8)
  308. z[1] = byte(extensionKeyShare)
  309. lengths := z[2:]
  310. z = z[6:]
  311. totalLength := 0
  312. for _, ks := range m.keyShares {
  313. z[0] = byte(ks.group >> 8)
  314. z[1] = byte(ks.group)
  315. z[2] = byte(len(ks.data) >> 8)
  316. z[3] = byte(len(ks.data))
  317. copy(z[4:], ks.data)
  318. z = z[4+len(ks.data):]
  319. totalLength += 4 + len(ks.data)
  320. }
  321. lengths[2] = byte(totalLength >> 8)
  322. lengths[3] = byte(totalLength)
  323. totalLength += 2
  324. lengths[0] = byte(totalLength >> 8)
  325. lengths[1] = byte(totalLength)
  326. }
  327. if len(m.supportedVersions) > 0 {
  328. z[0] = byte(extensionSupportedVersions >> 8)
  329. z[1] = byte(extensionSupportedVersions)
  330. l := 1 + 2*len(m.supportedVersions)
  331. z[2] = byte(l >> 8)
  332. z[3] = byte(l)
  333. l -= 1
  334. z[4] = byte(l)
  335. z = z[5:]
  336. for _, v := range m.supportedVersions {
  337. z[0] = byte(v >> 8)
  338. z[1] = byte(v)
  339. z = z[2:]
  340. }
  341. }
  342. if m.earlyData {
  343. z[0] = byte(extensionEarlyData >> 8)
  344. z[1] = byte(extensionEarlyData)
  345. z = z[4:]
  346. }
  347. m.raw = x
  348. return x
  349. }
  350. func (m *clientHelloMsg) unmarshal(data []byte) alert {
  351. if len(data) < 42 {
  352. return alertDecodeError
  353. }
  354. m.raw = data
  355. m.vers = uint16(data[4])<<8 | uint16(data[5])
  356. m.random = data[6:38]
  357. sessionIdLen := int(data[38])
  358. if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
  359. return alertDecodeError
  360. }
  361. m.sessionId = data[39 : 39+sessionIdLen]
  362. data = data[39+sessionIdLen:]
  363. bindersOffset := 39 + sessionIdLen
  364. if len(data) < 2 {
  365. return alertDecodeError
  366. }
  367. // cipherSuiteLen is the number of bytes of cipher suite numbers. Since
  368. // they are uint16s, the number must be even.
  369. cipherSuiteLen := int(data[0])<<8 | int(data[1])
  370. if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
  371. return alertDecodeError
  372. }
  373. numCipherSuites := cipherSuiteLen / 2
  374. m.cipherSuites = make([]uint16, numCipherSuites)
  375. for i := 0; i < numCipherSuites; i++ {
  376. m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
  377. if m.cipherSuites[i] == scsvRenegotiation {
  378. m.secureRenegotiationSupported = true
  379. }
  380. }
  381. data = data[2+cipherSuiteLen:]
  382. bindersOffset += 2 + cipherSuiteLen
  383. if len(data) < 1 {
  384. return alertDecodeError
  385. }
  386. compressionMethodsLen := int(data[0])
  387. if len(data) < 1+compressionMethodsLen {
  388. return alertDecodeError
  389. }
  390. m.compressionMethods = data[1 : 1+compressionMethodsLen]
  391. data = data[1+compressionMethodsLen:]
  392. bindersOffset += 1 + compressionMethodsLen
  393. m.nextProtoNeg = false
  394. m.serverName = ""
  395. m.ocspStapling = false
  396. m.ticketSupported = false
  397. m.sessionTicket = nil
  398. m.supportedSignatureAlgorithms = nil
  399. m.alpnProtocols = nil
  400. m.scts = false
  401. m.keyShares = nil
  402. m.supportedVersions = nil
  403. m.psks = nil
  404. m.pskKeyExchangeModes = nil
  405. m.earlyData = false
  406. if len(data) == 0 {
  407. // ClientHello is optionally followed by extension data
  408. return alertSuccess
  409. }
  410. if len(data) < 2 {
  411. return alertDecodeError
  412. }
  413. extensionsLength := int(data[0])<<8 | int(data[1])
  414. data = data[2:]
  415. bindersOffset += 2
  416. if extensionsLength != len(data) {
  417. return alertDecodeError
  418. }
  419. for len(data) != 0 {
  420. if len(data) < 4 {
  421. return alertDecodeError
  422. }
  423. extension := uint16(data[0])<<8 | uint16(data[1])
  424. length := int(data[2])<<8 | int(data[3])
  425. data = data[4:]
  426. bindersOffset += 4
  427. if len(data) < length {
  428. return alertDecodeError
  429. }
  430. switch extension {
  431. case extensionServerName:
  432. d := data[:length]
  433. if len(d) < 2 {
  434. return alertDecodeError
  435. }
  436. namesLen := int(d[0])<<8 | int(d[1])
  437. d = d[2:]
  438. if len(d) != namesLen {
  439. return alertDecodeError
  440. }
  441. for len(d) > 0 {
  442. if len(d) < 3 {
  443. return alertDecodeError
  444. }
  445. nameType := d[0]
  446. nameLen := int(d[1])<<8 | int(d[2])
  447. d = d[3:]
  448. if len(d) < nameLen {
  449. return alertDecodeError
  450. }
  451. if nameType == 0 {
  452. m.serverName = string(d[:nameLen])
  453. // An SNI value may not include a
  454. // trailing dot. See
  455. // https://tools.ietf.org/html/rfc6066#section-3.
  456. if strings.HasSuffix(m.serverName, ".") {
  457. // TODO use alertDecodeError?
  458. return alertUnexpectedMessage
  459. }
  460. break
  461. }
  462. d = d[nameLen:]
  463. }
  464. case extensionNextProtoNeg:
  465. if length > 0 {
  466. return alertDecodeError
  467. }
  468. m.nextProtoNeg = true
  469. case extensionStatusRequest:
  470. m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
  471. case extensionSupportedCurves:
  472. // http://tools.ietf.org/html/rfc4492#section-5.5.1
  473. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.4
  474. if length < 2 {
  475. return alertDecodeError
  476. }
  477. l := int(data[0])<<8 | int(data[1])
  478. if l%2 == 1 || length != l+2 {
  479. return alertDecodeError
  480. }
  481. numCurves := l / 2
  482. m.supportedCurves = make([]CurveID, numCurves)
  483. d := data[2:]
  484. for i := 0; i < numCurves; i++ {
  485. m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
  486. d = d[2:]
  487. }
  488. case extensionSupportedPoints:
  489. // http://tools.ietf.org/html/rfc4492#section-5.5.2
  490. if length < 1 {
  491. return alertDecodeError
  492. }
  493. l := int(data[0])
  494. if length != l+1 {
  495. return alertDecodeError
  496. }
  497. m.supportedPoints = make([]uint8, l)
  498. copy(m.supportedPoints, data[1:])
  499. case extensionSessionTicket:
  500. // http://tools.ietf.org/html/rfc5077#section-3.2
  501. m.ticketSupported = true
  502. m.sessionTicket = data[:length]
  503. case extensionSignatureAlgorithms:
  504. // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
  505. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.3
  506. if length < 2 || length&1 != 0 {
  507. return alertDecodeError
  508. }
  509. l := int(data[0])<<8 | int(data[1])
  510. if l != length-2 {
  511. return alertDecodeError
  512. }
  513. n := l / 2
  514. d := data[2:]
  515. m.supportedSignatureAlgorithms = make([]SignatureScheme, n)
  516. for i := range m.supportedSignatureAlgorithms {
  517. m.supportedSignatureAlgorithms[i] = SignatureScheme(d[0])<<8 | SignatureScheme(d[1])
  518. d = d[2:]
  519. }
  520. case extensionRenegotiationInfo:
  521. if length == 0 {
  522. return alertDecodeError
  523. }
  524. d := data[:length]
  525. l := int(d[0])
  526. d = d[1:]
  527. if l != len(d) {
  528. return alertDecodeError
  529. }
  530. m.secureRenegotiation = d
  531. m.secureRenegotiationSupported = true
  532. case extensionALPN:
  533. if length < 2 {
  534. return alertDecodeError
  535. }
  536. l := int(data[0])<<8 | int(data[1])
  537. if l != length-2 {
  538. return alertDecodeError
  539. }
  540. d := data[2:length]
  541. for len(d) != 0 {
  542. stringLen := int(d[0])
  543. d = d[1:]
  544. if stringLen == 0 || stringLen > len(d) {
  545. return alertDecodeError
  546. }
  547. m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
  548. d = d[stringLen:]
  549. }
  550. case extensionSCT:
  551. m.scts = true
  552. if length != 0 {
  553. return alertDecodeError
  554. }
  555. case extensionKeyShare:
  556. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.5
  557. if length < 2 {
  558. return alertDecodeError
  559. }
  560. l := int(data[0])<<8 | int(data[1])
  561. if l != length-2 {
  562. return alertDecodeError
  563. }
  564. d := data[2:length]
  565. for len(d) != 0 {
  566. if len(d) < 4 {
  567. return alertDecodeError
  568. }
  569. dataLen := int(d[2])<<8 | int(d[3])
  570. if dataLen == 0 || 4+dataLen > len(d) {
  571. return alertDecodeError
  572. }
  573. m.keyShares = append(m.keyShares, keyShare{
  574. group: CurveID(d[0])<<8 | CurveID(d[1]),
  575. data: d[4 : 4+dataLen],
  576. })
  577. d = d[4+dataLen:]
  578. }
  579. case extensionSupportedVersions:
  580. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.1
  581. if length < 1 {
  582. return alertDecodeError
  583. }
  584. l := int(data[0])
  585. if l%2 == 1 || length != l+1 {
  586. return alertDecodeError
  587. }
  588. n := l / 2
  589. d := data[1:]
  590. for i := 0; i < n; i++ {
  591. v := uint16(d[0])<<8 + uint16(d[1])
  592. m.supportedVersions = append(m.supportedVersions, v)
  593. d = d[2:]
  594. }
  595. case extensionPreSharedKey:
  596. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6
  597. if length < 2 {
  598. return alertDecodeError
  599. }
  600. // Ensure this extension is the last one in the Client Hello
  601. if len(data) != length {
  602. return alertIllegalParameter
  603. }
  604. li := int(data[0])<<8 | int(data[1])
  605. if 2+li+2 > length {
  606. return alertDecodeError
  607. }
  608. d := data[2 : 2+li]
  609. bindersOffset += 2 + li
  610. for len(d) > 0 {
  611. if len(d) < 6 {
  612. return alertDecodeError
  613. }
  614. l := int(d[0])<<8 | int(d[1])
  615. if len(d) < 2+l+4 {
  616. return alertDecodeError
  617. }
  618. m.psks = append(m.psks, psk{
  619. identity: d[2 : 2+l],
  620. obfTicketAge: uint32(d[l+2])<<24 | uint32(d[l+3])<<16 |
  621. uint32(d[l+4])<<8 | uint32(d[l+5]),
  622. })
  623. d = d[2+l+4:]
  624. }
  625. lb := int(data[li+2])<<8 | int(data[li+3])
  626. d = data[2+li+2:]
  627. if lb != len(d) || lb == 0 {
  628. return alertDecodeError
  629. }
  630. i := 0
  631. for len(d) > 0 {
  632. if i >= len(m.psks) {
  633. return alertIllegalParameter
  634. }
  635. if len(d) < 1 {
  636. return alertDecodeError
  637. }
  638. l := int(d[0])
  639. if l > len(d)-1 {
  640. return alertDecodeError
  641. }
  642. if i >= len(m.psks) {
  643. return alertIllegalParameter
  644. }
  645. m.psks[i].binder = d[1 : 1+l]
  646. d = d[1+l:]
  647. i++
  648. }
  649. if i != len(m.psks) {
  650. return alertIllegalParameter
  651. }
  652. m.rawTruncated = m.raw[:bindersOffset]
  653. case extensionPSKKeyExchangeModes:
  654. if length < 2 {
  655. return alertDecodeError
  656. }
  657. l := int(data[0])
  658. if length != l+1 {
  659. return alertDecodeError
  660. }
  661. m.pskKeyExchangeModes = data[1:length]
  662. case extensionEarlyData:
  663. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8
  664. m.earlyData = true
  665. }
  666. data = data[length:]
  667. bindersOffset += length
  668. }
  669. return alertSuccess
  670. }
  671. type serverHelloMsg struct {
  672. raw []byte
  673. vers uint16
  674. random []byte
  675. sessionId []byte
  676. cipherSuite uint16
  677. compressionMethod uint8
  678. nextProtoNeg bool
  679. nextProtos []string
  680. ocspStapling bool
  681. scts [][]byte
  682. ticketSupported bool
  683. secureRenegotiation []byte
  684. secureRenegotiationSupported bool
  685. alpnProtocol string
  686. // TLS 1.3
  687. keyShare keyShare
  688. psk bool
  689. pskIdentity uint16
  690. }
  691. func (m *serverHelloMsg) equal(i interface{}) bool {
  692. m1, ok := i.(*serverHelloMsg)
  693. if !ok {
  694. return false
  695. }
  696. if len(m.scts) != len(m1.scts) {
  697. return false
  698. }
  699. for i, sct := range m.scts {
  700. if !bytes.Equal(sct, m1.scts[i]) {
  701. return false
  702. }
  703. }
  704. return bytes.Equal(m.raw, m1.raw) &&
  705. m.vers == m1.vers &&
  706. bytes.Equal(m.random, m1.random) &&
  707. bytes.Equal(m.sessionId, m1.sessionId) &&
  708. m.cipherSuite == m1.cipherSuite &&
  709. m.compressionMethod == m1.compressionMethod &&
  710. m.nextProtoNeg == m1.nextProtoNeg &&
  711. eqStrings(m.nextProtos, m1.nextProtos) &&
  712. m.ocspStapling == m1.ocspStapling &&
  713. m.ticketSupported == m1.ticketSupported &&
  714. m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
  715. bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
  716. m.alpnProtocol == m1.alpnProtocol &&
  717. m.keyShare.group == m1.keyShare.group &&
  718. bytes.Equal(m.keyShare.data, m1.keyShare.data) &&
  719. m.psk == m1.psk &&
  720. m.pskIdentity == m1.pskIdentity
  721. }
  722. func (m *serverHelloMsg) marshal() []byte {
  723. if m.raw != nil {
  724. return m.raw
  725. }
  726. oldTLS13Draft := m.vers >= VersionTLS13Draft18 && m.vers <= VersionTLS13Draft21
  727. length := 38 + len(m.sessionId)
  728. if oldTLS13Draft {
  729. // no compression method, no session ID.
  730. length = 36
  731. }
  732. numExtensions := 0
  733. extensionsLength := 0
  734. nextProtoLen := 0
  735. if m.nextProtoNeg {
  736. numExtensions++
  737. for _, v := range m.nextProtos {
  738. nextProtoLen += len(v)
  739. }
  740. nextProtoLen += len(m.nextProtos)
  741. extensionsLength += nextProtoLen
  742. }
  743. if m.ocspStapling {
  744. numExtensions++
  745. }
  746. if m.ticketSupported {
  747. numExtensions++
  748. }
  749. if m.secureRenegotiationSupported {
  750. extensionsLength += 1 + len(m.secureRenegotiation)
  751. numExtensions++
  752. }
  753. if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
  754. if alpnLen >= 256 {
  755. panic("invalid ALPN protocol")
  756. }
  757. extensionsLength += 2 + 1 + alpnLen
  758. numExtensions++
  759. }
  760. sctLen := 0
  761. if len(m.scts) > 0 {
  762. for _, sct := range m.scts {
  763. sctLen += len(sct) + 2
  764. }
  765. extensionsLength += 2 + sctLen
  766. numExtensions++
  767. }
  768. if m.keyShare.group != 0 {
  769. extensionsLength += 4 + len(m.keyShare.data)
  770. numExtensions++
  771. }
  772. if m.psk {
  773. extensionsLength += 2
  774. numExtensions++
  775. }
  776. // supported_versions extension
  777. if m.vers >= VersionTLS13 {
  778. extensionsLength += 2
  779. numExtensions++
  780. }
  781. if numExtensions > 0 {
  782. extensionsLength += 4 * numExtensions
  783. length += 2 + extensionsLength
  784. }
  785. x := make([]byte, 4+length)
  786. x[0] = typeServerHello
  787. x[1] = uint8(length >> 16)
  788. x[2] = uint8(length >> 8)
  789. x[3] = uint8(length)
  790. if m.vers >= VersionTLS13 {
  791. x[4] = 3
  792. x[5] = 3
  793. } else {
  794. x[4] = uint8(m.vers >> 8)
  795. x[5] = uint8(m.vers)
  796. }
  797. copy(x[6:38], m.random)
  798. z := x[38:]
  799. if !oldTLS13Draft {
  800. x[38] = uint8(len(m.sessionId))
  801. copy(x[39:39+len(m.sessionId)], m.sessionId)
  802. z = x[39+len(m.sessionId):]
  803. }
  804. z[0] = uint8(m.cipherSuite >> 8)
  805. z[1] = uint8(m.cipherSuite)
  806. if oldTLS13Draft {
  807. // no compression method in older TLS 1.3 drafts.
  808. z = z[2:]
  809. } else {
  810. z[2] = m.compressionMethod
  811. z = z[3:]
  812. }
  813. if numExtensions > 0 {
  814. z[0] = byte(extensionsLength >> 8)
  815. z[1] = byte(extensionsLength)
  816. z = z[2:]
  817. }
  818. if m.vers >= VersionTLS13 {
  819. z[0] = byte(extensionSupportedVersions >> 8)
  820. z[1] = byte(extensionSupportedVersions)
  821. z[3] = 2
  822. z[4] = uint8(m.vers >> 8)
  823. z[5] = uint8(m.vers)
  824. z = z[6:]
  825. }
  826. if m.nextProtoNeg {
  827. z[0] = byte(extensionNextProtoNeg >> 8)
  828. z[1] = byte(extensionNextProtoNeg & 0xff)
  829. z[2] = byte(nextProtoLen >> 8)
  830. z[3] = byte(nextProtoLen)
  831. z = z[4:]
  832. for _, v := range m.nextProtos {
  833. l := len(v)
  834. if l > 255 {
  835. l = 255
  836. }
  837. z[0] = byte(l)
  838. copy(z[1:], []byte(v[0:l]))
  839. z = z[1+l:]
  840. }
  841. }
  842. if m.ocspStapling {
  843. z[0] = byte(extensionStatusRequest >> 8)
  844. z[1] = byte(extensionStatusRequest)
  845. z = z[4:]
  846. }
  847. if m.ticketSupported {
  848. z[0] = byte(extensionSessionTicket >> 8)
  849. z[1] = byte(extensionSessionTicket)
  850. z = z[4:]
  851. }
  852. if m.secureRenegotiationSupported {
  853. z[0] = byte(extensionRenegotiationInfo >> 8)
  854. z[1] = byte(extensionRenegotiationInfo & 0xff)
  855. z[2] = 0
  856. z[3] = byte(len(m.secureRenegotiation) + 1)
  857. z[4] = byte(len(m.secureRenegotiation))
  858. z = z[5:]
  859. copy(z, m.secureRenegotiation)
  860. z = z[len(m.secureRenegotiation):]
  861. }
  862. if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
  863. z[0] = byte(extensionALPN >> 8)
  864. z[1] = byte(extensionALPN & 0xff)
  865. l := 2 + 1 + alpnLen
  866. z[2] = byte(l >> 8)
  867. z[3] = byte(l)
  868. l -= 2
  869. z[4] = byte(l >> 8)
  870. z[5] = byte(l)
  871. l -= 1
  872. z[6] = byte(l)
  873. copy(z[7:], []byte(m.alpnProtocol))
  874. z = z[7+alpnLen:]
  875. }
  876. if sctLen > 0 {
  877. z[0] = byte(extensionSCT >> 8)
  878. z[1] = byte(extensionSCT)
  879. l := sctLen + 2
  880. z[2] = byte(l >> 8)
  881. z[3] = byte(l)
  882. z[4] = byte(sctLen >> 8)
  883. z[5] = byte(sctLen)
  884. z = z[6:]
  885. for _, sct := range m.scts {
  886. z[0] = byte(len(sct) >> 8)
  887. z[1] = byte(len(sct))
  888. copy(z[2:], sct)
  889. z = z[len(sct)+2:]
  890. }
  891. }
  892. if m.keyShare.group != 0 {
  893. z[0] = uint8(extensionKeyShare >> 8)
  894. z[1] = uint8(extensionKeyShare)
  895. l := 4 + len(m.keyShare.data)
  896. z[2] = uint8(l >> 8)
  897. z[3] = uint8(l)
  898. z[4] = uint8(m.keyShare.group >> 8)
  899. z[5] = uint8(m.keyShare.group)
  900. l -= 4
  901. z[6] = uint8(l >> 8)
  902. z[7] = uint8(l)
  903. copy(z[8:], m.keyShare.data)
  904. z = z[8+l:]
  905. }
  906. if m.psk {
  907. z[0] = byte(extensionPreSharedKey >> 8)
  908. z[1] = byte(extensionPreSharedKey)
  909. z[3] = 2
  910. z[4] = byte(m.pskIdentity >> 8)
  911. z[5] = byte(m.pskIdentity)
  912. z = z[6:]
  913. }
  914. m.raw = x
  915. return x
  916. }
  917. func (m *serverHelloMsg) unmarshal(data []byte) alert {
  918. if len(data) < 42 {
  919. return alertDecodeError
  920. }
  921. m.raw = data
  922. m.vers = uint16(data[4])<<8 | uint16(data[5])
  923. oldTLS13Draft := m.vers >= VersionTLS13Draft18 && m.vers <= VersionTLS13Draft21
  924. m.random = data[6:38]
  925. if !oldTLS13Draft {
  926. sessionIdLen := int(data[38])
  927. if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
  928. return alertDecodeError
  929. }
  930. m.sessionId = data[39 : 39+sessionIdLen]
  931. data = data[39+sessionIdLen:]
  932. } else {
  933. data = data[38:]
  934. }
  935. if len(data) < 3 {
  936. return alertDecodeError
  937. }
  938. m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
  939. if oldTLS13Draft {
  940. // no compression method in older TLS 1.3 drafts.
  941. data = data[2:]
  942. } else {
  943. m.compressionMethod = data[2]
  944. data = data[3:]
  945. }
  946. m.nextProtoNeg = false
  947. m.nextProtos = nil
  948. m.ocspStapling = false
  949. m.scts = nil
  950. m.ticketSupported = false
  951. m.alpnProtocol = ""
  952. m.keyShare.group = 0
  953. m.keyShare.data = nil
  954. m.psk = false
  955. m.pskIdentity = 0
  956. if len(data) == 0 {
  957. // ServerHello is optionally followed by extension data
  958. return alertSuccess
  959. }
  960. if len(data) < 2 {
  961. return alertDecodeError
  962. }
  963. extensionsLength := int(data[0])<<8 | int(data[1])
  964. data = data[2:]
  965. if len(data) != extensionsLength {
  966. return alertDecodeError
  967. }
  968. svData := findExtension(data, extensionSupportedVersions)
  969. if svData != nil {
  970. if len(svData) != 2 {
  971. return alertDecodeError
  972. }
  973. if m.vers != VersionTLS12 {
  974. return alertDecodeError
  975. }
  976. m.vers = uint16(svData[0])<<8 | uint16(svData[1])
  977. }
  978. for len(data) != 0 {
  979. if len(data) < 4 {
  980. return alertDecodeError
  981. }
  982. extension := uint16(data[0])<<8 | uint16(data[1])
  983. length := int(data[2])<<8 | int(data[3])
  984. data = data[4:]
  985. if len(data) < length {
  986. return alertDecodeError
  987. }
  988. switch extension {
  989. case extensionNextProtoNeg:
  990. m.nextProtoNeg = true
  991. d := data[:length]
  992. for len(d) > 0 {
  993. l := int(d[0])
  994. d = d[1:]
  995. if l == 0 || l > len(d) {
  996. return alertDecodeError
  997. }
  998. m.nextProtos = append(m.nextProtos, string(d[:l]))
  999. d = d[l:]
  1000. }
  1001. case extensionStatusRequest:
  1002. if length > 0 {
  1003. return alertDecodeError
  1004. }
  1005. m.ocspStapling = true
  1006. case extensionSessionTicket:
  1007. if length > 0 {
  1008. return alertDecodeError
  1009. }
  1010. m.ticketSupported = true
  1011. case extensionRenegotiationInfo:
  1012. if length == 0 {
  1013. return alertDecodeError
  1014. }
  1015. d := data[:length]
  1016. l := int(d[0])
  1017. d = d[1:]
  1018. if l != len(d) {
  1019. return alertDecodeError
  1020. }
  1021. m.secureRenegotiation = d
  1022. m.secureRenegotiationSupported = true
  1023. case extensionALPN:
  1024. d := data[:length]
  1025. if len(d) < 3 {
  1026. return alertDecodeError
  1027. }
  1028. l := int(d[0])<<8 | int(d[1])
  1029. if l != len(d)-2 {
  1030. return alertDecodeError
  1031. }
  1032. d = d[2:]
  1033. l = int(d[0])
  1034. if l != len(d)-1 {
  1035. return alertDecodeError
  1036. }
  1037. d = d[1:]
  1038. if len(d) == 0 {
  1039. // ALPN protocols must not be empty.
  1040. return alertDecodeError
  1041. }
  1042. m.alpnProtocol = string(d)
  1043. case extensionSCT:
  1044. d := data[:length]
  1045. if len(d) < 2 {
  1046. return alertDecodeError
  1047. }
  1048. l := int(d[0])<<8 | int(d[1])
  1049. d = d[2:]
  1050. if len(d) != l || l == 0 {
  1051. return alertDecodeError
  1052. }
  1053. m.scts = make([][]byte, 0, 3)
  1054. for len(d) != 0 {
  1055. if len(d) < 2 {
  1056. return alertDecodeError
  1057. }
  1058. sctLen := int(d[0])<<8 | int(d[1])
  1059. d = d[2:]
  1060. if sctLen == 0 || len(d) < sctLen {
  1061. return alertDecodeError
  1062. }
  1063. m.scts = append(m.scts, d[:sctLen])
  1064. d = d[sctLen:]
  1065. }
  1066. case extensionKeyShare:
  1067. d := data[:length]
  1068. if len(d) < 4 {
  1069. return alertDecodeError
  1070. }
  1071. m.keyShare.group = CurveID(d[0])<<8 | CurveID(d[1])
  1072. l := int(d[2])<<8 | int(d[3])
  1073. d = d[4:]
  1074. if len(d) != l {
  1075. return alertDecodeError
  1076. }
  1077. m.keyShare.data = d[:l]
  1078. case extensionPreSharedKey:
  1079. if length != 2 {
  1080. return alertDecodeError
  1081. }
  1082. m.psk = true
  1083. m.pskIdentity = uint16(data[0])<<8 | uint16(data[1])
  1084. }
  1085. data = data[length:]
  1086. }
  1087. return alertSuccess
  1088. }
  1089. type encryptedExtensionsMsg struct {
  1090. raw []byte
  1091. alpnProtocol string
  1092. earlyData bool
  1093. }
  1094. func (m *encryptedExtensionsMsg) equal(i interface{}) bool {
  1095. m1, ok := i.(*encryptedExtensionsMsg)
  1096. if !ok {
  1097. return false
  1098. }
  1099. return bytes.Equal(m.raw, m1.raw) &&
  1100. m.alpnProtocol == m1.alpnProtocol &&
  1101. m.earlyData == m1.earlyData
  1102. }
  1103. func (m *encryptedExtensionsMsg) marshal() []byte {
  1104. if m.raw != nil {
  1105. return m.raw
  1106. }
  1107. length := 2
  1108. if m.earlyData {
  1109. length += 4
  1110. }
  1111. alpnLen := len(m.alpnProtocol)
  1112. if alpnLen > 0 {
  1113. if alpnLen >= 256 {
  1114. panic("invalid ALPN protocol")
  1115. }
  1116. length += 2 + 2 + 2 + 1 + alpnLen
  1117. }
  1118. x := make([]byte, 4+length)
  1119. x[0] = typeEncryptedExtensions
  1120. x[1] = uint8(length >> 16)
  1121. x[2] = uint8(length >> 8)
  1122. x[3] = uint8(length)
  1123. length -= 2
  1124. x[4] = uint8(length >> 8)
  1125. x[5] = uint8(length)
  1126. z := x[6:]
  1127. if alpnLen > 0 {
  1128. z[0] = byte(extensionALPN >> 8)
  1129. z[1] = byte(extensionALPN)
  1130. l := 2 + 1 + alpnLen
  1131. z[2] = byte(l >> 8)
  1132. z[3] = byte(l)
  1133. l -= 2
  1134. z[4] = byte(l >> 8)
  1135. z[5] = byte(l)
  1136. l -= 1
  1137. z[6] = byte(l)
  1138. copy(z[7:], []byte(m.alpnProtocol))
  1139. z = z[7+alpnLen:]
  1140. }
  1141. if m.earlyData {
  1142. z[0] = byte(extensionEarlyData >> 8)
  1143. z[1] = byte(extensionEarlyData)
  1144. z = z[4:]
  1145. }
  1146. m.raw = x
  1147. return x
  1148. }
  1149. func (m *encryptedExtensionsMsg) unmarshal(data []byte) alert {
  1150. if len(data) < 6 {
  1151. return alertDecodeError
  1152. }
  1153. m.raw = data
  1154. m.alpnProtocol = ""
  1155. m.earlyData = false
  1156. extensionsLength := int(data[4])<<8 | int(data[5])
  1157. data = data[6:]
  1158. if len(data) != extensionsLength {
  1159. return alertDecodeError
  1160. }
  1161. for len(data) != 0 {
  1162. if len(data) < 4 {
  1163. return alertDecodeError
  1164. }
  1165. extension := uint16(data[0])<<8 | uint16(data[1])
  1166. length := int(data[2])<<8 | int(data[3])
  1167. data = data[4:]
  1168. if len(data) < length {
  1169. return alertDecodeError
  1170. }
  1171. switch extension {
  1172. case extensionALPN:
  1173. d := data[:length]
  1174. if len(d) < 3 {
  1175. return alertDecodeError
  1176. }
  1177. l := int(d[0])<<8 | int(d[1])
  1178. if l != len(d)-2 {
  1179. return alertDecodeError
  1180. }
  1181. d = d[2:]
  1182. l = int(d[0])
  1183. if l != len(d)-1 {
  1184. return alertDecodeError
  1185. }
  1186. d = d[1:]
  1187. if len(d) == 0 {
  1188. // ALPN protocols must not be empty.
  1189. return alertDecodeError
  1190. }
  1191. m.alpnProtocol = string(d)
  1192. case extensionEarlyData:
  1193. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8
  1194. m.earlyData = true
  1195. }
  1196. data = data[length:]
  1197. }
  1198. return alertSuccess
  1199. }
  1200. type certificateMsg struct {
  1201. raw []byte
  1202. certificates [][]byte
  1203. }
  1204. func (m *certificateMsg) equal(i interface{}) bool {
  1205. m1, ok := i.(*certificateMsg)
  1206. if !ok {
  1207. return false
  1208. }
  1209. return bytes.Equal(m.raw, m1.raw) &&
  1210. eqByteSlices(m.certificates, m1.certificates)
  1211. }
  1212. func (m *certificateMsg) marshal() (x []byte) {
  1213. if m.raw != nil {
  1214. return m.raw
  1215. }
  1216. var i int
  1217. for _, slice := range m.certificates {
  1218. i += len(slice)
  1219. }
  1220. length := 3 + 3*len(m.certificates) + i
  1221. x = make([]byte, 4+length)
  1222. x[0] = typeCertificate
  1223. x[1] = uint8(length >> 16)
  1224. x[2] = uint8(length >> 8)
  1225. x[3] = uint8(length)
  1226. certificateOctets := length - 3
  1227. x[4] = uint8(certificateOctets >> 16)
  1228. x[5] = uint8(certificateOctets >> 8)
  1229. x[6] = uint8(certificateOctets)
  1230. y := x[7:]
  1231. for _, slice := range m.certificates {
  1232. y[0] = uint8(len(slice) >> 16)
  1233. y[1] = uint8(len(slice) >> 8)
  1234. y[2] = uint8(len(slice))
  1235. copy(y[3:], slice)
  1236. y = y[3+len(slice):]
  1237. }
  1238. m.raw = x
  1239. return
  1240. }
  1241. func (m *certificateMsg) unmarshal(data []byte) alert {
  1242. if len(data) < 7 {
  1243. return alertDecodeError
  1244. }
  1245. m.raw = data
  1246. certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
  1247. if uint32(len(data)) != certsLen+7 {
  1248. return alertDecodeError
  1249. }
  1250. numCerts := 0
  1251. d := data[7:]
  1252. for certsLen > 0 {
  1253. if len(d) < 4 {
  1254. return alertDecodeError
  1255. }
  1256. certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1257. if uint32(len(d)) < 3+certLen {
  1258. return alertDecodeError
  1259. }
  1260. d = d[3+certLen:]
  1261. certsLen -= 3 + certLen
  1262. numCerts++
  1263. }
  1264. m.certificates = make([][]byte, numCerts)
  1265. d = data[7:]
  1266. for i := 0; i < numCerts; i++ {
  1267. certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1268. m.certificates[i] = d[3 : 3+certLen]
  1269. d = d[3+certLen:]
  1270. }
  1271. return alertSuccess
  1272. }
  1273. type certificateEntry struct {
  1274. data []byte
  1275. ocspStaple []byte
  1276. sctList [][]byte
  1277. }
  1278. type certificateMsg13 struct {
  1279. raw []byte
  1280. requestContext []byte
  1281. certificates []certificateEntry
  1282. }
  1283. func (m *certificateMsg13) equal(i interface{}) bool {
  1284. m1, ok := i.(*certificateMsg13)
  1285. if !ok {
  1286. return false
  1287. }
  1288. if len(m.certificates) != len(m1.certificates) {
  1289. return false
  1290. }
  1291. for i, _ := range m.certificates {
  1292. ok := bytes.Equal(m.certificates[i].data, m1.certificates[i].data)
  1293. ok = ok && bytes.Equal(m.certificates[i].ocspStaple, m1.certificates[i].ocspStaple)
  1294. ok = ok && eqByteSlices(m.certificates[i].sctList, m1.certificates[i].sctList)
  1295. if !ok {
  1296. return false
  1297. }
  1298. }
  1299. return bytes.Equal(m.raw, m1.raw) &&
  1300. bytes.Equal(m.requestContext, m1.requestContext)
  1301. }
  1302. func (m *certificateMsg13) marshal() (x []byte) {
  1303. if m.raw != nil {
  1304. return m.raw
  1305. }
  1306. var i int
  1307. for _, cert := range m.certificates {
  1308. i += len(cert.data)
  1309. if len(cert.ocspStaple) != 0 {
  1310. i += 8 + len(cert.ocspStaple)
  1311. }
  1312. if len(cert.sctList) != 0 {
  1313. i += 6
  1314. for _, sct := range cert.sctList {
  1315. i += 2 + len(sct)
  1316. }
  1317. }
  1318. }
  1319. length := 3 + 3*len(m.certificates) + i
  1320. length += 2 * len(m.certificates) // extensions
  1321. length += 1 + len(m.requestContext)
  1322. x = make([]byte, 4+length)
  1323. x[0] = typeCertificate
  1324. x[1] = uint8(length >> 16)
  1325. x[2] = uint8(length >> 8)
  1326. x[3] = uint8(length)
  1327. z := x[4:]
  1328. z[0] = byte(len(m.requestContext))
  1329. copy(z[1:], m.requestContext)
  1330. z = z[1+len(m.requestContext):]
  1331. certificateOctets := len(z) - 3
  1332. z[0] = uint8(certificateOctets >> 16)
  1333. z[1] = uint8(certificateOctets >> 8)
  1334. z[2] = uint8(certificateOctets)
  1335. z = z[3:]
  1336. for _, cert := range m.certificates {
  1337. z[0] = uint8(len(cert.data) >> 16)
  1338. z[1] = uint8(len(cert.data) >> 8)
  1339. z[2] = uint8(len(cert.data))
  1340. copy(z[3:], cert.data)
  1341. z = z[3+len(cert.data):]
  1342. extLenPos := z[:2]
  1343. z = z[2:]
  1344. extensionLen := 0
  1345. if len(cert.ocspStaple) != 0 {
  1346. stapleLen := 4 + len(cert.ocspStaple)
  1347. z[0] = uint8(extensionStatusRequest >> 8)
  1348. z[1] = uint8(extensionStatusRequest)
  1349. z[2] = uint8(stapleLen >> 8)
  1350. z[3] = uint8(stapleLen)
  1351. stapleLen -= 4
  1352. z[4] = statusTypeOCSP
  1353. z[5] = uint8(stapleLen >> 16)
  1354. z[6] = uint8(stapleLen >> 8)
  1355. z[7] = uint8(stapleLen)
  1356. copy(z[8:], cert.ocspStaple)
  1357. z = z[8+stapleLen:]
  1358. extensionLen += 8 + stapleLen
  1359. }
  1360. if len(cert.sctList) != 0 {
  1361. z[0] = uint8(extensionSCT >> 8)
  1362. z[1] = uint8(extensionSCT)
  1363. sctLenPos := z[2:6]
  1364. z = z[6:]
  1365. extensionLen += 6
  1366. sctLen := 2
  1367. for _, sct := range cert.sctList {
  1368. z[0] = uint8(len(sct) >> 8)
  1369. z[1] = uint8(len(sct))
  1370. copy(z[2:], sct)
  1371. z = z[2+len(sct):]
  1372. extensionLen += 2 + len(sct)
  1373. sctLen += 2 + len(sct)
  1374. }
  1375. sctLenPos[0] = uint8(sctLen >> 8)
  1376. sctLenPos[1] = uint8(sctLen)
  1377. sctLen -= 2
  1378. sctLenPos[2] = uint8(sctLen >> 8)
  1379. sctLenPos[3] = uint8(sctLen)
  1380. }
  1381. extLenPos[0] = uint8(extensionLen >> 8)
  1382. extLenPos[1] = uint8(extensionLen)
  1383. }
  1384. m.raw = x
  1385. return
  1386. }
  1387. func (m *certificateMsg13) unmarshal(data []byte) alert {
  1388. if len(data) < 5 {
  1389. return alertDecodeError
  1390. }
  1391. m.raw = data
  1392. ctxLen := data[4]
  1393. if len(data) < int(ctxLen)+5+3 {
  1394. return alertDecodeError
  1395. }
  1396. m.requestContext = data[5 : 5+ctxLen]
  1397. d := data[5+ctxLen:]
  1398. certsLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1399. if uint32(len(d)) != certsLen+3 {
  1400. return alertDecodeError
  1401. }
  1402. numCerts := 0
  1403. d = d[3:]
  1404. for certsLen > 0 {
  1405. if len(d) < 4 {
  1406. return alertDecodeError
  1407. }
  1408. certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1409. if uint32(len(d)) < 3+certLen {
  1410. return alertDecodeError
  1411. }
  1412. d = d[3+certLen:]
  1413. if len(d) < 2 {
  1414. return alertDecodeError
  1415. }
  1416. extLen := uint16(d[0])<<8 | uint16(d[1])
  1417. if uint16(len(d)) < 2+extLen {
  1418. return alertDecodeError
  1419. }
  1420. d = d[2+extLen:]
  1421. certsLen -= 3 + certLen + 2 + uint32(extLen)
  1422. numCerts++
  1423. }
  1424. m.certificates = make([]certificateEntry, numCerts)
  1425. d = data[8+ctxLen:]
  1426. for i := 0; i < numCerts; i++ {
  1427. certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1428. m.certificates[i].data = d[3 : 3+certLen]
  1429. d = d[3+certLen:]
  1430. extLen := uint16(d[0])<<8 | uint16(d[1])
  1431. d = d[2:]
  1432. for extLen > 0 {
  1433. if extLen < 4 {
  1434. return alertDecodeError
  1435. }
  1436. typ := uint16(d[0])<<8 | uint16(d[1])
  1437. bodyLen := uint16(d[2])<<8 | uint16(d[3])
  1438. if extLen < 4+bodyLen {
  1439. return alertDecodeError
  1440. }
  1441. body := d[4 : 4+bodyLen]
  1442. d = d[4+bodyLen:]
  1443. extLen -= 4 + bodyLen
  1444. switch typ {
  1445. case extensionStatusRequest:
  1446. if len(body) < 4 || body[0] != 0x01 {
  1447. return alertDecodeError
  1448. }
  1449. ocspLen := int(body[1])<<16 | int(body[2])<<8 | int(body[3])
  1450. if len(body) != 4+ocspLen {
  1451. return alertDecodeError
  1452. }
  1453. m.certificates[i].ocspStaple = body[4:]
  1454. case extensionSCT:
  1455. if len(body) < 2 {
  1456. return alertDecodeError
  1457. }
  1458. listLen := int(body[0])<<8 | int(body[1])
  1459. body = body[2:]
  1460. if len(body) != listLen {
  1461. return alertDecodeError
  1462. }
  1463. for len(body) > 0 {
  1464. if len(body) < 2 {
  1465. return alertDecodeError
  1466. }
  1467. sctLen := int(body[0])<<8 | int(body[1])
  1468. if len(body) < 2+sctLen {
  1469. return alertDecodeError
  1470. }
  1471. m.certificates[i].sctList = append(m.certificates[i].sctList, body[2:2+sctLen])
  1472. body = body[2+sctLen:]
  1473. }
  1474. }
  1475. }
  1476. }
  1477. return alertSuccess
  1478. }
  1479. type serverKeyExchangeMsg struct {
  1480. raw []byte
  1481. key []byte
  1482. }
  1483. func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
  1484. m1, ok := i.(*serverKeyExchangeMsg)
  1485. if !ok {
  1486. return false
  1487. }
  1488. return bytes.Equal(m.raw, m1.raw) &&
  1489. bytes.Equal(m.key, m1.key)
  1490. }
  1491. func (m *serverKeyExchangeMsg) marshal() []byte {
  1492. if m.raw != nil {
  1493. return m.raw
  1494. }
  1495. length := len(m.key)
  1496. x := make([]byte, length+4)
  1497. x[0] = typeServerKeyExchange
  1498. x[1] = uint8(length >> 16)
  1499. x[2] = uint8(length >> 8)
  1500. x[3] = uint8(length)
  1501. copy(x[4:], m.key)
  1502. m.raw = x
  1503. return x
  1504. }
  1505. func (m *serverKeyExchangeMsg) unmarshal(data []byte) alert {
  1506. m.raw = data
  1507. if len(data) < 4 {
  1508. return alertDecodeError
  1509. }
  1510. m.key = data[4:]
  1511. return alertSuccess
  1512. }
  1513. type certificateStatusMsg struct {
  1514. raw []byte
  1515. statusType uint8
  1516. response []byte
  1517. }
  1518. func (m *certificateStatusMsg) equal(i interface{}) bool {
  1519. m1, ok := i.(*certificateStatusMsg)
  1520. if !ok {
  1521. return false
  1522. }
  1523. return bytes.Equal(m.raw, m1.raw) &&
  1524. m.statusType == m1.statusType &&
  1525. bytes.Equal(m.response, m1.response)
  1526. }
  1527. func (m *certificateStatusMsg) marshal() []byte {
  1528. if m.raw != nil {
  1529. return m.raw
  1530. }
  1531. var x []byte
  1532. if m.statusType == statusTypeOCSP {
  1533. x = make([]byte, 4+4+len(m.response))
  1534. x[0] = typeCertificateStatus
  1535. l := len(m.response) + 4
  1536. x[1] = byte(l >> 16)
  1537. x[2] = byte(l >> 8)
  1538. x[3] = byte(l)
  1539. x[4] = statusTypeOCSP
  1540. l -= 4
  1541. x[5] = byte(l >> 16)
  1542. x[6] = byte(l >> 8)
  1543. x[7] = byte(l)
  1544. copy(x[8:], m.response)
  1545. } else {
  1546. x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
  1547. }
  1548. m.raw = x
  1549. return x
  1550. }
  1551. func (m *certificateStatusMsg) unmarshal(data []byte) alert {
  1552. m.raw = data
  1553. if len(data) < 5 {
  1554. return alertDecodeError
  1555. }
  1556. m.statusType = data[4]
  1557. m.response = nil
  1558. if m.statusType == statusTypeOCSP {
  1559. if len(data) < 8 {
  1560. return alertDecodeError
  1561. }
  1562. respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
  1563. if uint32(len(data)) != 4+4+respLen {
  1564. return alertDecodeError
  1565. }
  1566. m.response = data[8:]
  1567. }
  1568. return alertSuccess
  1569. }
  1570. type serverHelloDoneMsg struct{}
  1571. func (m *serverHelloDoneMsg) equal(i interface{}) bool {
  1572. _, ok := i.(*serverHelloDoneMsg)
  1573. return ok
  1574. }
  1575. func (m *serverHelloDoneMsg) marshal() []byte {
  1576. x := make([]byte, 4)
  1577. x[0] = typeServerHelloDone
  1578. return x
  1579. }
  1580. func (m *serverHelloDoneMsg) unmarshal(data []byte) alert {
  1581. if len(data) != 4 {
  1582. return alertDecodeError
  1583. }
  1584. return alertSuccess
  1585. }
  1586. type clientKeyExchangeMsg struct {
  1587. raw []byte
  1588. ciphertext []byte
  1589. }
  1590. func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
  1591. m1, ok := i.(*clientKeyExchangeMsg)
  1592. if !ok {
  1593. return false
  1594. }
  1595. return bytes.Equal(m.raw, m1.raw) &&
  1596. bytes.Equal(m.ciphertext, m1.ciphertext)
  1597. }
  1598. func (m *clientKeyExchangeMsg) marshal() []byte {
  1599. if m.raw != nil {
  1600. return m.raw
  1601. }
  1602. length := len(m.ciphertext)
  1603. x := make([]byte, length+4)
  1604. x[0] = typeClientKeyExchange
  1605. x[1] = uint8(length >> 16)
  1606. x[2] = uint8(length >> 8)
  1607. x[3] = uint8(length)
  1608. copy(x[4:], m.ciphertext)
  1609. m.raw = x
  1610. return x
  1611. }
  1612. func (m *clientKeyExchangeMsg) unmarshal(data []byte) alert {
  1613. m.raw = data
  1614. if len(data) < 4 {
  1615. return alertDecodeError
  1616. }
  1617. l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  1618. if l != len(data)-4 {
  1619. return alertDecodeError
  1620. }
  1621. m.ciphertext = data[4:]
  1622. return alertSuccess
  1623. }
  1624. type finishedMsg struct {
  1625. raw []byte
  1626. verifyData []byte
  1627. }
  1628. func (m *finishedMsg) equal(i interface{}) bool {
  1629. m1, ok := i.(*finishedMsg)
  1630. if !ok {
  1631. return false
  1632. }
  1633. return bytes.Equal(m.raw, m1.raw) &&
  1634. bytes.Equal(m.verifyData, m1.verifyData)
  1635. }
  1636. func (m *finishedMsg) marshal() (x []byte) {
  1637. if m.raw != nil {
  1638. return m.raw
  1639. }
  1640. x = make([]byte, 4+len(m.verifyData))
  1641. x[0] = typeFinished
  1642. x[3] = byte(len(m.verifyData))
  1643. copy(x[4:], m.verifyData)
  1644. m.raw = x
  1645. return
  1646. }
  1647. func (m *finishedMsg) unmarshal(data []byte) alert {
  1648. m.raw = data
  1649. if len(data) < 4 {
  1650. return alertDecodeError
  1651. }
  1652. m.verifyData = data[4:]
  1653. return alertSuccess
  1654. }
  1655. type nextProtoMsg struct {
  1656. raw []byte
  1657. proto string
  1658. }
  1659. func (m *nextProtoMsg) equal(i interface{}) bool {
  1660. m1, ok := i.(*nextProtoMsg)
  1661. if !ok {
  1662. return false
  1663. }
  1664. return bytes.Equal(m.raw, m1.raw) &&
  1665. m.proto == m1.proto
  1666. }
  1667. func (m *nextProtoMsg) marshal() []byte {
  1668. if m.raw != nil {
  1669. return m.raw
  1670. }
  1671. l := len(m.proto)
  1672. if l > 255 {
  1673. l = 255
  1674. }
  1675. padding := 32 - (l+2)%32
  1676. length := l + padding + 2
  1677. x := make([]byte, length+4)
  1678. x[0] = typeNextProtocol
  1679. x[1] = uint8(length >> 16)
  1680. x[2] = uint8(length >> 8)
  1681. x[3] = uint8(length)
  1682. y := x[4:]
  1683. y[0] = byte(l)
  1684. copy(y[1:], []byte(m.proto[0:l]))
  1685. y = y[1+l:]
  1686. y[0] = byte(padding)
  1687. m.raw = x
  1688. return x
  1689. }
  1690. func (m *nextProtoMsg) unmarshal(data []byte) alert {
  1691. m.raw = data
  1692. if len(data) < 5 {
  1693. return alertDecodeError
  1694. }
  1695. data = data[4:]
  1696. protoLen := int(data[0])
  1697. data = data[1:]
  1698. if len(data) < protoLen {
  1699. return alertDecodeError
  1700. }
  1701. m.proto = string(data[0:protoLen])
  1702. data = data[protoLen:]
  1703. if len(data) < 1 {
  1704. return alertDecodeError
  1705. }
  1706. paddingLen := int(data[0])
  1707. data = data[1:]
  1708. if len(data) != paddingLen {
  1709. return alertDecodeError
  1710. }
  1711. return alertSuccess
  1712. }
  1713. type certificateRequestMsg struct {
  1714. raw []byte
  1715. // hasSignatureAndHash indicates whether this message includes a list
  1716. // of signature and hash functions. This change was introduced with TLS
  1717. // 1.2.
  1718. hasSignatureAndHash bool
  1719. certificateTypes []byte
  1720. supportedSignatureAlgorithms []SignatureScheme
  1721. certificateAuthorities [][]byte
  1722. }
  1723. func (m *certificateRequestMsg) equal(i interface{}) bool {
  1724. m1, ok := i.(*certificateRequestMsg)
  1725. if !ok {
  1726. return false
  1727. }
  1728. return bytes.Equal(m.raw, m1.raw) &&
  1729. bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
  1730. eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
  1731. eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms)
  1732. }
  1733. func (m *certificateRequestMsg) marshal() (x []byte) {
  1734. if m.raw != nil {
  1735. return m.raw
  1736. }
  1737. // See http://tools.ietf.org/html/rfc4346#section-7.4.4
  1738. length := 1 + len(m.certificateTypes) + 2
  1739. casLength := 0
  1740. for _, ca := range m.certificateAuthorities {
  1741. casLength += 2 + len(ca)
  1742. }
  1743. length += casLength
  1744. if m.hasSignatureAndHash {
  1745. length += 2 + 2*len(m.supportedSignatureAlgorithms)
  1746. }
  1747. x = make([]byte, 4+length)
  1748. x[0] = typeCertificateRequest
  1749. x[1] = uint8(length >> 16)
  1750. x[2] = uint8(length >> 8)
  1751. x[3] = uint8(length)
  1752. x[4] = uint8(len(m.certificateTypes))
  1753. copy(x[5:], m.certificateTypes)
  1754. y := x[5+len(m.certificateTypes):]
  1755. if m.hasSignatureAndHash {
  1756. n := len(m.supportedSignatureAlgorithms) * 2
  1757. y[0] = uint8(n >> 8)
  1758. y[1] = uint8(n)
  1759. y = y[2:]
  1760. for _, sigAlgo := range m.supportedSignatureAlgorithms {
  1761. y[0] = uint8(sigAlgo >> 8)
  1762. y[1] = uint8(sigAlgo)
  1763. y = y[2:]
  1764. }
  1765. }
  1766. y[0] = uint8(casLength >> 8)
  1767. y[1] = uint8(casLength)
  1768. y = y[2:]
  1769. for _, ca := range m.certificateAuthorities {
  1770. y[0] = uint8(len(ca) >> 8)
  1771. y[1] = uint8(len(ca))
  1772. y = y[2:]
  1773. copy(y, ca)
  1774. y = y[len(ca):]
  1775. }
  1776. m.raw = x
  1777. return
  1778. }
  1779. func (m *certificateRequestMsg) unmarshal(data []byte) alert {
  1780. m.raw = data
  1781. if len(data) < 5 {
  1782. return alertDecodeError
  1783. }
  1784. length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1785. if uint32(len(data))-4 != length {
  1786. return alertDecodeError
  1787. }
  1788. numCertTypes := int(data[4])
  1789. data = data[5:]
  1790. if numCertTypes == 0 || len(data) <= numCertTypes {
  1791. return alertDecodeError
  1792. }
  1793. m.certificateTypes = make([]byte, numCertTypes)
  1794. if copy(m.certificateTypes, data) != numCertTypes {
  1795. return alertDecodeError
  1796. }
  1797. data = data[numCertTypes:]
  1798. if m.hasSignatureAndHash {
  1799. if len(data) < 2 {
  1800. return alertDecodeError
  1801. }
  1802. sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
  1803. data = data[2:]
  1804. if sigAndHashLen&1 != 0 {
  1805. return alertDecodeError
  1806. }
  1807. if len(data) < int(sigAndHashLen) {
  1808. return alertDecodeError
  1809. }
  1810. numSigAlgos := sigAndHashLen / 2
  1811. m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
  1812. for i := range m.supportedSignatureAlgorithms {
  1813. m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
  1814. data = data[2:]
  1815. }
  1816. }
  1817. if len(data) < 2 {
  1818. return alertDecodeError
  1819. }
  1820. casLength := uint16(data[0])<<8 | uint16(data[1])
  1821. data = data[2:]
  1822. if len(data) < int(casLength) {
  1823. return alertDecodeError
  1824. }
  1825. cas := make([]byte, casLength)
  1826. copy(cas, data)
  1827. data = data[casLength:]
  1828. m.certificateAuthorities = nil
  1829. for len(cas) > 0 {
  1830. if len(cas) < 2 {
  1831. return alertDecodeError
  1832. }
  1833. caLen := uint16(cas[0])<<8 | uint16(cas[1])
  1834. cas = cas[2:]
  1835. if len(cas) < int(caLen) {
  1836. return alertDecodeError
  1837. }
  1838. m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  1839. cas = cas[caLen:]
  1840. }
  1841. if len(data) != 0 {
  1842. return alertDecodeError
  1843. }
  1844. return alertSuccess
  1845. }
  1846. type certificateVerifyMsg struct {
  1847. raw []byte
  1848. hasSignatureAndHash bool
  1849. signatureAlgorithm SignatureScheme
  1850. signature []byte
  1851. }
  1852. func (m *certificateVerifyMsg) equal(i interface{}) bool {
  1853. m1, ok := i.(*certificateVerifyMsg)
  1854. if !ok {
  1855. return false
  1856. }
  1857. return bytes.Equal(m.raw, m1.raw) &&
  1858. m.hasSignatureAndHash == m1.hasSignatureAndHash &&
  1859. m.signatureAlgorithm == m1.signatureAlgorithm &&
  1860. bytes.Equal(m.signature, m1.signature)
  1861. }
  1862. func (m *certificateVerifyMsg) marshal() (x []byte) {
  1863. if m.raw != nil {
  1864. return m.raw
  1865. }
  1866. // See http://tools.ietf.org/html/rfc4346#section-7.4.8
  1867. siglength := len(m.signature)
  1868. length := 2 + siglength
  1869. if m.hasSignatureAndHash {
  1870. length += 2
  1871. }
  1872. x = make([]byte, 4+length)
  1873. x[0] = typeCertificateVerify
  1874. x[1] = uint8(length >> 16)
  1875. x[2] = uint8(length >> 8)
  1876. x[3] = uint8(length)
  1877. y := x[4:]
  1878. if m.hasSignatureAndHash {
  1879. y[0] = uint8(m.signatureAlgorithm >> 8)
  1880. y[1] = uint8(m.signatureAlgorithm)
  1881. y = y[2:]
  1882. }
  1883. y[0] = uint8(siglength >> 8)
  1884. y[1] = uint8(siglength)
  1885. copy(y[2:], m.signature)
  1886. m.raw = x
  1887. return
  1888. }
  1889. func (m *certificateVerifyMsg) unmarshal(data []byte) alert {
  1890. m.raw = data
  1891. if len(data) < 6 {
  1892. return alertDecodeError
  1893. }
  1894. length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1895. if uint32(len(data))-4 != length {
  1896. return alertDecodeError
  1897. }
  1898. data = data[4:]
  1899. if m.hasSignatureAndHash {
  1900. m.signatureAlgorithm = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
  1901. data = data[2:]
  1902. }
  1903. if len(data) < 2 {
  1904. return alertDecodeError
  1905. }
  1906. siglength := int(data[0])<<8 + int(data[1])
  1907. data = data[2:]
  1908. if len(data) != siglength {
  1909. return alertDecodeError
  1910. }
  1911. m.signature = data
  1912. return alertSuccess
  1913. }
  1914. type newSessionTicketMsg struct {
  1915. raw []byte
  1916. ticket []byte
  1917. }
  1918. func (m *newSessionTicketMsg) equal(i interface{}) bool {
  1919. m1, ok := i.(*newSessionTicketMsg)
  1920. if !ok {
  1921. return false
  1922. }
  1923. return bytes.Equal(m.raw, m1.raw) &&
  1924. bytes.Equal(m.ticket, m1.ticket)
  1925. }
  1926. func (m *newSessionTicketMsg) marshal() (x []byte) {
  1927. if m.raw != nil {
  1928. return m.raw
  1929. }
  1930. // See http://tools.ietf.org/html/rfc5077#section-3.3
  1931. ticketLen := len(m.ticket)
  1932. length := 2 + 4 + ticketLen
  1933. x = make([]byte, 4+length)
  1934. x[0] = typeNewSessionTicket
  1935. x[1] = uint8(length >> 16)
  1936. x[2] = uint8(length >> 8)
  1937. x[3] = uint8(length)
  1938. x[8] = uint8(ticketLen >> 8)
  1939. x[9] = uint8(ticketLen)
  1940. copy(x[10:], m.ticket)
  1941. m.raw = x
  1942. return
  1943. }
  1944. func (m *newSessionTicketMsg) unmarshal(data []byte) alert {
  1945. m.raw = data
  1946. if len(data) < 10 {
  1947. return alertDecodeError
  1948. }
  1949. length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1950. if uint32(len(data))-4 != length {
  1951. return alertDecodeError
  1952. }
  1953. ticketLen := int(data[8])<<8 + int(data[9])
  1954. if len(data)-10 != ticketLen {
  1955. return alertDecodeError
  1956. }
  1957. m.ticket = data[10:]
  1958. return alertSuccess
  1959. }
  1960. type newSessionTicketMsg13 struct {
  1961. raw []byte
  1962. lifetime uint32
  1963. ageAdd uint32
  1964. nonce []byte
  1965. ticket []byte
  1966. withEarlyDataInfo bool
  1967. maxEarlyDataLength uint32
  1968. }
  1969. func (m *newSessionTicketMsg13) equal(i interface{}) bool {
  1970. m1, ok := i.(*newSessionTicketMsg13)
  1971. if !ok {
  1972. return false
  1973. }
  1974. return bytes.Equal(m.raw, m1.raw) &&
  1975. m.lifetime == m1.lifetime &&
  1976. m.ageAdd == m1.ageAdd &&
  1977. bytes.Equal(m.nonce, m1.nonce) &&
  1978. bytes.Equal(m.ticket, m1.ticket) &&
  1979. m.withEarlyDataInfo == m1.withEarlyDataInfo &&
  1980. m.maxEarlyDataLength == m1.maxEarlyDataLength
  1981. }
  1982. func (m *newSessionTicketMsg13) marshal() (x []byte) {
  1983. if m.raw != nil {
  1984. return m.raw
  1985. }
  1986. // See https://tools.ietf.org/html/draft-ietf-tls-tls13-21#section-4.6.1
  1987. nonceLen := len(m.nonce)
  1988. ticketLen := len(m.ticket)
  1989. length := 13 + nonceLen + ticketLen
  1990. if m.withEarlyDataInfo {
  1991. length += 8
  1992. }
  1993. x = make([]byte, 4+length)
  1994. x[0] = typeNewSessionTicket
  1995. x[1] = uint8(length >> 16)
  1996. x[2] = uint8(length >> 8)
  1997. x[3] = uint8(length)
  1998. x[4] = uint8(m.lifetime >> 24)
  1999. x[5] = uint8(m.lifetime >> 16)
  2000. x[6] = uint8(m.lifetime >> 8)
  2001. x[7] = uint8(m.lifetime)
  2002. x[8] = uint8(m.ageAdd >> 24)
  2003. x[9] = uint8(m.ageAdd >> 16)
  2004. x[10] = uint8(m.ageAdd >> 8)
  2005. x[11] = uint8(m.ageAdd)
  2006. x[12] = uint8(nonceLen)
  2007. copy(x[13:13+nonceLen], m.nonce)
  2008. y := x[13+nonceLen:]
  2009. y[0] = uint8(ticketLen >> 8)
  2010. y[1] = uint8(ticketLen)
  2011. copy(y[2:2+ticketLen], m.ticket)
  2012. if m.withEarlyDataInfo {
  2013. z := y[2+ticketLen:]
  2014. // z[0] is already 0, this is the extensions vector length.
  2015. z[1] = 8
  2016. z[2] = uint8(extensionEarlyData >> 8)
  2017. z[3] = uint8(extensionEarlyData)
  2018. z[5] = 4
  2019. z[6] = uint8(m.maxEarlyDataLength >> 24)
  2020. z[7] = uint8(m.maxEarlyDataLength >> 16)
  2021. z[8] = uint8(m.maxEarlyDataLength >> 8)
  2022. z[9] = uint8(m.maxEarlyDataLength)
  2023. }
  2024. m.raw = x
  2025. return
  2026. }
  2027. func (m *newSessionTicketMsg13) unmarshal(data []byte) alert {
  2028. m.raw = data
  2029. m.maxEarlyDataLength = 0
  2030. m.withEarlyDataInfo = false
  2031. if len(data) < 17 {
  2032. return alertDecodeError
  2033. }
  2034. length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  2035. if uint32(len(data))-4 != length {
  2036. return alertDecodeError
  2037. }
  2038. m.lifetime = uint32(data[4])<<24 | uint32(data[5])<<16 |
  2039. uint32(data[6])<<8 | uint32(data[7])
  2040. m.ageAdd = uint32(data[8])<<24 | uint32(data[9])<<16 |
  2041. uint32(data[10])<<8 | uint32(data[11])
  2042. nonceLen := int(data[12])
  2043. if nonceLen == 0 || 13+nonceLen+2 > len(data) {
  2044. return alertDecodeError
  2045. }
  2046. m.nonce = data[13 : 13+nonceLen]
  2047. data = data[13+nonceLen:]
  2048. ticketLen := int(data[0])<<8 + int(data[1])
  2049. if ticketLen == 0 || 2+ticketLen+2 > len(data) {
  2050. return alertDecodeError
  2051. }
  2052. m.ticket = data[2 : 2+ticketLen]
  2053. data = data[2+ticketLen:]
  2054. extLen := int(data[0])<<8 + int(data[1])
  2055. if extLen != len(data)-2 {
  2056. return alertDecodeError
  2057. }
  2058. data = data[2:]
  2059. for len(data) > 0 {
  2060. if len(data) < 4 {
  2061. return alertDecodeError
  2062. }
  2063. extType := uint16(data[0])<<8 + uint16(data[1])
  2064. length := int(data[2])<<8 + int(data[3])
  2065. data = data[4:]
  2066. switch extType {
  2067. case extensionEarlyData:
  2068. if length != 4 {
  2069. return alertDecodeError
  2070. }
  2071. m.withEarlyDataInfo = true
  2072. m.maxEarlyDataLength = uint32(data[0])<<24 | uint32(data[1])<<16 |
  2073. uint32(data[2])<<8 | uint32(data[3])
  2074. }
  2075. data = data[length:]
  2076. }
  2077. return alertSuccess
  2078. }
  2079. type endOfEarlyDataMsg struct {
  2080. }
  2081. func (*endOfEarlyDataMsg) marshal() []byte {
  2082. return []byte{typeEndOfEarlyData, 0, 0, 0}
  2083. }
  2084. func (*endOfEarlyDataMsg) unmarshal(data []byte) alert {
  2085. if len(data) != 4 {
  2086. return alertDecodeError
  2087. }
  2088. return alertSuccess
  2089. }
  2090. type helloRequestMsg struct {
  2091. }
  2092. func (*helloRequestMsg) marshal() []byte {
  2093. return []byte{typeHelloRequest, 0, 0, 0}
  2094. }
  2095. func (*helloRequestMsg) unmarshal(data []byte) alert {
  2096. if len(data) != 4 {
  2097. return alertDecodeError
  2098. }
  2099. return alertSuccess
  2100. }
  2101. func eqUint16s(x, y []uint16) bool {
  2102. if len(x) != len(y) {
  2103. return false
  2104. }
  2105. for i, v := range x {
  2106. if y[i] != v {
  2107. return false
  2108. }
  2109. }
  2110. return true
  2111. }
  2112. func eqCurveIDs(x, y []CurveID) bool {
  2113. if len(x) != len(y) {
  2114. return false
  2115. }
  2116. for i, v := range x {
  2117. if y[i] != v {
  2118. return false
  2119. }
  2120. }
  2121. return true
  2122. }
  2123. func eqStrings(x, y []string) bool {
  2124. if len(x) != len(y) {
  2125. return false
  2126. }
  2127. for i, v := range x {
  2128. if y[i] != v {
  2129. return false
  2130. }
  2131. }
  2132. return true
  2133. }
  2134. func eqByteSlices(x, y [][]byte) bool {
  2135. if len(x) != len(y) {
  2136. return false
  2137. }
  2138. for i, v := range x {
  2139. if !bytes.Equal(v, y[i]) {
  2140. return false
  2141. }
  2142. }
  2143. return true
  2144. }
  2145. func eqSignatureAlgorithms(x, y []SignatureScheme) bool {
  2146. if len(x) != len(y) {
  2147. return false
  2148. }
  2149. for i, v := range x {
  2150. if v != y[i] {
  2151. return false
  2152. }
  2153. }
  2154. return true
  2155. }
  2156. func eqKeyShares(x, y []keyShare) bool {
  2157. if len(x) != len(y) {
  2158. return false
  2159. }
  2160. for i := range x {
  2161. if x[i].group != y[i].group {
  2162. return false
  2163. }
  2164. if !bytes.Equal(x[i].data, y[i].data) {
  2165. return false
  2166. }
  2167. }
  2168. return true
  2169. }
  2170. func findExtension(data []byte, extensionType uint16) []byte {
  2171. for len(data) != 0 {
  2172. if len(data) < 4 {
  2173. return nil
  2174. }
  2175. extension := uint16(data[0])<<8 | uint16(data[1])
  2176. length := int(data[2])<<8 | int(data[3])
  2177. data = data[4:]
  2178. if len(data) < length {
  2179. return nil
  2180. }
  2181. if extension == extensionType {
  2182. return data[:length]
  2183. }
  2184. data = data[length:]
  2185. }
  2186. return nil
  2187. }