Key Takeaways
- Federated machine learning is useful for edge devices with limited network bandwidth, since only model updates need to be sent to a central location, instead of large volumes of data
- Federated machine learning can improve data privacy, but it is not guaranteed
- Learning on edge devices can improve data diversity and enable model personalization
- Hosting models on the edge device allows for predictions even when the device is no longer connected
- Edge devices represent an opportunity to gather targeted data to help us answer specific questions
At QCon Plus 2021, a virtual conference for senior software engineers and architects covering the trends, best practices, and solutions leveraged by the world's most innovative software organizations, Katharine Jarmul spoke about machine learning on edge devices using federated machine learning.
Introduction
When we think about the edge, we typically think about small embedded devices: Internet of Things (IoT) or other types of devices that might have a small computer in them. This could even include things that aren’t obvious. For example, I recently learned that these little scooters that are all over my city in Berlin, Germany---maybe even yours as well---are collecting quite a lot of data and uploading it.
Suppose we were designing these scooters and specifying the data they could collect. We put on our data science and machine learning hats and think about the problems that we might want to solve. For example, we might want to know which scooters need maintenance. We might want to know about road and weather conditions. We might want to know about driver performance: could it result in some problem for the scooter, or for the driver, or for the other people and things around the scooter?
These are the types of questions we ask when we think about data and machine learning. When we think about implementing it on the edge, or with embedded small systems, we can encounter problems, because traditional machine learning needs extra information to answer these questions.
Traditional ML
Let's take a look at a traditional machine learning system and investigate how we could collect this data and answer questions. First, all the data would have to be aggregated into a data lake. It might need to be standardized, cleaned, or pre-processed. Eventually, that data would be handled by a data science team, who might do exploratory data analysis (EDA).
Once we have decided on the data preparation steps, we could automate them with a pipeline. Then that data would feed into model training, testing, validation, and selection. Ultimately, what we're trying to do as a data team is figure out what's the best way to solve a problem with the machine learning tools and the data we have available.
This is an iterative process. We might decide at the end that we have one or two models that we'd like to try, and we're ready to deploy one or more of them. At that point in time, we would deploy them, usually to a cloud service or other API that is reachable by the system that’s going to use the machine learning. That model could even be deployed directly to the edge hardware.
Traditional ML: Limitations
We can immediately start to see some limitations in this approach. First off: we need the data collected and sent over the network. The network connection the scooters have is probably not stable, and doesn't have a lot of upload bandwidth. For this reason, data collection is going to be a problem. Then we have to ingest the data into a data lake and aggregate across many locations. Then we can begin to train the model.
When we think about things like mobile devices, small embedded systems, or robots in factories, we can see some other problems: not only from a connectivity point of view, but also from a memory, storage, and compute power perspective. In addition to collecting and transmitting data, the devices will also have to interact in real time with a cloud API to make a decision if we want to use a recently updated model. These are some concerns and limitations that we can immediately see.
What Is Federated Learning?
How do we solve these problems? One approach is to use federated learning. Federated learning was initially developed by Google, as a solution for using machine learning on Android devices to solve some similar problems.
Let's take an easy example: suppose we want to build a better predictive text model for Spanish speakers by using keyboard data from mobile devices. The federated learning lifecycle is shown in the figure below. We start with a trained model, represented by the blue circle. We will send it out to all the Android devices that meet some selection criteria. First off, we only want Spanish language keyboards, so that's going to rule out some devices. Then we want the process to complete successfully, so that means that we want devices that have high battery power or that are charging. We want devices with a good mobile or internet connection.
In step A, the model is sent to those phones and machine learning happens on the devices. Each device is essentially doing part of the training---this can be an iterative process or a one-shot process---with a model update at the end. That update is not the data itself, but instead a representation, usually in some vectorized form, that indicates the direction that the machine learning model should go to improve, based on the data on this device.
All of those updates across all of those different devices are collected, as shown in step B, and sent to a centralized aggregator, which then uses some computation to combine them. This can be averaging, summation, or some other computation, depending on the algorithm and the architecture of the model.
Finally, in step C the updated model is sent back out to all the devices, completing the cycle. Presumably, another round could start, or if we determine that the model is good enough, we stop here.
Federated Learning Architectures
This is the classic federated learning architecture, where we have one centralized aggregator with distributed participants. However, there are other approaches, such as clustering participants together and deploying multiple aggregators. Depending on your current infrastructure, these might be better methods, and the data science and machine learning teams and your infrastructure teams should work together to determine the best architecture for your needs.
A fully-distributed architecture might be appropriate for scenarios like multi-party computation or other forms of distributed encrypted computations that allow multiple participants to do the machine learning on their local devices. These devices must coordinate with one another to combine their updates and to send the result out in a meaningful manner. If you have exposure to cryptographic protocols and communities around multi-party computation, or other types of multi-party approaches, this might be an interesting way to think about the problem. Of course, this architecture also adds complexity for coordinating all the different devices.
Use Case: FLoCs
It helps to look at a practical use case. We're going to look at Federated Learning of Cohorts, or FLoCs, also developed by Google. It's essentially a proposal to do away with third-party cookies, because they're awful and nobody likes them, and they are a privacy nightmare. Many browsers are removing functionality for third-party cookies or automatically blocking third-party cookies. But what should we use in order to do targeted, personalized advertising if we don't have third-party cookies?
That's what Google proposed FLoCs would do. Their idea for FLoCs is that you get assigned a cohort based on something you like, your browsing history, and so on. In the diagram below, we have two different cohorts: a group that likes plums and a group that likes oranges. Perhaps, if you were a fruit seller, you might want to target the plum cohort with plum ads and the orange cohort with orange ads, for example. The goal was to resolve the privacy problems and the poor user experience of online targeted ads, where sometimes a user would click on something and it would follow them for days.
Use Case: FLoCs - How it Works
Let's take a look at how it works, and identify the federated learning part. The first thing that happens when a user opens up a FLoC-enabled Chrome browser, is they get assigned a cohort if they don't already have one. The details weren't shared, but it's probably an iterative process. It's likely that a user’s cohort will change over time, because, obviously, their interests and browsing patterns change over time. Let’s say, in the “plum” cohort, federated learning finds that five different users visited plums.com today. That small piece of information could be sent as an update to the aggregator, which could then update the targeted ad model, and the process continues.
As you can see, this is not a typical federated learning process because the ad targeting model is likely not going to be shipped out and running on all of our devices. This is a hybrid model where some parts are centralized and some are federated. The cohort model is shipped out, and so it will probably be improved across cohorts over time, and those updates will be distributed across devices.
Does this address the initial privacy and user experience problems related to third-party cookies? It will probably lead to better ads, because it's an iterative process, and a user is not forever in a bucket based on something that they liked five weeks ago. However, from a privacy perspective, does it guarantee something? Is membership in the “plums” cohort private to a user, or can this be public in some way? This is a higher level ethical and philosophical question. It's interesting to think about: what are the qualities that users would want representing their cohorts? Google already said that medical history and sexual preferences would not be included, but are those the only categories?
Use Case: FLoCs - Privacy Attacks
The Electronic Frontier Foundation and several prominent technologists have critiqued the FLoC approach, because it's fairly easy to discern potential privacy attacks. For example, someone using FLoC arrives at our website, and we present them with a login screen. Now we have both their cohort and their login. If we use browser fingerprinting, now we have the browser fingerprint and their cohort. If cohorts are only a few thousand people, these scenarios mean it might be fairly easy to deanonymize or re-identify those individuals.
Large databases of these associations between FLoCs and other data could be collected by single businesses that have a large web presence, or could be aggregated across many partners and properties. This might even allow the reverse engineering of cohort selection, or even who those individuals are from the cohorts. FLoCs are not a silver bullet for privacy.
Federated Learning: Benefits and Weaknesses
What are the benefits and weaknesses of federated learning? One big benefit is that there's no centralized data collection, so we don't have to send and centralize the data all in one place. Instead, the data remains on the device, and we can use federated learning to get small updates to send to the aggregators. A weakness of this is that the data must be standardized on the device. With traditional machine learning, we had a large pipeline. We could change and clean and segregate data in different ways. When it's on the device, it needs to be ready for machine learning, or it needs to be quickly ready for machine learning. This is great if we already standardized data on our devices, but it might be hard if we haven't thought about that problem yet.
The second benefit is more diverse datasets from many edge devices. That's one of the reasons why Google deployed federated learning: how do we get better targeted improvements for particular types of machine learning that we're trying to do? However, with more diverse data from more varied users or population, we also have to think about the fact that it's unevenly distributed.
Let’s go back to the scooters. If I'm riding a scooter in Berlin, and you're riding a scooter in a different city with a very different commute, then our updates might actually cancel each other out. The data science team would have to consider: are we going to be able to sub-select populations that allow us to diversify our data without creating unnecessary bias due to the unevenly distributed data that's available?
Finally, a benefit that is often undervalued is on-device machine learning. This gives us the ability to answer questions in realtime on the device or to personalize a model. A weakness is that all the edge devices must have a copy of the model to operate on. If there is proprietary or confidential information in the model, then we should think about the appropriate security concerns there.
Finally, privacy is typically touted as a benefit of federated learning, and it definitely is. It's better to send a small update to a centralized location than it is to send the entire browsing history. However, those small updates can still leak information. There are several other types of attacks on user privacy like the one that we just talked about for FLoCs.
The benefits and guarantees of privacy are very dependent on the implementation. If we are really choosing federated learning for privacy benefits, then we should think about also employing techniques such as secure computation of the aggregates, or adding differential privacy and gradient clipping.
Federated Learning: When Does It Make Sense?
When should we use federate machine learning? The flowchart below can help us decide that. First, if traditional machine learning works for an application, then it probably makes sense to use that. Federated learning has a lot of benefits, but it's also still quite new.
If traditional machine learning doesn't work, the next two questions are: is our data standardized? Can the devices run federated learning software? If not, it might be better to simply try distributed data analysis first, which simply queries the data on the devices, sending just the aggregated results back to a central location. Then our data team can take a look at it and start answering the questions. Eventually, this could lead to rolling out a fully federated system.
If the data is standardized, and the devices can run the software, then the final decision is on the federated architecture: centralized or clustered? A centralized aggregation is a good default choice, as this will have less complexity and be simpler to start with. However, if our system is already designed in a clustered architecture, maybe a clustered federated learning architecture is a better choice.
Some other factors to consider are the privacy requirements and guarantees, which, again, are not automatically built into federate learning. We should discuss these with privacy and security experts on our team. There are also connectivity requirements: what happens when the device is lost? Is there a certain tipping point when a particular amount of devices are lost where we want to actually halt the learning round and stop the process? These are questions that we have to combine the knowledge from our software and infrastructure teams with the knowledge from our data team.
The Promise of Distributed Data Science and Analytics
I strongly believe that the future of machine learning and data science and analytics is distributed. When we think about the problems that we face as humans, as parts of our many societies and our one world, there are many situations where if we had intelligent edge devices, we could likely help nudge things in positive directions. For example: adapting to climate-related changes in farming, dealing with natural disasters, or tracking pandemics across the globe. For all of these things, we could deploy edge devices that would be appropriately positioned to ask and answer questions.
These devices could help create early warning systems, to better prepare for civil servants or NGO workers that are trying to solve these problems. Instead of trying to collect all possible data, and then figuring out answers later, a much more intelligent and much more privacy-aware way to solve these problems begins with asking several questions. What problem are we trying to solve? How can we use data from edge devices to help? Are distributed data analysis and federated learning a good fit? Asking and answering these questions can help ensure we’re collecting the right data from the right places.