Introduction: The Domain Shift Problem in Anomaly Detection
Imagine you spent six months collecting labeled anomaly data from a CNC milling machine on your factory floor. You painstakingly tagged every spindle vibration spike, every thermal drift event, every bearing degradation signature. Your anomaly detection model hits 0.95 AUROC on that machine. Then your company buys a second milling machine—same manufacturer, same model number, but a different production year. You deploy your model, and the AUROC drops to 0.62. Barely better than a coin flip.
This is the domain shift problem, and it is one of the most expensive headaches in industrial machine learning. The statistical distribution of sensor readings changes between machines, factories, sensor brands, and even seasons. Noise floors differ. Baseline amplitudes drift. The relationship between “normal” and “anomalous” subtly warps. Your perfectly trained model becomes useless the moment it leaves its original domain.
The classical solution is to label data in every new domain. But labeling anomaly data is brutally expensive—anomalies are rare by definition, and expert annotators are scarce. What if you could transfer the anomaly detection knowledge from your labeled source domain (machine A) to an unlabeled target domain (machine B) without restarting from scratch?
That is exactly what domain adaptation does. By training a model to learn features that are invariant across domains—features that capture the essence of “anomaly” regardless of which machine produced the signal—you can detect anomalies in new domains with little or no labeled target data. The technique has roots in computer vision (the famous DANN paper by Ganin et al., 2016), but its application to time-series anomaly detection remains underexplored in practice, despite being exactly where it is needed most.
This post is not a theoretical survey. It is a complete, runnable implementation guide. By the end, you will have nine production-ready Python scripts that implement three domain adaptation strategies—DANN (Domain-Adversarial Neural Networks), MMD (Maximum Mean Discrepancy), and CORAL (CORrelation ALignment)—on top of a CNN-LSTM hybrid encoder for multi-channel time-series anomaly detection. Every script is complete. No ellipses, no “fill in the rest,” no pseudocode. Copy, paste, run.
Let us build it.
Project Structure and Setup
Before writing any code, let us establish a clean project layout. Every file has a single responsibility, making the codebase easy to understand and modify for your own use case.
da-anomaly-detection/
├── config.py # Hyperparameters and configuration
├── dataset.py # Dataset classes and data loading
├── model.py # Model architecture (encoder, classifier, discriminator)
├── losses.py # Loss function definitions (DANN, MMD, CORAL)
├── train.py # Main training script with domain adaptation
├── evaluate.py # Evaluation and metrics
├── utils.py # Utility functions (seeding, checkpoints, plotting)
├── generate_synthetic_data.py # Generate example data for testing
├── requirements.txt # Dependencies
├── data/ # Generated or real data goes here
├── checkpoints/ # Saved model weights
└── results/ # Evaluation outputs, plots, metrics
Start by creating the directory and installing dependencies:
mkdir -p da-anomaly-detection/{data,checkpoints,results}
cd da-anomaly-detection
requirements.txt
torch>=2.0.0
numpy>=1.24.0
pandas>=2.0.0
scikit-learn>=1.3.0
matplotlib>=3.7.0
tqdm>=4.65.0
pip install -r requirements.txt
pip install torch --index-url https://download.pytorch.org/whl/cu121
Configuration and Hyperparameters
Centralizing configuration prevents magic numbers from scattering across your codebase. We use a Python dataclass so the IDE gives you autocompletion and type checking for free.
config.py
"""
config.py — Centralized configuration for domain-adaptive anomaly detection.
All hyperparameters live here. Override via CLI arguments in train.py.
"""
from dataclasses import dataclass, field
import torch
import os
@dataclass
class Config:
"""All hyperparameters and paths for the DA anomaly detection pipeline."""
# --- Data Parameters ---
num_features: int = 6 # Number of sensor channels
window_size: int = 64 # Sliding window length (timesteps)
stride: int = 16 # Stride for sliding window
train_ratio: float = 0.8 # Train/val split ratio
# --- Model Architecture ---
cnn_channels: list = field(default_factory=lambda: [32, 64, 128])
cnn_kernel_sizes: list = field(default_factory=lambda: [7, 5, 3])
lstm_hidden_dim: int = 128
lstm_num_layers: int = 2
latent_dim: int = 128 # Dimension of the shared feature space
classifier_hidden_dim: int = 64
discriminator_hidden_dim: int = 64
dropout: float = 0.3
# --- Training Parameters ---
batch_size: int = 64
learning_rate: float = 1e-3
discriminator_lr: float = 1e-3
weight_decay: float = 1e-4
epochs: int = 100
patience: int = 15 # Early stopping patience
# --- Domain Adaptation Parameters ---
adaptation_method: str = "dann" # 'dann', 'mmd', or 'coral'
lambda_domain: float = 1.0 # Max domain loss weight
lambda_recon: float = 0.5 # Reconstruction loss weight
lambda_cls: float = 1.0 # Classification loss weight
gamma: float = 10.0 # DANN lambda schedule steepness
mmd_kernel_bandwidth: list = field(
default_factory=lambda: [0.01, 0.1, 1.0, 10.0, 100.0]
)
# --- Anomaly Scoring ---
alpha: float = 0.7 # Weight for classifier score vs recon error
anomaly_threshold_percentile: float = 95.0
# --- Paths ---
data_dir: str = "data"
checkpoint_dir: str = "checkpoints"
results_dir: str = "results"
# --- Device and Reproducibility ---
seed: int = 42
device: str = ""
def __post_init__(self):
if not self.device:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(self.data_dir, exist_ok=True)
os.makedirs(self.checkpoint_dir, exist_ok=True)
os.makedirs(self.results_dir, exist_ok=True)
lambda_domain. Too high, and the model forgets how to classify anomalies. Too low, and domain adaptation has no effect. The progressive scheduling in our training script (DANN lambda schedule) addresses this by starting low and ramping up.
Generating Realistic Synthetic Data
Before touching real proprietary data, you need a sandbox. The script below generates two-domain synthetic time-series data with realistic characteristics: seasonal patterns, trends, multiple anomaly types, and domain-specific differences in noise, amplitude, and baseline offset. The source domain gets full labels; the target domain training set has no labels (simulating the real scenario), while the target test set has labels for evaluation.
generate_synthetic_data.py
"""
generate_synthetic_data.py — Generate realistic two-domain time-series data
with injected anomalies for testing domain adaptation.
Simulates 6-channel sensor data (e.g., 3 joints x [torque, position]) from
two different machines with different noise/amplitude characteristics.
"""
import argparse
import os
import numpy as np
import pandas as pd
def generate_base_signal(n_samples: int, num_features: int, seed: int = 42) -> np.ndarray:
"""Generate a base multi-channel time-series with realistic patterns."""
rng = np.random.RandomState(seed)
t = np.arange(n_samples)
signals = np.zeros((n_samples, num_features))
for ch in range(num_features):
freq1 = 0.002 + ch * 0.001
freq2 = 0.01 + ch * 0.003
phase1 = rng.uniform(0, 2 * np.pi)
phase2 = rng.uniform(0, 2 * np.pi)
# Seasonal component
seasonal = 2.0 * np.sin(2 * np.pi * freq1 * t + phase1)
# Higher-frequency oscillation
oscillation = 0.8 * np.sin(2 * np.pi * freq2 * t + phase2)
# Slow trend
trend = 0.0005 * t * ((-1) ** ch)
# Combine
signals[:, ch] = seasonal + oscillation + trend
return signals
def inject_anomalies(
signals: np.ndarray,
anomaly_ratio: float = 0.05,
seed: int = 42
) -> tuple:
"""
Inject multiple anomaly types into signals.
Returns (modified_signals, labels) where labels[i]=1 means anomaly.
"""
rng = np.random.RandomState(seed)
n_samples, num_features = signals.shape
labels = np.zeros(n_samples, dtype=int)
modified = signals.copy()
n_anomalies = int(n_samples * anomaly_ratio)
anomaly_types = ["spike", "drift", "level_shift", "frequency_change"]
# Choose random anomaly locations (non-overlapping segments)
segment_length = 20
max_start = n_samples - segment_length
starts = rng.choice(max_start, size=n_anomalies, replace=False)
for i, start in enumerate(starts):
end = start + segment_length
a_type = anomaly_types[i % len(anomaly_types)]
channel = rng.randint(0, num_features)
if a_type == "spike":
spike_pos = start + rng.randint(0, segment_length)
magnitude = rng.uniform(5, 10) * (1 if rng.random() > 0.5 else -1)
modified[spike_pos, channel] += magnitude
labels[spike_pos] = 1
elif a_type == "drift":
drift = np.linspace(0, rng.uniform(3, 6), segment_length)
modified[start:end, channel] += drift
labels[start:end] = 1
elif a_type == "level_shift":
shift = rng.uniform(3, 7) * (1 if rng.random() > 0.5 else -1)
modified[start:end, channel] += shift
labels[start:end] = 1
elif a_type == "frequency_change":
t_seg = np.arange(segment_length)
high_freq = 2.0 * np.sin(2 * np.pi * 0.15 * t_seg)
modified[start:end, channel] += high_freq
labels[start:end] = 1
return modified, labels
def apply_domain_transform(
signals: np.ndarray,
noise_scale: float = 0.3,
amplitude_scale: float = 1.0,
baseline_offset: float = 0.0,
seed: int = 42
) -> np.ndarray:
"""Apply domain-specific transformations to simulate a different machine."""
rng = np.random.RandomState(seed)
transformed = signals.copy()
n_samples, num_features = transformed.shape
# Per-channel amplitude scaling
for ch in range(num_features):
ch_amp = amplitude_scale * rng.uniform(0.8, 1.2)
ch_offset = baseline_offset + rng.uniform(-0.5, 0.5)
transformed[:, ch] = transformed[:, ch] * ch_amp + ch_offset
# Add domain-specific noise
noise = rng.normal(0, noise_scale, transformed.shape)
transformed += noise
return transformed
def generate_dataset(
n_samples: int,
num_features: int,
anomaly_ratio: float,
noise_scale: float,
amplitude_scale: float,
baseline_offset: float,
seed: int
) -> pd.DataFrame:
"""Generate a complete dataset with signals, anomalies, and domain transform."""
base = generate_base_signal(n_samples, num_features, seed=seed)
with_anomalies, labels = inject_anomalies(base, anomaly_ratio, seed=seed + 1)
transformed = apply_domain_transform(
with_anomalies,
noise_scale=noise_scale,
amplitude_scale=amplitude_scale,
baseline_offset=baseline_offset,
seed=seed + 2
)
columns = [f"sensor_{i}" for i in range(num_features)]
df = pd.DataFrame(transformed, columns=columns)
df["label"] = labels
df["timestamp"] = pd.date_range("2024-01-01", periods=n_samples, freq="s")
return df
def main():
parser = argparse.ArgumentParser(
description="Generate synthetic two-domain time-series data."
)
parser.add_argument("--output_dir", type=str, default="data",
help="Output directory for CSV files")
parser.add_argument("--n_samples", type=int, default=20000,
help="Number of samples per dataset")
parser.add_argument("--num_features", type=int, default=6,
help="Number of sensor channels")
parser.add_argument("--anomaly_ratio", type=float, default=0.05,
help="Fraction of timesteps with anomalies")
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
print("Generating source domain data (Machine A)...")
source_full = generate_dataset(
n_samples=args.n_samples,
num_features=args.num_features,
anomaly_ratio=args.anomaly_ratio,
noise_scale=0.2,
amplitude_scale=1.0,
baseline_offset=0.0,
seed=args.seed
)
split_idx = int(len(source_full) * 0.7)
source_train = source_full.iloc[:split_idx].reset_index(drop=True)
source_test = source_full.iloc[split_idx:].reset_index(drop=True)
print("Generating target domain data (Machine B)...")
target_full = generate_dataset(
n_samples=args.n_samples,
num_features=args.num_features,
anomaly_ratio=args.anomaly_ratio,
noise_scale=0.5, # Higher noise
amplitude_scale=1.4, # Different amplitude
baseline_offset=2.0, # Shifted baseline
seed=args.seed + 100
)
split_idx_t = int(len(target_full) * 0.7)
target_train = target_full.iloc[:split_idx_t].reset_index(drop=True)
target_test = target_full.iloc[split_idx_t:].reset_index(drop=True)
# Remove labels from target train (unsupervised in target domain)
target_train_unlabeled = target_train.drop(columns=["label"])
# Save all files
source_train.to_csv(os.path.join(args.output_dir, "source_train.csv"), index=False)
source_test.to_csv(os.path.join(args.output_dir, "source_test.csv"), index=False)
target_train_unlabeled.to_csv(os.path.join(args.output_dir, "target_train.csv"), index=False)
target_test.to_csv(os.path.join(args.output_dir, "target_test.csv"), index=False)
print(f"\nDatasets saved to {args.output_dir}/")
print(f" source_train.csv: {len(source_train)} samples, "
f"{source_train['label'].sum()} anomalies ({source_train['label'].mean()*100:.1f}%)")
print(f" source_test.csv: {len(source_test)} samples, "
f"{source_test['label'].sum()} anomalies ({source_test['label'].mean()*100:.1f}%)")
print(f" target_train.csv: {len(target_train_unlabeled)} samples (no labels)")
print(f" target_test.csv: {len(target_test)} samples, "
f"{target_test['label'].sum()} anomalies ({target_test['label'].mean()*100:.1f}%)")
if __name__ == "__main__":
main()
Run it immediately:
python generate_synthetic_data.py --output_dir data/ --n_samples 20000
You will get four CSV files. The source data has labels everywhere. The target training data has no labels—this is the whole point of domain adaptation. The target test data has labels so we can measure how well the adaptation worked.
Dataset Classes and Data Loading
Time-series anomaly detection operates on windows: fixed-length slices of the signal. Our dataset class handles windowing, normalization (fit on source, apply everywhere), and optional data augmentation. The DomainAdaptationDataLoader pairs source and target batches for simultaneous training.
dataset.py
"""
dataset.py — PyTorch Dataset classes for time-series domain adaptation.
Handles sliding-window creation, normalization, augmentation, and
paired source-target batch generation.
"""
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
class TimeSeriesDataset(Dataset):
"""
Sliding-window dataset for multi-channel time-series.
Args:
data: numpy array of shape (n_samples, num_features)
labels: numpy array of shape (n_samples,) or None for unlabeled data
window_size: number of timesteps per window
stride: step between consecutive windows
transform: optional callable for data augmentation
"""
def __init__(
self,
data: np.ndarray,
labels: np.ndarray = None,
window_size: int = 64,
stride: int = 16,
transform=None
):
self.data = data.astype(np.float32)
self.labels = labels
self.window_size = window_size
self.stride = stride
self.transform = transform
# Precompute valid window start indices
self.indices = list(range(0, len(data) - window_size + 1, stride))
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
start = self.indices[idx]
end = start + self.window_size
window = self.data[start:end] # (window_size, num_features)
if self.transform is not None:
window = self.transform(window)
# Transpose to (num_features, window_size) for Conv1d
window_tensor = torch.tensor(window, dtype=torch.float32).T
if self.labels is not None:
# Window label = 1 if any timestep in window is anomalous
window_label = float(self.labels[start:end].max())
return window_tensor, torch.tensor(window_label, dtype=torch.float32)
else:
return window_tensor, torch.tensor(-1.0, dtype=torch.float32)
class Normalizer:
"""
Fit on source training data, transform all data.
Uses per-channel mean and std normalization.
"""
def __init__(self):
self.mean = None
self.std = None
def fit(self, data: np.ndarray):
"""Compute mean and std from training data."""
self.mean = data.mean(axis=0)
self.std = data.std(axis=0)
# Prevent division by zero
self.std[self.std < 1e-8] = 1.0
return self
def transform(self, data: np.ndarray) -> np.ndarray:
"""Apply normalization."""
return (data - self.mean) / self.std
def fit_transform(self, data: np.ndarray) -> np.ndarray:
"""Fit and transform in one step."""
self.fit(data)
return self.transform(data)
class JitterTransform:
"""Add random Gaussian noise for data augmentation."""
def __init__(self, sigma: float = 0.03):
self.sigma = sigma
def __call__(self, window: np.ndarray) -> np.ndarray:
noise = np.random.normal(0, self.sigma, window.shape).astype(np.float32)
return window + noise
class ScalingTransform:
"""Random per-channel amplitude scaling for data augmentation."""
def __init__(self, sigma: float = 0.1):
self.sigma = sigma
def __call__(self, window: np.ndarray) -> np.ndarray:
factor = np.random.normal(1.0, self.sigma, (1, window.shape[1])).astype(np.float32)
return window * factor
class ComposeTransforms:
"""Chain multiple transforms together."""
def __init__(self, transforms: list):
self.transforms = transforms
def __call__(self, window: np.ndarray) -> np.ndarray:
for t in self.transforms:
window = t(window)
return window
def load_csv_data(filepath: str, has_labels: bool = True):
"""
Load a CSV file and separate features from labels.
Returns:
data: numpy array (n_samples, num_features)
labels: numpy array (n_samples,) or None
"""
df = pd.read_csv(filepath)
# Drop non-numeric columns like timestamp
feature_cols = [c for c in df.columns if c not in ("label", "timestamp")]
data = df[feature_cols].values.astype(np.float32)
labels = df["label"].values.astype(np.float32) if (has_labels and "label" in df.columns) else None
return data, labels
def create_data_loaders(config) -> dict:
"""
Create all data loaders for domain adaptation training.
Returns a dict with keys:
'source_train', 'source_val', 'target_train', 'target_test'
"""
import os
# Load raw data
source_train_data, source_train_labels = load_csv_data(
os.path.join(config.data_dir, "source_train.csv"), has_labels=True
)
source_test_data, source_test_labels = load_csv_data(
os.path.join(config.data_dir, "source_test.csv"), has_labels=True
)
target_train_data, _ = load_csv_data(
os.path.join(config.data_dir, "target_train.csv"), has_labels=False
)
target_test_data, target_test_labels = load_csv_data(
os.path.join(config.data_dir, "target_test.csv"), has_labels=True
)
# Normalize: fit on source train only
normalizer = Normalizer()
source_train_data = normalizer.fit_transform(source_train_data)
source_test_data = normalizer.transform(source_test_data)
target_train_data = normalizer.transform(target_train_data)
target_test_data = normalizer.transform(target_test_data)
# Optional augmentation for training
train_transform = ComposeTransforms([
JitterTransform(sigma=0.03),
ScalingTransform(sigma=0.1),
])
# Create datasets
source_train_ds = TimeSeriesDataset(
source_train_data, source_train_labels,
window_size=config.window_size, stride=config.stride,
transform=train_transform
)
source_test_ds = TimeSeriesDataset(
source_test_data, source_test_labels,
window_size=config.window_size, stride=config.stride
)
target_train_ds = TimeSeriesDataset(
target_train_data, labels=None,
window_size=config.window_size, stride=config.stride,
transform=train_transform
)
target_test_ds = TimeSeriesDataset(
target_test_data, target_test_labels,
window_size=config.window_size, stride=config.stride
)
# Create loaders
loaders = {
"source_train": DataLoader(
source_train_ds, batch_size=config.batch_size,
shuffle=True, drop_last=True, num_workers=0
),
"source_test": DataLoader(
source_test_ds, batch_size=config.batch_size,
shuffle=False, num_workers=0
),
"target_train": DataLoader(
target_train_ds, batch_size=config.batch_size,
shuffle=True, drop_last=True, num_workers=0
),
"target_test": DataLoader(
target_test_ds, batch_size=config.batch_size,
shuffle=False, num_workers=0
),
}
return loaders, normalizer
The Core Model Architecture
This is the heart of the system. Our architecture has four components working together: a shared encoder that processes time-series windows into a fixed-size feature vector, an anomaly classifier that predicts normal vs. anomaly, a reconstruction decoder that reconstructs the original input (providing an auxiliary anomaly signal), and a domain discriminator that tries to identify which domain produced a given feature vector. The magic ingredient is the Gradient Reversal Layer (GRL): during backpropagation, it flips the sign of gradients flowing from the domain discriminator to the encoder. This forces the encoder to learn features that are maximally uninformative about domain identity—precisely the domain-invariant representations we want.
Architecture:
┌─── Anomaly Classifier (binary: normal/anomaly)
Input → Shared Encoder ─┤
(time-series) ├─── Reconstruction Decoder (autoencoder branch)
└─── Domain Discriminator (with gradient reversal)
model.py
"""
model.py — Domain-adaptive anomaly detection model architecture.
Components:
- GradientReversalLayer: reverses gradients for adversarial domain adaptation
- SharedEncoder: CNN + BiLSTM feature extractor
- AnomalyClassifier: binary classification head
- ReconstructionDecoder: autoencoder branch for reconstruction-based scoring
- DomainDiscriminator: adversarial domain classification head
- DomainAdaptiveAnomalyDetector: full model combining all components
"""
import torch
import torch.nn as nn
from torch.autograd import Function
class GradientReversalFunction(Function):
"""
Gradient Reversal Layer (GRL) — Ganin et al., 2016.
Forward pass: identity.
Backward pass: negate gradients and scale by lambda.
"""
@staticmethod
def forward(ctx, x, lambda_val):
ctx.lambda_val = lambda_val
return x.clone()
@staticmethod
def backward(ctx, grad_output):
return -ctx.lambda_val * grad_output, None
class GradientReversalLayer(nn.Module):
"""Module wrapper for the gradient reversal function."""
def __init__(self, lambda_val: float = 1.0):
super().__init__()
self.lambda_val = lambda_val
def set_lambda(self, lambda_val: float):
self.lambda_val = lambda_val
def forward(self, x):
return GradientReversalFunction.apply(x, self.lambda_val)
class SharedEncoder(nn.Module):
"""
1D-CNN + Bidirectional LSTM encoder for multi-channel time-series.
Input shape: (batch, num_features, window_size)
Output shape: (batch, latent_dim)
"""
def __init__(
self,
num_features: int = 6,
cnn_channels: list = None,
cnn_kernel_sizes: list = None,
lstm_hidden_dim: int = 128,
lstm_num_layers: int = 2,
latent_dim: int = 128,
dropout: float = 0.3,
):
super().__init__()
if cnn_channels is None:
cnn_channels = [32, 64, 128]
if cnn_kernel_sizes is None:
cnn_kernel_sizes = [7, 5, 3]
# Build CNN layers
cnn_layers = []
in_channels = num_features
for out_ch, ks in zip(cnn_channels, cnn_kernel_sizes):
cnn_layers.extend([
nn.Conv1d(in_channels, out_ch, kernel_size=ks, padding=ks // 2),
nn.BatchNorm1d(out_ch),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
])
in_channels = out_ch
self.cnn = nn.Sequential(*cnn_layers)
# Bidirectional LSTM on top of CNN features
self.lstm = nn.LSTM(
input_size=cnn_channels[-1],
hidden_size=lstm_hidden_dim,
num_layers=lstm_num_layers,
batch_first=True,
bidirectional=True,
dropout=dropout if lstm_num_layers > 1 else 0.0,
)
# Project to latent space
self.fc = nn.Sequential(
nn.Linear(lstm_hidden_dim * 2, latent_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
)
self.latent_dim = latent_dim
def forward(self, x):
"""
Args:
x: (batch, num_features, window_size)
Returns:
latent: (batch, latent_dim)
"""
# CNN: (batch, cnn_channels[-1], window_size)
cnn_out = self.cnn(x)
# Transpose for LSTM: (batch, window_size, cnn_channels[-1])
lstm_in = cnn_out.permute(0, 2, 1)
# LSTM: (batch, window_size, lstm_hidden*2)
lstm_out, _ = self.lstm(lstm_in)
# Take last timestep output
last_hidden = lstm_out[:, -1, :]
# Project to latent space
latent = self.fc(last_hidden)
return latent
class AnomalyClassifier(nn.Module):
"""
Binary classification head: normal (0) vs anomaly (1).
Input: (batch, latent_dim)
Output: (batch, 1) — sigmoid logit
"""
def __init__(self, latent_dim: int = 128, hidden_dim: int = 64, dropout: float = 0.3):
super().__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1),
)
def forward(self, latent):
return self.net(latent)
class ReconstructionDecoder(nn.Module):
"""
Decoder that reconstructs the original input from latent features.
Uses LSTM + transposed Conv1d layers.
Input: (batch, latent_dim)
Output: (batch, num_features, window_size)
"""
def __init__(
self,
latent_dim: int = 128,
num_features: int = 6,
window_size: int = 64,
lstm_hidden_dim: int = 128,
dropout: float = 0.3,
):
super().__init__()
self.window_size = window_size
self.num_features = num_features
self.lstm_hidden_dim = lstm_hidden_dim
# Expand latent to sequence
self.fc = nn.Sequential(
nn.Linear(latent_dim, lstm_hidden_dim),
nn.ReLU(inplace=True),
)
# LSTM decoder
self.lstm = nn.LSTM(
input_size=lstm_hidden_dim,
hidden_size=lstm_hidden_dim,
num_layers=1,
batch_first=True,
)
# Transposed convolutions to reconstruct
self.deconv = nn.Sequential(
nn.ConvTranspose1d(lstm_hidden_dim, 64, kernel_size=3, padding=1),
nn.BatchNorm1d(64),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.ConvTranspose1d(64, 32, kernel_size=3, padding=1),
nn.BatchNorm1d(32),
nn.ReLU(inplace=True),
nn.ConvTranspose1d(32, num_features, kernel_size=3, padding=1),
)
def forward(self, latent):
"""
Args:
latent: (batch, latent_dim)
Returns:
reconstruction: (batch, num_features, window_size)
"""
batch_size = latent.size(0)
# Expand to sequence
expanded = self.fc(latent).unsqueeze(1).repeat(1, self.window_size, 1)
# LSTM decode
lstm_out, _ = self.lstm(expanded)
# Transpose for Conv1d: (batch, lstm_hidden, window_size)
conv_in = lstm_out.permute(0, 2, 1)
# Reconstruct
reconstruction = self.deconv(conv_in)
return reconstruction
class DomainDiscriminator(nn.Module):
"""
Domain classification head with Gradient Reversal Layer.
Classifies whether features came from source (0) or target (1) domain.
Input: (batch, latent_dim)
Output: (batch, 1) — domain logit
"""
def __init__(self, latent_dim: int = 128, hidden_dim: int = 64, dropout: float = 0.3):
super().__init__()
self.grl = GradientReversalLayer(lambda_val=1.0)
self.net = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1),
)
def set_lambda(self, lambda_val: float):
self.grl.set_lambda(lambda_val)
def forward(self, latent):
reversed_features = self.grl(latent)
return self.net(reversed_features)
class DomainAdaptiveAnomalyDetector(nn.Module):
"""
Full domain-adaptive anomaly detection model.
Combines encoder, anomaly classifier, reconstruction decoder,
and domain discriminator.
"""
def __init__(self, config):
super().__init__()
self.encoder = SharedEncoder(
num_features=config.num_features,
cnn_channels=config.cnn_channels,
cnn_kernel_sizes=config.cnn_kernel_sizes,
lstm_hidden_dim=config.lstm_hidden_dim,
lstm_num_layers=config.lstm_num_layers,
latent_dim=config.latent_dim,
dropout=config.dropout,
)
self.classifier = AnomalyClassifier(
latent_dim=config.latent_dim,
hidden_dim=config.classifier_hidden_dim,
dropout=config.dropout,
)
self.decoder = ReconstructionDecoder(
latent_dim=config.latent_dim,
num_features=config.num_features,
window_size=config.window_size,
lstm_hidden_dim=config.lstm_hidden_dim,
dropout=config.dropout,
)
self.discriminator = DomainDiscriminator(
latent_dim=config.latent_dim,
hidden_dim=config.discriminator_hidden_dim,
dropout=config.dropout,
)
def set_domain_lambda(self, lambda_val: float):
"""Update the GRL lambda for progressive scheduling."""
self.discriminator.set_lambda(lambda_val)
def forward(self, x):
"""
Full forward pass.
Args:
x: (batch, num_features, window_size)
Returns:
anomaly_logits: (batch, 1) — raw logits for anomaly classification
reconstruction: (batch, num_features, window_size) — reconstructed input
domain_logits: (batch, 1) — raw logits for domain classification
latent_features: (batch, latent_dim) — shared latent representation
"""
latent = self.encoder(x)
anomaly_logits = self.classifier(latent)
reconstruction = self.decoder(latent)
domain_logits = self.discriminator(latent)
return anomaly_logits, reconstruction, domain_logits, latent
Loss Functions: DANN, MMD, and CORAL
Domain adaptation is not one technique—it is a family of techniques, each with different strengths. Our implementation supports three approaches, all selectable via a single config flag. DANN uses adversarial training (the discriminator approach). MMD directly minimizes the statistical distance between source and target feature distributions using a kernel trick. CORAL aligns the second-order statistics (covariance matrices) of the two domains. You can switch between them in one line of config.
losses.py
"""
losses.py — Loss functions for domain-adaptive anomaly detection.
Includes:
- AnomalyDetectionLoss (BCE for anomaly classification)
- ReconstructionLoss (MSE for autoencoder)
- DomainAdversarialLoss (BCE for domain discrimination)
- MMDLoss (Maximum Mean Discrepancy with Gaussian kernel)
- CORALLoss (CORrelation ALignment)
- CombinedLoss (weighted combination of all losses)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class AnomalyDetectionLoss(nn.Module):
"""Binary cross-entropy loss for anomaly classification."""
def __init__(self):
super().__init__()
self.bce = nn.BCEWithLogitsLoss()
def forward(self, logits, labels):
"""
Args:
logits: (batch, 1) raw anomaly logits
labels: (batch,) binary labels (0=normal, 1=anomaly)
"""
return self.bce(logits.squeeze(-1), labels)
class ReconstructionLoss(nn.Module):
"""MSE loss between input and reconstruction."""
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def forward(self, reconstruction, original):
"""
Args:
reconstruction: (batch, num_features, window_size)
original: (batch, num_features, window_size)
"""
return self.mse(reconstruction, original)
class DomainAdversarialLoss(nn.Module):
"""BCE loss for domain classification (used with GRL for DANN)."""
def __init__(self):
super().__init__()
self.bce = nn.BCEWithLogitsLoss()
def forward(self, domain_logits, domain_labels):
"""
Args:
domain_logits: (batch, 1) raw domain logits
domain_labels: (batch,) domain labels (0=source, 1=target)
"""
return self.bce(domain_logits.squeeze(-1), domain_labels)
class MMDLoss(nn.Module):
"""
Maximum Mean Discrepancy loss with multi-scale Gaussian kernel.
Measures the distance between source and target feature distributions
in a reproducing kernel Hilbert space (RKHS).
"""
def __init__(self, kernel_bandwidths: list = None):
super().__init__()
if kernel_bandwidths is None:
self.kernel_bandwidths = [0.01, 0.1, 1.0, 10.0, 100.0]
else:
self.kernel_bandwidths = kernel_bandwidths
def gaussian_kernel(self, x, y):
"""
Compute multi-scale Gaussian kernel matrix between x and y.
Args:
x: (n, d) tensor
y: (m, d) tensor
Returns:
kernel_val: scalar — sum of Gaussian kernel values across bandwidths
"""
# Pairwise squared distances
xx = torch.mm(x, x.t())
yy = torch.mm(y, y.t())
xy = torch.mm(x, y.t())
rx = xx.diag().unsqueeze(0).expand_as(xx)
ry = yy.diag().unsqueeze(0).expand_as(yy)
dxx = rx.t() + rx - 2.0 * xx
dyy = ry.t() + ry - 2.0 * yy
dxy = rx.t() + ry - 2.0 * xy
k_xx = torch.zeros_like(xx)
k_yy = torch.zeros_like(yy)
k_xy = torch.zeros_like(xy)
for bw in self.kernel_bandwidths:
k_xx += torch.exp(-dxx / (2.0 * bw))
k_yy += torch.exp(-dyy / (2.0 * bw))
k_xy += torch.exp(-dxy / (2.0 * bw))
return k_xx, k_yy, k_xy
def forward(self, source_features, target_features):
"""
Compute MMD^2 between source and target feature distributions.
Args:
source_features: (n, d) latent features from source domain
target_features: (m, d) latent features from target domain
Returns:
mmd_loss: scalar
"""
n = source_features.size(0)
m = target_features.size(0)
k_xx, k_yy, k_xy = self.gaussian_kernel(source_features, target_features)
mmd = (k_xx.sum() / (n * n)
+ k_yy.sum() / (m * m)
- 2.0 * k_xy.sum() / (n * m))
return mmd
class CORALLoss(nn.Module):
"""
CORrelation ALignment loss.
Aligns the second-order statistics (covariance matrices) of
source and target feature distributions.
"""
def __init__(self):
super().__init__()
def forward(self, source_features, target_features):
"""
Compute CORAL loss.
Args:
source_features: (n, d) latent features from source domain
target_features: (m, d) latent features from target domain
Returns:
coral_loss: scalar
"""
d = source_features.size(1)
n_s = source_features.size(0)
n_t = target_features.size(0)
# Compute covariance matrices
source_centered = source_features - source_features.mean(dim=0, keepdim=True)
target_centered = target_features - target_features.mean(dim=0, keepdim=True)
cov_source = (source_centered.t() @ source_centered) / (n_s - 1)
cov_target = (target_centered.t() @ target_centered) / (n_t - 1)
# Frobenius norm of covariance difference
diff = cov_source - cov_target
coral_loss = (diff * diff).sum() / (4 * d * d)
return coral_loss
class CombinedLoss(nn.Module):
"""
Combines anomaly detection, reconstruction, and domain adaptation losses.
total_loss = lambda_cls * anomaly_loss
+ lambda_recon * recon_loss
+ lambda_domain * domain_loss
The domain_loss component uses DANN, MMD, or CORAL depending on config.
"""
def __init__(self, config):
super().__init__()
self.anomaly_loss_fn = AnomalyDetectionLoss()
self.recon_loss_fn = ReconstructionLoss()
self.dann_loss_fn = DomainAdversarialLoss()
self.mmd_loss_fn = MMDLoss(kernel_bandwidths=config.mmd_kernel_bandwidth)
self.coral_loss_fn = CORALLoss()
self.lambda_cls = config.lambda_cls
self.lambda_recon = config.lambda_recon
self.lambda_domain = config.lambda_domain
self.method = config.adaptation_method
def forward(
self,
anomaly_logits,
anomaly_labels,
reconstruction,
original,
domain_logits=None,
domain_labels=None,
source_features=None,
target_features=None,
current_lambda=None,
):
"""
Compute combined loss.
Args:
anomaly_logits: (batch, 1) anomaly classification logits (source only)
anomaly_labels: (batch,) anomaly labels (source only)
reconstruction: (batch, num_features, window_size) reconstruction
original: (batch, num_features, window_size) original input
domain_logits: (batch, 1) domain logits (DANN only)
domain_labels: (batch,) domain labels (DANN only)
source_features: (n, d) source latent features (MMD/CORAL)
target_features: (m, d) target latent features (MMD/CORAL)
current_lambda: float — current domain adaptation weight
Returns:
total_loss, loss_dict (breakdown of individual losses)
"""
domain_weight = current_lambda if current_lambda is not None else self.lambda_domain
# Anomaly classification loss (source only)
cls_loss = self.anomaly_loss_fn(anomaly_logits, anomaly_labels)
# Reconstruction loss (both domains)
recon_loss = self.recon_loss_fn(reconstruction, original)
# Domain adaptation loss
if self.method == "dann" and domain_logits is not None:
domain_loss = self.dann_loss_fn(domain_logits, domain_labels)
elif self.method == "mmd" and source_features is not None:
domain_loss = self.mmd_loss_fn(source_features, target_features)
elif self.method == "coral" and source_features is not None:
domain_loss = self.coral_loss_fn(source_features, target_features)
else:
domain_loss = torch.tensor(0.0, device=anomaly_logits.device)
total_loss = (
self.lambda_cls * cls_loss
+ self.lambda_recon * recon_loss
+ domain_weight * domain_loss
)
loss_dict = {
"total": total_loss.item(),
"classification": cls_loss.item(),
"reconstruction": recon_loss.item(),
"domain": domain_loss.item(),
}
return total_loss, loss_dict
The Main Training Script
This is where everything comes together. The training loop handles the delicate dance of simultaneously training the anomaly classifier (on labeled source data), the reconstruction decoder (on both domains), and the domain discriminator (adversarially, on both domains). The DANN lambda schedule progressively increases the domain adaptation strength over training, following the formula from the original paper: λp = 2 / (1 + exp(-γ · p)) - 1, where p is the training progress from 0 to 1.
train.py
"""
train.py — Main training script for domain-adaptive anomaly detection.
Supports three adaptation methods: DANN, MMD, CORAL.
Uses progressive lambda scheduling for stable training.
"""
import argparse
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
from config import Config
from dataset import create_data_loaders
from model import DomainAdaptiveAnomalyDetector
from losses import CombinedLoss
from utils import (
set_seed,
EarlyStopping,
save_checkpoint,
MetricLogger,
)
def compute_dann_lambda(epoch: int, total_epochs: int, gamma: float = 10.0) -> float:
"""
Progressive lambda schedule from the DANN paper (Ganin et al., 2016).
Ramps from 0 to 1 over training using a sigmoid-like schedule.
lambda_p = 2 / (1 + exp(-gamma * p)) - 1, where p = epoch / total_epochs
"""
p = epoch / total_epochs
return float(2.0 / (1.0 + np.exp(-gamma * p)) - 1.0)
def train_one_epoch(
model,
source_loader,
target_loader,
criterion,
optimizer,
device,
epoch,
total_epochs,
config,
):
"""Train for one epoch with domain adaptation."""
model.train()
epoch_losses = {"total": 0, "classification": 0, "reconstruction": 0, "domain": 0}
n_batches = 0
# Compute current domain adaptation lambda
current_lambda = compute_dann_lambda(epoch, total_epochs, config.gamma) * config.lambda_domain
# Set the GRL lambda in the model
model.set_domain_lambda(current_lambda)
# Zip source and target loaders (cycle the shorter one)
target_iter = iter(target_loader)
for source_batch, source_labels in source_loader:
# Get target batch (cycle if exhausted)
try:
target_batch, _ = next(target_iter)
except StopIteration:
target_iter = iter(target_loader)
target_batch, _ = next(target_iter)
source_batch = source_batch.to(device)
source_labels = source_labels.to(device)
target_batch = target_batch.to(device)
# Determine actual batch sizes (may differ)
bs_s = source_batch.size(0)
bs_t = target_batch.size(0)
# Forward pass: source domain
s_anomaly_logits, s_recon, s_domain_logits, s_latent = model(source_batch)
# Forward pass: target domain
t_anomaly_logits, t_recon, t_domain_logits, t_latent = model(target_batch)
# Combine reconstructions and originals for loss
all_recon = torch.cat([s_recon, t_recon], dim=0)
all_original = torch.cat([source_batch, target_batch], dim=0)
# Domain labels: 0 for source, 1 for target
domain_labels = torch.cat([
torch.zeros(bs_s, device=device),
torch.ones(bs_t, device=device),
])
all_domain_logits = torch.cat([s_domain_logits, t_domain_logits], dim=0)
# Compute combined loss
total_loss, loss_dict = criterion(
anomaly_logits=s_anomaly_logits,
anomaly_labels=source_labels,
reconstruction=all_recon,
original=all_original,
domain_logits=all_domain_logits,
domain_labels=domain_labels,
source_features=s_latent,
target_features=t_latent,
current_lambda=current_lambda,
)
# Backprop
optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# Accumulate losses
for key in epoch_losses:
epoch_losses[key] += loss_dict[key]
n_batches += 1
# Average losses
for key in epoch_losses:
epoch_losses[key] /= max(n_batches, 1)
epoch_losses["lambda"] = current_lambda
return epoch_losses
@torch.no_grad()
def validate(model, loader, criterion, device, config):
"""Validate on a labeled dataset (source test or target test)."""
model.eval()
all_logits = []
all_labels = []
total_recon_loss = 0
n_batches = 0
for batch, labels in loader:
batch = batch.to(device)
labels = labels.to(device)
anomaly_logits, recon, _, latent = model(batch)
recon_loss = nn.MSELoss()(recon, batch)
all_logits.append(anomaly_logits.squeeze(-1).cpu())
all_labels.append(labels.cpu())
total_recon_loss += recon_loss.item()
n_batches += 1
all_logits = torch.cat(all_logits)
all_labels = torch.cat(all_labels)
# Compute metrics
probs = torch.sigmoid(all_logits)
preds = (probs > 0.5).float()
accuracy = (preds == all_labels).float().mean().item()
from sklearn.metrics import roc_auc_score, f1_score
try:
auroc = roc_auc_score(all_labels.numpy(), probs.numpy())
except ValueError:
auroc = 0.5 # Only one class present
f1 = f1_score(all_labels.numpy(), preds.numpy(), zero_division=0)
return {
"accuracy": accuracy,
"auroc": auroc,
"f1": f1,
"recon_loss": total_recon_loss / max(n_batches, 1),
}
def main():
parser = argparse.ArgumentParser(description="Train domain-adaptive anomaly detector")
parser.add_argument("--method", type=str, default="dann",
choices=["dann", "mmd", "coral"],
help="Domain adaptation method")
parser.add_argument("--epochs", type=int, default=None)
parser.add_argument("--batch_size", type=int, default=None)
parser.add_argument("--lr", type=float, default=None)
parser.add_argument("--lambda_domain", type=float, default=None)
parser.add_argument("--lambda_recon", type=float, default=None)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--data_dir", type=str, default=None)
parser.add_argument("--device", type=str, default=None)
args = parser.parse_args()
# Build config with CLI overrides
config = Config()
config.adaptation_method = args.method
if args.epochs is not None:
config.epochs = args.epochs
if args.batch_size is not None:
config.batch_size = args.batch_size
if args.lr is not None:
config.learning_rate = args.lr
if args.lambda_domain is not None:
config.lambda_domain = args.lambda_domain
if args.lambda_recon is not None:
config.lambda_recon = args.lambda_recon
if args.seed is not None:
config.seed = args.seed
if args.data_dir is not None:
config.data_dir = args.data_dir
if args.device is not None:
config.device = args.device
# Setup
set_seed(config.seed)
device = torch.device(config.device)
print(f"Using device: {device}")
print(f"Adaptation method: {config.adaptation_method}")
print(f"Epochs: {config.epochs}, Batch size: {config.batch_size}, LR: {config.learning_rate}")
# Data
print("\nLoading data...")
loaders, normalizer = create_data_loaders(config)
print(f"Source train batches: {len(loaders['source_train'])}")
print(f"Target train batches: {len(loaders['target_train'])}")
# Model
model = DomainAdaptiveAnomalyDetector(config).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nModel parameters: {total_params:,}")
# Optimizer (single optimizer for simplicity; separate LRs via param groups)
optimizer = Adam([
{"params": model.encoder.parameters(), "lr": config.learning_rate},
{"params": model.classifier.parameters(), "lr": config.learning_rate},
{"params": model.decoder.parameters(), "lr": config.learning_rate},
{"params": model.discriminator.parameters(), "lr": config.discriminator_lr},
], weight_decay=config.weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=config.epochs, eta_min=1e-6)
# Loss
criterion = CombinedLoss(config)
# Early stopping
early_stopping = EarlyStopping(patience=config.patience, mode="max")
# Logging
logger = MetricLogger(config.results_dir)
# Training loop
best_target_auroc = 0.0
print("\n" + "=" * 60)
print("Starting training...")
print("=" * 60)
for epoch in range(config.epochs):
start_time = time.time()
# Train
train_losses = train_one_epoch(
model, loaders["source_train"], loaders["target_train"],
criterion, optimizer, device, epoch, config.epochs, config
)
# Validate on source test
source_metrics = validate(model, loaders["source_test"], criterion, device, config)
# Evaluate on target test (the real metric we care about)
target_metrics = validate(model, loaders["target_test"], criterion, device, config)
scheduler.step()
elapsed = time.time() - start_time
# Log
logger.log(epoch, train_losses, source_metrics, target_metrics)
# Print progress
if epoch % 5 == 0 or epoch == config.epochs - 1:
print(
f"Epoch {epoch:3d}/{config.epochs} ({elapsed:.1f}s) | "
f"Loss: {train_losses['total']:.4f} "
f"[cls={train_losses['classification']:.4f}, "
f"rec={train_losses['reconstruction']:.4f}, "
f"dom={train_losses['domain']:.4f}] | "
f"λ={train_losses['lambda']:.3f} | "
f"Src AUROC: {source_metrics['auroc']:.4f} | "
f"Tgt AUROC: {target_metrics['auroc']:.4f}"
)
# Save best model (based on target AUROC)
if target_metrics["auroc"] > best_target_auroc:
best_target_auroc = target_metrics["auroc"]
save_checkpoint(
model, optimizer, epoch, target_metrics,
os.path.join(config.checkpoint_dir, "best_model.pt")
)
# Early stopping on target AUROC
if early_stopping.step(target_metrics["auroc"]):
print(f"\nEarly stopping triggered at epoch {epoch}")
break
print("\n" + "=" * 60)
print(f"Training complete. Best target AUROC: {best_target_auroc:.4f}")
print(f"Best model saved to: {config.checkpoint_dir}/best_model.pt")
print("=" * 60)
# Save training curves
logger.save()
logger.plot_training_curves()
if __name__ == "__main__":
main()
Evaluation and Metrics
After training, we need rigorous evaluation on the target domain. Our evaluation script computes standard anomaly detection metrics, combines classifier and reconstruction scores, implements multiple threshold strategies, and generates diagnostic plots. This is where you find out if domain adaptation actually worked.
evaluate.py
"""
evaluate.py — Evaluation script for domain-adaptive anomaly detection.
Loads a trained model and evaluates on target domain test data.
Computes AUROC, AUPRC, F1, precision, recall.
Generates diagnostic plots and saves results to JSON.
"""
import argparse
import json
import os
import numpy as np
import torch
import torch.nn as nn
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.metrics import (
roc_auc_score,
average_precision_score,
f1_score,
precision_score,
recall_score,
accuracy_score,
confusion_matrix,
roc_curve,
precision_recall_curve,
)
from config import Config
from dataset import create_data_loaders
from model import DomainAdaptiveAnomalyDetector
from utils import set_seed, load_checkpoint
def compute_anomaly_scores(model, loader, device, alpha=0.7):
"""
Compute anomaly scores combining classifier output and reconstruction error.
anomaly_score = alpha * classifier_prob + (1 - alpha) * normalized_recon_error
Returns:
scores: numpy array of anomaly scores
labels: numpy array of ground truth labels
recon_errors: numpy array of per-sample reconstruction errors
classifier_probs: numpy array of classifier probabilities
latent_features: numpy array of latent features (for t-SNE)
"""
model.eval()
all_probs = []
all_labels = []
all_recon_errors = []
all_latent = []
with torch.no_grad():
for batch, labels in loader:
batch = batch.to(device)
anomaly_logits, recon, _, latent = model(batch)
# Classifier probability
probs = torch.sigmoid(anomaly_logits.squeeze(-1))
# Per-sample reconstruction error (mean across features and time)
recon_error = ((recon - batch) ** 2).mean(dim=(1, 2))
all_probs.append(probs.cpu().numpy())
all_labels.append(labels.numpy())
all_recon_errors.append(recon_error.cpu().numpy())
all_latent.append(latent.cpu().numpy())
all_probs = np.concatenate(all_probs)
all_labels = np.concatenate(all_labels)
all_recon_errors = np.concatenate(all_recon_errors)
all_latent = np.concatenate(all_latent)
# Normalize reconstruction errors to [0, 1]
re_min, re_max = all_recon_errors.min(), all_recon_errors.max()
if re_max - re_min > 1e-8:
norm_recon = (all_recon_errors - re_min) / (re_max - re_min)
else:
norm_recon = np.zeros_like(all_recon_errors)
# Combined anomaly score
scores = alpha * all_probs + (1 - alpha) * norm_recon
return scores, all_labels, all_recon_errors, all_probs, all_latent
def find_optimal_threshold(labels, scores):
"""Find the threshold that maximizes F1 score."""
thresholds = np.linspace(0, 1, 200)
best_f1 = 0
best_thresh = 0.5
for thresh in thresholds:
preds = (scores >= thresh).astype(int)
f1 = f1_score(labels, preds, zero_division=0)
if f1 > best_f1:
best_f1 = f1
best_thresh = thresh
return best_thresh, best_f1
def compute_all_metrics(labels, scores, threshold):
"""Compute all evaluation metrics at a given threshold."""
preds = (scores >= threshold).astype(int)
metrics = {
"auroc": float(roc_auc_score(labels, scores)),
"auprc": float(average_precision_score(labels, scores)),
"f1": float(f1_score(labels, preds, zero_division=0)),
"precision": float(precision_score(labels, preds, zero_division=0)),
"recall": float(recall_score(labels, preds, zero_division=0)),
"accuracy": float(accuracy_score(labels, preds)),
"threshold": float(threshold),
}
cm = confusion_matrix(labels, preds)
metrics["confusion_matrix"] = cm.tolist()
metrics["true_negatives"] = int(cm[0, 0])
metrics["false_positives"] = int(cm[0, 1])
metrics["false_negatives"] = int(cm[1, 0])
metrics["true_positives"] = int(cm[1, 1])
return metrics
def plot_roc_curve(labels, scores, save_path):
"""Plot and save ROC curve."""
fpr, tpr, _ = roc_curve(labels, scores)
auroc = roc_auc_score(labels, scores)
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(fpr, tpr, "b-", linewidth=2, label=f"AUROC = {auroc:.4f}")
ax.plot([0, 1], [0, 1], "k--", alpha=0.5, label="Random")
ax.set_xlabel("False Positive Rate", fontsize=12)
ax.set_ylabel("True Positive Rate", fontsize=12)
ax.set_title("ROC Curve — Target Domain", fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(save_path, dpi=150)
plt.close(fig)
print(f"ROC curve saved to {save_path}")
def plot_pr_curve(labels, scores, save_path):
"""Plot and save Precision-Recall curve."""
precision, recall, _ = precision_recall_curve(labels, scores)
auprc = average_precision_score(labels, scores)
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(recall, precision, "r-", linewidth=2, label=f"AUPRC = {auprc:.4f}")
baseline = labels.sum() / len(labels)
ax.axhline(y=baseline, color="k", linestyle="--", alpha=0.5, label=f"Baseline = {baseline:.3f}")
ax.set_xlabel("Recall", fontsize=12)
ax.set_ylabel("Precision", fontsize=12)
ax.set_title("Precision-Recall Curve — Target Domain", fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(save_path, dpi=150)
plt.close(fig)
print(f"PR curve saved to {save_path}")
def plot_score_distribution(labels, scores, threshold, save_path):
"""Plot anomaly score distribution for normal vs anomaly samples."""
fig, ax = plt.subplots(figsize=(10, 6))
normal_scores = scores[labels == 0]
anomaly_scores = scores[labels == 1]
ax.hist(normal_scores, bins=50, alpha=0.6, color="steelblue", label="Normal", density=True)
ax.hist(anomaly_scores, bins=50, alpha=0.6, color="indianred", label="Anomaly", density=True)
ax.axvline(x=threshold, color="black", linestyle="--", linewidth=2,
label=f"Threshold = {threshold:.3f}")
ax.set_xlabel("Anomaly Score", fontsize=12)
ax.set_ylabel("Density", fontsize=12)
ax.set_title("Anomaly Score Distribution — Target Domain", fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(save_path, dpi=150)
plt.close(fig)
print(f"Score distribution saved to {save_path}")
def plot_reconstruction_error(recon_errors, labels, save_path):
"""Plot reconstruction error over sample index, colored by label."""
fig, ax = plt.subplots(figsize=(14, 5))
indices = np.arange(len(recon_errors))
normal_mask = labels == 0
anomaly_mask = labels == 1
ax.scatter(indices[normal_mask], recon_errors[normal_mask],
s=2, alpha=0.4, c="steelblue", label="Normal")
ax.scatter(indices[anomaly_mask], recon_errors[anomaly_mask],
s=8, alpha=0.8, c="indianred", label="Anomaly")
ax.set_xlabel("Sample Index", fontsize=12)
ax.set_ylabel("Reconstruction Error", fontsize=12)
ax.set_title("Reconstruction Error Over Time — Target Domain", fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(save_path, dpi=150)
plt.close(fig)
print(f"Reconstruction error plot saved to {save_path}")
def main():
parser = argparse.ArgumentParser(description="Evaluate domain-adaptive anomaly detector")
parser.add_argument("--checkpoint", type=str,
default="checkpoints/best_model.pt",
help="Path to model checkpoint")
parser.add_argument("--data_dir", type=str, default="data",
help="Data directory")
parser.add_argument("--results_dir", type=str, default="results",
help="Output directory for results")
parser.add_argument("--alpha", type=float, default=0.7,
help="Weight for classifier score vs recon error")
parser.add_argument("--method", type=str, default="dann",
choices=["dann", "mmd", "coral"])
parser.add_argument("--device", type=str, default="")
args = parser.parse_args()
config = Config()
config.data_dir = args.data_dir
config.results_dir = args.results_dir
config.adaptation_method = args.method
if args.device:
config.device = args.device
set_seed(config.seed)
device = torch.device(config.device)
os.makedirs(config.results_dir, exist_ok=True)
print(f"Device: {device}")
print(f"Loading checkpoint: {args.checkpoint}")
# Load model
model = DomainAdaptiveAnomalyDetector(config).to(device)
checkpoint = load_checkpoint(args.checkpoint, model, device=device)
print(f"Loaded model from epoch {checkpoint.get('epoch', '?')}")
# Load data
loaders, normalizer = create_data_loaders(config)
# --- Evaluate on target test set ---
print("\n--- Target Domain Evaluation ---")
scores, labels, recon_errors, probs, latent_features = compute_anomaly_scores(
model, loaders["target_test"], device, alpha=args.alpha
)
# Find optimal threshold
optimal_thresh, optimal_f1 = find_optimal_threshold(labels, scores)
print(f"Optimal threshold: {optimal_thresh:.4f} (F1 = {optimal_f1:.4f})")
# Percentile-based threshold
percentile_thresh = np.percentile(scores, config.anomaly_threshold_percentile)
print(f"Percentile ({config.anomaly_threshold_percentile}%) threshold: {percentile_thresh:.4f}")
# Compute metrics at optimal threshold
metrics_optimal = compute_all_metrics(labels, scores, optimal_thresh)
metrics_optimal["threshold_method"] = "f1_optimal"
# Compute metrics at percentile threshold
metrics_percentile = compute_all_metrics(labels, scores, percentile_thresh)
metrics_percentile["threshold_method"] = "percentile"
# Print results
print(f"\n{'Metric':<20} {'F1-Optimal':>12} {'Percentile':>12}")
print("-" * 46)
for key in ["auroc", "auprc", "f1", "precision", "recall", "accuracy"]:
print(f"{key:<20} {metrics_optimal[key]:>12.4f} {metrics_percentile[key]:>12.4f}")
# Also evaluate on source test for comparison
print("\n--- Source Domain Evaluation (baseline) ---")
src_scores, src_labels, _, _, src_latent = compute_anomaly_scores(
model, loaders["source_test"], device, alpha=args.alpha
)
src_thresh, _ = find_optimal_threshold(src_labels, src_scores)
src_metrics = compute_all_metrics(src_labels, src_scores, src_thresh)
print(f"Source AUROC: {src_metrics['auroc']:.4f}, F1: {src_metrics['f1']:.4f}")
# --- Generate plots ---
print("\nGenerating plots...")
plot_roc_curve(labels, scores, os.path.join(config.results_dir, "roc_curve.png"))
plot_pr_curve(labels, scores, os.path.join(config.results_dir, "pr_curve.png"))
plot_score_distribution(labels, scores, optimal_thresh,
os.path.join(config.results_dir, "score_distribution.png"))
plot_reconstruction_error(recon_errors, labels,
os.path.join(config.results_dir, "recon_error.png"))
# --- Save results ---
results = {
"method": config.adaptation_method,
"alpha": args.alpha,
"target_metrics_optimal": metrics_optimal,
"target_metrics_percentile": metrics_percentile,
"source_metrics": src_metrics,
}
results_path = os.path.join(config.results_dir, "evaluation_results.json")
with open(results_path, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to {results_path}")
if __name__ == "__main__":
main()
Utility Functions
The utility module handles reproducibility, early stopping, checkpointing, metric logging, and visualization including t-SNE plots of feature distributions.
utils.py
"""
utils.py — Utility functions for the DA anomaly detection pipeline.
Includes:
- Seed setting for reproducibility
- EarlyStopping class
- Checkpoint save/load
- MetricLogger with CSV output and plotting
- t-SNE visualization of domain features
"""
import os
import random
import json
import numpy as np
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def set_seed(seed: int = 42):
"""Set random seeds for reproducibility across all libraries."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class EarlyStopping:
"""
Early stopping to halt training when a metric stops improving.
Args:
patience: number of epochs to wait before stopping
mode: 'min' or 'max' — whether lower or higher is better
min_delta: minimum improvement to count as progress
"""
def __init__(self, patience: int = 15, mode: str = "max", min_delta: float = 1e-4):
self.patience = patience
self.mode = mode
self.min_delta = min_delta
self.counter = 0
self.best_value = None
def step(self, value: float) -> bool:
"""
Check if training should stop.
Args:
value: current metric value
Returns:
True if training should stop
"""
if self.best_value is None:
self.best_value = value
return False
if self.mode == "max":
improved = value > self.best_value + self.min_delta
else:
improved = value < self.best_value - self.min_delta
if improved:
self.best_value = value
self.counter = 0
else:
self.counter += 1
return self.counter >= self.patience
def save_checkpoint(model, optimizer, epoch, metrics, filepath):
"""Save model checkpoint."""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"metrics": metrics,
}, filepath)
def load_checkpoint(filepath, model, optimizer=None, device="cpu"):
"""Load model checkpoint."""
checkpoint = torch.load(filepath, map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
if optimizer is not None and "optimizer_state_dict" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
return checkpoint
class MetricLogger:
"""
Logs training metrics to memory and saves to CSV/JSON.
Also generates training curve plots.
"""
def __init__(self, output_dir: str = "results"):
self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)
self.history = {
"epoch": [],
"train_total_loss": [],
"train_cls_loss": [],
"train_recon_loss": [],
"train_domain_loss": [],
"train_lambda": [],
"source_auroc": [],
"source_f1": [],
"target_auroc": [],
"target_f1": [],
}
def log(self, epoch, train_losses, source_metrics, target_metrics):
"""Record one epoch of metrics."""
self.history["epoch"].append(epoch)
self.history["train_total_loss"].append(train_losses["total"])
self.history["train_cls_loss"].append(train_losses["classification"])
self.history["train_recon_loss"].append(train_losses["reconstruction"])
self.history["train_domain_loss"].append(train_losses["domain"])
self.history["train_lambda"].append(train_losses.get("lambda", 0))
self.history["source_auroc"].append(source_metrics["auroc"])
self.history["source_f1"].append(source_metrics["f1"])
self.history["target_auroc"].append(target_metrics["auroc"])
self.history["target_f1"].append(target_metrics["f1"])
def save(self):
"""Save metrics history to JSON."""
path = os.path.join(self.output_dir, "training_history.json")
with open(path, "w") as f:
json.dump(self.history, f, indent=2)
print(f"Training history saved to {path}")
def plot_training_curves(self):
"""Generate and save training curve plots."""
epochs = self.history["epoch"]
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Loss curves
ax = axes[0, 0]
ax.plot(epochs, self.history["train_total_loss"], label="Total", linewidth=2)
ax.plot(epochs, self.history["train_cls_loss"], label="Classification", linewidth=1.5)
ax.plot(epochs, self.history["train_recon_loss"], label="Reconstruction", linewidth=1.5)
ax.plot(epochs, self.history["train_domain_loss"], label="Domain", linewidth=1.5)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("Training Losses")
ax.legend()
ax.grid(True, alpha=0.3)
# AUROC
ax = axes[0, 1]
ax.plot(epochs, self.history["source_auroc"], label="Source AUROC", linewidth=2)
ax.plot(epochs, self.history["target_auroc"], label="Target AUROC", linewidth=2)
ax.set_xlabel("Epoch")
ax.set_ylabel("AUROC")
ax.set_title("AUROC Over Training")
ax.legend()
ax.grid(True, alpha=0.3)
# F1
ax = axes[1, 0]
ax.plot(epochs, self.history["source_f1"], label="Source F1", linewidth=2)
ax.plot(epochs, self.history["target_f1"], label="Target F1", linewidth=2)
ax.set_xlabel("Epoch")
ax.set_ylabel("F1 Score")
ax.set_title("F1 Score Over Training")
ax.legend()
ax.grid(True, alpha=0.3)
# Lambda schedule
ax = axes[1, 1]
ax.plot(epochs, self.history["train_lambda"], label="Domain λ", linewidth=2,
color="purple")
ax.set_xlabel("Epoch")
ax.set_ylabel("Lambda Value")
ax.set_title("Domain Adaptation Lambda Schedule")
ax.legend()
ax.grid(True, alpha=0.3)
fig.tight_layout()
path = os.path.join(self.output_dir, "training_curves.png")
fig.savefig(path, dpi=150)
plt.close(fig)
print(f"Training curves saved to {path}")
def plot_tsne_features(
source_features: np.ndarray,
target_features: np.ndarray,
save_path: str,
title: str = "t-SNE Feature Visualization",
max_samples: int = 2000,
):
"""
Create t-SNE plot showing source vs target feature distributions.
Args:
source_features: (n, d) source latent features
target_features: (m, d) target latent features
save_path: path to save the plot
title: plot title
max_samples: max samples per domain (for speed)
"""
from sklearn.manifold import TSNE
# Subsample if needed
if len(source_features) > max_samples:
idx = np.random.choice(len(source_features), max_samples, replace=False)
source_features = source_features[idx]
if len(target_features) > max_samples:
idx = np.random.choice(len(target_features), max_samples, replace=False)
target_features = target_features[idx]
# Combine and run t-SNE
combined = np.concatenate([source_features, target_features], axis=0)
n_source = len(source_features)
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
embedded = tsne.fit_transform(combined)
fig, ax = plt.subplots(figsize=(10, 8))
ax.scatter(embedded[:n_source, 0], embedded[:n_source, 1],
s=10, alpha=0.5, c="steelblue", label="Source")
ax.scatter(embedded[n_source:, 0], embedded[n_source:, 1],
s=10, alpha=0.5, c="indianred", label="Target")
ax.set_title(title, fontsize=14)
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(save_path, dpi=150)
plt.close(fig)
print(f"t-SNE plot saved to {save_path}")
Running the Full Pipeline
With all nine scripts in place, here is the complete workflow from data generation to final evaluation. Open a terminal in the da-anomaly-detection/ directory and run these commands in order.
Step-by-Step Commands
# Step 1: Install dependencies
pip install -r requirements.txt
# Step 2: Generate synthetic two-domain data
python generate_synthetic_data.py --output_dir data/ --n_samples 20000
# Step 3: Train with DANN (Domain-Adversarial Neural Network)
python train.py --method dann --epochs 100 --batch_size 64 --lr 0.001
# Step 4: Evaluate on target domain
python evaluate.py --checkpoint checkpoints/best_model.pt --data_dir data/ --method dann
# (Optional) Step 5: Train with MMD instead
python train.py --method mmd --epochs 100 --batch_size 64
# (Optional) Step 6: Train with CORAL instead
python train.py --method coral --epochs 100 --batch_size 64
Each training run will print progress every 5 epochs, save the best model checkpoint (based on target domain AUROC), and output training curves to the results/ directory. The evaluation script generates ROC curves, PR curves, score distribution histograms, and reconstruction error time plots.
Understanding the Results
You have run the pipeline and have a results/evaluation_results.json file with numbers. But what do those numbers mean, and how do you know if domain adaptation is actually helping?
Interpreting the Evaluation Metrics
AUROC (Area Under the ROC Curve) is the primary metric. It measures the probability that a randomly chosen anomaly scores higher than a randomly chosen normal sample. An AUROC of 0.5 is random, 1.0 is perfect. For domain adaptation to be considered successful, the target domain AUROC should be significantly higher than the “no adaptation” baseline (training only on source, evaluating on target with no domain adaptation).
AUPRC (Area Under the Precision-Recall Curve) is more informative when anomalies are rare. In highly imbalanced datasets (1% anomaly rate), AUROC can look good even when the model has a high false positive rate. AUPRC penalizes false positives more heavily.
F1 Score is the harmonic mean of precision and recall, computed at the optimal threshold. It gives you a single number that balances false positives and false negatives. For industrial applications, you typically care more about recall (do not miss anomalies) than precision (some false alarms are acceptable).
What Good vs. Bad Domain Adaptation Looks Like
| Scenario | Source AUROC | Target AUROC (no adapt) | Target AUROC (with DA) | Interpretation |
|---|---|---|---|---|
| Successful adaptation | 0.95 | 0.62 | 0.87 | Domain adaptation recovered most performance |
| Negative transfer | 0.95 | 0.65 | 0.58 | DA made things worse; domains may be too different |
| No domain shift | 0.93 | 0.91 | 0.92 | Little domain shift exists; DA not needed |
| Partial adaptation | 0.95 | 0.55 | 0.72 | DA helps but gap remains; try tuning or more target data |
Understanding t-SNE Plots
The t-SNE visualization is your most intuitive diagnostic tool. Run it on the latent features before and after domain adaptation:
- Before adaptation: You should see two distinct clusters—source samples clumped together in one region, target samples in another. This visual separation confirms that domain shift exists in the data.
- After successful adaptation: The source and target clusters should overlap significantly. The encoder has learned features that look the same regardless of which domain produced the input. If the anomaly classifier works on source features, it should now work on the (overlapping) target features too.
- After failed adaptation: Clusters remain separate, or worse, everything collapses to a single point (mode collapse in the discriminator).
When to Use DANN vs. MMD vs. CORAL
| Method | Mechanism | Strengths | Weaknesses | Best For |
|---|---|---|---|---|
| DANN | Adversarial training via GRL | Powerful; learns complex alignment | Unstable training; sensitive to hyperparameters | Large domain shifts; enough training data |
| MMD | Kernel-based distribution matching | Stable training; mathematically principled | Expensive for large batches; kernel selection matters | Moderate domain shifts; limited compute |
| CORAL | Covariance matrix alignment | Simple; fast; no extra hyperparameters | Only matches second-order statistics | Small domain shifts; quick baseline |
Adapting to Your Own Data
The synthetic data is a sandbox. Here is how to plug in your own time-series data with minimal code changes.
Modifying dataset.py for Your Data Format
Your CSV files need to follow this structure: each row is a timestep, each column (except label and timestamp) is a sensor channel. The column names do not matter as long as label and timestamp are correctly named (or absent). If your data uses a different format, modify the load_csv_data() function:
# Example: your data has columns named 'temp_1', 'temp_2', 'vibration_x', etc.
# and uses 'anomaly' instead of 'label'
def load_csv_data(filepath, has_labels=True):
df = pd.read_csv(filepath)
exclude = ["anomaly", "timestamp", "machine_id", "date"]
feature_cols = [c for c in df.columns if c not in exclude]
data = df[feature_cols].values.astype(np.float32)
labels = df["anomaly"].values.astype(np.float32) if has_labels else None
return data, labels
Adjusting Model Dimensions
If your sensor data has a different number of channels, you only need to change num_features in config.py. The model automatically adjusts. For different sampling rates, adjust window_size—as a rule of thumb, your window should span roughly one “cycle” of the normal operating pattern. For a machine cycling every 5 seconds sampled at 100 Hz, use window_size=500. For slow processes (daily patterns at hourly sampling), use window_size=24.
Handling Class Imbalance
Real anomaly data is heavily imbalanced—often 1% anomalies or less. Three strategies that work well with this codebase:
- Weighted BCE loss: Replace
BCEWithLogitsLoss()withBCEWithLogitsLoss(pos_weight=torch.tensor([19.0]))where 19.0 is the ratio of normal to anomaly samples. - Focal loss: Down-weights easy negatives. Replace the BCE in
AnomalyDetectionLoss. - Oversampling: Use PyTorch’s
WeightedRandomSamplerto oversample anomaly windows in the source training loader.
Hyperparameter Tuning Guide
The hyperparameters listed below are ordered by sensitivity—tune the top ones first:
- lambda_domain (0.1–2.0): The most sensitive parameter. Too high causes the encoder to learn domain-invariant features that are useless for anomaly detection. Too low means no adaptation. Start at 0.5 and adjust.
- learning_rate (1e-4–1e-2): Standard neural network tuning. Use cosine annealing.
- window_size (32–256): Must capture enough context for anomalies to be visible.
- latent_dim (64–256): Larger gives more capacity but risks overfitting.
- alpha (0.5–0.9): Anomaly scoring mix. Higher alpha trusts the classifier more; lower trusts reconstruction error more.
Common Issues and Solutions
Domain adaptation training is notoriously finicky. Here is a reference table of problems you will likely encounter and how to fix them.
| Problem | Symptom | Cause | Solution |
|---|---|---|---|
| Discriminator mode collapse | Domain loss stays at ~0.69 (ln 2) | Discriminator outputs 0.5 for everything | Increase discriminator LR; add more layers; reduce GRL lambda |
| Training instability | Loss oscillates wildly or diverges | Lambda too high too early | Use progressive lambda schedule; reduce learning rate; increase gradient clipping |
| Negative transfer | Target AUROC decreases with DA | Domains are too different or share no useful structure | Reduce lambda_domain; try CORAL (less aggressive); verify domains share anomaly types |
| High false positive rate | Good recall but terrible precision | Threshold too low; recon error noisy | Increase alpha (trust classifier more); use percentile threshold; add recon error smoothing |
| Source AUROC drops during DA | Classification degrades on source | Domain-invariant features lose discriminative power | Increase lambda_cls; reduce lambda_domain; train classifier longer before starting DA |
| Out of memory (GPU) | CUDA OOM error | Batch size or model too large | Reduce batch_size; reduce latent_dim; use gradient accumulation |
| MMD loss is NaN | NaN in training | Kernel bandwidth mismatch with feature scale | Normalize features; adjust kernel_bandwidths in config; add epsilon to kernel computation |
Conclusion
You now have a complete, end-to-end implementation of domain-adaptive time-series anomaly detection. Let us recap what we built and where to go next.
The nine scripts in this guide cover the full pipeline: generating realistic synthetic data with domain shift, building a CNN-LSTM encoder with multi-head outputs, implementing three different domain adaptation strategies (DANN, MMD, CORAL), training with progressive lambda scheduling, and evaluating with comprehensive metrics and diagnostic plots. Every script is complete and runnable as-is.
The core insight is simple but powerful: instead of requiring expensive labeled data in every new domain, you can train a model to learn domain-invariant features—representations that capture the essence of “anomaly” regardless of which machine, factory, or sensor produced the signal. The Gradient Reversal Layer is the elegant mechanism that makes this adversarial training possible in a single unified model, while MMD and CORAL offer simpler, more stable alternatives.
Where should you go from here? Three directions are most promising. First, semi-supervised adaptation: if you can label even 5–10% of the target domain data, you can add a supervised loss on those labeled target samples alongside the unsupervised domain alignment, dramatically improving results. Second, multi-source adaptation: if you have data from machines A, B, and C, you can adapt to machine D by combining knowledge from all three sources, not just one. Third, continual adaptation: in production, the target domain drifts over time as machines age and wear. Implement online or periodic re-adaptation to keep the model current.
Domain adaptation is not a silver bullet. It works best when domains share the same underlying anomaly mechanisms but differ in superficial signal characteristics—exactly the scenario in most industrial settings. When it works, it can save months of labeling effort and accelerate deployment of anomaly detection to new equipment. The code in this guide gives you everything you need to start experimenting with your own data today.
References
- Ganin, Y., Ustinova, E., Ajakan, H., Germain, P., Larochelle, H., Laviolette, F., Marchand, M., and Lempitsky, V. (2016). “Domain-Adversarial Training of Neural Networks.” Journal of Machine Learning Research, 17(59), 1-35.
- Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., and Smola, A. J. (2012). “A Kernel Two-Sample Test.” Journal of Machine Learning Research, 13, 723-773.
- Sun, B. and Saenko, K. (2016). “Deep CORAL: Correlation Alignment for Deep Domain Adaptation.” Proceedings of the European Conference on Computer Vision (ECCV) Workshops.
- Ragab, M., Lu, Z., Chen, Z., Wu, M., Kwoh, C. K., and Li, X. (2023). “Time-Series Domain Adaptation: A Survey.” arXiv preprint.
- Chalapathy, R. and Chawla, S. (2019). “Deep Learning for Anomaly Detection: A Survey.” arXiv preprint.
- PyTorch Documentation. “Extending torch.autograd — Custom Function.”
Leave a Reply