Cross Attention is a popular method for retrieving information from a set of context tokens for making predictions. At inference time, for each prediction, Cross Attention scans the full set of (N) tokens. In practice, however, often only a small subset of tokens are required for good performance. Methods such as Perceiver IO are cheap at inference as they distill the information to a smaller-sized set of latent tokens L<N on which cross attention is then applied, resulting in only (L) complexity. However, in practice, as the number of input tokens and the amount of information to distill increases, the number of latent tokens needed also increases significantly. In this work, we propose Tree Cross Attention (TCA) – a module based on Cross Attention that only retrieves information from a logarithmic (log(N)) number of tokens for performing inference. TCA organizes the data in a tree structure and performs a tree search at inference time to retrieve the relevant tokens for prediction. Leveraging TCA, we introduce ReTreever, a flexible architecture for token-efficient inference. We show empirically that Tree Cross Attention (TCA) performs comparable to Cross Attention across various classification and uncertainty regression tasks while being significantly more token-efficient. Furthermore, we compare ReTreever against Perceiver IO, showing significant gains while using the same number of tokens for inference.
Bibtex
@misc{feng2023tree,
title={Tree Cross Attention},
author={Leo Feng and Frederick Tung and Hossein Hajimirsadeghi and Yoshua Bengio and Mohamed Osama Ahmed},
year={2023},
eprint={2309.17388},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Related Research
-
EBBS: An Ensemble with Bi-Level Beam Search for Zero-Shot Machine Translation
EBBS: An Ensemble with Bi-Level Beam Search for Zero-Shot Machine Translation
Y. Wen, B. Shayegh, C. Huang, Y. Cao, and L. Mou. Workshop at International Conference on Machine Learning (ICML)
Publications
-
Pre-training multi-billion parameter LLMs on a single GPU with Flora
Pre-training multi-billion parameter LLMs on a single GPU with Flora
Y. Hao, Y. Cao, and L. Mou.
Research
-
Forget Sharpness: Perturbed Forgetting of Model Biases Within SAM Dynamics
Forget Sharpness: Perturbed Forgetting of Model Biases Within SAM Dynamics
A. Vani, F. Tung, G. Oliveira, and H. Sharifi. International Conference on Machine Learning (ICML)
Publications
Work With Us!
Borealis AI is looking to hire for various roles across different teams. Visit our career page now and discover opportunities to join similar impactful projects!
Careers at Borealis AI