Tensorflow 源码中python API和底层c/c++之间的关联主要是靠 SWIG 来完成的,这是一个用于把C/C++集成到其他语言中的编译器。本文主要介绍swig以及tensorflow如何利用swig关联前后端,另外主要介绍了一下Session的生命周期。
http://www.swig.org/
SWIG实例之python调用C
首先准备C语言函数
/* File : example.c */
#include < time.h>
double My_variable = 3.0;
int fact(int n) {
if (n < = 1) return 1;
else return n*fact(n-1);
}
int my_mod(int x, int y) {
return (x%y);
}
char *get_time()
{
time_t ltime;
time(& ltime);
return ctime(& ltime);
}
然后准备swig interface接口文件
/* example.i */
%module example
%{
/* Put header files here or function declarations like below */
extern double My_variable;
extern int fact(int n);
extern int my_mod(int x, int y);
extern char *get_time();
%}
extern double My_variable;
extern int fact(int n);
extern int my_mod(int x, int y);
extern char *get_time();
编译
需要安装swig,Ubuntu环境执行 sudo apt-get install swig ,然后执行如下指令,会生产example.py 和 example_wrap.c 两个文件
$swig -python example.i
在example_wrap.c 中可以找到一个静态的函数符号表如下,这里将python调用名称如get_time与包装的函数_wrap_get_time进行连接起来,匹配对应的C函数实现,最终在example.c中找到具体实现
static PyMethodDef SwigMethods[] = {
{ (char *)"SWIG_PyInstanceMethod_New", (PyCFunction)SWIG_PyInstanceMethod_New, METH_O, NULL},
{ (char *)"fact", _wrap_fact, METH_VARARGS, NULL},
{ (char *)"my_mod", _wrap_my_mod, METH_VARARGS, NULL},
{ (char *)"get_time", _wrap_get_time, METH_VARARGS, NULL},
{ NULL, NULL, 0, NULL }
};
然后利用gcc编译C语言文件,注意这里-I需要输入python的include目录,若报错 _fatal error: Python.h: No such file or director_ 说明目录没有设置正确,可以使用locate Python.h查找include目录,tensorflow使用的是python 2.7所以这里执行如下
$gcc -c example.c example_wrap.c -I/usr/include/python2.7
执行完后生成对应的.o目标文件,然后链接成共享库
$ld -shared example.o example_wrap.o -o _example.so
到这里时执行会报错,_ld: example_wrap.o: relocation R_X86_64_32S against .rodata can not be used when making a shared object; recompile with -fPIC_
这里在上一个环节gcc编译的时候加上-fPIC即可,如 gcc -fPIC -c ...,然后再执行对应的链接ld操作,就可以生成对应的共享库 _example.so
使用python模块
可以如下用Python调用example中的函数
>>> import example >>> example.fact(5) 120 >>> example.my_mod(7,3) 1 >>> example.get_time() 'Sun Mar 18 12:59:22 2018\n' >>>
Tensorflow中Session生命周期
tensorflow利用swig包装C/C++实现系统,tensorflow使用Bazel来构建,在编译之前启动swig的代码生成过程,tensorflow中Session的swig接口文件为 tf_session.i 会生成适配文件:pywrap_tensorflow.py 和 pywrap_tensorflow.cpp,在编译之后会生成共享库 _pywrap_tensorflow.so。
类似于上述的swig例子,在pywrap_tensorflow.cpp中同样静态注册了一个函数符号表,在运行时会对Python的函数名称匹配找到对应的C函数实现,最终转到c_api.c
Session 在Python中的创建过程
当client启动Session执行之前,先创建一个Session实例,进而调用父类BaseSession的构造函数,这里开始调用swig创建的包装Python库pywrap_tensorflow的函数
tf_session.TF_NewSessionOptions等来创建或删除Session 。 而且从BaseSession的构造函数中也可以看到,如果没有初始化graph的话,会调用一个default的graph
# tensorflow/python/client/session.py
from tensorflow.python import pywrap_tensorflow as tf_session
class BaseSession(SessionInterface):
def __init__(self, target='', graph=None, config=None):
# other code
if graph is None:
self._graph = ops.get_default_graph()
# other code
self._session = None
opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
try:
with errors.raise_exception_on_not_ok_status() as status:
if self._created_with_new_api:
self._session = tf_session.TF_NewSession(self._graph._c_graph, opts,
else:
self._session = tf_session.TF_NewDeprecatedSession(opts, status)
finally:
tf_session.TF_DeleteSessionOptions(opts)
# other code
从这里开始就经过swig包装的中间层之后,就能够调用共享库中的C语言API了,同样在pywrap_tensorflow.cpp中也静态注册了函数调用的符号表
Session 在C/C++中的生命周期
以调用新API为例,在BaseSession的初始化时会调用pywrap_tensorflow的TF_NewSession,并进一步通过swig调用c_api中的函数TF_NewSession
TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
TF_Status* status) {
Session* session;
status->status = NewSession(opt->options, & session);
if (status->status.ok()) {
TF_Session* new_session = new TF_Session(session, graph);
if (graph != nullptr) {
mutex_lock l(graph->mu);
graph->sessions[new_session] = Status::OK();
}
return new_session;
} else {
DCHECK_EQ(nullptr, session);
return nullptr;
}
}
如上C API函数只是做了一个过渡,函数中的NewSession(opt->options, &session)才会调用后端的C++ 系统
后端系统C++中的NewSession函数如下,这是一个全局函数,不属于哪个类,之后会调用SessionFactory类的一系列函数来进行Session的创建,这一部分的内容就属于tensorflow的核心后端部分。
// tensorflow/core/common_runtime/session.cc
Status NewSession(const SessionOptions& options, Session** out_session) {
SessionFactory* factory;
const Status s = SessionFactory::GetFactory(options, & factory);
if (!s.ok()) {
*out_session = nullptr;
LOG(ERROR) < < s;
return s;
}
*out_session = factory->NewSession(options);
if (!*out_session) {
return errors::Internal("Failed to create session.");
}
return Status::OK();
}
在C++系统中,SessionFactory根据前端传递的Session.target来多态创建Session对象,如DirectionSession将启动本地运行模式,而GrpcSession将启动给予RPC的分布式运行模式
创建扩展运行graph
Python前端调用Session.run时,会将构造好的graph以GraphDef的形式发给C++,每次调用run都会将新增节点的graph发给C++后端用于extend原来的graph
后端系统每次Session.run执行被称为一次step, 每次step计算图会正想计算网络的输出,反向传递梯度,并完成训练参数的更新,后端系统根据Feed和Fetch对整个graph进行剪枝,得到一个最小依赖的计算子图(称为client graph)
在运行时会启动设备分配算法,如果节点之间的边横跨了设备,则将该边分裂,插入Send和Recv节点,实现通信。然后将子图片段(partition graph)注册到相应设备并启动执行。
关闭和销毁Session
主动调用sess.close之后进行Session的关闭,关闭之后Python前端会启动GC进行垃圾回收,当Session.del被调用后,启动后台C++的Session对象销毁过程。
参考文章:
- https://www.jianshu.com/p/667cbb20d802