Upgraded to 2.3 version; added a global 'asserts'.

This commit is contained in:
Jeremy Fincher 2003-04-29 12:59:35 +00:00
parent 54788a643a
commit 3c33583454

View File

@ -46,7 +46,7 @@ SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
__author__ = "Steve Purcell" __author__ = "Steve Purcell"
__email__ = "stephen_purcell at yahoo dot com" __email__ = "stephen_purcell at yahoo dot com"
__version__ = "#Revision: 1.43 $"[11:-2] __version__ = "#Revision: 1.46 $"[11:-2]
import time import time
import sys import sys
@ -55,10 +55,21 @@ import string
import os import os
import types import types
###
# Globals
###
asserts = 0
############################################################################## ##############################################################################
# Test framework core # Test framework core
############################################################################## ##############################################################################
# All classes defined herein are 'new-style' classes, allowing use of 'super()'
__metaclass__ = type
def _strclass(cls):
return "%s.%s" % (cls.__module__, cls.__name__)
class TestResult: class TestResult:
"""Holder for test result information. """Holder for test result information.
@ -109,11 +120,11 @@ class TestResult:
def _exc_info_to_string(self, err): def _exc_info_to_string(self, err):
"""Converts a sys.exc_info()-style tuple of values into a string.""" """Converts a sys.exc_info()-style tuple of values into a string."""
return string.join(apply(traceback.format_exception, err), '') return string.join(traceback.format_exception(*err), '')
def __repr__(self): def __repr__(self):
return "<%s run=%i errors=%i failures=%i>" % \ return "<%s run=%i errors=%i failures=%i>" % \
(self.__class__, self.testsRun, len(self.errors), (_strclass(self.__class__), self.testsRun, len(self.errors),
len(self.failures)) len(self.failures))
@ -183,14 +194,14 @@ class TestCase:
return doc and string.strip(string.split(doc, "\n")[0]) or None return doc and string.strip(string.split(doc, "\n")[0]) or None
def id(self): def id(self):
return "%s.%s" % (self.__class__, self.__testMethodName) return "%s.%s" % (_strclass(self.__class__), self.__testMethodName)
def __str__(self): def __str__(self):
return "%s (%s)" % (self.__testMethodName, self.__class__) return "%s (%s)" % (self.__testMethodName, _strclass(self.__class__))
def __repr__(self): def __repr__(self):
return "<%s testMethod=%s>" % \ return "<%s testMethod=%s>" % \
(self.__class__, self.__testMethodName) (_strclass(self.__class__), self.__testMethodName)
def run(self, result=None): def run(self, result=None):
return self(result) return self(result)
@ -249,17 +260,27 @@ class TestCase:
return (exctype, excvalue, tb) return (exctype, excvalue, tb)
return (exctype, excvalue, newtb) return (exctype, excvalue, newtb)
def _fail(self, msg):
"""Underlying implementation of failure."""
raise self.failureException, msg
def fail(self, msg=None): def fail(self, msg=None):
"""Fail immediately, with the given message.""" """Fail immediately, with the given message."""
raise self.failureException, msg global asserts
asserts += 1
self._fail(msg)
def failIf(self, expr, msg=None): def failIf(self, expr, msg=None):
"Fail the test if the expression is true." "Fail the test if the expression is true."
if expr: raise self.failureException, msg global asserts
asserts += 1
if expr: self._fail(msg)
def failUnless(self, expr, msg=None): def failUnless(self, expr, msg=None):
"""Fail the test unless the expression is true.""" """Fail the test unless the expression is true."""
if not expr: raise self.failureException, msg global asserts
asserts += 1
if not expr: self._fail(msg)
def failUnlessRaises(self, excClass, callableObj, *args, **kwargs): def failUnlessRaises(self, excClass, callableObj, *args, **kwargs):
"""Fail unless an exception of class excClass is thrown """Fail unless an exception of class excClass is thrown
@ -269,35 +290,71 @@ class TestCase:
deemed to have suffered an error, exactly as for an deemed to have suffered an error, exactly as for an
unexpected exception. unexpected exception.
""" """
global asserts
asserts += 1
try: try:
apply(callableObj, args, kwargs) callableObj(*args, **kwargs)
except excClass: except excClass:
return return
else: else:
if hasattr(excClass,'__name__'): excName = excClass.__name__ if hasattr(excClass,'__name__'): excName = excClass.__name__
else: excName = str(excClass) else: excName = str(excClass)
raise self.failureException, excName raise self._fail(excName)
def failUnlessEqual(self, first, second, msg=None): def failUnlessEqual(self, first, second, msg=None):
"""Fail if the two objects are unequal as determined by the '!=' """Fail if the two objects are unequal as determined by the '=='
operator. operator.
""" """
global asserts
asserts += 1
if not first == second: if not first == second:
raise self.failureException, \ self._fail(msg or '%s != %s' % (`first`, `second`))
(msg or '%s != %s' % (`first`, `second`))
def failIfEqual(self, first, second, msg=None): def failIfEqual(self, first, second, msg=None):
"""Fail if the two objects are equal as determined by the '==' """Fail if the two objects are equal as determined by the '=='
operator. operator.
""" """
global asserts
asserts += 1
if first == second: if first == second:
raise self.failureException, \ self._fail(msg or '%s == %s' % (`first`, `second`))
(msg or '%s == %s' % (`first`, `second`))
def failUnlessAlmostEqual(self, first, second, places=7, msg=None):
"""Fail if the two objects are unequal as determined by their
difference rounded to the given number of decimal places
(default 7) and comparing to zero.
Note that decimal places (from zero) is usually not the same
as significant digits (measured from the most signficant digit).
"""
global asserts
asserts += 1
if round(second-first, places) != 0:
self._fail(msg or '%s != %s within %s places' % \
(`first`, `second`, `places`))
def failIfAlmostEqual(self, first, second, places=7, msg=None):
"""Fail if the two objects are equal as determined by their
difference rounded to the given number of decimal places
(default 7) and comparing to zero.
Note that decimal places (from zero) is usually not the same
as significant digits (measured from the most signficant digit).
"""
global asserts
asserts += 1
if round(second-first, places) == 0:
self._fail(msg or '%s == %s within %s places' % \
(`first`, `second`, `places`))
assertEqual = assertEquals = failUnlessEqual assertEqual = assertEquals = failUnlessEqual
assertNotEqual = assertNotEquals = failIfEqual assertNotEqual = assertNotEquals = failIfEqual
assertAlmostEqual = assertAlmostEquals = failUnlessAlmostEqual
assertNotAlmostEqual = assertNotAlmostEquals = failIfAlmostEqual
assertRaises = failUnlessRaises assertRaises = failUnlessRaises
assert_ = failUnless assert_ = failUnless
@ -318,7 +375,7 @@ class TestSuite:
self.addTests(tests) self.addTests(tests)
def __repr__(self): def __repr__(self):
return "<%s tests=%s>" % (self.__class__, self._tests) return "<%s tests=%s>" % (_strclass(self.__class__), self._tests)
__str__ = __repr__ __str__ = __repr__
@ -382,10 +439,10 @@ class FunctionTestCase(TestCase):
return self.__testFunc.__name__ return self.__testFunc.__name__
def __str__(self): def __str__(self):
return "%s (%s)" % (self.__class__, self.__testFunc.__name__) return "%s (%s)" % (_strclass(self.__class__), self.__testFunc.__name__)
def __repr__(self): def __repr__(self):
return "<%s testFunc=%s>" % (self.__class__, self.__testFunc) return "<%s testFunc=%s>" % (_strclass(self.__class__), self.__testFunc)
def shortDescription(self): def shortDescription(self):
if self.__description is not None: return self.__description if self.__description is not None: return self.__description
@ -416,7 +473,8 @@ class TestLoader:
tests = [] tests = []
for name in dir(module): for name in dir(module):
obj = getattr(module, name) obj = getattr(module, name)
if type(obj) == types.ClassType and issubclass(obj, TestCase): if (isinstance(obj, (type, types.ClassType)) and
issubclass(obj, TestCase)):
tests.append(self.loadTestsFromTestCase(obj)) tests.append(self.loadTestsFromTestCase(obj))
return self.suiteClass(tests) return self.suiteClass(tests)
@ -450,7 +508,8 @@ class TestLoader:
import unittest import unittest
if type(obj) == types.ModuleType: if type(obj) == types.ModuleType:
return self.loadTestsFromModule(obj) return self.loadTestsFromModule(obj)
elif type(obj) == types.ClassType and issubclass(obj, unittest.TestCase): elif (isinstance(obj, (type, types.ClassType)) and
issubclass(obj, unittest.TestCase)):
return self.loadTestsFromTestCase(obj) return self.loadTestsFromTestCase(obj)
elif type(obj) == types.UnboundMethodType: elif type(obj) == types.UnboundMethodType:
return obj.im_class(obj.__name__) return obj.im_class(obj.__name__)
@ -525,7 +584,7 @@ class _WritelnDecorator:
return getattr(self.stream,attr) return getattr(self.stream,attr)
def writeln(self, *args): def writeln(self, *args):
if args: apply(self.write, args) if args: self.write(*args)
self.write('\n') # text-mode streams translate to \r\n if needed self.write('\n') # text-mode streams translate to \r\n if needed
@ -616,7 +675,7 @@ class TextTestRunner:
self.stream.writeln(result.separator2) self.stream.writeln(result.separator2)
run = result.testsRun run = result.testsRun
self.stream.writeln("Ran %d test%s in %.3fs" % self.stream.writeln("Ran %d test%s in %.3fs" %
(run, run == 1 and "" or "s", timeTaken)) (run, run != 1 and "s" or "", timeTaken))
self.stream.writeln() self.stream.writeln()
if not result.wasSuccessful(): if not result.wasSuccessful():
self.stream.write("FAILED (") self.stream.write("FAILED (")