Python mock的时候偷换对象

815 查看

最近做发短信的service的时候,与短信相关的测试需要mock,于是碰到以下问题。
例如有三个模块a和b和test。

# a.py:

def send_sms():
    # 调用运营商的短信接口
    print 'send_sms'


# b.py:

from a import send_sms

def func():
    status_code = send_sms()
    return status_code

# test.py:

from b import func

def dummy_send_sms():
    print 'send_sms'

test_assert(func() == status_success)

a模块负责发短信,b是具体的业务,我们想测试b的业务,但是我们可不想在测试的时候发短信。于是我们想在测试的时候,把b模块func里面对send_sms的调用改成对一个mock函数的调用,如test模块的dummy_send_sms函数。

先说结论,这里要这么干:

# test.py:
import sys
from b import func

def dummy_send_sms():
    print 'send_sms'

sys.modules['b'].__dict__['send_sms'] = dummy_send_sms
# 如果b模块是import a,然后a.send_sms的话就要这样
# sys.modules['a'].__dict__['send_sms'] = dummy_send_sms

test_assert(func() == status_success)

或者使用Python的mock库的patch。
可以这么干,有两个原因:
首先,Python里面模块就是一个Python对象,所以我们可以随时通过篡改这个模块对象的成员来篡改模块里面符号与对象的对应关系。
然后,Python里面对名字的resolution是在运行的时候发生的。

具体来说
函数调用的字节码如下:

LOAD_NAME  0 (send_sms) 
# 0 表示符号表第0个元素,也就是字符串"send_sms",然后把"send_sms"对应的对象压到栈里面(注意Python是基于堆栈的虚拟机,这里的栈跟调用栈并不完全一样)
CALL_FUNCTION 0
# 0表示0个参数,所调用的函数就是栈上面的由"send_sms"查到的对象

拿到"send_sms"这个参数,CALL_FUNCTION这条字节码怎么去执行函数呢?

# /Python/ceval.c

TARGET(LOAD_NAME) {
    PyObject *name = GETITEM(names, oparg);
    PyObject *locals = f->f_locals;
    PyObject *v;
    if (locals == NULL) {
        PyErr_Format(PyExc_SystemError,
                     "no locals when loading %R", name);
        goto error;
    }
    if (PyDict_CheckExact(locals)) {
        v = PyDict_GetItem(locals, name);
        Py_XINCREF(v);
    }
    else {
        v = PyObject_GetItem(locals, name);
        if (v == NULL && _PyErr_OCCURRED()) {
            if (!PyErr_ExceptionMatches(PyExc_KeyError))
                goto error;
            PyErr_Clear();
        }
    }
    if (v == NULL) {
        v = PyDict_GetItem(f->f_globals, name);
        Py_XINCREF(v);
        if (v == NULL) {
            if (PyDict_CheckExact(f->f_builtins)) {
                v = PyDict_GetItem(f->f_builtins, name);
                if (v == NULL) {
                    format_exc_check_arg(
                                PyExc_NameError,
                                NAME_ERROR_MSG, name);
                    goto error;
                }
                Py_INCREF(v);
            }
            else {
                v = PyObject_GetItem(f->f_builtins, name);
                if (v == NULL) {
                    if (PyErr_ExceptionMatches(PyExc_KeyError))
                        format_exc_check_arg(
                                    PyExc_NameError,
                                    NAME_ERROR_MSG, name);
                    goto error;
                }
            }
        }
    }
    PUSH(v);
    DISPATCH();
}

这段代码首先拿出name(也就是"send_sms"这个字符串),然后就在ff_localsf_globalsf_builtins 里面找相应的对象。
这个f就是当前的栈帧PyFrameObject。

typedef struct _frame {
    PyObject_VAR_HEAD
    # ...
    PyObject *f_builtins;       /* builtin symbol table (PyDictObject) */
    PyObject *f_globals;        /* global symbol table (PyDictObject) */
    PyObject *f_locals;         /* local symbol table (any mapping) */
    # ...
} PyFrameObject;

所以由名字找出对象,是在LOAD_NAME这条指令运行的时候才计算的,这样就为我们的篡改留了机会。
所以,当我们运行了sys.modules['b'].__dict__['send_sms'] = dummy_send_sms之后,LOAD_NAME之后根据"send_sms"找到的对象就是我们我们篡改的dummy_send_sms

其实,由于PyFunctionObject拥有个func_globals的指针指向所在模块的符号表:

typedef struct {
    PyObject_HEAD
    PyObject *func_code;    /* A code object */
    PyObject *func_globals;    /* A dictionary (other mappings won't do) */
    PyObject *func_defaults;    /* NULL or a tuple */
    PyObject *func_closure;    /* NULL or a tuple of cell objects */
    PyObject *func_doc;        /* The __doc__ attribute, can be anything */
    PyObject *func_name;    /* The __name__ attribute, a string object */
    PyObject *func_dict;    /* The __dict__ attribute, a dict or NULL */
    PyObject *func_weakreflist;    /* List of weak references */
    PyObject *func_module;    /* The __module__ attribute, can be anything */

    /* Invariant:
     *     func_closure contains the bindings for func_code->co_freevars, so
     *     PyTuple_Size(func_closure) == PyCode_GetNumFree(func_code)
     *     (func_closure may be NULL if PyCode_GetNumFree(func_code) == 0).
     */
} PyFunctionObject;

所以还可以这么干

# test.py:
import sys
from b import func

def dummy_send_sms():
    print 'send_sms'

# 可以直接从函数的 func_globals 指针修改符号表
b_globals = getattr(func, '__globals__')
b_globals['send_sms'] = dummy_send_sms

test_assert(func() == status_success)