3 """Unit tests for the with statement specified in PEP 343."""
6 __author__ = "Mike Bland"
7 __email__ = "mbland at acm dot org"
11 from collections import deque
12 from contextlib import GeneratorContextManager, contextmanager
13 from test.test_support import run_unittest
16 class MockContextManager(GeneratorContextManager):
17 def __init__(self, gen):
18 GeneratorContextManager.__init__(self, gen)
19 self.enter_called = False
20 self.exit_called = False
24 self.enter_called = True
25 return GeneratorContextManager.__enter__(self)
27 def __exit__(self, type, value, traceback):
28 self.exit_called = True
29 self.exit_args = (type, value, traceback)
30 return GeneratorContextManager.__exit__(self, type,
34 def mock_contextmanager(func):
35 def helper(*args, **kwds):
36 return MockContextManager(func(*args, **kwds))
40 class MockResource(object):
47 def mock_contextmanager_generator():
58 def __init__(self, *managers):
59 self.managers = managers
63 if self.entered is not None:
64 raise RuntimeError("Context is not reentrant")
65 self.entered = deque()
68 for mgr in self.managers:
69 vars.append(mgr.__enter__())
70 self.entered.appendleft(mgr)
72 if not self.__exit__(*sys.exc_info()):
76 def __exit__(self, *exc_info):
77 # Behave like nested with statements
79 # New exceptions override old ones
81 for mgr in self.entered:
84 ex = (None, None, None)
88 if ex is not exc_info:
89 raise ex[0], ex[1], ex[2]
92 class MockNested(Nested):
93 def __init__(self, *managers):
94 Nested.__init__(self, *managers)
95 self.enter_called = False
96 self.exit_called = False
100 self.enter_called = True
101 return Nested.__enter__(self)
103 def __exit__(self, *exc_info):
104 self.exit_called = True
105 self.exit_args = exc_info
106 return Nested.__exit__(self, *exc_info)
109 class FailureTestCase(unittest.TestCase):
110 def testNameError(self):
111 def fooNotDeclared():
113 self.assertRaises(NameError, fooNotDeclared)
115 def testEnterAttributeError(self):
116 class LacksEnter(object):
117 def __exit__(self, type, value, traceback):
123 self.assertRaises(AttributeError, fooLacksEnter)
125 def testExitAttributeError(self):
126 class LacksExit(object):
133 self.assertRaises(AttributeError, fooLacksExit)
135 def assertRaisesSyntaxError(self, codestr):
136 def shouldRaiseSyntaxError(s):
137 compile(s, '', 'single')
138 self.assertRaises(SyntaxError, shouldRaiseSyntaxError, codestr)
140 def testAssignmentToNoneError(self):
141 self.assertRaisesSyntaxError('with mock as None:\n pass')
142 self.assertRaisesSyntaxError(
143 'with mock as (None):\n'
146 def testAssignmentToEmptyTupleError(self):
147 self.assertRaisesSyntaxError(
151 def testAssignmentToTupleOnlyContainingNoneError(self):
152 self.assertRaisesSyntaxError('with mock as None,:\n pass')
153 self.assertRaisesSyntaxError(
154 'with mock as (None,):\n'
157 def testAssignmentToTupleContainingNoneError(self):
158 self.assertRaisesSyntaxError(
159 'with mock as (foo, None, bar):\n'
162 def testEnterThrows(self):
163 class EnterThrows(object):
165 raise RuntimeError("Enter threw")
166 def __exit__(self, *args):
174 self.assertRaises(RuntimeError, shouldThrow)
175 self.assertEqual(self.foo, None)
177 def testExitThrows(self):
178 class ExitThrows(object):
181 def __exit__(self, *args):
182 raise RuntimeError(42)
186 self.assertRaises(RuntimeError, shouldThrow)
188 class ContextmanagerAssertionMixin(object):
189 TEST_EXCEPTION = RuntimeError("test exception")
191 def assertInWithManagerInvariants(self, mock_manager):
192 self.assertTrue(mock_manager.enter_called)
193 self.assertFalse(mock_manager.exit_called)
194 self.assertEqual(mock_manager.exit_args, None)
196 def assertAfterWithManagerInvariants(self, mock_manager, exit_args):
197 self.assertTrue(mock_manager.enter_called)
198 self.assertTrue(mock_manager.exit_called)
199 self.assertEqual(mock_manager.exit_args, exit_args)
201 def assertAfterWithManagerInvariantsNoError(self, mock_manager):
202 self.assertAfterWithManagerInvariants(mock_manager,
205 def assertInWithGeneratorInvariants(self, mock_generator):
206 self.assertTrue(mock_generator.yielded)
207 self.assertFalse(mock_generator.stopped)
209 def assertAfterWithGeneratorInvariantsNoError(self, mock_generator):
210 self.assertTrue(mock_generator.yielded)
211 self.assertTrue(mock_generator.stopped)
213 def raiseTestException(self):
214 raise self.TEST_EXCEPTION
216 def assertAfterWithManagerInvariantsWithError(self, mock_manager):
217 self.assertTrue(mock_manager.enter_called)
218 self.assertTrue(mock_manager.exit_called)
219 self.assertEqual(mock_manager.exit_args[0], RuntimeError)
220 self.assertEqual(mock_manager.exit_args[1], self.TEST_EXCEPTION)
222 def assertAfterWithGeneratorInvariantsWithError(self, mock_generator):
223 self.assertTrue(mock_generator.yielded)
224 self.assertTrue(mock_generator.stopped)
227 class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
228 def testInlineGeneratorSyntax(self):
229 with mock_contextmanager_generator():
232 def testUnboundGenerator(self):
233 mock = mock_contextmanager_generator()
236 self.assertAfterWithManagerInvariantsNoError(mock)
238 def testInlineGeneratorBoundSyntax(self):
239 with mock_contextmanager_generator() as foo:
240 self.assertInWithGeneratorInvariants(foo)
241 # FIXME: In the future, we'll try to keep the bound names from leaking
242 self.assertAfterWithGeneratorInvariantsNoError(foo)
244 def testInlineGeneratorBoundToExistingVariable(self):
246 with mock_contextmanager_generator() as foo:
247 self.assertInWithGeneratorInvariants(foo)
248 self.assertAfterWithGeneratorInvariantsNoError(foo)
250 def testInlineGeneratorBoundToDottedVariable(self):
251 with mock_contextmanager_generator() as self.foo:
252 self.assertInWithGeneratorInvariants(self.foo)
253 self.assertAfterWithGeneratorInvariantsNoError(self.foo)
255 def testBoundGenerator(self):
256 mock = mock_contextmanager_generator()
258 self.assertInWithGeneratorInvariants(foo)
259 self.assertInWithManagerInvariants(mock)
260 self.assertAfterWithGeneratorInvariantsNoError(foo)
261 self.assertAfterWithManagerInvariantsNoError(mock)
263 def testNestedSingleStatements(self):
264 mock_a = mock_contextmanager_generator()
266 mock_b = mock_contextmanager_generator()
268 self.assertInWithManagerInvariants(mock_a)
269 self.assertInWithManagerInvariants(mock_b)
270 self.assertInWithGeneratorInvariants(foo)
271 self.assertInWithGeneratorInvariants(bar)
272 self.assertAfterWithManagerInvariantsNoError(mock_b)
273 self.assertAfterWithGeneratorInvariantsNoError(bar)
274 self.assertInWithManagerInvariants(mock_a)
275 self.assertInWithGeneratorInvariants(foo)
276 self.assertAfterWithManagerInvariantsNoError(mock_a)
277 self.assertAfterWithGeneratorInvariantsNoError(foo)
280 class NestedNonexceptionalTestCase(unittest.TestCase,
281 ContextmanagerAssertionMixin):
282 def testSingleArgInlineGeneratorSyntax(self):
283 with Nested(mock_contextmanager_generator()):
286 def testSingleArgUnbound(self):
287 mock_contextmanager = mock_contextmanager_generator()
288 mock_nested = MockNested(mock_contextmanager)
290 self.assertInWithManagerInvariants(mock_contextmanager)
291 self.assertInWithManagerInvariants(mock_nested)
292 self.assertAfterWithManagerInvariantsNoError(mock_contextmanager)
293 self.assertAfterWithManagerInvariantsNoError(mock_nested)
295 def testSingleArgBoundToNonTuple(self):
296 m = mock_contextmanager_generator()
297 # This will bind all the arguments to nested() into a single list
299 with Nested(m) as foo:
300 self.assertInWithManagerInvariants(m)
301 self.assertAfterWithManagerInvariantsNoError(m)
303 def testSingleArgBoundToSingleElementParenthesizedList(self):
304 m = mock_contextmanager_generator()
305 # This will bind all the arguments to nested() into a single list
307 with Nested(m) as (foo):
308 self.assertInWithManagerInvariants(m)
309 self.assertAfterWithManagerInvariantsNoError(m)
311 def testSingleArgBoundToMultipleElementTupleError(self):
312 def shouldThrowValueError():
313 with Nested(mock_contextmanager_generator()) as (foo, bar):
315 self.assertRaises(ValueError, shouldThrowValueError)
317 def testSingleArgUnbound(self):
318 mock_contextmanager = mock_contextmanager_generator()
319 mock_nested = MockNested(mock_contextmanager)
321 self.assertInWithManagerInvariants(mock_contextmanager)
322 self.assertInWithManagerInvariants(mock_nested)
323 self.assertAfterWithManagerInvariantsNoError(mock_contextmanager)
324 self.assertAfterWithManagerInvariantsNoError(mock_nested)
326 def testMultipleArgUnbound(self):
327 m = mock_contextmanager_generator()
328 n = mock_contextmanager_generator()
329 o = mock_contextmanager_generator()
330 mock_nested = MockNested(m, n, o)
332 self.assertInWithManagerInvariants(m)
333 self.assertInWithManagerInvariants(n)
334 self.assertInWithManagerInvariants(o)
335 self.assertInWithManagerInvariants(mock_nested)
336 self.assertAfterWithManagerInvariantsNoError(m)
337 self.assertAfterWithManagerInvariantsNoError(n)
338 self.assertAfterWithManagerInvariantsNoError(o)
339 self.assertAfterWithManagerInvariantsNoError(mock_nested)
341 def testMultipleArgBound(self):
342 mock_nested = MockNested(mock_contextmanager_generator(),
343 mock_contextmanager_generator(), mock_contextmanager_generator())
344 with mock_nested as (m, n, o):
345 self.assertInWithGeneratorInvariants(m)
346 self.assertInWithGeneratorInvariants(n)
347 self.assertInWithGeneratorInvariants(o)
348 self.assertInWithManagerInvariants(mock_nested)
349 self.assertAfterWithGeneratorInvariantsNoError(m)
350 self.assertAfterWithGeneratorInvariantsNoError(n)
351 self.assertAfterWithGeneratorInvariantsNoError(o)
352 self.assertAfterWithManagerInvariantsNoError(mock_nested)
355 class ExceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
356 def testSingleResource(self):
357 cm = mock_contextmanager_generator()
359 with cm as self.resource:
360 self.assertInWithManagerInvariants(cm)
361 self.assertInWithGeneratorInvariants(self.resource)
362 self.raiseTestException()
363 self.assertRaises(RuntimeError, shouldThrow)
364 self.assertAfterWithManagerInvariantsWithError(cm)
365 self.assertAfterWithGeneratorInvariantsWithError(self.resource)
367 def testNestedSingleStatements(self):
368 mock_a = mock_contextmanager_generator()
369 mock_b = mock_contextmanager_generator()
371 with mock_a as self.foo:
372 with mock_b as self.bar:
373 self.assertInWithManagerInvariants(mock_a)
374 self.assertInWithManagerInvariants(mock_b)
375 self.assertInWithGeneratorInvariants(self.foo)
376 self.assertInWithGeneratorInvariants(self.bar)
377 self.raiseTestException()
378 self.assertRaises(RuntimeError, shouldThrow)
379 self.assertAfterWithManagerInvariantsWithError(mock_a)
380 self.assertAfterWithManagerInvariantsWithError(mock_b)
381 self.assertAfterWithGeneratorInvariantsWithError(self.foo)
382 self.assertAfterWithGeneratorInvariantsWithError(self.bar)
384 def testMultipleResourcesInSingleStatement(self):
385 cm_a = mock_contextmanager_generator()
386 cm_b = mock_contextmanager_generator()
387 mock_nested = MockNested(cm_a, cm_b)
389 with mock_nested as (self.resource_a, self.resource_b):
390 self.assertInWithManagerInvariants(cm_a)
391 self.assertInWithManagerInvariants(cm_b)
392 self.assertInWithManagerInvariants(mock_nested)
393 self.assertInWithGeneratorInvariants(self.resource_a)
394 self.assertInWithGeneratorInvariants(self.resource_b)
395 self.raiseTestException()
396 self.assertRaises(RuntimeError, shouldThrow)
397 self.assertAfterWithManagerInvariantsWithError(cm_a)
398 self.assertAfterWithManagerInvariantsWithError(cm_b)
399 self.assertAfterWithManagerInvariantsWithError(mock_nested)
400 self.assertAfterWithGeneratorInvariantsWithError(self.resource_a)
401 self.assertAfterWithGeneratorInvariantsWithError(self.resource_b)
403 def testNestedExceptionBeforeInnerStatement(self):
404 mock_a = mock_contextmanager_generator()
405 mock_b = mock_contextmanager_generator()
408 with mock_a as self.foo:
409 self.assertInWithManagerInvariants(mock_a)
410 self.assertInWithGeneratorInvariants(self.foo)
411 self.raiseTestException()
412 with mock_b as self.bar:
414 self.assertRaises(RuntimeError, shouldThrow)
415 self.assertAfterWithManagerInvariantsWithError(mock_a)
416 self.assertAfterWithGeneratorInvariantsWithError(self.foo)
418 # The inner statement stuff should never have been touched
419 self.assertEqual(self.bar, None)
420 self.assertFalse(mock_b.enter_called)
421 self.assertFalse(mock_b.exit_called)
422 self.assertEqual(mock_b.exit_args, None)
424 def testNestedExceptionAfterInnerStatement(self):
425 mock_a = mock_contextmanager_generator()
426 mock_b = mock_contextmanager_generator()
428 with mock_a as self.foo:
429 with mock_b as self.bar:
430 self.assertInWithManagerInvariants(mock_a)
431 self.assertInWithManagerInvariants(mock_b)
432 self.assertInWithGeneratorInvariants(self.foo)
433 self.assertInWithGeneratorInvariants(self.bar)
434 self.raiseTestException()
435 self.assertRaises(RuntimeError, shouldThrow)
436 self.assertAfterWithManagerInvariantsWithError(mock_a)
437 self.assertAfterWithManagerInvariantsNoError(mock_b)
438 self.assertAfterWithGeneratorInvariantsWithError(self.foo)
439 self.assertAfterWithGeneratorInvariantsNoError(self.bar)
441 def testRaisedStopIteration1(self):
449 raise StopIteration("from with")
451 self.assertRaises(StopIteration, shouldThrow)
453 def testRaisedStopIteration2(self):
458 def __exit__(self, type, value, traceback):
463 raise StopIteration("from with")
465 self.assertRaises(StopIteration, shouldThrow)
467 def testRaisedStopIteration3(self):
468 # Another variant where the exception hasn't been instantiated
476 raise iter([]).next()
478 self.assertRaises(StopIteration, shouldThrow)
480 def testRaisedGeneratorExit1(self):
488 raise GeneratorExit("from with")
490 self.assertRaises(GeneratorExit, shouldThrow)
492 def testRaisedGeneratorExit2(self):
497 def __exit__(self, type, value, traceback):
502 raise GeneratorExit("from with")
504 self.assertRaises(GeneratorExit, shouldThrow)
506 def testErrorsInBool(self):
507 # issue4589: __exit__ return code may raise an exception
508 # when looking at its truth value.
511 def __init__(self, bool_conversion):
513 def __nonzero__(self):
514 return bool_conversion()
515 self.exit_result = Bool()
518 def __exit__(self, a, b, c):
519 return self.exit_result
522 with cm(lambda: True):
523 self.fail("Should NOT see this")
527 with cm(lambda: False):
528 self.fail("Should raise")
529 self.assertRaises(AssertionError, falseAsBool)
532 with cm(lambda: 1//0):
533 self.fail("Should NOT see this")
534 self.assertRaises(ZeroDivisionError, failAsBool)
537 class NonLocalFlowControlTestCase(unittest.TestCase):
539 def testWithBreak(self):
543 with mock_contextmanager_generator():
546 counter += 100 # Not reached
547 self.assertEqual(counter, 11)
549 def testWithContinue(self):
555 with mock_contextmanager_generator():
558 counter += 100 # Not reached
559 self.assertEqual(counter, 12)
561 def testWithReturn(self):
566 with mock_contextmanager_generator():
569 counter += 100 # Not reached
570 self.assertEqual(foo(), 11)
572 def testWithYield(self):
574 with mock_contextmanager_generator():
578 self.assertEqual(x, [12, 13])
580 def testWithRaise(self):
584 with mock_contextmanager_generator():
587 counter += 100 # Not reached
589 self.assertEqual(counter, 11)
591 self.fail("Didn't raise RuntimeError")
594 class AssignmentTargetTestCase(unittest.TestCase):
596 def testSingleComplexTarget(self):
597 targets = {1: [0, 1, 2]}
598 with mock_contextmanager_generator() as targets[1][0]:
599 self.assertEqual(targets.keys(), [1])
600 self.assertEqual(targets[1][0].__class__, MockResource)
601 with mock_contextmanager_generator() as targets.values()[0][1]:
602 self.assertEqual(targets.keys(), [1])
603 self.assertEqual(targets[1][1].__class__, MockResource)
604 with mock_contextmanager_generator() as targets[2]:
605 keys = targets.keys()
607 self.assertEqual(keys, [1, 2])
610 with mock_contextmanager_generator() as blah.foo:
611 self.assertEqual(hasattr(blah, "foo"), True)
613 def testMultipleComplexTargets(self):
615 def __enter__(self): return 1, 2, 3
616 def __exit__(self, t, v, tb): pass
617 targets = {1: [0, 1, 2]}
618 with C() as (targets[1][0], targets[1][1], targets[1][2]):
619 self.assertEqual(targets, {1: [1, 2, 3]})
620 with C() as (targets.values()[0][2], targets.values()[0][1], targets.values()[0][0]):
621 self.assertEqual(targets, {1: [3, 2, 1]})
622 with C() as (targets[1], targets[2], targets[3]):
623 self.assertEqual(targets, {1: 1, 2: 2, 3: 3})
626 with C() as (blah.one, blah.two, blah.three):
627 self.assertEqual(blah.one, 1)
628 self.assertEqual(blah.two, 2)
629 self.assertEqual(blah.three, 3)
632 class ExitSwallowsExceptionTestCase(unittest.TestCase):
634 def testExitTrueSwallowsException(self):
635 class AfricanSwallow:
636 def __enter__(self): pass
637 def __exit__(self, t, v, tb): return True
639 with AfricanSwallow():
641 except ZeroDivisionError:
642 self.fail("ZeroDivisionError should have been swallowed")
644 def testExitFalseDoesntSwallowException(self):
645 class EuropeanSwallow:
646 def __enter__(self): pass
647 def __exit__(self, t, v, tb): return False
649 with EuropeanSwallow():
651 except ZeroDivisionError:
654 self.fail("ZeroDivisionError should have been raised")
658 run_unittest(FailureTestCase, NonexceptionalTestCase,
659 NestedNonexceptionalTestCase, ExceptionalTestCase,
660 NonLocalFlowControlTestCase,
661 AssignmentTargetTestCase,
662 ExitSwallowsExceptionTestCase)
665 if __name__ == '__main__':