Machine Learning Design Patterns: Reusable Solutions for Complex Problems

Introduction

Design patterns are a formalized way to capture the knowledge and experience of experts, providing actionable advice that practitioners can use to solve recurring problems with proven approaches. Just as design patterns exist in traditional software development, microservices, APIs, and game development, the same concept applies to machine learning (ML).

The unique challenges in ML, such as data quality, reproducibility, data drift, and retraining, necessitate specific solutions. This article explores the concept of design patterns, their importance in machine learning, and provides concrete examples.

What are Design Patterns?

The idea of design patterns originated in architecture with Christopher Alexander et al., who documented over 200 architecture patterns in their book, A Pattern Language. A pattern describes a recurring problem in a discipline and offers a reusable solution as a "recipe." Each pattern has a name, allowing architects to communicate efficiently about problems and solutions. The solutions are abstract, enabling architects to adapt them to their specific contexts and environments.

Software Design Patterns

A software design pattern is a general, reusable solution to a common problem within a specific context. It's not a finished design that can be directly implemented in code but rather a template for solving a problem in various situations. Design patterns are based on software development best practices and serve as a shared vocabulary for developers to communicate their intentions.

Using design patterns enables developers to discuss problems and solutions at a high level without delving into implementation details. A team familiar with design patterns can quickly reach a consensus with minimal misunderstandings. Design patterns offer experience reusability at the design level, rather than code reusability at the source code level.

Read also: Read more about Computer Vision and Machine Learning

The Need for Machine Learning Design Patterns

ML, like other areas of computer science, started in an academic context where scalability, reliability, and performance were not primary concerns. However, deploying ML models in production is now an engineering discipline, requiring the application of software and data engineering best practices.

ML practitioners must leverage existing software engineering methods to solve recurring problems and develop new ML-specific design patterns. ML projects present unique challenges, such as data quality, concept drift, reproducibility, bias, and explainability, that influence solution design. Documenting these problems, their context, and solutions is crucial for knowledge transfer, communication, and democratizing the ML discipline.

The book Design Patterns: Elements of Reusable Object-Oriented Software is a seminal work in software design patterns, documenting them using a specific template. While ML patterns are still evolving, a universally accepted standard for documenting them is expected to emerge in the coming years.

Understanding Machine Learning Design Patterns

Machine learning design patterns encode the experience and knowledge of experts into actionable advice for practitioners. They capture best practices and solutions to common problems in designing, building, training, and deploying ML systems. These patterns complement traditional software development design patterns, extend the software engineering body of knowledge, and help avoid common pitfalls by using proven solutions.

Observer

The Observer pattern is a behavioral design pattern that defines a one-to-many dependency between objects, where a subject notifies all its observers about any state changes. This pattern is an excellent example of applying the principle of making the coupling between objects as weak as possible.

Read also: Revolutionizing Remote Monitoring

In the context of machine learning, the Observer pattern can be used to notify various components of a system when a model has been retrained or updated. For example, a model monitoring service could subscribe to notifications from a model training pipeline, allowing it to automatically update its performance metrics and detect concept drift.

Model-View-Controller (MVC)

The MVC framework is a design pattern that divides an application into three interconnected parts: the model (data), the view (user interface), and the controller (logic). This pattern promotes flexibility, reliability, and security.

In machine learning, the MVC pattern can be applied to build interactive tools for data exploration and model visualization. The model represents the underlying data and model predictions, the view displays the data and predictions to the user, and the controller handles user interactions and updates the model and view accordingly.

Factory

The Factory pattern decouples objects, such as training data, from how they are created. Creating these objects can sometimes be complex (e.g., distributed data loaders) and providing a base factory helps users by simplifying object creation and providing constraints that prevent mistakes.

The base factory can be defined via an interface or abstract class. Then, to create a new factory, we can subclass it and provide our own implementation details.

Read also: Boosting Algorithms Explained

Adapter

The adapter pattern increases compatibility between interfaces, such as data formats like CSV, Parquet, JSON, etc. This allows objects (e.g., stored data) with incompatible interfaces to collaborate.

Decorator

The decorator pattern allows users to easily add functionality to their existing code. It saves the most recent x calls in a dictionary which maps input arguments to returned results. These fixtures can then be referenced in downstream tests. Think of fixtures as methods that generate data for testing expected behavior. These fixtures are called before running any tests and shared across tests.

Users can create new objects for each strategy (aka algorithm) and depending on the strategy object used, the context behavior can vary at runtime. This decouples algorithms from clients, adding flexibility and reusability to the code.

Strategy

The Strategy Pattern is used when you want to define a family of algorithms, encapsulate each one, and make them interchangeable. The strategy lets the algorithm vary independently from clients that use it. This is applicable to our TuningStrategy and EvaluationStrategy.

Iterator

The iterator pattern provides a way to go through objects in a collection of objects. This decouples the traversal algorithm from the data container and users can define their own algorithm. For example, if we have data in a tree, we might want to swap between breadth-first and depth-first traversal.

Pipeline

The pipeline pattern lets users chain a sequence of transformations. Transforms are steps to process the data, such as data cleaning, feature engineering, dimensionality reduction, etc. An estimator (aka machine learning model) can be added to the end of the pipeline. Thus, the pipeline takes data as input, transforms it, and trains a model at the end. IMHO, in ML pipelines, everything from data preparation to feature engineering to model hyperparameters is a parameter to tune. Being able to chain and vary transforms and estimators makes for convenient pipeline parameter tuning. How would model metrics change if we np.log continuous variables or add trigrams to our text features?

When using a pipeline, we should be aware that the input and output of each transform step should have a standardized format, such as a Pandas Dataframe. Also, each transform should have the required methods such as .fit() and .transform().

Proxy

The proxy pattern allows us to provide a substitute for the production database or service. The proxy can then perform tasks on behalf of the endpoint.

One example is using a cache proxy to serve search queries. Search queries follow the Pareto principle where a fraction of queries account for the bulk of requests. Thus, by caching results for the top queries (e.g., query suggestion, query classification, retrieval and ranking), we can reduce the amount of compute required for real-time inference.

Mediator

The mediator pattern provides a middleman that restricts direct communication between services, forcing services to collaborate—indirectly—via the mediator. This promotes loose coupling by preventing services from referring to one another explicitly and implementing to the other services' interface.

If there are mediators in our existing systems, we should expect to serve machine learning services through them instead of directly to downstream applications (e.g., e-commerce sites, trade orders). Mediators typically have standard requirements (e.g., request schemas) and constraints (e.g., latency, rate limiting) to adhere to.

Process Data Once

A key pattern when designing data pipelines is to process and aggregate raw data just once, preferably early on. This reduces redundancy and streamlines data processing jobs, making pipelines more efficient and maintainable.

However, it can be challenging to process and aggregate data in a way so it flexibly supports various use cases. Imagine we’re part of an e-commerce company that tracks user behavior via clickstream logs. The logs capture pages they visit, products they viewed, and actions they took (e.g., click, add-to-cart, purchase). To build a dashboard for the conversion funnel, we’ll want to group logs by session and aggregate metrics on page visits, click-through rate, add-to-cart rate, conversion rate, etc. You can imagine other teams wanting to do similar analysis on this data too. Thus, instead of multiple teams building duplicate pipelines to process the raw data, we can process it once and store it in a tabular format for everyone to use.

Human-in-the-Loop (HITL)

If our ML system involves supervised learning, we can’t get away with not having labels. And if we don’t have labels—such as when we’re solving a new problem—we can apply HITL to collect them. We can collect labels explicitly (e.g., annotation) or implicitly (e.g., organic clicks, purchases).

Human annotators can deal with ambiguous or complex examples where heuristics or automated methods might fail. One way to collect labels is to ask users directly. For example, Stack Exchange lets users flag posts as spam and then uses these labels to train spam detection models.

Data Augmentation

Data augmentation involves applying transformations to existing data to create new, synthetic examples. This is valuable when ground truth data is limited, imbalanced, or lacks diversity.

The classic example of data augmentation is in computer vision. In this PyTorch example, images are augmented through geometric transforms such as cropping, rotation, and flipping, as well as color transforms such as grayscale, blurring, and inversion.

Hard Mining

Hard mining can be considered an extension of data augmentation where we find or generate challenging data points to train the model on. One approach to hard mining is to analyze model predictions for misclassified or low-confidence examples, find similar examples (e.g., nearest neighbors), and emphasize them in subsequent training.

Reframe

To reframe is to make a problem easier to solve by refining the initial problem statement or transforming the target feature.

For Amazon’s semantic search, the team initially had a binary label: purchased or not. They then used a two part-hinge loss where (\hat y) is the cosine similarity between query and product embeddings, and (y = 1) if the product is purchased (in response to the query) and zero otherwise. However, they found a large overlap in score distribution (left in image below) between the random negatives (red) and purchased positives (green). After some analysis, they found that this was due to products that were impressed but not purchased. Thus, they updated their target label to distinguish between products that were impressed but not purchased (grey) and random negatives (red).

Cascade

The cascade pattern splits an initial problem into smaller problems that can be solved sequentially. Thus, each subsequent model or system focuses on an increasingly difficult or smaller subset of the data. This improves overall efficiency and system performance.

A classic example of a cascade is recommender system design, usually split into retrieval and ranking steps. Retrieval is a fast but coarse step to narrow down millions of items into hundreds of candidates. Ranking is a slower but more precise step to score the retrieval candidates.

Data Flywheel

The data flywheel pattern revolves around continuously collecting data which then improves models which then improves user experience. This is one of the few sources of long-term competitive advantage.

A prime example of the data flywheel is the recommendation systems at Amazon and Netflix. When users look for their next show on Netflix, their searches, clicks, previews, watch time, and ratings are logged. This data is then used to build recsys that learn from their preferences and serve personalized recommendations.

Business Rules

The business rules layer allows us to incorporate domain expertise and business rules to augment or adjust the output of ML systems. This ensures that machine learning outputs align with business requirements and constraints.

Evaluate Before Deploy

This pattern represents the best practice of evaluating model (e.g., evaluation metrics) and system (e.g., error rate, latency) performance before integrating it into production. To adopt this pattern, simply have a validation hold-out when (re)training models.

An Example ML Pattern: Rebalancing

A common problem in Machine Learning is imbalanced classes in classification or regression problems, which can impact the trained model's performance.

Problem

It's common to encounter Machine Learning problems (classification predictive modeling) where the dataset classes are imbalanced, meaning the distribution of examples across the known classes is severely biased or skewed. This presents a challenge because most classification algorithms are designed assuming an equal number of examples for each class. Consequently, the trained model may perform poorly on the minority classes, which are often more "important."

For instance, consider credit card fraud detection or a model to predict melanoma from images. The problem with imbalanced classes is "blindly" believing in accuracy values. If only 3% of a dataset contains melanoma images, a model might achieve 97% accuracy by simply guessing the majority class (no melanoma) for each example, without learning how to predict the minority class.

In regression modeling, imbalanced datasets occur when the data has outliers that are either much lower or higher than the median.

Solution

Since accuracy is affected by imbalanced classes, the first step is to choose the right metric to evaluate the model. Techniques can be applied at the dataset or model level. At the data level, downsampling or upsampling can be used. At the model level, the classification problem can be converted into a regression one.

Choosing an Evaluation Metric

For imbalanced datasets, metrics like precision, recall, or F-measure are preferable to accuracy.

Definitions: