Технології AI Написано практикуючими розробниками

Unified Diffusion: одна модель для всіх модальностей — архітектура майбутнього генеративного AI

Оновлено: 18 хв читання 4 переглядів

DALL-E 3 генерує зображення з тексту. Stable Audio створює музику. Sora від OpenAI перетворює промпти у відео. Три різні моделі, три різні архітектури, три окремі training pipelines, три команди розробників, три набори GPU-кластерів для навчання.


DALL-E 3 генерує зображення з тексту. Stable Audio створює музику. Sora від OpenAI перетворює промпти у відео. Три різні моделі, три різні архітектури, три окремі training pipelines, три команди розробників, три набори GPU-кластерів для навчання.

А що якщо одна модель могла б генерувати все? Text-to-image, image-to-video, video-to-audio, audio-to-3D, text-to-anything — в єдиному уніфікованому framework? Не через окремі encoders та decoders, а через справді shared representation, де модель «розуміє» сутність контенту незалежно від форми.

Unified diffusion architectures — це не просто технічна оптимізація для економії compute. Це спроба побудувати «єдину теорію генерації» — фундаментальний підхід до мультимодального AI, де одна модель володіє справжнім cross-modal розумінням світу.


Проблема: ізольовані генеративні моделі

Поточний стан генеративного AI

Text-to-Image:    CLIP encoder + U-Net + VAE decoder
Text-to-Video:    VideoLDM + custom temporal attention layers
Text-to-Audio:    AudioLDM + mel-spectrogram VAE
Text-to-3D:       Score Distillation Sampling + NeRF
Image-to-Video:   I2V adapters + temporal convolutions
Audio-to-Music:   MusicGen + audio codec models

Кожна задача — окрема архітектура. Кожна архітектура — окреме навчання. Кожне навчання — терабайти даних та тижні compute.

Фундаментальні проблеми ізоляції

1. Дублювання знань

Text encoder в DALL-E 3 вивчив, що «sunset over mountains» означає помаранчеве небо, силуети гір, теплі відтінки. Text encoder в Sora вивчив те саме. AudioLDM знову вивчає, що «sunset» асоціюється з спокійною музикою. Три моделі — трикратне дублювання одних і тих самих концептуальних знань.

2. Відсутність cross-modal transfer

Модель, яка ідеально генерує зображення котів, не може допомогти моделі, яка генерує звуки котів. Хоча концептуально «кіт» — один і той самий об'єкт. Знання заблоковані в ізольованих silos.

3. Inconsistent quality

DALL-E 3 генерує фотореалістичні зображення. Але коли вам потрібне відео — якість різко падає. Звук для цього відео — ще гірший. Користувач отримує «зшиті» результати різної якості.

4. Computational waste

Три окремі моделі = три transformer backbones, три наборів attention weights, три embeddings. При цьому багато з цих знань overlap.

Ідея уніфікації

┌─────────────────────────────────────────────────────────┐
│           UNIFIED FOUNDATION MODEL                       │
│                                                          │
│   Input (any modality)                                  │
│          ↓                                              │
│   Universal Tokenization                                │
│          ↓                                              │
│   Shared Transformer Backbone                           │
│          ↓                                              │
│   Output (any modality)                                 │
│                                                          │
│   One model. One training. All modalities.              │
└─────────────────────────────────────────────────────────┘

Diffusion Models: технічна база

Forward process (додавання шуму)

import torch
import torch.nn as nn
import numpy as np
from typing import Tuple

class DiffusionProcess:
    """
    Базовий diffusion process: forward і reverse.
    """

    def __init__(self, num_timesteps: int = 1000,
                 beta_start: float = 1e-4,
                 beta_end: float = 0.02):
        """
        num_timesteps: кількість кроків diffusion
        beta_start, beta_end: noise schedule bounds
        """
        self.T = num_timesteps

        # Linear noise schedule
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

        # Pre-compute for efficiency
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)

    def forward_process(self, x_0: torch.Tensor, t: torch.Tensor,
                        noise: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward diffusion: x_0 → x_t

        q(x_t | x_0) = N(x_t; √ᾱ_t × x_0, (1-ᾱ_t)I)

        Closed form: x_t = √ᾱ_t × x_0 + √(1-ᾱ_t) × ε
        """
        if noise is None:
            noise = torch.randn_like(x_0)

        sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)

        x_t = sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise

        return x_t, noise

    def reverse_step(self, model: nn.Module, x_t: torch.Tensor,
                     t: int, condition: torch.Tensor = None) -> torch.Tensor:
        """
        Reverse diffusion step: x_t → x_{t-1}

        p(x_{t-1} | x_t) = N(x_{t-1}; μ_θ(x_t, t), σ_t²I)
        """
        # Predict noise
        if condition is not None:
            predicted_noise = model(x_t, t, condition)
        else:
            predicted_noise = model(x_t, t)

        # Compute mean
        alpha = self.alphas[t]
        alpha_cumprod = self.alphas_cumprod[t]
        beta = self.betas[t]

        # μ = (1/√α_t) × (x_t - (β_t/√(1-ᾱ_t)) × ε_θ)
        coef1 = 1 / torch.sqrt(alpha)
        coef2 = beta / self.sqrt_one_minus_alphas_cumprod[t]

        mean = coef1 * (x_t - coef2 * predicted_noise)

        # Add noise (except for t=0)
        if t > 0:
            noise = torch.randn_like(x_t)
            sigma = torch.sqrt(beta)
            x_prev = mean + sigma * noise
        else:
            x_prev = mean

        return x_prev

    @torch.no_grad()
    def sample(self, model: nn.Module, shape: Tuple,
               condition: torch.Tensor = None,
               device: str = "cuda") -> torch.Tensor:
        """
        Full reverse process: noise → clean sample.
        """
        # Start from pure noise
        x = torch.randn(shape, device=device)

        # Iterate T → 0
        for t in reversed(range(self.T)):
            t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
            x = self.reverse_step(model, x, t, condition)

        return x


class DiffusionTraining:
    """
    Training diffusion models.
    """

    def __init__(self, model: nn.Module, diffusion: DiffusionProcess):
        self.model = model
        self.diffusion = diffusion

    def training_loss(self, x_0: torch.Tensor,
                      condition: torch.Tensor = None) -> torch.Tensor:
        """
        Simple diffusion loss: predict noise.

        L = E_{t,x_0,ε} [||ε - ε_θ(x_t, t)||²]
        """
        batch_size = x_0.shape[0]

        # Random timestep for each sample
        t = torch.randint(0, self.diffusion.T, (batch_size,), device=x_0.device)

        # Add noise
        x_t, noise = self.diffusion.forward_process(x_0, t)

        # Predict noise
        if condition is not None:
            predicted_noise = self.model(x_t, t, condition)
        else:
            predicted_noise = self.model(x_t, t)

        # MSE loss
        loss = nn.functional.mse_loss(predicted_noise, noise)

        return loss

Unified Tokenization: ключ до уніфікації

Принцип: токенізуй все

"""
Unified Tokenization — перетворення будь-якої модальності
в послідовність дискретних або континуальних токенів.
"""

import torch
import torch.nn as nn
from einops import rearrange

class UnifiedTokenizer:
    """
    Tokenize any modality into a common format.
    """

    def __init__(self, token_dim: int = 768, patch_size: int = 16):
        self.token_dim = token_dim
        self.patch_size = patch_size

        # Modality-specific encoders
        self.image_encoder = ImagePatchEncoder(patch_size, token_dim)
        self.video_encoder = VideoPatchEncoder(patch_size, token_dim)
        self.audio_encoder = AudioSpectrogramEncoder(token_dim)
        self.text_encoder = TextTokenEncoder(token_dim)
        self.mesh_encoder = Mesh3DEncoder(token_dim)

    def tokenize(self, data, modality: str) -> torch.Tensor:
        """
        Convert any modality to unified token sequence.

        Output shape: [batch, num_tokens, token_dim]
        """
        if modality == "image":
            return self.image_encoder(data)
        elif modality == "video":
            return self.video_encoder(data)
        elif modality == "audio":
            return self.audio_encoder(data)
        elif modality == "text":
            return self.text_encoder(data)
        elif modality == "3d":
            return self.mesh_encoder(data)
        else:
            raise ValueError(f"Unknown modality: {modality}")


class ImagePatchEncoder(nn.Module):
    """
    Image → patches → tokens (ViT-style).
    """

    def __init__(self, patch_size: int = 16, token_dim: int = 768):
        super().__init__()
        self.patch_size = patch_size

        # Patch embedding: conv with kernel=stride=patch_size
        self.patch_embed = nn.Conv2d(
            3, token_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

        # Learnable position embeddings
        # Will be interpolated for different resolutions
        self.pos_embed = nn.Parameter(torch.zeros(1, 256, token_dim))

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """
        images: [B, 3, H, W]
        returns: [B, num_patches, token_dim]
        """
        # Patch embedding
        patches = self.patch_embed(images)  # [B, D, H/P, W/P]

        # Flatten spatial dimensions
        B, D, H, W = patches.shape
        tokens = rearrange(patches, 'b d h w -> b (h w) d')

        # Add positional embeddings
        # Interpolate if resolution differs from training
        pos = self._interpolate_pos_embed(self.pos_embed, H * W)
        tokens = tokens + pos

        return tokens

    def _interpolate_pos_embed(self, pos_embed: torch.Tensor,
                               num_patches: int) -> torch.Tensor:
        """Interpolate position embeddings for arbitrary sizes."""
        if pos_embed.shape[1] == num_patches:
            return pos_embed

        # 2D interpolation
        pos_embed_2d = rearrange(
            pos_embed,
            '1 (h w) d -> 1 d h w',
            h=int(pos_embed.shape[1] ** 0.5)
        )

        new_size = int(num_patches ** 0.5)
        pos_embed_2d = nn.functional.interpolate(
            pos_embed_2d,
            size=(new_size, new_size),
            mode='bilinear',
            align_corners=False
        )

        return rearrange(pos_embed_2d, '1 d h w -> 1 (h w) d')


class VideoPatchEncoder(nn.Module):
    """
    Video → 3D patches → tokens.
    """

    def __init__(self, patch_size: int = 16, token_dim: int = 768,
                 temporal_patch: int = 2):
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch = temporal_patch

        # 3D patch embedding
        self.patch_embed = nn.Conv3d(
            3, token_dim,
            kernel_size=(temporal_patch, patch_size, patch_size),
            stride=(temporal_patch, patch_size, patch_size)
        )

        self.pos_embed = nn.Parameter(torch.zeros(1, 1024, token_dim))

    def forward(self, video: torch.Tensor) -> torch.Tensor:
        """
        video: [B, C, T, H, W]
        returns: [B, num_patches, token_dim]
        """
        # 3D patch embedding
        patches = self.patch_embed(video)  # [B, D, T', H', W']

        # Flatten
        B, D, T, H, W = patches.shape
        tokens = rearrange(patches, 'b d t h w -> b (t h w) d')

        # Position embeddings
        pos = self._interpolate_pos_embed_3d(self.pos_embed, T * H * W)
        tokens = tokens + pos

        return tokens

    def _interpolate_pos_embed_3d(self, pos_embed, num_patches):
        """Interpolate 3D position embeddings."""
        # Simplified: use learned or sinusoidal
        return pos_embed[:, :num_patches, :]


class AudioSpectrogramEncoder(nn.Module):
    """
    Audio → mel-spectrogram → patches → tokens.
    """

    def __init__(self, token_dim: int = 768,
                 n_mels: int = 128,
                 patch_size: Tuple[int, int] = (16, 16)):
        super().__init__()

        # Mel-spectrogram extraction
        self.mel_transform = nn.Sequential(
            # Placeholder: use torchaudio.transforms.MelSpectrogram in practice
        )

        # Treat spectrogram as image
        self.patch_embed = nn.Conv2d(
            1, token_dim,  # 1 channel for mono
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, audio: torch.Tensor) -> torch.Tensor:
        """
        audio: [B, samples] or [B, 1, samples]
        returns: [B, num_patches, token_dim]
        """
        # Convert to mel-spectrogram
        mel = self.mel_transform(audio)  # [B, 1, n_mels, time]

        # Patch embedding (same as image)
        patches = self.patch_embed(mel)

        # Flatten
        tokens = rearrange(patches, 'b d h w -> b (h w) d')

        return tokens


class Mesh3DEncoder(nn.Module):
    """
    3D mesh/point cloud → tokens.
    """

    def __init__(self, token_dim: int = 768, num_points: int = 2048):
        super().__init__()

        # Point cloud processing (PointNet-style)
        self.point_embed = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, token_dim)
        )

        # Group points into tokens
        self.num_groups = 64
        self.points_per_group = num_points // self.num_groups

    def forward(self, points: torch.Tensor) -> torch.Tensor:
        """
        points: [B, N, 3] (xyz coordinates)
        returns: [B, num_groups, token_dim]
        """
        # Embed each point
        point_features = self.point_embed(points)  # [B, N, D]

        # Group and pool (simplified)
        B, N, D = point_features.shape
        grouped = point_features.view(B, self.num_groups, -1, D)
        tokens = grouped.mean(dim=2)  # [B, num_groups, D]

        return tokens

DiT: Diffusion Transformer Architecture

Від U-Net до Transformer

"""
Diffusion Transformer (DiT) — базова архітектура для Sora та інших SOTA моделей.
"""

import torch
import torch.nn as nn
from einops import rearrange
import math

class DiTBlock(nn.Module):
    """
    Single DiT transformer block з адаптивною нормалізацією.

    Особливість: timestep та condition впливають через
    scale/shift параметри замість simple concatenation.
    """

    def __init__(self, hidden_dim: int, num_heads: int = 12,
                 mlp_ratio: float = 4.0, dropout: float = 0.0):
        super().__init__()

        self.norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False)
        self.norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False)

        # Self-attention
        self.attn = nn.MultiheadAttention(
            hidden_dim, num_heads,
            dropout=dropout,
            batch_first=True
        )

        # MLP
        mlp_hidden = int(hidden_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, mlp_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, hidden_dim),
            nn.Dropout(dropout)
        )

        # AdaLN-Zero modulation
        # Outputs: scale1, shift1, gate1, scale2, shift2, gate2
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_dim, 6 * hidden_dim)
        )

    def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        """
        x: [B, N, D] — input tokens
        c: [B, D] — conditioning (timestep + optional condition)
        """
        # Get modulation parameters
        modulation = self.adaLN_modulation(c)
        scale1, shift1, gate1, scale2, shift2, gate2 = modulation.chunk(6, dim=-1)

        # Expand for broadcasting: [B, D] → [B, 1, D]
        scale1 = scale1.unsqueeze(1)
        shift1 = shift1.unsqueeze(1)
        gate1 = gate1.unsqueeze(1)
        scale2 = scale2.unsqueeze(1)
        shift2 = shift2.unsqueeze(1)
        gate2 = gate2.unsqueeze(1)

        # Self-attention with adaptive normalization
        h = self.norm1(x) * (1 + scale1) + shift1
        h, _ = self.attn(h, h, h)
        x = x + gate1 * h

        # MLP with adaptive normalization
        h = self.norm2(x) * (1 + scale2) + shift2
        h = self.mlp(h)
        x = x + gate2 * h

        return x


class DiT(nn.Module):
    """
    Diffusion Transformer — scalable architecture for generation.
    """

    def __init__(self, input_dim: int = 4,  # VAE latent channels
                 hidden_dim: int = 1152,
                 num_heads: int = 16,
                 depth: int = 28,
                 patch_size: int = 2,
                 num_classes: int = 1000,  # For class-conditional
                 learn_sigma: bool = True):
        super().__init__()

        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.learn_sigma = learn_sigma

        # Patch embedding for latent
        self.patch_embed = nn.Conv2d(
            input_dim, hidden_dim,
            kernel_size=patch_size, stride=patch_size
        )

        # Position embedding (will be resized at runtime)
        self.pos_embed = nn.Parameter(torch.zeros(1, 256, hidden_dim))

        # Timestep embedding
        self.time_embed = TimestepEmbedder(hidden_dim)

        # Class embedding (optional)
        self.class_embed = nn.Embedding(num_classes, hidden_dim)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            DiTBlock(hidden_dim, num_heads)
            for _ in range(depth)
        ])

        # Final layer
        self.final_layer = FinalLayer(hidden_dim, patch_size, input_dim * 2 if learn_sigma else input_dim)

        self._init_weights()

    def _init_weights(self):
        """Initialize with careful scheme for training stability."""
        # Initialize embeddings
        nn.init.normal_(self.pos_embed, std=0.02)

        # Zero-initialize final layer for better training
        nn.init.zeros_(self.final_layer.linear.weight)
        nn.init.zeros_(self.final_layer.linear.bias)

    def forward(self, x: torch.Tensor, t: torch.Tensor,
                y: torch.Tensor = None) -> torch.Tensor:
        """
        x: [B, C, H, W] — noisy latent
        t: [B] — timesteps
        y: [B] — class labels (optional)
        """
        # Patch embedding
        x = self.patch_embed(x)  # [B, D, H/P, W/P]
        x = rearrange(x, 'b d h w -> b (h w) d')

        # Add position embedding
        x = x + self.pos_embed[:, :x.shape[1], :]

        # Conditioning: timestep + class
        c = self.time_embed(t)
        if y is not None:
            c = c + self.class_embed(y)

        # Transformer blocks
        for block in self.blocks:
            x = block(x, c)

        # Final projection
        x = self.final_layer(x, c)

        # Unpatchify
        x = self.unpatchify(x)

        return x

    def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
        """Convert tokens back to spatial format."""
        # Assume square output
        h = w = int(x.shape[1] ** 0.5)
        c = x.shape[-1] // (self.patch_size ** 2)

        x = rearrange(
            x,
            'b (h w) (p1 p2 c) -> b c (h p1) (w p2)',
            h=h, w=w, p1=self.patch_size, p2=self.patch_size
        )

        return x


class TimestepEmbedder(nn.Module):
    """
    Sinusoidal timestep embedding + MLP.
    """

    def __init__(self, hidden_dim: int, frequency_dim: int = 256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.frequency_dim = frequency_dim

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        """t: [B] → [B, hidden_dim]"""
        # Sinusoidal embedding
        half_dim = self.frequency_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

        # MLP
        return self.mlp(emb)


class FinalLayer(nn.Module):
    """Final projection with adaptive normalization."""

    def __init__(self, hidden_dim: int, patch_size: int, out_channels: int):
        super().__init__()
        self.norm = nn.LayerNorm(hidden_dim, elementwise_affine=False)
        self.linear = nn.Linear(hidden_dim, patch_size * patch_size * out_channels)

        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_dim, 2 * hidden_dim)
        )

    def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        scale, shift = self.adaLN_modulation(c).chunk(2, dim=-1)
        x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
        x = self.linear(x)
        return x

Transfusion: уніфікація text і diffusion

"""
Transfusion (Meta, 2024) — unified training для text (AR) та images (diffusion).
"""

import torch
import torch.nn as nn
from typing import Optional, Tuple

class Transfusion(nn.Module):
    """
    Unified model для text та images.

    Text: Autoregressive (next token prediction)
    Images: Diffusion (denoising)
    """

    def __init__(self, vocab_size: int = 32000,
                 image_tokens: int = 256,  # Image token vocabulary
                 hidden_dim: int = 2048,
                 num_layers: int = 32,
                 num_heads: int = 16):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.image_token_offset = vocab_size  # Image tokens after text tokens

        # Text embedding
        self.text_embed = nn.Embedding(vocab_size, hidden_dim)

        # Image patch embedding (continuous)
        self.image_embed = nn.Linear(4, hidden_dim)  # VAE latent dim

        # Timestep embedding (for diffusion)
        self.time_embed = TimestepEmbedder(hidden_dim)

        # Shared transformer
        self.transformer = nn.ModuleList([
            TransfusionBlock(hidden_dim, num_heads)
            for _ in range(num_layers)
        ])

        # Output heads
        self.text_head = nn.Linear(hidden_dim, vocab_size)  # For next token
        self.image_head = nn.Linear(hidden_dim, 4)  # For noise prediction

        # Modality tokens
        self.text_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.image_token = nn.Parameter(torch.randn(1, 1, hidden_dim))

    def forward(self, input_ids: Optional[torch.Tensor] = None,
                image_latents: Optional[torch.Tensor] = None,
                timesteps: Optional[torch.Tensor] = None,
                mode: str = "text") -> Tuple[torch.Tensor, ...]:
        """
        Flexible forward для різних режимів.
        """
        if mode == "text":
            return self._forward_text(input_ids)
        elif mode == "image":
            return self._forward_image(image_latents, timesteps)
        elif mode == "interleaved":
            return self._forward_interleaved(input_ids, image_latents, timesteps)
        else:
            raise ValueError(f"Unknown mode: {mode}")

    def _forward_text(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        Autoregressive text generation.
        """
        B, L = input_ids.shape

        # Embed text
        x = self.text_embed(input_ids)  # [B, L, D]

        # Add modality token
        x = torch.cat([self.text_token.expand(B, -1, -1), x], dim=1)

        # Causal transformer
        for layer in self.transformer:
            x = layer(x, causal=True)

        # Predict next tokens
        logits = self.text_head(x[:, 1:, :])  # Remove modality token

        return logits

    def _forward_image(self, image_latents: torch.Tensor,
                       timesteps: torch.Tensor) -> torch.Tensor:
        """
        Diffusion image generation/denoising.
        """
        B = image_latents.shape[0]

        # Flatten latents to patches
        patches = rearrange(image_latents, 'b c h w -> b (h w) c')

        # Embed patches
        x = self.image_embed(patches)  # [B, N, D]

        # Add timestep conditioning
        t_emb = self.time_embed(timesteps)  # [B, D]

        # Add modality token
        x = torch.cat([self.image_token.expand(B, -1, -1), x], dim=1)

        # Full attention (not causal for diffusion)
        for layer in self.transformer:
            x = layer(x, causal=False, condition=t_emb)

        # Predict noise
        noise_pred = self.image_head(x[:, 1:, :])  # [B, N, 4]

        # Reshape back
        h = w = int(noise_pred.shape[1] ** 0.5)
        noise_pred = rearrange(noise_pred, 'b (h w) c -> b c h w', h=h, w=w)

        return noise_pred

    def _forward_interleaved(self, input_ids: torch.Tensor,
                             image_latents: torch.Tensor,
                             timesteps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Interleaved text and image (for multimodal generation).
        """
        B = input_ids.shape[0]

        # Embed text
        text_emb = self.text_embed(input_ids)  # [B, L_text, D]

        # Embed images
        patches = rearrange(image_latents, 'b c h w -> b (h w) c')
        image_emb = self.image_embed(patches)  # [B, L_img, D]

        # Timestep for images
        t_emb = self.time_embed(timesteps)

        # Concatenate: [text] [image]
        x = torch.cat([text_emb, image_emb], dim=1)

        # Transformer з causal attention для text частини
        for layer in self.transformer:
            x = layer(x, causal=False, condition=t_emb)  # Simplified

        # Split outputs
        L_text = input_ids.shape[1]
        text_out = x[:, :L_text, :]
        image_out = x[:, L_text:, :]

        text_logits = self.text_head(text_out)
        noise_pred = self.image_head(image_out)
        noise_pred = rearrange(noise_pred, 'b (h w) c -> b c h w', h=int(noise_pred.shape[1]**0.5))

        return text_logits, noise_pred


class TransfusionBlock(nn.Module):
    """
    Transformer block з optional causal masking та conditioning.
    """

    def __init__(self, hidden_dim: int, num_heads: int):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, 4 * hidden_dim),
            nn.GELU(),
            nn.Linear(4 * hidden_dim, hidden_dim)
        )

        # Conditioning projection
        self.cond_proj = nn.Linear(hidden_dim, 2 * hidden_dim)

    def forward(self, x: torch.Tensor, causal: bool = False,
                condition: torch.Tensor = None) -> torch.Tensor:
        # Optional conditioning via scale/shift
        if condition is not None:
            scale, shift = self.cond_proj(condition).chunk(2, dim=-1)
            x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

        # Attention
        mask = self._get_causal_mask(x.shape[1], x.device) if causal else None
        h = self.norm1(x)
        h, _ = self.attn(h, h, h, attn_mask=mask)
        x = x + h

        # MLP
        x = x + self.mlp(self.norm2(x))

        return x

    def _get_causal_mask(self, size: int, device) -> torch.Tensor:
        mask = torch.triu(torch.ones(size, size, device=device), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

Multi-Modal Training Strategy

"""
Strategies для training unified multi-modal models.
"""

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Dict, List

class UnifiedTrainer:
    """
    Training unified diffusion model на multiple modalities.
    """

    def __init__(self, model: nn.Module,
                 modalities: List[str],
                 loss_weights: Dict[str, float] = None):
        self.model = model
        self.modalities = modalities
        self.loss_weights = loss_weights or {m: 1.0 for m in modalities}

        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=1e-4,
            weight_decay=0.05
        )

    def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
        """
        Single training step з multi-task balancing.
        """
        self.optimizer.zero_grad()

        total_loss = 0
        losses = {}

        for modality in self.modalities:
            if modality not in batch:
                continue

            data = batch[modality]

            # Compute modality-specific loss
            if modality == "text":
                loss = self._text_loss(data)
            elif modality in ["image", "video", "audio"]:
                loss = self._diffusion_loss(data, modality)
            else:
                continue

            weighted_loss = self.loss_weights[modality] * loss
            total_loss = total_loss + weighted_loss
            losses[modality] = loss.item()

        total_loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

        self.optimizer.step()

        return losses

    def _text_loss(self, data: Dict) -> torch.Tensor:
        """Cross-entropy loss для autoregressive text."""
        input_ids = data['input_ids']
        labels = data['labels']

        logits = self.model(input_ids=input_ids, mode="text")
        loss = nn.functional.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
            ignore_index=-100
        )
        return loss

    def _diffusion_loss(self, data: Dict, modality: str) -> torch.Tensor:
        """MSE loss для diffusion denoising."""
        x_0 = data['latents']

        # Random timesteps
        B = x_0.shape[0]
        t = torch.randint(0, 1000, (B,), device=x_0.device)

        # Add noise
        noise = torch.randn_like(x_0)
        x_t = self._add_noise(x_0, t, noise)

        # Predict noise
        pred_noise = self.model(
            image_latents=x_t,
            timesteps=t,
            mode=modality
        )

        loss = nn.functional.mse_loss(pred_noise, noise)
        return loss

    def _add_noise(self, x_0, t, noise):
        """Simplified noise addition."""
        alpha = 1 - t.float() / 1000
        alpha = alpha.view(-1, 1, 1, 1)
        return alpha ** 0.5 * x_0 + (1 - alpha) ** 0.5 * noise


class GradientBalancing:
    """
    Techniques для balancing gradients між modalities.
    """

    @staticmethod
    def gradnorm(losses: Dict[str, torch.Tensor],
                 weights: Dict[str, float],
                 shared_params: List[nn.Parameter],
                 alpha: float = 0.5) -> Dict[str, float]:
        """
        GradNorm: balance gradients по magnitude.
        """
        # Compute gradient norms
        grad_norms = {}
        for task, loss in losses.items():
            grads = torch.autograd.grad(loss, shared_params, retain_graph=True)
            grad_norm = sum(g.norm() ** 2 for g in grads if g is not None) ** 0.5
            grad_norms[task] = grad_norm

        # Target: average gradient norm
        avg_norm = sum(grad_norms.values()) / len(grad_norms)

        # Update weights
        new_weights = {}
        for task, norm in grad_norms.items():
            relative_loss = losses[task] / sum(losses.values())
            target_norm = avg_norm * (relative_loss ** alpha)
            new_weights[task] = weights[task] * (target_norm / (norm + 1e-8))

        return new_weights

    @staticmethod
    def pcgrad(grads: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        PCGrad: Project conflicting gradients.
        """
        task_names = list(grads.keys())
        num_tasks = len(task_names)

        # Stack gradients
        grad_list = [grads[name].flatten() for name in task_names]

        # Project conflicting gradients
        for i in range(num_tasks):
            for j in range(num_tasks):
                if i == j:
                    continue

                dot = torch.dot(grad_list[i], grad_list[j])
                if dot < 0:
                    # Conflicting: project out component
                    grad_list[i] = grad_list[i] - dot * grad_list[j] / (grad_list[j].norm() ** 2 + 1e-8)

        # Average projected gradients
        return torch.stack(grad_list).mean(dim=0)

Бенчмарки уніфікованих моделей

| Model | Image FID | Video FVD | Audio FAD | Params | Training |

|-------|-----------|-----------|-----------|--------|----------|

| Specialized (ensemble) | 2.1 | 180 | 1.2 | 3 × 3B | 3 × weeks |

| DiT-XL (image only) | 2.3 | - | - | 675M | 1 week |

| Unified-IO 2 | 3.5 | 220 | 1.8 | 7B | 2 weeks |

| Transfusion | 2.8 | 195 | - | 7B | 2 weeks |

| OmniGen | 3.1 | 205 | 1.5 | 3.8B | 10 days |

Спостереження: Unified моделі наближаються до specialized за якістю, при цьому маючи значно меншу сумарну кількість параметрів.


Ідеї для дослідження

Для бакалавра:

  • Fine-tune DiT на специфічному domain
  • Порівняння U-Net vs DiT для image generation
  • Візуалізація attention maps в DiT

Для магістра:

  • Додавання нової модальності до unified model
  • Task balancing strategies для multi-modal training
  • Efficient inference для великих DiT моделей

Для PhD:

  • Theoretical foundations of unified representations
  • Scaling laws для multi-modal diffusion
  • Novel architectures для 4D+ data (video + audio + depth)
  • Emergent capabilities in unified models

Висновок: шлях до AGI-style generation

GPT показав: scale + unified architecture = emergence нових здібностей. Generative AI йде тим самим шляхом. Замість спеціалізованих моделей для кожної модальності — unified foundation models для генерації будь-чого.

Sora, Gemini 2.0, майбутні системи — все буде unified. Одна модель, яка генерує зображення, відео, аудіо, 3D, текст. Не через stitching окремих компонентів, а через справжнє cross-modal розуміння.

Це не мрія — це trajectory, який чітко видно в research papers 2023-2024 років. Хто розуміє unified architectures сьогодні, буде будувати AGI-style системи завтра. Для досліджень у цій галузі команда SKP-Degree на skp-degree.com.ua готова допомогти з реалізацією DiT-based архітектур та multi-modal training. Консультації в Telegram: @kursovi_diplomy.


Unified diffusion, DiT, Diffusion Transformer, Transfusion, multi-modal generation, Sora architecture, foundation models, cross-modal AI — ключові терміни для дипломної чи магістерської роботи з генеративного штучного інтелекту та multi-modal learning.

Про автора

Команда SKP-Degree

Верифікований автор

Розробники та дослідники AI · Python, TensorFlow, PyTorch · Досвід у промисловій розробці

Команда SKP-Degree — професійні розробники з досвідом 7+ років у промисловій розробці. Виконали 1000+ проєктів для студентів з України, Польщі та країн Балтії.

Python Django Java ML/AI React C# / .NET JavaScript

Потрібна допомога з роботою?

Замовте курсову чи дипломну роботу з програмування. Оплата після демонстрації!

Без передоплати Відеодемонстрація Автономна робота 24/7
Написати в Telegram