Transcript
Jarmul: I'm Katharine Jarmul. I'm really excited to talk about advances in federated learning at the edge with you. When I think about the edge, I often think about small embedded devices, IoT, other types of things that might have a small computer in them, and I might not even realize that. I recently learned that these little scooters that are all over my city in Berlin, Germany, and maybe even yours as well, that they are collecting quite a lot of data and sending it. When I think about the data they might be collecting, and when I put on my data science and machine learning hat, and I think about the problems that they might want to solve, they might want to know about maintenance. They might want to know about road and weather conditions. They might want to know about driver performance. Really, the ultimate question they're trying to answer is this last one, which is, is this going to result in some problem for the scooter, or for the human, or for the other things around the scooter and the human? These are the types of questions we ask when we think about data and machine learning. When we think about it on the edge, or with embedded small systems, this often becomes a problem because traditional machine learning needs quite a lot of extra information to answer these questions.
Traditional ML
Let's take a look at a traditional machine learning system and investigate how it might go about collecting this data and answering this question. First, all the data would have to be aggregated and collected into a data lake. It might need to be standardized, or munged, or cleaned, or something done with it beforehand. Then, eventually, that data is pulled usually by a data science team or by scripts written by data engineering, or data scientists on the team. That data is transformed. Initially, we might do EDA, or what's called exploratory data analysis. Then we might decide, ok, here's the data, how I want it prepared. Then we might automate it with a pipeline. Then, eventually, that goes into model training, testing, validation, and selection. In that case, what we're trying to do as a data team is figure out what's the best way to solve this problem with the machine learning that we have available and the data that we have available. Eventually, that's an iterative process. We might decide at the end is we have one or two models that we'd like to try, and we're ready to deploy them, or one or more of them. At that point in time, we would deploy them, usually to a cloud service or other API that is reachable for the participants or for whatever system's going to use the machine learning. This could also be internal. It could be something that's only reachable or internal to the scooter centralized, that interacts via another API with the scooters themselves.
Traditional ML: Limitations
We can immediately, hopefully start to see some of the limitations of this, first off that we need the data collected. This all has to happen some which way and be sent over the network. I'm not sure what type of network connection scooters have, but probably it's not stable. Probably it doesn't have a huge up power. For this reason, this is going to be a problem. Then we have to collect it all across the data lake. Then we need to aggregate across many different locations. Then we're going to train the model and so on. When we think about things like mobile devices, small embedded systems, robots in factories, we can see some of the problems here, not only from a connectivity point of view, but also from a memory, storage, compute power. Then also aggregating and sending all of that, and then having them interact in real time with an API to make a decision, that's a lot. That's a lot of problems and limitations that we can immediately see.
What Is Federated Learning?
How do we approach these? One such approach is to use federated learning. Federated learning was presented by Google, and was actively used in production first by Google when they wanted to learn on Android devices. There was numerous problems that they wanted to solve on Android devices. Let's take an easy example of predictive keyboard support. I want to build a better predictive text model for Spanish speakers by using keyboard data. What I will do is I'll first have the model that I'm actively using, that's the blue circle. I will send it out to all of the Android devices. There's some selection criteria here. First off, I only want Spanish language keyboards, so that's going to rule out some others. Then I want the process to complete successfully, so that means that I want devices that have high battery power or that are charging. I want devices with a good or stable mobile or internet connection. Perhaps I want somebody with enough text that I can help learn. Perhaps there's multiple criteria there that'll be selected.
At that point in time the model is sent to those phones. Then the machine learning happens on the device. Now the device is essentially doing some of the training for me, and that can be an iterative process or a one-shot process. At that point in time, then there is an update at the end of that. That update is not the data itself, but instead a representation usually in some vectorized form, that tells me, ok, this is the direction that the machine learning model should go to improve for based on the data on this device. All of those updates across all of those different devices are collected. They're usually aggregated, so this is step B, and sent to a centralized aggregator, which then uses some computation to combine them. This can be averaging sums, or some other computation, depending on the algorithm and the architecture of the model. Then that is sent back out to all the devices. Then every device has an updated model. I myself as the person managing the process, I also have an updated model. Presumably, then another round could start, or we could say, now the model is good enough, we stop here.
Federated Learning Architectures
When we think about the architecture, the design that I just showed you is the classic one where we have one centralized aggregator with distributed participants. However, you can also imagine other architectures such as clustering participants together and deploying multiple aggregators. Depending on your infrastructure, this might be a better approach, or a worse approach. This is something that should be discussed between the data science teams and machine learning teams, as well as your infrastructure teams. Obviously, I think that this is a good group to help solve those problems. Think about exactly, how should this be deployed?
Finally, we can also think of a scenario where it's fully distributed. This is where we start entering the cryptography realm, and we can think about things like multi-party computation, or other forms of distributed encrypted computations that allow multiple participants to do the machine learning on their local devices. Then to coordinate with one another to combine those updates, and then to send the result out in a meaningful manner. If you have exposure to the cryptographic protocols and communities around things like multi-party computation, and other types of multi-party approaches, then this might be an interesting way to think about the problem. Of course, this also adds massive amounts of complexity for coordinating all of the different devices.
Use Case: FLoCs
It helps to look at a practical use case. We're going to look at FLoCs, which is Federated Learning of Cohorts, that's also from Google. This Federated Learning of Cohorts was a proposal, and I think it's active on some Chrome browsers already. It's essentially a proposal to say, let's do away with third-party cookies, because they're awful, and nobody likes them. They are a privacy nightmare. Many different browsers are removing functionality for third-party cookies or automatically blocking third-party cookies. Let's do away with third-party cookies, but what should we use in order to do targeted, personalized advertising if we don't have third-party cookies? That's what Google proposed FLoCs will do. Their idea for FLoCs is that you get assigned a cohort based on maybe something you like or different browsing preferences, browsing history, and so on. Here, we have a plum group, and we have an orange group. We have these two different cohorts. Perhaps, if you were a fruit seller, you might say, I want to target this plum cohort with plum ads and this orange cohort with orange ads, for example.
Use Case: FLoCs - How it Works
Let's take a look at how it works, and what's the federated learning part of it? The first thing that happens is when you open up a FLoC-enabled Chrome browser, is you get assigned a cohort if you don't already have one. This is probably an iterative process. Those details weren't shared, but it's probably an iterative process. It's likely that my cohort will change over time, because, obviously, my interests and my browsing changes over time. I start browsing as a cohort, and let's say I'm now in the plum cohort, and it turns out via federated learning that we find that five different users visited plums.com today. Then that small piece of information could be sent as an update to the aggregator, which could then update the targeted ad model, and the process continues.
As you can see, this is not a typical federated learning because the ad targeting model is likely not going to be shipped out and running on all of our devices. This is a hybrid model where there's some parts centralized and some parts federated. The cohort model most likely is shipped out, and so that is also something that will probably be iterated across cohorts over time, and sent and updated. We can see how this process works. The goal was to resolve the privacy problems and just the poor user experience as well of online targeted ads, where sometimes you would click on something, it would follow you for days.
We can ask, will this work? Does this address those problems? It will probably lead to better ads, because it's an iterative process. You're not forever in a bucket based on something that you liked five weeks ago, so that most likely from a data perspective will change. However, from a privacy perspective, does it guarantee us something? Is my membership in plums lovers, private to me, or can this be public in some way? This is a higher level philosophical question. It's interesting to think about, what are the qualities that you would want or be comfortable with being aggregated for your cohorts? Google already said that medical history and sexual preferences would not be included, but are those the only categories?
Use Case: FLoCs - Privacy Attacks
In addition, Electronic Frontier Foundation, as well as several other prominent technologists came out and critiqued the FLoC approach, because it's fairly easy to discern potential privacy attacks. Such as someone using FLoC arrives on my website, and I present them with a login screen. Now I have both their cohort and their login. Or, I use browser fingerprinting, now I have the browser fingerprint and their cohort. These combined, again, when we think of cohorts as a few thousand people, create an environment where it might be fairly easy to deanonymize or re-identify those individuals. When we think about collecting a massive database, let's say we had large presence, or we aggregated across many different partners or other properties. We could potentially create a database full of user IDs and their cohorts, full of browser fingerprints and their cohorts. Full of a combination of those, which might even allow us to reverse engineer cohort selection, as well as potentially to reverse engineer even who those individuals are from the cohorts. It's not a silver bullet for privacy.
Federated Learning: Benefits and Weaknesses
Let's think about the benefits and weaknesses of federated learning. One big benefit is that there's no centralized data collection, so I don't have to send and centralize the data all in one place. Instead, the data remains on the device, and I can use federated learning to get small updates to send to the aggregators. A weakness of this is that the data must be standardized on the device. When we had traditional machine learning, we had a large pipeline. We could change and clean and segregate data in different ways. When it's on the device, it needs to be ready for machine learning, or it needs to be quickly ready for machine learning. This is great if you've already standardized data on your devices, but maybe it would be hard if you haven't thought about that problem yet.
The second huge benefit is more diverse datasets. We have way more diverse data, when we think about expanding our reach to many different edge devices. That's one of the reasons why Google even thought about federated learning is, how do we get better targeted improvements for particular types of machine learning that we're trying to do? When you get that more diverse data from more varied users or population, you also have to think of the fact that it's unevenly distributed. When we look at the example of our scooters. If I'm riding a scooter in Berlin, and you're riding a scooter in a very different city in the way that commute works, then your, and my updates might actually cancel each other out. It would depend on the data science team also thinking it through obviously with other teams, of what's the goal of this model? Are we going to be able to sub select populations that allow us to diversify our data without creating a bias in and of itself because of the unevenly distributed data that's available?
Finally, a huge benefit that I think is an undervalued benefit of federated learning is on-device machine learning. How do we make real-time decisions? We can also even personalize the model on the device in certain ways when we think about federated learning that allows for particularly smaller updates to smaller groups. This on-device machine learning, the speed, and the ability for us to answer those questions on the device is huge. A weakness of it if it wasn't already clear to you is that all of those devices must have the model or have whatever models they are working on. If there's proprietary or confidential information in the model, then one should think about the appropriate security concerns there.
Finally, privacy is often touted as a benefit of federated learning. I'm not here to say that it isn't. It definitely is. It's better to send a small update than it is to send the entire browsing history, or something like this to a centralized location. However, those small updates can leak information. There's several other types of privacy attacks like the one that we just talked about for FLoCs. The benefits of the privacy and the guarantees of the privacy are very dependent on the implementation. If you're really choosing federated learning for privacy benefits, then you should think about also employing techniques such as secure computation of the aggregates, or also adding differential privacy gradient clipping. Any other types of techniques that are common in terms of thinking about the privacy of those updates, and what guarantees you'd like to have in the system that you're designing.
Federated Learning: When Does It Make Sense?
When does it make sense to do machine learning in a federated way? My first question to you is, if traditional machine learning works, then you should probably use it. Federated learning is super cool, but it's also still super new. If you're looking for a tried and true and a 100% guaranteed solution, it's probably not it, unless you're working at Google, or one of the other places that has been running federated learning for a long time. Maybe traditional machine learning doesn't work, and you're excited to think about federated learning. My second question for you would be, is your data standardized? Can the devices run federated learning software? I have a list at the end, of a bunch of different open source libraries to play around with and get a feel for federated learning software. Can the devices run that software? Is the data standardized? If not, I would recommend that you try distributed data analysis first, which is essentially trying to ask questions or query the data on the devices, and then sending just the aggregate summations, or counts, or things like this, back to a central location. Then have your data team, or if you're a data person yourself, take a look at it and start answering the questions. Eventually, hopefully, you can roll out a fully federated system.
If the data is standardized, and the devices can run the software, then the final decision for you to make is how you want to do aggregations. Do you want to do them centralized or clustered? I would recommend, centralized is going to have less complexity and be simpler to start with. If you already have a clustered setup, or if your system is already designed in that manner, maybe clustered is easier for you to go to. That's up to you, how you want to structure that and what questions and requirements you have to make that final decision. Some other factors to consider are the privacy requirements and guarantees, which again, are not automatically built in and are something that you have to think about. Please talk with privacy and security experts on your team. I'm also a privacy and security expert, so should you ever want to reach out, please let me know. Then, there's also connectivity requirements. What are you going to do when the device is lost? Is there a certain tipping point when a particular amount of devices are lost where you want to actually halt the round, and stop the process? These are questions that you have to think about and should think about, again, combining the knowledge from your software and infrastructure teams with the knowledge from your data team.
The Promise of Distributed Data Science and Analytics
Finally, I strongly believe that the future of machine learning and data science and analytics is distributed. When we think about the problems that we face as humans, as parts of our many societies and our one world, we can think about many situations where if we had intelligent edge devices, we could probably minimally help nudge things in positive directions. When we think about farming, and climate related changes to farming, and the need for food in our growing worlds. When we think about Coronavirus, and other types of viruses, and how we can track their progress across the globe. When we think about climate related natural disasters, such as flooding, hurricanes, and so on. When we think about all of these things, we can also think about end devices that will be appropriately positioned to ask and answer questions. Potentially, to help create earlier warning systems, to help create better preparation for folks in our governments, folks that are working in NGOs, folks that are trying to solve these problems. Instead, again, of going and trying to collect all data possible, and then we'll figure it out later, we're thinking about the problem and actively asking, what data do we need to solve this problem? How do we go get it?
I think that's, honestly, a much more intelligent, and also, much more privacy aware way to solve these problems is, what problem am I trying to solve? How can I use edge devices to help me along this way? Is this a good thing to do distributed data analysis or data science, as well as potentially federated learning for? Can I solve this problem this way, rather than randomly collecting data and hoping we get to the outcomes? I hope you'll give federated learning or start thinking about how to distribute your data science, and collectivise it, and ask and answer the right questions and ensure you're collecting the right data from the right places.
Questions and Answers
Fedorov: You've mentioned the standardization for the data that is required for the federated learning techniques, would you be able to dig in a little bit more into what kind of standardization? Maybe you have any examples. What's the amount of effort that will be expected?
Jarmul: It can vary greatly depending on your machine learning use case. It's maybe important to define there, how complex you'd like it to be. Essentially, what the data needs to be built into is something that looks like either a matrix or a vector. If you're familiar with multi-dimensional arrays, this is usually how we would prepare data. When we think about an image, like if you think about if we wanted to video analyze your, and myself, and find where the face is. This will be broken down into many different matrices, usually with red, green, blue channels. Each pixel has three channels. Those are all in an array. Then the array is in an N-dimensional space, I think for photos is two-dimensional still.
All that has to happen on-device. For example, even what Google does with Android, is some older devices just do not have the same amount of compute power, so they are not pre-selected for rounds. If you have a newer device, you're probably in more rounds, if you have an older device, you're probably out of more rounds. The data has to be in some form like that, or has to be made available in a form like that, in again, these arrays, so that it can be easily used. You can imagine if you are building your own Android app and you want the data, then somewhere in your software, the data has to look like how you want it for the machine learning algorithm.
Fedorov: A slightly different way of thinking about collecting the data. If today, the approach is basically to log data, somewhat generically, and then figure out later, what parts of the data is useful or not? Here, you really have to think first about your data logging and the structure of your data, is that the correct way to summarize it?
Jarmul: That is a very good way to summarize it. It involves us using our intelligence first before we try to use artificial intelligence to help us.
Fedorov: I can imagine that could be a pretty significant mental shift for many applications.
Jarmul: Yes, definitely. I think it also puts the onus on the data team to think a little bit more about the problem they're trying to solve. Unfortunately, a common experience of data teams is that they have all this data, but none of it answers the question that they want to ask. Maybe this is also helpful in terms of thinking about asking the questions first, and then going about collecting the right data or the fitting data.
Fedorov: I think you mentioned about the capabilities of the device. The fact that the devices have to meet specific minimal requirements, in order to log information, and also to process that. Are there any specific data references, or the idea for what devices would be able to run the workloads, like in today's environment? How is it shifting over time? How do you see the evolution of enabling more devices to run federated learning?
Jarmul: It's pretty interesting to watch what's happening in the iOS world, because obviously, Apple is starting to ship really quiet performative chips, even chips designed for machine learning, or at least designed with machine learning in mind, to devices. Once we can have mini GPUs in our devices, so we can do all sorts of crazy machine learning ourselves. Really, again, why GPUs are so popular in machine learning is we're talking about vectorized math. When you're looking at chips or devices, you want to try to ask the question of, if I were to perform vectorized math operations, would these be easily available? Then, what type of RAM do I have? What type of memory do I have available? Probably that answers 80% of whether the device will work appropriately, although devices are getting more intelligent over time.
Fedorov: Any idea in terms of the CPU, or the memory requirements, or like the video cores that will be the minimal needed for the realistic use case, if you have any data points like that.
Jarmul: It's interesting, because quite a lot of the federated learning libraries are in the deep learning space. That's by no means a requirement. If you look to the open source libraries that I was able to find and share, there's even a few that use more classic machine learning models. Those would obviously have much smaller requirements for memory. Then, one good ask for yourself, is, do I actually need video processing, or can I just take a quick shot of an image and make a decision based on a few screen grabs? It would probably be good to have folks that are here, that understand edge systems well, talk directly with the machine learning team and see if there's a good compromise there. Because probably a lot more is possible, but the machine learning team will inevitably ask for the most complicated model that is possible.
Fedorov: That generally makes sense. You want to get the most of a particular logic.
How much commonalities do you see between iOS and Android? Generally, do they allow similar ML capabilities with the standard interface?
Jarmul: iOS, probably because of the very extremely standardized and controlled approach to the hardware, is ahead of where Android is, in terms of machine learning. With iOS, you can directly build deep learning models in iOS without switching language or doing anything. Apple is quite advanced at shipping models directly to your phone. If you use an Apple device, they actually already have differential privacy built in to the types of distributed data analysis they do. They do quite a lot of data analysis on keyboard usage, and I think some gesture support, and so forth. iOS is far more advanced if you take it as a whole. My guess is that, particularly some of the newer Samsung devices and other folks that are at the cutting edge of Android, is that they're also going to be shipping with these special chips. As soon as we start to have these chips designed, again, and vectorized math, or performing massive amounts of vectorized computations at once, then this is where we'll see the surge of Android availability of these things. Google, you can perform TensorFlow mobile, on Android devices as well. The libraries are there to help. It's just a matter of if the hardware can keep up.
Fedorov: Basically, there are some APIs and the interfaces, but developers need to be aware whether the device can actually support the runtime of the model.
Jarmul: Yes, your app probably wouldn't be very popular if all of a sudden all of the phones started shutting off. If you read the Google paper on federated learning on Android, in case you're targeting Android devices, there's some good, interesting tidbits in the details about how they selected devices. Namely, WiFi connected to the wall and with a certain level of Android.
Fedorov: What role does the metadata play with standardization for aggregation pipelines?
Jarmul: Metadata, depending on if you're creating it, or also if the device is creating it, for example, can always be pulled in and used by machine learning models. It is quite common for us to attach metadata as part of what we would call the feature space, which is how we define what we're going to take and show to the model. In aggregation context, depending on the type of system, often this metadata is dropped. It's actually something that is much easier, I think, to get in real time at the location where it's at, so on the device itself. For example, especially if you're thinking about intelligent gestures, or other things like this, obviously, the movement and the tilt, and so on of the device can be really helpful in figuring out what's happening. Also, then, doesn't have to be encoded and potentially lost when it's sent to a centralized aggregator.
Fedorov: Do you know of anyone who has attempted to apply federated learning techniques in the cybersecurity space?
Jarmul: Actually, yes. I was happy to be a part of a distributed machine learning that was based on a cryptographic approach that allowed for several large network operators to compare notes on essentially malware. It's like, can we actually find via DNS and IP, locations of malware running across large networks? This is just one approach that I know. I think that there's probably more outside of my experience. I think actually cybersecurity is one of the best places for distributed learning, because it often has the most proprietary and sensitive data. It's highly unlikely, for example, that all these large telecoms would get together and actively share a bunch of their transit data. That's not going to happen. Nor would we want it to happen. Choosing a distributed approach is a really good way to get multiple, large parties together to answer questions. That's why I think for cybersecurity, it's actually quite fitting.
Fedorov: One more question is about the cases with running the logic on the device itself. The big challenge is the connectivity, because if the device goes offline, what should we do in terms of running and executing on the logic, and when should we stop? What should the aggregator do in a case where one of the edge nodes is offline? Do you have any suggestions, any advice on that?
Jarmul: This is an incredibly hard problem just in distributed computing. Edge folks, you all know, probably much more than me, but I can share what I do know, which is, we're always going to have the straggler problem. Some of the most intelligent approaches I've seen to that have been included in Flink's design, so Apache Flink. It is a streaming data pipeline. The way that they've designed the different types of windowing requirements for responses. This is essentially intelligently selecting like, when do we say we give up on the last stragglers and we move on with the computation? There's whole algorithms and theories and approaches here. From what I've seen, and from what I've heard, Flink has quite an interesting approach. Actually read through the documentation on the different algorithms chosen and used.
Then, Beam uses some of those same things. If you're familiar also with Beam, you've probably seen some of the same windowing. At some point in time, you have to just decide, ok, we don't wait any longer for any more updates or responses. We've obviously either lost that, or the phone has decided that the computation is going to last too long for us to wait. Then you need to appropriately then size your rounds with the knowledge that you will have some amount of dropout. What might make sense first is even deploying something federated, and just observing dropout without trying to learn for a while. Then actually deploying the models you want to learn, so that you can first see the probability or also better learn the probability of a device dropping, and choose different devices.
See more presentations with transcripts