Recently open-sourced by Google, Metrax is a JAX library providing standardized, performant metrics implementations for classification, regression, NLP, vision, and audio models.
Metrax addresses a gap in the JAX ecosystem, explains Google, that has forced many teams migrating from TensorFlow to JAX to implement they own versions of common evaluation metrics such as accuracy, F1, RMS error, and others:
While creating metrics may seem, to some, like a fairly simple and straightforward topic, when considering large scale training and evaluation across datacenter-sized distributed compute environments, it becomes somewhat less trivial.
Metrax provides predefined evaluation metrics for a range of machine learning models, including classification, regression, recommendation, vision, and audio, with particular support for distributed and large-scale training environments. For vision models, the library includes metrics such as Intersection over Union (IoU), Signal-to-Noise Ratio (SNR), and Structural Similarity Index (SSIM), Metrax also includes robust NLP-related metrics, including Perplexity, BLEU, and ROUGE.
Google notes that one of Metrax’s goals is to ensure that all metrics are well implemented and adhere to best practices. Where supported by the metric definition, Metrax uses advanced JAX features like vmap and jit to boost performance. For example, these features are used in the implementation of the new “at K” metrics to enable computing multiple values of K in parallel. This enables to evaluate a model more comprehensively and faster.
You can use
PrecisionAtKto determine the precision of your model for multiple values of K (say, at K=1, K=8, and K=20), all in one forward pass through your model, rather than needing to callPrecisionAtKmultiple times with each of these arguments.
DevOps engineer writing on Substack under the name Neural Foundry wrote:
The fact that Metrax supports computing multiple K values in a single pass is a huge win for ranking systems. Ive been rewriting metrics utilities every time I switch projects and this kind of standardization is long overdue. The API looks clean too. Curious if they’ve benchmarked it against custom implementations for specific use cases like large scale recommendation pipelines.
The following snippet shows how to compute precision metrics given predictions and labels. An optional threshold can be specified for converting probability predictions into binary predictions:
import metrax
# Directly compute the metric state.
metric_state = metrax.Precision.from_model_output(
predictions=predictions,
labels=labels,
threshold=0.5
)
# The result is then readily available by calling compute().
result = metric_state.compute()
result
Google has also published a notebook containing a comprehensive set of examples, including multiple-devices scaling and integrations with Flax NNX, a simplified API that makes it easier to create, inspect, debug, and analyze neural networks in JAX.
JAX is an open-source Python library for high-performance numerical computation and machine learning. While offering a
