In order to boost the performance of data loading in PyTorch. I write TorchRecord which is similar to the TFRecord of Tensorflow.

You can find the TorchRecord project at here.

## Introduction

Following the design of TFRecord and caffe data storage, I choose Protocol Buffers which is a data interchange format developed by Google as the storage format of TorchRecord. Protocol Buffers can encode a set of Python objects into byte string and decoding it likes shooting fish in a barrel. Then, I insert all the byte strings into the LMDB(Lighting Memory-Mapping Database). Finally, we can select items in this database and decode them to the original objects.

## Protocol Buffers

I choose TensorProtos in caffe2 as the default proto. You can find the detail at here.

## Random Training

Most of the deep learninig models require random training. The dataloader need to ensure that it can provide random index of the dataset for each training epoch. The easiest way is generating the random index number and using the get() api of LMDB to get items in the database. But get() is very slow.

Finally, in TorchRecord, I shuffle the dataset in two steps：

1. Shuffle the data before inserting them into the database.
2. Get the items in database sequentially by Cursor. Cursor can construct a sequential iterator and it is faster than get(). Then, build a buffer pool and put items into this pool. The buffer pool will be shuffled every time after inserting an item. Finally, popping a random item as the training example.

Following the above steps, we can obtain a shuffled sequence of the dataset.

## Working Process

The main process will generate an index sequence(e.g. 1, 2, 3…n) and distribute them to every working process. Each working process will construct a cursor of lmdb and will seek to the location of current index number. After that, all the shuffling, decoding and transforming stuff will be done in each sub process.

## Benchmark

I test the TorchRecord at Intel® Xeon® CPU E5-2603 0 @ 1.80GHz 4core with 32 GB Mem.

• num_workers = 2:
Conventional: [00:42<00:00, 8.72it/s]
TorchRecord: [00:21<00:00, 16.91it/s]
• num_workers = 4:
Conventional: [00:22<00:00, 16.16it/s]
TorchRecord: [00:13<00:00, 26.73it/s]