25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

326 lines
8.4 KiB

  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package tls
  5. import (
  6. "bytes"
  7. "math/rand"
  8. "reflect"
  9. "testing"
  10. "testing/quick"
  11. )
  12. var tests = []interface{}{
  13. &clientHelloMsg{},
  14. &serverHelloMsg{},
  15. &finishedMsg{},
  16. &certificateMsg{},
  17. &certificateRequestMsg{},
  18. &certificateVerifyMsg{},
  19. &certificateStatusMsg{},
  20. &clientKeyExchangeMsg{},
  21. &nextProtoMsg{},
  22. &newSessionTicketMsg{},
  23. &sessionState{},
  24. }
  25. type testMessage interface {
  26. marshal() []byte
  27. unmarshal([]byte) bool
  28. equal(interface{}) bool
  29. }
  30. func TestMarshalUnmarshal(t *testing.T) {
  31. rand := rand.New(rand.NewSource(0))
  32. for i, iface := range tests {
  33. ty := reflect.ValueOf(iface).Type()
  34. n := 100
  35. if testing.Short() {
  36. n = 5
  37. }
  38. for j := 0; j < n; j++ {
  39. v, ok := quick.Value(ty, rand)
  40. if !ok {
  41. t.Errorf("#%d: failed to create value", i)
  42. break
  43. }
  44. m1 := v.Interface().(testMessage)
  45. marshaled := m1.marshal()
  46. m2 := iface.(testMessage)
  47. if !m2.unmarshal(marshaled) {
  48. t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
  49. break
  50. }
  51. m2.marshal() // to fill any marshal cache in the message
  52. if !m1.equal(m2) {
  53. t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
  54. break
  55. }
  56. if i >= 3 {
  57. // The first three message types (ClientHello,
  58. // ServerHello and Finished) are allowed to
  59. // have parsable prefixes because the extension
  60. // data is optional and the length of the
  61. // Finished varies across versions.
  62. for j := 0; j < len(marshaled); j++ {
  63. if m2.unmarshal(marshaled[0:j]) {
  64. t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
  65. break
  66. }
  67. }
  68. }
  69. }
  70. }
  71. }
  72. func TestFuzz(t *testing.T) {
  73. rand := rand.New(rand.NewSource(0))
  74. for _, iface := range tests {
  75. m := iface.(testMessage)
  76. for j := 0; j < 1000; j++ {
  77. len := rand.Intn(100)
  78. bytes := randomBytes(len, rand)
  79. // This just looks for crashes due to bounds errors etc.
  80. m.unmarshal(bytes)
  81. }
  82. }
  83. }
  84. func randomBytes(n int, rand *rand.Rand) []byte {
  85. r := make([]byte, n)
  86. for i := 0; i < n; i++ {
  87. r[i] = byte(rand.Int31())
  88. }
  89. return r
  90. }
  91. func randomString(n int, rand *rand.Rand) string {
  92. b := randomBytes(n, rand)
  93. return string(b)
  94. }
  95. func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  96. m := &clientHelloMsg{}
  97. m.vers = uint16(rand.Intn(65536))
  98. m.random = randomBytes(32, rand)
  99. m.sessionId = randomBytes(rand.Intn(32), rand)
  100. m.cipherSuites = make([]uint16, rand.Intn(63)+1)
  101. for i := 0; i < len(m.cipherSuites); i++ {
  102. m.cipherSuites[i] = uint16(rand.Int31())
  103. }
  104. m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
  105. if rand.Intn(10) > 5 {
  106. m.nextProtoNeg = true
  107. }
  108. if rand.Intn(10) > 5 {
  109. m.serverName = randomString(rand.Intn(255), rand)
  110. }
  111. m.ocspStapling = rand.Intn(10) > 5
  112. m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
  113. m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
  114. for i := range m.supportedCurves {
  115. m.supportedCurves[i] = CurveID(rand.Intn(30000))
  116. }
  117. if rand.Intn(10) > 5 {
  118. m.ticketSupported = true
  119. if rand.Intn(10) > 5 {
  120. m.sessionTicket = randomBytes(rand.Intn(300), rand)
  121. }
  122. }
  123. if rand.Intn(10) > 5 {
  124. m.signatureAndHashes = supportedSignatureAlgorithms
  125. }
  126. m.alpnProtocols = make([]string, rand.Intn(5))
  127. for i := range m.alpnProtocols {
  128. m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
  129. }
  130. if rand.Intn(10) > 5 {
  131. m.scts = true
  132. }
  133. return reflect.ValueOf(m)
  134. }
  135. func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  136. m := &serverHelloMsg{}
  137. m.vers = uint16(rand.Intn(65536))
  138. m.random = randomBytes(32, rand)
  139. m.sessionId = randomBytes(rand.Intn(32), rand)
  140. m.cipherSuite = uint16(rand.Int31())
  141. m.compressionMethod = uint8(rand.Intn(256))
  142. if rand.Intn(10) > 5 {
  143. m.nextProtoNeg = true
  144. n := rand.Intn(10)
  145. m.nextProtos = make([]string, n)
  146. for i := 0; i < n; i++ {
  147. m.nextProtos[i] = randomString(20, rand)
  148. }
  149. }
  150. if rand.Intn(10) > 5 {
  151. m.ocspStapling = true
  152. }
  153. if rand.Intn(10) > 5 {
  154. m.ticketSupported = true
  155. }
  156. m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
  157. if rand.Intn(10) > 5 {
  158. numSCTs := rand.Intn(4)
  159. m.scts = make([][]byte, numSCTs)
  160. for i := range m.scts {
  161. m.scts[i] = randomBytes(rand.Intn(500), rand)
  162. }
  163. }
  164. return reflect.ValueOf(m)
  165. }
  166. func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  167. m := &certificateMsg{}
  168. numCerts := rand.Intn(20)
  169. m.certificates = make([][]byte, numCerts)
  170. for i := 0; i < numCerts; i++ {
  171. m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
  172. }
  173. return reflect.ValueOf(m)
  174. }
  175. func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  176. m := &certificateRequestMsg{}
  177. m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
  178. numCAs := rand.Intn(100)
  179. m.certificateAuthorities = make([][]byte, numCAs)
  180. for i := 0; i < numCAs; i++ {
  181. m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
  182. }
  183. return reflect.ValueOf(m)
  184. }
  185. func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  186. m := &certificateVerifyMsg{}
  187. m.signature = randomBytes(rand.Intn(15)+1, rand)
  188. return reflect.ValueOf(m)
  189. }
  190. func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  191. m := &certificateStatusMsg{}
  192. if rand.Intn(10) > 5 {
  193. m.statusType = statusTypeOCSP
  194. m.response = randomBytes(rand.Intn(10)+1, rand)
  195. } else {
  196. m.statusType = 42
  197. }
  198. return reflect.ValueOf(m)
  199. }
  200. func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  201. m := &clientKeyExchangeMsg{}
  202. m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
  203. return reflect.ValueOf(m)
  204. }
  205. func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  206. m := &finishedMsg{}
  207. m.verifyData = randomBytes(12, rand)
  208. return reflect.ValueOf(m)
  209. }
  210. func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  211. m := &nextProtoMsg{}
  212. m.proto = randomString(rand.Intn(255), rand)
  213. return reflect.ValueOf(m)
  214. }
  215. func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  216. m := &newSessionTicketMsg{}
  217. m.ticket = randomBytes(rand.Intn(4), rand)
  218. return reflect.ValueOf(m)
  219. }
  220. func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
  221. s := &sessionState{}
  222. s.vers = uint16(rand.Intn(10000))
  223. s.cipherSuite = uint16(rand.Intn(10000))
  224. s.masterSecret = randomBytes(rand.Intn(100), rand)
  225. numCerts := rand.Intn(20)
  226. s.certificates = make([][]byte, numCerts)
  227. for i := 0; i < numCerts; i++ {
  228. s.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
  229. }
  230. return reflect.ValueOf(s)
  231. }
  232. func TestRejectEmptySCTList(t *testing.T) {
  233. // https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that
  234. // empty SCT lists are invalid.
  235. var random [32]byte
  236. sct := []byte{0x42, 0x42, 0x42, 0x42}
  237. serverHello := serverHelloMsg{
  238. vers: VersionTLS12,
  239. random: random[:],
  240. scts: [][]byte{sct},
  241. }
  242. serverHelloBytes := serverHello.marshal()
  243. var serverHelloCopy serverHelloMsg
  244. if !serverHelloCopy.unmarshal(serverHelloBytes) {
  245. t.Fatal("Failed to unmarshal initial message")
  246. }
  247. // Change serverHelloBytes so that the SCT list is empty
  248. i := bytes.Index(serverHelloBytes, sct)
  249. if i < 0 {
  250. t.Fatal("Cannot find SCT in ServerHello")
  251. }
  252. var serverHelloEmptySCT []byte
  253. serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
  254. // Append the extension length and SCT list length for an empty list.
  255. serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
  256. serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
  257. // Update the handshake message length.
  258. serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
  259. serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
  260. serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
  261. // Update the extensions length
  262. serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
  263. serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
  264. if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
  265. t.Fatal("Unmarshaled ServerHello with empty SCT list")
  266. }
  267. }
  268. func TestRejectEmptySCT(t *testing.T) {
  269. // Not only must the SCT list be non-empty, but the SCT elements must
  270. // not be zero length.
  271. var random [32]byte
  272. serverHello := serverHelloMsg{
  273. vers: VersionTLS12,
  274. random: random[:],
  275. scts: [][]byte{nil},
  276. }
  277. serverHelloBytes := serverHello.marshal()
  278. var serverHelloCopy serverHelloMsg
  279. if serverHelloCopy.unmarshal(serverHelloBytes) {
  280. t.Fatal("Unmarshaled ServerHello with zero-length SCT")
  281. }
  282. }