Alternative TLS implementation in Go
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

299 строки
7.3 KiB

  1. // Copyright 2012 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package tls
  5. import (
  6. "bytes"
  7. "crypto/aes"
  8. "crypto/cipher"
  9. "crypto/hmac"
  10. "crypto/sha256"
  11. "crypto/subtle"
  12. "errors"
  13. "io"
  14. )
  15. // sessionState contains the information that is serialized into a session
  16. // ticket in order to later resume a connection.
  17. type sessionState struct {
  18. vers uint16
  19. cipherSuite uint16
  20. masterSecret []byte
  21. certificates [][]byte
  22. // usedOldKey is true if the ticket from which this session came from
  23. // was encrypted with an older key and thus should be refreshed.
  24. usedOldKey bool
  25. }
  26. func (s *sessionState) equal(i interface{}) bool {
  27. s1, ok := i.(*sessionState)
  28. if !ok {
  29. return false
  30. }
  31. if s.vers != s1.vers ||
  32. s.cipherSuite != s1.cipherSuite ||
  33. !bytes.Equal(s.masterSecret, s1.masterSecret) {
  34. return false
  35. }
  36. if len(s.certificates) != len(s1.certificates) {
  37. return false
  38. }
  39. for i := range s.certificates {
  40. if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
  41. return false
  42. }
  43. }
  44. return true
  45. }
  46. func (s *sessionState) marshal() []byte {
  47. length := 2 + 2 + 2 + len(s.masterSecret) + 2
  48. for _, cert := range s.certificates {
  49. length += 4 + len(cert)
  50. }
  51. ret := make([]byte, length)
  52. x := ret
  53. x[0] = byte(s.vers >> 8)
  54. x[1] = byte(s.vers)
  55. x[2] = byte(s.cipherSuite >> 8)
  56. x[3] = byte(s.cipherSuite)
  57. x[4] = byte(len(s.masterSecret) >> 8)
  58. x[5] = byte(len(s.masterSecret))
  59. x = x[6:]
  60. copy(x, s.masterSecret)
  61. x = x[len(s.masterSecret):]
  62. x[0] = byte(len(s.certificates) >> 8)
  63. x[1] = byte(len(s.certificates))
  64. x = x[2:]
  65. for _, cert := range s.certificates {
  66. x[0] = byte(len(cert) >> 24)
  67. x[1] = byte(len(cert) >> 16)
  68. x[2] = byte(len(cert) >> 8)
  69. x[3] = byte(len(cert))
  70. copy(x[4:], cert)
  71. x = x[4+len(cert):]
  72. }
  73. return ret
  74. }
  75. func (s *sessionState) unmarshal(data []byte) bool {
  76. if len(data) < 8 {
  77. return false
  78. }
  79. s.vers = uint16(data[0])<<8 | uint16(data[1])
  80. s.cipherSuite = uint16(data[2])<<8 | uint16(data[3])
  81. masterSecretLen := int(data[4])<<8 | int(data[5])
  82. data = data[6:]
  83. if len(data) < masterSecretLen {
  84. return false
  85. }
  86. s.masterSecret = data[:masterSecretLen]
  87. data = data[masterSecretLen:]
  88. if len(data) < 2 {
  89. return false
  90. }
  91. numCerts := int(data[0])<<8 | int(data[1])
  92. data = data[2:]
  93. s.certificates = make([][]byte, numCerts)
  94. for i := range s.certificates {
  95. if len(data) < 4 {
  96. return false
  97. }
  98. certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  99. data = data[4:]
  100. if certLen < 0 {
  101. return false
  102. }
  103. if len(data) < certLen {
  104. return false
  105. }
  106. s.certificates[i] = data[:certLen]
  107. data = data[certLen:]
  108. }
  109. return len(data) == 0
  110. }
  111. type sessionState13 struct {
  112. vers uint16
  113. suite uint16
  114. ageAdd uint32
  115. createdAt uint64
  116. maxEarlyDataLen uint32
  117. resumptionSecret []byte
  118. alpnProtocol string
  119. SNI string
  120. }
  121. func (s *sessionState13) equal(i interface{}) bool {
  122. s1, ok := i.(*sessionState13)
  123. if !ok {
  124. return false
  125. }
  126. return s.vers == s1.vers &&
  127. s.suite == s1.suite &&
  128. s.ageAdd == s1.ageAdd &&
  129. s.createdAt == s1.createdAt &&
  130. s.maxEarlyDataLen == s1.maxEarlyDataLen &&
  131. bytes.Equal(s.resumptionSecret, s1.resumptionSecret) &&
  132. s.alpnProtocol == s1.alpnProtocol &&
  133. s.SNI == s1.SNI
  134. }
  135. func (s *sessionState13) marshal() []byte {
  136. length := 2 + 2 + 4 + 8 + 4 + 2 + len(s.resumptionSecret) + 2 + len(s.alpnProtocol) + 2 + len(s.SNI)
  137. x := make([]byte, length)
  138. x[0] = byte(s.vers >> 8)
  139. x[1] = byte(s.vers)
  140. x[2] = byte(s.suite >> 8)
  141. x[3] = byte(s.suite)
  142. x[4] = byte(s.ageAdd >> 24)
  143. x[5] = byte(s.ageAdd >> 16)
  144. x[6] = byte(s.ageAdd >> 8)
  145. x[7] = byte(s.ageAdd)
  146. x[8] = byte(s.createdAt >> 56)
  147. x[9] = byte(s.createdAt >> 48)
  148. x[10] = byte(s.createdAt >> 40)
  149. x[11] = byte(s.createdAt >> 32)
  150. x[12] = byte(s.createdAt >> 24)
  151. x[13] = byte(s.createdAt >> 16)
  152. x[14] = byte(s.createdAt >> 8)
  153. x[15] = byte(s.createdAt)
  154. x[16] = byte(s.maxEarlyDataLen >> 24)
  155. x[17] = byte(s.maxEarlyDataLen >> 16)
  156. x[18] = byte(s.maxEarlyDataLen >> 8)
  157. x[19] = byte(s.maxEarlyDataLen)
  158. x[20] = byte(len(s.resumptionSecret) >> 8)
  159. x[21] = byte(len(s.resumptionSecret))
  160. copy(x[22:], s.resumptionSecret)
  161. z := x[22+len(s.resumptionSecret):]
  162. z[0] = byte(len(s.alpnProtocol) >> 8)
  163. z[1] = byte(len(s.alpnProtocol))
  164. copy(z[2:], s.alpnProtocol)
  165. z = z[2+len(s.alpnProtocol):]
  166. z[0] = byte(len(s.SNI) >> 8)
  167. z[1] = byte(len(s.SNI))
  168. copy(z[2:], s.SNI)
  169. return x
  170. }
  171. func (s *sessionState13) unmarshal(data []byte) bool {
  172. if len(data) < 24 {
  173. return false
  174. }
  175. s.vers = uint16(data[0])<<8 | uint16(data[1])
  176. s.suite = uint16(data[2])<<8 | uint16(data[3])
  177. s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
  178. s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 |
  179. uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15])
  180. s.maxEarlyDataLen = uint32(data[16])<<24 | uint32(data[17])<<16 | uint32(data[18])<<8 | uint32(data[19])
  181. l := int(data[20])<<8 | int(data[21])
  182. if len(data) < 22+l+2 {
  183. return false
  184. }
  185. s.resumptionSecret = data[22 : 22+l]
  186. z := data[22+l:]
  187. l = int(z[0])<<8 | int(z[1])
  188. if len(z) < 2+l+2 {
  189. return false
  190. }
  191. s.alpnProtocol = string(z[2 : 2+l])
  192. z = z[2+l:]
  193. l = int(z[0])<<8 | int(z[1])
  194. if len(z) != 2+l {
  195. return false
  196. }
  197. s.SNI = string(z[2 : 2+l])
  198. return true
  199. }
  200. func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) {
  201. encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(serialized)+sha256.Size)
  202. keyName := encrypted[:ticketKeyNameLen]
  203. iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
  204. macBytes := encrypted[len(encrypted)-sha256.Size:]
  205. if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
  206. return nil, err
  207. }
  208. key := c.config.ticketKeys()[0]
  209. copy(keyName, key.keyName[:])
  210. block, err := aes.NewCipher(key.aesKey[:])
  211. if err != nil {
  212. return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
  213. }
  214. cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], serialized)
  215. mac := hmac.New(sha256.New, key.hmacKey[:])
  216. mac.Write(encrypted[:len(encrypted)-sha256.Size])
  217. mac.Sum(macBytes[:0])
  218. return encrypted, nil
  219. }
  220. func (c *Conn) decryptTicket(encrypted []byte) (serialized []byte, usedOldKey bool) {
  221. if c.config.SessionTicketsDisabled ||
  222. len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
  223. return nil, false
  224. }
  225. keyName := encrypted[:ticketKeyNameLen]
  226. iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
  227. macBytes := encrypted[len(encrypted)-sha256.Size:]
  228. keys := c.config.ticketKeys()
  229. keyIndex := -1
  230. for i, candidateKey := range keys {
  231. if bytes.Equal(keyName, candidateKey.keyName[:]) {
  232. keyIndex = i
  233. break
  234. }
  235. }
  236. if keyIndex == -1 {
  237. return nil, false
  238. }
  239. key := &keys[keyIndex]
  240. mac := hmac.New(sha256.New, key.hmacKey[:])
  241. mac.Write(encrypted[:len(encrypted)-sha256.Size])
  242. expected := mac.Sum(nil)
  243. if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
  244. return nil, false
  245. }
  246. block, err := aes.NewCipher(key.aesKey[:])
  247. if err != nil {
  248. return nil, false
  249. }
  250. ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
  251. plaintext := ciphertext
  252. cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
  253. return plaintext, keyIndex > 0
  254. }