By using this site, you agree to the Privacy Policy and Terms of Use.
Accept
World of SoftwareWorld of SoftwareWorld of Software
  • News
  • Software
  • Mobile
  • Computing
  • Gaming
  • Videos
  • More
    • Gadget
    • Web Stories
    • Trending
    • Press Release
Search
  • Privacy
  • Terms
  • Advertise
  • Contact
Copyright © All Rights Reserved. World of Software.
Reading: Fine-Tuning ResNet-18 with TensorFlow Model Garden for CIFAR-10 Classification | HackerNoon
Share
Sign In
Notification Show More
Font ResizerAa
World of SoftwareWorld of Software
Font ResizerAa
  • Software
  • Mobile
  • Computing
  • Gadget
  • Gaming
  • Videos
Search
  • News
  • Software
  • Mobile
  • Computing
  • Gaming
  • Videos
  • More
    • Gadget
    • Web Stories
    • Trending
    • Press Release
Have an existing account? Sign In
Follow US
  • Privacy
  • Terms
  • Advertise
  • Contact
Copyright © All Rights Reserved. World of Software.
World of Software > Computing > Fine-Tuning ResNet-18 with TensorFlow Model Garden for CIFAR-10 Classification | HackerNoon
Computing

Fine-Tuning ResNet-18 with TensorFlow Model Garden for CIFAR-10 Classification | HackerNoon

News Room
Last updated: 2025/08/13 at 1:53 AM
News Room Published 13 August 2025
Share
SHARE

Content Overview

  • Setup
  • Configure the ResNet-18 model for the Cifar-10 dataset
  • Visualize the training model
  • Visualize the testing model
  • Train and evaluate
  • Export a SavedModel

This tutorial fine-tunes a Residual Network (ResNet) from the TensorFlow Model Garden package (tensorflow-models) to classify images in the CIFAR dataset.

Model Garden contains a collection of state-of-the-art vision models, implemented with TensorFlow’s high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.

This tutorial uses a ResNet model, a state-of-the-art image classifier. This tutorial uses the ResNet-18 model, a convolutional neural network with 18 layers.

This tutorial demonstrates how to:

  1. Use models from the TensorFlow Models package.
  2. Fine-tune a pre-built ResNet for image classification.
  3. Export the tuned ResNet model.

Setup

Install and import the necessary modules.

pip install -U -q "tf-models-official"

Import TensorFlow, TensorFlow Datasets, and a few helper libraries.

import pprint
import tempfile

from IPython import display
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds
2023-10-17 11:52:54.005237: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-17 11:52:54.005294: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-17 11:52:54.005338: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

The tensorflow_models package contains the ResNet vision model, and the official.vision.serving model contains the function to save and export the tuned model.

import tensorflow_models as tfm

# These are not in the tfm public API for v2.9. They will be available in v2.10
from official.vision.serving import export_saved_model_lib
import official.core.train_lib

Configure the ResNet-18 model for the Cifar-10 dataset

The CIFAR10 dataset contains 60,000 color images in mutually exclusive 10 classes, with 6,000 images in each class.

In Model Garden, the collections of parameters that define a model are called configs. Model Garden can create a config based on a known set of parameters via a factory.

Use the resnet_imagenet factory configuration, as defined by tfm.vision.configs.image_classification.image_classification_imagenet. The configuration is set up to train ResNet to converge on ImageNet.

exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')
tfds_name = 'cifar10'
ds,ds_info = tfds.load(
tfds_name,
with_info=True)
ds_info
2023-10-17 11:52:59.285390: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
tfds.core.DatasetInfo(
    name='cifar10',
    full_name='cifar10/3.0.2',
    description="""
    The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
    """,
    homepage='https://www.cs.toronto.edu/~kriz/cifar.html',
    data_dir='gs://tensorflow-datasets/datasets/cifar10/3.0.2',
    file_format=tfrecord,
    download_size=162.17 MiB,
    dataset_size=132.40 MiB,
    features=FeaturesDict({
        'id': Text(shape=(), dtype=string),
        'image': Image(shape=(32, 32, 3), dtype=uint8),
        'label': ClassLabel(shape=(), dtype=int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=50000, num_shards=1>,
    },
    citation="""@TECHREPORT{Krizhevsky09learningmultiple,
        author = {Alex Krizhevsky},
        title = {Learning multiple layers of features from tiny images},
        institution = {},
        year = {2009}
    }""",
)

Adjust the model and dataset configurations so that it works with Cifar-10 (cifar10).

# Configure model
exp_config.task.model.num_classes = 10
exp_config.task.model.input_size = list(ds_info.features["image"].shape)
exp_config.task.model.backbone.resnet.model_id = 18

# Configure training and testing data
batch_size = 128

exp_config.task.train_data.input_path = ''
exp_config.task.train_data.tfds_name = tfds_name
exp_config.task.train_data.tfds_split = 'train'
exp_config.task.train_data.global_batch_size = batch_size

exp_config.task.validation_data.input_path = ''
exp_config.task.validation_data.tfds_name = tfds_name
exp_config.task.validation_data.tfds_split = 'test'
exp_config.task.validation_data.global_batch_size = batch_size

Adjust the trainer configuration.

logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]

if 'GPU' in ''.join(logical_device_names):
  print('This may be broken in Colab.')
  device = 'GPU'
elif 'TPU' in ''.join(logical_device_names):
  print('This may be broken in Colab.')
  device = 'TPU'
else:
  print('Running on CPU is slow, so only train for a few steps.')
  device = 'CPU'

if device=='CPU':
  train_steps = 20
  exp_config.trainer.steps_per_loop = 5
else:
  train_steps=5000
  exp_config.trainer.steps_per_loop = 100

exp_config.trainer.summary_interval = 100
exp_config.trainer.checkpoint_interval = train_steps
exp_config.trainer.validation_interval = 1000
exp_config.trainer.validation_steps =  ds_info.splits['test'].num_examples // batch_size
exp_config.trainer.train_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'
exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1
exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100
Running on CPU is slow, so only train for a few steps.

Print the modified configuration.

pprint.pprint(exp_config.as_dict())

display.Javascript("google.colab.output.setIframeHeight('300px');")
{'runtime': {'all_reduce_alg': None,
             'batchnorm_spatial_persistent': False,
             'dataset_num_private_threads': None,
             'default_shard_dim': -1,
             'distribution_strategy': 'mirrored',
             'enable_xla': True,
             'gpu_thread_mode': None,
             'loss_scale': None,
             'mixed_precision_dtype': None,
             'num_cores_per_replica': 1,
             'num_gpus': 0,
             'num_packs': 1,
             'per_gpu_thread_count': 0,
             'run_eagerly': False,
             'task_index': -1,
             'tpu': None,
             'tpu_enable_xla_dynamic_padder': None,
             'use_tpu_mp_strategy': False,
             'worker_hosts': None},
 'task': {'allow_image_summary': False,
          'differential_privacy_config': None,
          'eval_input_partition_dims': [],
          'evaluation': {'precision_and_recall_thresholds': None,
                         'report_per_class_precision_and_recall': False,
                         'top_k': 5},
          'freeze_backbone': False,
          'init_checkpoint': None,
          'init_checkpoint_modules': 'all',
          'losses': {'l2_weight_decay': 0.0001,
                     'label_smoothing': 0.0,
                     'loss_weight': 1.0,
                     'one_hot': True,
                     'soft_labels': False,
                     'use_binary_cross_entropy': False},
          'model': {'add_head_batch_norm': False,
                    'backbone': {'resnet': {'bn_trainable': True,
                                            'depth_multiplier': 1.0,
                                            'model_id': 18,
                                            'replace_stem_max_pool': False,
                                            'resnetd_shortcut': False,
                                            'scale_stem': True,
                                            'se_ratio': 0.0,
                                            'stem_type': 'v0',
                                            'stochastic_depth_drop_rate': 0.0},
                                 'type': 'resnet'},
                    'dropout_rate': 0.0,
                    'input_size': [32, 32, 3],
                    'kernel_initializer': 'random_uniform',
                    'norm_activation': {'activation': 'relu',
                                        'norm_epsilon': 1e-05,
                                        'norm_momentum': 0.9,
                                        'use_sync_bn': False},
                    'num_classes': 10,
                    'output_softmax': False},
          'model_output_keys': [],
          'name': None,
          'train_data': {'apply_tf_data_service_before_batching': False,
                         'aug_crop': True,
                         'aug_policy': None,
                         'aug_rand_hflip': True,
                         'aug_type': None,
                         'autotune_algorithm': None,
                         'block_length': 1,
                         'cache': False,
                         'center_crop_fraction': 0.875,
                         'color_jitter': 0.0,
                         'crop_area_range': (0.08, 1.0),
                         'cycle_length': 10,
                         'decode_jpeg_only': True,
                         'decoder': {'simple_decoder': {'attribute_names': [],
                                                        'mask_binarize_threshold': None,
                                                        'regenerate_source_id': False},
                                     'type': 'simple_decoder'},
                         'deterministic': None,
                         'drop_remainder': True,
                         'dtype': 'float32',
                         'enable_shared_tf_data_service_between_parallel_trainers': False,
                         'enable_tf_data_service': False,
                         'file_type': 'tfrecord',
                         'global_batch_size': 128,
                         'image_field_key': 'image/encoded',
                         'input_path': '',
                         'is_multilabel': False,
                         'is_training': True,
                         'label_field_key': 'image/class/label',
                         'mixup_and_cutmix': None,
                         'prefetch_buffer_size': None,
                         'randaug_magnitude': 10,
                         'random_erasing': None,
                         'repeated_augment': None,
                         'seed': None,
                         'sharding': True,
                         'shuffle_buffer_size': 10000,
                         'tf_data_service_address': None,
                         'tf_data_service_job_name': None,
                         'tf_resize_method': 'bilinear',
                         'tfds_as_supervised': False,
                         'tfds_data_dir': '',
                         'tfds_name': 'cifar10',
                         'tfds_skip_decoding_feature': '',
                         'tfds_split': 'train',
                         'three_augment': False,
                         'trainer_id': None,
                         'weights': None},
          'train_input_partition_dims': [],
          'validation_data': {'apply_tf_data_service_before_batching': False,
                              'aug_crop': True,
                              'aug_policy': None,
                              'aug_rand_hflip': True,
                              'aug_type': None,
                              'autotune_algorithm': None,
                              'block_length': 1,
                              'cache': False,
                              'center_crop_fraction': 0.875,
                              'color_jitter': 0.0,
                              'crop_area_range': (0.08, 1.0),
                              'cycle_length': 10,
                              'decode_jpeg_only': True,
                              'decoder': {'simple_decoder': {'attribute_names': [],
                                                             'mask_binarize_threshold': None,
                                                             'regenerate_source_id': False},
                                          'type': 'simple_decoder'},
                              'deterministic': None,
                              'drop_remainder': True,
                              'dtype': 'float32',
                              'enable_shared_tf_data_service_between_parallel_trainers': False,
                              'enable_tf_data_service': False,
                              'file_type': 'tfrecord',
                              'global_batch_size': 128,
                              'image_field_key': 'image/encoded',
                              'input_path': '',
                              'is_multilabel': False,
                              'is_training': False,
                              'label_field_key': 'image/class/label',
                              'mixup_and_cutmix': None,
                              'prefetch_buffer_size': None,
                              'randaug_magnitude': 10,
                              'random_erasing': None,
                              'repeated_augment': None,
                              'seed': None,
                              'sharding': True,
                              'shuffle_buffer_size': 10000,
                              'tf_data_service_address': None,
                              'tf_data_service_job_name': None,
                              'tf_resize_method': 'bilinear',
                              'tfds_as_supervised': False,
                              'tfds_data_dir': '',
                              'tfds_name': 'cifar10',
                              'tfds_skip_decoding_feature': '',
                              'tfds_split': 'test',
                              'three_augment': False,
                              'trainer_id': None,
                              'weights': None} },
 'trainer': {'allow_tpu_summary': False,
             'best_checkpoint_eval_metric': '',
             'best_checkpoint_export_subdir': '',
             'best_checkpoint_metric_comp': 'higher',
             'checkpoint_interval': 20,
             'continuous_eval_timeout': 3600,
             'eval_tf_function': True,
             'eval_tf_while_loop': False,
             'loss_upper_bound': 1000000.0,
             'max_to_keep': 5,
             'optimizer_config': {'ema': None,
                                  'learning_rate': {'cosine': {'alpha': 0.0,
                                                               'decay_steps': 20,
                                                               'initial_learning_rate': 0.1,
                                                               'name': 'CosineDecay',
                                                               'offset': 0},
                                                    'type': 'cosine'},
                                  'optimizer': {'sgd': {'clipnorm': None,
                                                        'clipvalue': None,
                                                        'decay': 0.0,
                                                        'global_clipnorm': None,
                                                        'momentum': 0.9,
                                                        'name': 'SGD',
                                                        'nesterov': False},
                                                'type': 'sgd'},
                                  'warmup': {'linear': {'name': 'linear',
                                                        'warmup_learning_rate': 0,
                                                        'warmup_steps': 100},
                                             'type': 'linear'} },
             'preemption_on_demand_checkpoint': True,
             'recovery_begin_steps': 0,
             'recovery_max_trials': 0,
             'steps_per_loop': 5,
             'summary_interval': 100,
             'train_steps': 20,
             'train_tf_function': True,
             'train_tf_while_loop': True,
             'validation_interval': 1000,
             'validation_steps': 78,
             'validation_summary_subdir': 'validation'} }
<IPython.core.display.Javascript object>

Set up the distribution strategy.

logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]

if exp_config.runtime.mixed_precision_dtype == tf.float16:
    tf.keras.mixed_precision.set_global_policy('mixed_float16')

if 'GPU' in ''.join(logical_device_names):
  distribution_strategy = tf.distribute.MirroredStrategy()
elif 'TPU' in ''.join(logical_device_names):
  tf.tpu.experimental.initialize_tpu_system()
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')
  distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
  print('Warning: this will be really slow.')
  distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])
Warning: this will be really slow.

Create the Task object (tfm.core.base_task.Task) from the config_definitions.TaskConfig.

The Task object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by tfm.core.train_lib.run_experiment.

with distribution_strategy.scope():
  model_dir = tempfile.mkdtemp()
  task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)

#  tf.keras.utils.plot_model(task.build_model(), show_shapes=True)
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  print()
  print(f'images.shape: {str(images.shape):16}  images.dtype: {images.dtype!r}')
  print(f'labels.shape: {str(labels.shape):16}  labels.dtype: {labels.dtype!r}')
images.shape: (128, 32, 32, 3)  images.dtype: tf.float32
labels.shape: (128,)            labels.dtype: tf.int32
2023-10-17 11:53:02.248801: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Visualize the training data

The dataloader applies a z-score normalization using preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB), so the images returned by the dataset can’t be directly displayed by standard tools. The visualization code needs to rescale the data into the [0,1] range.

plt.hist(images.numpy().flatten());

Use ds_info (which is an instance of tfds.core.DatasetInfo) to lookup the text descriptions of each class ID.

label_info = ds_info.features['label']
label_info.int2str(1)
'automobile'

Visualize a batch of the data.

def show_batch(images, labels, predictions=None):
  plt.figure(figsize=(10, 10))
  min = images.numpy().min()
  max = images.numpy().max()
  delta = max - min

  for i in range(12):
    plt.subplot(6, 6, i + 1)
    plt.imshow((images[i]-min) / delta)
    if predictions is None:
      plt.title(label_info.int2str(labels[i]))
    else:
      if labels[i] == predictions[i]:
        color = 'g'
      else:
        color = 'r'
      plt.title(label_info.int2str(predictions[i]), color=color)
    plt.axis("off")
plt.figure(figsize=(10, 10))
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  show_batch(images, labels)
2023-10-17 11:53:04.198417: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Visualize the testing data

Visualize a batch of images from the validation dataset.

plt.figure(figsize=(10, 10));
for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):
  show_batch(images, labels)
2023-10-17 11:53:07.007846: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Train and evaluate

model, eval_logs = tfm.core.train_lib.run_experiment(
    distribution_strategy=distribution_strategy,
    task=task,
    mode='train_and_eval',
    params=exp_config,
    model_dir=model_dir,
    run_post_eval=True)
restoring or initializing model...
INFO:tensorflow:Customized initialization is done through the passed `init_fn`.
INFO:tensorflow:Customized initialization is done through the passed `init_fn`.
train | step:      0 | training until step 20...
2023-10-17 11:53:09.849007: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
train | step:      5 | steps/sec:    0.5 | output: 
    {'accuracy': 0.103125,
     'learning_rate': 0.0,
     'top_5_accuracy': 0.4828125,
     'training_loss': 2.7998607}
saved checkpoint to /tmpfs/tmp/tmpu0ate1h5/ckpt-5.
train | step:     10 | steps/sec:    0.8 | output: 
    {'accuracy': 0.0828125,
     'learning_rate': 0.0,
     'top_5_accuracy': 0.4984375,
     'training_loss': 2.8205295}
train | step:     15 | steps/sec:    0.8 | output: 
    {'accuracy': 0.0921875,
     'learning_rate': 0.0,
     'top_5_accuracy': 0.503125,
     'training_loss': 2.8169343}
train | step:     20 | steps/sec:    0.8 | output: 
    {'accuracy': 0.1015625,
     'learning_rate': 0.0,
     'top_5_accuracy': 0.45,
     'training_loss': 2.8760865}
 eval | step:     20 | running 78 steps of evaluation...
 eval | step:     20 | steps/sec:   24.4 | eval time:    3.2 sec | output: 
    {'accuracy': 0.09485176,
     'steps_per_second': 24.40085348913806,
     'top_5_accuracy': 0.49589342,
     'validation_loss': 2.5864375}
saved checkpoint to /tmpfs/tmp/tmpu0ate1h5/ckpt-20.
2023-10-17 11:53:43.844533: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5253: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5253: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.
eval | step:     20 | running 78 steps of evaluation...
2023-10-17 11:53:45.627213: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
eval | step:     20 | steps/sec:   40.1 | eval time:    1.9 sec | output: 
    {'accuracy': 0.09485176,
     'steps_per_second': 40.14298727815298,
     'top_5_accuracy': 0.49589342,
     'validation_loss': 2.5864375}
#  tf.keras.utils.plot_model(model, show_shapes=True)

Print the accuracy, top_5_accuracy, and validation_loss evaluation metrics.

for key, value in eval_logs.items():
    if isinstance(value, tf.Tensor):
      value = value.numpy()
    print(f'{key:20}: {value:.3f}')
accuracy            : 0.095
top_5_accuracy      : 0.496
validation_loss     : 2.586
steps_per_second    : 40.143

Run a batch of the processed training data through the model, and view the results

for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  predictions = model.predict(images)
  predictions = tf.argmax(predictions, axis=-1)

show_batch(images, labels, tf.cast(predictions, tf.int32))

if device=='CPU':
  plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')
2023-10-17 11:53:49.840600: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
4/4 [==============================] - 1s 13ms/step
2023-10-17 11:53:50.778301: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Export a SavedModel

The keras.Model object returned by train_lib.run_experiment expects the data to be normalized by the dataset loader using the same mean and variance statiscics in preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB). This export function handles those details, so you can pass tf.uint8 images and get the correct results.

# Saving and exporting the trained model
export_saved_model_lib.export_inference_graph(
    input_type='image_tensor',
    batch_size=1,
    input_image_size=[32, 32],
    params=exp_config,
    checkpoint_path=tf.train.latest_checkpoint(model_dir),
    export_dir='./export/')
INFO:tensorflow:Assets written to: ./export/assets
INFO:tensorflow:Assets written to: ./export/assets

Test the exported model.

# Importing SavedModel
imported = tf.saved_model.load('./export/')
model_fn = imported.signatures['serving_default']

Visualize the predictions.

plt.figure(figsize=(10, 10))
for data in tfds.load('cifar10', split='test').batch(12).take(1):
  predictions = []
  for image in data['image']:
    index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]
    predictions.append(index)
  show_batch(data['image'], data['label'], predictions)

  if device=='CPU':
    plt.suptitle('The model was only trained for a few steps, it is not expected to do better than random.')
2023-10-17 11:54:01.438509: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Originally published on the TensorFlow website, this article appears here under a new headline and is licensed under CC BY 4.0. Code samples shared under the Apache 2.0 License.

Sign Up For Daily Newsletter

Be keep up! Get the latest breaking news delivered straight to your inbox.
By signing up, you agree to our Terms of Use and acknowledge the data practices in our Privacy Policy. You may unsubscribe at any time.
Share This Article
Facebook Twitter Email Print
Share
What do you think?
Love0
Sad0
Happy0
Sleepy0
Angry0
Dead0
Wink0
Previous Article Gemini Has A Major Confidence Problem – BGR
Next Article I improved my balance and stability in a week by doing this yoga pose at my desk
Leave a comment

Leave a Reply Cancel reply

Your email address will not be published. Required fields are marked *

Stay Connected

248.1k Like
69.1k Follow
134k Pin
54.3k Follow

Latest News

How to Find the Right Influencers for Your Brand
Computing
Is 'South Park' Season 27, Episode 3 on Tonight?
News
Is Amazon testing a cheaper color Kindle?
News
Apple responds to Musk's lawsuit threat over the App Store
News

You Might also Like

Computing

How to Find the Right Influencers for Your Brand

3 Min Read
Computing

How Blob Gas Priority Fees Influence Ethereum’s Transaction Costs | HackerNoon

18 Min Read
Computing

On a mission to ‘kill Google Chrome,’ UW students join Y Combinator to launch AI-powered browser

5 Min Read
Computing

Linux Preps For New “SoC Power Slider” With Upcoming Panther Lake

3 Min Read
//

World of Software is your one-stop website for the latest tech news and updates, follow us now to get the news that matters to you.

Quick Link

  • Privacy Policy
  • Terms of use
  • Advertise
  • Contact

Topics

  • Computing
  • Software
  • Press Release
  • Trending

Sign Up for Our Newsletter

Subscribe to our newsletter to get our newest articles instantly!

World of SoftwareWorld of Software
Follow US
Copyright © All Rights Reserved. World of Software.
Welcome Back!

Sign in to your account

Lost your password?