Alternative TLS implementation in Go
Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

548 Zeilen
16 KiB

  1. package tls
  2. import (
  3. "bytes"
  4. "crypto"
  5. "crypto/ecdsa"
  6. "crypto/elliptic"
  7. "crypto/hmac"
  8. "crypto/rsa"
  9. "crypto/subtle"
  10. "encoding/hex"
  11. "errors"
  12. "fmt"
  13. "hash"
  14. "io"
  15. "os"
  16. "runtime/debug"
  17. "time"
  18. "golang_org/x/crypto/curve25519"
  19. )
  20. func (hs *serverHandshakeState) doTLS13Handshake() error {
  21. config := hs.c.config
  22. c := hs.c
  23. hs.c.cipherSuite, hs.hello13.cipherSuite = hs.suite.id, hs.suite.id
  24. hs.c.clientHello = hs.clientHello.marshal()
  25. // When picking the group for the handshake, priority is given to groups
  26. // that the client provided a keyShare for, so to avoid a round-trip.
  27. // After that the order of CurvePreferences is respected.
  28. var ks keyShare
  29. for _, curveID := range config.curvePreferences() {
  30. for _, keyShare := range hs.clientHello.keyShares {
  31. if curveID == keyShare.group {
  32. ks = keyShare
  33. break
  34. }
  35. }
  36. }
  37. if ks.group == 0 {
  38. c.sendAlert(alertInternalError)
  39. return errors.New("tls: HelloRetryRequest not implemented") // TODO(filippo)
  40. }
  41. privateKey, serverKS, err := config.generateKeyShare(ks.group)
  42. if err != nil {
  43. c.sendAlert(alertInternalError)
  44. return err
  45. }
  46. hs.hello13.keyShare = serverKS
  47. hash := hashForSuite(hs.suite)
  48. hashSize := hash.Size()
  49. earlySecret, isPSK := hs.checkPSK()
  50. if !isPSK {
  51. earlySecret = hkdfExtract(hash, nil, nil)
  52. }
  53. c.didResume = isPSK
  54. hs.finishedHash13 = hash.New()
  55. hs.finishedHash13.Write(hs.clientHello.marshal())
  56. handshakeCtx := hs.finishedHash13.Sum(nil)
  57. earlyClientCipher, _ := hs.prepareCipher(handshakeCtx, earlySecret, "client early traffic secret")
  58. ecdheSecret := deriveECDHESecret(ks, privateKey)
  59. if ecdheSecret == nil {
  60. c.sendAlert(alertIllegalParameter)
  61. return errors.New("tls: bad ECDHE client share")
  62. }
  63. hs.finishedHash13.Write(hs.hello13.marshal())
  64. if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil {
  65. return err
  66. }
  67. handshakeSecret := hkdfExtract(hash, ecdheSecret, earlySecret)
  68. handshakeCtx = hs.finishedHash13.Sum(nil)
  69. clientCipher, cTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "client handshake traffic secret")
  70. hs.hsClientCipher = clientCipher
  71. serverCipher, sTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "server handshake traffic secret")
  72. c.out.setCipher(c.vers, serverCipher)
  73. serverFinishedKey := hkdfExpandLabel(hash, sTrafficSecret, nil, "finished", hashSize)
  74. hs.clientFinishedKey = hkdfExpandLabel(hash, cTrafficSecret, nil, "finished", hashSize)
  75. hs.finishedHash13.Write(hs.hello13Enc.marshal())
  76. if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil {
  77. return err
  78. }
  79. if !isPSK {
  80. if err := hs.sendCertificate13(); err != nil {
  81. return err
  82. }
  83. }
  84. verifyData := hmacOfSum(hash, hs.finishedHash13, serverFinishedKey)
  85. serverFinished := &finishedMsg{
  86. verifyData: verifyData,
  87. }
  88. hs.finishedHash13.Write(serverFinished.marshal())
  89. if _, err := c.writeRecord(recordTypeHandshake, serverFinished.marshal()); err != nil {
  90. return err
  91. }
  92. hs.masterSecret = hkdfExtract(hash, nil, handshakeSecret)
  93. handshakeCtx = hs.finishedHash13.Sum(nil)
  94. hs.appClientCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "client application traffic secret")
  95. serverCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "server application traffic secret")
  96. c.out.setCipher(c.vers, serverCipher)
  97. if hs.hello13Enc.earlyData {
  98. c.in.setCipher(c.vers, earlyClientCipher)
  99. c.phase = readingEarlyData
  100. } else if hs.clientHello.earlyData {
  101. c.in.setCipher(c.vers, hs.hsClientCipher)
  102. c.phase = discardingEarlyData
  103. } else {
  104. c.in.setCipher(c.vers, hs.hsClientCipher)
  105. c.phase = waitingClientFinished
  106. }
  107. return nil
  108. }
  109. // readClientFinished13 is called when, on the second flight of the client,
  110. // a handshake message is received. This might be immediately or after the
  111. // early data. Once done it sends the session tickets. Under c.in lock.
  112. func (hs *serverHandshakeState) readClientFinished13() error {
  113. c := hs.c
  114. c.phase = readingClientFinished
  115. msg, err := c.readHandshake()
  116. if err != nil {
  117. return err
  118. }
  119. clientFinished, ok := msg.(*finishedMsg)
  120. if !ok {
  121. c.sendAlert(alertUnexpectedMessage)
  122. return unexpectedMessageError(clientFinished, msg)
  123. }
  124. hash := hashForSuite(hs.suite)
  125. expectedVerifyData := hmacOfSum(hash, hs.finishedHash13, hs.clientFinishedKey)
  126. if len(expectedVerifyData) != len(clientFinished.verifyData) ||
  127. subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 {
  128. c.sendAlert(alertHandshakeFailure)
  129. return errors.New("tls: client's Finished message is incorrect")
  130. }
  131. hs.finishedHash13.Write(clientFinished.marshal())
  132. c.hs = nil // Discard the server handshake state
  133. c.phase = handshakeConfirmed
  134. c.in.setCipher(c.vers, hs.appClientCipher)
  135. c.in.traceErr, c.out.traceErr = nil, nil
  136. return hs.sendSessionTicket13()
  137. }
  138. func (hs *serverHandshakeState) sendCertificate13() error {
  139. c := hs.c
  140. certMsg := &certificateMsg13{
  141. certificates: hs.cert.Certificate,
  142. }
  143. hs.finishedHash13.Write(certMsg.marshal())
  144. if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
  145. return err
  146. }
  147. sigScheme, err := hs.selectTLS13SignatureScheme()
  148. if err != nil {
  149. c.sendAlert(alertInternalError)
  150. return err
  151. }
  152. sigHash := hashForSignatureScheme(sigScheme)
  153. opts := crypto.SignerOpts(sigHash)
  154. if signatureSchemeIsPSS(sigScheme) {
  155. opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
  156. }
  157. toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.finishedHash13.Sum(nil))
  158. signature, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), toSign[:], opts)
  159. if err != nil {
  160. c.sendAlert(alertInternalError)
  161. return err
  162. }
  163. verifyMsg := &certificateVerifyMsg{
  164. hasSignatureAndHash: true,
  165. signatureAndHash: sigSchemeToSigAndHash(sigScheme),
  166. signature: signature,
  167. }
  168. hs.finishedHash13.Write(verifyMsg.marshal())
  169. if _, err := c.writeRecord(recordTypeHandshake, verifyMsg.marshal()); err != nil {
  170. return err
  171. }
  172. return nil
  173. }
  174. func (c *Conn) handleEndOfEarlyData() {
  175. if c.phase != readingEarlyData || c.vers < VersionTLS13 {
  176. c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
  177. return
  178. }
  179. c.phase = waitingClientFinished
  180. c.in.setCipher(c.vers, c.hs.hsClientCipher)
  181. }
  182. // selectTLS13SignatureScheme chooses the SignatureScheme for the CertificateVerify
  183. // based on the certificate type and client supported schemes. If no overlap is found,
  184. // a fallback is selected.
  185. //
  186. // See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.4.1.2
  187. func (hs *serverHandshakeState) selectTLS13SignatureScheme() (sigScheme SignatureScheme, err error) {
  188. var supportedSchemes []SignatureScheme
  189. signer, ok := hs.cert.PrivateKey.(crypto.Signer)
  190. if !ok {
  191. return 0, errors.New("tls: certificate private key does not implement crypto.Signer")
  192. }
  193. pk := signer.Public()
  194. if _, ok := pk.(*rsa.PublicKey); ok {
  195. sigScheme = PSSWithSHA256
  196. supportedSchemes = []SignatureScheme{PSSWithSHA256, PSSWithSHA384, PSSWithSHA512}
  197. } else if pk, ok := pk.(*ecdsa.PublicKey); ok {
  198. switch pk.Curve {
  199. case elliptic.P256():
  200. sigScheme = ECDSAWithP256AndSHA256
  201. supportedSchemes = []SignatureScheme{ECDSAWithP256AndSHA256}
  202. case elliptic.P384():
  203. sigScheme = ECDSAWithP384AndSHA384
  204. supportedSchemes = []SignatureScheme{ECDSAWithP384AndSHA384}
  205. case elliptic.P521():
  206. sigScheme = ECDSAWithP521AndSHA512
  207. supportedSchemes = []SignatureScheme{ECDSAWithP521AndSHA512}
  208. default:
  209. return 0, errors.New("tls: unknown ECDSA certificate curve")
  210. }
  211. } else {
  212. return 0, errors.New("tls: unknown certificate key type")
  213. }
  214. for _, ss := range supportedSchemes {
  215. for _, cs := range hs.clientHello.signatureAndHashes {
  216. if ss == sigAndHashToSigScheme(cs) {
  217. return ss, nil
  218. }
  219. }
  220. }
  221. return sigScheme, nil
  222. }
  223. func sigSchemeToSigAndHash(s SignatureScheme) (sah signatureAndHash) {
  224. sah.hash = byte(s >> 8)
  225. sah.signature = byte(s)
  226. return
  227. }
  228. func sigAndHashToSigScheme(sah signatureAndHash) SignatureScheme {
  229. return SignatureScheme(sah.hash)<<8 | SignatureScheme(sah.signature)
  230. }
  231. func signatureSchemeIsPSS(s SignatureScheme) bool {
  232. return s == PSSWithSHA256 || s == PSSWithSHA384 || s == PSSWithSHA512
  233. }
  234. // hashForSignatureScheme returns the Hash used by a SignatureScheme which is
  235. // supported by selectTLS13SignatureScheme.
  236. func hashForSignatureScheme(ss SignatureScheme) crypto.Hash {
  237. switch ss {
  238. case PSSWithSHA256, ECDSAWithP256AndSHA256:
  239. return crypto.SHA256
  240. case PSSWithSHA384, ECDSAWithP384AndSHA384:
  241. return crypto.SHA384
  242. case PSSWithSHA512, ECDSAWithP521AndSHA512:
  243. return crypto.SHA512
  244. default:
  245. panic("unsupported SignatureScheme passed to hashForSignatureScheme")
  246. }
  247. }
  248. func hashForSuite(suite *cipherSuite) crypto.Hash {
  249. if suite.flags&suiteSHA384 != 0 {
  250. return crypto.SHA384
  251. }
  252. return crypto.SHA256
  253. }
  254. func prepareDigitallySigned(hash crypto.Hash, context string, data []byte) []byte {
  255. message := bytes.Repeat([]byte{32}, 64)
  256. message = append(message, context...)
  257. message = append(message, 0)
  258. message = append(message, data...)
  259. h := hash.New()
  260. h.Write(message)
  261. return h.Sum(nil)
  262. }
  263. func (c *Config) generateKeyShare(curveID CurveID) ([]byte, keyShare, error) {
  264. if curveID == X25519 {
  265. var scalar, public [32]byte
  266. if _, err := io.ReadFull(c.rand(), scalar[:]); err != nil {
  267. return nil, keyShare{}, err
  268. }
  269. curve25519.ScalarBaseMult(&public, &scalar)
  270. return scalar[:], keyShare{group: curveID, data: public[:]}, nil
  271. }
  272. curve, ok := curveForCurveID(curveID)
  273. if !ok {
  274. return nil, keyShare{}, errors.New("tls: preferredCurves includes unsupported curve")
  275. }
  276. privateKey, x, y, err := elliptic.GenerateKey(curve, c.rand())
  277. if err != nil {
  278. return nil, keyShare{}, err
  279. }
  280. ecdhePublic := elliptic.Marshal(curve, x, y)
  281. return privateKey, keyShare{group: curveID, data: ecdhePublic}, nil
  282. }
  283. func deriveECDHESecret(ks keyShare, pk []byte) []byte {
  284. if ks.group == X25519 {
  285. if len(ks.data) != 32 {
  286. return nil
  287. }
  288. var theirPublic, sharedKey, scalar [32]byte
  289. copy(theirPublic[:], ks.data)
  290. copy(scalar[:], pk)
  291. curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic)
  292. return sharedKey[:]
  293. }
  294. curve, ok := curveForCurveID(ks.group)
  295. if !ok {
  296. return nil
  297. }
  298. x, y := elliptic.Unmarshal(curve, ks.data)
  299. if x == nil {
  300. return nil
  301. }
  302. x, _ = curve.ScalarMult(x, y, pk)
  303. xBytes := x.Bytes()
  304. curveSize := (curve.Params().BitSize + 8 - 1) >> 3
  305. if len(xBytes) == curveSize {
  306. return xBytes
  307. }
  308. buf := make([]byte, curveSize)
  309. copy(buf[len(buf)-len(xBytes):], xBytes)
  310. return buf
  311. }
  312. func hkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L int) []byte {
  313. hkdfLabel := make([]byte, 4+len("TLS 1.3, ")+len(label)+len(hashValue))
  314. hkdfLabel[0] = byte(L >> 8)
  315. hkdfLabel[1] = byte(L)
  316. hkdfLabel[2] = byte(len("TLS 1.3, ") + len(label))
  317. copy(hkdfLabel[3:], "TLS 1.3, ")
  318. z := hkdfLabel[3+len("TLS 1.3, "):]
  319. copy(z, label)
  320. z = z[len(label):]
  321. z[0] = byte(len(hashValue))
  322. copy(z[1:], hashValue)
  323. return hkdfExpand(hash, secret, hkdfLabel, L)
  324. }
  325. func hmacOfSum(f crypto.Hash, hash hash.Hash, key []byte) []byte {
  326. h := hmac.New(f.New, key)
  327. h.Write(hash.Sum(nil))
  328. return h.Sum(nil)
  329. }
  330. func (hs *serverHandshakeState) prepareCipher(handshakeCtx, secret []byte, label string) (interface{}, []byte) {
  331. hash := hashForSuite(hs.suite)
  332. trafficSecret := hkdfExpandLabel(hash, secret, handshakeCtx, label, hash.Size())
  333. key := hkdfExpandLabel(hash, trafficSecret, nil, "key", hs.suite.keyLen)
  334. iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", 12)
  335. return hs.suite.aead(key, iv), trafficSecret
  336. }
  337. // Maximum allowed mismatch between the stated age of a ticket
  338. // and the server-observed one. See
  339. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2.
  340. const ticketAgeSkewAllowance = 10 * time.Second
  341. func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
  342. if hs.c.config.SessionTicketsDisabled {
  343. return nil, false
  344. }
  345. foundDHE := false
  346. for _, mode := range hs.clientHello.pskKeyExchangeModes {
  347. if mode == pskDHEKeyExchange {
  348. foundDHE = true
  349. break
  350. }
  351. }
  352. if !foundDHE {
  353. return nil, false
  354. }
  355. hash := hashForSuite(hs.suite)
  356. hashSize := hash.Size()
  357. for i := range hs.clientHello.psks {
  358. sessionTicket := append([]uint8{}, hs.clientHello.psks[i].identity...)
  359. serializedTicket, _ := hs.c.decryptTicket(sessionTicket)
  360. if serializedTicket == nil {
  361. continue
  362. }
  363. s := &sessionState13{}
  364. if ok := s.unmarshal(serializedTicket); !ok {
  365. continue
  366. }
  367. if s.vers != hs.c.vers {
  368. continue
  369. }
  370. clientAge := time.Duration(hs.clientHello.psks[i].obfTicketAge-s.ageAdd) * time.Millisecond
  371. serverAge := time.Since(time.Unix(int64(s.createdAt), 0))
  372. if clientAge-serverAge > ticketAgeSkewAllowance || clientAge-serverAge < -ticketAgeSkewAllowance {
  373. continue
  374. }
  375. // This enforces the stricter 0-RTT requirements on all ticket uses.
  376. // The benefit of using PSK+ECDHE without 0-RTT are small enough that
  377. // we can give them up in the edge case of changed suite or ALPN or SNI.
  378. if s.suite != hs.suite.id {
  379. continue
  380. }
  381. if s.alpnProtocol != hs.c.clientProtocol {
  382. continue
  383. }
  384. if s.SNI != hs.c.serverName {
  385. continue
  386. }
  387. earlySecret := hkdfExtract(hash, s.resumptionSecret, nil)
  388. handshakeCtx := hash.New().Sum(nil)
  389. binderKey := hkdfExpandLabel(hash, earlySecret, handshakeCtx, "resumption psk binder key", hashSize)
  390. binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize)
  391. chHash := hash.New()
  392. chHash.Write(hs.clientHello.rawTruncated)
  393. expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey)
  394. if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) == 1 {
  395. if i == 0 && hs.clientHello.earlyData {
  396. // This is a ticket intended to be used for 0-RTT
  397. if s.maxEarlyDataLen == 0 {
  398. // But we had not tagged it as such. We could close the connection
  399. // here, but instead we just ignore the ticket and the 0-RTT data.
  400. continue
  401. }
  402. if hs.c.config.Accept0RTTData {
  403. hs.c.ticketMaxEarlyData = int64(s.maxEarlyDataLen)
  404. hs.hello13Enc.earlyData = true
  405. }
  406. }
  407. hs.hello13.psk = true
  408. hs.hello13.pskIdentity = uint16(i)
  409. return earlySecret, true
  410. }
  411. }
  412. return nil, false
  413. }
  414. func (hs *serverHandshakeState) sendSessionTicket13() error {
  415. c := hs.c
  416. if c.config.SessionTicketsDisabled {
  417. return nil
  418. }
  419. foundDHE := false
  420. for _, mode := range hs.clientHello.pskKeyExchangeModes {
  421. if mode == pskDHEKeyExchange {
  422. foundDHE = true
  423. break
  424. }
  425. }
  426. if !foundDHE {
  427. return nil
  428. }
  429. hash := hashForSuite(hs.suite)
  430. handshakeCtx := hs.finishedHash13.Sum(nil)
  431. resumptionSecret := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "resumption master secret", hash.Size())
  432. ageAddBuf := make([]byte, 4)
  433. if _, err := io.ReadFull(c.config.rand(), ageAddBuf); err != nil {
  434. c.sendAlert(alertInternalError)
  435. return err
  436. }
  437. sessionState := &sessionState13{
  438. vers: c.vers,
  439. suite: hs.suite.id,
  440. ageAdd: uint32(ageAddBuf[0])<<24 | uint32(ageAddBuf[1])<<16 |
  441. uint32(ageAddBuf[2])<<8 | uint32(ageAddBuf[3]),
  442. createdAt: uint64(time.Now().Unix()),
  443. resumptionSecret: resumptionSecret,
  444. alpnProtocol: c.clientProtocol,
  445. SNI: c.serverName,
  446. maxEarlyDataLen: c.config.Max0RTTDataSize,
  447. }
  448. ticket, err := c.encryptTicket(sessionState.marshal())
  449. if err != nil {
  450. c.sendAlert(alertInternalError)
  451. return err
  452. }
  453. ticketMsg := &newSessionTicketMsg13{
  454. lifetime: 24 * 3600, // TODO(filippo)
  455. maxEarlyDataLength: c.config.Max0RTTDataSize,
  456. withEarlyDataInfo: c.config.Max0RTTDataSize > 0,
  457. ageAdd: sessionState.ageAdd,
  458. ticket: ticket,
  459. }
  460. if _, err := c.writeRecord(recordTypeHandshake, ticketMsg.marshal()); err != nil {
  461. return err
  462. }
  463. return nil
  464. }
  465. func (hs *serverHandshakeState) traceErr(err error) {
  466. if err == nil {
  467. return
  468. }
  469. if os.Getenv("TLSDEBUG") == "error" {
  470. if hs != nil && hs.clientHello != nil {
  471. os.Stderr.WriteString(hex.Dump(hs.clientHello.marshal()))
  472. } else if err == io.EOF {
  473. return // don't stack trace on EOF before CH
  474. }
  475. fmt.Fprintf(os.Stderr, "\n%s\n", debug.Stack())
  476. }
  477. }