luaunit.lua 98 KB


  1. --[[
  2. luaunit.lua
  3. Description: A unit testing framework
  4. Homepage: https://github.com/bluebird75/luaunit
  5. Development by Philippe Fremy <phil@freehackers.org>
  6. Based on initial work of Ryu, Gwang (http://www.gpgstudy.com/gpgiki/LuaUnit)
  7. License: BSD License, see LICENSE.txt
  8. Version: 3.2
  9. ]]--
  10. require("math")
  11. local M={}
  12. -- private exported functions (for testing)
  13. M.private = {}
  14. M.VERSION='3.2'
  15. M._VERSION=M.VERSION -- For LuaUnit v2 compatibility
  16. --[[ Some people like assertEquals( actual, expected ) and some people prefer
  17. assertEquals( expected, actual ).
  18. ]]--
  19. M.ORDER_ACTUAL_EXPECTED = true
  20. M.PRINT_TABLE_REF_IN_ERROR_MSG = false
  21. M.TABLE_EQUALS_KEYBYCONTENT = true
  22. M.LINE_LENGTH = 80
  23. M.TABLE_DIFF_ANALYSIS_THRESHOLD = 10 -- display deep analysis for more than 10 items
  24. M.LIST_DIFF_ANALYSIS_THRESHOLD = 10 -- display deep analysis for more than 10 items
  25. --[[ M.EPSILON is meant to help with Lua's floating point math in simple corner
  26. cases like almostEquals(1.1-0.1, 1), which may not work as-is (e.g. on numbers
  27. with rational binary representation) if the user doesn't provide some explicit
  28. error margin.
  29. The default margin used by almostEquals() in such cases is M.EPSILON; and since
  30. Lua may be compiled with different numeric precisions (single vs. double), we
  31. try to select a useful default for it dynamically. Note: If the initial value
  32. is not acceptable, it can be changed by the user to better suit specific needs.
  33. See also: https://en.wikipedia.org/wiki/Machine_epsilon
  34. ]]
  35. M.EPSILON = 2^-52 -- = machine epsilon for "double", ~2.22E-16
  36. if math.abs(1.1 - 1 - 0.1) > M.EPSILON then
  37. -- rounding error is above EPSILON, assume single precision
  38. M.EPSILON = 2^-23 -- = machine epsilon for "float", ~1.19E-07
  39. end
  40. -- set this to false to debug luaunit
  41. local STRIP_LUAUNIT_FROM_STACKTRACE = true
  42. M.VERBOSITY_DEFAULT = 10
  43. M.VERBOSITY_LOW = 1
  44. M.VERBOSITY_QUIET = 0
  45. M.VERBOSITY_VERBOSE = 20
  46. M.DEFAULT_DEEP_ANALYSIS = nil
  47. M.FORCE_DEEP_ANALYSIS = true
  48. M.DISABLE_DEEP_ANALYSIS = false
  49. -- set EXPORT_ASSERT_TO_GLOBALS to have all asserts visible as global values
  50. -- EXPORT_ASSERT_TO_GLOBALS = true
  51. -- we need to keep a copy of the script args before it is overriden
  52. local cmdline_argv = rawget(_G, "arg")
  53. M.FAILURE_PREFIX = 'LuaUnit test FAILURE: ' -- prefix string for failed tests
  54. M.USAGE=[[Usage: lua <your_test_suite.lua> [options] [testname1 [testname2] ... ]
  55. Options:
  56. -h, --help: Print this help
  57. --version: Print version information
  58. -v, --verbose: Increase verbosity
  59. -q, --quiet: Set verbosity to minimum
  60. -e, --error: Stop on first error
  61. -f, --failure: Stop on first failure or error
  62. -r, --random Run tests in random order
  63. -o, --output OUTPUT: Set output type to OUTPUT
  64. Possible values: text, tap, junit, nil
  65. -n, --name NAME: For junit only, mandatory name of xml file
  66. -c, --count NUM: Execute all tests NUM times, e.g. to trig the JIT
  67. -p, --pattern PATTERN: Execute all test names matching the Lua PATTERN
  68. May be repeated to include several patterns
  69. Make sure you escape magic chars like +? with %
  70. -x, --exclude PATTERN: Exclude all test names matching the Lua PATTERN
  71. May be repeated to exclude several patterns
  72. Make sure you escape magic chars like +? with %
  73. testname1, testname2, ... : tests to run in the form of testFunction,
  74. TestClass or TestClass.testMethod
  75. ]]
  76. local is_equal -- defined here to allow calling from mismatchFormattingPureList
  77. ----------------------------------------------------------------
  78. --
  79. -- general utility functions
  80. --
  81. ----------------------------------------------------------------
  82. local function pcall_or_abort(func, ...)
  83. -- unpack is a global function for Lua 5.1, otherwise use table.unpack
  84. local unpack = rawget(_G, "unpack") or table.unpack
  85. local result = {pcall(func, ...)}
  86. if not result[1] then
  87. -- an error occurred
  88. print(result[2]) -- error message
  89. print()
  90. print(M.USAGE)
  91. os.exit(-1)
  92. end
  93. return unpack(result, 2)
  94. end
  95. local crossTypeOrdering = {
  96. number = 1, boolean = 2, string = 3, table = 4, other = 5
  97. }
  98. local crossTypeComparison = {
  99. number = function(a, b) return a < b end,
  100. string = function(a, b) return a < b end,
  101. other = function(a, b) return tostring(a) < tostring(b) end,
  102. }
  103. local function crossTypeSort(a, b)
  104. local type_a, type_b = type(a), type(b)
  105. if type_a == type_b then
  106. local func = crossTypeComparison[type_a] or crossTypeComparison.other
  107. return func(a, b)
  108. end
  109. type_a = crossTypeOrdering[type_a] or crossTypeOrdering.other
  110. type_b = crossTypeOrdering[type_b] or crossTypeOrdering.other
  111. return type_a < type_b
  112. end
  113. local function __genSortedIndex( t )
  114. -- Returns a sequence consisting of t's keys, sorted.
  115. local sortedIndex = {}
  116. for key,_ in pairs(t) do
  117. table.insert(sortedIndex, key)
  118. end
  119. table.sort(sortedIndex, crossTypeSort)
  120. return sortedIndex
  121. end
  122. M.private.__genSortedIndex = __genSortedIndex
  123. local function sortedNext(state, control)
  124. -- Equivalent of the next() function of table iteration, but returns the
  125. -- keys in sorted order (see __genSortedIndex and crossTypeSort).
  126. -- The state is a temporary variable during iteration and contains the
  127. -- sorted key table (state.sortedIdx). It also stores the last index (into
  128. -- the keys) used by the iteration, to find the next one quickly.
  129. local key
  130. --print("sortedNext: control = "..tostring(control) )
  131. if control == nil then
  132. -- start of iteration
  133. state.count = #state.sortedIdx
  134. state.lastIdx = 1
  135. key = state.sortedIdx[1]
  136. return key, state.t[key]
  137. end
  138. -- normally, we expect the control variable to match the last key used
  139. if control ~= state.sortedIdx[state.lastIdx] then
  140. -- strange, we have to find the next value by ourselves
  141. -- the key table is sorted in crossTypeSort() order! -> use bisection
  142. local lower, upper = 1, state.count
  143. repeat
  144. state.lastIdx = math.modf((lower + upper) / 2)
  145. key = state.sortedIdx[state.lastIdx]
  146. if key == control then
  147. break -- key found (and thus prev index)
  148. end
  149. if crossTypeSort(key, control) then
  150. -- key < control, continue search "right" (towards upper bound)
  151. lower = state.lastIdx + 1
  152. else
  153. -- key > control, continue search "left" (towards lower bound)
  154. upper = state.lastIdx - 1
  155. end
  156. until lower > upper
  157. if lower > upper then -- only true if the key wasn't found, ...
  158. state.lastIdx = state.count -- ... so ensure no match in code below
  159. end
  160. end
  161. -- proceed by retrieving the next value (or nil) from the sorted keys
  162. state.lastIdx = state.lastIdx + 1
  163. key = state.sortedIdx[state.lastIdx]
  164. if key then
  165. return key, state.t[key]
  166. end
  167. -- getting here means returning `nil`, which will end the iteration
  168. end
  169. local function sortedPairs(tbl)
  170. -- Equivalent of the pairs() function on tables. Allows to iterate in
  171. -- sorted order. As required by "generic for" loops, this will return the
  172. -- iterator (function), an "invariant state", and the initial control value.
  173. -- (see http://www.lua.org/pil/7.2.html)
  174. return sortedNext, {t = tbl, sortedIdx = __genSortedIndex(tbl)}, nil
  175. end
  176. M.private.sortedPairs = sortedPairs
  177. -- seed the random with a strongly varying seed
  178. math.randomseed(os.clock()*1E11)
  179. local function randomizeTable( t )
  180. -- randomize the item orders of the table t
  181. for i = #t, 2, -1 do
  182. local j = math.random(i)
  183. if i ~= j then
  184. t[i], t[j] = t[j], t[i]
  185. end
  186. end
  187. end
  188. M.private.randomizeTable = randomizeTable
  189. local function strsplit(delimiter, text)
  190. -- Split text into a list consisting of the strings in text, separated
  191. -- by strings matching delimiter (which may _NOT_ be a pattern).
  192. -- Example: strsplit(", ", "Anna, Bob, Charlie, Dolores")
  193. if delimiter == "" then -- this would result in endless loops
  194. error("delimiter matches empty string!")
  195. end
  196. local list, pos, first, last = {}, 1
  197. while true do
  198. first, last = text:find(delimiter, pos, true)
  199. if first then -- found?
  200. table.insert(list, text:sub(pos, first - 1))
  201. pos = last + 1
  202. else
  203. table.insert(list, text:sub(pos))
  204. break
  205. end
  206. end
  207. return list
  208. end
  209. M.private.strsplit = strsplit
  210. local function hasNewLine( s )
  211. -- return true if s has a newline
  212. return (string.find(s, '\n', 1, true) ~= nil)
  213. end
  214. M.private.hasNewLine = hasNewLine
  215. local function prefixString( prefix, s )
  216. -- Prefix all the lines of s with prefix
  217. return prefix .. string.gsub(s, '\n', '\n' .. prefix)
  218. end
  219. M.private.prefixString = prefixString
  220. local function strMatch(s, pattern, start, final )
  221. -- return true if s matches completely the pattern from index start to index end
  222. -- return false in every other cases
  223. -- if start is nil, matches from the beginning of the string
  224. -- if final is nil, matches to the end of the string
  225. start = start or 1
  226. final = final or string.len(s)
  227. local foundStart, foundEnd = string.find(s, pattern, start, false)
  228. return foundStart == start and foundEnd == final
  229. end
  230. M.private.strMatch = strMatch
  231. local function patternFilter(patterns, expr)
  232. -- Run `expr` through the inclusion and exclusion rules defined in patterns
  233. -- and return true if expr shall be included, false for excluded.
  234. -- Inclusion pattern are defined as normal patterns, exclusions
  235. -- patterns start with `!` and are followed by a normal pattern
  236. -- result: nil = UNKNOWN (not matched yet), true = ACCEPT, false = REJECT
  237. -- default: true if no explicit "include" is found, set to false otherwise
  238. local default, result = true, nil
  239. if patterns ~= nil then
  240. for _, pattern in ipairs(patterns) do
  241. local exclude = pattern:sub(1,1) == '!'
  242. if exclude then
  243. pattern = pattern:sub(2)
  244. else
  245. -- at least one include pattern specified, a match is required
  246. default = false
  247. end
  248. -- print('pattern: ',pattern)
  249. -- print('exclude: ',exclude)
  250. -- print('default: ',default)
  251. if string.find(expr, pattern) then
  252. -- set result to false when excluding, true otherwise
  253. result = not exclude
  254. end
  255. end
  256. end
  257. if result ~= nil then
  258. return result
  259. end
  260. return default
  261. end
  262. M.private.patternFilter = patternFilter
  263. local function xmlEscape( s )
  264. -- Return s escaped for XML attributes
  265. -- escapes table:
  266. -- " &quot;
  267. -- ' &apos;
  268. -- < &lt;
  269. -- > &gt;
  270. -- & &amp;
  271. return string.gsub( s, '.', {
  272. ['&'] = "&amp;",
  273. ['"'] = "&quot;",
  274. ["'"] = "&apos;",
  275. ['<'] = "&lt;",
  276. ['>'] = "&gt;",
  277. } )
  278. end
  279. M.private.xmlEscape = xmlEscape
  280. local function xmlCDataEscape( s )
  281. -- Return s escaped for CData section, escapes: "]]>"
  282. return string.gsub( s, ']]>', ']]&gt;' )
  283. end
  284. M.private.xmlCDataEscape = xmlCDataEscape
  285. local function stripLuaunitTrace( stackTrace )
  286. --[[
  287. -- Example of a traceback:
  288. <<stack traceback:
  289. example_with_luaunit.lua:130: in function 'test2_withFailure'
  290. ./luaunit.lua:1449: in function <./luaunit.lua:1449>
  291. [C]: in function 'xpcall'
  292. ./luaunit.lua:1449: in function 'protectedCall'
  293. ./luaunit.lua:1508: in function 'execOneFunction'
  294. ./luaunit.lua:1596: in function 'runSuiteByInstances'
  295. ./luaunit.lua:1660: in function 'runSuiteByNames'
  296. ./luaunit.lua:1736: in function 'runSuite'
  297. example_with_luaunit.lua:140: in main chunk
  298. [C]: in ?>>
  299. Other example:
  300. <<stack traceback:
  301. ./luaunit.lua:545: in function 'assertEquals'
  302. example_with_luaunit.lua:58: in function 'TestToto.test7'
  303. ./luaunit.lua:1517: in function <./luaunit.lua:1517>
  304. [C]: in function 'xpcall'
  305. ./luaunit.lua:1517: in function 'protectedCall'
  306. ./luaunit.lua:1578: in function 'execOneFunction'
  307. ./luaunit.lua:1677: in function 'runSuiteByInstances'
  308. ./luaunit.lua:1730: in function 'runSuiteByNames'
  309. ./luaunit.lua:1806: in function 'runSuite'
  310. example_with_luaunit.lua:140: in main chunk
  311. [C]: in ?>>
  312. <<stack traceback:
  313. luaunit2/example_with_luaunit.lua:124: in function 'test1_withFailure'
  314. luaunit2/luaunit.lua:1532: in function <luaunit2/luaunit.lua:1532>
  315. [C]: in function 'xpcall'
  316. luaunit2/luaunit.lua:1532: in function 'protectedCall'
  317. luaunit2/luaunit.lua:1591: in function 'execOneFunction'
  318. luaunit2/luaunit.lua:1679: in function 'runSuiteByInstances'
  319. luaunit2/luaunit.lua:1743: in function 'runSuiteByNames'
  320. luaunit2/luaunit.lua:1819: in function 'runSuite'
  321. luaunit2/example_with_luaunit.lua:140: in main chunk
  322. [C]: in ?>>
  323. -- first line is "stack traceback": KEEP
  324. -- next line may be luaunit line: REMOVE
  325. -- next lines are call in the program under testOk: REMOVE
  326. -- next lines are calls from luaunit to call the program under test: KEEP
  327. -- Strategy:
  328. -- keep first line
  329. -- remove lines that are part of luaunit
  330. -- kepp lines until we hit a luaunit line
  331. ]]
  332. local function isLuaunitInternalLine( s )
  333. -- return true if line of stack trace comes from inside luaunit
  334. return s:find('[/\\]luaunit%.lua:%d+: ') ~= nil
  335. end
  336. -- print( '<<'..stackTrace..'>>' )
  337. local t = strsplit( '\n', stackTrace )
  338. -- print( prettystr(t) )
  339. local idx = 2
  340. -- remove lines that are still part of luaunit
  341. while t[idx] and isLuaunitInternalLine( t[idx] ) do
  342. -- print('Removing : '..t[idx] )
  343. table.remove(t, idx)
  344. end
  345. -- keep lines until we hit luaunit again
  346. while t[idx] and (not isLuaunitInternalLine(t[idx])) do
  347. -- print('Keeping : '..t[idx] )
  348. idx = idx + 1
  349. end
  350. -- remove remaining luaunit lines
  351. while t[idx] do
  352. -- print('Removing : '..t[idx] )
  353. table.remove(t, idx)
  354. end
  355. -- print( prettystr(t) )
  356. return table.concat( t, '\n')
  357. end
  358. M.private.stripLuaunitTrace = stripLuaunitTrace
  359. local function prettystr_sub(v, indentLevel, keeponeline, printTableRefs, recursionTable )
  360. local type_v = type(v)
  361. if "string" == type_v then
  362. if keeponeline then
  363. v = v:gsub("\n", "\\n") -- escape newline(s)
  364. end
  365. -- use clever delimiters according to content:
  366. -- enclose with single quotes if string contains ", but no '
  367. if v:find('"', 1, true) and not v:find("'", 1, true) then
  368. return "'" .. v .. "'"
  369. end
  370. -- use double quotes otherwise, escape embedded "
  371. return '"' .. v:gsub('"', '\\"') .. '"'
  372. elseif "table" == type_v then
  373. --if v.__class__ then
  374. -- return string.gsub( tostring(v), 'table', v.__class__ )
  375. --end
  376. return M.private._table_tostring(v, indentLevel, keeponeline,
  377. printTableRefs, recursionTable)
  378. elseif "number" == type_v then
  379. -- eliminate differences in formatting between various Lua versions
  380. if v ~= v then
  381. return "#NaN" -- "not a number"
  382. end
  383. if v == math.huge then
  384. return "#Inf" -- "infinite"
  385. end
  386. if v == -math.huge then
  387. return "-#Inf"
  388. end
  389. if _VERSION == "Lua 5.3" then
  390. local i = math.tointeger(v)
  391. if i then
  392. return tostring(i)
  393. end
  394. end
  395. end
  396. return tostring(v)
  397. end
  398. local function prettystr( v, keeponeline )
  399. --[[ Better string conversion, to display nice variable content:
  400. For strings, if keeponeline is set to true, string is displayed on one line, with visible \n
  401. * string are enclosed with " by default, or with ' if string contains a "
  402. * if table is a class, display class name
  403. * tables are expanded
  404. ]]--
  405. local recursionTable = {}
  406. local s = prettystr_sub(v, 1, keeponeline, M.PRINT_TABLE_REF_IN_ERROR_MSG, recursionTable)
  407. if recursionTable.recursionDetected and not M.PRINT_TABLE_REF_IN_ERROR_MSG then
  408. -- some table contain recursive references,
  409. -- so we must recompute the value by including all table references
  410. -- else the result looks like crap
  411. recursionTable = {}
  412. s = prettystr_sub(v, 1, keeponeline, true, recursionTable)
  413. end
  414. return s
  415. end
  416. M.prettystr = prettystr
  417. local function tryMismatchFormatting( table_a, table_b, doDeepAnalysis )
  418. --[[
  419. Prepares a nice error message when comparing tables, performing a deeper
  420. analysis.
  421. Arguments:
  422. * table_a, table_b: tables to be compared
  423. * doDeepAnalysis:
  424. M.DEFAULT_DEEP_ANALYSIS: (the default if not specified) perform deep analysis only for big lists and big dictionnaries
  425. M.FORCE_DEEP_ANALYSIS : always perform deep analysis
  426. M.DISABLE_DEEP_ANALYSIS: never perform deep analysis
  427. Returns: {success, result}
  428. * success: false if deep analysis could not be performed
  429. in this case, just use standard assertion message
  430. * result: if success is true, a multi-line string with deep analysis of the two lists
  431. ]]
  432. -- check if table_a & table_b are suitable for deep analysis
  433. if type(table_a) ~= 'table' or type(table_b) ~= 'table' then
  434. return false
  435. end
  436. if doDeepAnalysis == M.DISABLE_DEEP_ANALYSIS then
  437. return false
  438. end
  439. local len_a, len_b, isPureList = #table_a, #table_b, true
  440. for k1, v1 in pairs(table_a) do
  441. if type(k1) ~= 'number' or k1 > len_a then
  442. -- this table a mapping
  443. isPureList = false
  444. break
  445. end
  446. end
  447. if isPureList then
  448. for k2, v2 in pairs(table_b) do
  449. if type(k2) ~= 'number' or k2 > len_b then
  450. -- this table a mapping
  451. isPureList = false
  452. break
  453. end
  454. end
  455. end
  456. if isPureList and math.min(len_a, len_b) < M.LIST_DIFF_ANALYSIS_THRESHOLD then
  457. if not (doDeepAnalysis == M.FORCE_DEEP_ANALYSIS) then
  458. return false
  459. end
  460. end
  461. if isPureList then
  462. return M.private.mismatchFormattingPureList( table_a, table_b )
  463. else
  464. -- only work on mapping for the moment
  465. -- return M.private.mismatchFormattingMapping( table_a, table_b, doDeepAnalysis )
  466. return false
  467. end
  468. end
  469. M.private.tryMismatchFormatting = tryMismatchFormatting
  470. local function getTaTbDescr()
  471. if not M.ORDER_ACTUAL_EXPECTED then
  472. return 'expected', 'actual'
  473. end
  474. return 'actual', 'expected'
  475. end
  476. local function extendWithStrFmt( res, ... )
  477. table.insert( res, string.format( ... ) )
  478. end
  479. local function mismatchFormattingMapping( table_a, table_b, doDeepAnalysis )
  480. --[[
  481. Prepares a nice error message when comparing tables which are not pure lists, performing a deeper
  482. analysis.
  483. Returns: {success, result}
  484. * success: false if deep analysis could not be performed
  485. in this case, just use standard assertion message
  486. * result: if success is true, a multi-line string with deep analysis of the two lists
  487. ]]
  488. -- disable for the moment
  489. --[[
  490. local result = {}
  491. local descrTa, descrTb = getTaTbDescr()
  492. local keysCommon = {}
  493. local keysOnlyTa = {}
  494. local keysOnlyTb = {}
  495. local keysDiffTaTb = {}
  496. local k, v
  497. for k,v in pairs( table_a ) do
  498. if is_equal( v, table_b[k] ) then
  499. table.insert( keysCommon, k )
  500. else
  501. if table_b[k] == nil then
  502. table.insert( keysOnlyTa, k )
  503. else
  504. table.insert( keysDiffTaTb, k )
  505. end
  506. end
  507. end
  508. for k,v in pairs( table_b ) do
  509. if not is_equal( v, table_a[k] ) and table_a[k] == nil then
  510. table.insert( keysOnlyTb, k )
  511. end
  512. end
  513. local len_a = #keysCommon + #keysDiffTaTb + #keysOnlyTa
  514. local len_b = #keysCommon + #keysDiffTaTb + #keysOnlyTb
  515. local limited_display = (len_a < 5 or len_b < 5)
  516. if math.min(len_a, len_b) < M.TABLE_DIFF_ANALYSIS_THRESHOLD then
  517. return false
  518. end
  519. if not limited_display then
  520. if len_a == len_b then
  521. extendWithStrFmt( result, 'Table A (%s) and B (%s) both have %d items', descrTa, descrTb, len_a )
  522. else
  523. extendWithStrFmt( result, 'Table A (%s) has %d items and table B (%s) has %d items', descrTa, len_a, descrTb, len_b )
  524. end
  525. if #keysCommon == 0 and #keysDiffTaTb == 0 then
  526. table.insert( result, 'Table A and B have no keys in common, they are totally different')
  527. else
  528. local s_other = 'other '
  529. if #keysCommon then
  530. extendWithStrFmt( result, 'Table A and B have %d identical items', #keysCommon )
  531. else
  532. table.insert( result, 'Table A and B have no identical items' )
  533. s_other = ''
  534. end
  535. if #keysDiffTaTb ~= 0 then
  536. result[#result] = string.format( '%s and %d items differing present in both tables', result[#result], #keysDiffTaTb)
  537. else
  538. result[#result] = string.format( '%s and no %sitems differing present in both tables', result[#result], s_other, #keysDiffTaTb)
  539. end
  540. end
  541. extendWithStrFmt( result, 'Table A has %d keys not present in table B and table B has %d keys not present in table A', #keysOnlyTa, #keysOnlyTb )
  542. end
  543. local function keytostring(k)
  544. if "string" == type(k) and k:match("^[_%a][_%w]*$") then
  545. return k
  546. end
  547. return prettystr(k)
  548. end
  549. if #keysDiffTaTb ~= 0 then
  550. table.insert( result, 'Items differing in A and B:')
  551. for k,v in sortedPairs( keysDiffTaTb ) do
  552. extendWithStrFmt( result, ' - A[%s]: %s', keytostring(v), prettystr(table_a[v]) )
  553. extendWithStrFmt( result, ' + B[%s]: %s', keytostring(v), prettystr(table_b[v]) )
  554. end
  555. end
  556. if #keysOnlyTa ~= 0 then
  557. table.insert( result, 'Items only in table A:' )
  558. for k,v in sortedPairs( keysOnlyTa ) do
  559. extendWithStrFmt( result, ' - A[%s]: %s', keytostring(v), prettystr(table_a[v]) )
  560. end
  561. end
  562. if #keysOnlyTb ~= 0 then
  563. table.insert( result, 'Items only in table B:' )
  564. for k,v in sortedPairs( keysOnlyTb ) do
  565. extendWithStrFmt( result, ' + B[%s]: %s', keytostring(v), prettystr(table_b[v]) )
  566. end
  567. end
  568. if #keysCommon ~= 0 then
  569. table.insert( result, 'Items common to A and B:')
  570. for k,v in sortedPairs( keysCommon ) do
  571. extendWithStrFmt( result, ' = A and B [%s]: %s', keytostring(v), prettystr(table_a[v]) )
  572. end
  573. end
  574. return true, table.concat( result, '\n')
  575. ]]
  576. end
  577. M.private.mismatchFormattingMapping = mismatchFormattingMapping
  578. local function mismatchFormattingPureList( table_a, table_b )
  579. --[[
  580. Prepares a nice error message when comparing tables which are lists, performing a deeper
  581. analysis.
  582. Returns: {success, result}
  583. * success: false if deep analysis could not be performed
  584. in this case, just use standard assertion message
  585. * result: if success is true, a multi-line string with deep analysis of the two lists
  586. ]]
  587. local result, descrTa, descrTb = {}, getTaTbDescr()
  588. local len_a, len_b, refa, refb = #table_a, #table_b, '', ''
  589. if M.PRINT_TABLE_REF_IN_ERROR_MSG then
  590. refa, refb = string.format( '<%s> ', tostring(table_a)), string.format('<%s> ', tostring(table_b) )
  591. end
  592. local longest, shortest = math.max(len_a, len_b), math.min(len_a, len_b)
  593. local deltalv = longest - shortest
  594. local commonUntil = shortest
  595. for i = 1, shortest do
  596. if not is_equal(table_a[i], table_b[i]) then
  597. commonUntil = i - 1
  598. break
  599. end
  600. end
  601. local commonBackTo = shortest - 1
  602. for i = 0, shortest - 1 do
  603. if not is_equal(table_a[len_a-i], table_b[len_b-i]) then
  604. commonBackTo = i - 1
  605. break
  606. end
  607. end
  608. table.insert( result, 'List difference analysis:' )
  609. if len_a == len_b then
  610. -- TODO: handle expected/actual naming
  611. extendWithStrFmt( result, '* lists %sA (%s) and %sB (%s) have the same size', refa, descrTa, refb, descrTb )
  612. else
  613. extendWithStrFmt( result, '* list sizes differ: list %sA (%s) has %d items, list %sB (%s) has %d items', refa, descrTa, len_a, refb, descrTb, len_b )
  614. end
  615. extendWithStrFmt( result, '* lists A and B start differing at index %d', commonUntil+1 )
  616. if commonBackTo >= 0 then
  617. if deltalv > 0 then
  618. extendWithStrFmt( result, '* lists A and B are equal again from index %d for A, %d for B', len_a-commonBackTo, len_b-commonBackTo )
  619. else
  620. extendWithStrFmt( result, '* lists A and B are equal again from index %d', len_a-commonBackTo )
  621. end
  622. end
  623. local function insertABValue(ai, bi)
  624. bi = bi or ai
  625. if is_equal( table_a[ai], table_b[bi]) then
  626. return extendWithStrFmt( result, ' = A[%d], B[%d]: %s', ai, bi, prettystr(table_a[ai]) )
  627. else
  628. extendWithStrFmt( result, ' - A[%d]: %s', ai, prettystr(table_a[ai]))
  629. extendWithStrFmt( result, ' + B[%d]: %s', bi, prettystr(table_b[bi]))
  630. end
  631. end
  632. -- common parts to list A & B, at the beginning
  633. if commonUntil > 0 then
  634. table.insert( result, '* Common parts:' )
  635. for i = 1, commonUntil do
  636. insertABValue( i )
  637. end
  638. end
  639. -- diffing parts to list A & B
  640. if commonUntil < shortest - commonBackTo - 1 then
  641. table.insert( result, '* Differing parts:' )
  642. for i = commonUntil + 1, shortest - commonBackTo - 1 do
  643. insertABValue( i )
  644. end
  645. end
  646. -- display indexes of one list, with no match on other list
  647. if shortest - commonBackTo <= longest - commonBackTo - 1 then
  648. table.insert( result, '* Present only in one list:' )
  649. for i = shortest - commonBackTo, longest - commonBackTo - 1 do
  650. if len_a > len_b then
  651. extendWithStrFmt( result, ' - A[%d]: %s', i, prettystr(table_a[i]) )
  652. -- table.insert( result, '+ (no matching B index)')
  653. else
  654. -- table.insert( result, '- no matching A index')
  655. extendWithStrFmt( result, ' + B[%d]: %s', i, prettystr(table_b[i]) )
  656. end
  657. end
  658. end
  659. -- common parts to list A & B, at the end
  660. if commonBackTo >= 0 then
  661. table.insert( result, '* Common parts at the end of the lists' )
  662. for i = longest - commonBackTo, longest do
  663. if len_a > len_b then
  664. insertABValue( i, i-deltalv )
  665. else
  666. insertABValue( i-deltalv, i )
  667. end
  668. end
  669. end
  670. return true, table.concat( result, '\n')
  671. end
  672. M.private.mismatchFormattingPureList = mismatchFormattingPureList
  673. local function prettystrPairs(value1, value2, suffix_a, suffix_b)
  674. --[[
  675. This function helps with the recurring task of constructing the "expected
  676. vs. actual" error messages. It takes two arbitrary values and formats
  677. corresponding strings with prettystr().
  678. To keep the (possibly complex) output more readable in case the resulting
  679. strings contain line breaks, they get automatically prefixed with additional
  680. newlines. Both suffixes are optional (default to empty strings), and get
  681. appended to the "value1" string. "suffix_a" is used if line breaks were
  682. encountered, "suffix_b" otherwise.
  683. Returns the two formatted strings (including padding/newlines).
  684. ]]
  685. local str1, str2 = prettystr(value1), prettystr(value2)
  686. if hasNewLine(str1) or hasNewLine(str2) then
  687. -- line break(s) detected, add padding
  688. return "\n" .. str1 .. (suffix_a or ""), "\n" .. str2
  689. end
  690. return str1 .. (suffix_b or ""), str2
  691. end
  692. M.private.prettystrPairs = prettystrPairs
  693. local TABLE_TOSTRING_SEP = ", "
  694. local TABLE_TOSTRING_SEP_LEN = string.len(TABLE_TOSTRING_SEP)
  695. local function _table_tostring( tbl, indentLevel, keeponeline, printTableRefs, recursionTable )
  696. printTableRefs = printTableRefs or M.PRINT_TABLE_REF_IN_ERROR_MSG
  697. recursionTable = recursionTable or {}
  698. recursionTable[tbl] = true
  699. local result, dispOnMultLines = {}, false
  700. -- like prettystr but do not enclose with "" if the string is just alphanumerical
  701. -- this is better for displaying table keys who are often simple strings
  702. local function keytostring(k)
  703. if "string" == type(k) and k:match("^[_%a][_%w]*$") then
  704. return k
  705. end
  706. return prettystr_sub(k, indentLevel+1, true, printTableRefs, recursionTable)
  707. end
  708. local entry, count, seq_index = nil, 0, 1
  709. for k, v in sortedPairs( tbl ) do
  710. if k == seq_index then
  711. -- for the sequential part of tables, we'll skip the "<key>=" output
  712. entry = ''
  713. seq_index = seq_index + 1
  714. elseif recursionTable[k] then
  715. -- recursion in the key detected
  716. recursionTable.recursionDetected = true
  717. entry = "<"..tostring(k)..">="
  718. else
  719. entry = keytostring(k) .. "="
  720. end
  721. if recursionTable[v] then
  722. -- recursion in the value detected!
  723. recursionTable.recursionDetected = true
  724. entry = entry .. "<"..tostring(v)..">"
  725. else
  726. entry = entry ..
  727. prettystr_sub( v, indentLevel+1, keeponeline, printTableRefs, recursionTable )
  728. end
  729. count = count + 1
  730. result[count] = entry
  731. end
  732. if not keeponeline then
  733. -- set dispOnMultLines if the maximum LINE_LENGTH would be exceeded
  734. local totalLength = 0
  735. for k, v in ipairs( result ) do
  736. totalLength = totalLength + string.len( v )
  737. if totalLength >= M.LINE_LENGTH then
  738. dispOnMultLines = true
  739. break
  740. end
  741. end
  742. if not dispOnMultLines then
  743. -- adjust with length of separator(s):
  744. -- two items need 1 sep, three items two seps, ... plus len of '{}'
  745. if count > 0 then
  746. totalLength = totalLength + TABLE_TOSTRING_SEP_LEN * (count - 1)
  747. end
  748. dispOnMultLines = totalLength + 2 >= M.LINE_LENGTH
  749. end
  750. end
  751. -- now reformat the result table (currently holding element strings)
  752. if dispOnMultLines then
  753. local indentString = string.rep(" ", indentLevel - 1)
  754. result = {"{\n ", indentString,
  755. table.concat(result, ",\n " .. indentString), "\n",
  756. indentString, "}"}
  757. else
  758. result = {"{", table.concat(result, TABLE_TOSTRING_SEP), "}"}
  759. end
  760. if printTableRefs then
  761. table.insert(result, 1, "<"..tostring(tbl).."> ") -- prepend table ref
  762. end
  763. return table.concat(result)
  764. end
  765. M.private._table_tostring = _table_tostring -- prettystr_sub() needs it
  766. local function _table_contains(t, element)
  767. if type(t) == "table" then
  768. local type_e = type(element)
  769. for _, value in pairs(t) do
  770. if type(value) == type_e then
  771. if value == element then
  772. return true
  773. end
  774. if type_e == 'table' then
  775. -- if we wanted recursive items content comparison, we could use
  776. -- _is_table_items_equals(v, expected) but one level of just comparing
  777. -- items is sufficient
  778. if M.private._is_table_equals( value, element ) then
  779. return true
  780. end
  781. end
  782. end
  783. end
  784. end
  785. return false
  786. end
  787. local function _is_table_items_equals(actual, expected )
  788. local type_a, type_e = type(actual), type(expected)
  789. if (type_a == 'table') and (type_e == 'table') then
  790. for k, v in pairs(actual) do
  791. if not _table_contains(expected, v) then
  792. return false
  793. end
  794. end
  795. for k, v in pairs(expected) do
  796. if not _table_contains(actual, v) then
  797. return false
  798. end
  799. end
  800. return true
  801. elseif type_a ~= type_e then
  802. return false
  803. elseif actual ~= expected then
  804. return false
  805. end
  806. return true
  807. end
  808. --[[
  809. This is a specialized metatable to help with the bookkeeping of recursions
  810. in _is_table_equals(). It provides an __index table that implements utility
  811. functions for easier management of the table. The "cached" method queries
  812. the state of a specific (actual,expected) pair; and the "store" method sets
  813. this state to the given value. The state of pairs not "seen" / visited is
  814. assumed to be `nil`.
  815. ]]
  816. local _recursion_cache_MT = {
  817. __index = {
  818. -- Return the cached value for an (actual,expected) pair (or `nil`)
  819. cached = function(t, actual, expected)
  820. local subtable = t[actual] or {}
  821. return subtable[expected]
  822. end,
  823. -- Store cached value for a specific (actual,expected) pair.
  824. -- Returns the value, so it's easy to use for a "tailcall" (return ...).
  825. store = function(t, actual, expected, value, asymmetric)
  826. local subtable = t[actual]
  827. if not subtable then
  828. subtable = {}
  829. t[actual] = subtable
  830. end
  831. subtable[expected] = value
  832. -- Unless explicitly marked "asymmetric": Consider the recursion
  833. -- on (expected,actual) to be equivalent to (actual,expected) by
  834. -- default, and thus cache the value for both.
  835. if not asymmetric then
  836. t:store(expected, actual, value, true)
  837. end
  838. return value
  839. end
  840. }
  841. }
  842. local function _is_table_equals(actual, expected, recursions)
  843. local type_a, type_e = type(actual), type(expected)
  844. recursions = recursions or setmetatable({}, _recursion_cache_MT)
  845. if type_a ~= type_e then
  846. return false -- different types won't match
  847. end
  848. if (type_a == 'table') --[[ and (type_e == 'table') ]] then
  849. if actual == expected then
  850. -- Both reference the same table, so they are actually identical
  851. return recursions:store(actual, expected, true)
  852. end
  853. -- If we've tested this (actual,expected) pair before: return cached value
  854. local previous = recursions:cached(actual, expected)
  855. if previous ~= nil then
  856. return previous
  857. end
  858. -- Mark this (actual,expected) pair, so we won't recurse it again. For
  859. -- now, assume a "false" result, which we might adjust later if needed.
  860. recursions:store(actual, expected, false)
  861. -- Tables must have identical element count, or they can't match.
  862. if (#actual ~= #expected) then
  863. return false
  864. end
  865. local actualKeysMatched, actualTableKeys = {}, {}
  866. for k, v in pairs(actual) do
  867. if M.TABLE_EQUALS_KEYBYCONTENT and type(k) == "table" then
  868. -- If the keys are tables, things get a bit tricky here as we
  869. -- can have _is_table_equals(t[k1], t[k2]) despite k1 ~= k2. So
  870. -- we first collect table keys from "actual", and then later try
  871. -- to match each table key from "expected" to actualTableKeys.
  872. table.insert(actualTableKeys, k)
  873. else
  874. if not _is_table_equals(v, expected[k], recursions) then
  875. return false -- Mismatch on value, tables can't be equal
  876. end
  877. actualKeysMatched[k] = true -- Keep track of matched keys
  878. end
  879. end
  880. for k, v in pairs(expected) do
  881. if M.TABLE_EQUALS_KEYBYCONTENT and type(k) == "table" then
  882. local found = false
  883. -- Note: DON'T use ipairs() here, table may be non-sequential!
  884. for i, candidate in pairs(actualTableKeys) do
  885. if _is_table_equals(candidate, k, recursions) then
  886. if _is_table_equals(actual[candidate], v, recursions) then
  887. found = true
  888. -- Remove the candidate we matched against from the list
  889. -- of table keys, so each key in actual can only match
  890. -- one key in expected.
  891. actualTableKeys[i] = nil
  892. break
  893. end
  894. -- keys match but values don't, keep searching
  895. end
  896. end
  897. if not found then
  898. return false -- no matching (key,value) pair
  899. end
  900. else
  901. if not actualKeysMatched[k] then
  902. -- Found a key that we did not see in "actual" -> mismatch
  903. return false
  904. end
  905. -- Otherwise actual[k] was already matched against v = expected[k].
  906. end
  907. end
  908. if next(actualTableKeys) then
  909. -- If there is any key left in actualTableKeys, then that is
  910. -- a table-type key in actual with no matching counterpart
  911. -- (in expected), and so the tables aren't equal.
  912. return false
  913. end
  914. -- The tables are actually considered equal, update cache and return result
  915. return recursions:store(actual, expected, true)
  916. elseif actual ~= expected then
  917. return false
  918. end
  919. return true
  920. end
  921. M.private._is_table_equals = _is_table_equals
  922. is_equal = _is_table_equals
  923. local function failure(msg, level)
  924. -- raise an error indicating a test failure
  925. -- for error() compatibility we adjust "level" here (by +1), to report the
  926. -- calling context
  927. error(M.FAILURE_PREFIX .. msg, (level or 1) + 1)
  928. end
  929. local function fail_fmt(level, ...)
  930. -- failure with printf-style formatted message and given error level
  931. failure(string.format(...), (level or 1) + 1)
  932. end
  933. M.private.fail_fmt = fail_fmt
  934. local function error_fmt(level, ...)
  935. -- printf-style error()
  936. error(string.format(...), (level or 1) + 1)
  937. end
  938. ----------------------------------------------------------------
  939. --
  940. -- assertions
  941. --
  942. ----------------------------------------------------------------
  943. local function errorMsgEquality(actual, expected, doDeepAnalysis)
  944. if not M.ORDER_ACTUAL_EXPECTED then
  945. expected, actual = actual, expected
  946. end
  947. if type(expected) == 'string' or type(expected) == 'table' then
  948. local strExpected, strActual = prettystrPairs(expected, actual)
  949. local result = string.format("expected: %s\nactual: %s", strExpected, strActual)
  950. -- extend with mismatch analysis if possible:
  951. local success, mismatchResult
  952. success, mismatchResult = tryMismatchFormatting( actual, expected, doDeepAnalysis )
  953. if success then
  954. result = table.concat( { result, mismatchResult }, '\n' )
  955. end
  956. return result
  957. end
  958. return string.format("expected: %s, actual: %s",
  959. prettystr(expected), prettystr(actual))
  960. end
  961. function M.assertError(f, ...)
  962. -- assert that calling f with the arguments will raise an error
  963. -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error
  964. if pcall( f, ... ) then
  965. failure( "Expected an error when calling function but no error generated", 2 )
  966. end
  967. end
  968. function M.assertEvalToTrue(value)
  969. if not value then
  970. failure("expected: a value evaluating to true, actual: " ..prettystr(value), 2)
  971. end
  972. end
  973. function M.assertEvalToFalse(value)
  974. if value then
  975. failure("expected: false or nil, actual: " ..prettystr(value), 2)
  976. end
  977. end
  978. function M.assertIsTrue(value)
  979. if value ~= true then
  980. failure("expected: true, actual: " ..prettystr(value), 2)
  981. end
  982. end
  983. function M.assertNotIsTrue(value)
  984. if value == true then
  985. failure("expected: anything but true, actual: " ..prettystr(value), 2)
  986. end
  987. end
  988. function M.assertIsFalse(value)
  989. if value ~= false then
  990. failure("expected: false, actual: " ..prettystr(value), 2)
  991. end
  992. end
  993. function M.assertNotIsFalse(value)
  994. if value == false then
  995. failure("expected: anything but false, actual: " ..prettystr(value), 2)
  996. end
  997. end
  998. function M.assertIsNil(value)
  999. if value ~= nil then
  1000. failure("expected: nil, actual: " ..prettystr(value), 2)
  1001. end
  1002. end
  1003. function M.assertNotIsNil(value)
  1004. if value == nil then
  1005. failure("expected non nil value, received nil", 2)
  1006. end
  1007. end
  1008. function M.assertIsNaN(value)
  1009. if type(value) ~= "number" or value == value then
  1010. failure("expected: nan, actual: " ..prettystr(value), 2)
  1011. end
  1012. end
  1013. function M.assertNotIsNaN(value)
  1014. if type(value) == "number" and value ~= value then
  1015. failure("expected non nan value, received nan", 2)
  1016. end
  1017. end
  1018. function M.assertIsInf(value)
  1019. if type(value) ~= "number" or math.abs(value) ~= math.huge then
  1020. failure("expected: inf, actual: " ..prettystr(value), 2)
  1021. end
  1022. end
  1023. function M.assertNotIsInf(value)
  1024. if type(value) == "number" and math.abs(value) == math.huge then
  1025. failure("expected non inf value, received ±inf", 2)
  1026. end
  1027. end
  1028. function M.assertEquals(actual, expected, doDeepAnalysis)
  1029. if type(actual) == 'table' and type(expected) == 'table' then
  1030. if not _is_table_equals(actual, expected) then
  1031. failure( errorMsgEquality(actual, expected, doDeepAnalysis), 2 )
  1032. end
  1033. elseif type(actual) ~= type(expected) then
  1034. failure( errorMsgEquality(actual, expected), 2 )
  1035. elseif actual ~= expected then
  1036. failure( errorMsgEquality(actual, expected), 2 )
  1037. end
  1038. end
  1039. function M.almostEquals( actual, expected, margin )
  1040. if type(actual) ~= 'number' or type(expected) ~= 'number' or type(margin) ~= 'number' then
  1041. error_fmt(3, 'almostEquals: must supply only number arguments.\nArguments supplied: %s, %s, %s',
  1042. prettystr(actual), prettystr(expected), prettystr(margin))
  1043. end
  1044. if margin < 0 then
  1045. error('almostEquals: margin must not be negative, current value is ' .. margin, 3)
  1046. end
  1047. return math.abs(expected - actual) <= margin
  1048. end
  1049. function M.assertAlmostEquals( actual, expected, margin )
  1050. -- check that two floats are close by margin
  1051. margin = margin or M.EPSILON
  1052. if not M.almostEquals(actual, expected, margin) then
  1053. if not M.ORDER_ACTUAL_EXPECTED then
  1054. expected, actual = actual, expected
  1055. end
  1056. local delta = math.abs(actual - expected) - margin
  1057. fail_fmt(2, 'Values are not almost equal\n' ..
  1058. 'Actual: %s, expected: %s with margin of %s; delta: %s',
  1059. actual, expected, margin, delta)
  1060. end
  1061. end
  1062. function M.assertNotEquals(actual, expected)
  1063. if type(actual) ~= type(expected) then
  1064. return
  1065. end
  1066. if type(actual) == 'table' and type(expected) == 'table' then
  1067. if not _is_table_equals(actual, expected) then
  1068. return
  1069. end
  1070. elseif actual ~= expected then
  1071. return
  1072. end
  1073. fail_fmt(2, 'Received the not expected value: %s', prettystr(actual))
  1074. end
  1075. function M.assertNotAlmostEquals( actual, expected, margin )
  1076. -- check that two floats are not close by margin
  1077. margin = margin or M.EPSILON
  1078. if M.almostEquals(actual, expected, margin) then
  1079. if not M.ORDER_ACTUAL_EXPECTED then
  1080. expected, actual = actual, expected
  1081. end
  1082. local delta = margin - math.abs(actual - expected)
  1083. fail_fmt(2, 'Values are almost equal\nActual: %s, expected: %s' ..
  1084. ' with a difference above margin of %s; delta: %s',
  1085. actual, expected, margin, delta)
  1086. end
  1087. end
  1088. function M.assertStrContains( str, sub, useRe )
  1089. -- this relies on lua string.find function
  1090. -- a string always contains the empty string
  1091. if not string.find(str, sub, 1, not useRe) then
  1092. sub, str = prettystrPairs(sub, str, '\n')
  1093. fail_fmt(2, 'Error, %s %s was not found in string %s',
  1094. useRe and 'regexp' or 'substring', sub, str)
  1095. end
  1096. end
  1097. function M.assertStrIContains( str, sub )
  1098. -- this relies on lua string.find function
  1099. -- a string always contains the empty string
  1100. if not string.find(str:lower(), sub:lower(), 1, true) then
  1101. sub, str = prettystrPairs(sub, str, '\n')
  1102. fail_fmt(2, 'Error, substring %s was not found (case insensitively) in string %s',
  1103. sub, str)
  1104. end
  1105. end
  1106. function M.assertNotStrContains( str, sub, useRe )
  1107. -- this relies on lua string.find function
  1108. -- a string always contains the empty string
  1109. if string.find(str, sub, 1, not useRe) then
  1110. sub, str = prettystrPairs(sub, str, '\n')
  1111. fail_fmt(2, 'Error, %s %s was found in string %s',
  1112. useRe and 'regexp' or 'substring', sub, str)
  1113. end
  1114. end
  1115. function M.assertNotStrIContains( str, sub )
  1116. -- this relies on lua string.find function
  1117. -- a string always contains the empty string
  1118. if string.find(str:lower(), sub:lower(), 1, true) then
  1119. sub, str = prettystrPairs(sub, str, '\n')
  1120. fail_fmt(2, 'Error, substring %s was found (case insensitively) in string %s',
  1121. sub, str)
  1122. end
  1123. end
  1124. function M.assertStrMatches( str, pattern, start, final )
  1125. -- Verify a full match for the string
  1126. -- for a partial match, simply use assertStrContains with useRe set to true
  1127. if not strMatch( str, pattern, start, final ) then
  1128. pattern, str = prettystrPairs(pattern, str, '\n')
  1129. fail_fmt(2, 'Error, pattern %s was not matched by string %s',
  1130. pattern, str)
  1131. end
  1132. end
  1133. function M.assertErrorMsgEquals( expectedMsg, func, ... )
  1134. -- assert that calling f with the arguments will raise an error
  1135. -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error
  1136. local no_error, error_msg = pcall( func, ... )
  1137. if no_error then
  1138. failure( 'No error generated when calling function but expected error: "'..expectedMsg..'"', 2 )
  1139. end
  1140. if error_msg ~= expectedMsg then
  1141. error_msg, expectedMsg = prettystrPairs(error_msg, expectedMsg)
  1142. fail_fmt(2, 'Exact error message expected: %s\nError message received: %s\n',
  1143. expectedMsg, error_msg)
  1144. end
  1145. end
  1146. function M.assertErrorMsgContains( partialMsg, func, ... )
  1147. -- assert that calling f with the arguments will raise an error
  1148. -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error
  1149. local no_error, error_msg = pcall( func, ... )
  1150. if no_error then
  1151. failure( 'No error generated when calling function but expected error containing: '..prettystr(partialMsg), 2 )
  1152. end
  1153. if not string.find( error_msg, partialMsg, nil, true ) then
  1154. error_msg, partialMsg = prettystrPairs(error_msg, partialMsg)
  1155. fail_fmt(2, 'Error message does not contain: %s\nError message received: %s\n',
  1156. partialMsg, error_msg)
  1157. end
  1158. end
  1159. function M.assertErrorMsgMatches( expectedMsg, func, ... )
  1160. -- assert that calling f with the arguments will raise an error
  1161. -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error
  1162. local no_error, error_msg = pcall( func, ... )
  1163. if no_error then
  1164. failure( 'No error generated when calling function but expected error matching: "'..expectedMsg..'"', 2 )
  1165. end
  1166. if not strMatch( error_msg, expectedMsg ) then
  1167. expectedMsg, error_msg = prettystrPairs(expectedMsg, error_msg)
  1168. fail_fmt(2, 'Error message does not match: %s\nError message received: %s\n',
  1169. expectedMsg, error_msg)
  1170. end
  1171. end
  1172. --[[
  1173. Add type assertion functions to the module table M. Each of these functions
  1174. takes a single parameter "value", and checks that its Lua type matches the
  1175. expected string (derived from the function name):
  1176. M.assertIsXxx(value) -> ensure that type(value) conforms to "xxx"
  1177. ]]
  1178. for _, funcName in ipairs(
  1179. {'assertIsNumber', 'assertIsString', 'assertIsTable', 'assertIsBoolean',
  1180. 'assertIsFunction', 'assertIsUserdata', 'assertIsThread'}
  1181. ) do
  1182. local typeExpected = funcName:match("^assertIs([A-Z]%a*)$")
  1183. -- Lua type() always returns lowercase, also make sure the match() succeeded
  1184. typeExpected = typeExpected and typeExpected:lower()
  1185. or error("bad function name '"..funcName.."' for type assertion")
  1186. M[funcName] = function(value)
  1187. if type(value) ~= typeExpected then
  1188. fail_fmt(2, 'Expected: a %s value, actual: type %s, value %s',
  1189. typeExpected, type(value), prettystrPairs(value))
  1190. end
  1191. end
  1192. end
  1193. --[[
  1194. Add shortcuts for verifying type of a variable, without failure (luaunit v2 compatibility)
  1195. M.isXxx(value) -> returns true if type(value) conforms to "xxx"
  1196. ]]
  1197. for _, typeExpected in ipairs(
  1198. {'Number', 'String', 'Table', 'Boolean',
  1199. 'Function', 'Userdata', 'Thread', 'Nil' }
  1200. ) do
  1201. local typeExpectedLower = typeExpected:lower()
  1202. local isType = function(value)
  1203. return (type(value) == typeExpectedLower)
  1204. end
  1205. M['is'..typeExpected] = isType
  1206. M['is_'..typeExpectedLower] = isType
  1207. end
  1208. --[[
  1209. Add non-type assertion functions to the module table M. Each of these functions
  1210. takes a single parameter "value", and checks that its Lua type differs from the
  1211. expected string (derived from the function name):
  1212. M.assertNotIsXxx(value) -> ensure that type(value) is not "xxx"
  1213. ]]
  1214. for _, funcName in ipairs(
  1215. {'assertNotIsNumber', 'assertNotIsString', 'assertNotIsTable', 'assertNotIsBoolean',
  1216. 'assertNotIsFunction', 'assertNotIsUserdata', 'assertNotIsThread'}
  1217. ) do
  1218. local typeUnexpected = funcName:match("^assertNotIs([A-Z]%a*)$")
  1219. -- Lua type() always returns lowercase, also make sure the match() succeeded
  1220. typeUnexpected = typeUnexpected and typeUnexpected:lower()
  1221. or error("bad function name '"..funcName.."' for type assertion")
  1222. M[funcName] = function(value)
  1223. if type(value) == typeUnexpected then
  1224. fail_fmt(2, 'Not expected: a %s type, actual: value %s',
  1225. typeUnexpected, prettystrPairs(value))
  1226. end
  1227. end
  1228. end
  1229. function M.assertIs(actual, expected)
  1230. if actual ~= expected then
  1231. if not M.ORDER_ACTUAL_EXPECTED then
  1232. actual, expected = expected, actual
  1233. end
  1234. expected, actual = prettystrPairs(expected, actual, '\n', ', ')
  1235. fail_fmt(2, 'Expected object and actual object are not the same\nExpected: %sactual: %s',
  1236. expected, actual)
  1237. end
  1238. end
  1239. function M.assertNotIs(actual, expected)
  1240. if actual == expected then
  1241. if not M.ORDER_ACTUAL_EXPECTED then
  1242. expected = actual
  1243. end
  1244. fail_fmt(2, 'Expected object and actual object are the same object: %s',
  1245. prettystrPairs(expected))
  1246. end
  1247. end
  1248. function M.assertItemsEquals(actual, expected)
  1249. -- checks that the items of table expected
  1250. -- are contained in table actual. Warning, this function
  1251. -- is at least O(n^2)
  1252. if not _is_table_items_equals(actual, expected ) then
  1253. expected, actual = prettystrPairs(expected, actual)
  1254. fail_fmt(2, 'Contents of the tables are not identical:\nExpected: %s\nActual: %s',
  1255. expected, actual)
  1256. end
  1257. end
  1258. ----------------------------------------------------------------
  1259. -- Compatibility layer
  1260. ----------------------------------------------------------------
  1261. -- for compatibility with LuaUnit v2.x
  1262. function M.wrapFunctions()
  1263. -- In LuaUnit version <= 2.1 , this function was necessary to include
  1264. -- a test function inside the global test suite. Nowadays, the functions
  1265. -- are simply run directly as part of the test discovery process.
  1266. -- so just do nothing !
  1267. io.stderr:write[[Use of WrapFunctions() is no longer needed.
  1268. Just prefix your test function names with "test" or "Test" and they
  1269. will be picked up and run by LuaUnit.
  1270. ]]
  1271. end
  1272. local list_of_funcs = {
  1273. -- { official function name , alias }
  1274. -- general assertions
  1275. { 'assertEquals' , 'assert_equals' },
  1276. { 'assertItemsEquals' , 'assert_items_equals' },
  1277. { 'assertNotEquals' , 'assert_not_equals' },
  1278. { 'assertAlmostEquals' , 'assert_almost_equals' },
  1279. { 'assertNotAlmostEquals' , 'assert_not_almost_equals' },
  1280. { 'assertEvalToTrue' , 'assert_eval_to_true' },
  1281. { 'assertEvalToFalse' , 'assert_eval_to_false' },
  1282. { 'assertStrContains' , 'assert_str_contains' },
  1283. { 'assertStrIContains' , 'assert_str_icontains' },
  1284. { 'assertNotStrContains' , 'assert_not_str_contains' },
  1285. { 'assertNotStrIContains' , 'assert_not_str_icontains' },
  1286. { 'assertStrMatches' , 'assert_str_matches' },
  1287. { 'assertError' , 'assert_error' },
  1288. { 'assertErrorMsgEquals' , 'assert_error_msg_equals' },
  1289. { 'assertErrorMsgContains' , 'assert_error_msg_contains' },
  1290. { 'assertErrorMsgMatches' , 'assert_error_msg_matches' },
  1291. { 'assertIs' , 'assert_is' },
  1292. { 'assertNotIs' , 'assert_not_is' },
  1293. { 'wrapFunctions' , 'WrapFunctions' },
  1294. { 'wrapFunctions' , 'wrap_functions' },
  1295. -- type assertions: assertIsXXX -> assert_is_xxx
  1296. { 'assertIsNumber' , 'assert_is_number' },
  1297. { 'assertIsString' , 'assert_is_string' },
  1298. { 'assertIsTable' , 'assert_is_table' },
  1299. { 'assertIsBoolean' , 'assert_is_boolean' },
  1300. { 'assertIsNil' , 'assert_is_nil' },
  1301. { 'assertIsTrue' , 'assert_is_true' },
  1302. { 'assertIsFalse' , 'assert_is_false' },
  1303. { 'assertIsNaN' , 'assert_is_nan' },
  1304. { 'assertIsInf' , 'assert_is_inf' },
  1305. { 'assertIsFunction' , 'assert_is_function' },
  1306. { 'assertIsThread' , 'assert_is_thread' },
  1307. { 'assertIsUserdata' , 'assert_is_userdata' },
  1308. -- type assertions: assertIsXXX -> assertXxx
  1309. { 'assertIsNumber' , 'assertNumber' },
  1310. { 'assertIsString' , 'assertString' },
  1311. { 'assertIsTable' , 'assertTable' },
  1312. { 'assertIsBoolean' , 'assertBoolean' },
  1313. { 'assertIsNil' , 'assertNil' },
  1314. { 'assertIsTrue' , 'assertTrue' },
  1315. { 'assertIsFalse' , 'assertFalse' },
  1316. { 'assertIsNaN' , 'assertNaN' },
  1317. { 'assertIsInf' , 'assertInf' },
  1318. { 'assertIsFunction' , 'assertFunction' },
  1319. { 'assertIsThread' , 'assertThread' },
  1320. { 'assertIsUserdata' , 'assertUserdata' },
  1321. -- type assertions: assertIsXXX -> assert_xxx (luaunit v2 compat)
  1322. { 'assertIsNumber' , 'assert_number' },
  1323. { 'assertIsString' , 'assert_string' },
  1324. { 'assertIsTable' , 'assert_table' },
  1325. { 'assertIsBoolean' , 'assert_boolean' },
  1326. { 'assertIsNil' , 'assert_nil' },
  1327. { 'assertIsTrue' , 'assert_true' },
  1328. { 'assertIsFalse' , 'assert_false' },
  1329. { 'assertIsNaN' , 'assert_nan' },
  1330. { 'assertIsInf' , 'assert_inf' },
  1331. { 'assertIsFunction' , 'assert_function' },
  1332. { 'assertIsThread' , 'assert_thread' },
  1333. { 'assertIsUserdata' , 'assert_userdata' },
  1334. -- type assertions: assertNotIsXXX -> assert_not_is_xxx
  1335. { 'assertNotIsNumber' , 'assert_not_is_number' },
  1336. { 'assertNotIsString' , 'assert_not_is_string' },
  1337. { 'assertNotIsTable' , 'assert_not_is_table' },
  1338. { 'assertNotIsBoolean' , 'assert_not_is_boolean' },
  1339. { 'assertNotIsNil' , 'assert_not_is_nil' },
  1340. { 'assertNotIsTrue' , 'assert_not_is_true' },
  1341. { 'assertNotIsFalse' , 'assert_not_is_false' },
  1342. { 'assertNotIsNaN' , 'assert_not_is_nan' },
  1343. { 'assertNotIsInf' , 'assert_not_is_inf' },
  1344. { 'assertNotIsFunction' , 'assert_not_is_function' },
  1345. { 'assertNotIsThread' , 'assert_not_is_thread' },
  1346. { 'assertNotIsUserdata' , 'assert_not_is_userdata' },
  1347. -- type assertions: assertNotIsXXX -> assertNotXxx (luaunit v2 compat)
  1348. { 'assertNotIsNumber' , 'assertNotNumber' },
  1349. { 'assertNotIsString' , 'assertNotString' },
  1350. { 'assertNotIsTable' , 'assertNotTable' },
  1351. { 'assertNotIsBoolean' , 'assertNotBoolean' },
  1352. { 'assertNotIsNil' , 'assertNotNil' },
  1353. { 'assertNotIsTrue' , 'assertNotTrue' },
  1354. { 'assertNotIsFalse' , 'assertNotFalse' },
  1355. { 'assertNotIsNaN' , 'assertNotNaN' },
  1356. { 'assertNotIsInf' , 'assertNotInf' },
  1357. { 'assertNotIsFunction' , 'assertNotFunction' },
  1358. { 'assertNotIsThread' , 'assertNotThread' },
  1359. { 'assertNotIsUserdata' , 'assertNotUserdata' },
  1360. -- type assertions: assertNotIsXXX -> assert_not_xxx
  1361. { 'assertNotIsNumber' , 'assert_not_number' },
  1362. { 'assertNotIsString' , 'assert_not_string' },
  1363. { 'assertNotIsTable' , 'assert_not_table' },
  1364. { 'assertNotIsBoolean' , 'assert_not_boolean' },
  1365. { 'assertNotIsNil' , 'assert_not_nil' },
  1366. { 'assertNotIsTrue' , 'assert_not_true' },
  1367. { 'assertNotIsFalse' , 'assert_not_false' },
  1368. { 'assertNotIsNaN' , 'assert_not_nan' },
  1369. { 'assertNotIsInf' , 'assert_not_inf' },
  1370. { 'assertNotIsFunction' , 'assert_not_function' },
  1371. { 'assertNotIsThread' , 'assert_not_thread' },
  1372. { 'assertNotIsUserdata' , 'assert_not_userdata' },
  1373. -- all assertions with Coroutine duplicate Thread assertions
  1374. { 'assertIsThread' , 'assertIsCoroutine' },
  1375. { 'assertIsThread' , 'assertCoroutine' },
  1376. { 'assertIsThread' , 'assert_is_coroutine' },
  1377. { 'assertIsThread' , 'assert_coroutine' },
  1378. { 'assertNotIsThread' , 'assertNotIsCoroutine' },
  1379. { 'assertNotIsThread' , 'assertNotCoroutine' },
  1380. { 'assertNotIsThread' , 'assert_not_is_coroutine' },
  1381. { 'assertNotIsThread' , 'assert_not_coroutine' },
  1382. }
  1383. -- Create all aliases in M
  1384. for _,v in ipairs( list_of_funcs ) do
  1385. local funcname, alias = v[1], v[2]
  1386. M[alias] = M[funcname]
  1387. if EXPORT_ASSERT_TO_GLOBALS then
  1388. _G[funcname] = M[funcname]
  1389. _G[alias] = M[funcname]
  1390. end
  1391. end
  1392. ----------------------------------------------------------------
  1393. --
  1394. -- Outputters
  1395. --
  1396. ----------------------------------------------------------------
  1397. -- A common "base" class for outputters
  1398. -- For concepts involved (class inheritance) see http://www.lua.org/pil/16.2.html
  1399. local genericOutput = { __class__ = 'genericOutput' } -- class
  1400. local genericOutput_MT = { __index = genericOutput } -- metatable
  1401. M.genericOutput = genericOutput -- publish, so that custom classes may derive from it
  1402. function genericOutput.new(runner, default_verbosity)
  1403. -- runner is the "parent" object controlling the output, usually a LuaUnit instance
  1404. local t = { runner = runner }
  1405. if runner then
  1406. t.result = runner.result
  1407. t.verbosity = runner.verbosity or default_verbosity
  1408. t.fname = runner.fname
  1409. else
  1410. t.verbosity = default_verbosity
  1411. end
  1412. return setmetatable( t, genericOutput_MT)
  1413. end
  1414. -- abstract ("empty") methods
  1415. function genericOutput:startSuite() end
  1416. function genericOutput:startClass(className) end
  1417. function genericOutput:startTest(testName) end
  1418. function genericOutput:addStatus(node) end
  1419. function genericOutput:endTest(node) end
  1420. function genericOutput:endClass() end
  1421. function genericOutput:endSuite() end
  1422. ----------------------------------------------------------------
  1423. -- class TapOutput
  1424. ----------------------------------------------------------------
  1425. local TapOutput = genericOutput.new() -- derived class
  1426. local TapOutput_MT = { __index = TapOutput } -- metatable
  1427. TapOutput.__class__ = 'TapOutput'
  1428. -- For a good reference for TAP format, check: http://testanything.org/tap-specification.html
  1429. function TapOutput.new(runner)
  1430. local t = genericOutput.new(runner, M.VERBOSITY_LOW)
  1431. return setmetatable( t, TapOutput_MT)
  1432. end
  1433. function TapOutput:startSuite()
  1434. print("1.."..self.result.testCount)
  1435. print('# Started on '..self.result.startDate)
  1436. end
  1437. function TapOutput:startClass(className)
  1438. if className ~= '[TestFunctions]' then
  1439. print('# Starting class: '..className)
  1440. end
  1441. end
  1442. function TapOutput:addStatus( node )
  1443. io.stdout:write("not ok ", self.result.currentTestNumber, "\t", node.testName, "\n")
  1444. if self.verbosity > M.VERBOSITY_LOW then
  1445. print( prefixString( ' ', node.msg ) )
  1446. end
  1447. if self.verbosity > M.VERBOSITY_DEFAULT then
  1448. print( prefixString( ' ', node.stackTrace ) )
  1449. end
  1450. end
  1451. function TapOutput:endTest( node )
  1452. if node:isPassed() then
  1453. io.stdout:write("ok ", self.result.currentTestNumber, "\t", node.testName, "\n")
  1454. end
  1455. end
  1456. function TapOutput:endSuite()
  1457. print( '# '..M.LuaUnit.statusLine( self.result ) )
  1458. return self.result.notPassedCount
  1459. end
  1460. -- class TapOutput end
  1461. ----------------------------------------------------------------
  1462. -- class JUnitOutput
  1463. ----------------------------------------------------------------
  1464. -- See directory junitxml for more information about the junit format
  1465. local JUnitOutput = genericOutput.new() -- derived class
  1466. local JUnitOutput_MT = { __index = JUnitOutput } -- metatable
  1467. JUnitOutput.__class__ = 'JUnitOutput'
  1468. function JUnitOutput.new(runner)
  1469. local t = genericOutput.new(runner, M.VERBOSITY_LOW)
  1470. t.testList = {}
  1471. return setmetatable( t, JUnitOutput_MT )
  1472. end
  1473. function JUnitOutput:startSuite()
  1474. -- open xml file early to deal with errors
  1475. if self.fname == nil then
  1476. error('With Junit, an output filename must be supplied with --name!')
  1477. end
  1478. if string.sub(self.fname,-4) ~= '.xml' then
  1479. self.fname = self.fname..'.xml'
  1480. end
  1481. self.fd = io.open(self.fname, "w")
  1482. if self.fd == nil then
  1483. error("Could not open file for writing: "..self.fname)
  1484. end
  1485. print('# XML output to '..self.fname)
  1486. print('# Started on '..self.result.startDate)
  1487. end
  1488. function JUnitOutput:startClass(className)
  1489. if className ~= '[TestFunctions]' then
  1490. print('# Starting class: '..className)
  1491. end
  1492. end
  1493. function JUnitOutput:startTest(testName)
  1494. print('# Starting test: '..testName)
  1495. end
  1496. function JUnitOutput:addStatus( node )
  1497. if node:isFailure() then
  1498. print('# Failure: ' .. node.msg)
  1499. -- print('# ' .. node.stackTrace)
  1500. elseif node:isError() then
  1501. print('# Error: ' .. node.msg)
  1502. -- print('# ' .. node.stackTrace)
  1503. end
  1504. end
  1505. function JUnitOutput:endSuite()
  1506. print( '# '..M.LuaUnit.statusLine(self.result))
  1507. -- XML file writing
  1508. self.fd:write('<?xml version="1.0" encoding="UTF-8" ?>\n')
  1509. self.fd:write('<testsuites>\n')
  1510. self.fd:write(string.format(
  1511. ' <testsuite name="LuaUnit" id="00001" package="" hostname="localhost" tests="%d" timestamp="%s" time="%0.3f" errors="%d" failures="%d">\n',
  1512. self.result.runCount, self.result.startIsodate, self.result.duration, self.result.errorCount, self.result.failureCount ))
  1513. self.fd:write(" <properties>\n")
  1514. self.fd:write(string.format(' <property name="Lua Version" value="%s"/>\n', _VERSION ) )
  1515. self.fd:write(string.format(' <property name="LuaUnit Version" value="%s"/>\n', M.VERSION) )
  1516. -- XXX please include system name and version if possible
  1517. self.fd:write(" </properties>\n")
  1518. for i,node in ipairs(self.result.tests) do
  1519. self.fd:write(string.format(' <testcase classname="%s" name="%s" time="%0.3f">\n',
  1520. node.className, node.testName, node.duration ) )
  1521. if node:isNotPassed() then
  1522. self.fd:write(node:statusXML())
  1523. end
  1524. self.fd:write(' </testcase>\n')
  1525. end
  1526. -- Next two lines are needed to validate junit ANT xsd, but really not useful in general:
  1527. self.fd:write(' <system-out/>\n')
  1528. self.fd:write(' <system-err/>\n')
  1529. self.fd:write(' </testsuite>\n')
  1530. self.fd:write('</testsuites>\n')
  1531. self.fd:close()
  1532. return self.result.notPassedCount
  1533. end
  1534. -- class TapOutput end
  1535. ----------------------------------------------------------------
  1536. -- class TextOutput
  1537. ----------------------------------------------------------------
  1538. --[[
  1539. -- Python Non verbose:
  1540. For each test: . or F or E
  1541. If some failed tests:
  1542. ==============
  1543. ERROR / FAILURE: TestName (testfile.testclass)
  1544. ---------
  1545. Stack trace
  1546. then --------------
  1547. then "Ran x tests in 0.000s"
  1548. then OK or FAILED (failures=1, error=1)
  1549. -- Python Verbose:
  1550. testname (filename.classname) ... ok
  1551. testname (filename.classname) ... FAIL
  1552. testname (filename.classname) ... ERROR
  1553. then --------------
  1554. then "Ran x tests in 0.000s"
  1555. then OK or FAILED (failures=1, error=1)
  1556. -- Ruby:
  1557. Started
  1558. .
  1559. Finished in 0.002695 seconds.
  1560. 1 tests, 2 assertions, 0 failures, 0 errors
  1561. -- Ruby:
  1562. >> ruby tc_simple_number2.rb
  1563. Loaded suite tc_simple_number2
  1564. Started
  1565. F..
  1566. Finished in 0.038617 seconds.
  1567. 1) Failure:
  1568. test_failure(TestSimpleNumber) [tc_simple_number2.rb:16]:
  1569. Adding doesn't work.
  1570. <3> expected but was
  1571. <4>.
  1572. 3 tests, 4 assertions, 1 failures, 0 errors
  1573. -- Java Junit
  1574. .......F.
  1575. Time: 0,003
  1576. There was 1 failure:
  1577. 1) testCapacity(junit.samples.VectorTest)junit.framework.AssertionFailedError
  1578. at junit.samples.VectorTest.testCapacity(VectorTest.java:87)
  1579. at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
  1580. at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
  1581. at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
  1582. FAILURES!!!
  1583. Tests run: 8, Failures: 1, Errors: 0
  1584. -- Maven
  1585. # mvn test
  1586. -------------------------------------------------------
  1587. T E S T S
  1588. -------------------------------------------------------
  1589. Running math.AdditionTest
  1590. Tests run: 2, Failures: 1, Errors: 0, Skipped: 0, Time elapsed:
  1591. 0.03 sec <<< FAILURE!
  1592. Results :
  1593. Failed tests:
  1594. testLireSymbole(math.AdditionTest)
  1595. Tests run: 2, Failures: 1, Errors: 0, Skipped: 0
  1596. -- LuaUnit
  1597. ---- non verbose
  1598. * display . or F or E when running tests
  1599. ---- verbose
  1600. * display test name + ok/fail
  1601. ----
  1602. * blank line
  1603. * number) ERROR or FAILURE: TestName
  1604. Stack trace
  1605. * blank line
  1606. * number) ERROR or FAILURE: TestName
  1607. Stack trace
  1608. then --------------
  1609. then "Ran x tests in 0.000s (%d not selected, %d skipped)"
  1610. then OK or FAILED (failures=1, error=1)
  1611. ]]
  1612. local TextOutput = genericOutput.new() -- derived class
  1613. local TextOutput_MT = { __index = TextOutput } -- metatable
  1614. TextOutput.__class__ = 'TextOutput'
  1615. function TextOutput.new(runner)
  1616. local t = genericOutput.new(runner, M.VERBOSITY_DEFAULT)
  1617. t.errorList = {}
  1618. return setmetatable( t, TextOutput_MT )
  1619. end
  1620. function TextOutput:startSuite()
  1621. if self.verbosity > M.VERBOSITY_DEFAULT then
  1622. print( 'Started on '.. self.result.startDate )
  1623. end
  1624. end
  1625. function TextOutput:startTest(testName)
  1626. if self.verbosity > M.VERBOSITY_DEFAULT then
  1627. io.stdout:write( " ", self.result.currentNode.testName, " ... " )
  1628. end
  1629. end
  1630. function TextOutput:endTest( node )
  1631. if node:isPassed() then
  1632. if self.verbosity > M.VERBOSITY_DEFAULT then
  1633. io.stdout:write("Ok\n")
  1634. else
  1635. io.stdout:write(".")
  1636. end
  1637. else
  1638. if self.verbosity > M.VERBOSITY_DEFAULT then
  1639. print( node.status )
  1640. print( node.msg )
  1641. --[[
  1642. -- find out when to do this:
  1643. if self.verbosity > M.VERBOSITY_DEFAULT then
  1644. print( node.stackTrace )
  1645. end
  1646. ]]
  1647. else
  1648. -- write only the first character of status
  1649. io.stdout:write(string.sub(node.status, 1, 1))
  1650. end
  1651. end
  1652. end
  1653. function TextOutput:displayOneFailedTest( index, fail )
  1654. print(index..") "..fail.testName )
  1655. print( fail.msg )
  1656. print( fail.stackTrace )
  1657. print()
  1658. end
  1659. function TextOutput:displayFailedTests()
  1660. if self.result.notPassedCount ~= 0 then
  1661. print("Failed tests:")
  1662. print("-------------")
  1663. for i, v in ipairs(self.result.notPassed) do
  1664. self:displayOneFailedTest(i, v)
  1665. end
  1666. end
  1667. end
  1668. function TextOutput:endSuite()
  1669. if self.verbosity > M.VERBOSITY_DEFAULT then
  1670. print("=========================================================")
  1671. else
  1672. print()
  1673. end
  1674. self:displayFailedTests()
  1675. print( M.LuaUnit.statusLine( self.result ) )
  1676. if self.result.notPassedCount == 0 then
  1677. print('OK')
  1678. end
  1679. end
  1680. -- class TextOutput end
  1681. ----------------------------------------------------------------
  1682. -- class NilOutput
  1683. ----------------------------------------------------------------
  1684. local function nopCallable()
  1685. --print(42)
  1686. return nopCallable
  1687. end
  1688. local NilOutput = { __class__ = 'NilOuptut' } -- class
  1689. local NilOutput_MT = { __index = nopCallable } -- metatable
  1690. function NilOutput.new(runner)
  1691. return setmetatable( { __class__ = 'NilOutput' }, NilOutput_MT )
  1692. end
  1693. ----------------------------------------------------------------
  1694. --
  1695. -- class LuaUnit
  1696. --
  1697. ----------------------------------------------------------------
  1698. M.LuaUnit = {
  1699. outputType = TextOutput,
  1700. verbosity = M.VERBOSITY_DEFAULT,
  1701. __class__ = 'LuaUnit'
  1702. }
  1703. local LuaUnit_MT = { __index = M.LuaUnit }
  1704. if EXPORT_ASSERT_TO_GLOBALS then
  1705. LuaUnit = M.LuaUnit
  1706. end
  1707. function M.LuaUnit.new()
  1708. return setmetatable( {}, LuaUnit_MT )
  1709. end
  1710. -----------------[[ Utility methods ]]---------------------
  1711. function M.LuaUnit.asFunction(aObject)
  1712. -- return "aObject" if it is a function, and nil otherwise
  1713. if 'function' == type(aObject) then
  1714. return aObject
  1715. end
  1716. end
  1717. function M.LuaUnit.splitClassMethod(someName)
  1718. --[[
  1719. Return a pair of className, methodName strings for a name in the form
  1720. "class.method". If no class part (or separator) is found, will return
  1721. nil, someName instead (the latter being unchanged).
  1722. This convention thus also replaces the older isClassMethod() test:
  1723. You just have to check for a non-nil className (return) value.
  1724. ]]
  1725. local separator = string.find(someName, '.', 1, true)
  1726. if separator then
  1727. return someName:sub(1, separator - 1), someName:sub(separator + 1)
  1728. end
  1729. return nil, someName
  1730. end
  1731. function M.LuaUnit.isMethodTestName( s )
  1732. -- return true is the name matches the name of a test method
  1733. -- default rule is that is starts with 'Test' or with 'test'
  1734. return string.sub(s, 1, 4):lower() == 'test'
  1735. end
  1736. function M.LuaUnit.isTestName( s )
  1737. -- return true is the name matches the name of a test
  1738. -- default rule is that is starts with 'Test' or with 'test'
  1739. return string.sub(s, 1, 4):lower() == 'test'
  1740. end
  1741. function M.LuaUnit.collectTests()
  1742. -- return a list of all test names in the global namespace
  1743. -- that match LuaUnit.isTestName
  1744. local testNames = {}
  1745. for k, _ in pairs(_G) do
  1746. if type(k) == "string" and M.LuaUnit.isTestName( k ) then
  1747. table.insert( testNames , k )
  1748. end
  1749. end
  1750. table.sort( testNames )
  1751. return testNames
  1752. end
  1753. function M.LuaUnit.parseCmdLine( cmdLine )
  1754. -- parse the command line
  1755. -- Supported command line parameters:
  1756. -- --verbose, -v: increase verbosity
  1757. -- --quiet, -q: silence output
  1758. -- --error, -e: treat errors as fatal (quit program)
  1759. -- --output, -o, + name: select output type
  1760. -- --pattern, -p, + pattern: run test matching pattern, may be repeated
  1761. -- --exclude, -x, + pattern: run test not matching pattern, may be repeated
  1762. -- --random, -r, : run tests in random order
  1763. -- --name, -n, + fname: name of output file for junit, default to stdout
  1764. -- --count, -c, + num: number of times to execute each test
  1765. -- [testnames, ...]: run selected test names
  1766. --
  1767. -- Returns a table with the following fields:
  1768. -- verbosity: nil, M.VERBOSITY_DEFAULT, M.VERBOSITY_QUIET, M.VERBOSITY_VERBOSE
  1769. -- output: nil, 'tap', 'junit', 'text', 'nil'
  1770. -- testNames: nil or a list of test names to run
  1771. -- exeCount: num or 1
  1772. -- pattern: nil or a list of patterns
  1773. -- exclude: nil or a list of patterns
  1774. local result, state = {}, nil
  1775. local SET_OUTPUT = 1
  1776. local SET_PATTERN = 2
  1777. local SET_EXCLUDE = 3
  1778. local SET_FNAME = 4
  1779. local SET_XCOUNT = 5
  1780. if cmdLine == nil then
  1781. return result
  1782. end
  1783. local function parseOption( option )
  1784. if option == '--help' or option == '-h' then
  1785. result['help'] = true
  1786. return
  1787. elseif option == '--version' then
  1788. result['version'] = true
  1789. return
  1790. elseif option == '--verbose' or option == '-v' then
  1791. result['verbosity'] = M.VERBOSITY_VERBOSE
  1792. return
  1793. elseif option == '--quiet' or option == '-q' then
  1794. result['verbosity'] = M.VERBOSITY_QUIET
  1795. return
  1796. elseif option == '--error' or option == '-e' then
  1797. result['quitOnError'] = true
  1798. return
  1799. elseif option == '--failure' or option == '-f' then
  1800. result['quitOnFailure'] = true
  1801. return
  1802. elseif option == '--random' or option == '-r' then
  1803. result['randomize'] = true
  1804. return
  1805. elseif option == '--output' or option == '-o' then
  1806. state = SET_OUTPUT
  1807. return state
  1808. elseif option == '--name' or option == '-n' then
  1809. state = SET_FNAME
  1810. return state
  1811. elseif option == '--count' or option == '-c' then
  1812. state = SET_XCOUNT
  1813. return state
  1814. elseif option == '--pattern' or option == '-p' then
  1815. state = SET_PATTERN
  1816. return state
  1817. elseif option == '--exclude' or option == '-x' then
  1818. state = SET_EXCLUDE
  1819. return state
  1820. end
  1821. error('Unknown option: '..option,3)
  1822. end
  1823. local function setArg( cmdArg, state )
  1824. if state == SET_OUTPUT then
  1825. result['output'] = cmdArg
  1826. return
  1827. elseif state == SET_FNAME then
  1828. result['fname'] = cmdArg
  1829. return
  1830. elseif state == SET_XCOUNT then
  1831. result['exeCount'] = tonumber(cmdArg)
  1832. or error('Malformed -c argument: '..cmdArg)
  1833. return
  1834. elseif state == SET_PATTERN then
  1835. if result['pattern'] then
  1836. table.insert( result['pattern'], cmdArg )
  1837. else
  1838. result['pattern'] = { cmdArg }
  1839. end
  1840. return
  1841. elseif state == SET_EXCLUDE then
  1842. local notArg = '!'..cmdArg
  1843. if result['pattern'] then
  1844. table.insert( result['pattern'], notArg )
  1845. else
  1846. result['pattern'] = { notArg }
  1847. end
  1848. return
  1849. end
  1850. error('Unknown parse state: '.. state)
  1851. end
  1852. for i, cmdArg in ipairs(cmdLine) do
  1853. if state ~= nil then
  1854. setArg( cmdArg, state, result )
  1855. state = nil
  1856. else
  1857. if cmdArg:sub(1,1) == '-' then
  1858. state = parseOption( cmdArg )
  1859. else
  1860. if result['testNames'] then
  1861. table.insert( result['testNames'], cmdArg )
  1862. else
  1863. result['testNames'] = { cmdArg }
  1864. end
  1865. end
  1866. end
  1867. end
  1868. if result['help'] then
  1869. M.LuaUnit.help()
  1870. end
  1871. if result['version'] then
  1872. M.LuaUnit.version()
  1873. end
  1874. if state ~= nil then
  1875. error('Missing argument after '..cmdLine[ #cmdLine ],2 )
  1876. end
  1877. return result
  1878. end
  1879. function M.LuaUnit.help()
  1880. print(M.USAGE)
  1881. os.exit(0)
  1882. end
  1883. function M.LuaUnit.version()
  1884. print('LuaUnit v'..M.VERSION..' by Philippe Fremy <phil@freehackers.org>')
  1885. os.exit(0)
  1886. end
  1887. ----------------------------------------------------------------
  1888. -- class NodeStatus
  1889. ----------------------------------------------------------------
  1890. local NodeStatus = { __class__ = 'NodeStatus' } -- class
  1891. local NodeStatus_MT = { __index = NodeStatus } -- metatable
  1892. M.NodeStatus = NodeStatus
  1893. -- values of status
  1894. NodeStatus.PASS = 'PASS'
  1895. NodeStatus.FAIL = 'FAIL'
  1896. NodeStatus.ERROR = 'ERROR'
  1897. function NodeStatus.new( number, testName, className )
  1898. local t = { number = number, testName = testName, className = className }
  1899. setmetatable( t, NodeStatus_MT )
  1900. t:pass()
  1901. return t
  1902. end
  1903. function NodeStatus:pass()
  1904. self.status = self.PASS
  1905. -- useless but we know it's the field we want to use
  1906. self.msg = nil
  1907. self.stackTrace = nil
  1908. end
  1909. function NodeStatus:fail(msg, stackTrace)
  1910. self.status = self.FAIL
  1911. self.msg = msg
  1912. self.stackTrace = stackTrace
  1913. end
  1914. function NodeStatus:error(msg, stackTrace)
  1915. self.status = self.ERROR
  1916. self.msg = msg
  1917. self.stackTrace = stackTrace
  1918. end
  1919. function NodeStatus:isPassed()
  1920. return self.status == NodeStatus.PASS
  1921. end
  1922. function NodeStatus:isNotPassed()
  1923. -- print('hasFailure: '..prettystr(self))
  1924. return self.status ~= NodeStatus.PASS
  1925. end
  1926. function NodeStatus:isFailure()
  1927. return self.status == NodeStatus.FAIL
  1928. end
  1929. function NodeStatus:isError()
  1930. return self.status == NodeStatus.ERROR
  1931. end
  1932. function NodeStatus:statusXML()
  1933. if self:isError() then
  1934. return table.concat(
  1935. {' <error type="', xmlEscape(self.msg), '">\n',
  1936. ' <![CDATA[', xmlCDataEscape(self.stackTrace),
  1937. ']]></error>\n'})
  1938. elseif self:isFailure() then
  1939. return table.concat(
  1940. {' <failure type="', xmlEscape(self.msg), '">\n',
  1941. ' <![CDATA[', xmlCDataEscape(self.stackTrace),
  1942. ']]></failure>\n'})
  1943. end
  1944. return ' <passed/>\n' -- (not XSD-compliant! normally shouldn't get here)
  1945. end
  1946. --------------[[ Output methods ]]-------------------------
  1947. local function conditional_plural(number, singular)
  1948. -- returns a grammatically well-formed string "%d <singular/plural>"
  1949. local suffix = ''
  1950. if number ~= 1 then -- use plural
  1951. suffix = (singular:sub(-2) == 'ss') and 'es' or 's'
  1952. end
  1953. return string.format('%d %s%s', number, singular, suffix)
  1954. end
  1955. function M.LuaUnit.statusLine(result)
  1956. -- return status line string according to results
  1957. local s = {
  1958. string.format('Ran %d tests in %0.3f seconds',
  1959. result.runCount, result.duration),
  1960. conditional_plural(result.passedCount, 'success'),
  1961. }
  1962. if result.notPassedCount > 0 then
  1963. if result.failureCount > 0 then
  1964. table.insert(s, conditional_plural(result.failureCount, 'failure'))
  1965. end
  1966. if result.errorCount > 0 then
  1967. table.insert(s, conditional_plural(result.errorCount, 'error'))
  1968. end
  1969. else
  1970. table.insert(s, '0 failures')
  1971. end
  1972. if result.nonSelectedCount > 0 then
  1973. table.insert(s, string.format("%d non-selected", result.nonSelectedCount))
  1974. end
  1975. return table.concat(s, ', ')
  1976. end
  1977. function M.LuaUnit:startSuite(testCount, nonSelectedCount)
  1978. self.result = {
  1979. testCount = testCount,
  1980. nonSelectedCount = nonSelectedCount,
  1981. passedCount = 0,
  1982. runCount = 0,
  1983. currentTestNumber = 0,
  1984. currentClassName = "",
  1985. currentNode = nil,
  1986. suiteStarted = true,
  1987. startTime = os.clock(),
  1988. startDate = os.date(os.getenv('LUAUNIT_DATEFMT')),
  1989. startIsodate = os.date('%Y-%m-%dT%H:%M:%S'),
  1990. patternIncludeFilter = self.patternIncludeFilter,
  1991. tests = {},
  1992. failures = {},
  1993. errors = {},
  1994. notPassed = {},
  1995. }
  1996. self.outputType = self.outputType or TextOutput
  1997. self.output = self.outputType.new(self)
  1998. self.output:startSuite()
  1999. end
  2000. function M.LuaUnit:startClass( className )
  2001. self.result.currentClassName = className
  2002. self.output:startClass( className )
  2003. end
  2004. function M.LuaUnit:startTest( testName )
  2005. self.result.currentTestNumber = self.result.currentTestNumber + 1
  2006. self.result.runCount = self.result.runCount + 1
  2007. self.result.currentNode = NodeStatus.new(
  2008. self.result.currentTestNumber,
  2009. testName,
  2010. self.result.currentClassName
  2011. )
  2012. self.result.currentNode.startTime = os.clock()
  2013. table.insert( self.result.tests, self.result.currentNode )
  2014. self.output:startTest( testName )
  2015. end
  2016. function M.LuaUnit:addStatus( err )
  2017. -- "err" is expected to be a table / result from protectedCall()
  2018. if err.status == NodeStatus.PASS then
  2019. return
  2020. end
  2021. local node = self.result.currentNode
  2022. --[[ As a first approach, we will report only one error or one failure for one test.
  2023. However, we can have the case where the test is in failure, and the teardown is in error.
  2024. In such case, it's a good idea to report both a failure and an error in the test suite. This is
  2025. what Python unittest does for example. However, it mixes up counts so need to be handled carefully: for
  2026. example, there could be more (failures + errors) count that tests. What happens to the current node ?
  2027. We will do this more intelligent version later.
  2028. ]]
  2029. -- if the node is already in failure/error, just don't report the new error (see above)
  2030. if node.status ~= NodeStatus.PASS then
  2031. return
  2032. end
  2033. if err.status == NodeStatus.FAIL then
  2034. node:fail( err.msg, err.trace )
  2035. table.insert( self.result.failures, node )
  2036. elseif err.status == NodeStatus.ERROR then
  2037. node:error( err.msg, err.trace )
  2038. table.insert( self.result.errors, node )
  2039. end
  2040. if node:isFailure() or node:isError() then
  2041. -- add to the list of failed tests (gets printed separately)
  2042. table.insert( self.result.notPassed, node )
  2043. end
  2044. self.output:addStatus( node )
  2045. end
  2046. function M.LuaUnit:endTest()
  2047. local node = self.result.currentNode
  2048. -- print( 'endTest() '..prettystr(node))
  2049. -- print( 'endTest() '..prettystr(node:isNotPassed()))
  2050. node.duration = os.clock() - node.startTime
  2051. node.startTime = nil
  2052. self.output:endTest( node )
  2053. if node:isPassed() then
  2054. self.result.passedCount = self.result.passedCount + 1
  2055. elseif node:isError() then
  2056. if self.quitOnError or self.quitOnFailure then
  2057. -- Runtime error - abort test execution as requested by
  2058. -- "--error" option. This is done by setting a special
  2059. -- flag that gets handled in runSuiteByInstances().
  2060. print("\nERROR during LuaUnit test execution:\n" .. node.msg)
  2061. self.result.aborted = true
  2062. end
  2063. elseif node:isFailure() then
  2064. if self.quitOnFailure then
  2065. -- Failure - abort test execution as requested by
  2066. -- "--failure" option. This is done by setting a special
  2067. -- flag that gets handled in runSuiteByInstances().
  2068. print("\nFailure during LuaUnit test execution:\n" .. node.msg)
  2069. self.result.aborted = true
  2070. end
  2071. end
  2072. self.result.currentNode = nil
  2073. end
  2074. function M.LuaUnit:endClass()
  2075. self.output:endClass()
  2076. end
  2077. function M.LuaUnit:endSuite()
  2078. if self.result.suiteStarted == false then
  2079. error('LuaUnit:endSuite() -- suite was already ended' )
  2080. end
  2081. self.result.duration = os.clock()-self.result.startTime
  2082. self.result.suiteStarted = false
  2083. -- Expose test counts for outputter's endSuite(). This could be managed
  2084. -- internally instead, but unit tests (and existing use cases) might
  2085. -- rely on these fields being present.
  2086. self.result.notPassedCount = #self.result.notPassed
  2087. self.result.failureCount = #self.result.failures
  2088. self.result.errorCount = #self.result.errors
  2089. self.output:endSuite()
  2090. end
  2091. function M.LuaUnit:setOutputType(outputType)
  2092. -- default to text
  2093. -- tap produces results according to TAP format
  2094. if outputType:upper() == "NIL" then
  2095. self.outputType = NilOutput
  2096. return
  2097. end
  2098. if outputType:upper() == "TAP" then
  2099. self.outputType = TapOutput
  2100. return
  2101. end
  2102. if outputType:upper() == "JUNIT" then
  2103. self.outputType = JUnitOutput
  2104. return
  2105. end
  2106. if outputType:upper() == "TEXT" then
  2107. self.outputType = TextOutput
  2108. return
  2109. end
  2110. error( 'No such format: '..outputType,2)
  2111. end
  2112. --------------[[ Runner ]]-----------------
  2113. function M.LuaUnit:protectedCall(classInstance, methodInstance, prettyFuncName)
  2114. -- if classInstance is nil, this is just a function call
  2115. -- else, it's method of a class being called.
  2116. local function err_handler(e)
  2117. -- transform error into a table, adding the traceback information
  2118. return {
  2119. status = NodeStatus.ERROR,
  2120. msg = e,
  2121. trace = string.sub(debug.traceback("", 3), 2)
  2122. }
  2123. end
  2124. local ok, err
  2125. if classInstance then
  2126. -- stupid Lua < 5.2 does not allow xpcall with arguments so let's use a workaround
  2127. ok, err = xpcall( function () methodInstance(classInstance) end, err_handler )
  2128. else
  2129. ok, err = xpcall( function () methodInstance() end, err_handler )
  2130. end
  2131. if ok then
  2132. return {status = NodeStatus.PASS}
  2133. end
  2134. -- determine if the error was a failed test:
  2135. -- We do this by stripping the failure prefix from the error message,
  2136. -- while keeping track of the gsub() count. A non-zero value -> failure
  2137. local failed, iter_msg
  2138. iter_msg = self.exeCount and 'iteration: '..self.currentCount..', '
  2139. err.msg, failed = err.msg:gsub(M.FAILURE_PREFIX, iter_msg or '', 1)
  2140. if failed > 0 then
  2141. err.status = NodeStatus.FAIL
  2142. end
  2143. -- reformat / improve the stack trace
  2144. if prettyFuncName then -- we do have the real method name
  2145. err.trace = err.trace:gsub("in (%a+) 'methodInstance'", "in %1 '"..prettyFuncName.."'")
  2146. end
  2147. if STRIP_LUAUNIT_FROM_STACKTRACE then
  2148. err.trace = stripLuaunitTrace(err.trace)
  2149. end
  2150. return err -- return the error "object" (table)
  2151. end
  2152. function M.LuaUnit:execOneFunction(className, methodName, classInstance, methodInstance)
  2153. -- When executing a test function, className and classInstance must be nil
  2154. -- When executing a class method, all parameters must be set
  2155. if type(methodInstance) ~= 'function' then
  2156. error( tostring(methodName)..' must be a function, not '..type(methodInstance))
  2157. end
  2158. local prettyFuncName
  2159. if className == nil then
  2160. className = '[TestFunctions]'
  2161. prettyFuncName = methodName
  2162. else
  2163. prettyFuncName = className..'.'..methodName
  2164. end
  2165. if self.lastClassName ~= className then
  2166. if self.lastClassName ~= nil then
  2167. self:endClass()
  2168. end
  2169. self:startClass( className )
  2170. self.lastClassName = className
  2171. end
  2172. self:startTest(prettyFuncName)
  2173. local node = self.result.currentNode
  2174. for iter_n = 1, self.exeCount or 1 do
  2175. if node:isNotPassed() then
  2176. break
  2177. end
  2178. self.currentCount = iter_n
  2179. -- run setUp first (if any)
  2180. if classInstance then
  2181. local func = self.asFunction( classInstance.setUp ) or
  2182. self.asFunction( classInstance.Setup ) or
  2183. self.asFunction( classInstance.setup ) or
  2184. self.asFunction( classInstance.SetUp )
  2185. if func then
  2186. self:addStatus(self:protectedCall(classInstance, func, className..'.setUp'))
  2187. end
  2188. end
  2189. -- run testMethod()
  2190. if node:isPassed() then
  2191. self:addStatus(self:protectedCall(classInstance, methodInstance, prettyFuncName))
  2192. end
  2193. -- lastly, run tearDown (if any)
  2194. if classInstance then
  2195. local func = self.asFunction( classInstance.tearDown ) or
  2196. self.asFunction( classInstance.TearDown ) or
  2197. self.asFunction( classInstance.teardown ) or
  2198. self.asFunction( classInstance.Teardown )
  2199. if func then
  2200. self:addStatus(self:protectedCall(classInstance, func, className..'.tearDown'))
  2201. end
  2202. end
  2203. end
  2204. self:endTest()
  2205. end
  2206. function M.LuaUnit.expandOneClass( result, className, classInstance )
  2207. --[[
  2208. Input: a list of { name, instance }, a class name, a class instance
  2209. Ouptut: modify result to add all test method instance in the form:
  2210. { className.methodName, classInstance }
  2211. ]]
  2212. for methodName, methodInstance in sortedPairs(classInstance) do
  2213. if M.LuaUnit.asFunction(methodInstance) and M.LuaUnit.isMethodTestName( methodName ) then
  2214. table.insert( result, { className..'.'..methodName, classInstance } )
  2215. end
  2216. end
  2217. end
  2218. function M.LuaUnit.expandClasses( listOfNameAndInst )
  2219. --[[
  2220. -- expand all classes (provided as {className, classInstance}) to a list of {className.methodName, classInstance}
  2221. -- functions and methods remain untouched
  2222. Input: a list of { name, instance }
  2223. Output:
  2224. * { function name, function instance } : do nothing
  2225. * { class.method name, class instance }: do nothing
  2226. * { class name, class instance } : add all method names in the form of (className.methodName, classInstance)
  2227. ]]
  2228. local result = {}
  2229. for i,v in ipairs( listOfNameAndInst ) do
  2230. local name, instance = v[1], v[2]
  2231. if M.LuaUnit.asFunction(instance) then
  2232. table.insert( result, { name, instance } )
  2233. else
  2234. if type(instance) ~= 'table' then
  2235. error( 'Instance must be a table or a function, not a '..type(instance)..', value '..prettystr(instance))
  2236. end
  2237. local className, methodName = M.LuaUnit.splitClassMethod( name )
  2238. if className then
  2239. local methodInstance = instance[methodName]
  2240. if methodInstance == nil then
  2241. error( "Could not find method in class "..tostring(className).." for method "..tostring(methodName) )
  2242. end
  2243. table.insert( result, { name, instance } )
  2244. else
  2245. M.LuaUnit.expandOneClass( result, name, instance )
  2246. end
  2247. end
  2248. end
  2249. return result
  2250. end
  2251. function M.LuaUnit.applyPatternFilter( patternIncFilter, listOfNameAndInst )
  2252. local included, excluded = {}, {}
  2253. for i, v in ipairs( listOfNameAndInst ) do
  2254. -- local name, instance = v[1], v[2]
  2255. if patternFilter( patternIncFilter, v[1] ) then
  2256. table.insert( included, v )
  2257. else
  2258. table.insert( excluded, v )
  2259. end
  2260. end
  2261. return included, excluded
  2262. end
  2263. function M.LuaUnit:runSuiteByInstances( listOfNameAndInst )
  2264. --[[ Run an explicit list of tests. All test instances and names must be supplied.
  2265. each test must be one of:
  2266. * { function name, function instance }
  2267. * { class name, class instance }
  2268. * { class.method name, class instance }
  2269. ]]
  2270. local expandedList = self.expandClasses( listOfNameAndInst )
  2271. if self.randomize then
  2272. randomizeTable( expandedList )
  2273. end
  2274. local filteredList, filteredOutList = self.applyPatternFilter(
  2275. self.patternIncludeFilter, expandedList )
  2276. self:startSuite( #filteredList, #filteredOutList )
  2277. for i,v in ipairs( filteredList ) do
  2278. local name, instance = v[1], v[2]
  2279. if M.LuaUnit.asFunction(instance) then
  2280. self:execOneFunction( nil, name, nil, instance )
  2281. else
  2282. -- expandClasses() should have already taken care of sanitizing the input
  2283. assert( type(instance) == 'table' )
  2284. local className, methodName = M.LuaUnit.splitClassMethod( name )
  2285. assert( className ~= nil )
  2286. local methodInstance = instance[methodName]
  2287. assert(methodInstance ~= nil)
  2288. self:execOneFunction( className, methodName, instance, methodInstance )
  2289. end
  2290. if self.result.aborted then
  2291. break -- "--error" or "--failure" option triggered
  2292. end
  2293. end
  2294. if self.lastClassName ~= nil then
  2295. self:endClass()
  2296. end
  2297. self:endSuite()
  2298. if self.result.aborted then
  2299. print("LuaUnit ABORTED (as requested by --error or --failure option)")
  2300. os.exit(-2)
  2301. end
  2302. end
  2303. function M.LuaUnit:runSuiteByNames( listOfName )
  2304. --[[ Run LuaUnit with a list of generic names, coming either from command-line or from global
  2305. namespace analysis. Convert the list into a list of (name, valid instances (table or function))
  2306. and calls runSuiteByInstances.
  2307. ]]
  2308. local instanceName, instance
  2309. local listOfNameAndInst = {}
  2310. for i,name in ipairs( listOfName ) do
  2311. local className, methodName = M.LuaUnit.splitClassMethod( name )
  2312. if className then
  2313. instanceName = className
  2314. instance = _G[instanceName]
  2315. if instance == nil then
  2316. error( "No such name in global space: "..instanceName )
  2317. end
  2318. if type(instance) ~= 'table' then
  2319. error( 'Instance of '..instanceName..' must be a table, not '..type(instance))
  2320. end
  2321. local methodInstance = instance[methodName]
  2322. if methodInstance == nil then
  2323. error( "Could not find method in class "..tostring(className).." for method "..tostring(methodName) )
  2324. end
  2325. else
  2326. -- for functions and classes
  2327. instanceName = name
  2328. instance = _G[instanceName]
  2329. end
  2330. if instance == nil then
  2331. error( "No such name in global space: "..instanceName )
  2332. end
  2333. if (type(instance) ~= 'table' and type(instance) ~= 'function') then
  2334. error( 'Name must match a function or a table: '..instanceName )
  2335. end
  2336. table.insert( listOfNameAndInst, { name, instance } )
  2337. end
  2338. self:runSuiteByInstances( listOfNameAndInst )
  2339. end
  2340. function M.LuaUnit.run(...)
  2341. -- Run some specific test classes.
  2342. -- If no arguments are passed, run the class names specified on the
  2343. -- command line. If no class name is specified on the command line
  2344. -- run all classes whose name starts with 'Test'
  2345. --
  2346. -- If arguments are passed, they must be strings of the class names
  2347. -- that you want to run or generic command line arguments (-o, -p, -v, ...)
  2348. local runner = M.LuaUnit.new()
  2349. return runner:runSuite(...)
  2350. end
  2351. function M.LuaUnit:runSuite( ... )
  2352. local args = {...}
  2353. if type(args[1]) == 'table' and args[1].__class__ == 'LuaUnit' then
  2354. -- run was called with the syntax M.LuaUnit:runSuite()
  2355. -- we support both M.LuaUnit.run() and M.LuaUnit:run()
  2356. -- strip out the first argument
  2357. table.remove(args,1)
  2358. end
  2359. if #args == 0 then
  2360. args = cmdline_argv
  2361. end
  2362. local options = pcall_or_abort( M.LuaUnit.parseCmdLine, args )
  2363. -- We expect these option fields to be either `nil` or contain
  2364. -- valid values, so it's safe to always copy them directly.
  2365. self.verbosity = options.verbosity
  2366. self.quitOnError = options.quitOnError
  2367. self.quitOnFailure = options.quitOnFailure
  2368. self.fname = options.fname
  2369. self.exeCount = options.exeCount
  2370. self.patternIncludeFilter = options.pattern
  2371. self.randomize = options.randomize
  2372. if options.output then
  2373. if options.output:lower() == 'junit' and options.fname == nil then
  2374. print('With junit output, a filename must be supplied with -n or --name')
  2375. os.exit(-1)
  2376. end
  2377. pcall_or_abort(self.setOutputType, self, options.output)
  2378. end
  2379. self:runSuiteByNames( options.testNames or M.LuaUnit.collectTests() )
  2380. return self.result.notPassedCount
  2381. end
  2382. -- class LuaUnit
  2383. -- For compatbility with LuaUnit v2
  2384. M.run = M.LuaUnit.run
  2385. M.Run = M.LuaUnit.run
  2386. function M:setVerbosity( verbosity )
  2387. M.LuaUnit.verbosity = verbosity
  2388. end
  2389. M.set_verbosity = M.setVerbosity
  2390. M.SetVerbosity = M.setVerbosity
  2391. return M