import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import math
from ice.anomaly_detection.models.base import BaseAnomalyDetection
"""
The code of Anomaly Transformer is taken from official implementation:
https://github.com/thuml/Anomaly-Transformer
"""
class TriangularCausalMask():
def __init__(self, B, L, device="cpu"):
mask_shape = [B, 1, L, L]
with torch.no_grad():
self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
@property
def mask(self):
return self._mask
class AnomalyAttention(nn.Module):
def __init__(self, win_size, mask_flag=True, scale=None, attention_dropout=0.0, output_attention=True, device="cpu"):
super(AnomalyAttention, self).__init__()
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
window_size = win_size
self.distances = torch.zeros((window_size, window_size)).to(device)
for i in range(window_size):
for j in range(window_size):
self.distances[i][j] = abs(i - j)
def forward(self, queries, keys, values, sigma, attn_mask):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1. / math.sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
if self.mask_flag:
if attn_mask is None:
attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
attn = scale * scores
sigma = sigma.transpose(1, 2) # B L H -> B H L
window_size = attn.shape[-1]
sigma = torch.sigmoid(sigma * 5) + 1e-5
sigma = torch.pow(3, sigma) - 1
sigma = sigma.unsqueeze(-1).repeat(1, 1, 1, window_size) # B H L L
prior = self.distances.unsqueeze(0).unsqueeze(0).repeat(sigma.shape[0], sigma.shape[1], 1, 1).to(queries.device)
prior = 1.0 / (math.sqrt(2 * math.pi) * sigma) * torch.exp(-prior ** 2 / 2 / (sigma ** 2))
series = self.dropout(torch.softmax(attn, dim=-1))
V = torch.einsum("bhls,bshd->blhd", series, values)
if self.output_attention:
return (V.contiguous(), series, prior, sigma)
else:
return (V.contiguous(), None)
class AttentionLayer(nn.Module):
def __init__(self, attention, d_model, n_heads, d_keys=None,
d_values=None):
super(AttentionLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads)
d_values = d_values or (d_model // n_heads)
self.norm = nn.LayerNorm(d_model)
self.inner_attention = attention
self.query_projection = nn.Linear(d_model,
d_keys * n_heads)
self.key_projection = nn.Linear(d_model,
d_keys * n_heads)
self.value_projection = nn.Linear(d_model,
d_values * n_heads)
self.sigma_projection = nn.Linear(d_model,
n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
def forward(self, queries, keys, values, attn_mask):
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
x = queries
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
sigma = self.sigma_projection(x).view(B, L, H)
out, series, prior, sigma = self.inner_attention(
queries,
keys,
values,
sigma,
attn_mask
)
out = out.view(B, L, -1)
return self.out_projection(out), series, prior, sigma
class PositionalEmbedding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEmbedding, self).__init__()
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model).float()
pe.require_grad = False
position = torch.arange(0, max_len).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return self.pe[:, :x.size(1)]
class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(TokenEmbedding, self).__init__()
padding = 1 if torch.__version__ >= '1.5.0' else 2
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
def forward(self, x):
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
return x
class DataEmbedding(nn.Module):
def __init__(self, c_in, d_model, dropout=0.0):
super(DataEmbedding, self).__init__()
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x):
x = self.value_embedding(x) + self.position_embedding(x)
return self.dropout(x)
class EncoderLayer(nn.Module):
def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
super(EncoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.attention = attention
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, attn_mask=None):
new_x, attn, mask, sigma = self.attention(
x, x, x,
attn_mask=attn_mask
)
x = x + self.dropout(new_x)
y = x = self.norm1(x)
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
return self.norm2(x + y), attn, mask, sigma
class Encoder(nn.Module):
def __init__(self, attn_layers, norm_layer=None):
super(Encoder, self).__init__()
self.attn_layers = nn.ModuleList(attn_layers)
self.norm = norm_layer
def forward(self, x, attn_mask=None):
# x [B, L, D]
series_list = []
prior_list = []
sigma_list = []
for attn_layer in self.attn_layers:
x, series, prior, sigma = attn_layer(x, attn_mask=attn_mask)
series_list.append(series)
prior_list.append(prior)
sigma_list.append(sigma)
if self.norm is not None:
x = self.norm(x)
return x, series_list, prior_list, sigma_list
class AnomalyTransformerModel(nn.Module):
def __init__(
self,
win_size,
enc_in,
c_out,
d_model,
n_heads,
e_layers,
d_ff,
dropout,
activation,
device
):
super(AnomalyTransformerModel, self).__init__()
# Encoding
self.embedding = DataEmbedding(enc_in, d_model, dropout)
# Encoder
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
AnomalyAttention(win_size, False, attention_dropout=dropout, device=device),
d_model, n_heads),
d_model,
d_ff,
dropout=dropout,
activation=activation
) for l in range(e_layers)
],
norm_layer=torch.nn.LayerNorm(d_model)
)
self.projection = nn.Linear(d_model, c_out, bias=True)
def forward(self, x):
enc_out = self.embedding(x)
enc_out, series, prior, sigmas = self.encoder(enc_out)
enc_out = self.projection(enc_out)
return enc_out # [B, L, D]