We train embedding tables with fewer parameters by combining multiple sketches of the same data, iteratively.
Henry Tsang and Thomas Ahle ArXiv PDF GitHub Repository
Embedding tables are used by machine learning systems to work with categorical features. In modern Recommendation Systems, these tables can be very large, necessitating the development of new methods for fitting them in memory, even during training. We suggest Clustered Compositional Embeddings (CCE) which combines clustering-based compression like quantization to codebooks with dynamic methods like The Hashing Trick and Compositional Embeddings. Experimentally CCE achieves the best of both worlds: The high compression rate of codebook-based quantization, but dynamically like hashing-based methods, so it can be used during training. Theoretically, we prove that CCE is guaranteed to converge to the optimal codebook and give a tight bound for the number of iterations required.
import torch, cce
class GMF(torch.nn.Module):
""" A simple Generalized Matrix Factorization model """
def __init__(self, n_users, n_items, dim, num_params):
super().__init__()
self.user_embedding = cce.make_embedding(n_users, num_params, dim, 'cce', n_chunks=4)
self.item_embedding = cce.make_embedding(n_items, num_params, dim, 'cce', n_chunks=4)
def forward(self, user, item):
user_emb = self.user_embedding(user)
item_emb = self.item_embedding(item)
return torch.sigmoid((user_emb * item_emb).sum(-1))
def epoch_end(self):
self.user_embedding.cluster()
self.item_embedding.cluster()
Other than the Clustered Compositional Embedding, the library also contain many other compressed embedding methods, such as
ce.RobeEmbedding
,
ce.CompositionalEmbedding
,
ce.TensorTrainEmbedding
and
ce.DeepHashEmbedding
.
We adapted the Deep Learning Recommendation Model (DLRM) model to use CCE. Even reducing the number of parameters by a factor 8,500, we were able to get the same test loss (Binary cross entropy) as the full DLRM model.
Note how previous compressed training methods were not significantly better than just using the hashing trick. Also note that most compressed embeddings actually achieve better loss than the full embedding table of the baseline. This indicates the reduced number of parameters is useful as regularization. However, post-training compression methods, like Product Quantization, are unable to take advantae of this.
Single iteration of CCE:
For more details, see cce/cce.py in the Github repository.
Context: Modern Recommendation Systems require large embedding tables, challenging to fit in memory during training.
Solution: CCE combines hashing/sketching methods with clustering during training, to learn an efficent sparse, data dependent hash function.
Contributions:
@inproceedings{tsang2023clustering, title={Clustering Embedding Tables, Without First Learning Them}, author={Tsang, Henry Ling-Hei and Ahle, Thomas Dybdahl}, booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, year={2023} }