DALL-E 3 генерує зображення з тексту. Stable Audio створює музику. Sora від OpenAI перетворює промпти у відео. Три різні моделі, три різні архітектури, три окремі training pipelines, три команди розробників, три набори GPU-кластерів для навчання.
А що якщо одна модель могла б генерувати все? Text-to-image, image-to-video, video-to-audio, audio-to-3D, text-to-anything — в єдиному уніфікованому framework? Не через окремі encoders та decoders, а через справді shared representation, де модель «розуміє» сутність контенту незалежно від форми.
Unified diffusion architectures — це не просто технічна оптимізація для економії compute. Це спроба побудувати «єдину теорію генерації» — фундаментальний підхід до мультимодального AI, де одна модель володіє справжнім cross-modal розумінням світу.
Проблема: ізольовані генеративні моделі
Поточний стан генеративного AI
Text-to-Image: CLIP encoder + U-Net + VAE decoder
Text-to-Video: VideoLDM + custom temporal attention layers
Text-to-Audio: AudioLDM + mel-spectrogram VAE
Text-to-3D: Score Distillation Sampling + NeRF
Image-to-Video: I2V adapters + temporal convolutions
Audio-to-Music: MusicGen + audio codec models
Кожна задача — окрема архітектура. Кожна архітектура — окреме навчання. Кожне навчання — терабайти даних та тижні compute.
Фундаментальні проблеми ізоляції
1. Дублювання знань
Text encoder в DALL-E 3 вивчив, що «sunset over mountains» означає помаранчеве небо, силуети гір, теплі відтінки. Text encoder в Sora вивчив те саме. AudioLDM знову вивчає, що «sunset» асоціюється з спокійною музикою. Три моделі — трикратне дублювання одних і тих самих концептуальних знань.
2. Відсутність cross-modal transfer
Модель, яка ідеально генерує зображення котів, не може допомогти моделі, яка генерує звуки котів. Хоча концептуально «кіт» — один і той самий об'єкт. Знання заблоковані в ізольованих silos.
3. Inconsistent quality
DALL-E 3 генерує фотореалістичні зображення. Але коли вам потрібне відео — якість різко падає. Звук для цього відео — ще гірший. Користувач отримує «зшиті» результати різної якості.
4. Computational waste
Три окремі моделі = три transformer backbones, три наборів attention weights, три embeddings. При цьому багато з цих знань overlap.
Ідея уніфікації
┌─────────────────────────────────────────────────────────┐
│ UNIFIED FOUNDATION MODEL │
│ │
│ Input (any modality) │
│ ↓ │
│ Universal Tokenization │
│ ↓ │
│ Shared Transformer Backbone │
│ ↓ │
│ Output (any modality) │
│ │
│ One model. One training. All modalities. │
└─────────────────────────────────────────────────────────┘
Diffusion Models: технічна база
Forward process (додавання шуму)
import torch
import torch.nn as nn
import numpy as np
from typing import Tuple
class DiffusionProcess:
"""
Базовий diffusion process: forward і reverse.
"""
def __init__(self, num_timesteps: int = 1000,
beta_start: float = 1e-4,
beta_end: float = 0.02):
"""
num_timesteps: кількість кроків diffusion
beta_start, beta_end: noise schedule bounds
"""
self.T = num_timesteps
# Linear noise schedule
self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
self.alphas = 1 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Pre-compute for efficiency
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)
def forward_process(self, x_0: torch.Tensor, t: torch.Tensor,
noise: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward diffusion: x_0 → x_t
q(x_t | x_0) = N(x_t; √ᾱ_t × x_0, (1-ᾱ_t)I)
Closed form: x_t = √ᾱ_t × x_0 + √(1-ᾱ_t) × ε
"""
if noise is None:
noise = torch.randn_like(x_0)
sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
x_t = sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise
return x_t, noise
def reverse_step(self, model: nn.Module, x_t: torch.Tensor,
t: int, condition: torch.Tensor = None) -> torch.Tensor:
"""
Reverse diffusion step: x_t → x_{t-1}
p(x_{t-1} | x_t) = N(x_{t-1}; μ_θ(x_t, t), σ_t²I)
"""
# Predict noise
if condition is not None:
predicted_noise = model(x_t, t, condition)
else:
predicted_noise = model(x_t, t)
# Compute mean
alpha = self.alphas[t]
alpha_cumprod = self.alphas_cumprod[t]
beta = self.betas[t]
# μ = (1/√α_t) × (x_t - (β_t/√(1-ᾱ_t)) × ε_θ)
coef1 = 1 / torch.sqrt(alpha)
coef2 = beta / self.sqrt_one_minus_alphas_cumprod[t]
mean = coef1 * (x_t - coef2 * predicted_noise)
# Add noise (except for t=0)
if t > 0:
noise = torch.randn_like(x_t)
sigma = torch.sqrt(beta)
x_prev = mean + sigma * noise
else:
x_prev = mean
return x_prev
@torch.no_grad()
def sample(self, model: nn.Module, shape: Tuple,
condition: torch.Tensor = None,
device: str = "cuda") -> torch.Tensor:
"""
Full reverse process: noise → clean sample.
"""
# Start from pure noise
x = torch.randn(shape, device=device)
# Iterate T → 0
for t in reversed(range(self.T)):
t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
x = self.reverse_step(model, x, t, condition)
return x
class DiffusionTraining:
"""
Training diffusion models.
"""
def __init__(self, model: nn.Module, diffusion: DiffusionProcess):
self.model = model
self.diffusion = diffusion
def training_loss(self, x_0: torch.Tensor,
condition: torch.Tensor = None) -> torch.Tensor:
"""
Simple diffusion loss: predict noise.
L = E_{t,x_0,ε} [||ε - ε_θ(x_t, t)||²]
"""
batch_size = x_0.shape[0]
# Random timestep for each sample
t = torch.randint(0, self.diffusion.T, (batch_size,), device=x_0.device)
# Add noise
x_t, noise = self.diffusion.forward_process(x_0, t)
# Predict noise
if condition is not None:
predicted_noise = self.model(x_t, t, condition)
else:
predicted_noise = self.model(x_t, t)
# MSE loss
loss = nn.functional.mse_loss(predicted_noise, noise)
return loss
Unified Tokenization: ключ до уніфікації
Принцип: токенізуй все
"""
Unified Tokenization — перетворення будь-якої модальності
в послідовність дискретних або континуальних токенів.
"""
import torch
import torch.nn as nn
from einops import rearrange
class UnifiedTokenizer:
"""
Tokenize any modality into a common format.
"""
def __init__(self, token_dim: int = 768, patch_size: int = 16):
self.token_dim = token_dim
self.patch_size = patch_size
# Modality-specific encoders
self.image_encoder = ImagePatchEncoder(patch_size, token_dim)
self.video_encoder = VideoPatchEncoder(patch_size, token_dim)
self.audio_encoder = AudioSpectrogramEncoder(token_dim)
self.text_encoder = TextTokenEncoder(token_dim)
self.mesh_encoder = Mesh3DEncoder(token_dim)
def tokenize(self, data, modality: str) -> torch.Tensor:
"""
Convert any modality to unified token sequence.
Output shape: [batch, num_tokens, token_dim]
"""
if modality == "image":
return self.image_encoder(data)
elif modality == "video":
return self.video_encoder(data)
elif modality == "audio":
return self.audio_encoder(data)
elif modality == "text":
return self.text_encoder(data)
elif modality == "3d":
return self.mesh_encoder(data)
else:
raise ValueError(f"Unknown modality: {modality}")
class ImagePatchEncoder(nn.Module):
"""
Image → patches → tokens (ViT-style).
"""
def __init__(self, patch_size: int = 16, token_dim: int = 768):
super().__init__()
self.patch_size = patch_size
# Patch embedding: conv with kernel=stride=patch_size
self.patch_embed = nn.Conv2d(
3, token_dim,
kernel_size=patch_size,
stride=patch_size
)
# Learnable position embeddings
# Will be interpolated for different resolutions
self.pos_embed = nn.Parameter(torch.zeros(1, 256, token_dim))
def forward(self, images: torch.Tensor) -> torch.Tensor:
"""
images: [B, 3, H, W]
returns: [B, num_patches, token_dim]
"""
# Patch embedding
patches = self.patch_embed(images) # [B, D, H/P, W/P]
# Flatten spatial dimensions
B, D, H, W = patches.shape
tokens = rearrange(patches, 'b d h w -> b (h w) d')
# Add positional embeddings
# Interpolate if resolution differs from training
pos = self._interpolate_pos_embed(self.pos_embed, H * W)
tokens = tokens + pos
return tokens
def _interpolate_pos_embed(self, pos_embed: torch.Tensor,
num_patches: int) -> torch.Tensor:
"""Interpolate position embeddings for arbitrary sizes."""
if pos_embed.shape[1] == num_patches:
return pos_embed
# 2D interpolation
pos_embed_2d = rearrange(
pos_embed,
'1 (h w) d -> 1 d h w',
h=int(pos_embed.shape[1] ** 0.5)
)
new_size = int(num_patches ** 0.5)
pos_embed_2d = nn.functional.interpolate(
pos_embed_2d,
size=(new_size, new_size),
mode='bilinear',
align_corners=False
)
return rearrange(pos_embed_2d, '1 d h w -> 1 (h w) d')
class VideoPatchEncoder(nn.Module):
"""
Video → 3D patches → tokens.
"""
def __init__(self, patch_size: int = 16, token_dim: int = 768,
temporal_patch: int = 2):
super().__init__()
self.patch_size = patch_size
self.temporal_patch = temporal_patch
# 3D patch embedding
self.patch_embed = nn.Conv3d(
3, token_dim,
kernel_size=(temporal_patch, patch_size, patch_size),
stride=(temporal_patch, patch_size, patch_size)
)
self.pos_embed = nn.Parameter(torch.zeros(1, 1024, token_dim))
def forward(self, video: torch.Tensor) -> torch.Tensor:
"""
video: [B, C, T, H, W]
returns: [B, num_patches, token_dim]
"""
# 3D patch embedding
patches = self.patch_embed(video) # [B, D, T', H', W']
# Flatten
B, D, T, H, W = patches.shape
tokens = rearrange(patches, 'b d t h w -> b (t h w) d')
# Position embeddings
pos = self._interpolate_pos_embed_3d(self.pos_embed, T * H * W)
tokens = tokens + pos
return tokens
def _interpolate_pos_embed_3d(self, pos_embed, num_patches):
"""Interpolate 3D position embeddings."""
# Simplified: use learned or sinusoidal
return pos_embed[:, :num_patches, :]
class AudioSpectrogramEncoder(nn.Module):
"""
Audio → mel-spectrogram → patches → tokens.
"""
def __init__(self, token_dim: int = 768,
n_mels: int = 128,
patch_size: Tuple[int, int] = (16, 16)):
super().__init__()
# Mel-spectrogram extraction
self.mel_transform = nn.Sequential(
# Placeholder: use torchaudio.transforms.MelSpectrogram in practice
)
# Treat spectrogram as image
self.patch_embed = nn.Conv2d(
1, token_dim, # 1 channel for mono
kernel_size=patch_size,
stride=patch_size
)
def forward(self, audio: torch.Tensor) -> torch.Tensor:
"""
audio: [B, samples] or [B, 1, samples]
returns: [B, num_patches, token_dim]
"""
# Convert to mel-spectrogram
mel = self.mel_transform(audio) # [B, 1, n_mels, time]
# Patch embedding (same as image)
patches = self.patch_embed(mel)
# Flatten
tokens = rearrange(patches, 'b d h w -> b (h w) d')
return tokens
class Mesh3DEncoder(nn.Module):
"""
3D mesh/point cloud → tokens.
"""
def __init__(self, token_dim: int = 768, num_points: int = 2048):
super().__init__()
# Point cloud processing (PointNet-style)
self.point_embed = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 256),
nn.ReLU(),
nn.Linear(256, token_dim)
)
# Group points into tokens
self.num_groups = 64
self.points_per_group = num_points // self.num_groups
def forward(self, points: torch.Tensor) -> torch.Tensor:
"""
points: [B, N, 3] (xyz coordinates)
returns: [B, num_groups, token_dim]
"""
# Embed each point
point_features = self.point_embed(points) # [B, N, D]
# Group and pool (simplified)
B, N, D = point_features.shape
grouped = point_features.view(B, self.num_groups, -1, D)
tokens = grouped.mean(dim=2) # [B, num_groups, D]
return tokens
DiT: Diffusion Transformer Architecture
Від U-Net до Transformer
"""
Diffusion Transformer (DiT) — базова архітектура для Sora та інших SOTA моделей.
"""
import torch
import torch.nn as nn
from einops import rearrange
import math
class DiTBlock(nn.Module):
"""
Single DiT transformer block з адаптивною нормалізацією.
Особливість: timestep та condition впливають через
scale/shift параметри замість simple concatenation.
"""
def __init__(self, hidden_dim: int, num_heads: int = 12,
mlp_ratio: float = 4.0, dropout: float = 0.0):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False)
self.norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False)
# Self-attention
self.attn = nn.MultiheadAttention(
hidden_dim, num_heads,
dropout=dropout,
batch_first=True
)
# MLP
mlp_hidden = int(hidden_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, mlp_hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden, hidden_dim),
nn.Dropout(dropout)
)
# AdaLN-Zero modulation
# Outputs: scale1, shift1, gate1, scale2, shift2, gate2
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_dim, 6 * hidden_dim)
)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
"""
x: [B, N, D] — input tokens
c: [B, D] — conditioning (timestep + optional condition)
"""
# Get modulation parameters
modulation = self.adaLN_modulation(c)
scale1, shift1, gate1, scale2, shift2, gate2 = modulation.chunk(6, dim=-1)
# Expand for broadcasting: [B, D] → [B, 1, D]
scale1 = scale1.unsqueeze(1)
shift1 = shift1.unsqueeze(1)
gate1 = gate1.unsqueeze(1)
scale2 = scale2.unsqueeze(1)
shift2 = shift2.unsqueeze(1)
gate2 = gate2.unsqueeze(1)
# Self-attention with adaptive normalization
h = self.norm1(x) * (1 + scale1) + shift1
h, _ = self.attn(h, h, h)
x = x + gate1 * h
# MLP with adaptive normalization
h = self.norm2(x) * (1 + scale2) + shift2
h = self.mlp(h)
x = x + gate2 * h
return x
class DiT(nn.Module):
"""
Diffusion Transformer — scalable architecture for generation.
"""
def __init__(self, input_dim: int = 4, # VAE latent channels
hidden_dim: int = 1152,
num_heads: int = 16,
depth: int = 28,
patch_size: int = 2,
num_classes: int = 1000, # For class-conditional
learn_sigma: bool = True):
super().__init__()
self.patch_size = patch_size
self.hidden_dim = hidden_dim
self.learn_sigma = learn_sigma
# Patch embedding for latent
self.patch_embed = nn.Conv2d(
input_dim, hidden_dim,
kernel_size=patch_size, stride=patch_size
)
# Position embedding (will be resized at runtime)
self.pos_embed = nn.Parameter(torch.zeros(1, 256, hidden_dim))
# Timestep embedding
self.time_embed = TimestepEmbedder(hidden_dim)
# Class embedding (optional)
self.class_embed = nn.Embedding(num_classes, hidden_dim)
# Transformer blocks
self.blocks = nn.ModuleList([
DiTBlock(hidden_dim, num_heads)
for _ in range(depth)
])
# Final layer
self.final_layer = FinalLayer(hidden_dim, patch_size, input_dim * 2 if learn_sigma else input_dim)
self._init_weights()
def _init_weights(self):
"""Initialize with careful scheme for training stability."""
# Initialize embeddings
nn.init.normal_(self.pos_embed, std=0.02)
# Zero-initialize final layer for better training
nn.init.zeros_(self.final_layer.linear.weight)
nn.init.zeros_(self.final_layer.linear.bias)
def forward(self, x: torch.Tensor, t: torch.Tensor,
y: torch.Tensor = None) -> torch.Tensor:
"""
x: [B, C, H, W] — noisy latent
t: [B] — timesteps
y: [B] — class labels (optional)
"""
# Patch embedding
x = self.patch_embed(x) # [B, D, H/P, W/P]
x = rearrange(x, 'b d h w -> b (h w) d')
# Add position embedding
x = x + self.pos_embed[:, :x.shape[1], :]
# Conditioning: timestep + class
c = self.time_embed(t)
if y is not None:
c = c + self.class_embed(y)
# Transformer blocks
for block in self.blocks:
x = block(x, c)
# Final projection
x = self.final_layer(x, c)
# Unpatchify
x = self.unpatchify(x)
return x
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
"""Convert tokens back to spatial format."""
# Assume square output
h = w = int(x.shape[1] ** 0.5)
c = x.shape[-1] // (self.patch_size ** 2)
x = rearrange(
x,
'b (h w) (p1 p2 c) -> b c (h p1) (w p2)',
h=h, w=w, p1=self.patch_size, p2=self.patch_size
)
return x
class TimestepEmbedder(nn.Module):
"""
Sinusoidal timestep embedding + MLP.
"""
def __init__(self, hidden_dim: int, frequency_dim: int = 256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.frequency_dim = frequency_dim
def forward(self, t: torch.Tensor) -> torch.Tensor:
"""t: [B] → [B, hidden_dim]"""
# Sinusoidal embedding
half_dim = self.frequency_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# MLP
return self.mlp(emb)
class FinalLayer(nn.Module):
"""Final projection with adaptive normalization."""
def __init__(self, hidden_dim: int, patch_size: int, out_channels: int):
super().__init__()
self.norm = nn.LayerNorm(hidden_dim, elementwise_affine=False)
self.linear = nn.Linear(hidden_dim, patch_size * patch_size * out_channels)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_dim, 2 * hidden_dim)
)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
scale, shift = self.adaLN_modulation(c).chunk(2, dim=-1)
x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = self.linear(x)
return x
Transfusion: уніфікація text і diffusion
"""
Transfusion (Meta, 2024) — unified training для text (AR) та images (diffusion).
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple
class Transfusion(nn.Module):
"""
Unified model для text та images.
Text: Autoregressive (next token prediction)
Images: Diffusion (denoising)
"""
def __init__(self, vocab_size: int = 32000,
image_tokens: int = 256, # Image token vocabulary
hidden_dim: int = 2048,
num_layers: int = 32,
num_heads: int = 16):
super().__init__()
self.hidden_dim = hidden_dim
self.vocab_size = vocab_size
self.image_token_offset = vocab_size # Image tokens after text tokens
# Text embedding
self.text_embed = nn.Embedding(vocab_size, hidden_dim)
# Image patch embedding (continuous)
self.image_embed = nn.Linear(4, hidden_dim) # VAE latent dim
# Timestep embedding (for diffusion)
self.time_embed = TimestepEmbedder(hidden_dim)
# Shared transformer
self.transformer = nn.ModuleList([
TransfusionBlock(hidden_dim, num_heads)
for _ in range(num_layers)
])
# Output heads
self.text_head = nn.Linear(hidden_dim, vocab_size) # For next token
self.image_head = nn.Linear(hidden_dim, 4) # For noise prediction
# Modality tokens
self.text_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
self.image_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
def forward(self, input_ids: Optional[torch.Tensor] = None,
image_latents: Optional[torch.Tensor] = None,
timesteps: Optional[torch.Tensor] = None,
mode: str = "text") -> Tuple[torch.Tensor, ...]:
"""
Flexible forward для різних режимів.
"""
if mode == "text":
return self._forward_text(input_ids)
elif mode == "image":
return self._forward_image(image_latents, timesteps)
elif mode == "interleaved":
return self._forward_interleaved(input_ids, image_latents, timesteps)
else:
raise ValueError(f"Unknown mode: {mode}")
def _forward_text(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Autoregressive text generation.
"""
B, L = input_ids.shape
# Embed text
x = self.text_embed(input_ids) # [B, L, D]
# Add modality token
x = torch.cat([self.text_token.expand(B, -1, -1), x], dim=1)
# Causal transformer
for layer in self.transformer:
x = layer(x, causal=True)
# Predict next tokens
logits = self.text_head(x[:, 1:, :]) # Remove modality token
return logits
def _forward_image(self, image_latents: torch.Tensor,
timesteps: torch.Tensor) -> torch.Tensor:
"""
Diffusion image generation/denoising.
"""
B = image_latents.shape[0]
# Flatten latents to patches
patches = rearrange(image_latents, 'b c h w -> b (h w) c')
# Embed patches
x = self.image_embed(patches) # [B, N, D]
# Add timestep conditioning
t_emb = self.time_embed(timesteps) # [B, D]
# Add modality token
x = torch.cat([self.image_token.expand(B, -1, -1), x], dim=1)
# Full attention (not causal for diffusion)
for layer in self.transformer:
x = layer(x, causal=False, condition=t_emb)
# Predict noise
noise_pred = self.image_head(x[:, 1:, :]) # [B, N, 4]
# Reshape back
h = w = int(noise_pred.shape[1] ** 0.5)
noise_pred = rearrange(noise_pred, 'b (h w) c -> b c h w', h=h, w=w)
return noise_pred
def _forward_interleaved(self, input_ids: torch.Tensor,
image_latents: torch.Tensor,
timesteps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Interleaved text and image (for multimodal generation).
"""
B = input_ids.shape[0]
# Embed text
text_emb = self.text_embed(input_ids) # [B, L_text, D]
# Embed images
patches = rearrange(image_latents, 'b c h w -> b (h w) c')
image_emb = self.image_embed(patches) # [B, L_img, D]
# Timestep for images
t_emb = self.time_embed(timesteps)
# Concatenate: [text] [image]
x = torch.cat([text_emb, image_emb], dim=1)
# Transformer з causal attention для text частини
for layer in self.transformer:
x = layer(x, causal=False, condition=t_emb) # Simplified
# Split outputs
L_text = input_ids.shape[1]
text_out = x[:, :L_text, :]
image_out = x[:, L_text:, :]
text_logits = self.text_head(text_out)
noise_pred = self.image_head(image_out)
noise_pred = rearrange(noise_pred, 'b (h w) c -> b c h w', h=int(noise_pred.shape[1]**0.5))
return text_logits, noise_pred
class TransfusionBlock(nn.Module):
"""
Transformer block з optional causal masking та conditioning.
"""
def __init__(self, hidden_dim: int, num_heads: int):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, 4 * hidden_dim),
nn.GELU(),
nn.Linear(4 * hidden_dim, hidden_dim)
)
# Conditioning projection
self.cond_proj = nn.Linear(hidden_dim, 2 * hidden_dim)
def forward(self, x: torch.Tensor, causal: bool = False,
condition: torch.Tensor = None) -> torch.Tensor:
# Optional conditioning via scale/shift
if condition is not None:
scale, shift = self.cond_proj(condition).chunk(2, dim=-1)
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
# Attention
mask = self._get_causal_mask(x.shape[1], x.device) if causal else None
h = self.norm1(x)
h, _ = self.attn(h, h, h, attn_mask=mask)
x = x + h
# MLP
x = x + self.mlp(self.norm2(x))
return x
def _get_causal_mask(self, size: int, device) -> torch.Tensor:
mask = torch.triu(torch.ones(size, size, device=device), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask
Multi-Modal Training Strategy
"""
Strategies для training unified multi-modal models.
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Dict, List
class UnifiedTrainer:
"""
Training unified diffusion model на multiple modalities.
"""
def __init__(self, model: nn.Module,
modalities: List[str],
loss_weights: Dict[str, float] = None):
self.model = model
self.modalities = modalities
self.loss_weights = loss_weights or {m: 1.0 for m in modalities}
self.optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-4,
weight_decay=0.05
)
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""
Single training step з multi-task balancing.
"""
self.optimizer.zero_grad()
total_loss = 0
losses = {}
for modality in self.modalities:
if modality not in batch:
continue
data = batch[modality]
# Compute modality-specific loss
if modality == "text":
loss = self._text_loss(data)
elif modality in ["image", "video", "audio"]:
loss = self._diffusion_loss(data, modality)
else:
continue
weighted_loss = self.loss_weights[modality] * loss
total_loss = total_loss + weighted_loss
losses[modality] = loss.item()
total_loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
return losses
def _text_loss(self, data: Dict) -> torch.Tensor:
"""Cross-entropy loss для autoregressive text."""
input_ids = data['input_ids']
labels = data['labels']
logits = self.model(input_ids=input_ids, mode="text")
loss = nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=-100
)
return loss
def _diffusion_loss(self, data: Dict, modality: str) -> torch.Tensor:
"""MSE loss для diffusion denoising."""
x_0 = data['latents']
# Random timesteps
B = x_0.shape[0]
t = torch.randint(0, 1000, (B,), device=x_0.device)
# Add noise
noise = torch.randn_like(x_0)
x_t = self._add_noise(x_0, t, noise)
# Predict noise
pred_noise = self.model(
image_latents=x_t,
timesteps=t,
mode=modality
)
loss = nn.functional.mse_loss(pred_noise, noise)
return loss
def _add_noise(self, x_0, t, noise):
"""Simplified noise addition."""
alpha = 1 - t.float() / 1000
alpha = alpha.view(-1, 1, 1, 1)
return alpha ** 0.5 * x_0 + (1 - alpha) ** 0.5 * noise
class GradientBalancing:
"""
Techniques для balancing gradients між modalities.
"""
@staticmethod
def gradnorm(losses: Dict[str, torch.Tensor],
weights: Dict[str, float],
shared_params: List[nn.Parameter],
alpha: float = 0.5) -> Dict[str, float]:
"""
GradNorm: balance gradients по magnitude.
"""
# Compute gradient norms
grad_norms = {}
for task, loss in losses.items():
grads = torch.autograd.grad(loss, shared_params, retain_graph=True)
grad_norm = sum(g.norm() ** 2 for g in grads if g is not None) ** 0.5
grad_norms[task] = grad_norm
# Target: average gradient norm
avg_norm = sum(grad_norms.values()) / len(grad_norms)
# Update weights
new_weights = {}
for task, norm in grad_norms.items():
relative_loss = losses[task] / sum(losses.values())
target_norm = avg_norm * (relative_loss ** alpha)
new_weights[task] = weights[task] * (target_norm / (norm + 1e-8))
return new_weights
@staticmethod
def pcgrad(grads: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
PCGrad: Project conflicting gradients.
"""
task_names = list(grads.keys())
num_tasks = len(task_names)
# Stack gradients
grad_list = [grads[name].flatten() for name in task_names]
# Project conflicting gradients
for i in range(num_tasks):
for j in range(num_tasks):
if i == j:
continue
dot = torch.dot(grad_list[i], grad_list[j])
if dot < 0:
# Conflicting: project out component
grad_list[i] = grad_list[i] - dot * grad_list[j] / (grad_list[j].norm() ** 2 + 1e-8)
# Average projected gradients
return torch.stack(grad_list).mean(dim=0)
Бенчмарки уніфікованих моделей
| Model | Image FID | Video FVD | Audio FAD | Params | Training |
|-------|-----------|-----------|-----------|--------|----------|
| Specialized (ensemble) | 2.1 | 180 | 1.2 | 3 × 3B | 3 × weeks |
| DiT-XL (image only) | 2.3 | - | - | 675M | 1 week |
| Unified-IO 2 | 3.5 | 220 | 1.8 | 7B | 2 weeks |
| Transfusion | 2.8 | 195 | - | 7B | 2 weeks |
| OmniGen | 3.1 | 205 | 1.5 | 3.8B | 10 days |
Спостереження: Unified моделі наближаються до specialized за якістю, при цьому маючи значно меншу сумарну кількість параметрів.
Ідеї для дослідження
Для бакалавра:
- Fine-tune DiT на специфічному domain
- Порівняння U-Net vs DiT для image generation
- Візуалізація attention maps в DiT
Для магістра:
- Додавання нової модальності до unified model
- Task balancing strategies для multi-modal training
- Efficient inference для великих DiT моделей
Для PhD:
- Theoretical foundations of unified representations
- Scaling laws для multi-modal diffusion
- Novel architectures для 4D+ data (video + audio + depth)
- Emergent capabilities in unified models
Висновок: шлях до AGI-style generation
GPT показав: scale + unified architecture = emergence нових здібностей. Generative AI йде тим самим шляхом. Замість спеціалізованих моделей для кожної модальності — unified foundation models для генерації будь-чого.
Sora, Gemini 2.0, майбутні системи — все буде unified. Одна модель, яка генерує зображення, відео, аудіо, 3D, текст. Не через stitching окремих компонентів, а через справжнє cross-modal розуміння.
Це не мрія — це trajectory, який чітко видно в research papers 2023-2024 років. Хто розуміє unified architectures сьогодні, буде будувати AGI-style системи завтра. Для досліджень у цій галузі команда SKP-Degree на skp-degree.com.ua готова допомогти з реалізацією DiT-based архітектур та multi-modal training. Консультації в Telegram: @kursovi_diplomy.
Unified diffusion, DiT, Diffusion Transformer, Transfusion, multi-modal generation, Sora architecture, foundation models, cross-modal AI — ключові терміни для дипломної чи магістерської роботи з генеративного штучного інтелекту та multi-modal learning.