You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

371 lines
9.8 KiB

  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package tls
  5. import (
  6. "bytes"
  7. "math/rand"
  8. "reflect"
  9. "strings"
  10. "testing"
  11. "testing/quick"
  12. )
  13. var tests = []interface{}{
  14. &clientHelloMsg{},
  15. &serverHelloMsg{},
  16. &finishedMsg{},
  17. &certificateMsg{},
  18. &certificateRequestMsg{},
  19. &certificateVerifyMsg{},
  20. &certificateStatusMsg{},
  21. &clientKeyExchangeMsg{},
  22. &nextProtoMsg{},
  23. &newSessionTicketMsg{},
  24. &sessionState{},
  25. &serverHelloMsg13{},
  26. &encryptedExtensionsMsg{},
  27. &certificateMsg13{},
  28. }
  29. type testMessage interface {
  30. marshal() []byte
  31. unmarshal([]byte) bool
  32. equal(interface{}) bool
  33. }
  34. func TestMarshalUnmarshal(t *testing.T) {
  35. rand := rand.New(rand.NewSource(0))
  36. for i, iface := range tests {
  37. ty := reflect.ValueOf(iface).Type()
  38. n := 100
  39. if testing.Short() {
  40. n = 5
  41. }
  42. for j := 0; j < n; j++ {
  43. v, ok := quick.Value(ty, rand)
  44. if !ok {
  45. t.Errorf("#%d: failed to create value", i)
  46. break
  47. }
  48. m1 := v.Interface().(testMessage)
  49. marshaled := m1.marshal()
  50. m2 := iface.(testMessage)
  51. if !m2.unmarshal(marshaled) {
  52. t.Errorf("#%d.%d failed to unmarshal %#v %x", i, j, m1, marshaled)
  53. break
  54. }
  55. m2.marshal() // to fill any marshal cache in the message
  56. if !m1.equal(m2) {
  57. t.Errorf("#%d.%d got:%#v want:%#v %x", i, j, m2, m1, marshaled)
  58. break
  59. }
  60. if i >= 3 {
  61. // The first three message types (ClientHello,
  62. // ServerHello and Finished) are allowed to
  63. // have parsable prefixes because the extension
  64. // data is optional and the length of the
  65. // Finished varies across versions.
  66. for j := 0; j < len(marshaled); j++ {
  67. if m2.unmarshal(marshaled[0:j]) {
  68. t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
  69. break
  70. }
  71. }
  72. }
  73. }
  74. }
  75. }
  76. func TestFuzz(t *testing.T) {
  77. rand := rand.New(rand.NewSource(0))
  78. for _, iface := range tests {
  79. m := iface.(testMessage)
  80. for j := 0; j < 1000; j++ {
  81. len := rand.Intn(100)
  82. bytes := randomBytes(len, rand)
  83. // This just looks for crashes due to bounds errors etc.
  84. m.unmarshal(bytes)
  85. }
  86. }
  87. }
  88. func randomBytes(n int, rand *rand.Rand) []byte {
  89. r := make([]byte, n)
  90. for i := 0; i < n; i++ {
  91. r[i] = byte(rand.Int31())
  92. }
  93. return r
  94. }
  95. func randomString(n int, rand *rand.Rand) string {
  96. b := randomBytes(n, rand)
  97. return string(b)
  98. }
  99. func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  100. m := &clientHelloMsg{}
  101. m.vers = uint16(rand.Intn(65536))
  102. m.random = randomBytes(32, rand)
  103. m.sessionId = randomBytes(rand.Intn(32), rand)
  104. m.cipherSuites = make([]uint16, rand.Intn(63)+1)
  105. for i := 0; i < len(m.cipherSuites); i++ {
  106. m.cipherSuites[i] = uint16(rand.Int31())
  107. }
  108. m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
  109. if rand.Intn(10) > 5 {
  110. m.nextProtoNeg = true
  111. }
  112. if rand.Intn(10) > 5 {
  113. m.serverName = randomString(rand.Intn(255), rand)
  114. for strings.HasSuffix(m.serverName, ".") {
  115. m.serverName = m.serverName[:len(m.serverName)-1]
  116. }
  117. }
  118. m.ocspStapling = rand.Intn(10) > 5
  119. m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
  120. m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
  121. for i := range m.supportedCurves {
  122. m.supportedCurves[i] = CurveID(rand.Intn(30000))
  123. }
  124. if rand.Intn(10) > 5 {
  125. m.ticketSupported = true
  126. if rand.Intn(10) > 5 {
  127. m.sessionTicket = randomBytes(rand.Intn(300), rand)
  128. }
  129. }
  130. if rand.Intn(10) > 5 {
  131. m.signatureAndHashes = supportedSignatureAlgorithms
  132. }
  133. m.alpnProtocols = make([]string, rand.Intn(5))
  134. for i := range m.alpnProtocols {
  135. m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
  136. }
  137. if rand.Intn(10) > 5 {
  138. m.scts = true
  139. }
  140. m.keyShares = make([]keyShare, rand.Intn(4))
  141. for i := range m.keyShares {
  142. m.keyShares[i].group = CurveID(rand.Intn(30000))
  143. m.keyShares[i].data = randomBytes(rand.Intn(300), rand)
  144. }
  145. m.supportedVersions = make([]uint16, rand.Intn(5))
  146. for i := range m.supportedVersions {
  147. m.supportedVersions[i] = uint16(rand.Intn(30000))
  148. }
  149. return reflect.ValueOf(m)
  150. }
  151. func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  152. m := &serverHelloMsg{}
  153. m.vers = uint16(rand.Intn(65536))
  154. m.random = randomBytes(32, rand)
  155. m.sessionId = randomBytes(rand.Intn(32), rand)
  156. m.cipherSuite = uint16(rand.Int31())
  157. m.compressionMethod = uint8(rand.Intn(256))
  158. if rand.Intn(10) > 5 {
  159. m.nextProtoNeg = true
  160. n := rand.Intn(10)
  161. m.nextProtos = make([]string, n)
  162. for i := 0; i < n; i++ {
  163. m.nextProtos[i] = randomString(20, rand)
  164. }
  165. }
  166. if rand.Intn(10) > 5 {
  167. m.ocspStapling = true
  168. }
  169. if rand.Intn(10) > 5 {
  170. m.ticketSupported = true
  171. }
  172. m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
  173. if rand.Intn(10) > 5 {
  174. numSCTs := rand.Intn(4)
  175. m.scts = make([][]byte, numSCTs)
  176. for i := range m.scts {
  177. m.scts[i] = randomBytes(rand.Intn(500), rand)
  178. }
  179. }
  180. return reflect.ValueOf(m)
  181. }
  182. func (*serverHelloMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
  183. m := &serverHelloMsg13{}
  184. m.vers = uint16(rand.Intn(65536))
  185. m.random = randomBytes(32, rand)
  186. m.cipherSuite = uint16(rand.Int31())
  187. m.keyShare.group = CurveID(rand.Intn(30000))
  188. m.keyShare.data = randomBytes(rand.Intn(300), rand)
  189. return reflect.ValueOf(m)
  190. }
  191. func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  192. m := &encryptedExtensionsMsg{}
  193. m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
  194. return reflect.ValueOf(m)
  195. }
  196. func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  197. m := &certificateMsg{}
  198. numCerts := rand.Intn(20)
  199. m.certificates = make([][]byte, numCerts)
  200. for i := 0; i < numCerts; i++ {
  201. m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
  202. }
  203. return reflect.ValueOf(m)
  204. }
  205. func (*certificateMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
  206. m := &certificateMsg13{}
  207. numCerts := rand.Intn(20)
  208. m.certificates = make([][]byte, numCerts)
  209. for i := 0; i < numCerts; i++ {
  210. m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
  211. }
  212. m.requestContext = randomBytes(rand.Intn(5), rand)
  213. return reflect.ValueOf(m)
  214. }
  215. func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  216. m := &certificateRequestMsg{}
  217. m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
  218. numCAs := rand.Intn(100)
  219. m.certificateAuthorities = make([][]byte, numCAs)
  220. for i := 0; i < numCAs; i++ {
  221. m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
  222. }
  223. return reflect.ValueOf(m)
  224. }
  225. func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  226. m := &certificateVerifyMsg{}
  227. m.signature = randomBytes(rand.Intn(15)+1, rand)
  228. return reflect.ValueOf(m)
  229. }
  230. func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  231. m := &certificateStatusMsg{}
  232. if rand.Intn(10) > 5 {
  233. m.statusType = statusTypeOCSP
  234. m.response = randomBytes(rand.Intn(10)+1, rand)
  235. } else {
  236. m.statusType = 42
  237. }
  238. return reflect.ValueOf(m)
  239. }
  240. func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  241. m := &clientKeyExchangeMsg{}
  242. m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
  243. return reflect.ValueOf(m)
  244. }
  245. func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  246. m := &finishedMsg{}
  247. m.verifyData = randomBytes(12, rand)
  248. return reflect.ValueOf(m)
  249. }
  250. func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  251. m := &nextProtoMsg{}
  252. m.proto = randomString(rand.Intn(255), rand)
  253. return reflect.ValueOf(m)
  254. }
  255. func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  256. m := &newSessionTicketMsg{}
  257. m.ticket = randomBytes(rand.Intn(4), rand)
  258. return reflect.ValueOf(m)
  259. }
  260. func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
  261. s := &sessionState{}
  262. s.vers = uint16(rand.Intn(10000))
  263. s.cipherSuite = uint16(rand.Intn(10000))
  264. s.masterSecret = randomBytes(rand.Intn(100), rand)
  265. numCerts := rand.Intn(20)
  266. s.certificates = make([][]byte, numCerts)
  267. for i := 0; i < numCerts; i++ {
  268. s.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
  269. }
  270. return reflect.ValueOf(s)
  271. }
  272. func TestRejectEmptySCTList(t *testing.T) {
  273. // https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that
  274. // empty SCT lists are invalid.
  275. var random [32]byte
  276. sct := []byte{0x42, 0x42, 0x42, 0x42}
  277. serverHello := serverHelloMsg{
  278. vers: VersionTLS12,
  279. random: random[:],
  280. scts: [][]byte{sct},
  281. }
  282. serverHelloBytes := serverHello.marshal()
  283. var serverHelloCopy serverHelloMsg
  284. if !serverHelloCopy.unmarshal(serverHelloBytes) {
  285. t.Fatal("Failed to unmarshal initial message")
  286. }
  287. // Change serverHelloBytes so that the SCT list is empty
  288. i := bytes.Index(serverHelloBytes, sct)
  289. if i < 0 {
  290. t.Fatal("Cannot find SCT in ServerHello")
  291. }
  292. var serverHelloEmptySCT []byte
  293. serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
  294. // Append the extension length and SCT list length for an empty list.
  295. serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
  296. serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
  297. // Update the handshake message length.
  298. serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
  299. serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
  300. serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
  301. // Update the extensions length
  302. serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
  303. serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
  304. if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
  305. t.Fatal("Unmarshaled ServerHello with empty SCT list")
  306. }
  307. }
  308. func TestRejectEmptySCT(t *testing.T) {
  309. // Not only must the SCT list be non-empty, but the SCT elements must
  310. // not be zero length.
  311. var random [32]byte
  312. serverHello := serverHelloMsg{
  313. vers: VersionTLS12,
  314. random: random[:],
  315. scts: [][]byte{nil},
  316. }
  317. serverHelloBytes := serverHello.marshal()
  318. var serverHelloCopy serverHelloMsg
  319. if serverHelloCopy.unmarshal(serverHelloBytes) {
  320. t.Fatal("Unmarshaled ServerHello with zero-length SCT")
  321. }
  322. }