TY - JOUR
T1 - Transformers learn to implement preconditioned gradient descent for in-context learning
AU - Ahn, Kwangjun
AU - Cheng, Xiang
AU - Daneshmand, Hadi
AU - Sra, Suvrit
N1 - Publisher Copyright:
© 2023 Neural information processing systems foundation. All rights reserved.
PY - 2023
Y1 - 2023
N2 - Several recent works demonstrate that transformers can implement algorithms like gradient descent. By a careful construction of weights, these works show that multiple layers of transformers are expressive enough to simulate iterations of gradient descent. Going beyond the question of expressivity, we ask: Can transformers learn to implement such algorithms by training over random problem instances? To our knowledge, we make the first theoretical progress on this question via an analysis of the loss landscape for linear transformers trained over random instances of linear regression. For a single attention layer, we prove the global minimum of the training objective implements a single iteration of preconditioned gradient descent. Notably, the preconditioning matrix not only adapts to the input distribution but also to the variance induced by data inadequacy. For a transformer with L attention layers, we prove certain critical points of the training objective implement L iterations of preconditioned gradient descent. Our results call for future theoretical studies on learning algorithms by training transformers.
AB - Several recent works demonstrate that transformers can implement algorithms like gradient descent. By a careful construction of weights, these works show that multiple layers of transformers are expressive enough to simulate iterations of gradient descent. Going beyond the question of expressivity, we ask: Can transformers learn to implement such algorithms by training over random problem instances? To our knowledge, we make the first theoretical progress on this question via an analysis of the loss landscape for linear transformers trained over random instances of linear regression. For a single attention layer, we prove the global minimum of the training objective implements a single iteration of preconditioned gradient descent. Notably, the preconditioning matrix not only adapts to the input distribution but also to the variance induced by data inadequacy. For a transformer with L attention layers, we prove certain critical points of the training objective implement L iterations of preconditioned gradient descent. Our results call for future theoretical studies on learning algorithms by training transformers.
UR - http://www.scopus.com/inward/record.url?scp=85205445511&partnerID=8YFLogxK
M3 - Conference article
AN - SCOPUS:85205445511
SN - 1049-5258
VL - 36
SP - 45614
EP - 45650
JO - Advances in Neural Information Processing Systems
JF - Advances in Neural Information Processing Systems
T2 - 37th Conference on Neural Information Processing Systems, NeurIPS 2023
Y2 - 10 December 2023 through 16 December 2023
ER -