diff --git a/python/otbtf.py b/python/otbtf.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f37f5440ee74765554471f7fdcbc308976d13e4
--- /dev/null
+++ b/python/otbtf.py
@@ -0,0 +1,525 @@
+# -*- coding: utf-8 -*-
+# ==========================================================================
+#
+#   Copyright 2018-2019 Remi Cresson (IRSTEA)
+#   Copyright 2020 Remi Cresson (INRAE)
+#
+#   Licensed under the Apache License, Version 2.0 (the "License");
+#   you may not use this file except in compliance with the License.
+#   You may obtain a copy of the License at
+#
+#          http://www.apache.org/licenses/LICENSE-2.0.txt
+#
+#   Unless required by applicable law or agreed to in writing, software
+#   distributed under the License is distributed on an "AS IS" BASIS,
+#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#   See the License for the specific language governing permissions and
+#   limitations under the License.
+#
+# ==========================================================================*/
+import threading
+import multiprocessing
+import time
+import numpy as np
+import tensorflow as tf
+import gdal
+import logging
+from abc import ABC, abstractmethod
+
+
+"""
+------------------------------------------------------- Helpers --------------------------------------------------------
+"""
+
+
+def gdal_open(filename):
+    """
+    Open a GDAL raster
+    :param filename: raster file
+    :return: a GDAL ds instance
+    """
+    ds = gdal.Open(filename)
+    if ds is None:
+        raise Exception("Unable to open file {}".format(filename))
+    return ds
+
+
+def read_as_np_arr(ds, as_patches=True):
+    """
+    Read a GDAL raster as numpy array
+    :param ds: GDAL ds instance
+    :param as_patches: if True, the returned numpy array has the following shape (n, psz_x, psz_x, nb_channels). If
+        False, the shape is (1, psz_y, psz_x, nb_channels)
+    :return: Numpy array of dim 4
+    """
+    buffer = ds.ReadAsArray()
+    szx = ds.RasterXSize
+    if len(buffer.shape) == 3:
+        buffer = np.transpose(buffer, axes=(1, 2, 0))
+    if not as_patches:
+        n = 1
+        szy = ds.RasterYSize
+    else:
+        n = int(ds.RasterYSize / szx)
+        szy = szx
+    return np.float32(buffer.reshape((n, szy, szx, ds.RasterCount)))
+
+
+"""
+---------------------------------------------------- Buffer class ------------------------------------------------------
+"""
+
+
+class Buffer:
+    """
+    Used to store and access list of objects
+    """
+
+    def __init__(self, max_length):
+        self.max_length = max_length
+        self.container = []
+
+    def size(self):
+        return len(self.container)
+
+    def add(self, x):
+        self.container.append(x)
+        assert (self.size() <= self.max_length)
+
+    def is_complete(self):
+        return self.size() == self.max_length
+
+
+"""
+------------------------------------------------ PatchesReaderBase class -----------------------------------------------
+"""
+
+
+class PatchesReaderBase(ABC):
+    """
+    Base class for patches delivery
+    """
+
+    @abstractmethod
+    def get_sample(self, index):
+        """
+        Return one sample.
+        :return One sample instance, whatever the sample structure is (dict, numpy array, ...)
+        """
+        pass
+
+    @abstractmethod
+    def get_stats(self) -> dict:
+        """
+        Compute some statistics for each source.
+        Depending if streaming is used, the statistics are computed directly in memory, or chunk-by-chunk.
+
+        :return a dict having the following structure:
+        {
+        "src_key_0":
+            {"min": np.array([...]),
+            "max": np.array([...]),
+            "mean": np.array([...]),
+            "std": np.array([...])},
+        ...,
+        "src_key_M":
+            {"min": np.array([...]),
+            "max": np.array([...]),
+            "mean": np.array([...]),
+            "std": np.array([...])},
+        }
+        """
+        pass
+
+    @abstractmethod
+    def get_size(self):
+        """
+        Returns the total number of samples
+        :return: number of samples (int)
+        """
+        pass
+
+
+"""
+----------------------------------------------- PatchesImagesReader class ----------------------------------------------
+"""
+
+
+class PatchesImagesReader(PatchesReaderBase):
+    """
+    This class provides a read access to a set of patches images.
+
+    A patches image is an image of patches stacked in rows, as produced from the OTBTF "PatchesExtraction"
+    application, and is stored in a raster format (e.g. GeoTiff).
+    A source can be a particular domain in which the patches are extracted (remember that in OTBTF applications,
+    the number of sources is controlled by the OTB_TF_NSOURCES environment variable).
+
+    This class enables to use:
+     - multiple sources
+     - multiple patches images per source
+
+    Each patch can be independently accessed using the get_sample(index) function, with index in [0, self.size),
+    self.size being the total number of patches (must be the same for each sources).
+
+    :see PatchesReaderBase
+    """
+
+    def __init__(self, filenames_dict: dict, use_streaming=False):
+        """
+        :param filenames_dict: A dict() structured as follow:
+            {src_name1: [src1_patches_image_1.tif, ..., src1_patches_image_N.tif],
+             src_name2: [src2_patches_image_1.tif, ..., src2_patches_image_N.tif],
+             ...
+             src_nameM: [srcM_patches_image_1.tif, ..., srcM_patches_image_N.tif]}
+        :param use_streaming: if True, the patches are read on the fly from the disc, nothing is kept in memory.
+        """
+
+        assert (len(filenames_dict.values()) > 0)
+
+        # ds dict
+        self.ds = dict()
+        for src_key, src_filenames in filenames_dict.items():
+            self.ds[src_key] = []
+            for src_filename in src_filenames:
+                self.ds[src_key].append(gdal_open(src_filename))
+
+        if len(set([len(ds_list) for ds_list in self.ds.values()])) != 1:
+            raise Exception("Each source must have the same number of patches images")
+
+        # streaming on/off
+        self.use_streaming = use_streaming
+
+        # ds check
+        nb_of_patches = {key: 0 for key in self.ds}
+        self.nb_of_channels = dict()
+        for src_key, ds_list in self.ds.items():
+            for ds in ds_list:
+                nb_of_patches[src_key] += self._get_nb_of_patches(ds)
+                if src_key not in self.nb_of_channels:
+                    self.nb_of_channels[src_key] = ds.RasterCount
+                else:
+                    if self.nb_of_channels[src_key] != ds.RasterCount:
+                        raise Exception("All patches images from one source must have the same number of channels!"
+                                        "Error happened for source: {}".format(src_key))
+        if len(set(nb_of_patches.values())) != 1:
+            raise Exception("Sources must have the same number of patches! Number of patches: {}".format(nb_of_patches))
+
+        # ds sizes
+        src_key_0 = list(self.ds)[0]  # first key
+        self.ds_sizes = [self._get_nb_of_patches(ds) for ds in self.ds[src_key_0]]
+        self.size = sum(self.ds_sizes)
+
+        # if use_streaming is False, we store in memory all patches images
+        if not self.use_streaming:
+            patches_list = {src_key: [read_as_np_arr(ds) for ds in self.ds[src_key]] for src_key in self.ds}
+            self.patches_buffer = {src_key: np.concatenate(patches_list[src_key], axis=-1) for src_key in self.ds}
+
+    def _get_ds_and_offset_from_index(self, index):
+        offset = index
+        for i, ds_size in enumerate(self.ds_sizes):
+            if offset < ds_size:
+                break
+            offset -= ds_size
+
+        return i, offset
+
+    @staticmethod
+    def _get_nb_of_patches(ds):
+        return int(ds.RasterYSize / ds.RasterXSize)
+
+    @staticmethod
+    def _read_extract_as_np_arr(ds, offset):
+        assert (ds is not None)
+        psz = ds.RasterXSize
+        yoff = int(offset * psz)
+        assert (yoff + psz <= ds.RasterYSize)
+        buffer = ds.ReadAsArray(0, yoff, psz, psz)
+        if len(buffer.shape) == 3:
+            buffer = np.transpose(buffer, axes=(1, 2, 0))
+        return np.float32(buffer)
+
+    def get_sample(self, index):
+        """
+        Return one sample of the dataset.
+        :param index: the sample index. Must be in the [0, self.size) range.
+        :return: The sample is stored in a dict() with the following structure:
+            {"src_key_0": np.array((psz_y_0, psz_x_0, nb_ch_0)),
+             "src_key_1": np.array((psz_y_1, psz_x_1, nb_ch_1)),
+             ...
+             "src_key_M": np.array((psz_y_M, psz_x_M, nb_ch_M))}
+        """
+        assert (0 <= index)
+        assert (index < self.size)
+
+        if not self.use_streaming:
+            res = {src_key: self.patches_buffer[src_key][index, :, :, :] for src_key in self.ds}
+        else:
+            i, offset = self._get_ds_and_offset_from_index(index)
+            res = {src_key: self._read_extract_as_np_arr(self.ds[src_key][i], offset) for src_key in self.ds}
+
+        return res
+
+    def get_stats(self):
+        """
+        Compute some statistics for each source.
+        Depending if streaming is used, the statistics are computed directly in memory, or chunk-by-chunk.
+
+        :return statistics dict
+        """
+        logging.info("Computing stats")
+        if not self.use_streaming:
+            axis = (0, 1, 2)  # (row, col)
+            stats = {src_key: {"min": np.amin(patches_buffer, axis=axis),
+                               "max": np.amax(patches_buffer, axis=axis),
+                               "mean": np.mean(patches_buffer, axis=axis),
+                               "std": np.std(patches_buffer, axis=axis)} for src_key, patches_buffer in
+                     self.patches_buffer.items()}
+        else:
+            axis = (0, 1)  # (row, col)
+
+            def _filled(value):
+                return {src_key: value * np.ones((self.nb_of_channels[src_key])) for src_key in self.ds}
+
+            _maxs = _filled(0.0)
+            _mins = _filled(float("inf"))
+            _sums = _filled(0.0)
+            _sqsums = _filled(0.0)
+            for index in range(self.size):
+                sample = self.get_sample(index=index)
+                for src_key, np_arr in sample.items():
+                    rnumel = 1.0 / float(np_arr.shape[0] * np_arr.shape[1])
+                    _mins[src_key] = np.minimum(np.amin(np_arr, axis=axis).flatten(), _mins[src_key])
+                    _maxs[src_key] = np.maximum(np.amax(np_arr, axis=axis).flatten(), _maxs[src_key])
+                    _sums[src_key] += rnumel * np.sum(np_arr, axis=axis).flatten()
+                    _sqsums[src_key] += rnumel * np.sum(np.square(np_arr), axis=axis).flatten()
+
+            rsize = 1.0 / float(self.size)
+            stats = {src_key: {"min": _mins[src_key],
+                               "max": _maxs[src_key],
+                               "mean": rsize * _sums[src_key],
+                               "std": np.sqrt(rsize * _sqsums[src_key] - np.square(rsize * _sums[src_key]))
+                               } for src_key in self.ds}
+        logging.info("Stats: {}".format(stats))
+        return stats
+
+    def get_size(self):
+        return self.size
+
+
+"""
+------------------------------------------------- IteratorBase class ---------------------------------------------------
+"""
+
+
+class IteratorBase(ABC):
+    """
+    Base class for iterators
+    """
+    @abstractmethod
+    def __init__(self, patches_reader: PatchesReaderBase):
+        pass
+
+
+"""
+------------------------------------------------ RandomIterator class --------------------------------------------------
+"""
+
+
+class RandomIterator(IteratorBase):
+    """
+    Pick a random number in the [0, handler.size) range.
+    """
+
+    def __init__(self, patches_reader):
+        super().__init__(patches_reader=patches_reader)
+        self.indices = np.arange(0, patches_reader.get_size())
+        self._shuffle()
+        self.count = 0
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        current_index = self.indices[self.count]
+        if self.count < len(self.indices) - 1:
+            self.count += 1
+        else:
+            self._shuffle()
+            self.count = 0
+        return current_index
+
+    def _shuffle(self):
+        np.random.shuffle(self.indices)
+
+
+"""
+--------------------------------------------------- Dataset class ------------------------------------------------------
+"""
+
+
+class Dataset:
+    """
+    Handles the "mining" of patches.
+    This class has a thread that extract tuples from the readers, while ensuring the access of already gathered tuples.
+
+    :see PatchesReaderBase
+    :see Buffer
+    """
+
+    def __init__(self, patches_reader: PatchesReaderBase, buffer_length: int = 128,
+                 Iterator: IteratorBase = RandomIterator):
+        """
+        :param patches_reader: The patches reader instance
+        :param buffer_length: The number of samples that are stored in the buffer
+        :param Iterator: The iterator class used to generate the sequence of patches indices.
+        """
+
+        # patches reader
+        self.patches_reader = patches_reader
+        self.size = self.patches_reader.get_size()
+
+        # iterator
+        self.iterator = Iterator(patches_reader=self.patches_reader)
+
+        # Get patches sizes and type, of the first sample of the first tile
+        self.output_types = dict()
+        self.output_shapes = dict()
+        one_sample = self.patches_reader.get_sample(index=0)
+        for src_key, np_arr in one_sample.items():
+            self.output_shapes[src_key] = np_arr.shape
+            self.output_types[src_key] = tf.dtypes.as_dtype(np_arr.dtype)
+
+        logging.info("output_types: {}".format(self.output_types))
+        logging.info("output_shapes: {}".format(self.output_shapes))
+
+        # buffers
+        self.miner_buffer = Buffer(buffer_length)
+        self.consumer_buffer = Buffer(buffer_length)
+        self.consumer_buffer_pos = 0
+        self.tot_wait = 0
+        self.miner_thread = self._summon_miner_thread()
+        self.read_lock = multiprocessing.Lock()
+        self._dump()
+
+        # Prepare tf dataset for one epoch
+        self.tf_dataset = tf.data.Dataset.from_generator(self._generator,
+                                                         output_types=self.output_types,
+                                                         output_shapes=self.output_shapes).repeat(1)
+
+    def get_stats(self) -> dict:
+        """
+        :return: the dataset statistics, computed by the patches reader
+        """
+        return self.patches_reader.get_stats()
+
+    def read_one_sample(self):
+        """
+        Read one element of the consumer_buffer
+        The lock is used to prevent different threads to read and update the internal counter concurrently
+        """
+        with self.read_lock:
+            output = None
+            if self.consumer_buffer_pos < self.consumer_buffer.max_length:
+                output = self.consumer_buffer.container[self.consumer_buffer_pos]
+                self.consumer_buffer_pos += 1
+            if self.consumer_buffer_pos == self.consumer_buffer.max_length:
+                self._dump()
+                self.consumer_buffer_pos = 0
+            return output
+
+    def _dump(self):
+        """
+        This function dumps the miner_buffer into the consumer_buffer, and restart the miner_thread
+        """
+        # Wait for miner to finish his job
+        t = time.time()
+        self.miner_thread.join()
+        self.tot_wait += time.time() - t
+
+        # Copy miner_buffer.container --> consumer_buffer.container
+        self.consumer_buffer.container = [elem for elem in self.miner_buffer.container]
+
+        # Clear miner_buffer.container
+        self.miner_buffer.container.clear()
+
+        # Restart miner_thread
+        self.miner_thread = self._summon_miner_thread()
+
+    def _collect(self):
+        """
+        This function collects samples.
+        It is threaded by the miner_thread.
+        """
+        # Fill the miner_container until it's full
+        while not self.miner_buffer.is_complete():
+            try:
+                index = next(self.iterator)
+                new_sample = self.patches_reader.get_sample(index=index)
+                self.miner_buffer.add(new_sample)
+            except Exception as e:
+                logging.warning("Error during collecting samples: {}".format(e))
+
+    def _summon_miner_thread(self):
+        """
+        Create and starts the thread for the data collect
+        """
+        t = threading.Thread(target=self._collect)
+        t.start()
+        return t
+
+    def _generator(self):
+        """
+        Generator function, used for the tf dataset
+        """
+        for elem in range(self.size):
+            yield self.read_one_sample()
+
+    def get_tf_dataset(self, batch_size, drop_remainder=True):
+        """
+        Returns a TF dataset, ready to be used with the provided batch size
+        :param batch_size: the batch size
+        :param drop_remainder: drop incomplete batches
+        :return: The TF dataset
+        """
+        if batch_size <= 2 * self.miner_buffer.max_length:
+            logging.warning("Batch size is {} but dataset buffer has {} elements. Consider using a larger dataset "
+                            "buffer to avoid I/O bottleneck".format(batch_size, self.miner_buffer.max_length))
+        return self.tf_dataset.batch(batch_size, drop_remainder=drop_remainder)
+
+    def get_total_wait_in_seconds(self):
+        """
+        Returns the number of seconds during which the data gathering was delayed because of I/O bottleneck
+        :return: duration in seconds
+        """
+        return self.tot_wait
+
+
+"""
+------------------------------------------- DatasetFromPatchesImages class ---------------------------------------------
+"""
+
+
+class DatasetFromPatchesImages(Dataset):
+    """
+    Handles the "mining" of a set of patches images.
+
+    :see PatchesImagesReader
+    :see Dataset
+    """
+
+    def __init__(self, filenames_dict: dict, use_streaming: bool = False, buffer_length: int = 128,
+                 Iterator: IteratorBase = RandomIterator):
+        """
+        :param filenames_dict: A dict() structured as follow:
+            {src_name1: [src1_patches_image1, ..., src1_patches_imageN1],
+             src_name2: [src2_patches_image2, ..., src2_patches_imageN2],
+             ...
+             src_nameM: [srcM_patches_image1, ..., srcM_patches_imageNM]}
+        :param use_streaming: if True, the patches are read on the fly from the disc, nothing is kept in memory.
+        :param buffer_length: The number of samples that are stored in the buffer (used when "use_streaming" is True).
+        :param Iterator: The iterator class used to generate the sequence of patches indices.
+        """
+        # patches reader
+        patches_reader = PatchesImagesReader(filenames_dict=filenames_dict, use_streaming=use_streaming)
+
+        super().__init__(patches_reader=patches_reader, buffer_length=buffer_length, Iterator=Iterator)