Federated Learning: як навчати AI без передачі даних
Революція в машинному навчанні — модель їде до даних, а не дані до моделі
Уявіть: ви хочете навчити модель розпізнавання хвороб на медичних знімках. Вам потрібні дані з 100 лікарень. Але жодна лікарня не може передати вам дані пацієнтів — це порушення GDPR, HIPAA, та елементарної етики.
Традиційний підхід Machine Learning вимагає централізації даних. Всі дані збираються в одне місце, там відбувається навчання. Це створює колосальні проблеми:
- Приватність: Персональні дані покидають контрольовану зону
- Регуляції: GDPR, HIPAA, CCPA забороняють такі трансфери
- Безпека: Централізоване сховище — ідеальна мішень для атак
- Bandwidth: Передача терабайтів даних — повільно і дорого
Революційна ідея
Federated Learning перевертає парадигму: замість «дані їдуть до моделі» — «модель їде до даних». Кожен учасник навчає модель локально на своїх даних, потім лише оновлення ваг передаються на сервер для агрегації.
Архітектура федеративного навчання
Раунд федеративного навчання
- Broadcast: Сервер надсилає поточну глобальну модель всім клієнтам
- Local Training: Кожен клієнт навчає модель на своїх локальних даних
- Upload: Клієнти надсилають оновлення ваг (не дані!) на сервер
- Aggregation: Сервер агрегує оновлення в нову глобальну модель
- Repeat: Процес повторюється до збіжності
FedAvg: класичний алгоритм
Federated Averaging (FedAvg) — перший і досі найпопулярніший алгоритм федеративного навчання, запропонований Google у 2017 році.
Wt+1 = Σk=1K (nk / n) · Wkt+1
де:
• Wt+1 — нові глобальні ваги
• K — кількість клієнтів
• nk — кількість зразків на клієнті k
• n — загальна кількість зразків
• Wkt+1 — локальні ваги клієнта k після навчання
Реалізація FedAvg на Python
import numpy as np
from typing import List, Tuple
class FedAvgServer:
"""Сервер для федеративного навчання з FedAvg."""
def __init__(self, initial_model: np.ndarray):
self.global_model = initial_model.copy()
self.round_number = 0
def aggregate(
self,
client_updates: List[Tuple[np.ndarray, int]]
) -> np.ndarray:
"""
Агрегація оновлень від клієнтів.
Args:
client_updates: список (weights, n_samples) від кожного клієнта
Returns:
Нові глобальні ваги
"""
total_samples = sum(n for _, n in client_updates)
# Зважене усереднення
new_weights = np.zeros_like(self.global_model)
for weights, n_samples in client_updates:
weight_factor = n_samples / total_samples
new_weights += weight_factor * weights
self.global_model = new_weights
self.round_number += 1
return self.global_model
class FedAvgClient:
"""Клієнт для федеративного навчання."""
def __init__(self, client_id: str, local_data: np.ndarray):
self.client_id = client_id
self.local_data = local_data
self.model = None
def receive_model(self, global_model: np.ndarray):
"""Отримання глобальної моделі від сервера."""
self.model = global_model.copy()
def train_local(
self,
epochs: int = 5,
learning_rate: float = 0.01
) -> Tuple[np.ndarray, int]:
"""
Локальне навчання на власних даних.
Returns:
(оновлені ваги, кількість зразків)
"""
# Тут має бути реальне навчання моделі
# Спрощена імітація:
for epoch in range(epochs):
gradient = self._compute_gradient()
self.model -= learning_rate * gradient
return self.model, len(self.local_data)
def _compute_gradient(self) -> np.ndarray:
"""Обчислення градієнту на локальних даних."""
# Спрощена імітація
return np.random.randn(*self.model.shape) * 0.1
# Приклад використання
def run_federated_training(
server: FedAvgServer,
clients: List[FedAvgClient],
num_rounds: int = 10
):
for round_num in range(num_rounds):
print(f"Round {round_num + 1}/{num_rounds}")
# 1. Broadcast глобальної моделі
for client in clients:
client.receive_model(server.global_model)
# 2. Локальне навчання
updates = []
for client in clients:
weights, n_samples = client.train_local()
updates.append((weights, n_samples))
# 3. Агрегація
server.aggregate(updates)
print(f" Агреговано {len(updates)} оновлень")
Проблема Non-IID даних
FedAvg чудово працює, коли дані на клієнтах схожі (IID — Independent and Identically Distributed). Але в реальності це рідкість:
Кожен клієнт має репрезентативну вибірку всіх класів.
Приклад: кожна лікарня має пацієнтів з усіма хворобами у схожих пропорціяхДані на клієнтах суттєво відрізняються за розподілом.
Приклад: одна лікарня спеціалізується на кардіології, інша — на онкологіїFedProx: вирішення Non-IID
FedProx додає проксимальний терм до функції втрат, який не дозволяє локальній моделі занадто відхилятися від глобальної:
hk(w; wt) = Fk(w) + (μ/2) · ||w - wt||²
де μ — гіперпараметр регуляризації (зазвичай 0.001 - 0.1)
class FedProxClient(FedAvgClient):
"""Клієнт з підтримкою FedProx регуляризації."""
def __init__(self, client_id: str, local_data: np.ndarray, mu: float = 0.01):
super().__init__(client_id, local_data)
self.mu = mu
self.global_model_copy = None
def receive_model(self, global_model: np.ndarray):
super().receive_model(global_model)
# Зберігаємо копію для проксимального терму
self.global_model_copy = global_model.copy()
def _compute_loss(self, predictions, targets) -> float:
"""Обчислення loss з проксимальним термом."""
base_loss = self._cross_entropy(predictions, targets)
# Проксимальний терм
proximal_term = (self.mu / 2) * np.sum(
(self.model - self.global_model_copy) ** 2
)
return base_loss + proximal_term
Диференційна приватність
Навіть без передачі даних, оновлення ваг можуть розкривати інформацію. Атаки типу Membership Inference або Model Inversion можуть відновити чутливі дані з градієнтів.
Атака Model Inversion
Зловмисник, маючи доступ до градієнтів, може реконструювати вхідні дані. У 2020 році дослідники продемонстрували відновлення фотографій обличь з градієнтів з точністю 95%.
Differential Privacy (DP)
Рішення — додавати контрольований шум до градієнтів, забезпечуючи математичні гарантії приватності:
import numpy as np
def add_gaussian_noise(
gradients: np.ndarray,
sensitivity: float,
epsilon: float,
delta: float
) -> np.ndarray:
"""
Додавання гауссівського шуму для (ε, δ)-differential privacy.
Args:
gradients: оригінальні градієнти
sensitivity: чутливість функції (max зміна при зміні одного запису)
epsilon: параметр приватності (менше = більше приватності)
delta: ймовірність витоку
Returns:
Градієнти з доданим шумом
"""
# Обчислення стандартного відхилення шуму
sigma = sensitivity * np.sqrt(2 * np.log(1.25 / delta)) / epsilon
# Генерація та додавання шуму
noise = np.random.normal(0, sigma, gradients.shape)
return gradients + noise
def clip_gradients(
gradients: np.ndarray,
max_norm: float
) -> np.ndarray:
"""
Gradient clipping для обмеження чутливості.
"""
norm = np.linalg.norm(gradients)
if norm > max_norm:
gradients = gradients * (max_norm / norm)
return gradients
class DPFedAvgClient(FedAvgClient):
"""Клієнт з Differential Privacy."""
def __init__(
self,
client_id: str,
local_data: np.ndarray,
epsilon: float = 1.0,
delta: float = 1e-5,
max_grad_norm: float = 1.0
):
super().__init__(client_id, local_data)
self.epsilon = epsilon
self.delta = delta
self.max_grad_norm = max_grad_norm
def train_local(self, epochs: int = 5, learning_rate: float = 0.01):
for epoch in range(epochs):
gradient = self._compute_gradient()
# 1. Gradient clipping
gradient = clip_gradients(gradient, self.max_grad_norm)
# 2. Додавання шуму
gradient = add_gaussian_noise(
gradient,
sensitivity=self.max_grad_norm,
epsilon=self.epsilon / epochs, # Композиція бюджету
delta=self.delta / epochs
)
self.model -= learning_rate * gradient
return self.model, len(self.local_data)
Secure Aggregation
Secure Aggregation — криптографічний протокол, що дозволяє серверу обчислювати суму оновлень, не бачачи індивідуальних оновлень жодного клієнта.
import hashlib
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
class SecureAggregationClient:
"""Спрощена реалізація Secure Aggregation."""
def __init__(self, client_id: int, total_clients: int):
self.client_id = client_id
self.total_clients = total_clients
self.pairwise_masks = {}
def generate_pairwise_mask(
self,
other_client_id: int,
shared_secret: bytes,
shape: tuple
) -> np.ndarray:
"""Генерація маски на основі спільного секрету."""
# Детермінований seed з секрету
seed = int.from_bytes(
hashlib.sha256(shared_secret).digest()[:4],
'big'
)
rng = np.random.RandomState(seed)
mask = rng.randn(*shape)
# Визначаємо знак: менший ID додає, більший віднімає
if self.client_id > other_client_id:
mask = -mask
self.pairwise_masks[other_client_id] = mask
return mask
def mask_update(self, weights: np.ndarray) -> np.ndarray:
"""Застосування всіх масок до оновлення."""
masked = weights.copy()
for mask in self.pairwise_masks.values():
masked += mask
return masked
Flower Framework: практична реалізація
Flower — найпопулярніший open-source фреймворк для федеративного навчання. Підтримує PyTorch, TensorFlow, JAX та будь-який ML framework.
# pip install flwr torch torchvision
import flwr as fl
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from collections import OrderedDict
# Проста CNN для MNIST
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7)
x = self.relu(self.fc1(x))
return self.fc2(x)
# Flower Client
class MNISTClient(fl.client.NumPyClient):
def __init__(self, trainloader, testloader):
self.trainloader = trainloader
self.testloader = testloader
self.model = SimpleCNN()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def get_parameters(self, config):
"""Повернення параметрів моделі."""
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
def set_parameters(self, parameters):
"""Встановлення параметрів від сервера."""
params_dict = zip(self.model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
self.model.load_state_dict(state_dict, strict=True)
def fit(self, parameters, config):
"""Локальне навчання."""
self.set_parameters(parameters)
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
self.model.train()
for epoch in range(config.get("local_epochs", 1)):
for images, labels in self.trainloader:
images, labels = images.to(self.device), labels.to(self.device)
optimizer.zero_grad()
loss = criterion(self.model(images), labels)
loss.backward()
optimizer.step()
return self.get_parameters(config), len(self.trainloader.dataset), {}
def evaluate(self, parameters, config):
"""Оцінка моделі."""
self.set_parameters(parameters)
criterion = nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
self.model.eval()
with torch.no_grad():
for images, labels in self.testloader:
images, labels = images.to(self.device), labels.to(self.device)
outputs = self.model(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return loss / len(self.testloader), total, {"accuracy": correct / total}
# Запуск клієнта
def start_client(client_id: int, num_clients: int = 10):
# Завантаження MNIST
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
trainset = datasets.MNIST("./data", train=True, download=True, transform=transform)
testset = datasets.MNIST("./data", train=False, transform=transform)
# Розподіл даних між клієнтами (симуляція Non-IID)
samples_per_client = len(trainset) // num_clients
start_idx = client_id * samples_per_client
end_idx = start_idx + samples_per_client
client_trainset = Subset(trainset, range(start_idx, end_idx))
trainloader = DataLoader(client_trainset, batch_size=32, shuffle=True)
testloader = DataLoader(testset, batch_size=32)
# Запуск Flower клієнта
client = MNISTClient(trainloader, testloader)
fl.client.start_numpy_client(
server_address="localhost:8080",
client=client
)
if __name__ == "__main__":
import sys
client_id = int(sys.argv[1]) if len(sys.argv) > 1 else 0
start_client(client_id)
Сервер Flower
import flwr as fl
# Стратегія FedAvg з кастомними параметрами
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.5, # 50% клієнтів на раунд
fraction_evaluate=0.3, # 30% для оцінки
min_fit_clients=2, # Мінімум 2 клієнти
min_available_clients=3, # Чекати 3 клієнти
on_fit_config_fn=lambda r: {"local_epochs": 2}
)
# Запуск сервера
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=10),
strategy=strategy
)
Порівняння алгоритмів
| Алгоритм | Non-IID | Комунікація | Privacy | Складність |
|---|---|---|---|---|
| FedAvg | Середня | Низька | Базова | Проста |
| FedProx | Висока | Низька | Базова | Проста |
| SCAFFOLD | Висока | 2× FedAvg | Базова | Середня |
| FedAvg + DP | Середня | Низька | Сильна | Середня |
| SecAgg | Середня | Висока | Криптографічна | Складна |
Реальні застосування
Клавіатура Google навчається на ваших повідомленнях без передачі тексту на сервер. Предиктивний ввід покращується локально, лише агреговані оновлення йдуть в хмару.
Мільярди пристроїв, терабайти даних — все приватно.Лікарні спільно навчають моделі діагностики без обміну даними пацієнтів. NVIDIA FLARE використовується в проектах з 20+ медичних центрів.
Відповідність HIPAA та GDPR з коробки.Банки спільно навчають моделі виявлення шахрайства, не розкриваючи транзакції. Кожен банк бачить лише свої дані, але модель вчиться на патернах усіх.
Виявлення fraud на 40% точніше.Tesla та інші виробники використовують FL для навчання на даних з мільйонів автомобілів. Edge-пристрої навчаються локально, синхронізуючи лише оновлення.
Bandwidth економія: 1000× менше даних.Потрібна допомога з проектом?
Federated Learning -- складна та актуальна тема для дипломної чи курсової роботи. Наші ML-інженери допоможуть з реалізацією та документацією.
Замовити ML проектІдеї для курсової роботи
Зміст: Реалізація обох алгоритмів, експерименти з різними ступенями Non-IID (label skew, quantity skew), аналіз збіжності.
Технології: Python, PyTorch, Flower
Складність: Середня
Зміст: Реалізація DP-FL для класифікації медичних зображень, аналіз trade-off між приватністю та точністю.
Технології: Python, PyTorch, Opacus, Flower
Складність: Висока
Зміст: Реалізація Per-FedAvg або FedPer — алгоритмів, що створюють персоналізовані моделі для кожного клієнта.
Технології: Python, TensorFlow/PyTorch, Flower
Складність: Середня
Потрібна допомога з курсовою?
Федеративне навчання — одна з найактуальніших тем у ML. Ми допоможемо з реалізацією, експериментами та оформленням роботи.
Замовити курсову з Federated Learning