You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

596 lines
23 KiB

1 month ago
  1. import math
  2. from einops import rearrange, reduce
  3. import torch
  4. import torch.nn as nn
  5. from torch.autograd import Function
  6. import torch.nn.functional as F
  7. class DifferentiableEntropyFunction(Function):
  8. @staticmethod
  9. def forward(ctx, zq, basis, K, eps):
  10. zb = (zq + 1) / 2
  11. zi = ((zb * basis).sum(-1)).to(torch.int64)
  12. cnt = torch.scatter_reduce(torch.zeros(2 ** K, device=zq.device, dtype=zq.dtype),
  13. 0,
  14. zi.flatten(),
  15. torch.ones_like(zi.flatten()).to(zq.dtype),
  16. 'sum')
  17. prob = (cnt + eps) / (cnt + eps).sum()
  18. H = -(prob * torch.log(prob)).sum()
  19. ctx.save_for_backward(zq, zi, prob)
  20. ctx.K = K
  21. return H
  22. @staticmethod
  23. def backward(ctx, grad_output):
  24. zq, zi, prob = ctx.saved_tensors
  25. grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K
  26. reord_grad = grad_array[zi.flatten()].reshape(zi.shape)
  27. grad_input = reord_grad.unsqueeze(-1) * zq
  28. return grad_input, None, None, None, None
  29. def codebook_entropy(zq, basis, K, eps=1e-4):
  30. return DifferentiableEntropyFunction.apply(zq, basis, K, eps)
  31. class BinarySphericalQuantizer(nn.Module):
  32. def __init__(self, embed_dim, beta, gamma0, gamma, zeta,
  33. input_format='bchw',
  34. soft_entropy=True, group_size=9,
  35. persample_entropy_compute='analytical',
  36. cb_entropy_compute='group',
  37. l2_norm=True,
  38. inv_temperature=1):
  39. """
  40. Paper link: https://arxiv.org/pdf/2406.07548.pdf
  41. Here we use the official implementation of the BinarySphericalQuantizer.
  42. """
  43. super().__init__()
  44. self.embed_dim = embed_dim
  45. self.beta = beta # loss weight for commit loss
  46. self.gamma0 = gamma0 # loss weight for entropy penalty
  47. self.gamma = gamma # loss weight for entropy penalty
  48. self.zeta = zeta # loss weight for entire entropy penalty
  49. self.input_format = input_format
  50. assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size"
  51. self.num_groups = self.embed_dim // group_size
  52. self.group_size = group_size
  53. assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'"
  54. assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'"
  55. self.persample_entropy_compute = persample_entropy_compute
  56. self.cb_entropy_compute = cb_entropy_compute
  57. self.l2_norm = l2_norm
  58. self.inv_temperature = inv_temperature
  59. self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1))
  60. self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1))
  61. self.num_dimensions = 2 ** embed_dim
  62. self.bits_per_index = embed_dim
  63. # we only need to keep the codebook portion up to the group size
  64. # because we approximate the H loss with this subcode
  65. group_codes = torch.arange(2 ** self.group_size)
  66. group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:]
  67. self.register_buffer('group_codebook', group_codebook, persistent=False)
  68. self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf
  69. def quantize(self, z):
  70. assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}"
  71. zhat = torch.where(z > 0,
  72. torch.tensor(1, dtype=z.dtype, device=z.device),
  73. torch.tensor(-1, dtype=z.dtype, device=z.device))
  74. return z + (zhat - z).detach()
  75. def forward(self, z):
  76. # if self.input_format == 'bchw':
  77. # z = rearrange(z, 'b c h w -> b h w c')
  78. zq = self.quantize(z)
  79. indices = self.codes_to_indexes(zq.detach())
  80. group_indices = self.codes_to_group_indexes(zq.detach())
  81. if not self.training:
  82. used_codes = torch.unique(indices, return_counts=False)
  83. else:
  84. used_codes = None
  85. q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
  86. if self.soft_entropy:
  87. persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z)
  88. entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
  89. else:
  90. zb_by_sample = ((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32)
  91. persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample)
  92. cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim)
  93. entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
  94. zq = zq * q_scale
  95. # commit loss
  96. commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1))
  97. # if self.input_format == 'bchw':
  98. # zq = rearrange(zq, 'b h w c -> b c h w')
  99. return (
  100. zq,
  101. commit_loss + self.zeta * entropy_penalty / self.inv_temperature,
  102. {"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices,
  103. "avg_prob": avg_prob}
  104. )
  105. def soft_entropy_loss(self, z):
  106. # if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size
  107. # the sub-code is the last group_size bits of the full code
  108. group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1)
  109. divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size)
  110. # we calculate the distance between the divided_z and the codebook for each subgroup
  111. distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book)
  112. prob = (-distance * self.inv_temperature).softmax(dim=-1)
  113. if self.persample_entropy_compute == 'analytical':
  114. if self.l2_norm:
  115. p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature)
  116. else:
  117. p = torch.sigmoid(-4 * z * self.inv_temperature)
  118. prob = torch.stack([p, 1 - p], dim=-1)
  119. per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
  120. else:
  121. per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
  122. # macro average of the probability of each subgroup
  123. avg_prob = reduce(prob, '... g d ->g d', 'mean')
  124. codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False)
  125. # the approximation of the entropy is the sum of the entropy of each subgroup
  126. return per_sample_entropy, codebook_entropy.sum(), avg_prob
  127. def get_hard_per_sample_entropy(self, zb_by_sample):
  128. probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1]
  129. persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8)
  130. persample_entropy = persample_entropy.sum(-1)
  131. return persample_entropy.mean()
  132. def codes_to_indexes(self, zhat):
  133. """Converts a `code` to an index in the codebook.
  134. Args:
  135. zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
  136. """
  137. assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}"
  138. return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64)
  139. def codes_to_group_indexes(self, zhat):
  140. """Converts a `code` to a list of indexes (in groups) in the codebook.
  141. Args:
  142. zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
  143. """
  144. zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size)
  145. return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64)
  146. def indexes_to_codes(self, indices):
  147. """Inverse of `indexes_to_codes`."""
  148. indices = indices.unsqueeze(-1)
  149. codes_non_centered = torch.remainder(
  150. torch.floor_divide(indices, self.basis), 2
  151. )
  152. return codes_non_centered * 2 - 1
  153. def group_indexes_to_codes(self, group_indices):
  154. """Inverse of `group_indexes_to_codes`."""
  155. group_indices = group_indices.unsqueeze(-1)
  156. codes_non_centered = torch.remainder(
  157. torch.floor_divide(group_indices, self.group_basis), 2
  158. )
  159. codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)')
  160. return codes_non_centered * 2 - 1
  161. def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True):
  162. if normalize:
  163. probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True)
  164. else:
  165. probs = count
  166. H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim)
  167. return H
  168. def get_group_codebook_entry(self, group_indices):
  169. z_q = self.group_indexes_to_codes(group_indices)
  170. q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
  171. z_q = z_q * q_scale
  172. if self.input_format == 'bchw':
  173. h, w = int(z_q.shape[1] ** 0.5)
  174. assert h * w == z_q.shape[1], 'Invalid sequence length'
  175. z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
  176. return z_q
  177. def get_codebook_entry(self, indices):
  178. z_q = self.indexes_to_codes(indices)
  179. q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
  180. z_q = z_q * q_scale
  181. if self.input_format == 'bchw':
  182. h, w = int(z_q.shape[1] ** 0.5)
  183. assert h * w == z_q.shape[1], 'Invalid sequence length'
  184. z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
  185. return z_q
  186. class BSQuantizer(nn.Module):
  187. def __init__(self, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size):
  188. super().__init__()
  189. self.codebook_dim = s1_bits + s2_bits
  190. self.s1_bits = s1_bits
  191. self.s2_bits = s2_bits
  192. self.bsq = BinarySphericalQuantizer(self.codebook_dim, beta, gamma0, gamma, zeta, group_size=group_size)
  193. def bits_to_indices(self, bits):
  194. bits = (bits >= 0).to(torch.long)
  195. indices = 2 ** torch.arange(
  196. 0,
  197. bits.shape[-1],
  198. 1,
  199. dtype=torch.long,
  200. device=bits.device,
  201. )
  202. return (bits * indices).sum(-1)
  203. def forward(self, z, half=False):
  204. z = F.normalize(z, dim=-1)
  205. quantized, bsq_loss, metrics = self.bsq(z)
  206. if half:
  207. q_pre = quantized[:, :, :self.s1_bits]
  208. q_post = quantized[:, :, self.s1_bits:]
  209. z_indices = [self.bits_to_indices(q_pre), self.bits_to_indices(q_post)]
  210. else:
  211. z_indices = self.bits_to_indices(quantized)
  212. return bsq_loss, quantized, z_indices
  213. class RMSNorm(torch.nn.Module):
  214. def __init__(self, dim: int, eps: float = 1e-5):
  215. super().__init__()
  216. self.eps = eps
  217. self.weight = nn.Parameter(torch.ones(dim))
  218. def _norm(self, x):
  219. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  220. def forward(self, x):
  221. output = self._norm(x.float()).type_as(x)
  222. return output * self.weight
  223. class FeedForward(nn.Module):
  224. def __init__(self, d_model, ff_dim, ffn_dropout_p=0.0):
  225. super().__init__()
  226. self.w1 = nn.Linear(d_model, ff_dim, bias=False)
  227. self.w3 = nn.Linear(d_model, ff_dim, bias=False)
  228. self.w2 = nn.Linear(ff_dim, d_model, bias=False)
  229. self.ffn_dropout = nn.Dropout(ffn_dropout_p)
  230. def forward(self, x):
  231. return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
  232. class RotaryPositionalEmbedding(nn.Module):
  233. def __init__(self, dim):
  234. super().__init__()
  235. inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
  236. self.register_buffer("inv_freq", inv_freq)
  237. self.seq_len_cached = None
  238. self.cos_cached = None
  239. self.sin_cached = None
  240. def _update_cos_sin_cache(self, x, seq_len):
  241. if seq_len != self.seq_len_cached:
  242. self.seq_len_cached = seq_len
  243. t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
  244. freqs = torch.einsum('i,j->ij', t, self.inv_freq)
  245. emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
  246. self.cos_cached = emb.cos()[None, None, :, :]
  247. self.sin_cached = emb.sin()[None, None, :, :]
  248. return self.cos_cached, self.sin_cached
  249. def forward(self, q, k):
  250. cos, sin = self._update_cos_sin_cache(q, q.shape[-2])
  251. return (
  252. (q * cos) + (self._rotate_half(q) * sin),
  253. (k * cos) + (self._rotate_half(k) * sin),
  254. )
  255. def _rotate_half(self, x):
  256. x1, x2 = x.chunk(2, dim=-1)
  257. return torch.cat((-x2, x1), dim=-1)
  258. def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, training=True) -> torch.Tensor:
  259. L, S = query.size(-2), key.size(-2)
  260. scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
  261. attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
  262. if is_causal:
  263. assert attn_mask is None
  264. temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).to(query.device)
  265. attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
  266. attn_bias.to(query.dtype)
  267. attn_weight = query @ key.transpose(-2, -1) * scale_factor
  268. attn_weight += attn_bias
  269. if attn_mask is not None:
  270. attn_mask_bias = torch.zeros_like(attn_weight)
  271. if attn_mask.dtype == torch.bool:
  272. attn_mask_bias.masked_fill_(attn_mask, float("-inf"))
  273. else:
  274. attn_mask_bias += attn_mask
  275. attn_weight += attn_mask_bias
  276. attn_weight = torch.softmax(attn_weight, dim=-1)
  277. attn_weight = torch.dropout(attn_weight, dropout_p, train=training)
  278. return attn_weight @ value
  279. class MultiHeadAttentionWithRoPE(nn.Module):
  280. def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0):
  281. super().__init__()
  282. self.d_model = d_model
  283. self.n_heads = n_heads
  284. self.head_dim = d_model // n_heads
  285. self.q_proj = nn.Linear(d_model, d_model)
  286. self.k_proj = nn.Linear(d_model, d_model)
  287. self.v_proj = nn.Linear(d_model, d_model)
  288. self.out_proj = nn.Linear(d_model, d_model)
  289. self.rotary = RotaryPositionalEmbedding(self.head_dim)
  290. self.attn_dropout_p = attn_dropout_p
  291. self.resid_dropout = nn.Dropout(resid_dropout_p)
  292. def forward(self, x, key_padding_mask=None):
  293. batch_size, seq_len, _ = x.shape
  294. q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
  295. k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
  296. v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
  297. q, k = self.rotary(q, k)
  298. if key_padding_mask is not None:
  299. attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len]
  300. attn_mask = attn_mask.expand(-1, self.n_heads, seq_len, -1) # [batch, n_heads, q_len, k_len]
  301. else:
  302. attn_mask = None
  303. attn_output = scaled_dot_product_attention(
  304. q, k, v,
  305. attn_mask=attn_mask,
  306. dropout_p=self.attn_dropout_p,
  307. is_causal=True,
  308. training=self.training
  309. )
  310. attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
  311. return self.resid_dropout(self.out_proj(attn_output))
  312. class MultiHeadCrossAttentionWithRoPE(nn.Module):
  313. def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout=0.0):
  314. super().__init__()
  315. self.d_model = d_model
  316. self.n_heads = n_heads
  317. self.head_dim = d_model // n_heads
  318. self.q_proj = nn.Linear(d_model, d_model)
  319. self.k_proj = nn.Linear(d_model, d_model)
  320. self.v_proj = nn.Linear(d_model, d_model)
  321. self.out_proj = nn.Linear(d_model, d_model)
  322. self.rotary = RotaryPositionalEmbedding(self.head_dim)
  323. self.attn_dropout_p = attn_dropout_p
  324. self.resid_dropout = nn.Dropout(resid_dropout)
  325. def forward(self, query, key, value, key_padding_mask=None):
  326. batch_size, q_len, _ = query.shape
  327. _, seq_len, _ = key.shape
  328. q = self.q_proj(query).view(batch_size, q_len, self.n_heads, self.head_dim).transpose(1, 2)
  329. k = self.k_proj(key).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
  330. v = self.v_proj(value).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
  331. q, k = self.rotary(q, k)
  332. if key_padding_mask is not None:
  333. attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
  334. attn_mask = attn_mask.expand(-1, self.n_heads, q_len, -1)
  335. else:
  336. attn_mask = None
  337. is_causal_flag = self.training
  338. attn_output = scaled_dot_product_attention(
  339. q, k, v,
  340. attn_mask=attn_mask,
  341. dropout_p=self.attn_dropout_p,
  342. is_causal=is_causal_flag,
  343. training=self.training
  344. )
  345. attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model)
  346. return self.resid_dropout(self.out_proj(attn_output))
  347. class HierarchicalEmbedding(nn.Module):
  348. def __init__(self, s1_bits, s2_bits, d_model=256):
  349. super().__init__()
  350. self.s1_bits = s1_bits
  351. self.s2_bits = s2_bits
  352. vocab_s1 = 2 ** s1_bits
  353. vocab_s2 = 2 ** s2_bits
  354. self.emb_s1 = nn.Embedding(vocab_s1, d_model)
  355. self.emb_s2 = nn.Embedding(vocab_s2, d_model)
  356. self.d_model = d_model
  357. self.fusion_proj = nn.Linear(d_model * 2, d_model)
  358. nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5)
  359. nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5)
  360. def split_token(self, token_ids: torch.Tensor, s2_bits: int):
  361. """Inputs:
  362. token_ids (torch.Tensor): Composite token IDs of shape [batch_size, seq_len] or [N], each in range [0, 2^(s1_bits + s2_bits) - 1].
  363. s2_bits (int): Number of low bits used for the fine token (s2).
  364. """
  365. assert isinstance(s2_bits, int) and s2_bits > 0, "s2_bits must be a positive integer"
  366. t = token_ids.long()
  367. mask = (1 << s2_bits) - 1
  368. s2_ids = t & mask # extract low bits
  369. s1_ids = t >> s2_bits # extract high bits
  370. return s1_ids, s2_ids
  371. def forward(self, token_ids):
  372. """Inputs:
  373. token_ids:
  374. - tuple or list: (s1_ids, s2_ids), each of shape [batch_size, seq_len], or
  375. - torch.Tensor: composite token IDs of shape [batch_size, seq_len], which will be split into (s1_ids, s2_ids) internally.
  376. Output: [batch_size, seq_len, d_model]
  377. """
  378. if isinstance(token_ids, tuple) or isinstance(token_ids, list):
  379. s1_ids, s2_ids = token_ids
  380. else:
  381. s1_ids, s2_ids = self.split_token(token_ids, self.s2_bits)
  382. s1_emb = self.emb_s1(s1_ids) * math.sqrt(self.d_model)
  383. s2_emb = self.emb_s2(s2_ids) * math.sqrt(self.d_model)
  384. return self.fusion_proj(torch.cat([s1_emb, s2_emb], dim=-1))
  385. class DependencyAwareLayer(nn.Module):
  386. def __init__(self, d_model, n_heads=4, attn_dropout_p=0.0, resid_dropout=0.0):
  387. super().__init__()
  388. self.cross_attn = MultiHeadCrossAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout)
  389. self.norm = RMSNorm(d_model)
  390. def forward(self, hidden_states, sibling_embed, key_padding_mask=None):
  391. """hidden_states: [batch, seq_len, d_model]
  392. sibling_embed: Embedding from another subtoken
  393. """
  394. attn_out = self.cross_attn(
  395. query=sibling_embed,
  396. key=hidden_states,
  397. value=hidden_states,
  398. key_padding_mask=key_padding_mask
  399. )
  400. return self.norm(hidden_states + attn_out)
  401. class TransformerBlock(nn.Module):
  402. def __init__(self, d_model, n_heads, ff_dim=1024, ffn_dropout_p=0.0, attn_dropout_p=0.0, resid_dropout_p=0.0):
  403. super().__init__()
  404. self.norm1 = RMSNorm(d_model)
  405. self.self_attn = MultiHeadAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout_p)
  406. self.norm2 = RMSNorm(d_model)
  407. self.ffn = FeedForward(d_model, ff_dim, ffn_dropout_p)
  408. def forward(self, x, key_padding_mask=None):
  409. residual = x
  410. x = self.norm1(x)
  411. attn_out = self.self_attn(x, key_padding_mask=key_padding_mask)
  412. x = residual + attn_out
  413. residual = x
  414. x = self.norm2(x)
  415. ffn_out = self.ffn(x)
  416. x = residual + ffn_out
  417. return x
  418. class DualHead(nn.Module):
  419. def __init__(self, s1_bits, s2_bits, d_model):
  420. super().__init__()
  421. self.vocab_s1 = 2 ** s1_bits
  422. self.vocab_s2 = 2 ** s2_bits
  423. self.proj_s1 = nn.Linear(d_model, self.vocab_s1)
  424. self.proj_s2 = nn.Linear(d_model, self.vocab_s2)
  425. def compute_loss(self, s1_logits, s2_logits, s1_targets, s2_targets, padding_mask=None):
  426. if padding_mask is not None:
  427. valid_mask = (padding_mask == 0)
  428. s1_logits = s1_logits[valid_mask]
  429. s2_logits = s2_logits[valid_mask]
  430. s1_targets = s1_targets[valid_mask]
  431. s2_targets = s2_targets[valid_mask]
  432. ce_s1 = F.cross_entropy(s1_logits, s1_targets)
  433. ce_s2 = F.cross_entropy(s2_logits, s2_targets)
  434. else:
  435. ce_s1 = F.cross_entropy(s1_logits.reshape(-1, self.vocab_s1), s1_targets.reshape(-1))
  436. ce_s2 = F.cross_entropy(s2_logits.reshape(-1, self.vocab_s2), s2_targets.reshape(-1))
  437. ce_loss = (ce_s1 + ce_s2) / 2
  438. return ce_loss, ce_s1, ce_s2
  439. def forward(self, x):
  440. return self.proj_s1(x)
  441. def cond_forward(self, x2):
  442. return self.proj_s2(x2)
  443. class FixedEmbedding(nn.Module):
  444. def __init__(self, c_in, d_model):
  445. super(FixedEmbedding, self).__init__()
  446. w = torch.zeros(c_in, d_model).float()
  447. w.require_grad = False
  448. position = torch.arange(0, c_in).float().unsqueeze(1)
  449. div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
  450. w[:, 0::2] = torch.sin(position * div_term)
  451. w[:, 1::2] = torch.cos(position * div_term)
  452. self.emb = nn.Embedding(c_in, d_model)
  453. self.emb.weight = nn.Parameter(w, requires_grad=False)
  454. def forward(self, x):
  455. return self.emb(x).detach()
  456. class TemporalEmbedding(nn.Module):
  457. def __init__(self, d_model, learn_pe):
  458. super(TemporalEmbedding, self).__init__()
  459. minute_size = 60
  460. hour_size = 24
  461. weekday_size = 7
  462. day_size = 32
  463. month_size = 13
  464. Embed = FixedEmbedding if not learn_pe else nn.Embedding
  465. self.minute_embed = Embed(minute_size, d_model)
  466. self.hour_embed = Embed(hour_size, d_model)
  467. self.weekday_embed = Embed(weekday_size, d_model)
  468. self.day_embed = Embed(day_size, d_model)
  469. self.month_embed = Embed(month_size, d_model)
  470. def forward(self, x):
  471. x = x.long()
  472. minute_x = self.minute_embed(x[:, :, 0])
  473. hour_x = self.hour_embed(x[:, :, 1])
  474. weekday_x = self.weekday_embed(x[:, :, 2])
  475. day_x = self.day_embed(x[:, :, 3])
  476. month_x = self.month_embed(x[:, :, 4])
  477. return hour_x + weekday_x + day_x + month_x + minute_x