contextvars模块

本文总阅读量

前记

contextvars模块最主要的功能就是可以为asyncio生态添加上下文功能,即使程序在多个协程并发运行的情况下,也能调用到程序的上下文变量.

上下文,可以理解为我们说话的语境,有些话脱离了特定的语境,他的意思就变了,程序的运行也是如此.在线程中也是有他的上下文,只不过称为堆栈,如在python中就是保存在thread.local变量中,而协程也有他自己的上下文,但是没有暴露出来,不过我们可以通过contextvars模块去保存与读取.

使用contextvars的好处不仅可以防止’一个变量传遍天’的事情发生外,还能很好的结合TypeHint,,让自己的代码更加适应工程化,同时可以让自己的代码可以被mypy以及IDE检查.

更新说明

  • 切换web框架sanic为starlette
  • 增加一个自己编写且可用于starlette,fastapi的context说明

1.有无上下文的区别

首先,先回到两个Python的两个经典框架Django与Flask,他们有很多不同,比如在传request对象时,他们的方式就是不一样的.
先看看Django的

1
2
3
4
5
6
from django.http import HttpResponse


def root(request):
so1n_name = request.GET.get('so1n_name')
return HttpResponse(f'Name is {so1n_name}')

再看看Flask

1
2
3
4
5
6
7
8
9
from flask import Flask, request


app = Flask(__name__)

@app.route('/')
def root():
so1n_name = request.GET.get('so1n_name')
return f'Name is {so1n_name}'

可以发现,在Django中,我们需要显示的传一个叫request变量,而Flask则是import一个叫request的全局变量,并在视图中直接使用,达到解耦的目的.

可能会有人说, 也就是传个变量的区别,为了省传这个变量,而花许多功夫去维护一个全局变量,有点不值得,那可以看看下面的例子,如果层次多就会出现’一个参数传一天’的情况.分层做的好或者需求不坑爹一般不会出现像下面的情况,一个好的程序员能做好代码的分层,但还是要按照需求去更改代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 伪代码,举个例子一个request传了3个函数
from django.http import HttpResponse


def is_allow(request, uid):
if request.ip == '127.0.0.1' and check_permissions(uid):
return True
else:
return False

def check_permissions(request, uid):
pass

def root(request):
user_id = request.GET.get('uid')
if is_allow(request, id):
return HttpResponse('ok')
else
return HttpResponse('error')

2.在asyncio中使用contextvars

threading.local 的隔离效果很好,但是他是针对线程的,只隔离线程之间的数据状态,所以werkzeug为了支持在gevent中运行,自己实现了一个Local变量,与contextvars相似,不过contextvars是对asyncio的上下文提供支持.

下面以一个redis client为例子,展示如何在asyncio生态中使用contextvars(详细解释见代码).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# demo/context.py
# 该文件存放contextvars相关
import contextvars

if TYPE_CHECKING:
from demo.redis_dal import RDS # 这里是一个redis的封装实例

# 初始化一个redis相关的全局context
redis_pool_context = contextvars.ContextVar('redis_pool')

# 通过函数调用可以获取到当前协程运行时的context上下文
def get_redis() -> 'RDS':
return redis_pool_context.get()

# demo/web_tool.py
# 该文件存放starlette相关模块
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.middleware.base import RequestResponseEndpoint
from starlette.responses import Response
from demo.redis_dal import RDS


# 初始化一个redis客户端变量,当前为空
REDIS_POOL = None # type: Optional[RDS]


class RequestContextMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
# 通过中间件,在进入路由之前,把redis客户端放入当前协程的上下文之中
token = redis_pool_context.set(REDIS_POOL)
try:
response = await call_next(request)
return response
finally:
# 调用完成,回收当前请求设置的redis客户端的上下文
redis_pool_context.reset(token)

async def startup_event() -> None:
global REDIS_POOL

REDIS_POOL = RDS() # 初始化客户端,里面通过asyncio.ensure_future逻辑延后连接

async def shutdown_event() -> None:
if REDIS_POOL:
await REDIS_POOL.close() # 关闭redis客户端

# demo/server.py
# 该文件存放starlette main逻辑
from starlette.applications import Starlette
from starlette.responses import JSONResponse

from demo.web_tool import RequestContextMiddleware
from demo.context import get_redis


APP = Starlette()
APP.add_middleware(RequestContextMiddleware)


@APP.route('/')
async def homepage(request):
# 伪代码,这里是执行redis命令
# 只要验证 id(get_redis())等于demo.web_tool里REDID_POOL的id一致,那证明contextvars可以为asyncio维护一套上下文状态
await get_redis().execute()
return JSONResponse({'hello': 'world'})

3.在asyncio框架中使用contextvars中的好处

首先看看未使用contextvars时,asyncio框架是如何传变量的,根据starlette的文档,在未使用contextvars时,传递redis客户端实例的办法是通过request.stat这个变量保存redis客户端的实例,改写代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# demo/web_tools.py
class RequestContextMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
request.stat.redis = REDIS_POOL
response = await call_next(request)
return response
# demo/server.py
@APP.route('/')
async def homepage(request):
# 伪代码,这里是执行redis命令
await request.stat.redis.execute()
return JSONResponse({'hello': 'world'})

可以发现,使用request.stat方法后代码量变少了,但是这是引用到了python动态特性的方法,所以除了在运行时,永远不知道这个方法调用到的变量的类型是什么,同时在编写代码时,IDE也无法智能的帮你检查(如输入request.stat.redis.时,IDE不会出现execute,或者出错时,IDE并不会提示).

2.在asyncio中使用contextvars中可以看到:

1
2
def get_redis() -> 'RDS':
return redis_pool_context.get()

使用contextvars可以与Type Hints结合,解决request.stat的缺点,虽然python是一种表达很强的动态语言,但是我觉得在web这种工程化比较严格的项目中,使用Type Hints来提升代码的可维护性和便利性是非常有必要的

4.如何优雅的使用contextvars

从上面的示例代码来看,每多一个变量,就需要自己去写一个context,一个变量的初始化,一个变量的get函数,同时在引用时使用函数会比较别扭.

自己在使用了contextvars一段时间后,觉得这样太麻烦了,每次都要做一堆重复的操作,且平时使用最多的就是把一个实例或者提炼出Headers的参数放入contextvars中,所以写了一个封装fastapi_tools.context(虽然是用于fastapi的,但也是可以用于starlette),他能屏蔽所有与contextvars的相关逻辑,其中由ContextModel负责contextvars的set和get操作,ContextMiddleware管理contextvars的周期,HeaderQuery和CustomQuery分别托管Headers相关的参数和用户自己的实例.调用者只需要在ContextModel中写入自己需要的变量,引用时调用ContextModel的属性即可

使用代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import uuid
import httpx
from fastapi import FastAPI
from fastapi_tools.context import ContextMiddleware
from fastapi_tools.context import ContextBaseModel
from fastapi_tools.context import HeaderQuery # 用于把header的变量存放于contextvars中
from fastapi_tools.context import CustomQuery # 用于把自己的实例(如上文所说的redis客户端)存放于contextvars中

app = FastAPI()
client = httpx.AsyncClient()


class ContextModel(ContextBaseModel):
# 通过该实例可以屏蔽大部分与contextvars相关的操作,如果要添加一个变量,则在该实例添加一个属性即可.
# 属性必须要使用Type Hints的写法,不然不会识别(强制使用Type Hints)
request_id: str = HeaderQuery(
'X-Request-Id',
default_func=lambda request: str(uuid.uuid4())
) # 从headers中获取request id,如果没有这个值,则用uuid代替
ip: str = HeaderQuery(
'X-Real-IP',
default_func=lambda request: request.client.host
) # 从headers中获取ip,如果没有就从request.client.host中获取,default_func需要我们传入一个函数,且该函数必须接受一个starlette.request的变量
user_agent: str = HeaderQuery('User-Agent')

# 自己传入自己实现的一个实例,交由CustomQuery托管
http_client: httpx.AsyncClient = CustomQuery(client)


# 添加一个可以维护context的中间件,该中间件会自动维护该请求期间context_model的状态,与见'2.在asyncio中使用contextvars'中'demo/web_pool.py'的RequestContextMiddleware的原理差不多
app.add_middleware(ContextMiddleware, context_model=ContextModel())


@app.get("/")
async def root():
# 通过访问ContextModel的属性即可获取到对应的上下文变量
# 检查contextvars的client实例是否与自己创建的一致
assert ContextModel.http_client == client
# 返回自己设置的headers参数
return {
"message": {
key: value
for key, value in ContextModel.to_dict().items()
if not key.startswith('custom') # 去掉自己设置的实例,这里不应该被返回到用户展示
}
}


if __name__ == '__main__':
import uvicorn
uvicorn.run(app)

5.contextvars的原理

在第一次使用时,我就很好奇contextvars是如何去维护程序的上下文的,好在contextvars的作者出了一个向下兼容的contextvars库,虽然他不支持asyncio,但我们还是可以通过代码了解到他的基本原理.

5.1 ContextMeta,ContextVarMeta和TokenMeta

ContextMeta,ContextVarMetaTokenMeta的功能都是防止用户来继承Context,ContextVarToken,原理都是通过元类来判断类名是否是自己编写类的名称,如果不是则抛错

5.2 Token

用于保存调用set时的数据,包括context本身和上一次被set的旧数据,并在调用set后返回token.
返回的token可以被用户在调用context后,调用context.reset(token)来清空保存的上下文,方便context能及时的被回收

5.3 全局唯一context

Python中由threading.local()负责每个线程的context, 协程属于线程的’子集’,所以contextvar直接基于threading.local()生成自己的全局context.

contextvar通过threading.local()生成自己的全局context–_state.并通过_set_context把上下文对象设置到_state.context和_get_context从_state.context获取上下文(实际上也充当了set的功能)

5.3.1 threading.local()

关于threading.local(),虽然不是本文重点,但由于contextvars是基于threading.local()进行封装的,所以还是要明白threading.local(),这里做一个简单的示例解释.

一个线程使用自己的局部变量比使用全局变量好,因为局部变量只有线程自己能看见,不会影响其他线程,而全局变量的修改必须加锁.但是局部变量也有问题,就是在函数调用的时候,传递起来很麻烦.
先看全局变量的例子:

1
2
3
4
5
6
7
pet_dict = {}

def get_pet(pet_name):
return pet_dict[pet_name]

def set_pet(pet_name):
return pet_dict[pet_name]

如果是多线程调用的话,那就需要加锁啦,如果用局部变量,实际上就是每个线程有一个自己的pet_dict,假设每个线程调用get_pet,set_pet时,都会把自己的pid传入进来, 那么可以改为这样子

1
2
3
4
5
6
7
pet_dict = {}

def get_pet(pet_name, pid):
return pet_dict[pid][pet_name]

def set_pet(pet_name, pid):
return pet_dict[pid][pet_name]

可以看到,没有对异常检查和初始化等处理,如果值比较复杂,靠我们手动维护太麻烦了.
这时候threading.local()就应运而生了,他负责帮我们处理这些维护的工作,我们只要对他进行一些调用即可,调用起来跟单线程调用一样简单方便.

1
2
3
4
5
6
7
thread_local=threading.local()

def get_pet(pet_name):
return thread_local[pet_name]

def set_pet(pet_name, pid):
return thread_local[pet_name]

5.4contextvar自己封装的Context

contextvars自己封装的Context比较简单,主要是self._data,这里使用到了一个叫immutables.Map()的不可变对象,并对immutables.Map()进行一些封装,所以context可以看成一个不可变的dict.

查看immutables.Map()的示例代码可以看到,每次对原对象的修改时,原对象并不会发生改变,并会返回一个已经发生改变的新对象.

1
2
3
4
5
6
7
8
9
10
11
12
map2 = map.set('a', 10)
print(map, map2)
# will print:
# <immutables.Map({'a': 1, 'b': 2})>
# <immutables.Map({'a': 10, 'b': 2})>

map3 = map2.delete('b')
print(map, map2, map3)
# will print:
# <immutables.Map({'a': 1, 'b': 2})>
# <immutables.Map({'a': 10, 'b': 2})>
# <immutables.Map({'a': 10})>

此外,context还有一个叫run的方法,执行run的时候,会使用一个新的上下文来调用传入的函数.结果并不会影响当前上下文.

1
2
3
4
5
6
7
8
9
10
11
12
13
def run(self, callable, *args, **kwargs):
# 已经存在旧的context,抛出异常,防止多线程循环调用
if self._prev_context is not None:
raise RuntimeError(
'cannot enter context: {} is already entered'.format(self))

self._prev_context = _get_context() # 保存当前的context
try:
_set_context(self) # 设置新的context
return callable(*args, **kwargs) # 执行函数
finally:
_set_context(self._prev_context) # 设置为旧的context
self._prev_context = None

5.5 ContextVar

我们一般在使用contextvars模块时,经常使用的就是ContextVar这个类了,这个类很简单,主要提供了set–设置值,get–获取值,reset–重置值三个方法.

  • set
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    def set(self, value):
    ctx = _get_context() # 获取当前上下文对象`Context`
    data = ctx._data
    try:
    old_value = data[self] # 获取Context旧对象
    except KeyError:
    old_value = Token.MISSING # 获取不到则填充一个object(全局唯一)

    updated_data = data.set(self, value) # 设置新的值
    ctx._data = updated_data
    return Token(ctx, self, old_value) # 返回带有旧值的token
  • get
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    def get(self, default=_NO_DEFAULT):
    ctx = _get_context() # 获取当前上下文对象`Context`
    try:
    return ctx[self] # 返回获取的值
    except KeyError:
    pass

    if default is not _NO_DEFAULT:
    return default # 返回调用get时设置的值

    if self._default is not _NO_DEFAULT:
    return self._default # 返回初始化context时设置的默认值

    raise LookupError # 都没有则会抛错
  • reset
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    def reset(self, token):
    if token._used:
    # 判断token是否已经被使用
    raise RuntimeError("Token has already been used once")

    if token._var is not self:
    # 判断token是否是当前contextvar返回的
    raise ValueError(
    "Token was created by a different ContextVar")

    if token._context is not _get_context():
    # 判断token的上下文是否跟contextvar上下文一致
    raise ValueError(
    "Token was created in a different Context")

    ctx = token._context
    if token._old_value is Token.MISSING:
    # 如果没有旧值则删除该值
    ctx._data = ctx._data.delete(token._var)
    else:
    # 有旧值则当前contextvar变为旧值
    ctx._data = ctx._data.set(token._var, token._old_value)

    token._used = True # 设置flag,标记token已经被使用了
    则此,contextvar的原理了解完了,接下来再看看他是如何在asyncio运行的.

6.contextvars asyncio

这里使用aiotask-context进行源码分析,主要分析在如何在asyncio中如何获取和设置context

6.1在asyncio中获取context

相比起contextvars复杂的概念,在asyncio中,我们可以很简单的获取到当前协程的task,并由task获取context,由于Pyhon3.7对asyncio的高级API 重新设计,所以可以看到需要对获取当前task进行封装

1
2
3
4
5
6
7
8
9
10
11
12
PY37 = sys.version_info >= (3, 7)

if PY37:
def asyncio_current_task(loop=None):
"""Return the current task or None."""
try:
return asyncio.current_task(loop)
except RuntimeError:
# simulate old behaviour
return None
else:
asyncio_current_task = asyncio.Task.current_task

之后我们调用asyncio_current_task().context即可获取到当前的上下文了…

6.2 对上下文的操作

同样的,我们这里也需要set, get, clear的操作,不过十分简单,只要获取context后对他进行判空操作,再进行类似于dict的操作即可.

  • set
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    def set(key, value):
    """
    Sets the given value inside Task.context[key]. If the key does not exist it creates it.
    :param key: identifier for accessing the context dict.
    :param value: value to store inside context[key].
    :raises
    """
    current_task = asyncio_current_task()
    if not current_task:
    raise ValueError(NO_LOOP_EXCEPTION_MSG.format(key))

    current_task.context[key] = value
  • get
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    def get(key, default=None):
    """
    Retrieves the value stored in key from the Task.context dict. If key does not exist,
    or there is no event loop running, default will be returned
    :param key: identifier for accessing the context dict.
    :param default: None by default, returned in case key is not found.
    :return: Value stored inside the dict[key].
    """
    current_task = asyncio_current_task()
    if not current_task:
    raise ValueError(NO_LOOP_EXCEPTION_MSG.format(key))

    return current_task.context.get(key, default)
  • clear
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    def clear():
    """
    Clear the Task.context.
    :raises ValueError: if no current task.
    """
    current_task = asyncio_current_task()
    if not current_task:
    raise ValueError("No event loop found")

    current_task.context.clear()

    6.2 copying_task_factory和chainmap_task_factory

    在Python的更高级版本中,以及支持设置context了,所以这两个方法可以不再使用了.他们最后都用到了task_factory的方法,简单说就是创建一个新的task,再合成context,最后把context设置到task
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    def task_factory(loop, coro, copy_context=False, context_factory=None):
    """
    By default returns a task factory that uses a simple dict as the task context,
    but allows context creation and inheritance to be customized via ``context_factory``.
    """
    # 生成context工厂函数
    context_factory = context_factory or partial(
    dict_context_factory, copy_context=copy_context)

    # 创建task
    task = asyncio.tasks.Task(coro, loop=loop)
    if task._source_traceback:
    del task._source_traceback[-1]

    # 获取task的context
    try:
    context = asyncio_current_task(loop=loop).context
    except AttributeError:
    context = None

    # 从context工厂中处理context并赋值在task
    task.context = context_factory(context)

    return task
    aiotask-context提供了两个对context处理的函数dict_context_factorychainmap_context_factory.在aiotask-context中,context是一个dict对象,dict_context_factory可以选择赋值或者设置新的context
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    def dict_context_factory(parent_context=None, copy_context=False):
    """A traditional ``dict`` context to keep things simple"""
    if parent_context is None:
    # initial context
    return {}
    else:
    # inherit context
    new_context = parent_context
    if copy_context:
    new_context = deepcopy(new_context)
    return new_context
    chainmap_context_factorydict_context_factory的区别就是在合并context而不是直接继承.同时借用ChainMap保证合并context后,还能同步context的改变
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    def chainmap_context_factory(parent_context=None):
    """
    A ``ChainMap`` context, to avoid copying any data
    and yet preserve strict one-way inheritance
    (just like with dict copying)
    """
    if parent_context is None:
    # initial context
    return ChainMap()
    else:
    # inherit context
    if not isinstance(parent_context, ChainMap):
    # if a dict context was previously used, then convert
    # (without modifying the original dict)
    parent_context = ChainMap(parent_context)
    return parent_context.new_child()

    7.总结

    contextvars本身其实很简单,但他可以让我们调用起来更加方便便捷,减少我们的传参次数,同时还可以结合TypeHint使项目更加工成化.除此之外,我们在编写web项目时,我们都会使用一个request_id来记录请求的链路日志,排查问题十分方便,借用contextvars也可以非常容易做到,可以查看aiotask-context的一个例子进行了解
查看评论