Alternative TLS implementation in Go
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

334 lignes
8.6 KiB

  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package tls
  5. import (
  6. "bytes"
  7. "math/rand"
  8. "reflect"
  9. "strings"
  10. "testing"
  11. "testing/quick"
  12. )
  13. var tests = []interface{}{
  14. &clientHelloMsg{},
  15. &serverHelloMsg{},
  16. &finishedMsg{},
  17. &certificateMsg{},
  18. &certificateRequestMsg{},
  19. &certificateVerifyMsg{},
  20. &certificateStatusMsg{},
  21. &clientKeyExchangeMsg{},
  22. &nextProtoMsg{},
  23. &newSessionTicketMsg{},
  24. &sessionState{},
  25. }
  26. type testMessage interface {
  27. marshal() []byte
  28. unmarshal([]byte) bool
  29. equal(interface{}) bool
  30. }
  31. func TestMarshalUnmarshal(t *testing.T) {
  32. rand := rand.New(rand.NewSource(0))
  33. for i, iface := range tests {
  34. ty := reflect.ValueOf(iface).Type()
  35. n := 100
  36. if testing.Short() {
  37. n = 5
  38. }
  39. for j := 0; j < n; j++ {
  40. v, ok := quick.Value(ty, rand)
  41. if !ok {
  42. t.Errorf("#%d: failed to create value", i)
  43. break
  44. }
  45. m1 := v.Interface().(testMessage)
  46. marshaled := m1.marshal()
  47. m2 := iface.(testMessage)
  48. if !m2.unmarshal(marshaled) {
  49. t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
  50. break
  51. }
  52. m2.marshal() // to fill any marshal cache in the message
  53. if !m1.equal(m2) {
  54. t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
  55. break
  56. }
  57. if i >= 3 {
  58. // The first three message types (ClientHello,
  59. // ServerHello and Finished) are allowed to
  60. // have parsable prefixes because the extension
  61. // data is optional and the length of the
  62. // Finished varies across versions.
  63. for j := 0; j < len(marshaled); j++ {
  64. if m2.unmarshal(marshaled[0:j]) {
  65. t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
  66. break
  67. }
  68. }
  69. }
  70. }
  71. }
  72. }
  73. func TestFuzz(t *testing.T) {
  74. rand := rand.New(rand.NewSource(0))
  75. for _, iface := range tests {
  76. m := iface.(testMessage)
  77. for j := 0; j < 1000; j++ {
  78. len := rand.Intn(100)
  79. bytes := randomBytes(len, rand)
  80. // This just looks for crashes due to bounds errors etc.
  81. m.unmarshal(bytes)
  82. }
  83. }
  84. }
  85. func randomBytes(n int, rand *rand.Rand) []byte {
  86. r := make([]byte, n)
  87. for i := 0; i < n; i++ {
  88. r[i] = byte(rand.Int31())
  89. }
  90. return r
  91. }
  92. func randomString(n int, rand *rand.Rand) string {
  93. b := randomBytes(n, rand)
  94. return string(b)
  95. }
  96. func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  97. m := &clientHelloMsg{}
  98. m.vers = uint16(rand.Intn(65536))
  99. m.random = randomBytes(32, rand)
  100. m.sessionId = randomBytes(rand.Intn(32), rand)
  101. m.cipherSuites = make([]uint16, rand.Intn(63)+1)
  102. for i := 0; i < len(m.cipherSuites); i++ {
  103. cs := uint16(rand.Int31())
  104. if cs == scsvRenegotiation {
  105. cs += 1
  106. }
  107. m.cipherSuites[i] = cs
  108. }
  109. m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
  110. if rand.Intn(10) > 5 {
  111. m.nextProtoNeg = true
  112. }
  113. if rand.Intn(10) > 5 {
  114. m.serverName = randomString(rand.Intn(255), rand)
  115. for strings.HasSuffix(m.serverName, ".") {
  116. m.serverName = m.serverName[:len(m.serverName)-1]
  117. }
  118. }
  119. m.ocspStapling = rand.Intn(10) > 5
  120. m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
  121. m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
  122. for i := range m.supportedCurves {
  123. m.supportedCurves[i] = CurveID(rand.Intn(30000))
  124. }
  125. if rand.Intn(10) > 5 {
  126. m.ticketSupported = true
  127. if rand.Intn(10) > 5 {
  128. m.sessionTicket = randomBytes(rand.Intn(300), rand)
  129. }
  130. }
  131. if rand.Intn(10) > 5 {
  132. m.signatureAndHashes = supportedSignatureAlgorithms
  133. }
  134. m.alpnProtocols = make([]string, rand.Intn(5))
  135. for i := range m.alpnProtocols {
  136. m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
  137. }
  138. if rand.Intn(10) > 5 {
  139. m.scts = true
  140. }
  141. return reflect.ValueOf(m)
  142. }
  143. func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  144. m := &serverHelloMsg{}
  145. m.vers = uint16(rand.Intn(65536))
  146. m.random = randomBytes(32, rand)
  147. m.sessionId = randomBytes(rand.Intn(32), rand)
  148. m.cipherSuite = uint16(rand.Int31())
  149. m.compressionMethod = uint8(rand.Intn(256))
  150. if rand.Intn(10) > 5 {
  151. m.nextProtoNeg = true
  152. n := rand.Intn(10)
  153. m.nextProtos = make([]string, n)
  154. for i := 0; i < n; i++ {
  155. m.nextProtos[i] = randomString(20, rand)
  156. }
  157. }
  158. if rand.Intn(10) > 5 {
  159. m.ocspStapling = true
  160. }
  161. if rand.Intn(10) > 5 {
  162. m.ticketSupported = true
  163. }
  164. m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
  165. if rand.Intn(10) > 5 {
  166. numSCTs := rand.Intn(4)
  167. m.scts = make([][]byte, numSCTs)
  168. for i := range m.scts {
  169. m.scts[i] = randomBytes(rand.Intn(500), rand)
  170. }
  171. }
  172. return reflect.ValueOf(m)
  173. }
  174. func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  175. m := &certificateMsg{}
  176. numCerts := rand.Intn(20)
  177. m.certificates = make([][]byte, numCerts)
  178. for i := 0; i < numCerts; i++ {
  179. m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
  180. }
  181. return reflect.ValueOf(m)
  182. }
  183. func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  184. m := &certificateRequestMsg{}
  185. m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
  186. numCAs := rand.Intn(100)
  187. m.certificateAuthorities = make([][]byte, numCAs)
  188. for i := 0; i < numCAs; i++ {
  189. m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
  190. }
  191. return reflect.ValueOf(m)
  192. }
  193. func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  194. m := &certificateVerifyMsg{}
  195. m.signature = randomBytes(rand.Intn(15)+1, rand)
  196. return reflect.ValueOf(m)
  197. }
  198. func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  199. m := &certificateStatusMsg{}
  200. if rand.Intn(10) > 5 {
  201. m.statusType = statusTypeOCSP
  202. m.response = randomBytes(rand.Intn(10)+1, rand)
  203. } else {
  204. m.statusType = 42
  205. }
  206. return reflect.ValueOf(m)
  207. }
  208. func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  209. m := &clientKeyExchangeMsg{}
  210. m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
  211. return reflect.ValueOf(m)
  212. }
  213. func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  214. m := &finishedMsg{}
  215. m.verifyData = randomBytes(12, rand)
  216. return reflect.ValueOf(m)
  217. }
  218. func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  219. m := &nextProtoMsg{}
  220. m.proto = randomString(rand.Intn(255), rand)
  221. return reflect.ValueOf(m)
  222. }
  223. func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  224. m := &newSessionTicketMsg{}
  225. m.ticket = randomBytes(rand.Intn(4), rand)
  226. return reflect.ValueOf(m)
  227. }
  228. func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
  229. s := &sessionState{}
  230. s.vers = uint16(rand.Intn(10000))
  231. s.cipherSuite = uint16(rand.Intn(10000))
  232. s.masterSecret = randomBytes(rand.Intn(100), rand)
  233. numCerts := rand.Intn(20)
  234. s.certificates = make([][]byte, numCerts)
  235. for i := 0; i < numCerts; i++ {
  236. s.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
  237. }
  238. return reflect.ValueOf(s)
  239. }
  240. func TestRejectEmptySCTList(t *testing.T) {
  241. // https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that
  242. // empty SCT lists are invalid.
  243. var random [32]byte
  244. sct := []byte{0x42, 0x42, 0x42, 0x42}
  245. serverHello := serverHelloMsg{
  246. vers: VersionTLS12,
  247. random: random[:],
  248. scts: [][]byte{sct},
  249. }
  250. serverHelloBytes := serverHello.marshal()
  251. var serverHelloCopy serverHelloMsg
  252. if !serverHelloCopy.unmarshal(serverHelloBytes) {
  253. t.Fatal("Failed to unmarshal initial message")
  254. }
  255. // Change serverHelloBytes so that the SCT list is empty
  256. i := bytes.Index(serverHelloBytes, sct)
  257. if i < 0 {
  258. t.Fatal("Cannot find SCT in ServerHello")
  259. }
  260. var serverHelloEmptySCT []byte
  261. serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
  262. // Append the extension length and SCT list length for an empty list.
  263. serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
  264. serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
  265. // Update the handshake message length.
  266. serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
  267. serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
  268. serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
  269. // Update the extensions length
  270. serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
  271. serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
  272. if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
  273. t.Fatal("Unmarshaled ServerHello with empty SCT list")
  274. }
  275. }
  276. func TestRejectEmptySCT(t *testing.T) {
  277. // Not only must the SCT list be non-empty, but the SCT elements must
  278. // not be zero length.
  279. var random [32]byte
  280. serverHello := serverHelloMsg{
  281. vers: VersionTLS12,
  282. random: random[:],
  283. scts: [][]byte{nil},
  284. }
  285. serverHelloBytes := serverHello.marshal()
  286. var serverHelloCopy serverHelloMsg
  287. if serverHelloCopy.unmarshal(serverHelloBytes) {
  288. t.Fatal("Unmarshaled ServerHello with zero-length SCT")
  289. }
  290. }