Alternative TLS implementation in Go
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

handshake_messages_test.go 12 KiB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package tls
  5. import (
  6. "bytes"
  7. "math/rand"
  8. "reflect"
  9. "strings"
  10. "testing"
  11. "testing/quick"
  12. )
  13. var tests = []interface{}{
  14. &clientHelloMsg{},
  15. &serverHelloMsg{},
  16. &finishedMsg{},
  17. &certificateMsg{},
  18. &certificateRequestMsg{},
  19. &certificateRequestMsg13{},
  20. &certificateVerifyMsg{},
  21. &certificateStatusMsg{},
  22. &clientKeyExchangeMsg{},
  23. &nextProtoMsg{},
  24. &newSessionTicketMsg{},
  25. &sessionState{},
  26. &encryptedExtensionsMsg{},
  27. &certificateMsg13{},
  28. &newSessionTicketMsg13{},
  29. &sessionState13{},
  30. }
  31. type testMessage interface {
  32. marshal() []byte
  33. unmarshal([]byte) alert
  34. equal(interface{}) bool
  35. }
  36. func TestMarshalUnmarshal(t *testing.T) {
  37. rand := rand.New(rand.NewSource(0))
  38. for i, iface := range tests {
  39. ty := reflect.ValueOf(iface).Type()
  40. n := 100
  41. if testing.Short() {
  42. n = 5
  43. }
  44. for j := 0; j < n; j++ {
  45. v, ok := quick.Value(ty, rand)
  46. if !ok {
  47. t.Errorf("#%d: failed to create value", i)
  48. break
  49. }
  50. m1 := v.Interface().(testMessage)
  51. marshaled := m1.marshal()
  52. m2 := iface.(testMessage)
  53. if m2.unmarshal(marshaled) != alertSuccess {
  54. t.Errorf("#%d.%d failed to unmarshal %#v %x", i, j, m1, marshaled)
  55. break
  56. }
  57. m2.marshal() // to fill any marshal cache in the message
  58. if !m1.equal(m2) {
  59. t.Errorf("#%d.%d got:%#v want:%#v %x", i, j, m2, m1, marshaled)
  60. break
  61. }
  62. if i >= 3 {
  63. // The first three message types (ClientHello,
  64. // ServerHello and Finished) are allowed to
  65. // have parsable prefixes because the extension
  66. // data is optional and the length of the
  67. // Finished varies across versions.
  68. for j := 0; j < len(marshaled); j++ {
  69. if m2.unmarshal(marshaled[0:j]) == alertSuccess {
  70. t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
  71. break
  72. }
  73. }
  74. }
  75. }
  76. }
  77. }
  78. func TestFuzz(t *testing.T) {
  79. rand := rand.New(rand.NewSource(0))
  80. for _, iface := range tests {
  81. m := iface.(testMessage)
  82. for j := 0; j < 1000; j++ {
  83. len := rand.Intn(100)
  84. bytes := randomBytes(len, rand)
  85. // This just looks for crashes due to bounds errors etc.
  86. m.unmarshal(bytes)
  87. }
  88. }
  89. }
  90. func randomBytes(n int, rand *rand.Rand) []byte {
  91. r := make([]byte, n)
  92. if _, err := rand.Read(r); err != nil {
  93. panic("rand.Read failed: " + err.Error())
  94. }
  95. return r
  96. }
  97. func randomString(n int, rand *rand.Rand) string {
  98. b := randomBytes(n, rand)
  99. return string(b)
  100. }
  101. func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  102. m := &clientHelloMsg{}
  103. m.vers = uint16(rand.Intn(65536))
  104. m.random = randomBytes(32, rand)
  105. m.sessionId = randomBytes(rand.Intn(32), rand)
  106. m.cipherSuites = make([]uint16, rand.Intn(63)+1)
  107. for i := 0; i < len(m.cipherSuites); i++ {
  108. cs := uint16(rand.Int31())
  109. if cs == scsvRenegotiation {
  110. cs += 1
  111. }
  112. m.cipherSuites[i] = cs
  113. }
  114. m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
  115. if rand.Intn(10) > 5 {
  116. m.nextProtoNeg = true
  117. }
  118. if rand.Intn(10) > 5 {
  119. m.serverName = randomString(rand.Intn(255), rand)
  120. for strings.HasSuffix(m.serverName, ".") {
  121. m.serverName = m.serverName[:len(m.serverName)-1]
  122. }
  123. }
  124. m.ocspStapling = rand.Intn(10) > 5
  125. m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
  126. m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
  127. for i := range m.supportedCurves {
  128. m.supportedCurves[i] = CurveID(rand.Intn(30000))
  129. }
  130. if rand.Intn(10) > 5 {
  131. m.ticketSupported = true
  132. if rand.Intn(10) > 5 {
  133. m.sessionTicket = randomBytes(rand.Intn(300), rand)
  134. }
  135. }
  136. if rand.Intn(10) > 5 {
  137. m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
  138. }
  139. m.alpnProtocols = make([]string, rand.Intn(5))
  140. for i := range m.alpnProtocols {
  141. m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
  142. }
  143. if rand.Intn(10) > 5 {
  144. m.scts = true
  145. }
  146. m.keyShares = make([]keyShare, rand.Intn(4))
  147. for i := range m.keyShares {
  148. m.keyShares[i].group = CurveID(rand.Intn(30000))
  149. m.keyShares[i].data = randomBytes(rand.Intn(300)+1, rand)
  150. }
  151. m.supportedVersions = make([]uint16, rand.Intn(5))
  152. for i := range m.supportedVersions {
  153. m.supportedVersions[i] = uint16(rand.Intn(30000))
  154. }
  155. if rand.Intn(10) > 5 {
  156. m.earlyData = true
  157. }
  158. if rand.Intn(10) > 5 {
  159. m.delegatedCredential = true
  160. }
  161. if rand.Intn(10) > 5 {
  162. m.extendedMSSupported = true
  163. }
  164. return reflect.ValueOf(m)
  165. }
  166. func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  167. m := &serverHelloMsg{}
  168. m.vers = uint16(rand.Intn(65536))
  169. m.random = randomBytes(32, rand)
  170. m.sessionId = randomBytes(rand.Intn(32), rand)
  171. m.cipherSuite = uint16(rand.Int31())
  172. m.compressionMethod = uint8(rand.Intn(256))
  173. if rand.Intn(10) > 5 {
  174. m.nextProtoNeg = true
  175. n := rand.Intn(10)
  176. m.nextProtos = make([]string, n)
  177. for i := 0; i < n; i++ {
  178. m.nextProtos[i] = randomString(20, rand)
  179. }
  180. }
  181. if rand.Intn(10) > 5 {
  182. m.ocspStapling = true
  183. }
  184. if rand.Intn(10) > 5 {
  185. m.ticketSupported = true
  186. }
  187. if rand.Intn(10) > 5 {
  188. m.extendedMSSupported = true
  189. }
  190. m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
  191. if rand.Intn(10) > 5 {
  192. numSCTs := rand.Intn(4)
  193. m.scts = make([][]byte, numSCTs)
  194. for i := range m.scts {
  195. m.scts[i] = randomBytes(rand.Intn(500)+1, rand)
  196. }
  197. }
  198. if rand.Intn(10) > 5 {
  199. m.keyShare.group = CurveID(rand.Intn(30000) + 1)
  200. m.keyShare.data = randomBytes(rand.Intn(300)+1, rand)
  201. }
  202. if rand.Intn(10) > 5 {
  203. m.psk = true
  204. m.pskIdentity = uint16(rand.Int31())
  205. }
  206. return reflect.ValueOf(m)
  207. }
  208. func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  209. m := &encryptedExtensionsMsg{}
  210. if rand.Intn(10) > 5 {
  211. m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
  212. }
  213. if rand.Intn(10) > 5 {
  214. m.earlyData = true
  215. }
  216. return reflect.ValueOf(m)
  217. }
  218. func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  219. m := &certificateMsg{}
  220. numCerts := rand.Intn(20)
  221. m.certificates = make([][]byte, numCerts)
  222. for i := 0; i < numCerts; i++ {
  223. m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
  224. }
  225. return reflect.ValueOf(m)
  226. }
  227. func (*certificateMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
  228. m := &certificateMsg13{}
  229. numCerts := rand.Intn(20)
  230. m.certificates = make([]certificateEntry, numCerts)
  231. for i := 0; i < numCerts; i++ {
  232. m.certificates[i].data = randomBytes(rand.Intn(10)+1, rand)
  233. if rand.Intn(2) == 1 {
  234. m.certificates[i].ocspStaple = randomBytes(rand.Intn(10)+1, rand)
  235. }
  236. numScts := rand.Intn(3)
  237. for j := 0; j < numScts; j++ {
  238. m.certificates[i].sctList = append(m.certificates[i].sctList, randomBytes(rand.Intn(10)+1, rand))
  239. }
  240. }
  241. m.requestContext = randomBytes(rand.Intn(5), rand)
  242. return reflect.ValueOf(m)
  243. }
  244. func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  245. m := &certificateRequestMsg{}
  246. m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
  247. numCAs := rand.Intn(100)
  248. m.certificateAuthorities = make([][]byte, numCAs)
  249. for i := 0; i < numCAs; i++ {
  250. m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
  251. }
  252. return reflect.ValueOf(m)
  253. }
  254. func (*certificateRequestMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
  255. m := &certificateRequestMsg13{}
  256. m.requestContext = randomBytes(rand.Intn(5), rand)
  257. m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
  258. numCAs := rand.Intn(100)
  259. m.certificateAuthorities = make([][]byte, numCAs)
  260. for i := 0; i < numCAs; i++ {
  261. m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
  262. }
  263. return reflect.ValueOf(m)
  264. }
  265. func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  266. m := &certificateVerifyMsg{}
  267. m.signature = randomBytes(rand.Intn(15)+1, rand)
  268. return reflect.ValueOf(m)
  269. }
  270. func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  271. m := &certificateStatusMsg{}
  272. if rand.Intn(10) > 5 {
  273. m.statusType = statusTypeOCSP
  274. m.response = randomBytes(rand.Intn(10)+1, rand)
  275. } else {
  276. m.statusType = 42
  277. }
  278. return reflect.ValueOf(m)
  279. }
  280. func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  281. m := &clientKeyExchangeMsg{}
  282. m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
  283. return reflect.ValueOf(m)
  284. }
  285. func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  286. m := &finishedMsg{}
  287. m.verifyData = randomBytes(12, rand)
  288. return reflect.ValueOf(m)
  289. }
  290. func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  291. m := &nextProtoMsg{}
  292. m.proto = randomString(rand.Intn(255), rand)
  293. return reflect.ValueOf(m)
  294. }
  295. func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  296. m := &newSessionTicketMsg{}
  297. m.ticket = randomBytes(rand.Intn(4), rand)
  298. return reflect.ValueOf(m)
  299. }
  300. func (*newSessionTicketMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
  301. m := &newSessionTicketMsg13{}
  302. m.ageAdd = uint32(rand.Intn(0xffffffff))
  303. m.lifetime = uint32(rand.Intn(0xffffffff))
  304. m.nonce = randomBytes(1+rand.Intn(255), rand)
  305. m.ticket = randomBytes(1+rand.Intn(40), rand)
  306. if rand.Intn(10) > 5 {
  307. m.withEarlyDataInfo = true
  308. m.maxEarlyDataLength = uint32(rand.Intn(0xffffffff))
  309. }
  310. return reflect.ValueOf(m)
  311. }
  312. func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
  313. s := &sessionState{}
  314. s.vers = uint16(rand.Intn(10000))
  315. s.cipherSuite = uint16(rand.Intn(10000))
  316. s.masterSecret = randomBytes(rand.Intn(100), rand)
  317. if rand.Intn(10) > 5 {
  318. s.usedEMS = true
  319. }
  320. numCerts := rand.Intn(20)
  321. s.certificates = make([][]byte, numCerts)
  322. for i := 0; i < numCerts; i++ {
  323. s.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
  324. }
  325. return reflect.ValueOf(s)
  326. }
  327. func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value {
  328. s := &sessionState13{}
  329. s.vers = uint16(rand.Intn(10000))
  330. s.suite = uint16(rand.Intn(10000))
  331. s.ageAdd = uint32(rand.Intn(0xffffffff))
  332. s.maxEarlyDataLen = uint32(rand.Intn(0xffffffff))
  333. s.createdAt = uint64(rand.Int63n(0xfffffffffffffff))
  334. s.pskSecret = randomBytes(rand.Intn(100), rand)
  335. s.alpnProtocol = randomString(rand.Intn(100), rand)
  336. s.SNI = randomString(rand.Intn(100), rand)
  337. return reflect.ValueOf(s)
  338. }
  339. func TestRejectEmptySCTList(t *testing.T) {
  340. // https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that
  341. // empty SCT lists are invalid.
  342. var random [32]byte
  343. sct := []byte{0x42, 0x42, 0x42, 0x42}
  344. serverHello := serverHelloMsg{
  345. vers: VersionTLS12,
  346. random: random[:],
  347. scts: [][]byte{sct},
  348. }
  349. serverHelloBytes := serverHello.marshal()
  350. var serverHelloCopy serverHelloMsg
  351. if serverHelloCopy.unmarshal(serverHelloBytes) != alertSuccess {
  352. t.Fatal("Failed to unmarshal initial message")
  353. }
  354. // Change serverHelloBytes so that the SCT list is empty
  355. i := bytes.Index(serverHelloBytes, sct)
  356. if i < 0 {
  357. t.Fatal("Cannot find SCT in ServerHello")
  358. }
  359. var serverHelloEmptySCT []byte
  360. serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
  361. // Append the extension length and SCT list length for an empty list.
  362. serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
  363. serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
  364. // Update the handshake message length.
  365. serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
  366. serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
  367. serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
  368. // Update the extensions length
  369. serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
  370. serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
  371. if serverHelloCopy.unmarshal(serverHelloEmptySCT) == alertSuccess {
  372. t.Fatal("Unmarshaled ServerHello with empty SCT list")
  373. }
  374. }
  375. func TestRejectEmptySCT(t *testing.T) {
  376. // Not only must the SCT list be non-empty, but the SCT elements must
  377. // not be zero length.
  378. var random [32]byte
  379. serverHello := serverHelloMsg{
  380. vers: VersionTLS12,
  381. random: random[:],
  382. scts: [][]byte{nil},
  383. }
  384. serverHelloBytes := serverHello.marshal()
  385. var serverHelloCopy serverHelloMsg
  386. if serverHelloCopy.unmarshal(serverHelloBytes) == alertSuccess {
  387. t.Fatal("Unmarshaled ServerHello with zero-length SCT")
  388. }
  389. }