了解上下文管理器 & with 块

with 语句的目的是简化 try/finally 模式。这种模式用于保证一段代码运行完毕后执行某项操作,即便那段代码由于异常、return 语句或 sys.exit()调用而中止,也会执行指定的操作。finally 子句中的代码通常用于释放重要的资源,或者还原临时变更的状态。

“上下文管理器协议”包含 __enter____exit__ 两个方法。with 语句开始运行时,会在上下文管理器对象上调用 __enter__ 方法。with 语句运行结束后,会在上下文管理器对象上调用 __exit__ 方法,以此扮演 finally 子句的角色。

>>> with open('mirror.py') as fp:
...     src = fp.read(60) 
...
>>> len(src)
60
>>> fp
<_io.TextIOWrapper name='mirror.py' mode='r' encoding='UTF-8'>
>>> fp.closed, fp.encoding
(True, 'UTF-8')
>>> fp.read(60)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: I/O operation on closed file.

open()函数返回 TextIOWrapper 类的实例,而该实例的 __enter__ 方法返回 self。不过,__enter__ 方法除了返回上下文管理器之外,还可能返回其他对象。通过 as 子句,可以将 __enter__ 方法返回的结果传递给外部对象。

不管控制流程以哪种方式退出 with 块,都会在上下文管理器对象上调用 __exit__ 方法,而不是在 __enter__ 方法返回的对象上调用。

自定义上下文管理器

class LookingGlass:
    def __enter__(self):
        import sys
        self.original_write = sys.stdout.write  
        sys.stdout.write = self.reverse_write 
        return 'JABBERWOCKY' 
    def reverse_write(self, text): 
        self.original_write(text[::-1])
    def __exit__(self, exc_type, exc_value, traceback): 
        import sys  
        sys.stdout.write = self.original_write  
        if exc_type is ZeroDivisionError: 
            print('Please DO NOT divide by zero!')
            return True
 
>>> from mirror import LookingGlass
>>> with LookingGlass() as what: 
...      print('Alice, Kitty and Snowdrop') 
...      print(what)
...
pordwonS dna yttiK ,ecilA 
YKCOWREBBAJ
>>> what  
'JABBERWOCKY'
>>> print('Back to normal.') 
Back to normal.

在实际使用中,如果应用程序接管了标准输出,可能会暂时把 sys.stdout 换成类似文件的其他对象,然后再切换成原来的版本。contextlib.redirect_stdout 上下文管理器就是这么做的:只需传入类似文件的对象,用于替代 sys.stdout

例如,在使用 tqdm 时,避免打印混乱,可以加重定向上下文,把 print 结果输出到文件中。

__enter__:只有隐式的 self 作为入参。

__exit__:除了隐式的 self,还有三个参数:

  • exc_type 异常类(例如 ZeroDivisionError)
  • exc_value 异常实例。有时会有参数传给异常构造方法,例如错误消息,这些参数可以使用 exc_value.args 获取。
  • traceback traceback 对象。

标准库示例

在 sqlite3 模块中用于管理事务,参见“12.6.7.3. Using the connection as a context manager”。

在 threading 模块中用于维护锁、条件和信号,参见“17.1.10. Using locks, conditions, and semaphores in the with statement”。

为 Decimal 对象的算术运算设置环境,参见 decimal.localcontext 函数的文档。

为了测试临时给对象打补丁,参见 unittest.mock.patch 函数的文档。

contextlib 模块

closing:如果对象提供了 close() 方法,但没有实现 __enter__ / __exit__ 协议,那么可以使用这个函数构建上下文管理器。

suppress:构建临时忽略指定异常的上下文管理器。

@contextmanager:这个装饰器把简单的生成器函数变成上下文管理器,这样就不用创建类去实现管理器协议了。

ContextDecorator:这是个基类,用于定义基于类的上下文管理器。这种上下文管理器也能用于装饰函数,在受管理的上下文中运行整个函数。

ExitStack:这个上下文管理器能进入多个上下文管理器。with 块结束时,ExitStack 按照后进先出的顺序调用栈中各个上下文管理器的 __exit__ 方法。如果事先不知道 with 块要进入多少个上下文管理器,可以使用这个类。例如,同时打开任意一个文件列表中的所有文件。

@contextmanager

@contextmanager 装饰器能减少创建上下文管理器的样板代码量,因为不用编写一个完整的类、定义 __enter____exit__ 方法,而只需实现有一个 yield 语句的生成器,生成想让 __enter__ 方法返回的值。

在使用 @contextmanager 装饰的生成器中,yield 语句的作用是把函数的定义体分成两部分:yield 语句前面的所有代码在 with 块开始时(即解释器调用 __enter__ 方法时)执行, yield 语句后面的代码在 with 块结束时(即调用 __exit__ 方法时)执行。

或者说,会把函数包装成实现了 __enter____exit__ 方法的类。

这个类的 __enter__ 方法有如下作用。

  1. 调用生成器函数,保存生成器对象(这里把它称为 gen)。
  2. 调用 next(gen),执行到 yield 关键字所在的位置。
  3. 返回 next(gen)产出的值,以便把产出的值绑定到 with/as 语句中的目标变量上。

with 块终止时,__exit__ 方法会做以下几件事。

  1. 检查有没有把异常传给 exc_type;如果有,调用 gen.throw(exception),在生成器函数定义体中包含 yield 关键字的那一行抛出异常。
  2. 否则,调用 next(gen),继续执行生成器函数定义体中 yield 语句之后的代码。
import contextlib
@contextlib.contextmanager  
def looking_glass():
    import sys
    original_write = sys.stdout.write  
    def reverse_write(text):  
        original_write(text[::-1])
    sys.stdout.write = reverse_write  
    yield 'JABBERWOCKY'  
    sys.stdout.write = original_write  
 
>>> from mirror_gen import looking_glass
>>> with looking_glass() as what: 
...      print('Alice, Kitty and Snowdrop')
...      print(what)
...
pordwonS dna yttiK ,ecilA
YKCOWREBBAJ
>>> what
'JABBERWOCKY'

上面有一个严重错误:如果在 with 块中抛出了异常,Python 解释器会将其捕获,然后在 looking_glass 函数的 yield 表达式里再次抛出。但是,那里没有处理错误的代码,因此 looking_glass 函数会中止,永远无法恢复成原来的 sys.stdout.write 方法,导致系统处于无效状态。

解决办法如下:

import contextlib
@contextlib.contextmanager
def looking_glass():
    import sys
    original_write = sys.stdout.write
    def reverse_write(text):
        original_write(text[::-1])
    sys.stdout.write = reverse_write
    msg = ''  
    try:
        yield 'JABBERWOCKY'
    except ZeroDivisionError:  
        msg = 'Please DO NOT divide by zero!'
    finally:
        sys.stdout.write = original_write  
        if msg:
            print(msg) 

一般来说,需要在 yield 语句上增加 try-except 语句,避免出现 with 块内抛出异常导致 yield 后面语句(__exit__ 内容)没有执行的情况。