Flask 源码分析

参考文章: http://www.zchengjoey.com/posts/flask%E6%BA%90%E7%A0%81%E5%88%86%E6%9E%90/

WSGI

在说到wsgi时我们先看一下面向http的python程序需要关心的内容:

  1. 请求:
    • 请求方法(method)
    • 请求地址(url)
    • 请求内容(body)
    • 请求头(header)
    • 请求环境(environ)
  2. 响应:
    • 响应码(status_code)
    • 响应数据(data)
    • 响应头(header)

wsgi要做的就是关于程序端和服务端的标准规范, 将服务程序接收到的请i去传递给python程序, 并将网络的数据流和python的结构体进行转换.
它规定了python程序必须是一个可调用对象(实现了call函数的方法或类), 接受两个参数environ(WSGI的环境信息)和start_response(开始响应请求的函数), 并返回可迭代的结果. 直接上代码来实现一个最简单的web程序返回hello world:

from werkzeug.serving import run_simple

class WebClass:
    def __init__(self):
        pass

    def __call__(self, environ, start_response):
        status = '200 OK'
        response_headers = [('Content-type', 'text/plain')]
        start_response(status, response_headers)
        yield str.encode("Hello World!\n")

if __name__ == "__main__":
    app = WebClass()
    run_simple("127.0.0.1", 5000, app)

WebClass正是实现了__call__方法的可调用对象, 接受environ和start_respone, 并在返回之前调用start_response, start_response接受两个必须的参数, status_code(http状态码)和response_header(响应头), yield hello world正是要求的可迭代结果, 现在这个类只是实现了最简单的功能, 路由注册, 模板渲染等都没有实现.这里用了werkzeug提供的run_simple, 其实我们创建flask应用, 跑起来的时候调用的也是这个函数,后面将会讲到。

项目运行

app.run(), 源码如下:

    def run(self, host=None, port=None, debug=None, load_dotenv=True, **options):
        # 是否是 cli 启动
        if os.environ.get("FLASK_RUN_FROM_CLI") == "true":
            from .debughelpers import explain_ignored_app_run

            explain_ignored_app_run()
            return

        if get_load_dotenv(load_dotenv):
            cli.load_dotenv()

            # if set, let env vars override previous values
            if "FLASK_ENV" in os.environ:
                self.env = get_env()
                self.debug = get_debug_flag()
            elif "FLASK_DEBUG" in os.environ:
                self.debug = get_debug_flag()

        # debug passed to method overrides all other sources
        if debug is not None:
            self.debug = bool(debug)

        _host = "127.0.0.1"
        _port = 5000
        server_name = self.config.get("SERVER_NAME")
        sn_host, sn_port = None, None

        if server_name:
            sn_host, _, sn_port = server_name.partition(":")

        host = host or sn_host or _host
        # pick the first value that's not None (0 is allowed)
        port = int(next((p for p in (port, sn_port) if p is not None), _port))

        options.setdefault("use_reloader", self.debug)
        options.setdefault("use_debugger", self.debug)
        # 默认是多线程启动的
        options.setdefault("threaded", True)
        # 显示 banner 信息
        cli.show_server_banner(self.env, self.debug, self.name, False)

        from werkzeug.serving import run_simple

        try:
            # 运行 run_simple
            run_simple(host, port, self, **options)
        finally:
            self._got_first_request = False

上面函数的功能很简单,处理了以下参数,其中最主要的还是调用 werkzeug 的 run_simple 函数, run_simple 源码如下:

def run_simple(
    hostname,
    port,
    application,
    use_reloader=False,
    use_debugger=False,
    use_evalex=True,
    extra_files=None,
    reloader_interval=1,
    reloader_type="auto",
    threaded=False,
    processes=1,
    request_handler=None,
    static_files=None,
    passthrough_errors=False,
    ssl_context=None,
):

    if not isinstance(port, int):
        raise TypeError("port must be an integer")
    if use_debugger:
        from .debug import DebuggedApplication

        application = DebuggedApplication(application, use_evalex)
    if static_files:
        from .middleware.shared_data import SharedDataMiddleware

        application = SharedDataMiddleware(application, static_files)
    # 启动后的显示信息,包含运行地址,端口等
    def log_startup(sock):
        display_hostname = hostname if hostname not in ("", "*") else "localhost"
        quit_msg = "(Press CTRL+C to quit)"
        if sock.family == af_unix:
            _log("info", " * Running on %s %s", display_hostname, quit_msg)
        else:
            if ":" in display_hostname:
                display_hostname = "[%s]" % display_hostname
            port = sock.getsockname()[1]
            _log(
                "info",
                " * Running on %s://%s:%d/ %s",
                "http" if ssl_context is None else "https",
                display_hostname,
                port,
                quit_msg,
            )

    def inner():
        try:
            fd = int(os.environ["WERKZEUG_SERVER_FD"])
        except (LookupError, ValueError):
            fd = None
        # 获取一个 wsgi 对象
        srv = make_server(
            hostname,
            port,
            application,
            threaded,
            processes,
            request_handler,
            passthrough_errors,
            ssl_context,
            fd=fd,
        )
        if fd is None:
            log_startup(srv.socket)
        # 调用这个 wsgi 对象的 serve_forever 方法,让其项目处于运行状态
        srv.serve_forever()

    if use_reloader:
        if not is_running_from_reloader():
            if port == 0 and not can_open_by_fd:
                raise ValueError(
                    "Cannot bind to a random port with enabled "
                    "reloader if the Python interpreter does "
                    "not support socket opening by fd."
                )
            address_family = select_address_family(hostname, port)
            server_address = get_sockaddr(hostname, port, address_family)
            s = socket.socket(address_family, socket.SOCK_STREAM)
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            s.bind(server_address)
            if hasattr(s, "set_inheritable"):
                s.set_inheritable(True)
            if can_open_by_fd:
                os.environ["WERKZEUG_SERVER_FD"] = str(s.fileno())
                s.listen(LISTEN_QUEUE)
                log_startup(s)
            else:
                s.close()
                if address_family == af_unix:
                    _log("info", "Unlinking %s" % server_address)
                    os.unlink(server_address)
        from ._reloader import run_with_reloader

        run_with_reloader(inner, extra_files, reloader_interval, reloader_type)
    else:
        # 调用 inner 函数
        inner()

run_simple 中最后调用了其中的 inner 函数,inner 中的 make_server 函数返回一个 WSGIServer 对象,然后再调用 WSGIServer 对象的 serve_forever 方法,创建 wsgi 的服务, 然后运行,这样项目就运行起来了。

请求处理

前面说了wsgi规定应用程序必须实现__call__方法, 找到Flask对应的内容:

    def __call__(self, environ, start_response):
        return self.wsgi_app(environ, start_response)

wsgi_app 源码如下:

    def wsgi_app(self, environ, start_response):
        # 获取 RequestContext 对象
        ctx = self.request_context(environ)
        error = None
        try:
            try:
                ctx.push()
                # full_dispatch_request 中做了很多事情,包含请求钩子处理,请求处理等
                response = self.full_dispatch_request()
            except Exception as e:
                error = e
                response = self.handle_exception(e)
            except:
                error = sys.exc_info()[1]
                raise
            return response(environ, start_response)
        finally:
            if self.should_ignore_error(error):
                error = None
            ctx.auto_pop(error)

full_dispatch_request 源码如下:

def full_dispatch_request(self):
    self.try_trigger_before_first_request_functions()
    try:
        request_started.send(self)
        # 判断是否有 @before_first_request,@before_request 装饰的钩子函数,如果有调用
        rv = self.preprocess_request()
        if rv is None:
            rv = self.dispatch_request()
    except Exception as e:
        rv = self.handle_user_exception(e)
     # 将请求结果通过钩子函数处理(如果有 @after_request 装饰的函数,那么就处理,不管请求是否异常都要执行的 teardown_request 函数)
    return self.finalize_request(rv)

这段最核心的就是 dispatch_request , dispatch_request 就是我们注册的路由函数的执行结果, 在 dispatch_request 之前我们看到 preprocess_request , 它的作用是将钩子函数处理一下

  1. 第一次请求处理之前的钩子函数, 通过 before_first_request 定义的
  2. 每个请求处理之前的钩子函数, 通过 before_request 定义的

而在 dispat_request 之后还有 finalize_request 函数, 它的作用同样是将请求结果通过钩子函数处理一下:

  1. 每个请求正常处理之后的 hook 函数,通过 after_request 定义
  2. 不管请求是否异常都要执行的 teardown_request hook 函数 所以上面最重要的就是 dispatch_request 函数, 找到我们注册的路由函数, 并返回

preprocess_request 源码如下:

    def preprocess_request(self):
        # 获取当前请求对象的 blueprint
        bp = _request_ctx_stack.top.request.blueprint

        funcs = self.url_value_preprocessors.get(None, ())
        if bp is not None and bp in self.url_value_preprocessors:
            funcs = chain(funcs, self.url_value_preprocessors[bp])
        for func in funcs:
            func(request.endpoint, request.view_args)

        funcs = self.before_request_funcs.get(None, ())
        if bp is not None and bp in self.before_request_funcs:
            funcs = chain(funcs, self.before_request_funcs[bp])
        for func in funcs:
            rv = func()
            if rv is not None:
                return rv

finalize_request 源码如下:

def finalize_request(self, rv, from_error_handler=False):
    # 创建一个 response 对象
    response = self.make_response(rv)
    try:
        # 处理返回前的钩子函数
        response = self.process_response(response)
        request_finished.send(self, response=response)
    except Exception:
        if not from_error_handler:
            raise
        self.logger.exception(
            "Request finalizing failed with an error while handling an error"
        )
    return response

process_response 源码如下:

   def process_response(self, response):
        ctx = _request_ctx_stack.top
        bp = ctx.request.blueprint
        funcs = ctx._after_request_functions
        if bp is not None and bp in self.after_request_funcs:
            funcs = chain(funcs, reversed(self.after_request_funcs[bp]))
        if None in self.after_request_funcs:
            funcs = chain(funcs, reversed(self.after_request_funcs[None]))
        for handler in funcs:
            response = handler(response)
        if not self.session_interface.is_null_session(ctx.session):
            self.session_interface.save_session(self, ctx.session, response)
        return response

响应

make_response 源码如下:

def make_response(self, rv):
    status = headers = None
    # 如果返回是一个 tuple
    if isinstance(rv, tuple):
        len_rv = len(rv)

        # 返回的长度为 3,那么分别就是 response 内容,http status,http headers
        if len_rv == 3:
            rv, status, headers = rv
        # 如果返回长度只有两个
        elif len_rv == 2:
            # 且最后一个是 dict, tuple, list 格式的,那么返回是 response 内容和 headers
            if isinstance(rv[1], (Headers, dict, tuple, list)):
                rv, headers = rv
            else:
            # 否则返回内容就是 response 内容和 status
                rv, status = rv
        # 其他格式则不支持
        else:
            raise TypeError(
                "The view function did not return a valid response tuple."
                " The tuple must have the form (body, status, headers),"
                " (body, status), or (body, headers)."
            )

    # 返回响应内容则不支持
    if rv is None:
        raise TypeError(
            "The view function did not return a valid response. The"
            " function either returned None or ended without a return"
            " statement."
        )

    # 确保 response 是响应类的实例
    if not isinstance(rv, self.response_class):
        if isinstance(rv, (text_type, bytes, bytearray)):
            rv = self.response_class(rv, status=status, headers=headers)
            status = headers = None
        elif isinstance(rv, dict):
            rv = jsonify(rv)
        elif isinstance(rv, BaseResponse) or callable(rv):
            try:
                rv = self.response_class.force_type(rv, request.environ)
            except TypeError as e:
                new_error = TypeError(
                    "{e}\nThe view function did not return a valid"
                    " response. The return type must be a string, dict, tuple,"
                    " Response instance, or WSGI callable, but it was a"
                    " {rv.__class__.__name__}.".format(e=e, rv=rv)
                )
                reraise(TypeError, new_error, sys.exc_info()[2])
        else:
            raise TypeError(
                "The view function did not return a valid"
                " response. The return type must be a string, dict, tuple,"
                " Response instance, or WSGI callable, but it was a"
                " {rv.__class__.__name__}.".format(rv=rv)
            )

    if status is not None:
        if isinstance(status, (text_type, bytes, bytearray)):
            rv.status = status
        else:
            rv.status_code = status
    if headers:
        rv.headers.extend(headers)
    return rv

路由匹配

在 flask 中, 构建路由规则有两种方法, 这两种方法其实是一样的,都是调用 add_url_rule 来实现

  1. 通过`@app.route()`的装饰器, 上面例子用的就是这种方法
  2. 通过app.add_url_rule, 这个方法的签名为 add_url_rule(self, rule, endpoint=None, view_func=None, **options)
def route(self, rule, **options):
 def decorator(f):
  endpoint = options.pop("endpoint", None)
  self.add_url_rule(rule, endpoint, f, **options)
  return f
 return decorator

add_url_rule 源码如下:

@setupmethod
    def add_url_rule(
        self,
        rule,
        endpoint=None,
        view_func=None,
        provide_automatic_options=None,
        **options
    ):
        # 如果没有设置 endpoint,那么 endpoint 就是函数名称
        if endpoint is None:
            endpoint = _endpoint_from_view_func(view_func)
        options["endpoint"] = endpoint

        # 获取函数对应的方法
        methods = options.pop("methods", None)
        if methods is None:
            methods = getattr(view_func, "methods", None) or ("GET",)
        if isinstance(methods, string_types):
            raise TypeError(
                "Allowed methods have to be iterables of strings, "
                'for example: @app.route(..., methods=["POST"])'
            )
        methods = set(item.upper() for item in methods)

        # 获取 required_methods
        required_methods = set(getattr(view_func, "required_methods", ()))
        if provide_automatic_options is None:
            provide_automatic_options = getattr(
                view_func, "provide_automatic_options", None
            )
        if provide_automatic_options is None:
            if "OPTIONS" not in methods:
                provide_automatic_options = True
                required_methods.add("OPTIONS")
            else:
                provide_automatic_options = False

        # 全部 methods
        methods |= required_methods

        # 获取一个 rule,底层使用的是 werkzeug.routing 的 Rule 类对象
        rule = self.url_rule_class(rule, methods=methods, **options)
        rule.provide_automatic_options = provide_automatic_options

        # 将路由添加进入 view_functions 中,底层使用的是 werkzeug.routing 的 Map 类对象
        self.url_map.add(rule)
        if view_func is not None:
            old_func = self.view_functions.get(endpoint)
            if old_func is not None and old_func != view_func:
                raise AssertionError(
                    "View function mapping is overwriting an "
                    "existing endpoint function: %s" % endpoint
                )
            self.view_functions[endpoint] = view_func

可以发现这个函数主要做的就是更新 app 的 url_map 和 view_functions 这两个变量.查找定义, 发现 url_map 是 werkzeug.routing 的Map 类对象, rule 是 werkzeug.routing 的 Rule 类对象, view_functions 就是一个字典, 从上我们也可以知道每个视图函数的 endpoint 必须是不同的.也可以发现, flask 的核心路由逻辑其实实在 werkzeug 中实现的。

dispatch_request 源码如下,在这里面将路由和要处理的函数结合起来,实现执行视图函数。

def dispatch_request(self):
    # 获取当前请求
    req = _request_ctx_stack.top.request
    if req.routing_exception is not None:
        self.raise_routing_exception(req)
    # 获取当前请求的 rule
    rule = req.url_rule
    if (
        getattr(rule, "provide_automatic_options", False)
        and req.method == "OPTIONS"
    ):
        return self.make_default_options_response()
    # 通过 endpoint 匹配视图函数并且执行
    return self.view_functions[rule.endpoint](**req.view_args)

上下文

之前在上面我们已经讲到dispatch_request函数在找到view_function后, 只是将最基本的参数传给了view_function, 可是有时这对视图函数来说是远远不够的, 它有时还需要头部(header), body里的数据, 才能正确运行, 可能 最简单的方法就是将所有的这些信息封装成一个对象, 作为参数传给视图函数, 可是这样一来所有的视图函数都需要添加对应的参数, 即使并没有用到它.

flask 的做法是把这些信息作为上下文, 类似全局变量的东西, 在需要的时候, 用 from flask import request 获取, 比如经常用的request.json, request.args, 这里有一个很重要的点就是它们必须是动态的, 在多线程或多协程的情况下, 每个线程或协程获取的必须是自己独特的对象, 不能导入后结果获取的是其他请求的内容, 那就乱套了.

那么flask是如何实现不同的线程协程准确获得自己的上下文的呢, 我们先来看一下这两个上下文的定义:

application context 演化成出两个变量 current_appg

request context 演化出 requestsession

他们的实现正式依靠 Local StackLocal Proxy 类, 正是这两个东西才让我们在并发程序中每个视图函数都会看到属于自己的上下文而不会混乱, 而这两个类能在多线程或多协程情况下实现隔离效果是考了另一个基础类 Local, 实现了类似threading.local的效果

def _lookup_req_object(name):
    top = _request_ctx_stack.top
    if top is None:
        raise RuntimeError(_request_ctx_err_msg)
    return getattr(top, name)

def _lookup_app_object(name):
    top = _app_ctx_stack.top
    if top is None:
        raise RuntimeError(_app_ctx_err_msg)
    return getattr(top, name)

def _find_app():
    top = _app_ctx_stack.top
    if top is None:
        raise RuntimeError(_app_ctx_err_msg)
    return top.app

_request_ctx_stack = LocalStack()
_app_ctx_stack = LocalStack()
current_app = LocalProxy(_find_app)
request = LocalProxy(partial(_lookup_req_object, "request"))
session = LocalProxy(partial(_lookup_req_object, "session"))
g = LocalProxy(partial(_lookup_app_object, "g"))

Local 代码如下:

__storage__ 是用于存储内容的地方,格式为 {“线程ID”: {“key”:”value”}},因为每个请求线程 ID 不一样,所以 Local 实现了隔离的效果。__ident_func__ 绑定的是 get_ident 函数,作用是获取当前的线程 ID。

class Local(object):
    # 设置只有这两个属性可以被外部访问
    __slots__ = ("__storage__", "__ident_func__")

    def __init__(self):
        # 设置 __storage__ 的初始化值
        object.__setattr__(self, "__storage__", {})
        # 设置 __ident_func__ 的值,get_ident 获得当前线程的 id
        object.__setattr__(self, "__ident_func__", get_ident)

    # 遍历 Local
    def __iter__(self):
        return iter(self.__storage__.items())

    def __call__(self, proxy):
        """Create a proxy for a name."""
        return LocalProxy(self, proxy)

    # 释放
    def __release_local__(self):
        self.__storage__.pop(self.__ident_func__(), None)
    # 获取
    # 下面三个方法实现了属性的访问,设置和删除
    # 内部都调用了get_ident方法, 获取当前的线程或协程id, 然后一次为键访问值
    # 这样外部只是看到访问实例的属性, 其实内部已经实现了线程或协程的切换
    def __getattr__(self, name):
        try:
            return self.__storage__[self.__ident_func__()][name]
        except KeyError:
            raise AttributeError(name)
    # 设置
    def __setattr__(self, name, value):
        ident = self.__ident_func__()
        storage = self.__storage__
        try:
            storage[ident][name] = value
        except KeyError:
            storage[ident] = {name: value}
    # 删除
    def __delattr__(self, name):
        try:
            del self.__storage__[self.__ident_func__()][name]
        except KeyError:
            raise AttributeError(name)

Local 是用来提供多线程或多协程的隔离属性访问的, 那么 Local Stack 就提供了隔离的栈访问, 它只要提供了 push, pop, top方法, 主要是栈的一些方法,在 LocalStack 的 push 方法中, 其实是对属性_local也就是 Local 的操作, 也就是先创建一个列表, self._local.storage[ident(当前线程或协程id)]['stack'] = [], 然后其实还是用 append 将 request 请求信息添加进去

class LocalStack(object):
    def __init__(self):
        self._local = Local()

    def __release_local__(self):
        self._local.__release_local__()

    def _get__ident_func__(self):
        return self._local.__ident_func__

    def _set__ident_func__(self, value):
        object.__setattr__(self._local, '__ident_func__', value)
    __ident_func__ = property(_get__ident_func__, _set__ident_func__)
    del _get__ident_func__, _set__ident_func__

    def __call__(self):
        def _lookup():
            rv = self.top
            if rv is None:
                raise RuntimeError('object unbound')
            return rv
        return LocalProxy(_lookup)

    # push, pop, top 实现了栈的操作
    def push(self, obj):
        rv = getattr(self._local, 'stack', None)
        if rv is None:
            self._local.stack = rv = []
        rv.append(obj)
        return rv

    def pop(self):
        stack = getattr(self._local, 'stack', None)
        if stack is None:
            return None
        elif len(stack) == 1:
            release_local(self._local)
            return stack[-1]
        else:
            return stack.pop()

    @property
    def top(self):
        try:
            return self._local.stack[-1]
        except (AttributeError, IndexError):
            return None

上述已经将 Local 和 LocalStack 讲的差不多了。request ,g 的实现,用的是 LocalProxy,LocalProxy 还重写了所有的魔术方法,具体实现都是代理对象,LocalProxy 简要代码如下:

@implements_bool
class LocalProxy(object):
    __slots__ = ('__local', '__dict__', '__name__', '__wrapped__')

    def __init__(self, local, name=None):
        object.__setattr__(self, '_LocalProxy__local', local)
        object.__setattr__(self, '__name__', name)
        if callable(local) and not hasattr(local, '__release_local__'):
            # "local" 是一个回调函数
            object.__setattr__(self, '__wrapped__', local)
    # 获取当前线程或协程对应的对象
    def _get_current_object(self):
        if not hasattr(self.__local, '__release_local__'):
            return self.__local()
        try:
            return getattr(self.__local, self.__name__)
        except AttributeError:
            raise RuntimeError('no object bound to %s' % self.__name__)

_request_ctx_stack 代表的就是请求上下文,ctx 其实是 RequestContext,请求来的时候会调用 RequestContext 的 push 方法

def push(self):
        # app_ctx 存放的是 Flask 应用实例,如果不存在,且不是当前的应用实例,那么就创建一个,在进行 push 操作
        app_ctx = _app_ctx_stack.top
        if app_ctx is None or app_ctx.app != self.app:
            app_ctx = self.app.app_context()
            app_ctx.push()
            self._implicit_app_ctx_stack.append(app_ctx)
        else:
            self._implicit_app_ctx_stack.append(None)

        if hasattr(sys, "exc_clear"):
            sys.exc_clear()
        # push 当前 RequestContext 对象到 _request_ctx_stack 中
        _request_ctx_stack.push(self)
        if self.session is None:
            session_interface = self.app.session_interface
            self.session = session_interface.open_session(self.app, self.request)

            if self.session is None:
                self.session = session_interface.make_null_session(self.app)

        if self.url_adapter is not None:
            self.match_request()

    def pop(self, exc=_sentinel):
        app_ctx = self._implicit_app_ctx_stack.pop()

        try:
            clear_request = False
            if not self._implicit_app_ctx_stack:
                self.app.do_teardown_request(exc)

                request_close = getattr(self.request, 'close', None)
                if request_close is not None:
                    request_close()
                clear_request = True
        finally:
            rv = _request_ctx_stack.pop()
            if clear_request:
                rv.request.environ['werkzeug.request'] = None
            if app_ctx is not None:
                app_ctx.pop(exc)

    def auto_pop(self, exc):
        if self.request.environ.get('flask._preserve_context') or \
           (exc is not None and self.app.preserve_context_on_exception):
            self.preserved = True
            self._preserved_exc = exc
        else:
            self.pop(exc)

push 就是将该请求的 application context (如果 _app_ctx_stack 栈顶不是当前请求所在 app,需要重新创建 app context )和 request context 都保存到相关的栈上, pop 则相反, 做一些出栈清理操作。

现在上下文就比较清楚了,就是每次有请求过来,flask 会创建当前线程或协程需要处理的两个上下文,并压入隔离的栈。application context针对的是 flask实 例的, 因为 app 实例只有一个, 所以多个请求其实是公用一个 application context, 而 request context 是每次请求过来都要创建的,在请求结束时又出栈, 所以两个的生命周期时不同的, 也就是 application context 的周期就是实例的生命周期, 而request context 的生命周期取决于请求存在的时间。

flask: app.run() 时, 并没有看到哪里启动了多线程, 理论上在单线程的情况下, 只有一个请求处理完成之后才能处理下一个请求, 那么上面为什么能同时处理多个请求, 哪里创建了多线程呢?
在flask的启动函数(run)中, 有这么一个源码:

options.setdefault("threaded", True)