最简单那种自带的梯度裁剪就不说了,这里说的是TensorFlow没有我们想要的操作。 minimize() minimize()
实际上包含了两个步骤,即
compute_gradients
,用于计算梯度,
和
apply_gradients
,用于使用计算得到的梯度来更新对应的variable。
建议使用anaconda 创建虚拟环境,以免玩坏了,或者出现各种bug。初次玩总会遇到莫名其妙的问题!
利用TensorFlow自定义op的机制(要在Linux环境下进行)定义自己处理梯度的方法,相当于在上面两步间插入一个自定义的操作,把处理后的值apply就行。 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 //zero_out.cc 文件 #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/op_kernel.h" #include<cmath> #include <cstdlib> #include <ctime> using namespace tensorflow; double RefactorinGradient(double num) { //something return 0; } REGISTER_OP("ZeroOut") .Input("to_zero: float") .Output("zeroed: float") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { c->set_output(0, c->input(0)); return Status::OK(); }); class ZeroOutOp : public OpKernel { public: explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { // 将输入 tensor 从 context 中取出。 const Tensor& input_tensor = context->input(0); auto input = input_tensor.flat<float>(); // 创建一个 ouput_tensor, 使用 context->allocate_ouput() 给它分配空间。 Tensor* output_tensor = NULL; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output_flat = output_tensor->flat<float>(); // Set all but the first element of the output tensor to 0. const int N = input.size(); for (int i = 0; i < N; i++) { output_flat(i) = RefactorinGradient(input(i)); } } }; REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
1 2 3 4 5 6 TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') ) TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') ) sudo g++ -std=c++11 -shared zero_out.cc -o zero_out.so -fPIC -D_GLIBCXX_USE_CXX11_ABI=0 ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2
百度的paddlepaddle深度学习框架可以在Python端自定义,相对简单一点,但性能肯定是没有c++的好