2017年6月24日 星期六

Device Interfaces in Tensorflow

Tensorflow just got overwhelming over the developer world in recent years. It enables a developer, even a newbie in the machine learning world, to build a neural network in just couple of minutes. Also, it is designed to run on various of devices and platforms, like CPU, GPU, and distributed system. In this article we're going to focus on the latter feature, to see how Tensorflow interacts with different hardware to perform those heavy computations.

Scenario

We want to find out the way to add new devices into Tensorflow in order to execute kernels on top of them.

In the ideal world, every kernel implementation is independent to the underlying device. That is, one can execute a kernel on various of devices without, or with minimum modification on the kernel code. Such that if we want to use an alternative device, say FPGA, to power the computations, all we need to do is implementing some sort of device interface(maybe a C++ class?), rather than rewriting all of the kernels. Kernels would leverage the device interface to perform critical, and more primitive(in comparison with the ML algorithm on top of that) calculations. For example, matrix multiplications.

Take 1: The tensorflow::Device class

Let's start with this SO question. The answer indicated that if we want to add a new device, we need to implement the tensorflow::Device class and register it with some macro. In that class (tensorflow/core/common_runtime/device.h) there is a worth-noting virtual function that is most likely be the place where the target-specific computation logic is implemented: Device::Compute.

  // Performs the actual compute function.
  //
  // Subclasses may override this function if they wish to perform
  // some initialization before each compute.
  virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) {
    op_kernel->Compute(context);
  }

Unlike the Compute function in OpKernel class, which performs the real computation, the Compute function here acts more like a wrapper around all of the OpKernel::Compute. Just like the comment says, this function is responsible for setting up or initializing the device context before each kernel computation. Let's look at a concrete example: The BaseGPUDevice (tensorflow/core/common_runtime/gpu/gpu_device.h).

Thing we care most is the BaseGPUDevice::Compute function, and basically all of its primary computation logic would be delegated to BaseGPUDevice::ComputeHelper,  which is shown below(some verbose code have been trimmed).

void BaseGPUDevice::ComputeHelper(OpKernel* op_kernel,
                                  OpKernelContext* context) {
  GPUDeviceContext* gpu_device_context = device_contexts_[0];
  if (context->op_device_context() != nullptr) {
    gpu_device_context =
        static_cast<GPUDeviceContext*>(context->op_device_context());
  }
  gpu::Stream* stream = gpu_device_context->stream();
  //const auto stream_id = gpu_device_context->stream_id();

  const auto num_streams = streams_.size();
  if (num_streams > 1) {
    // If this op's device context is different from the other contexts,
    // we must wait on the stream.
    for (int i = 0; i < context->num_inputs(); ++i) {
      const GPUDeviceContext* idc =
          static_cast<GPUDeviceContext*>(context->input_device_context(i));

      if (idc->stream() != stream) stream->ThenWaitFor(idc->stream());
    }
  }
  gpu::cuda::ScopedActivateExecutorContext scoped_activation{stream->parent()};
  op_kernel->Compute(context);
  if (context->status().ok()) {
    if (sync_every_op_) {
      // Note: GPUUtil::Sync() only syncs the default stream.
      // We need to either sync the stream used by this op, or
      // all streams.  Given that this flag is typically used for
      // debugging it makes more sense to sync all GPU activity.
      context->SetStatus(GPUUtil::SyncAll(this));
    }
  }
}

Before calling the kernel's Compute function in line 23, this function would wait for the input arguments to finish if their (CUDA) streams are different from the kernel's (line 19).

Now it's pretty clear that this interface is NOT the one we desire in the scenario mentioned previously. It seems that the target-specific logics are implemented in the kernel. For example, the bias_op kernel is separated into two classes: Bias and BiasGPU, which are located in bias_op.cc and bias_op_gpu.cu.cc under tensorflow/core/kernels, respectively. CUDA code are hard-coded into the implementation of BiasGPU::compute (line 78 ~ 88).

  CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
  if (data_format == FORMAT_NHWC) {
    BiasNHWCKernel<
        T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
        config.virtual_thread_count, input, bias, output, bias_size);
  } else {
    BiasNCHWKernel<
        T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
        config.virtual_thread_count, input, bias, output, bias_size,
        image_size);
  }

Then in runtime, framework would pick the corresponding kernel variant depends on the launching configuration.

Take 2: The Eigen::TensorDevice class

During our inspection on tensorflow/core/kernels/bias_op_gpu.cu.cc, we found that there is an interesting thing in line 29:
typedef Eigen::GpuDevice GPUDevice;
Though instances of this type are only used for accessing CUDA stream in this file, we're still curious about the relationship between Eigen and Tensorflow here.

Eigen is a famous linear algebra library, and Tensorflow used it heavily in its codebase since there are many linear algebra calculations in ML algorithm, for example, again, matrix multiplications. The question is: How does Tensorflow take advantage of Eigen?

Before moving forward, you should know that due to some reasons related to the build system, part of the Eigen library code introduced here is not presented in the Tensorflow codebase, it's stored in here and here. I recommend you to learn where these two URLs can be found from this[1] note, in case that they may vary from version to version.

Let's look at QuantizeAndDequantizeOneScaleImpl::Compute (tensorflow/core/kernels/quantize_and_dequantize_op.h). First, we can find that this class is not divided into separated CPU and GPU variants. The Device type, which would eventually be resolved to types that based on Eigen::TensorDevice,  and argument Device& d, play the main roles here. Here is an example of how the latter one is used (line 71 ~ 76):

        out.device(d) =
            ((input.cwiseMin(max_range).cwiseMax(min_range) - min_range) *
                 scale +
             T(0.5)).floor() *
                inverse_scale +
            min_range;

out variable above is a Tensor, and the RHS expression above would also be resolved to a Tensor. Nevertheless, instead of assigning result of RHS to the out variable directly, the evaluation of RHS expression would be postponed and the entire RHS expression would be delegated to Eigen::TensorDevice::operator= . Implementations of Eigen::TensorDevice, Eigen::GpuDevice [2] for example, would be responsible for executing the RHS expression that passed in.

Summary

Now we know there are two ways to enable kernel execution on new devices:
  1. Modify the kernel source with device specific code or add another variant of that kernel (e.g Bias and BiasGPU ). 
  2. Implement another Eigen::TensorDevice. (e.g Eigen::GpuDevice )
Methods above are complementary, adopt different one depends on properties of the device and kernels. For example, if operations are strongly related to linear algebra, the second method is more adequate; otherwise, the first one might be more expressive although it might require lots of kernel modifications (not every kernels I think, since there are kernels that just can't be executed on devices other than CPU).

[1]: Tensorflow choose to use part of the Eigen library without any modification, so not until the first build would the build system fetch archived libraries files from official repository of Eigen. The aforementioned behavior is written in tensorflow/workspace.bzl, line 148.

[2]: Eigen/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceCuda.h