You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

295 lines
7.5 KiB

  1. // Copyright (c) 2016, Google Inc.
  2. //
  3. // Permission to use, copy, modify, and/or distribute this software for any
  4. // purpose with or without fee is hereby granted, provided that the above
  5. // copyright notice and this permission notice appear in all copies.
  6. //
  7. // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  8. // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  9. // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
  10. // SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  11. // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
  12. // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
  13. // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  14. package main
  15. import (
  16. "bufio"
  17. "errors"
  18. "fmt"
  19. "io"
  20. "math/big"
  21. "os"
  22. "strings"
  23. )
  24. type test struct {
  25. LineNumber int
  26. Type string
  27. Values map[string]*big.Int
  28. }
  29. type testScanner struct {
  30. scanner *bufio.Scanner
  31. lineNo int
  32. err error
  33. test test
  34. }
  35. func newTestScanner(r io.Reader) *testScanner {
  36. return &testScanner{scanner: bufio.NewScanner(r)}
  37. }
  38. func (s *testScanner) scanLine() bool {
  39. if !s.scanner.Scan() {
  40. return false
  41. }
  42. s.lineNo++
  43. return true
  44. }
  45. func (s *testScanner) addAttribute(line string) (key string, ok bool) {
  46. fields := strings.SplitN(line, "=", 2)
  47. if len(fields) != 2 {
  48. s.setError(errors.New("invalid syntax"))
  49. return "", false
  50. }
  51. key = strings.TrimSpace(fields[0])
  52. value := strings.TrimSpace(fields[1])
  53. valueInt, ok := new(big.Int).SetString(value, 16)
  54. if !ok {
  55. s.setError(fmt.Errorf("could not parse %q", value))
  56. return "", false
  57. }
  58. if _, dup := s.test.Values[key]; dup {
  59. s.setError(fmt.Errorf("duplicate key %q", key))
  60. return "", false
  61. }
  62. s.test.Values[key] = valueInt
  63. return key, true
  64. }
  65. func (s *testScanner) Scan() bool {
  66. s.test = test{
  67. Values: make(map[string]*big.Int),
  68. }
  69. // Scan until the first attribute.
  70. for {
  71. if !s.scanLine() {
  72. return false
  73. }
  74. if len(s.scanner.Text()) != 0 && s.scanner.Text()[0] != '#' {
  75. break
  76. }
  77. }
  78. var ok bool
  79. s.test.Type, ok = s.addAttribute(s.scanner.Text())
  80. if !ok {
  81. return false
  82. }
  83. s.test.LineNumber = s.lineNo
  84. for s.scanLine() {
  85. if len(s.scanner.Text()) == 0 {
  86. break
  87. }
  88. if s.scanner.Text()[0] == '#' {
  89. continue
  90. }
  91. if _, ok := s.addAttribute(s.scanner.Text()); !ok {
  92. return false
  93. }
  94. }
  95. return s.scanner.Err() == nil
  96. }
  97. func (s *testScanner) Test() test {
  98. return s.test
  99. }
  100. func (s *testScanner) Err() error {
  101. if s.err != nil {
  102. return s.err
  103. }
  104. return s.scanner.Err()
  105. }
  106. func (s *testScanner) setError(err error) {
  107. s.err = fmt.Errorf("line %d: %s", s.lineNo, err)
  108. }
  109. func checkKeys(t test, keys ...string) bool {
  110. var foundErrors bool
  111. for _, k := range keys {
  112. if _, ok := t.Values[k]; !ok {
  113. fmt.Fprintf(os.Stderr, "Line %d: missing key %q.\n", t.LineNumber, k)
  114. foundErrors = true
  115. }
  116. }
  117. for k, _ := range t.Values {
  118. var found bool
  119. for _, k2 := range keys {
  120. if k == k2 {
  121. found = true
  122. break
  123. }
  124. }
  125. if !found {
  126. fmt.Fprintf(os.Stderr, "Line %d: unexpected key %q.\n", t.LineNumber, k)
  127. foundErrors = true
  128. }
  129. }
  130. return !foundErrors
  131. }
  132. func checkResult(t test, expr, key string, r *big.Int) {
  133. if t.Values[key].Cmp(r) != 0 {
  134. fmt.Fprintf(os.Stderr, "Line %d: %s did not match %s.\n\tGot %s\n", t.LineNumber, expr, key, r.Text(16))
  135. }
  136. }
  137. func main() {
  138. if len(os.Args) != 2 {
  139. fmt.Fprintf(os.Stderr, "Usage: %s bn_tests.txt\n", os.Args[0])
  140. os.Exit(1)
  141. }
  142. in, err := os.Open(os.Args[1])
  143. if err != nil {
  144. fmt.Fprintf(os.Stderr, "Error opening %s: %s.\n", os.Args[0], err)
  145. os.Exit(1)
  146. }
  147. defer in.Close()
  148. scanner := newTestScanner(in)
  149. for scanner.Scan() {
  150. test := scanner.Test()
  151. switch test.Type {
  152. case "Sum":
  153. if checkKeys(test, "A", "B", "Sum") {
  154. r := new(big.Int).Add(test.Values["A"], test.Values["B"])
  155. checkResult(test, "A + B", "Sum", r)
  156. }
  157. case "LShift1":
  158. if checkKeys(test, "A", "LShift1") {
  159. r := new(big.Int).Add(test.Values["A"], test.Values["A"])
  160. checkResult(test, "A + A", "LShift1", r)
  161. }
  162. case "LShift":
  163. if checkKeys(test, "A", "N", "LShift") {
  164. r := new(big.Int).Lsh(test.Values["A"], uint(test.Values["N"].Uint64()))
  165. checkResult(test, "A << N", "LShift", r)
  166. }
  167. case "RShift":
  168. if checkKeys(test, "A", "N", "RShift") {
  169. r := new(big.Int).Rsh(test.Values["A"], uint(test.Values["N"].Uint64()))
  170. checkResult(test, "A >> N", "RShift", r)
  171. }
  172. case "Square":
  173. if checkKeys(test, "A", "Square") {
  174. r := new(big.Int).Mul(test.Values["A"], test.Values["A"])
  175. checkResult(test, "A * A", "Square", r)
  176. }
  177. case "Product":
  178. if checkKeys(test, "A", "B", "Product") {
  179. r := new(big.Int).Mul(test.Values["A"], test.Values["B"])
  180. checkResult(test, "A * B", "Product", r)
  181. }
  182. case "Quotient":
  183. if checkKeys(test, "A", "B", "Quotient", "Remainder") {
  184. q, r := new(big.Int).QuoRem(test.Values["A"], test.Values["B"], new(big.Int))
  185. checkResult(test, "A / B", "Quotient", q)
  186. checkResult(test, "A % B", "Remainder", r)
  187. }
  188. case "ModMul":
  189. if checkKeys(test, "A", "B", "M", "ModMul") {
  190. r := new(big.Int).Mul(test.Values["A"], test.Values["B"])
  191. r = r.Mod(r, test.Values["M"])
  192. checkResult(test, "A * B (mod M)", "ModMul", r)
  193. }
  194. case "ModExp":
  195. if checkKeys(test, "A", "E", "M", "ModExp") {
  196. r := new(big.Int).Exp(test.Values["A"], test.Values["E"], test.Values["M"])
  197. checkResult(test, "A ^ E (mod M)", "ModExp", r)
  198. }
  199. case "Exp":
  200. if checkKeys(test, "A", "E", "Exp") {
  201. r := new(big.Int).Exp(test.Values["A"], test.Values["E"], nil)
  202. checkResult(test, "A ^ E", "Exp", r)
  203. }
  204. case "ModSqrt":
  205. bigOne := new(big.Int).SetInt64(1)
  206. bigTwo := new(big.Int).SetInt64(2)
  207. if checkKeys(test, "A", "P", "ModSqrt") {
  208. test.Values["A"].Mod(test.Values["A"], test.Values["P"])
  209. r := new(big.Int).Mul(test.Values["ModSqrt"], test.Values["ModSqrt"])
  210. r = r.Mod(r, test.Values["P"])
  211. checkResult(test, "ModSqrt ^ 2 (mod P)", "A", r)
  212. if test.Values["P"].Cmp(bigTwo) > 0 {
  213. pMinus1Over2 := new(big.Int).Sub(test.Values["P"], bigOne)
  214. pMinus1Over2.Rsh(pMinus1Over2, 1)
  215. if test.Values["ModSqrt"].Cmp(pMinus1Over2) > 0 {
  216. fmt.Fprintf(os.Stderr, "Line %d: ModSqrt should be minimal.\n", test.LineNumber)
  217. }
  218. }
  219. }
  220. case "ModInv":
  221. if checkKeys(test, "A", "M", "ModInv") {
  222. r := new(big.Int).ModInverse(test.Values["A"], test.Values["M"])
  223. checkResult(test, "A ^ -1 (mod M)", "ModInv", r)
  224. }
  225. case "ModSquare":
  226. if checkKeys(test, "A", "M", "ModSquare") {
  227. r := new(big.Int).Mul(test.Values["A"], test.Values["A"])
  228. r = r.Mod(r, test.Values["M"])
  229. checkResult(test, "A * A (mod M)", "ModSquare", r)
  230. }
  231. case "NotModSquare":
  232. if checkKeys(test, "P", "NotModSquare") {
  233. if new(big.Int).ModSqrt(test.Values["NotModSquare"], test.Values["P"]) != nil {
  234. fmt.Fprintf(os.Stderr, "Line %d: value was a square.\n", test.LineNumber)
  235. }
  236. }
  237. case "GCD":
  238. if checkKeys(test, "A", "B", "GCD", "LCM") {
  239. a := test.Values["A"]
  240. b := test.Values["B"]
  241. // Go's GCD function does not accept zero, unlike OpenSSL.
  242. var g *big.Int
  243. if a.Sign() == 0 {
  244. g = b
  245. } else if b.Sign() == 0 {
  246. g = a
  247. } else {
  248. g = new(big.Int).GCD(nil, nil, a, b)
  249. }
  250. checkResult(test, "GCD(A, B)", "GCD", g)
  251. if g.Sign() != 0 {
  252. lcm := new(big.Int).Mul(a, b)
  253. lcm = lcm.Div(lcm, g)
  254. checkResult(test, "LCM(A, B)", "LCM", lcm)
  255. }
  256. }
  257. default:
  258. fmt.Fprintf(os.Stderr, "Line %d: unknown test type %q.\n", test.LineNumber, test.Type)
  259. }
  260. }
  261. if scanner.Err() != nil {
  262. fmt.Fprintf(os.Stderr, "Error reading tests: %s.\n", scanner.Err())
  263. }
  264. }