Transcript
Tramel: Today we're going to be talking about federated learning. What this is going to start with is just a very high-level birds-eye view of federated learning and what it can mean, and then we're going to jump into some of the challenges about what happens on the ground. Like Mike was saying, I'm really excited to be talking about this here in this crowd, because as you heard in the introductory speech this morning, this is really this intersection between software engineering, ops, machine learning, and trying to find that nice space in the middle, that nice interface.
Federated learning is a real crucible because it brings together even more, so it's really an interface between data science, machine learning, engineering, DevOps, software data, and security engineering, and bringing all this together in one spot. If you don't think you have any organizational weaknesses between these skill sets in your company, you're going to find out real quick that you just haven't found them yet, once you start taking on a federated project. Despite some of the challenges that we're going to talk about today, I think you'll see that there's some real rewards for this.
Just a little bit of introduction. Who am I, why am I talking about this? I'm Eric Tramel, I'm a Ph.D. researcher, and I lead our federated learning R&D team at Owkin. What is Owkin? I can give you just one-minute pitch, we're like the littlest best biotech that you've never heard of. We're located in Europe and here in the U.S., and one of the things that we're really proud of building through a lot of hard work of our team is our Owkin loop, which is a medical research network that brings together 30 hospitals and research institutions. Our goal at Owkin is really to advance clinical research and improve the pace of research, especially in oncology, by bringing AI and ML to researchers. Federated learning is one of the ways that we see to do that, so, how can we bring together high-quality clinical research data sets?
Machine Learning Today
let's talk a little bit about federated learning, so, to start with, what is machine learning today? Machine learning today looks like gather and analyze the situation. For a little bit of this introductory part of the talk, I'm going to be framing things in terms of mobile devices, but this applies broadly if you have computation available on the edge of your network. What does machine learning look like today? We start with acquisition, we have a bunch of devices, we have users, participants in our systems that are providing data. Maybe we're generating the data ourselves, but somehow, we have data that's split between many different sources.
To run conventional machine learning pipelines, we need to bring all this data together because the compute needs to happen where our data is. If we take a look at the mobile setting, what do we need to do? We need to acquire this data, so, we need to get our users to agree to send us that data, which may be more or less complicated, and then, we need to pay the cost of bringing and gathering all that data in one spot. Then also too, if we want fresh data, we need to continually do this, bring more and more data to us.
Once we have the data, we have a number of costs and risks that come up in just maintaining data that we've gathered from multiple sources. If we're dealing with data from users, we have, for example, GDPR compliance in the EU where we work a lot, which is saying, a user can ask, "Oh, I don't want my data to be used in your system anymore, please remove it." Now you have to dig through all of your infrastructure and find out where that data is, remove it from everything, and then give a certification that you've actually done that.
If you're storing personally identifying information on your side, then you need to make sure that you are following all the proper security protocols, and maintaining that, because there are always threats that loom from people trying to take advantage of those situations. Also, you just have to pay the maintenance of keeping all of this data. Additionally, once you have all this data in one spot, you're going to need some big compute, potentially, if you have a lot of data and a very complicated problem, then you're going to need a lot of compute, and one of the common themes in machine learning is you start with the little data, then you do something simple. You get some more data and then you think you might do something a little bit more exciting, so, things start to get more complicated. You need more horsepower and then the cycle continues, you end up using more and more resources, so, you need to be able to scale your computes along with your data.
After that, you have some other things to consider in this gather and analyze. FAANG has kind of gotten a little bit of a bad rap in recent years, potential PR blowback of vacuuming up a lot of user data, and then people wonder, "Well, what happens with it once it's in your cloud? Where does it go, how is it used? Is it passed on to third parties?" Some particularly big companies got into to some particular hot water about this. Then also too, if you're dealing with certain kinds of data that are very strongly regulated like health data, maybe this is an extra burden that you don't really want to take on, but maybe there's some particular challenges, some tasks in machine learning, something important for your business that you want to do with this data, but the regulations around moving that data and storing that data on your servers make it difficult to manage.
The Federated Approach
All of these give rise to a different kind of approach, which is, what can we do if we don't move the data? That's where the federated approach is, so let's take a look at these same challenges in the federated setting. In the federated setting, the idea is to keep the data where it originated, so, if it's user data that's generated on a mobile device, keep it on the mobile device, never take it off that device. If it's data that's generated in a hospital, never take it out of the hospital, leave it there.
For machine learning test, the compute is happening where the data is, so now we have some other extra complications, we need to bring the compute to the data. We'll go into those challenges in just a bit, but let's think about what is the potential of this federated setting if we can do all of the compute on the edge on the devices and still get high-quality machine learning prediction models?
In terms of acquisition what cost do you pay? Not so much, you're not moving the data, you're not uploading it to somewhere, the data is always fresh because you can make requests to the user device, they have their fresh data there. Also, you can limit your user agreement because they're not passing it off to somewhere and what happens up there, the user can exactly specify, "Oh, I'm ready to use my data in your system now and not tomorrow, but yes, on Thursday." so, that helps.
In terms of maintenance, you're not maintaining a data store because your users and their devices are the data store. Also you're not moving the PII to yourself, so you're not taking on that extra liability. In terms of compute while mobile devices, let's say, might be quite limited in terms of if you compare them to big beefy GPU clusters, there's still a lot of idle CPU that's available on mobile devices that isn't being used all the time, especially while you sleep, so, there's a lot of extra compute that's available on the edge, and we can use this to train models.
The other nice feature of this is that as you get more devices in your system that have data, they come with their own compute, so you get a little bit of scaling for free. If you get another user, you have more data and more compute. More data and more compute, it goes on and on and on, so you get this "free" auto-scaling. Then because of this limited agreement for the users, they can decide to not participate or to join in the network, and this gives more control to the data owners. This is a real big PR win because you don't have to appear as the one that's taking all the data from everybody.
Use Cases of Federated Learning
Let's talk about a few use cases of federated learning, and maybe you're already thinking about a few of them. What I'm going to introduce is just a few use cases that I've seen a publicized and that I think are particularly interesting. One of the first ones is the use case that was originally popularized by Google, maybe you've seen their original blog post on federated learning, and this is what they're already pushing out to Android devices right now. What they've advertised is doing or training language models on personal text data.
Here when we think about potentially hot data that you don't want to move, your personal text history is not something that you probably feel very comfortable with sending off to Google to do with what they will. A technique like federated learning really offers a lot because if you can train a general nice language model for text prediction on the mobile device, but without moving all that data off, then you suddenly gain a lot. You gain access to potentially one of the largest data sets in the world, and you can use that without having to incur all of the burden of moving that data.
Another nice use case was in the Firefox browser, there was a study that was done by an intern there on doing a URL all ranking prediction. This is another situation where you might have potentially hot data and you don't want to share all of your browsing history with Firefox all the time. What's done here is a very similar situation, one wants to learn a ranking model, but this model can be trained locally on your computer, and then this information is aggregated to train a better overall model.
Another task that we've been looking at is collaboration between pharmaceutical companies in the context of drug discovery. Here you have competitors that they say, "We understand that if we had all of our data in one spot together, we could train some really nice representation models that we could use to really speed up drug discovery. However, our data is our lifeblood and we don't want to share it." This is billions of dollars' worth of investment, so, there's really high incentive to not collaborate, but also some incentive to collaborate. Technologies like federated learning can really offer a lot here, federated learning has enabled one of the first of its kind, grants from the EU that we'll be leading, which is on providing a federated learning system between 10 different pharmaceutical companies in Europe, and this is really cool to see.
Another use case is in the hospital, maybe trying to train very large models, maybe U-Nets for doing tissue segmentation, and then trying to train these collaboratively between multiple hospitals. Intel had a very interesting paper on this, and in the U.K., King's College is actually putting together an initiative called their AI Center where they're going to be doing a federated learning deployment between four different NHS trusts to do these very large radiology machine learning model training.
Finally, another nice one was another [inaudible 00:13:11] startup working on wake-word detection for digital assistance. Here, many people have different ways of pronouncing particular wake-word for these assistants and from all of these, you want to train one very nice model, but we don't want these devices recording us all the time in our homes. All of these things could be powered by federated learning.
Federated Averaging
How does that work in practice? What are the actual machine learning operations that are happening under the hood? What we're going to do is take a look at one simple approach, it's simple, but it's actually very effective, and in fact, it's the algorithm that's powering all of these use cases that I've just shown you, and you're going to be amazed how simple this is. It works like this, we're just going to break it down really quick. You're going to start with a model, some initial state, maybe it's just your random initial state, maybe it's some other model checkpoint, so, you're going to start with this at some model server. Quite simply, you're going to ask for some available workers “Who's ready to do some training?” Maybe you have a population of devices, if you're Google, it's millions, tens of millions of devices, if you're in our case, it's tens of research institutions. You're going to ask which ones are ready and available for training. After this, you're going to just send down the model checkpoint that you want to train, you just synchronize this to all of your available workers. Subsequently, you're going to do some steps of some local fit, this could be many, it could be few. This is, in fact, a hyperparameter of your federated learning training.
Based on the data that's available at each of these different sites, each worker, or mobile device, or research institution, this is going to lead to subtly different models, which are going to be a little bit biased towards the data that's available in each one of these places, but it's going to give some representation of the overall data set. After you perform this local training, you're simply going to take all of these model updates, how much they moved from their initial setting, and report these back to your central model server who's going to take a look at everything that has been given and average it, the federated averaging. Congratulations, you're done. After this, you can repeat this for many rounds.
Challenges in FL
All of this looks really great on paper, but there are some challenges. Maybe you've already seen some of the bottlenecks that exist in a system like this and some potential risks, and I'm going to go through some of those that we've found as we've been working on these systems. There's a lot of very practical challenges, and some of these look a little bit daunting, but I'll say that a lot of these have already been surmounted in production systems. From here it's a matter of how do we find the most elegant way of dealing with these challenges, and then, how do we provide toolsets that really open up this technique for even more companies? This leads to a no free lunch theorem of federated learning, which is anything that's worthwhile is worth working for, which you've probably heard your father say at some point.
One of the first challenges associated with federated learning is workflows. What does the standard workflow look like for a machine learning engineer? A machine learning engineer, depending on your organization, maybe you also call them a data scientist, is going to sit down and look at a problem and think about how to solve it. What architecture do I need to use? What loss? What am I going to measure, what optimizer? Are there hyperparameters? What's the pipeline for producing batches? What's the right features? How am I going to need to augment my data? All of the things that you need to be able to produce a machine learning solution.
They're going to come up with all this. Maybe they're going to use PyTorch, maybe they're going to use TensorFlow, maybe they're going to use something else that hasn't been invented yet or something that has, and they're going to take all of these things together and then say, "Great. I've written it. Now run it." Okay, maybe that works. you can take this, you can run it locally, do a little debug, that's fine. Maybe you put in a little bit more work, and you containerize this whole thing, you deployed on some other machine learning infrastructure, you deployed on the cloud even, for training continuously on new data, and you do a lot of work here.
What I should say is that these workflows are really focused on a single machine learning runtime. Maybe if you're doing some extra work in TF Distributed, and you want to spread this out over many GPUs, it gets a little bit more complex, but overall, you're writing a piece of code that you expect to run in one place. You're just going to take that code and run it somewhere somehow.
When it comes to federated learning, you're writing code, but where is it going to run? It's running remotely somewhere else in an environment that the ML engineer doesn't necessarily have direct control over or quite understands the specifications for. In this you say, "Okay. Well, also what do they need to produce? Are they producing a script? Are they producing some other kind of a set of code? What is it exactly that they need to build to be able to use it on a federated learning infrastructure?"
This also introduces another interesting point of who is the person that controls the federated learning specific design choices that you're doing, which is how often do you need to communicate? What style of federated algorithm are you using? Is this the ML engineer that produces the original solution for your problem, or is it somebody else in your organization? Is it someone more on the off-site? Who is making these decisions? This is when I say that you're going to find out about these organizational weaknesses. It's this question about, where are the roles when it comes to federated learning that really bring a lot of this stuff to the surface?
What this means is that you need some kind of intermediate description to try to decouple the work of the ML engineer from the work of the federated learning ops side of things. This intermediate description, what it can allow for is saying, "Oh, I want to write my code in PyTorch, I'm going to describe my model in PyTorch. I want to describe my optimizers and PyTorch, and I want to just pass that code off to somewhere, and I just want this to be the one description that I write and have it work everywhere," or TensorFlow, of course.
One really needs this intermediate description, not just of the training or experiment that you want to run, but really of the operations that you want to do here, operations in terms of the actual machine learning code that's running. We heard a talk earlier that mentioned the need for a byte code for machine learning, this is where this would be really helpful in the setting of federated learning, some kind of standardization here for intermediate descriptions of machine learning code.
There are some approaches to this already that you can find in the different opensource frameworks. There was a recently released TensorFlow federated, some Google engineers invented a whole new language that isn't even the graph, it's not C code, it's yet another language for describing machine learning operations in, a MapReduce format as this intermediate description layer, so the complexity of systems starts to go up a lot. There's another approach that's taken in the context of PyTorch by the Openminds Project and their [inaudible 00:21:53] module, where again, it's about trying to abstract the operations that you want to do.
Once you have this intermediate description of your federated learning problem, you also need to think about where it's getting run. What is actually catching these operations and machine learning operations that you want to do and then actually running them? You have to think about, "Ok, what is the device? What is its context? How does it get created? What are the interpreters that you need to build?" Unfortunately, there's nothing quite off the shelf for this in the context of federated learning, so, this is something that you have to put together on your own, which is one of the challenges to implementing these systems in the wild.
This it elucidates the need for this abstraction layer from the original ML engineer that described the algorithm that you want to run, because if you're having to do very specific things on your device, like you're using a kind of a TF lighter, you're using some very device-specific libraries to make this stuff happen. This is not something that you want your original data scientists or ML engineer to have to work with or try to munch about with. This is really where the power of that abstraction comes in.
Then finally, you have this other big problem of this orchestrator itself. What is this orchestrator? When a federated learning job is submitted, there's something that needs to catch that, that needs to go out and find what devices are available on the network. Which ones to accept right now, which ones to reject, how to send to them the commands that need to be run, and this is the role of the orchestrator, this is a pretty heavy task. Just like with the device context, there's nothing that's built out there that you can just take off the shelf and use. There's a very nice description and a reference architecture that Google's given and they had a technical paper that you can check out in the slides after this talk, it's linked here at the bottom, where they describe how to build this service, but it's something that you have to start from the beginning. It can be more or less complicated depending on what exactly you're trying to build.
Once you've done all that and you have your system deployed, and you say, "Ok, we're ready to start doing some federated learning," now you have to start thinking about, "Ok, now what's the workflow for the data scientist that's working on this kind of a system?" The conventional approach is, you say, "Oh, ok, I want to start working on a particular problem. I've got unknown data I'm going to start taking a look at that data, I'm going to try to understand the features. I'm going to try to understand the distributions. I'm going to dig into it so that I can find the right approach to use." In this setting, this is problematic because you can't pick up that data and bring it back to your servers to analyze it in such a nice and neat format for your data scientists, so, you need to bootstrap in some way.
A more bootstrapping approach would be to go out to public data to try to make a hypothesis about the kind of algorithm or architecture that you want to use, and it changes the mindset a bit. Working in the federated context, your data scientists can't have a very rapid iteration of looking at the data and then going back and forth. You really have to take this like a scientific approach of hypothesizing and then making a prediction about what the system would do with this given architectural algorithm, and then making the observations about what it comes back with. It's a little bit of a slower process, it takes a little bit more work on the front end to think about the kinds of things that you want to deploy to your federated system, which necessitates the need for parallel experimentation. The kind of system that you want to build, you want to build it in such a way that it can really support not just one experiment running on your system, but many, because you're going to try to say, "Well, maybe this algorithm will work, but I don't exactly know and it's going to take some time to understand if it does. So, what about this other architecture, or this one, or maybe at this rate, or with that batch size?" You need to have this parallel experimentation going on in your system.
Then, it comes to privacy, now you think about, "Ok, so what are we moving?" If you've accomplished everything before this, you say, "All right, well, the data is not leaving the device. The data is not leaving the hospital. It's not leaving our mobile device, so, everything is safe" No, you have to think about everything that you're moving, so what are you moving? From the user device is going out some kind of update to a model, which, if you're training some deep conf net, maybe it's hard to try to understand what the specific features are that could be used to identify your users, but maybe you're trying to train a different kind of a model. Maybe it's user sentiment, and you're trying to understand, they like this thing more or not that thing. They really like sushi, they hate dog parts, and this could end up revealing some kind of information.
What one has to do is to try to restrict the access to all kinds of information that are leaving the device, that sounds a little bit challenging. How do you restrict all information coming from this device, because who's going to see it? Is it a man in the middle? Well, you can encrypt the communications and you can sort of mitigate that risk, but then what about the server at the end? You need your users to trust you and your server that you're not digging into their information, so how can they trust you?
One of the ways that you can do that is through multiparty computation. This is a technique to encrypt all of the information, in this case, model updates, which are coming from the individual devices, and encrypt them all with individual keys such that the server on the other side can't decrypt these individual contributions. However, with the right secure multiparty computation protocol, like secure aggregation, what you can decrypt is their average. There are many different flavors of NPC that you could be using and for different kinds of tasks, but what you can guarantee with this kind of approach is you say, "Well, we're never looking at an individual contribution. We can only see their aggregate or their average. So, we can only see the overall statistic but not the individual."
This is the way that you can show that you're not digging into any single one's personal data, except for the fact that if you only had two users, it's not such a strong guarantee, but as you have more and more participants in your network, then this aggregation becomes less and less sensitive to an individual. This is actually the kind of system that Google uses in their federated learning production system. The cost of doing this is really just in time, the cryptography is mostly lossless, but there are multiple rounds to this protocol, so you spend a little bit extra time to do this.
Collaborative FL in Health
A lot of challenges that we see, a lot of things to try to apply, but I want to talk now a little bit about how these sorts of systems get used in healthcare, and the different kinds of challenges that we have to think about here. This is really what we've been seeing and addressing in our experience at Owkin. Why would one want to approach collaborative federated learning in health care? What are the benefits? The first one is quite clear, if you have more data, there's more that you can do, quite bluntly, but specifically in health, you have situations where, like for rare disease, there's maybe a few available patients or pieces of data inside each of the hospitals, and so, maybe within one region, or one hospital network, or even within one country, you might not have enough data to build up a very powerful machine learning model to investigate these diseases, but by joining together many more institutions, you can do something really useful.
Another use for this, especially for drug development is to try to generalize across multiple population centers. You can pick up data or make a contract with one single hospital and try to understand how a drug responds in one population center, but this isn't going to do very well for the drug company that wants to sell its therapies across the entire world. You really need to join data from multiple different sites, but this starts to become quite a legal burden in trying to make many, many individual contracts to do this.
The other thing that federated learning and the hospitals can do is incentivizes hospitals to try to make their health records machine learning ready, because with the right systems, you can give attribution back to the hospitals and say, "Hey, your data was used like this, and it helped in this way, and now your data that's just been sitting on your servers suddenly has value." It not just serves the operator of the federated learning system, but it also serves the hospital and the researchers inside of that research hospital too because now suddenly, they have more data sets and that they can make use of. There are still a lot of non-technical challenges associated with this setting just in terms of health regulations, the legal and contracting, fighting innovation departments around intellectual property is complicated. I talked earlier about competition between pharmaceutical companies, but you'd be surprised how much competition there is between hospitals. It's just astounding, I wouldn't have thought it was that way, but it is.
What are some of the challenges that we face here? One of the first ones that we saw early on was large data and large models. The data is not moving, but say you're working on trying to do volumetric units, to do segmentation tasks on whole volumes of MRI, or CT, now you're dealing with quite large models, so certainly upwards of 500 Megs, just for the model weights. Now you want to do a distributed training, so what do you need to do? Maybe you need to push those weights one way or another, now you're starting to incur a lot of bandwidth. The thing is that bandwidth in the data center comes for free in some sense, it's just a little bit of time, but bandwidth has other implications when you're in the hospital setting, because that bandwidth is not just powering your machine learning task, but other critical systems in the hospital, so, you really need to restrict your usage.
What you have to do here is try to develop algorithms that communicate as little as possible. Through quantization applied to the model, applied to the gradients, changing how you do the distributed learning strategy to reduce the amount of communication rounds that you need, all of this is a very big problem in the health setting. The nice thing is that there's some very cool research that's been happening over the past couple of years that can show if you're training DCNNs on something like ImageNet, you can get like 10,000 X compression. Apparently, there's a lot of information when you train a neural network that's really useless, so it’s a nice feature.
Another problem in this setting is the networks are a lot smaller, so it's not tens of millions of users, but it's really tens of research institutions. Here the effects of biases in heterogeneities can cause bigger problems than your model training, you really need to take special care here in monitoring the progress of training and trying to understand which institutions are pushing in one way or another and try to develop your algorithms in such a way that they're robust to these kinds of heterogeneities. This is especially prevalent in hospitals where you have different lab procedures, different procedures for coding. This is like, what are you going to call one disease versus another? All of this can change from hospital to hospital and create big problems when you're trying to do machine learning in health. Another big step here is to make sure that you have consistency in how the data is generated, which if you own the app or mobile device and you produce your own data, then it's quite nice, but in this setting, it's quite difficult.
Second to last problem is traceability, this doesn't crop up so much in a big consumer setting for federated learning, but when you're dealing with health institutions that have tight regulations on data use, you need to be able to produce logs of what data was accessed, when, how it was used? You need to do this in an unforgeable way. At Owkin, the approach that we've taken here is to use technologies like Hyperledger Fabric to track and to trace each of the machine learning operations that's happening in the different hospitals, so we can provide that guaranteed record of everything that's occurring. This is also what allows us to give value back to the hospitals, because they can see, "Oh, my data was used like this. Oh, and it boosted that model that way. Nice."
The other big problem here is having to go even the extra mile in terms of privacy. Here it's not enough to give a hand waving approach to say, "We just average over many users and therefore don't worry about it." but one actually needs to be able to give a demonstration of privacy. Here the challenge is really about saying, "Ok, if patient A is in the data set, we need to be able to demonstrate that gradients produced by that model that are transmitted, and also too, that the final model can't somehow be used to identify that user A was in that data set." This is what we call a membership risk in trying to say that there's no risk in participating in the training set. This can be approached through smart application of differential privacy and some other strong permissioning around the model use, and ideally, secure enclaves for all of your machine learning operations, so a big challenge there.
Takeaways
The takeaways for this are: federated learning is a real C change in terms of approach to machine learning operations, and in my opinion, it's really going to be the future of machine learning on personal data. The implementation is non-trivial, the challenges are there, but they're certainly surmountable, and they already have been surmounted by companies, so it's something that can be done. Lastly, get ahead of the curve, there's some open-source frameworks for starting to toy around with this. If you start to see some ways of how federated learning can work in your business model, go ahead, dive in, have your ML engineers or data scientists start taking a look at it using those open-source frameworks and see what can be done, because the tooling is only going to improve as time goes on, and I think it's a bright future for federated learning.
Questions and Answers
Participant 1: You mentioned GDPR requires the person to be able to remove their data from a system. Have you really removed that data from a system if you've trained a machine learning model with it?
Tramel: This is something that's nebulous in terms of both the legislation as well as the legal opinion about what that means, and you're going to find different interpretations. On the first part, can you remove their data from your data lakes? Yes, you can find that record, trace everything down in the backups, and then get everything out. In the other case of how it's used for machine learning model, I haven't seen anything yet that says, "Ok, you need to get rid of your neural network because this one user's data was used in it," because there are many other users that also contributed to that model at the same time, and so, what was that one user’s contribution, it's hard to measure. In this case, you can't necessarily take the data out of the model, but you could certainly try to retrain the model without it, which happens naturally in the federated learning system if they say, "I'm out. I'm not using it again." If you relaunch your training, again it's going to be without the data.
Participant 2: I'm unsure if there's actually research on this, but is there ability to look at a privacy budget for an individual over time, so to actually implement differential privacy over time with sequential learning in a federated situation?
Tramel: There've been some interesting applications of differential privacy to SGD. There's a very nice SGD paper, and the thing is that the privacy budget is going to be given on a data set by data set basis. You can show the differential privacy accounting for one data set, and see when you want to restrict access to that data set afterwards, but, you do get into some complicated situations because you start to train your model, you measure the privacy budget that was given away for that single model by doing that training on that data set, and then, you need to log that, keep up with it, and this is where a ledger system helps because you can have an unforgeable record of that. Then after think about, "Okay, well, can we continue training on that or not?" It's curious for me to understand what it means on individual records, but I don't have anything to say about the individual record case other than I think it's cool and I'd like to have it.
See more presentations with transcripts