Home AI/ML Domain Adaptation for Time-Series Anomaly Detection: Complete Implementation Guide with Full Training Scripts

Domain Adaptation for Time-Series Anomaly Detection: Complete Implementation Guide with Full Training Scripts

Introduction: The Domain Shift Problem in Anomaly Detection

Imagine you spent six months collecting labeled anomaly data from a CNC milling machine on your factory floor. You painstakingly tagged every spindle vibration spike, every thermal drift event, every bearing degradation signature. Your anomaly detection model hits 0.95 AUROC on that machine. Then your company buys a second milling machine—same manufacturer, same model number, but a different production year. You deploy your model, and the AUROC drops to 0.62. Barely better than a coin flip.

This is the domain shift problem, and it is one of the most expensive headaches in industrial machine learning. The statistical distribution of sensor readings changes between machines, factories, sensor brands, and even seasons. Noise floors differ. Baseline amplitudes drift. The relationship between “normal” and “anomalous” subtly warps. Your perfectly trained model becomes useless the moment it leaves its original domain.

The classical solution is to label data in every new domain. But labeling anomaly data is brutally expensive—anomalies are rare by definition, and expert annotators are scarce. What if you could transfer the anomaly detection knowledge from your labeled source domain (machine A) to an unlabeled target domain (machine B) without restarting from scratch?

That is exactly what domain adaptation does. By training a model to learn features that are invariant across domains—features that capture the essence of “anomaly” regardless of which machine produced the signal—you can detect anomalies in new domains with little or no labeled target data. The technique has roots in computer vision (the famous DANN paper by Ganin et al., 2016), but its application to time-series anomaly detection remains underexplored in practice, despite being exactly where it is needed most.

This post is not a theoretical survey. It is a complete, runnable implementation guide. By the end, you will have nine production-ready Python scripts that implement three domain adaptation strategies—DANN (Domain-Adversarial Neural Networks), MMD (Maximum Mean Discrepancy), and CORAL (CORrelation ALignment)—on top of a CNN-LSTM hybrid encoder for multi-channel time-series anomaly detection. Every script is complete. No ellipses, no “fill in the rest,” no pseudocode. Copy, paste, run.

Let us build it.

Project Structure and Setup

Before writing any code, let us establish a clean project layout. Every file has a single responsibility, making the codebase easy to understand and modify for your own use case.

da-anomaly-detection/
├── config.py                    # Hyperparameters and configuration
├── dataset.py                   # Dataset classes and data loading
├── model.py                     # Model architecture (encoder, classifier, discriminator)
├── losses.py                    # Loss function definitions (DANN, MMD, CORAL)
├── train.py                     # Main training script with domain adaptation
├── evaluate.py                  # Evaluation and metrics
├── utils.py                     # Utility functions (seeding, checkpoints, plotting)
├── generate_synthetic_data.py   # Generate example data for testing
├── requirements.txt             # Dependencies
├── data/                        # Generated or real data goes here
├── checkpoints/                 # Saved model weights
└── results/                     # Evaluation outputs, plots, metrics

Start by creating the directory and installing dependencies:

mkdir -p da-anomaly-detection/{data,checkpoints,results}
cd da-anomaly-detection

requirements.txt

torch>=2.0.0
numpy>=1.24.0
pandas>=2.0.0
scikit-learn>=1.3.0
matplotlib>=3.7.0
tqdm>=4.65.0
pip install -r requirements.txt
Tip: If you have a CUDA-capable GPU, install PyTorch with CUDA support for significantly faster training: pip install torch --index-url https://download.pytorch.org/whl/cu121

Configuration and Hyperparameters

Centralizing configuration prevents magic numbers from scattering across your codebase. We use a Python dataclass so the IDE gives you autocompletion and type checking for free.

config.py

"""
config.py — Centralized configuration for domain-adaptive anomaly detection.
All hyperparameters live here. Override via CLI arguments in train.py.
"""

from dataclasses import dataclass, field
import torch
import os


@dataclass
class Config:
    """All hyperparameters and paths for the DA anomaly detection pipeline."""

    # --- Data Parameters ---
    num_features: int = 6           # Number of sensor channels
    window_size: int = 64           # Sliding window length (timesteps)
    stride: int = 16                # Stride for sliding window
    train_ratio: float = 0.8        # Train/val split ratio

    # --- Model Architecture ---
    cnn_channels: list = field(default_factory=lambda: [32, 64, 128])
    cnn_kernel_sizes: list = field(default_factory=lambda: [7, 5, 3])
    lstm_hidden_dim: int = 128
    lstm_num_layers: int = 2
    latent_dim: int = 128           # Dimension of the shared feature space
    classifier_hidden_dim: int = 64
    discriminator_hidden_dim: int = 64
    dropout: float = 0.3

    # --- Training Parameters ---
    batch_size: int = 64
    learning_rate: float = 1e-3
    discriminator_lr: float = 1e-3
    weight_decay: float = 1e-4
    epochs: int = 100
    patience: int = 15              # Early stopping patience

    # --- Domain Adaptation Parameters ---
    adaptation_method: str = "dann"  # 'dann', 'mmd', or 'coral'
    lambda_domain: float = 1.0       # Max domain loss weight
    lambda_recon: float = 0.5        # Reconstruction loss weight
    lambda_cls: float = 1.0          # Classification loss weight
    gamma: float = 10.0              # DANN lambda schedule steepness
    mmd_kernel_bandwidth: list = field(
        default_factory=lambda: [0.01, 0.1, 1.0, 10.0, 100.0]
    )

    # --- Anomaly Scoring ---
    alpha: float = 0.7              # Weight for classifier score vs recon error
    anomaly_threshold_percentile: float = 95.0

    # --- Paths ---
    data_dir: str = "data"
    checkpoint_dir: str = "checkpoints"
    results_dir: str = "results"

    # --- Device and Reproducibility ---
    seed: int = 42
    device: str = ""

    def __post_init__(self):
        if not self.device:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        os.makedirs(self.data_dir, exist_ok=True)
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        os.makedirs(self.results_dir, exist_ok=True)
Key Takeaway: The most sensitive hyperparameter in domain adaptation is lambda_domain. Too high, and the model forgets how to classify anomalies. Too low, and domain adaptation has no effect. The progressive scheduling in our training script (DANN lambda schedule) addresses this by starting low and ramping up.

Generating Realistic Synthetic Data

Before touching real proprietary data, you need a sandbox. The script below generates two-domain synthetic time-series data with realistic characteristics: seasonal patterns, trends, multiple anomaly types, and domain-specific differences in noise, amplitude, and baseline offset. The source domain gets full labels; the target domain training set has no labels (simulating the real scenario), while the target test set has labels for evaluation.

generate_synthetic_data.py

"""
generate_synthetic_data.py — Generate realistic two-domain time-series data
with injected anomalies for testing domain adaptation.

Simulates 6-channel sensor data (e.g., 3 joints x [torque, position]) from
two different machines with different noise/amplitude characteristics.
"""

import argparse
import os
import numpy as np
import pandas as pd


def generate_base_signal(n_samples: int, num_features: int, seed: int = 42) -> np.ndarray:
    """Generate a base multi-channel time-series with realistic patterns."""
    rng = np.random.RandomState(seed)
    t = np.arange(n_samples)
    signals = np.zeros((n_samples, num_features))

    for ch in range(num_features):
        freq1 = 0.002 + ch * 0.001
        freq2 = 0.01 + ch * 0.003
        phase1 = rng.uniform(0, 2 * np.pi)
        phase2 = rng.uniform(0, 2 * np.pi)

        # Seasonal component
        seasonal = 2.0 * np.sin(2 * np.pi * freq1 * t + phase1)
        # Higher-frequency oscillation
        oscillation = 0.8 * np.sin(2 * np.pi * freq2 * t + phase2)
        # Slow trend
        trend = 0.0005 * t * ((-1) ** ch)
        # Combine
        signals[:, ch] = seasonal + oscillation + trend

    return signals


def inject_anomalies(
    signals: np.ndarray,
    anomaly_ratio: float = 0.05,
    seed: int = 42
) -> tuple:
    """
    Inject multiple anomaly types into signals.
    Returns (modified_signals, labels) where labels[i]=1 means anomaly.
    """
    rng = np.random.RandomState(seed)
    n_samples, num_features = signals.shape
    labels = np.zeros(n_samples, dtype=int)
    modified = signals.copy()

    n_anomalies = int(n_samples * anomaly_ratio)
    anomaly_types = ["spike", "drift", "level_shift", "frequency_change"]

    # Choose random anomaly locations (non-overlapping segments)
    segment_length = 20
    max_start = n_samples - segment_length
    starts = rng.choice(max_start, size=n_anomalies, replace=False)

    for i, start in enumerate(starts):
        end = start + segment_length
        a_type = anomaly_types[i % len(anomaly_types)]
        channel = rng.randint(0, num_features)

        if a_type == "spike":
            spike_pos = start + rng.randint(0, segment_length)
            magnitude = rng.uniform(5, 10) * (1 if rng.random() > 0.5 else -1)
            modified[spike_pos, channel] += magnitude
            labels[spike_pos] = 1

        elif a_type == "drift":
            drift = np.linspace(0, rng.uniform(3, 6), segment_length)
            modified[start:end, channel] += drift
            labels[start:end] = 1

        elif a_type == "level_shift":
            shift = rng.uniform(3, 7) * (1 if rng.random() > 0.5 else -1)
            modified[start:end, channel] += shift
            labels[start:end] = 1

        elif a_type == "frequency_change":
            t_seg = np.arange(segment_length)
            high_freq = 2.0 * np.sin(2 * np.pi * 0.15 * t_seg)
            modified[start:end, channel] += high_freq
            labels[start:end] = 1

    return modified, labels


def apply_domain_transform(
    signals: np.ndarray,
    noise_scale: float = 0.3,
    amplitude_scale: float = 1.0,
    baseline_offset: float = 0.0,
    seed: int = 42
) -> np.ndarray:
    """Apply domain-specific transformations to simulate a different machine."""
    rng = np.random.RandomState(seed)
    transformed = signals.copy()
    n_samples, num_features = transformed.shape

    # Per-channel amplitude scaling
    for ch in range(num_features):
        ch_amp = amplitude_scale * rng.uniform(0.8, 1.2)
        ch_offset = baseline_offset + rng.uniform(-0.5, 0.5)
        transformed[:, ch] = transformed[:, ch] * ch_amp + ch_offset

    # Add domain-specific noise
    noise = rng.normal(0, noise_scale, transformed.shape)
    transformed += noise

    return transformed


def generate_dataset(
    n_samples: int,
    num_features: int,
    anomaly_ratio: float,
    noise_scale: float,
    amplitude_scale: float,
    baseline_offset: float,
    seed: int
) -> pd.DataFrame:
    """Generate a complete dataset with signals, anomalies, and domain transform."""
    base = generate_base_signal(n_samples, num_features, seed=seed)
    with_anomalies, labels = inject_anomalies(base, anomaly_ratio, seed=seed + 1)
    transformed = apply_domain_transform(
        with_anomalies,
        noise_scale=noise_scale,
        amplitude_scale=amplitude_scale,
        baseline_offset=baseline_offset,
        seed=seed + 2
    )

    columns = [f"sensor_{i}" for i in range(num_features)]
    df = pd.DataFrame(transformed, columns=columns)
    df["label"] = labels
    df["timestamp"] = pd.date_range("2024-01-01", periods=n_samples, freq="s")
    return df


def main():
    parser = argparse.ArgumentParser(
        description="Generate synthetic two-domain time-series data."
    )
    parser.add_argument("--output_dir", type=str, default="data",
                        help="Output directory for CSV files")
    parser.add_argument("--n_samples", type=int, default=20000,
                        help="Number of samples per dataset")
    parser.add_argument("--num_features", type=int, default=6,
                        help="Number of sensor channels")
    parser.add_argument("--anomaly_ratio", type=float, default=0.05,
                        help="Fraction of timesteps with anomalies")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed")
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    print("Generating source domain data (Machine A)...")
    source_full = generate_dataset(
        n_samples=args.n_samples,
        num_features=args.num_features,
        anomaly_ratio=args.anomaly_ratio,
        noise_scale=0.2,
        amplitude_scale=1.0,
        baseline_offset=0.0,
        seed=args.seed
    )
    split_idx = int(len(source_full) * 0.7)
    source_train = source_full.iloc[:split_idx].reset_index(drop=True)
    source_test = source_full.iloc[split_idx:].reset_index(drop=True)

    print("Generating target domain data (Machine B)...")
    target_full = generate_dataset(
        n_samples=args.n_samples,
        num_features=args.num_features,
        anomaly_ratio=args.anomaly_ratio,
        noise_scale=0.5,           # Higher noise
        amplitude_scale=1.4,       # Different amplitude
        baseline_offset=2.0,       # Shifted baseline
        seed=args.seed + 100
    )
    split_idx_t = int(len(target_full) * 0.7)
    target_train = target_full.iloc[:split_idx_t].reset_index(drop=True)
    target_test = target_full.iloc[split_idx_t:].reset_index(drop=True)

    # Remove labels from target train (unsupervised in target domain)
    target_train_unlabeled = target_train.drop(columns=["label"])

    # Save all files
    source_train.to_csv(os.path.join(args.output_dir, "source_train.csv"), index=False)
    source_test.to_csv(os.path.join(args.output_dir, "source_test.csv"), index=False)
    target_train_unlabeled.to_csv(os.path.join(args.output_dir, "target_train.csv"), index=False)
    target_test.to_csv(os.path.join(args.output_dir, "target_test.csv"), index=False)

    print(f"\nDatasets saved to {args.output_dir}/")
    print(f"  source_train.csv: {len(source_train)} samples, "
          f"{source_train['label'].sum()} anomalies ({source_train['label'].mean()*100:.1f}%)")
    print(f"  source_test.csv:  {len(source_test)} samples, "
          f"{source_test['label'].sum()} anomalies ({source_test['label'].mean()*100:.1f}%)")
    print(f"  target_train.csv: {len(target_train_unlabeled)} samples (no labels)")
    print(f"  target_test.csv:  {len(target_test)} samples, "
          f"{target_test['label'].sum()} anomalies ({target_test['label'].mean()*100:.1f}%)")


if __name__ == "__main__":
    main()

Run it immediately:

python generate_synthetic_data.py --output_dir data/ --n_samples 20000

You will get four CSV files. The source data has labels everywhere. The target training data has no labels—this is the whole point of domain adaptation. The target test data has labels so we can measure how well the adaptation worked.

Dataset Classes and Data Loading

Time-series anomaly detection operates on windows: fixed-length slices of the signal. Our dataset class handles windowing, normalization (fit on source, apply everywhere), and optional data augmentation. The DomainAdaptationDataLoader pairs source and target batches for simultaneous training.

dataset.py

"""
dataset.py — PyTorch Dataset classes for time-series domain adaptation.

Handles sliding-window creation, normalization, augmentation, and
paired source-target batch generation.
"""

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader


class TimeSeriesDataset(Dataset):
    """
    Sliding-window dataset for multi-channel time-series.

    Args:
        data: numpy array of shape (n_samples, num_features)
        labels: numpy array of shape (n_samples,) or None for unlabeled data
        window_size: number of timesteps per window
        stride: step between consecutive windows
        transform: optional callable for data augmentation
    """

    def __init__(
        self,
        data: np.ndarray,
        labels: np.ndarray = None,
        window_size: int = 64,
        stride: int = 16,
        transform=None
    ):
        self.data = data.astype(np.float32)
        self.labels = labels
        self.window_size = window_size
        self.stride = stride
        self.transform = transform

        # Precompute valid window start indices
        self.indices = list(range(0, len(data) - window_size + 1, stride))

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        start = self.indices[idx]
        end = start + self.window_size
        window = self.data[start:end]  # (window_size, num_features)

        if self.transform is not None:
            window = self.transform(window)

        # Transpose to (num_features, window_size) for Conv1d
        window_tensor = torch.tensor(window, dtype=torch.float32).T

        if self.labels is not None:
            # Window label = 1 if any timestep in window is anomalous
            window_label = float(self.labels[start:end].max())
            return window_tensor, torch.tensor(window_label, dtype=torch.float32)
        else:
            return window_tensor, torch.tensor(-1.0, dtype=torch.float32)


class Normalizer:
    """
    Fit on source training data, transform all data.
    Uses per-channel mean and std normalization.
    """

    def __init__(self):
        self.mean = None
        self.std = None

    def fit(self, data: np.ndarray):
        """Compute mean and std from training data."""
        self.mean = data.mean(axis=0)
        self.std = data.std(axis=0)
        # Prevent division by zero
        self.std[self.std < 1e-8] = 1.0
        return self

    def transform(self, data: np.ndarray) -> np.ndarray:
        """Apply normalization."""
        return (data - self.mean) / self.std

    def fit_transform(self, data: np.ndarray) -> np.ndarray:
        """Fit and transform in one step."""
        self.fit(data)
        return self.transform(data)


class JitterTransform:
    """Add random Gaussian noise for data augmentation."""

    def __init__(self, sigma: float = 0.03):
        self.sigma = sigma

    def __call__(self, window: np.ndarray) -> np.ndarray:
        noise = np.random.normal(0, self.sigma, window.shape).astype(np.float32)
        return window + noise


class ScalingTransform:
    """Random per-channel amplitude scaling for data augmentation."""

    def __init__(self, sigma: float = 0.1):
        self.sigma = sigma

    def __call__(self, window: np.ndarray) -> np.ndarray:
        factor = np.random.normal(1.0, self.sigma, (1, window.shape[1])).astype(np.float32)
        return window * factor


class ComposeTransforms:
    """Chain multiple transforms together."""

    def __init__(self, transforms: list):
        self.transforms = transforms

    def __call__(self, window: np.ndarray) -> np.ndarray:
        for t in self.transforms:
            window = t(window)
        return window


def load_csv_data(filepath: str, has_labels: bool = True):
    """
    Load a CSV file and separate features from labels.

    Returns:
        data: numpy array (n_samples, num_features)
        labels: numpy array (n_samples,) or None
    """
    df = pd.read_csv(filepath)
    # Drop non-numeric columns like timestamp
    feature_cols = [c for c in df.columns if c not in ("label", "timestamp")]
    data = df[feature_cols].values.astype(np.float32)
    labels = df["label"].values.astype(np.float32) if (has_labels and "label" in df.columns) else None
    return data, labels


def create_data_loaders(config) -> dict:
    """
    Create all data loaders for domain adaptation training.

    Returns a dict with keys:
        'source_train', 'source_val', 'target_train', 'target_test'
    """
    import os

    # Load raw data
    source_train_data, source_train_labels = load_csv_data(
        os.path.join(config.data_dir, "source_train.csv"), has_labels=True
    )
    source_test_data, source_test_labels = load_csv_data(
        os.path.join(config.data_dir, "source_test.csv"), has_labels=True
    )
    target_train_data, _ = load_csv_data(
        os.path.join(config.data_dir, "target_train.csv"), has_labels=False
    )
    target_test_data, target_test_labels = load_csv_data(
        os.path.join(config.data_dir, "target_test.csv"), has_labels=True
    )

    # Normalize: fit on source train only
    normalizer = Normalizer()
    source_train_data = normalizer.fit_transform(source_train_data)
    source_test_data = normalizer.transform(source_test_data)
    target_train_data = normalizer.transform(target_train_data)
    target_test_data = normalizer.transform(target_test_data)

    # Optional augmentation for training
    train_transform = ComposeTransforms([
        JitterTransform(sigma=0.03),
        ScalingTransform(sigma=0.1),
    ])

    # Create datasets
    source_train_ds = TimeSeriesDataset(
        source_train_data, source_train_labels,
        window_size=config.window_size, stride=config.stride,
        transform=train_transform
    )
    source_test_ds = TimeSeriesDataset(
        source_test_data, source_test_labels,
        window_size=config.window_size, stride=config.stride
    )
    target_train_ds = TimeSeriesDataset(
        target_train_data, labels=None,
        window_size=config.window_size, stride=config.stride,
        transform=train_transform
    )
    target_test_ds = TimeSeriesDataset(
        target_test_data, target_test_labels,
        window_size=config.window_size, stride=config.stride
    )

    # Create loaders
    loaders = {
        "source_train": DataLoader(
            source_train_ds, batch_size=config.batch_size,
            shuffle=True, drop_last=True, num_workers=0
        ),
        "source_test": DataLoader(
            source_test_ds, batch_size=config.batch_size,
            shuffle=False, num_workers=0
        ),
        "target_train": DataLoader(
            target_train_ds, batch_size=config.batch_size,
            shuffle=True, drop_last=True, num_workers=0
        ),
        "target_test": DataLoader(
            target_test_ds, batch_size=config.batch_size,
            shuffle=False, num_workers=0
        ),
    }

    return loaders, normalizer
Caution: Always fit your normalizer on the source training data only. If you fit on the combined source+target data, you leak information about the target distribution, which defeats the purpose of domain adaptation and inflates your evaluation metrics.

The Core Model Architecture

This is the heart of the system. Our architecture has four components working together: a shared encoder that processes time-series windows into a fixed-size feature vector, an anomaly classifier that predicts normal vs. anomaly, a reconstruction decoder that reconstructs the original input (providing an auxiliary anomaly signal), and a domain discriminator that tries to identify which domain produced a given feature vector. The magic ingredient is the Gradient Reversal Layer (GRL): during backpropagation, it flips the sign of gradients flowing from the domain discriminator to the encoder. This forces the encoder to learn features that are maximally uninformative about domain identity—precisely the domain-invariant representations we want.

Architecture:
                        ┌─── Anomaly Classifier (binary: normal/anomaly)
Input → Shared Encoder ─┤
  (time-series)         ├─── Reconstruction Decoder (autoencoder branch)
                        └─── Domain Discriminator (with gradient reversal)

model.py

"""
model.py — Domain-adaptive anomaly detection model architecture.

Components:
  - GradientReversalLayer: reverses gradients for adversarial domain adaptation
  - SharedEncoder: CNN + BiLSTM feature extractor
  - AnomalyClassifier: binary classification head
  - ReconstructionDecoder: autoencoder branch for reconstruction-based scoring
  - DomainDiscriminator: adversarial domain classification head
  - DomainAdaptiveAnomalyDetector: full model combining all components
"""

import torch
import torch.nn as nn
from torch.autograd import Function


class GradientReversalFunction(Function):
    """
    Gradient Reversal Layer (GRL) — Ganin et al., 2016.
    Forward pass: identity.
    Backward pass: negate gradients and scale by lambda.
    """

    @staticmethod
    def forward(ctx, x, lambda_val):
        ctx.lambda_val = lambda_val
        return x.clone()

    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.lambda_val * grad_output, None


class GradientReversalLayer(nn.Module):
    """Module wrapper for the gradient reversal function."""

    def __init__(self, lambda_val: float = 1.0):
        super().__init__()
        self.lambda_val = lambda_val

    def set_lambda(self, lambda_val: float):
        self.lambda_val = lambda_val

    def forward(self, x):
        return GradientReversalFunction.apply(x, self.lambda_val)


class SharedEncoder(nn.Module):
    """
    1D-CNN + Bidirectional LSTM encoder for multi-channel time-series.

    Input shape:  (batch, num_features, window_size)
    Output shape: (batch, latent_dim)
    """

    def __init__(
        self,
        num_features: int = 6,
        cnn_channels: list = None,
        cnn_kernel_sizes: list = None,
        lstm_hidden_dim: int = 128,
        lstm_num_layers: int = 2,
        latent_dim: int = 128,
        dropout: float = 0.3,
    ):
        super().__init__()
        if cnn_channels is None:
            cnn_channels = [32, 64, 128]
        if cnn_kernel_sizes is None:
            cnn_kernel_sizes = [7, 5, 3]

        # Build CNN layers
        cnn_layers = []
        in_channels = num_features
        for out_ch, ks in zip(cnn_channels, cnn_kernel_sizes):
            cnn_layers.extend([
                nn.Conv1d(in_channels, out_ch, kernel_size=ks, padding=ks // 2),
                nn.BatchNorm1d(out_ch),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout),
            ])
            in_channels = out_ch
        self.cnn = nn.Sequential(*cnn_layers)

        # Bidirectional LSTM on top of CNN features
        self.lstm = nn.LSTM(
            input_size=cnn_channels[-1],
            hidden_size=lstm_hidden_dim,
            num_layers=lstm_num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if lstm_num_layers > 1 else 0.0,
        )

        # Project to latent space
        self.fc = nn.Sequential(
            nn.Linear(lstm_hidden_dim * 2, latent_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
        )
        self.latent_dim = latent_dim

    def forward(self, x):
        """
        Args:
            x: (batch, num_features, window_size)
        Returns:
            latent: (batch, latent_dim)
        """
        # CNN: (batch, cnn_channels[-1], window_size)
        cnn_out = self.cnn(x)
        # Transpose for LSTM: (batch, window_size, cnn_channels[-1])
        lstm_in = cnn_out.permute(0, 2, 1)
        # LSTM: (batch, window_size, lstm_hidden*2)
        lstm_out, _ = self.lstm(lstm_in)
        # Take last timestep output
        last_hidden = lstm_out[:, -1, :]
        # Project to latent space
        latent = self.fc(last_hidden)
        return latent


class AnomalyClassifier(nn.Module):
    """
    Binary classification head: normal (0) vs anomaly (1).

    Input:  (batch, latent_dim)
    Output: (batch, 1) — sigmoid logit
    """

    def __init__(self, latent_dim: int = 128, hidden_dim: int = 64, dropout: float = 0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1),
        )

    def forward(self, latent):
        return self.net(latent)


class ReconstructionDecoder(nn.Module):
    """
    Decoder that reconstructs the original input from latent features.
    Uses LSTM + transposed Conv1d layers.

    Input:  (batch, latent_dim)
    Output: (batch, num_features, window_size)
    """

    def __init__(
        self,
        latent_dim: int = 128,
        num_features: int = 6,
        window_size: int = 64,
        lstm_hidden_dim: int = 128,
        dropout: float = 0.3,
    ):
        super().__init__()
        self.window_size = window_size
        self.num_features = num_features
        self.lstm_hidden_dim = lstm_hidden_dim

        # Expand latent to sequence
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, lstm_hidden_dim),
            nn.ReLU(inplace=True),
        )

        # LSTM decoder
        self.lstm = nn.LSTM(
            input_size=lstm_hidden_dim,
            hidden_size=lstm_hidden_dim,
            num_layers=1,
            batch_first=True,
        )

        # Transposed convolutions to reconstruct
        self.deconv = nn.Sequential(
            nn.ConvTranspose1d(lstm_hidden_dim, 64, kernel_size=3, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.ConvTranspose1d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.ConvTranspose1d(32, num_features, kernel_size=3, padding=1),
        )

    def forward(self, latent):
        """
        Args:
            latent: (batch, latent_dim)
        Returns:
            reconstruction: (batch, num_features, window_size)
        """
        batch_size = latent.size(0)
        # Expand to sequence
        expanded = self.fc(latent).unsqueeze(1).repeat(1, self.window_size, 1)
        # LSTM decode
        lstm_out, _ = self.lstm(expanded)
        # Transpose for Conv1d: (batch, lstm_hidden, window_size)
        conv_in = lstm_out.permute(0, 2, 1)
        # Reconstruct
        reconstruction = self.deconv(conv_in)
        return reconstruction


class DomainDiscriminator(nn.Module):
    """
    Domain classification head with Gradient Reversal Layer.
    Classifies whether features came from source (0) or target (1) domain.

    Input:  (batch, latent_dim)
    Output: (batch, 1) — domain logit
    """

    def __init__(self, latent_dim: int = 128, hidden_dim: int = 64, dropout: float = 0.3):
        super().__init__()
        self.grl = GradientReversalLayer(lambda_val=1.0)
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1),
        )

    def set_lambda(self, lambda_val: float):
        self.grl.set_lambda(lambda_val)

    def forward(self, latent):
        reversed_features = self.grl(latent)
        return self.net(reversed_features)


class DomainAdaptiveAnomalyDetector(nn.Module):
    """
    Full domain-adaptive anomaly detection model.
    Combines encoder, anomaly classifier, reconstruction decoder,
    and domain discriminator.
    """

    def __init__(self, config):
        super().__init__()
        self.encoder = SharedEncoder(
            num_features=config.num_features,
            cnn_channels=config.cnn_channels,
            cnn_kernel_sizes=config.cnn_kernel_sizes,
            lstm_hidden_dim=config.lstm_hidden_dim,
            lstm_num_layers=config.lstm_num_layers,
            latent_dim=config.latent_dim,
            dropout=config.dropout,
        )
        self.classifier = AnomalyClassifier(
            latent_dim=config.latent_dim,
            hidden_dim=config.classifier_hidden_dim,
            dropout=config.dropout,
        )
        self.decoder = ReconstructionDecoder(
            latent_dim=config.latent_dim,
            num_features=config.num_features,
            window_size=config.window_size,
            lstm_hidden_dim=config.lstm_hidden_dim,
            dropout=config.dropout,
        )
        self.discriminator = DomainDiscriminator(
            latent_dim=config.latent_dim,
            hidden_dim=config.discriminator_hidden_dim,
            dropout=config.dropout,
        )

    def set_domain_lambda(self, lambda_val: float):
        """Update the GRL lambda for progressive scheduling."""
        self.discriminator.set_lambda(lambda_val)

    def forward(self, x):
        """
        Full forward pass.

        Args:
            x: (batch, num_features, window_size)

        Returns:
            anomaly_logits:  (batch, 1) — raw logits for anomaly classification
            reconstruction:  (batch, num_features, window_size) — reconstructed input
            domain_logits:   (batch, 1) — raw logits for domain classification
            latent_features: (batch, latent_dim) — shared latent representation
        """
        latent = self.encoder(x)
        anomaly_logits = self.classifier(latent)
        reconstruction = self.decoder(latent)
        domain_logits = self.discriminator(latent)
        return anomaly_logits, reconstruction, domain_logits, latent
Key Takeaway: The Gradient Reversal Layer is just two lines of custom autograd code, but it is the entire mechanism that makes DANN work. During the forward pass, it does nothing. During the backward pass, it negates the gradient. This simple trick turns a standard domain classifier into an adversarial training signal that forces the encoder to produce domain-invariant features.

Loss Functions: DANN, MMD, and CORAL

Domain adaptation is not one technique—it is a family of techniques, each with different strengths. Our implementation supports three approaches, all selectable via a single config flag. DANN uses adversarial training (the discriminator approach). MMD directly minimizes the statistical distance between source and target feature distributions using a kernel trick. CORAL aligns the second-order statistics (covariance matrices) of the two domains. You can switch between them in one line of config.

losses.py

"""
losses.py — Loss functions for domain-adaptive anomaly detection.

Includes:
  - AnomalyDetectionLoss (BCE for anomaly classification)
  - ReconstructionLoss (MSE for autoencoder)
  - DomainAdversarialLoss (BCE for domain discrimination)
  - MMDLoss (Maximum Mean Discrepancy with Gaussian kernel)
  - CORALLoss (CORrelation ALignment)
  - CombinedLoss (weighted combination of all losses)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class AnomalyDetectionLoss(nn.Module):
    """Binary cross-entropy loss for anomaly classification."""

    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, logits, labels):
        """
        Args:
            logits: (batch, 1) raw anomaly logits
            labels: (batch,) binary labels (0=normal, 1=anomaly)
        """
        return self.bce(logits.squeeze(-1), labels)


class ReconstructionLoss(nn.Module):
    """MSE loss between input and reconstruction."""

    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

    def forward(self, reconstruction, original):
        """
        Args:
            reconstruction: (batch, num_features, window_size)
            original: (batch, num_features, window_size)
        """
        return self.mse(reconstruction, original)


class DomainAdversarialLoss(nn.Module):
    """BCE loss for domain classification (used with GRL for DANN)."""

    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, domain_logits, domain_labels):
        """
        Args:
            domain_logits: (batch, 1) raw domain logits
            domain_labels: (batch,) domain labels (0=source, 1=target)
        """
        return self.bce(domain_logits.squeeze(-1), domain_labels)


class MMDLoss(nn.Module):
    """
    Maximum Mean Discrepancy loss with multi-scale Gaussian kernel.

    Measures the distance between source and target feature distributions
    in a reproducing kernel Hilbert space (RKHS).
    """

    def __init__(self, kernel_bandwidths: list = None):
        super().__init__()
        if kernel_bandwidths is None:
            self.kernel_bandwidths = [0.01, 0.1, 1.0, 10.0, 100.0]
        else:
            self.kernel_bandwidths = kernel_bandwidths

    def gaussian_kernel(self, x, y):
        """
        Compute multi-scale Gaussian kernel matrix between x and y.

        Args:
            x: (n, d) tensor
            y: (m, d) tensor
        Returns:
            kernel_val: scalar — sum of Gaussian kernel values across bandwidths
        """
        # Pairwise squared distances
        xx = torch.mm(x, x.t())
        yy = torch.mm(y, y.t())
        xy = torch.mm(x, y.t())

        rx = xx.diag().unsqueeze(0).expand_as(xx)
        ry = yy.diag().unsqueeze(0).expand_as(yy)

        dxx = rx.t() + rx - 2.0 * xx
        dyy = ry.t() + ry - 2.0 * yy
        dxy = rx.t() + ry - 2.0 * xy

        k_xx = torch.zeros_like(xx)
        k_yy = torch.zeros_like(yy)
        k_xy = torch.zeros_like(xy)

        for bw in self.kernel_bandwidths:
            k_xx += torch.exp(-dxx / (2.0 * bw))
            k_yy += torch.exp(-dyy / (2.0 * bw))
            k_xy += torch.exp(-dxy / (2.0 * bw))

        return k_xx, k_yy, k_xy

    def forward(self, source_features, target_features):
        """
        Compute MMD^2 between source and target feature distributions.

        Args:
            source_features: (n, d) latent features from source domain
            target_features:  (m, d) latent features from target domain
        Returns:
            mmd_loss: scalar
        """
        n = source_features.size(0)
        m = target_features.size(0)

        k_xx, k_yy, k_xy = self.gaussian_kernel(source_features, target_features)

        mmd = (k_xx.sum() / (n * n)
               + k_yy.sum() / (m * m)
               - 2.0 * k_xy.sum() / (n * m))

        return mmd


class CORALLoss(nn.Module):
    """
    CORrelation ALignment loss.

    Aligns the second-order statistics (covariance matrices) of
    source and target feature distributions.
    """

    def __init__(self):
        super().__init__()

    def forward(self, source_features, target_features):
        """
        Compute CORAL loss.

        Args:
            source_features: (n, d) latent features from source domain
            target_features:  (m, d) latent features from target domain
        Returns:
            coral_loss: scalar
        """
        d = source_features.size(1)
        n_s = source_features.size(0)
        n_t = target_features.size(0)

        # Compute covariance matrices
        source_centered = source_features - source_features.mean(dim=0, keepdim=True)
        target_centered = target_features - target_features.mean(dim=0, keepdim=True)

        cov_source = (source_centered.t() @ source_centered) / (n_s - 1)
        cov_target = (target_centered.t() @ target_centered) / (n_t - 1)

        # Frobenius norm of covariance difference
        diff = cov_source - cov_target
        coral_loss = (diff * diff).sum() / (4 * d * d)

        return coral_loss


class CombinedLoss(nn.Module):
    """
    Combines anomaly detection, reconstruction, and domain adaptation losses.

    total_loss = lambda_cls * anomaly_loss
               + lambda_recon * recon_loss
               + lambda_domain * domain_loss

    The domain_loss component uses DANN, MMD, or CORAL depending on config.
    """

    def __init__(self, config):
        super().__init__()
        self.anomaly_loss_fn = AnomalyDetectionLoss()
        self.recon_loss_fn = ReconstructionLoss()
        self.dann_loss_fn = DomainAdversarialLoss()
        self.mmd_loss_fn = MMDLoss(kernel_bandwidths=config.mmd_kernel_bandwidth)
        self.coral_loss_fn = CORALLoss()

        self.lambda_cls = config.lambda_cls
        self.lambda_recon = config.lambda_recon
        self.lambda_domain = config.lambda_domain
        self.method = config.adaptation_method

    def forward(
        self,
        anomaly_logits,
        anomaly_labels,
        reconstruction,
        original,
        domain_logits=None,
        domain_labels=None,
        source_features=None,
        target_features=None,
        current_lambda=None,
    ):
        """
        Compute combined loss.

        Args:
            anomaly_logits: (batch, 1) anomaly classification logits (source only)
            anomaly_labels: (batch,) anomaly labels (source only)
            reconstruction: (batch, num_features, window_size) reconstruction
            original: (batch, num_features, window_size) original input
            domain_logits: (batch, 1) domain logits (DANN only)
            domain_labels: (batch,) domain labels (DANN only)
            source_features: (n, d) source latent features (MMD/CORAL)
            target_features: (m, d) target latent features (MMD/CORAL)
            current_lambda: float — current domain adaptation weight

        Returns:
            total_loss, loss_dict (breakdown of individual losses)
        """
        domain_weight = current_lambda if current_lambda is not None else self.lambda_domain

        # Anomaly classification loss (source only)
        cls_loss = self.anomaly_loss_fn(anomaly_logits, anomaly_labels)

        # Reconstruction loss (both domains)
        recon_loss = self.recon_loss_fn(reconstruction, original)

        # Domain adaptation loss
        if self.method == "dann" and domain_logits is not None:
            domain_loss = self.dann_loss_fn(domain_logits, domain_labels)
        elif self.method == "mmd" and source_features is not None:
            domain_loss = self.mmd_loss_fn(source_features, target_features)
        elif self.method == "coral" and source_features is not None:
            domain_loss = self.coral_loss_fn(source_features, target_features)
        else:
            domain_loss = torch.tensor(0.0, device=anomaly_logits.device)

        total_loss = (
            self.lambda_cls * cls_loss
            + self.lambda_recon * recon_loss
            + domain_weight * domain_loss
        )

        loss_dict = {
            "total": total_loss.item(),
            "classification": cls_loss.item(),
            "reconstruction": recon_loss.item(),
            "domain": domain_loss.item(),
        }

        return total_loss, loss_dict

The Main Training Script

This is where everything comes together. The training loop handles the delicate dance of simultaneously training the anomaly classifier (on labeled source data), the reconstruction decoder (on both domains), and the domain discriminator (adversarially, on both domains). The DANN lambda schedule progressively increases the domain adaptation strength over training, following the formula from the original paper: λp = 2 / (1 + exp(-γ · p)) - 1, where p is the training progress from 0 to 1.

train.py

"""
train.py — Main training script for domain-adaptive anomaly detection.

Supports three adaptation methods: DANN, MMD, CORAL.
Uses progressive lambda scheduling for stable training.
"""

import argparse
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm

from config import Config
from dataset import create_data_loaders
from model import DomainAdaptiveAnomalyDetector
from losses import CombinedLoss
from utils import (
    set_seed,
    EarlyStopping,
    save_checkpoint,
    MetricLogger,
)


def compute_dann_lambda(epoch: int, total_epochs: int, gamma: float = 10.0) -> float:
    """
    Progressive lambda schedule from the DANN paper (Ganin et al., 2016).
    Ramps from 0 to 1 over training using a sigmoid-like schedule.

    lambda_p = 2 / (1 + exp(-gamma * p)) - 1, where p = epoch / total_epochs
    """
    p = epoch / total_epochs
    return float(2.0 / (1.0 + np.exp(-gamma * p)) - 1.0)


def train_one_epoch(
    model,
    source_loader,
    target_loader,
    criterion,
    optimizer,
    device,
    epoch,
    total_epochs,
    config,
):
    """Train for one epoch with domain adaptation."""
    model.train()
    epoch_losses = {"total": 0, "classification": 0, "reconstruction": 0, "domain": 0}
    n_batches = 0

    # Compute current domain adaptation lambda
    current_lambda = compute_dann_lambda(epoch, total_epochs, config.gamma) * config.lambda_domain

    # Set the GRL lambda in the model
    model.set_domain_lambda(current_lambda)

    # Zip source and target loaders (cycle the shorter one)
    target_iter = iter(target_loader)

    for source_batch, source_labels in source_loader:
        # Get target batch (cycle if exhausted)
        try:
            target_batch, _ = next(target_iter)
        except StopIteration:
            target_iter = iter(target_loader)
            target_batch, _ = next(target_iter)

        source_batch = source_batch.to(device)
        source_labels = source_labels.to(device)
        target_batch = target_batch.to(device)

        # Determine actual batch sizes (may differ)
        bs_s = source_batch.size(0)
        bs_t = target_batch.size(0)

        # Forward pass: source domain
        s_anomaly_logits, s_recon, s_domain_logits, s_latent = model(source_batch)

        # Forward pass: target domain
        t_anomaly_logits, t_recon, t_domain_logits, t_latent = model(target_batch)

        # Combine reconstructions and originals for loss
        all_recon = torch.cat([s_recon, t_recon], dim=0)
        all_original = torch.cat([source_batch, target_batch], dim=0)

        # Domain labels: 0 for source, 1 for target
        domain_labels = torch.cat([
            torch.zeros(bs_s, device=device),
            torch.ones(bs_t, device=device),
        ])
        all_domain_logits = torch.cat([s_domain_logits, t_domain_logits], dim=0)

        # Compute combined loss
        total_loss, loss_dict = criterion(
            anomaly_logits=s_anomaly_logits,
            anomaly_labels=source_labels,
            reconstruction=all_recon,
            original=all_original,
            domain_logits=all_domain_logits,
            domain_labels=domain_labels,
            source_features=s_latent,
            target_features=t_latent,
            current_lambda=current_lambda,
        )

        # Backprop
        optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        # Accumulate losses
        for key in epoch_losses:
            epoch_losses[key] += loss_dict[key]
        n_batches += 1

    # Average losses
    for key in epoch_losses:
        epoch_losses[key] /= max(n_batches, 1)

    epoch_losses["lambda"] = current_lambda
    return epoch_losses


@torch.no_grad()
def validate(model, loader, criterion, device, config):
    """Validate on a labeled dataset (source test or target test)."""
    model.eval()
    all_logits = []
    all_labels = []
    total_recon_loss = 0
    n_batches = 0

    for batch, labels in loader:
        batch = batch.to(device)
        labels = labels.to(device)

        anomaly_logits, recon, _, latent = model(batch)
        recon_loss = nn.MSELoss()(recon, batch)

        all_logits.append(anomaly_logits.squeeze(-1).cpu())
        all_labels.append(labels.cpu())
        total_recon_loss += recon_loss.item()
        n_batches += 1

    all_logits = torch.cat(all_logits)
    all_labels = torch.cat(all_labels)

    # Compute metrics
    probs = torch.sigmoid(all_logits)
    preds = (probs > 0.5).float()
    accuracy = (preds == all_labels).float().mean().item()

    from sklearn.metrics import roc_auc_score, f1_score
    try:
        auroc = roc_auc_score(all_labels.numpy(), probs.numpy())
    except ValueError:
        auroc = 0.5  # Only one class present
    f1 = f1_score(all_labels.numpy(), preds.numpy(), zero_division=0)

    return {
        "accuracy": accuracy,
        "auroc": auroc,
        "f1": f1,
        "recon_loss": total_recon_loss / max(n_batches, 1),
    }


def main():
    parser = argparse.ArgumentParser(description="Train domain-adaptive anomaly detector")
    parser.add_argument("--method", type=str, default="dann",
                        choices=["dann", "mmd", "coral"],
                        help="Domain adaptation method")
    parser.add_argument("--epochs", type=int, default=None)
    parser.add_argument("--batch_size", type=int, default=None)
    parser.add_argument("--lr", type=float, default=None)
    parser.add_argument("--lambda_domain", type=float, default=None)
    parser.add_argument("--lambda_recon", type=float, default=None)
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--data_dir", type=str, default=None)
    parser.add_argument("--device", type=str, default=None)
    args = parser.parse_args()

    # Build config with CLI overrides
    config = Config()
    config.adaptation_method = args.method
    if args.epochs is not None:
        config.epochs = args.epochs
    if args.batch_size is not None:
        config.batch_size = args.batch_size
    if args.lr is not None:
        config.learning_rate = args.lr
    if args.lambda_domain is not None:
        config.lambda_domain = args.lambda_domain
    if args.lambda_recon is not None:
        config.lambda_recon = args.lambda_recon
    if args.seed is not None:
        config.seed = args.seed
    if args.data_dir is not None:
        config.data_dir = args.data_dir
    if args.device is not None:
        config.device = args.device

    # Setup
    set_seed(config.seed)
    device = torch.device(config.device)
    print(f"Using device: {device}")
    print(f"Adaptation method: {config.adaptation_method}")
    print(f"Epochs: {config.epochs}, Batch size: {config.batch_size}, LR: {config.learning_rate}")

    # Data
    print("\nLoading data...")
    loaders, normalizer = create_data_loaders(config)
    print(f"Source train batches: {len(loaders['source_train'])}")
    print(f"Target train batches: {len(loaders['target_train'])}")

    # Model
    model = DomainAdaptiveAnomalyDetector(config).to(device)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\nModel parameters: {total_params:,}")

    # Optimizer (single optimizer for simplicity; separate LRs via param groups)
    optimizer = Adam([
        {"params": model.encoder.parameters(), "lr": config.learning_rate},
        {"params": model.classifier.parameters(), "lr": config.learning_rate},
        {"params": model.decoder.parameters(), "lr": config.learning_rate},
        {"params": model.discriminator.parameters(), "lr": config.discriminator_lr},
    ], weight_decay=config.weight_decay)

    scheduler = CosineAnnealingLR(optimizer, T_max=config.epochs, eta_min=1e-6)

    # Loss
    criterion = CombinedLoss(config)

    # Early stopping
    early_stopping = EarlyStopping(patience=config.patience, mode="max")

    # Logging
    logger = MetricLogger(config.results_dir)

    # Training loop
    best_target_auroc = 0.0
    print("\n" + "=" * 60)
    print("Starting training...")
    print("=" * 60)

    for epoch in range(config.epochs):
        start_time = time.time()

        # Train
        train_losses = train_one_epoch(
            model, loaders["source_train"], loaders["target_train"],
            criterion, optimizer, device, epoch, config.epochs, config
        )

        # Validate on source test
        source_metrics = validate(model, loaders["source_test"], criterion, device, config)

        # Evaluate on target test (the real metric we care about)
        target_metrics = validate(model, loaders["target_test"], criterion, device, config)

        scheduler.step()

        elapsed = time.time() - start_time

        # Log
        logger.log(epoch, train_losses, source_metrics, target_metrics)

        # Print progress
        if epoch % 5 == 0 or epoch == config.epochs - 1:
            print(
                f"Epoch {epoch:3d}/{config.epochs} ({elapsed:.1f}s) | "
                f"Loss: {train_losses['total']:.4f} "
                f"[cls={train_losses['classification']:.4f}, "
                f"rec={train_losses['reconstruction']:.4f}, "
                f"dom={train_losses['domain']:.4f}] | "
                f"λ={train_losses['lambda']:.3f} | "
                f"Src AUROC: {source_metrics['auroc']:.4f} | "
                f"Tgt AUROC: {target_metrics['auroc']:.4f}"
            )

        # Save best model (based on target AUROC)
        if target_metrics["auroc"] > best_target_auroc:
            best_target_auroc = target_metrics["auroc"]
            save_checkpoint(
                model, optimizer, epoch, target_metrics,
                os.path.join(config.checkpoint_dir, "best_model.pt")
            )

        # Early stopping on target AUROC
        if early_stopping.step(target_metrics["auroc"]):
            print(f"\nEarly stopping triggered at epoch {epoch}")
            break

    print("\n" + "=" * 60)
    print(f"Training complete. Best target AUROC: {best_target_auroc:.4f}")
    print(f"Best model saved to: {config.checkpoint_dir}/best_model.pt")
    print("=" * 60)

    # Save training curves
    logger.save()
    logger.plot_training_curves()


if __name__ == "__main__":
    main()
Tip: The key metric to watch is target AUROC, not source AUROC. Source AUROC tells you the model can classify anomalies where it has labels—that is expected. Target AUROC tells you if domain adaptation is actually transferring anomaly detection knowledge to the unlabeled domain.

Evaluation and Metrics

After training, we need rigorous evaluation on the target domain. Our evaluation script computes standard anomaly detection metrics, combines classifier and reconstruction scores, implements multiple threshold strategies, and generates diagnostic plots. This is where you find out if domain adaptation actually worked.

evaluate.py

"""
evaluate.py — Evaluation script for domain-adaptive anomaly detection.

Loads a trained model and evaluates on target domain test data.
Computes AUROC, AUPRC, F1, precision, recall.
Generates diagnostic plots and saves results to JSON.
"""

import argparse
import json
import os
import numpy as np
import torch
import torch.nn as nn
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    f1_score,
    precision_score,
    recall_score,
    accuracy_score,
    confusion_matrix,
    roc_curve,
    precision_recall_curve,
)

from config import Config
from dataset import create_data_loaders
from model import DomainAdaptiveAnomalyDetector
from utils import set_seed, load_checkpoint


def compute_anomaly_scores(model, loader, device, alpha=0.7):
    """
    Compute anomaly scores combining classifier output and reconstruction error.

    anomaly_score = alpha * classifier_prob + (1 - alpha) * normalized_recon_error

    Returns:
        scores: numpy array of anomaly scores
        labels: numpy array of ground truth labels
        recon_errors: numpy array of per-sample reconstruction errors
        classifier_probs: numpy array of classifier probabilities
        latent_features: numpy array of latent features (for t-SNE)
    """
    model.eval()
    all_probs = []
    all_labels = []
    all_recon_errors = []
    all_latent = []

    with torch.no_grad():
        for batch, labels in loader:
            batch = batch.to(device)
            anomaly_logits, recon, _, latent = model(batch)

            # Classifier probability
            probs = torch.sigmoid(anomaly_logits.squeeze(-1))

            # Per-sample reconstruction error (mean across features and time)
            recon_error = ((recon - batch) ** 2).mean(dim=(1, 2))

            all_probs.append(probs.cpu().numpy())
            all_labels.append(labels.numpy())
            all_recon_errors.append(recon_error.cpu().numpy())
            all_latent.append(latent.cpu().numpy())

    all_probs = np.concatenate(all_probs)
    all_labels = np.concatenate(all_labels)
    all_recon_errors = np.concatenate(all_recon_errors)
    all_latent = np.concatenate(all_latent)

    # Normalize reconstruction errors to [0, 1]
    re_min, re_max = all_recon_errors.min(), all_recon_errors.max()
    if re_max - re_min > 1e-8:
        norm_recon = (all_recon_errors - re_min) / (re_max - re_min)
    else:
        norm_recon = np.zeros_like(all_recon_errors)

    # Combined anomaly score
    scores = alpha * all_probs + (1 - alpha) * norm_recon

    return scores, all_labels, all_recon_errors, all_probs, all_latent


def find_optimal_threshold(labels, scores):
    """Find the threshold that maximizes F1 score."""
    thresholds = np.linspace(0, 1, 200)
    best_f1 = 0
    best_thresh = 0.5

    for thresh in thresholds:
        preds = (scores >= thresh).astype(int)
        f1 = f1_score(labels, preds, zero_division=0)
        if f1 > best_f1:
            best_f1 = f1
            best_thresh = thresh

    return best_thresh, best_f1


def compute_all_metrics(labels, scores, threshold):
    """Compute all evaluation metrics at a given threshold."""
    preds = (scores >= threshold).astype(int)
    metrics = {
        "auroc": float(roc_auc_score(labels, scores)),
        "auprc": float(average_precision_score(labels, scores)),
        "f1": float(f1_score(labels, preds, zero_division=0)),
        "precision": float(precision_score(labels, preds, zero_division=0)),
        "recall": float(recall_score(labels, preds, zero_division=0)),
        "accuracy": float(accuracy_score(labels, preds)),
        "threshold": float(threshold),
    }

    cm = confusion_matrix(labels, preds)
    metrics["confusion_matrix"] = cm.tolist()
    metrics["true_negatives"] = int(cm[0, 0])
    metrics["false_positives"] = int(cm[0, 1])
    metrics["false_negatives"] = int(cm[1, 0])
    metrics["true_positives"] = int(cm[1, 1])

    return metrics


def plot_roc_curve(labels, scores, save_path):
    """Plot and save ROC curve."""
    fpr, tpr, _ = roc_curve(labels, scores)
    auroc = roc_auc_score(labels, scores)

    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(fpr, tpr, "b-", linewidth=2, label=f"AUROC = {auroc:.4f}")
    ax.plot([0, 1], [0, 1], "k--", alpha=0.5, label="Random")
    ax.set_xlabel("False Positive Rate", fontsize=12)
    ax.set_ylabel("True Positive Rate", fontsize=12)
    ax.set_title("ROC Curve — Target Domain", fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"ROC curve saved to {save_path}")


def plot_pr_curve(labels, scores, save_path):
    """Plot and save Precision-Recall curve."""
    precision, recall, _ = precision_recall_curve(labels, scores)
    auprc = average_precision_score(labels, scores)

    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(recall, precision, "r-", linewidth=2, label=f"AUPRC = {auprc:.4f}")
    baseline = labels.sum() / len(labels)
    ax.axhline(y=baseline, color="k", linestyle="--", alpha=0.5, label=f"Baseline = {baseline:.3f}")
    ax.set_xlabel("Recall", fontsize=12)
    ax.set_ylabel("Precision", fontsize=12)
    ax.set_title("Precision-Recall Curve — Target Domain", fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"PR curve saved to {save_path}")


def plot_score_distribution(labels, scores, threshold, save_path):
    """Plot anomaly score distribution for normal vs anomaly samples."""
    fig, ax = plt.subplots(figsize=(10, 6))

    normal_scores = scores[labels == 0]
    anomaly_scores = scores[labels == 1]

    ax.hist(normal_scores, bins=50, alpha=0.6, color="steelblue", label="Normal", density=True)
    ax.hist(anomaly_scores, bins=50, alpha=0.6, color="indianred", label="Anomaly", density=True)
    ax.axvline(x=threshold, color="black", linestyle="--", linewidth=2,
               label=f"Threshold = {threshold:.3f}")
    ax.set_xlabel("Anomaly Score", fontsize=12)
    ax.set_ylabel("Density", fontsize=12)
    ax.set_title("Anomaly Score Distribution — Target Domain", fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"Score distribution saved to {save_path}")


def plot_reconstruction_error(recon_errors, labels, save_path):
    """Plot reconstruction error over sample index, colored by label."""
    fig, ax = plt.subplots(figsize=(14, 5))

    indices = np.arange(len(recon_errors))
    normal_mask = labels == 0
    anomaly_mask = labels == 1

    ax.scatter(indices[normal_mask], recon_errors[normal_mask],
               s=2, alpha=0.4, c="steelblue", label="Normal")
    ax.scatter(indices[anomaly_mask], recon_errors[anomaly_mask],
               s=8, alpha=0.8, c="indianred", label="Anomaly")
    ax.set_xlabel("Sample Index", fontsize=12)
    ax.set_ylabel("Reconstruction Error", fontsize=12)
    ax.set_title("Reconstruction Error Over Time — Target Domain", fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"Reconstruction error plot saved to {save_path}")


def main():
    parser = argparse.ArgumentParser(description="Evaluate domain-adaptive anomaly detector")
    parser.add_argument("--checkpoint", type=str,
                        default="checkpoints/best_model.pt",
                        help="Path to model checkpoint")
    parser.add_argument("--data_dir", type=str, default="data",
                        help="Data directory")
    parser.add_argument("--results_dir", type=str, default="results",
                        help="Output directory for results")
    parser.add_argument("--alpha", type=float, default=0.7,
                        help="Weight for classifier score vs recon error")
    parser.add_argument("--method", type=str, default="dann",
                        choices=["dann", "mmd", "coral"])
    parser.add_argument("--device", type=str, default="")
    args = parser.parse_args()

    config = Config()
    config.data_dir = args.data_dir
    config.results_dir = args.results_dir
    config.adaptation_method = args.method
    if args.device:
        config.device = args.device

    set_seed(config.seed)
    device = torch.device(config.device)
    os.makedirs(config.results_dir, exist_ok=True)

    print(f"Device: {device}")
    print(f"Loading checkpoint: {args.checkpoint}")

    # Load model
    model = DomainAdaptiveAnomalyDetector(config).to(device)
    checkpoint = load_checkpoint(args.checkpoint, model, device=device)
    print(f"Loaded model from epoch {checkpoint.get('epoch', '?')}")

    # Load data
    loaders, normalizer = create_data_loaders(config)

    # --- Evaluate on target test set ---
    print("\n--- Target Domain Evaluation ---")
    scores, labels, recon_errors, probs, latent_features = compute_anomaly_scores(
        model, loaders["target_test"], device, alpha=args.alpha
    )

    # Find optimal threshold
    optimal_thresh, optimal_f1 = find_optimal_threshold(labels, scores)
    print(f"Optimal threshold: {optimal_thresh:.4f} (F1 = {optimal_f1:.4f})")

    # Percentile-based threshold
    percentile_thresh = np.percentile(scores, config.anomaly_threshold_percentile)
    print(f"Percentile ({config.anomaly_threshold_percentile}%) threshold: {percentile_thresh:.4f}")

    # Compute metrics at optimal threshold
    metrics_optimal = compute_all_metrics(labels, scores, optimal_thresh)
    metrics_optimal["threshold_method"] = "f1_optimal"

    # Compute metrics at percentile threshold
    metrics_percentile = compute_all_metrics(labels, scores, percentile_thresh)
    metrics_percentile["threshold_method"] = "percentile"

    # Print results
    print(f"\n{'Metric':<20} {'F1-Optimal':>12} {'Percentile':>12}")
    print("-" * 46)
    for key in ["auroc", "auprc", "f1", "precision", "recall", "accuracy"]:
        print(f"{key:<20} {metrics_optimal[key]:>12.4f} {metrics_percentile[key]:>12.4f}")

    # Also evaluate on source test for comparison
    print("\n--- Source Domain Evaluation (baseline) ---")
    src_scores, src_labels, _, _, src_latent = compute_anomaly_scores(
        model, loaders["source_test"], device, alpha=args.alpha
    )
    src_thresh, _ = find_optimal_threshold(src_labels, src_scores)
    src_metrics = compute_all_metrics(src_labels, src_scores, src_thresh)
    print(f"Source AUROC: {src_metrics['auroc']:.4f}, F1: {src_metrics['f1']:.4f}")

    # --- Generate plots ---
    print("\nGenerating plots...")
    plot_roc_curve(labels, scores, os.path.join(config.results_dir, "roc_curve.png"))
    plot_pr_curve(labels, scores, os.path.join(config.results_dir, "pr_curve.png"))
    plot_score_distribution(labels, scores, optimal_thresh,
                           os.path.join(config.results_dir, "score_distribution.png"))
    plot_reconstruction_error(recon_errors, labels,
                             os.path.join(config.results_dir, "recon_error.png"))

    # --- Save results ---
    results = {
        "method": config.adaptation_method,
        "alpha": args.alpha,
        "target_metrics_optimal": metrics_optimal,
        "target_metrics_percentile": metrics_percentile,
        "source_metrics": src_metrics,
    }
    results_path = os.path.join(config.results_dir, "evaluation_results.json")
    with open(results_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nResults saved to {results_path}")


if __name__ == "__main__":
    main()

Utility Functions

The utility module handles reproducibility, early stopping, checkpointing, metric logging, and visualization including t-SNE plots of feature distributions.

utils.py

"""
utils.py — Utility functions for the DA anomaly detection pipeline.

Includes:
  - Seed setting for reproducibility
  - EarlyStopping class
  - Checkpoint save/load
  - MetricLogger with CSV output and plotting
  - t-SNE visualization of domain features
"""

import os
import random
import json
import numpy as np
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt


def set_seed(seed: int = 42):
    """Set random seeds for reproducibility across all libraries."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class EarlyStopping:
    """
    Early stopping to halt training when a metric stops improving.

    Args:
        patience: number of epochs to wait before stopping
        mode: 'min' or 'max' — whether lower or higher is better
        min_delta: minimum improvement to count as progress
    """

    def __init__(self, patience: int = 15, mode: str = "max", min_delta: float = 1e-4):
        self.patience = patience
        self.mode = mode
        self.min_delta = min_delta
        self.counter = 0
        self.best_value = None

    def step(self, value: float) -> bool:
        """
        Check if training should stop.

        Args:
            value: current metric value
        Returns:
            True if training should stop
        """
        if self.best_value is None:
            self.best_value = value
            return False

        if self.mode == "max":
            improved = value > self.best_value + self.min_delta
        else:
            improved = value < self.best_value - self.min_delta

        if improved:
            self.best_value = value
            self.counter = 0
        else:
            self.counter += 1

        return self.counter >= self.patience


def save_checkpoint(model, optimizer, epoch, metrics, filepath):
    """Save model checkpoint."""
    os.makedirs(os.path.dirname(filepath), exist_ok=True)
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "metrics": metrics,
    }, filepath)


def load_checkpoint(filepath, model, optimizer=None, device="cpu"):
    """Load model checkpoint."""
    checkpoint = torch.load(filepath, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint["model_state_dict"])
    if optimizer is not None and "optimizer_state_dict" in checkpoint:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    return checkpoint


class MetricLogger:
    """
    Logs training metrics to memory and saves to CSV/JSON.
    Also generates training curve plots.
    """

    def __init__(self, output_dir: str = "results"):
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        self.history = {
            "epoch": [],
            "train_total_loss": [],
            "train_cls_loss": [],
            "train_recon_loss": [],
            "train_domain_loss": [],
            "train_lambda": [],
            "source_auroc": [],
            "source_f1": [],
            "target_auroc": [],
            "target_f1": [],
        }

    def log(self, epoch, train_losses, source_metrics, target_metrics):
        """Record one epoch of metrics."""
        self.history["epoch"].append(epoch)
        self.history["train_total_loss"].append(train_losses["total"])
        self.history["train_cls_loss"].append(train_losses["classification"])
        self.history["train_recon_loss"].append(train_losses["reconstruction"])
        self.history["train_domain_loss"].append(train_losses["domain"])
        self.history["train_lambda"].append(train_losses.get("lambda", 0))
        self.history["source_auroc"].append(source_metrics["auroc"])
        self.history["source_f1"].append(source_metrics["f1"])
        self.history["target_auroc"].append(target_metrics["auroc"])
        self.history["target_f1"].append(target_metrics["f1"])

    def save(self):
        """Save metrics history to JSON."""
        path = os.path.join(self.output_dir, "training_history.json")
        with open(path, "w") as f:
            json.dump(self.history, f, indent=2)
        print(f"Training history saved to {path}")

    def plot_training_curves(self):
        """Generate and save training curve plots."""
        epochs = self.history["epoch"]

        fig, axes = plt.subplots(2, 2, figsize=(14, 10))

        # Loss curves
        ax = axes[0, 0]
        ax.plot(epochs, self.history["train_total_loss"], label="Total", linewidth=2)
        ax.plot(epochs, self.history["train_cls_loss"], label="Classification", linewidth=1.5)
        ax.plot(epochs, self.history["train_recon_loss"], label="Reconstruction", linewidth=1.5)
        ax.plot(epochs, self.history["train_domain_loss"], label="Domain", linewidth=1.5)
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        ax.set_title("Training Losses")
        ax.legend()
        ax.grid(True, alpha=0.3)

        # AUROC
        ax = axes[0, 1]
        ax.plot(epochs, self.history["source_auroc"], label="Source AUROC", linewidth=2)
        ax.plot(epochs, self.history["target_auroc"], label="Target AUROC", linewidth=2)
        ax.set_xlabel("Epoch")
        ax.set_ylabel("AUROC")
        ax.set_title("AUROC Over Training")
        ax.legend()
        ax.grid(True, alpha=0.3)

        # F1
        ax = axes[1, 0]
        ax.plot(epochs, self.history["source_f1"], label="Source F1", linewidth=2)
        ax.plot(epochs, self.history["target_f1"], label="Target F1", linewidth=2)
        ax.set_xlabel("Epoch")
        ax.set_ylabel("F1 Score")
        ax.set_title("F1 Score Over Training")
        ax.legend()
        ax.grid(True, alpha=0.3)

        # Lambda schedule
        ax = axes[1, 1]
        ax.plot(epochs, self.history["train_lambda"], label="Domain λ", linewidth=2,
                color="purple")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Lambda Value")
        ax.set_title("Domain Adaptation Lambda Schedule")
        ax.legend()
        ax.grid(True, alpha=0.3)

        fig.tight_layout()
        path = os.path.join(self.output_dir, "training_curves.png")
        fig.savefig(path, dpi=150)
        plt.close(fig)
        print(f"Training curves saved to {path}")


def plot_tsne_features(
    source_features: np.ndarray,
    target_features: np.ndarray,
    save_path: str,
    title: str = "t-SNE Feature Visualization",
    max_samples: int = 2000,
):
    """
    Create t-SNE plot showing source vs target feature distributions.

    Args:
        source_features: (n, d) source latent features
        target_features: (m, d) target latent features
        save_path: path to save the plot
        title: plot title
        max_samples: max samples per domain (for speed)
    """
    from sklearn.manifold import TSNE

    # Subsample if needed
    if len(source_features) > max_samples:
        idx = np.random.choice(len(source_features), max_samples, replace=False)
        source_features = source_features[idx]
    if len(target_features) > max_samples:
        idx = np.random.choice(len(target_features), max_samples, replace=False)
        target_features = target_features[idx]

    # Combine and run t-SNE
    combined = np.concatenate([source_features, target_features], axis=0)
    n_source = len(source_features)

    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    embedded = tsne.fit_transform(combined)

    fig, ax = plt.subplots(figsize=(10, 8))
    ax.scatter(embedded[:n_source, 0], embedded[:n_source, 1],
               s=10, alpha=0.5, c="steelblue", label="Source")
    ax.scatter(embedded[n_source:, 0], embedded[n_source:, 1],
               s=10, alpha=0.5, c="indianred", label="Target")
    ax.set_title(title, fontsize=14)
    ax.legend(fontsize=12)
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"t-SNE plot saved to {save_path}")

Running the Full Pipeline

With all nine scripts in place, here is the complete workflow from data generation to final evaluation. Open a terminal in the da-anomaly-detection/ directory and run these commands in order.

Step-by-Step Commands

# Step 1: Install dependencies
pip install -r requirements.txt

# Step 2: Generate synthetic two-domain data
python generate_synthetic_data.py --output_dir data/ --n_samples 20000

# Step 3: Train with DANN (Domain-Adversarial Neural Network)
python train.py --method dann --epochs 100 --batch_size 64 --lr 0.001

# Step 4: Evaluate on target domain
python evaluate.py --checkpoint checkpoints/best_model.pt --data_dir data/ --method dann

# (Optional) Step 5: Train with MMD instead
python train.py --method mmd --epochs 100 --batch_size 64

# (Optional) Step 6: Train with CORAL instead
python train.py --method coral --epochs 100 --batch_size 64

Each training run will print progress every 5 epochs, save the best model checkpoint (based on target domain AUROC), and output training curves to the results/ directory. The evaluation script generates ROC curves, PR curves, score distribution histograms, and reconstruction error time plots.

Understanding the Results

You have run the pipeline and have a results/evaluation_results.json file with numbers. But what do those numbers mean, and how do you know if domain adaptation is actually helping?

Interpreting the Evaluation Metrics

AUROC (Area Under the ROC Curve) is the primary metric. It measures the probability that a randomly chosen anomaly scores higher than a randomly chosen normal sample. An AUROC of 0.5 is random, 1.0 is perfect. For domain adaptation to be considered successful, the target domain AUROC should be significantly higher than the “no adaptation” baseline (training only on source, evaluating on target with no domain adaptation).

AUPRC (Area Under the Precision-Recall Curve) is more informative when anomalies are rare. In highly imbalanced datasets (1% anomaly rate), AUROC can look good even when the model has a high false positive rate. AUPRC penalizes false positives more heavily.

F1 Score is the harmonic mean of precision and recall, computed at the optimal threshold. It gives you a single number that balances false positives and false negatives. For industrial applications, you typically care more about recall (do not miss anomalies) than precision (some false alarms are acceptable).

What Good vs. Bad Domain Adaptation Looks Like

Scenario Source AUROC Target AUROC (no adapt) Target AUROC (with DA) Interpretation
Successful adaptation 0.95 0.62 0.87 Domain adaptation recovered most performance
Negative transfer 0.95 0.65 0.58 DA made things worse; domains may be too different
No domain shift 0.93 0.91 0.92 Little domain shift exists; DA not needed
Partial adaptation 0.95 0.55 0.72 DA helps but gap remains; try tuning or more target data

 

Understanding t-SNE Plots

The t-SNE visualization is your most intuitive diagnostic tool. Run it on the latent features before and after domain adaptation:

  • Before adaptation: You should see two distinct clusters—source samples clumped together in one region, target samples in another. This visual separation confirms that domain shift exists in the data.
  • After successful adaptation: The source and target clusters should overlap significantly. The encoder has learned features that look the same regardless of which domain produced the input. If the anomaly classifier works on source features, it should now work on the (overlapping) target features too.
  • After failed adaptation: Clusters remain separate, or worse, everything collapses to a single point (mode collapse in the discriminator).

When to Use DANN vs. MMD vs. CORAL

Method Mechanism Strengths Weaknesses Best For
DANN Adversarial training via GRL Powerful; learns complex alignment Unstable training; sensitive to hyperparameters Large domain shifts; enough training data
MMD Kernel-based distribution matching Stable training; mathematically principled Expensive for large batches; kernel selection matters Moderate domain shifts; limited compute
CORAL Covariance matrix alignment Simple; fast; no extra hyperparameters Only matches second-order statistics Small domain shifts; quick baseline

 

Tip: Start with CORAL (simplest, fastest) to establish a baseline. If it does not close the gap enough, try MMD. If you need maximum performance and can handle some training instability, use DANN with careful lambda scheduling.

Adapting to Your Own Data

The synthetic data is a sandbox. Here is how to plug in your own time-series data with minimal code changes.

Modifying dataset.py for Your Data Format

Your CSV files need to follow this structure: each row is a timestep, each column (except label and timestamp) is a sensor channel. The column names do not matter as long as label and timestamp are correctly named (or absent). If your data uses a different format, modify the load_csv_data() function:

# Example: your data has columns named 'temp_1', 'temp_2', 'vibration_x', etc.
# and uses 'anomaly' instead of 'label'
def load_csv_data(filepath, has_labels=True):
    df = pd.read_csv(filepath)
    exclude = ["anomaly", "timestamp", "machine_id", "date"]
    feature_cols = [c for c in df.columns if c not in exclude]
    data = df[feature_cols].values.astype(np.float32)
    labels = df["anomaly"].values.astype(np.float32) if has_labels else None
    return data, labels

Adjusting Model Dimensions

If your sensor data has a different number of channels, you only need to change num_features in config.py. The model automatically adjusts. For different sampling rates, adjust window_size—as a rule of thumb, your window should span roughly one “cycle” of the normal operating pattern. For a machine cycling every 5 seconds sampled at 100 Hz, use window_size=500. For slow processes (daily patterns at hourly sampling), use window_size=24.

Handling Class Imbalance

Real anomaly data is heavily imbalanced—often 1% anomalies or less. Three strategies that work well with this codebase:

  1. Weighted BCE loss: Replace BCEWithLogitsLoss() with BCEWithLogitsLoss(pos_weight=torch.tensor([19.0])) where 19.0 is the ratio of normal to anomaly samples.
  2. Focal loss: Down-weights easy negatives. Replace the BCE in AnomalyDetectionLoss.
  3. Oversampling: Use PyTorch’s WeightedRandomSampler to oversample anomaly windows in the source training loader.

Hyperparameter Tuning Guide

The hyperparameters listed below are ordered by sensitivity—tune the top ones first:

  1. lambda_domain (0.1–2.0): The most sensitive parameter. Too high causes the encoder to learn domain-invariant features that are useless for anomaly detection. Too low means no adaptation. Start at 0.5 and adjust.
  2. learning_rate (1e-4–1e-2): Standard neural network tuning. Use cosine annealing.
  3. window_size (32–256): Must capture enough context for anomalies to be visible.
  4. latent_dim (64–256): Larger gives more capacity but risks overfitting.
  5. alpha (0.5–0.9): Anomaly scoring mix. Higher alpha trusts the classifier more; lower trusts reconstruction error more.

Common Issues and Solutions

Domain adaptation training is notoriously finicky. Here is a reference table of problems you will likely encounter and how to fix them.

Problem Symptom Cause Solution
Discriminator mode collapse Domain loss stays at ~0.69 (ln 2) Discriminator outputs 0.5 for everything Increase discriminator LR; add more layers; reduce GRL lambda
Training instability Loss oscillates wildly or diverges Lambda too high too early Use progressive lambda schedule; reduce learning rate; increase gradient clipping
Negative transfer Target AUROC decreases with DA Domains are too different or share no useful structure Reduce lambda_domain; try CORAL (less aggressive); verify domains share anomaly types
High false positive rate Good recall but terrible precision Threshold too low; recon error noisy Increase alpha (trust classifier more); use percentile threshold; add recon error smoothing
Source AUROC drops during DA Classification degrades on source Domain-invariant features lose discriminative power Increase lambda_cls; reduce lambda_domain; train classifier longer before starting DA
Out of memory (GPU) CUDA OOM error Batch size or model too large Reduce batch_size; reduce latent_dim; use gradient accumulation
MMD loss is NaN NaN in training Kernel bandwidth mismatch with feature scale Normalize features; adjust kernel_bandwidths in config; add epsilon to kernel computation

 

Caution: Domain adaptation assumes the source and target domains share the same anomaly types, just with different feature distributions. If the target domain has fundamentally different anomaly mechanisms (not just different sensor characteristics), domain adaptation will not help, and you need at least some labeled target data (semi-supervised adaptation).

Conclusion

You now have a complete, end-to-end implementation of domain-adaptive time-series anomaly detection. Let us recap what we built and where to go next.

The nine scripts in this guide cover the full pipeline: generating realistic synthetic data with domain shift, building a CNN-LSTM encoder with multi-head outputs, implementing three different domain adaptation strategies (DANN, MMD, CORAL), training with progressive lambda scheduling, and evaluating with comprehensive metrics and diagnostic plots. Every script is complete and runnable as-is.

The core insight is simple but powerful: instead of requiring expensive labeled data in every new domain, you can train a model to learn domain-invariant features—representations that capture the essence of “anomaly” regardless of which machine, factory, or sensor produced the signal. The Gradient Reversal Layer is the elegant mechanism that makes this adversarial training possible in a single unified model, while MMD and CORAL offer simpler, more stable alternatives.

Where should you go from here? Three directions are most promising. First, semi-supervised adaptation: if you can label even 5–10% of the target domain data, you can add a supervised loss on those labeled target samples alongside the unsupervised domain alignment, dramatically improving results. Second, multi-source adaptation: if you have data from machines A, B, and C, you can adapt to machine D by combining knowledge from all three sources, not just one. Third, continual adaptation: in production, the target domain drifts over time as machines age and wear. Implement online or periodic re-adaptation to keep the model current.

Domain adaptation is not a silver bullet. It works best when domains share the same underlying anomaly mechanisms but differ in superficial signal characteristics—exactly the scenario in most industrial settings. When it works, it can save months of labeling effort and accelerate deployment of anomaly detection to new equipment. The code in this guide gives you everything you need to start experimenting with your own data today.

References

  1. Ganin, Y., Ustinova, E., Ajakan, H., Germain, P., Larochelle, H., Laviolette, F., Marchand, M., and Lempitsky, V. (2016). “Domain-Adversarial Training of Neural Networks.” Journal of Machine Learning Research, 17(59), 1-35.
  2. Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., and Smola, A. J. (2012). “A Kernel Two-Sample Test.” Journal of Machine Learning Research, 13, 723-773.
  3. Sun, B. and Saenko, K. (2016). “Deep CORAL: Correlation Alignment for Deep Domain Adaptation.” Proceedings of the European Conference on Computer Vision (ECCV) Workshops.
  4. Ragab, M., Lu, Z., Chen, Z., Wu, M., Kwoh, C. K., and Li, X. (2023). “Time-Series Domain Adaptation: A Survey.” arXiv preprint.
  5. Chalapathy, R. and Chawla, S. (2019). “Deep Learning for Anomaly Detection: A Survey.” arXiv preprint.
  6. PyTorch Documentation. “Extending torch.autograd — Custom Function.”

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *