By using this site, you agree to the Privacy Policy and Terms of Use.
Accept
World of SoftwareWorld of SoftwareWorld of Software
  • News
  • Software
  • Mobile
  • Computing
  • Gaming
  • Videos
  • More
    • Gadget
    • Web Stories
    • Trending
    • Press Release
Search
  • Privacy
  • Terms
  • Advertise
  • Contact
Copyright Β© All Rights Reserved. World of Software.
Reading: Train a Mask R-CNN for Instance Segmentation with TF Model Garden | HackerNoon
Share
Sign In
Notification Show More
Font ResizerAa
World of SoftwareWorld of Software
Font ResizerAa
  • Software
  • Mobile
  • Computing
  • Gadget
  • Gaming
  • Videos
Search
  • News
  • Software
  • Mobile
  • Computing
  • Gaming
  • Videos
  • More
    • Gadget
    • Web Stories
    • Trending
    • Press Release
Have an existing account? Sign In
Follow US
  • Privacy
  • Terms
  • Advertise
  • Contact
Copyright Β© All Rights Reserved. World of Software.
World of Software > Computing > Train a Mask R-CNN for Instance Segmentation with TF Model Garden | HackerNoon
Computing

Train a Mask R-CNN for Instance Segmentation with TF Model Garden | HackerNoon

News Room
Last updated: 2025/10/14 at 10:14 PM
News Room Published 14 October 2025
Share
SHARE

Content Overview

  • Install Necessary Dependencies
  • Import required libraries
  • Download subset of Ivis dataset
  • Configure the MaskRCNN Resnet FPN COCO model for custom dataset
  • Create the Task object (tfm.core.basetask.Task) from the configdefinitions.TaskConfig

This tutorial fine-tunes a Mask R-CNN with Mobilenet V2 as backbone model from the TensorFlow Model Garden package (tensorflow-models).

Model Garden contains a collection of state-of-the-art models, implemented with TensorFlow’s high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.

This tutorial demonstrates how to:

  1. Use models from the TensorFlow Models package.
  2. Train/Fine-tune a pre-built Mask R-CNN with mobilenet as backbone for Object Detection and Instance Segmentation
  3. Export the trained/tuned Mask R-CNN model

Install Necessary Dependencies

pip install -U -q "tf-models-official"
pip install -U -q remotezip tqdm opencv-python einops

Import required libraries

import os
import io
import json
import tqdm
import shutil
import pprint
import pathlib
import tempfile
import requests
import collections
import matplotlib
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from PIL import Image
from six import BytesIO
from etils import epath
from IPython import display
from urllib.request import urlopen

2023-11-30 12:05:19.630836: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-30 12:05:19.630880: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-30 12:05:19.632442: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

import orbit
import tensorflow as tf
import tensorflow_models as tfm
import tensorflow_datasets as tfds

from official.core import exp_factory
from official.core import config_definitions as cfg
from official.vision.data import tfrecord_lib
from official.vision.serving import export_saved_model_lib
from official.vision.dataloaders.tf_example_decoder import TfExampleDecoder
from official.vision.utils.object_detection import visualization_utils
from official.vision.ops.preprocess_ops import normalize_image, resize_and_crop_image
from official.vision.data.create_coco_tf_record import coco_annotations_to_lists

pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation
print(tf.__version__) # Check the version of tensorflow used

%matplotlib inline

2.15.0

Download subset of lvis dataset

LVIS: A dataset for large vocabulary instance segmentation.

:::tip
Note: LVIS uses the COCO 2017 train, validation, and test image sets. If you have already downloaded the COCO images, you only need to download the LVIS annotations. LVIS val set contains images from COCO 2017 train in addition to the COCO 2017 val split.

:::

# @title Download annotation files

wget https://dl.fbaipublicfiles.com/LVIS/lvis_v1_train.json.zip
unzip -q lvis_v1_train.json.zip
rm lvis_v1_train.json.zip

wget https://dl.fbaipublicfiles.com/LVIS/lvis_v1_val.json.zip
unzip -q lvis_v1_val.json.zip
rm lvis_v1_val.json.zip

wget https://dl.fbaipublicfiles.com/LVIS/lvis_v1_image_info_test_dev.json.zip
unzip -q lvis_v1_image_info_test_dev.json.zip
rm lvis_v1_image_info_test_dev.json.zip

--2023-11-30 12:05:23--  https://dl.fbaipublicfiles.com/LVIS/lvis_v1_train.json.zip
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 3.163.189.51, 3.163.189.108, 3.163.189.14, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|3.163.189.51|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 350264821 (334M) [application/zip]
Saving to: β€˜lvis_v1_train.json.zip’

lvis_v1_train.json. 100%[===================>] 334.04M   295MB/s    in 1.1s    

2023-11-30 12:05:25 (295 MB/s) - β€˜lvis_v1_train.json.zip’ saved [350264821/350264821]

--2023-11-30 12:05:34--  https://dl.fbaipublicfiles.com/LVIS/lvis_v1_val.json.zip
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 3.163.189.51, 3.163.189.108, 3.163.189.14, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|3.163.189.51|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 64026968 (61M) [application/zip]
Saving to: β€˜lvis_v1_val.json.zip’

lvis_v1_val.json.zi 100%[===================>]  61.06M   184MB/s    in 0.3s    

2023-11-30 12:05:34 (184 MB/s) - β€˜lvis_v1_val.json.zip’ saved [64026968/64026968]

--2023-11-30 12:05:36--  https://dl.fbaipublicfiles.com/LVIS/lvis_v1_image_info_test_dev.json.zip
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 3.163.189.51, 3.163.189.108, 3.163.189.14, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|3.163.189.51|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 384629 (376K) [application/zip]
Saving to: β€˜lvis_v1_image_info_test_dev.json.zip’

lvis_v1_image_info_ 100%[===================>] 375.61K  --.-KB/s    in 0.03s   

2023-11-30 12:05:37 (12.3 MB/s) - β€˜lvis_v1_image_info_test_dev.json.zip’ saved [384629/384629]

# @title Lvis annotation parsing

# Annotations with invalid bounding boxes. Will not be used.
_INVALID_ANNOTATIONS = [
    # Train split.
    662101,
    81217,
    462924,
    227817,
    29381,
    601484,
    412185,
    504667,
    572573,
    91937,
    239022,
    181534,
    101685,
    # Validation split.
    36668,
    57541,
    33126,
    10932,
]

def get_category_map(annotation_path, num_classes):
  with epath.Path(annotation_path).open() as f:
      data = json.load(f)

  category_map = {id+1: {'id': cat_dict['id'],
                       'name': cat_dict['name']}
                  for id, cat_dict in enumerate(data['categories'][:num_classes])}
  return category_map

class LvisAnnotation:
  """LVIS annotation helper class.
  The format of the annations is explained on
  https://www.lvisdataset.org/dataset.
  """

  def __init__(self, annotation_path):
    with epath.Path(annotation_path).open() as f:
      data = json.load(f)
    self._data = data

    img_id2annotations = collections.defaultdict(list)
    for a in self._data.get('annotations', []):
      if a['category_id'] in category_ids:
        img_id2annotations[a['image_id']].append(a)
    self._img_id2annotations = {
        k: list(sorted(v, key=lambda a: a['id']))
        for k, v in img_id2annotations.items()
    }

  @property
  def categories(self):
    """Return the category dicts, as sorted in the file."""
    return self._data['categories']

  @property
  def images(self):
    """Return the image dicts, as sorted in the file."""
    sub_images = []
    for image_info in self._data['images']:
      if image_info['id'] in self._img_id2annotations:
        sub_images.append(image_info)
    return sub_images

  def get_annotations(self, img_id):
    """Return all annotations associated with the image id string."""
    # Some images don't have any annotations. Return empty list instead.
    return self._img_id2annotations.get(img_id, [])

def _generate_tf_records(prefix, images_zip, annotation_file, num_shards=5):
    """Generate TFRecords."""

    lvis_annotation = LvisAnnotation(annotation_file)

    def _process_example(prefix, image_info, id_to_name_map):
      # Search image dirs.
      filename = pathlib.Path(image_info['coco_url']).name
      image = tf.io.read_file(os.path.join(IMGS_DIR, filename))
      instances = lvis_annotation.get_annotations(img_id=image_info['id'])
      instances = [x for x in instances if x['id'] not in _INVALID_ANNOTATIONS]
      # print([x['category_id'] for x in instances])
      is_crowd = {'iscrowd': 0}
      instances = [dict(x, **is_crowd) for x in instances]
      neg_category_ids = image_info.get('neg_category_ids', [])
      not_exhaustive_category_ids = image_info.get(
          'not_exhaustive_category_ids', []
      )
      data, _ = coco_annotations_to_lists(instances,
                                          id_to_name_map,
                                          image_info['height'],
                                          image_info['width'],
                                          include_masks=True)
      # data['category_id'] = [id-1 for id in data['category_id']]
      keys_to_features = {
          'image/encoded':
              tfrecord_lib.convert_to_feature(image.numpy()),
          'image/filename':
               tfrecord_lib.convert_to_feature(filename.encode('utf8')),
          'image/format':
              tfrecord_lib.convert_to_feature('jpg'.encode('utf8')),
          'image/height':
              tfrecord_lib.convert_to_feature(image_info['height']),
          'image/width':
              tfrecord_lib.convert_to_feature(image_info['width']),
          'image/source_id':
              tfrecord_lib.convert_to_feature(str(image_info['id']).encode('utf8')),
          'image/object/bbox/xmin':
              tfrecord_lib.convert_to_feature(data['xmin']),
          'image/object/bbox/xmax':
              tfrecord_lib.convert_to_feature(data['xmax']),
          'image/object/bbox/ymin':
              tfrecord_lib.convert_to_feature(data['ymin']),
          'image/object/bbox/ymax':
              tfrecord_lib.convert_to_feature(data['ymax']),
          'image/object/class/text':
              tfrecord_lib.convert_to_feature(data['category_names']),
          'image/object/class/label':
              tfrecord_lib.convert_to_feature(data['category_id']),
          'image/object/is_crowd':
              tfrecord_lib.convert_to_feature(data['is_crowd']),
          'image/object/area':
              tfrecord_lib.convert_to_feature(data['area'], 'float_list'),
          'image/object/mask':
              tfrecord_lib.convert_to_feature(data['encoded_mask_png'])
      }
      # print(keys_to_features['image/object/class/label'])
      example = tf.train.Example(
          features=tf.train.Features(feature=keys_to_features))
      return example



    # file_names = [f"{prefix}/{pathlib.Path(image_info['coco_url']).name}"
    #               for image_info in lvis_annotation.images]
    # _extract_images(images_zip, file_names)
    writers = [
        tf.io.TFRecordWriter(
            tf_records_dir + prefix +'-%05d-of-%05d.tfrecord' % (i, num_shards))
        for i in range(num_shards)
    ]
    id_to_name_map = {cat_dict['id']: cat_dict['name']
                      for cat_dict in lvis_annotation.categories[:NUM_CLASSES]}
    # print(id_to_name_map)
    for idx, image_info in enumerate(tqdm.tqdm(lvis_annotation.images)):
      img_data = requests.get(image_info['coco_url'], stream=True).content
      img_name = image_info['coco_url'].split('/')[-1]
      with open(os.path.join(IMGS_DIR, img_name), 'wb') as handler:
          handler.write(img_data)
      tf_example = _process_example(prefix, image_info, id_to_name_map)
      writers[idx % num_shards].write(tf_example.SerializeToString())

    del lvis_annotation

_URLS = {
    'train_images': 'http://images.cocodataset.org/zips/train2017.zip',
    'validation_images': 'http://images.cocodataset.org/zips/val2017.zip',
    'test_images': 'http://images.cocodataset.org/zips/test2017.zip',
}

train_prefix = 'train'
valid_prefix = 'val'

train_annotation_path="./lvis_v1_train.json"
valid_annotation_path="./lvis_v1_val.json"

IMGS_DIR = './lvis_sub_dataset/'
tf_records_dir="./lvis_tfrecords/"


if not os.path.exists(IMGS_DIR):
  os.mkdir(IMGS_DIR)

if not os.path.exists(tf_records_dir):
  os.mkdir(tf_records_dir)



NUM_CLASSES = 3
category_index = get_category_map(valid_annotation_path, NUM_CLASSES)
category_ids = list(category_index.keys())

# Below helper function are taken from github tensorflow dataset lvis
# https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/datasets/lvis/lvis_dataset_builder.py
_generate_tf_records(train_prefix,
                     _URLS['train_images'],
                     train_annotation_path)

100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2338/2338 [16:14<00:00,  2.40it/s]

_generate_tf_records(valid_prefix,
                     _URLS['validation_images'],
                     valid_annotation_path)

100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 422/422 [02:56<00:00,  2.40it/s]

Configure the MaskRCNN Resnet FPN COCO model for custom dataset

train_data_input_path="./lvis_tfrecords/train*"
valid_data_input_path="./lvis_tfrecords/val*"
test_data_input_path="./lvis_tfrecords/test*"
model_dir="./trained_model/"
export_dir="./exported_model/"

if not os.path.exists(model_dir):
  os.mkdir(model_dir)

In Model Garden, the collections of parameters that define a model are called configs. Model Garden can create a config based on a known set of parameters via a factory.

Use the retinanet_mobilenet_coco experiment configuration, as defined by tfm.vision.configs.maskrcnn.maskrcnn_mobilenet_coco.

Please find all the registered experiements here

The configuration defines an experiment to train a Mask R-CNN model with mobilenet as backbone and FPN as decoder. Default Congiguration is trained on COCO train2017 and evaluated on COCO val2017.

There are also other alternative experiments available such as maskrcnn_resnetfpn_coco, maskrcnn_spinenet_coco and more. One can switch to them by changing the experiment name argument to the get_exp_config function.

exp_config = exp_factory.get_exp_config('maskrcnn_mobilenet_coco')

model_ckpt_path="./model_ckpt/"
if not os.path.exists(model_ckpt_path):
  os.mkdir(model_ckpt_path)

!gsutil cp gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.data-00000-of-00001 './model_ckpt/'
!gsutil cp gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.index './model_ckpt/'

Copying gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.data-00000-of-00001...

Operation completed over 1 objects/26.9 MiB.                                     
Copying gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.index...

Operation completed over 1 objects/7.5 KiB.

Adjust the model and dataset configurations so that it works with custom dataset.

BATCH_SIZE = 8
HEIGHT, WIDTH = 256, 256
IMG_SHAPE = [HEIGHT, WIDTH, 3]


# Backbone Config
exp_config.task.annotation_file = None
exp_config.task.freeze_backbone = True
exp_config.task.init_checkpoint = "./model_ckpt/ckpt-180648"
exp_config.task.init_checkpoint_modules = "backbone"

# Model Config
exp_config.task.model.num_classes = NUM_CLASSES + 1
exp_config.task.model.input_size = IMG_SHAPE

# Training Data Config
exp_config.task.train_data.input_path = train_data_input_path
exp_config.task.train_data.dtype="float32"
exp_config.task.train_data.global_batch_size = BATCH_SIZE
exp_config.task.train_data.shuffle_buffer_size = 64
exp_config.task.train_data.parser.aug_scale_max = 1.0
exp_config.task.train_data.parser.aug_scale_min = 1.0

# Validation Data Config
exp_config.task.validation_data.input_path = valid_data_input_path
exp_config.task.validation_data.dtype="float32"
exp_config.task.validation_data.global_batch_size = BATCH_SIZE

Adjust the trainer configuration.

logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]

if 'GPU' in ''.join(logical_device_names):
  print('This may be broken in Colab.')
  device="GPU"
elif 'TPU' in ''.join(logical_device_names):
  print('This may be broken in Colab.')
  device="TPU"
else:
  print('Running on CPU is slow, so only train for a few steps.')
  device="CPU"


train_steps = 2000
exp_config.trainer.steps_per_loop = 200 # steps_per_loop = num_of_training_examples // train_batch_size

exp_config.trainer.summary_interval = 200
exp_config.trainer.checkpoint_interval = 200
exp_config.trainer.validation_interval = 200
exp_config.trainer.validation_steps =  200 # validation_steps = num_of_validation_examples // eval_batch_size
exp_config.trainer.train_steps = train_steps
exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 200
exp_config.trainer.optimizer_config.learning_rate.type="cosine"
exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.07
exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = 0.05

This may be broken in Colab.

Print the modified configuration.

pp.pprint(exp_config.as_dict())
display.Javascript("google.colab.output.setIframeHeight('500px');")

{   'runtime': {   'all_reduce_alg': None,
                   'batchnorm_spatial_persistent': False,
                   'dataset_num_private_threads': None,
                   'default_shard_dim': -1,
                   'distribution_strategy': 'mirrored',
                   'enable_xla': False,
                   'gpu_thread_mode': None,
                   'loss_scale': None,
                   'mixed_precision_dtype': 'bfloat16',
                   'num_cores_per_replica': 1,
                   'num_gpus': 0,
                   'num_packs': 1,
                   'per_gpu_thread_count': 0,
                   'run_eagerly': False,
                   'task_index': -1,
                   'tpu': None,
                   'tpu_enable_xla_dynamic_padder': None,
                   'use_tpu_mp_strategy': False,
                   'worker_hosts': None},
    'task': {   'allow_image_summary': False,
                'allowed_mask_class_ids': None,
                'annotation_file': None,
                'differential_privacy_config': None,
                'freeze_backbone': True,
                'init_checkpoint': './model_ckpt/ckpt-180648',
                'init_checkpoint_modules': 'backbone',
                'losses': {   'class_weights': None,
                              'frcnn_box_weight': 1.0,
                              'frcnn_class_loss_top_k_percent': 1.0,
                              'frcnn_class_use_binary_cross_entropy': False,
                              'frcnn_class_weight': 1.0,
                              'frcnn_huber_loss_delta': 1.0,
                              'l2_weight_decay': 4e-05,
                              'loss_weight': 1.0,
                              'mask_weight': 1.0,
                              'rpn_box_weight': 1.0,
                              'rpn_huber_loss_delta': 0.1111111111111111,
                              'rpn_score_weight': 1.0},
                'model': {   'anchor': {   'anchor_size': 3,
                                           'aspect_ratios': [0.5, 1.0, 2.0],
                                           'num_scales': 1},
                             'backbone': {   'mobilenet': {   'filter_size_scale': 1.0,
                                                              'model_id': 'MobileNetV2',
                                                              'output_intermediate_endpoints': False,
                                                              'output_stride': None,
                                                              'stochastic_depth_drop_rate': 0.0},
                                             'type': 'mobilenet'},
                             'decoder': {   'fpn': {   'fusion_type': 'sum',
                                                       'num_filters': 128,
                                                       'use_keras_layer': False,
                                                       'use_separable_conv': True},
                                            'type': 'fpn'},
                             'detection_generator': {   'apply_nms': True,
                                                        'max_num_detections': 100,
                                                        'nms_iou_threshold': 0.5,
                                                        'nms_version': 'v2',
                                                        'pre_nms_score_threshold': 0.05,
                                                        'pre_nms_top_k': 5000,
                                                        'soft_nms_sigma': None,
                                                        'use_cpu_nms': False,
                                                        'use_sigmoid_probability': False},
                             'detection_head': {   'cascade_class_ensemble': False,
                                                   'class_agnostic_bbox_pred': False,
                                                   'fc_dims': 512,
                                                   'num_convs': 4,
                                                   'num_fcs': 1,
                                                   'num_filters': 128,
                                                   'use_separable_conv': True},
                             'include_mask': True,
                             'input_size': [256, 256, 3],
                             'mask_head': {   'class_agnostic': False,
                                              'num_convs': 4,
                                              'num_filters': 128,
                                              'upsample_factor': 2,
                                              'use_separable_conv': True},
                             'mask_roi_aligner': {   'crop_size': 14,
                                                     'sample_offset': 0.5},
                             'mask_sampler': {'num_sampled_masks': 128},
                             'max_level': 6,
                             'min_level': 3,
                             'norm_activation': {   'activation': 'relu6',
                                                    'norm_epsilon': 0.001,
                                                    'norm_momentum': 0.99,
                                                    'use_sync_bn': True},
                             'num_classes': 4,
                             'outer_boxes_scale': 1.0,
                             'roi_aligner': {   'crop_size': 7,
                                                'sample_offset': 0.5},
                             'roi_generator': {   'nms_iou_threshold': 0.7,
                                                  'num_proposals': 1000,
                                                  'pre_nms_min_size_threshold': 0.0,
                                                  'pre_nms_score_threshold': 0.0,
                                                  'pre_nms_top_k': 2000,
                                                  'test_nms_iou_threshold': 0.7,
                                                  'test_num_proposals': 1000,
                                                  'test_pre_nms_min_size_threshold': 0.0,
                                                  'test_pre_nms_score_threshold': 0.0,
                                                  'test_pre_nms_top_k': 1000,
                                                  'use_batched_nms': False},
                             'roi_sampler': {   'background_iou_high_threshold': 0.5,
                                                'background_iou_low_threshold': 0.0,
                                                'cascade_iou_thresholds': None,
                                                'foreground_fraction': 0.25,
                                                'foreground_iou_threshold': 0.5,
                                                'mix_gt_boxes': True,
                                                'num_sampled_rois': 512},
                             'rpn_head': {   'num_convs': 1,
                                             'num_filters': 128,
                                             'use_separable_conv': True} },
                'name': None,
                'per_category_metrics': False,
                'train_data': {   'apply_tf_data_service_before_batching': False,
                                  'autotune_algorithm': None,
                                  'block_length': 1,
                                  'cache': False,
                                  'cycle_length': None,
                                  'decoder': {   'simple_decoder': {   'attribute_names': [   ],
                                                                       'mask_binarize_threshold': None,
                                                                       'regenerate_source_id': False},
                                                 'type': 'simple_decoder'},
                                  'deterministic': None,
                                  'drop_remainder': True,
                                  'dtype': 'float32',
                                  'enable_shared_tf_data_service_between_parallel_trainers': False,
                                  'enable_tf_data_service': False,
                                  'file_type': 'tfrecord',
                                  'global_batch_size': 8,
                                  'input_path': './lvis_tfrecords/train*',
                                  'is_training': True,
                                  'num_examples': -1,
                                  'parser': {   'aug_rand_hflip': True,
                                                'aug_rand_vflip': False,
                                                'aug_scale_max': 1.0,
                                                'aug_scale_min': 1.0,
                                                'aug_type': None,
                                                'mask_crop_size': 112,
                                                'match_threshold': 0.5,
                                                'max_num_instances': 100,
                                                'num_channels': 3,
                                                'pad': True,
                                                'rpn_batch_size_per_im': 256,
                                                'rpn_fg_fraction': 0.5,
                                                'rpn_match_threshold': 0.7,
                                                'rpn_unmatched_threshold': 0.3,
                                                'skip_crowd_during_training': True,
                                                'unmatched_threshold': 0.5},
                                  'prefetch_buffer_size': None,
                                  'seed': None,
                                  'sharding': True,
                                  'shuffle_buffer_size': 64,
                                  'tf_data_service_address': None,
                                  'tf_data_service_job_name': None,
                                  'tfds_as_supervised': False,
                                  'tfds_data_dir': '',
                                  'tfds_name': '',
                                  'tfds_skip_decoding_feature': '',
                                  'tfds_split': '',
                                  'trainer_id': None,
                                  'weights': None},
                'use_approx_instance_metrics': False,
                'use_coco_metrics': True,
                'use_wod_metrics': False,
                'validation_data': {   'apply_tf_data_service_before_batching': False,
                                       'autotune_algorithm': None,
                                       'block_length': 1,
                                       'cache': False,
                                       'cycle_length': None,
                                       'decoder': {   'simple_decoder': {   'attribute_names': [   ],
                                                                            'mask_binarize_threshold': None,
                                                                            'regenerate_source_id': False},
                                                      'type': 'simple_decoder'},
                                       'deterministic': None,
                                       'drop_remainder': False,
                                       'dtype': 'float32',
                                       'enable_shared_tf_data_service_between_parallel_trainers': False,
                                       'enable_tf_data_service': False,
                                       'file_type': 'tfrecord',
                                       'global_batch_size': 8,
                                       'input_path': './lvis_tfrecords/val*',
                                       'is_training': False,
                                       'num_examples': -1,
                                       'parser': {   'aug_rand_hflip': False,
                                                     'aug_rand_vflip': False,
                                                     'aug_scale_max': 1.0,
                                                     'aug_scale_min': 1.0,
                                                     'aug_type': None,
                                                     'mask_crop_size': 112,
                                                     'match_threshold': 0.5,
                                                     'max_num_instances': 100,
                                                     'num_channels': 3,
                                                     'pad': True,
                                                     'rpn_batch_size_per_im': 256,
                                                     'rpn_fg_fraction': 0.5,
                                                     'rpn_match_threshold': 0.7,
                                                     'rpn_unmatched_threshold': 0.3,
                                                     'skip_crowd_during_training': True,
                                                     'unmatched_threshold': 0.5},
                                       'prefetch_buffer_size': None,
                                       'seed': None,
                                       'sharding': True,
                                       'shuffle_buffer_size': 10000,
                                       'tf_data_service_address': None,
                                       'tf_data_service_job_name': None,
                                       'tfds_as_supervised': False,
                                       'tfds_data_dir': '',
                                       'tfds_name': '',
                                       'tfds_skip_decoding_feature': '',
                                       'tfds_split': '',
                                       'trainer_id': None,
                                       'weights': None} },
    'trainer': {   'allow_tpu_summary': False,
                   'best_checkpoint_eval_metric': '',
                   'best_checkpoint_export_subdir': '',
                   'best_checkpoint_metric_comp': 'higher',
                   'checkpoint_interval': 200,
                   'continuous_eval_timeout': 3600,
                   'eval_tf_function': True,
                   'eval_tf_while_loop': False,
                   'loss_upper_bound': 1000000.0,
                   'max_to_keep': 5,
                   'optimizer_config': {   'ema': None,
                                           'learning_rate': {   'cosine': {   'alpha': 0.0,
                                                                              'decay_steps': 2000,
                                                                              'initial_learning_rate': 0.07,
                                                                              'name': 'CosineDecay',
                                                                              'offset': 0},
                                                                'type': 'cosine'},
                                           'optimizer': {   'sgd': {   'clipnorm': None,
                                                                       'clipvalue': None,
                                                                       'decay': 0.0,
                                                                       'global_clipnorm': None,
                                                                       'momentum': 0.9,
                                                                       'name': 'SGD',
                                                                       'nesterov': False},
                                                            'type': 'sgd'},
                                           'warmup': {   'linear': {   'name': 'linear',
                                                                       'warmup_learning_rate': 0.05,
                                                                       'warmup_steps': 200},
                                                         'type': 'linear'} },
                   'preemption_on_demand_checkpoint': True,
                   'recovery_begin_steps': 0,
                   'recovery_max_trials': 0,
                   'steps_per_loop': 200,
                   'summary_interval': 200,
                   'train_steps': 2000,
                   'train_tf_function': True,
                   'train_tf_while_loop': True,
                   'validation_interval': 200,
                   'validation_steps': 200,
                   'validation_summary_subdir': 'validation'} }
<IPython.core.display.Javascript object>

Set up the distribution strategy.

# Setting up the Strategy
if exp_config.runtime.mixed_precision_dtype == tf.float16:
    tf.keras.mixed_precision.set_global_policy('mixed_float16')

if 'GPU' in ''.join(logical_device_names):
  distribution_strategy = tf.distribute.MirroredStrategy()
elif 'TPU' in ''.join(logical_device_names):
  tf.tpu.experimental.initialize_tpu_system()
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')
  distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
  print('Warning: this will be really slow.')
  distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])

print("Done")

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
Done

Create the Task object (tfm.core.base_task.Task) from the config_definitions.TaskConfig.

The Task object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by tfm.core.train_lib.run_experiment.

with distribution_strategy.scope():
  task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)

:::info
Originally published on the TensorFlow website, this article appears here under a new headline and is licensed under CC BY 4.0. Code samples shared under the Apache 2.0 License.

:::

Sign Up For Daily Newsletter

Be keep up! Get the latest breaking news delivered straight to your inbox.
By signing up, you agree to our Terms of Use and acknowledge the data practices in our Privacy Policy. You may unsubscribe at any time.
Share This Article
Facebook Twitter Email Print
Share
What do you think?
Love0
Sad0
Happy0
Sleepy0
Angry0
Dead0
Wink0
Previous Article 5 Free Android Apps You Should Always Install First On New Phones – BGR
Next Article 4 big tech settlements you might be eligible for in 2025
Leave a comment

Leave a Reply Cancel reply

Your email address will not be published. Required fields are marked *

Stay Connected

248.1k Like
69.1k Follow
134k Pin
54.3k Follow

Latest News

πŸ‘¨πŸΏβ€πŸš€ Daily – MultiChoice fades to French |
Computing
New ‘Pixnapping’ attack lets hackers steal Android chats, 2FA codes in seconds
News
It is already the most delivered aircraft in history
Mobile
the mega-franchise is finally evolving!
Mobile

You Might also Like

Computing

πŸ‘¨πŸΏβ€πŸš€ Daily – MultiChoice fades to French |

2 Min Read
Computing

What Happens When You Change the “Temperature” of Your AI? | HackerNoon

0 Min Read
Computing

China’s Geely expands to Poland with best-selling electric SUV Β· TechNode

1 Min Read
Computing

Try This if Your TensorFlow Code Is Slow | HackerNoon

34 Min Read
//

World of Software is your one-stop website for the latest tech news and updates, follow us now to get the news that matters to you.

Quick Link

  • Privacy Policy
  • Terms of use
  • Advertise
  • Contact

Topics

  • Computing
  • Software
  • Press Release
  • Trending

Sign Up for Our Newsletter

Subscribe to our newsletter to get our newest articles instantly!

World of SoftwareWorld of Software
Follow US
Copyright Β© All Rights Reserved. World of Software.
Welcome Back!

Sign in to your account

Lost your password?