diff --git a/src/fix.py b/src/fix.py index a96c2631a..72380d3e8 100644 --- a/src/fix.py +++ b/src/fix.py @@ -290,4 +290,15 @@ def itersplit(iterable, isSeparator, yieldEmpty=False): if acc or yieldEmpty: yield acc +def flatten(seq, strings=False): + for elt in seq: + if not strings and type(elt) == str or type(elt) == unicode: + yield elt + else: + try: + for x in flatten(elt): + yield x + except TypeError: + yield elt + # vim:set shiftwidth=4 tabstop=8 expandtab textwidth=78: diff --git a/test/fix_test.py b/test/fix_test.py index b195a1419..4c475b330 100644 --- a/test/fix_test.py +++ b/test/fix_test.py @@ -240,3 +240,17 @@ class FunctionsTest(unittest.TestCase): self.assertEqual(AL.values(), [2, 3, 4]) self.assertEqual(list(AL.itervalues()), [2, 3, 4]) self.assertEqual(len(AL), 3) + + def testFlatten(self): + def lflatten(seq): + return list(flatten(seq)) + self.assertEqual(lflatten([]), []) + self.assertEqual(lflatten([1]), [1]) + self.assertEqual(lflatten(range(10)), range(10)) + twoRanges = range(10)*2 + twoRanges.sort() + self.assertEqual(lflatten(zip(range(10), range(10))), twoRanges) + self.assertEqual(lflatten([1, [2, 3], 4]), [1, 2, 3, 4]) + self.assertEqual(lflatten([[[[[[[[[[]]]]]]]]]]), []) + self.assertEqual(lflatten([1, [2, [3, 4], 5], 6]), [1, 2, 3, 4, 5, 6]) + self.assertRaises(TypeError, lflatten, 1)