From 705ec54538bf938884eb777f6215358afca4b728 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@inrae.fr> Date: Mon, 19 Sep 2022 19:23:23 +0200 Subject: [PATCH 1/2] REFAC: move _is_chief() and cropped_tensor_name() into model.py --- otbtf/model.py | 35 ++++++++++++++++++++++++++++++++++- otbtf/utils.py | 34 ---------------------------------- 2 files changed, 34 insertions(+), 35 deletions(-) diff --git a/otbtf/model.py b/otbtf/model.py index a6306e27..38fc6ca6 100644 --- a/otbtf/model.py +++ b/otbtf/model.py @@ -3,7 +3,6 @@ import abc import logging import tensorflow -from otbtf.utils import _is_chief, cropped_tensor_name class ModelBase(abc.ABC): @@ -165,3 +164,37 @@ class ModelBase(abc.ABC): model_simplified = tensorflow.keras.Model(inputs=inputs, outputs=outputs, name=self.__class__.__name__ + '_simplified') tensorflow.keras.utils.plot_model(model_simplified, output_path) + + +def _is_chief(strategy): + """ + Tell if the current worker is the chief. + + :param strategy: strategy + :return: True if the current worker is the chief, False else + """ + # Note: there are two possible `TF_CONFIG` configuration. + # 1) In addition to `worker` tasks, a `chief` task type is use; + # in this case, this function should be modified to + # `return task_type == 'chief'`. + # 2) Only `worker` task type is used; in this case, worker 0 is + # regarded as the chief. The implementation demonstrated here + # is for this case. + # For the purpose of this Colab section, the `task_type is None` case + # is added because it is effectively run with only a single worker. + + if strategy.cluster_resolver: # this means MultiWorkerMirroredStrategy + task_type, task_id = strategy.cluster_resolver.task_type, strategy.cluster_resolver.task_id + return (task_type == 'chief') or (task_type == 'worker' and task_id == 0) or task_type is None + # strategy with only one worker + return True + + +def cropped_tensor_name(tensor_name, crop): + """ + A name for the padded tensor + :param tensor_name: tensor name + :param pad: pad value + :return: name + """ + return "{}_crop{}".format(tensor_name, crop) diff --git a/otbtf/utils.py b/otbtf/utils.py index 7aa777af..069638a5 100644 --- a/otbtf/utils.py +++ b/otbtf/utils.py @@ -63,37 +63,3 @@ def read_as_np_arr(gdal_ds, as_patches=True, dtype=None): buffer = buffer.astype(dtype) return buffer - - -def _is_chief(strategy): - """ - Tell if the current worker is the chief. - - :param strategy: strategy - :return: True if the current worker is the chief, False else - """ - # Note: there are two possible `TF_CONFIG` configuration. - # 1) In addition to `worker` tasks, a `chief` task type is use; - # in this case, this function should be modified to - # `return task_type == 'chief'`. - # 2) Only `worker` task type is used; in this case, worker 0 is - # regarded as the chief. The implementation demonstrated here - # is for this case. - # For the purpose of this Colab section, the `task_type is None` case - # is added because it is effectively run with only a single worker. - - if strategy.cluster_resolver: # this means MultiWorkerMirroredStrategy - task_type, task_id = strategy.cluster_resolver.task_type, strategy.cluster_resolver.task_id - return (task_type == 'chief') or (task_type == 'worker' and task_id == 0) or task_type is None - # strategy with only one worker - return True - - -def cropped_tensor_name(tensor_name, crop): - """ - A name for the padded tensor - :param tensor_name: tensor name - :param pad: pad value - :return: name - """ - return "{}_crop{}".format(tensor_name, crop) -- GitLab From abd9d433ba746aac22e29b98b4118fd5053c512e Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@inrae.fr> Date: Mon, 19 Sep 2022 19:33:33 +0200 Subject: [PATCH 2/2] ADD: warning when otbtf is imported w/o GDAL --- otbtf/__init__.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/otbtf/__init__.py b/otbtf/__init__.py index 5321e3db..6ee17f9e 100644 --- a/otbtf/__init__.py +++ b/otbtf/__init__.py @@ -20,9 +20,12 @@ """ OTBTF python module """ +try: + from otbtf.utils import read_as_np_arr, gdal_open + from otbtf.dataset import Buffer, PatchesReaderBase, PatchesImagesReader, IteratorBase, RandomIterator, Dataset, \ + DatasetFromPatchesImages +except ImportError: + print("Warning: otbtf.utils and otbtf.dataset were not imported. Using OTBTF without GDAL.") -from otbtf.utils import read_as_np_arr, gdal_open -from otbtf.dataset import Buffer, PatchesReaderBase, PatchesImagesReader, IteratorBase, RandomIterator, Dataset, \ - DatasetFromPatchesImages from otbtf.tfrecords import TFRecords from otbtf.model import ModelBase -- GitLab