Arrow Research search

Author name cluster

Kwangjun Ahn

Possible papers associated with this exact author name in Arrow. This page groups case-insensitive exact name matches and is not a full identity disambiguation profile.

19 papers
2 author rows

Possible papers

19

ICLR Conference 2025 Conference Paper

Does SGD really happen in tiny subspaces?

  • Minhak Song
  • Kwangjun Ahn
  • Chulhee Yun

Understanding the training dynamics of deep neural networks is challenging due to their high-dimensional nature and intricate loss landscapes. Recent studies have revealed that, along the training trajectory, the gradient approximately aligns with a low-rank top eigenspace of the training loss Hessian, referred to as the dominant subspace. Given this alignment, this paper explores whether neural networks can be trained within the dominant subspace, which, if feasible, could lead to more efficient training methods. Our primary observation is that when the SGD update is projected onto the dominant subspace, the training loss does not decrease further. This suggests that the observed alignment between the gradient and the dominant subspace is spurious. Surprisingly, projecting out the dominant subspace proves to be just as effective as the original update, despite removing the majority of the original update component. We observe similar behavior across practical setups, including the large learning rate regime (also known as Edge of Stability), Sharpness-Aware Minimization, momentum, and adaptive optimizers. We discuss the main causes and implications of this spurious alignment, shedding light on the dynamics of neural network training.

ICML Conference 2025 Conference Paper

General framework for online-to-nonconvex conversion: Schedule-free SGD is also effective for nonconvex optimization

  • Kwangjun Ahn
  • Gagik Magakyan
  • Ashok Cutkosky

This work investigates the effectiveness of schedule-free methods, developed by A. Defazio et al. (NeurIPS 2024), in nonconvex optimization settings, inspired by their remarkable empirical success in training neural networks. Specifically, we show that schedule-free SGD achieves optimal iteration complexity for nonsmooth, non-convex optimization problems. Our proof begins with the development of a general framework for online-to-nonconvex conversion, which converts a given online learning algorithm into an optimization algorithm for nonconvex losses. Our general framework not only recovers existing conversions but also leads to two novel conversion schemes. Notably, one of these new conversions corresponds directly to schedule-free SGD, allowing us to establish its optimality. Additionally, our analysis provides valuable insights into the parameter choices for schedule-free SGD, addressing a theoretical gap that the convex theory cannot explain.

ICLR Conference 2025 Conference Paper

The Belief State Transformer

  • Edward S. Hu
  • Kwangjun Ahn
  • Qinghua Liu
  • Haoran Xu
  • Manan Tomar
  • Ada Langford
  • Dinesh Jayaraman
  • Alex Lamb

We introduce the "Belief State Transformer", a next-token predictor that takes both a prefix and suffix as inputs, with a novel objective of predicting both the next token for the prefix and the previous token for the suffix. The Belief State Transformer effectively learns to solve challenging problems that conventional forward-only transformers struggle with, in a domain-independent fashion. Key to this success is learning a compact belief state that captures all relevant information necessary for accurate predictions. Empirical ablations show that each component of the model is essential in difficult scenarios where standard Transformers fall short. For the task of story writing with known prefixes and suffixes, our approach outperforms the Fill-in-the-Middle method for reaching known goals and demonstrates improved performance even when the goals are unknown. Altogether, the Belief State Transformer enables more efficient goal-conditioned decoding, better test-time inference, and high-quality text representations on small scale problems. Website: https://edwhu.github.io/bst-website

NeurIPS Conference 2025 Conference Paper

Through the River: Understanding the Benefit of Schedule-Free Methods for Language Model Training

  • Minhak Song
  • Beomhan Baek
  • Kwangjun Ahn
  • Chulhee Yun

As both model and dataset sizes continue to scale rapidly, conventional pretraining strategies with fixed compute budgets—such as cosine learning rate schedules—are increasingly inadequate for large-scale training. Recent alternatives, including warmup-stable-decay (WSD) schedules and weight averaging, offer greater flexibility. However, WSD relies on explicit decay phases to track progress, while weight averaging addresses this limitation at the cost of additional memory. In search of a more principled and scalable alternative, we revisit the Schedule-Free (SF) method [Defazio et al. , 2024], which has shown strong empirical performance across diverse settings. We show that SF-AdamW effectively navigates the "river" structure of the loss landscape without decay phases or auxiliary averaging, making it particularly suitable for continuously scaling training workloads. To understand this behavior, we conduct a theoretical and empirical analysis of SF dynamics, revealing that it implicitly performs weight averaging without memory overhead. Guided by this analysis, we propose a refined variant of SF that improves robustness to momentum and performs better under large batch sizes, addressing key limitations of the original method. Together, these results establish SF as a practical, scalable, and theoretically grounded approach for language model training.

NeurIPS Conference 2024 Conference Paper

Adam with model exponential moving average is effective for nonconvex optimization

  • Kwangjun Ahn
  • Ashok Cutkosky

In this work, we offer a theoretical analysis of two modern optimization techniques for training large and complex models: (i) adaptive optimization algorithms, such as Adam, and (ii) the model exponential moving average (EMA). Specifically, we demonstrate that a clipped version of Adam with model EMA achieves the optimal convergence rates in various nonconvex optimization settings, both smooth and nonsmooth. Moreover, when the scale varies significantly across different coordinates, we demonstrate that the coordinate-wise adaptivity of Adam is provably advantageous. Notably, unlike previous analyses of Adam, our analysis crucially relies on its core elements---momentum and discounting factors---as well as model EMA, motivating their wide applications in practice.

ICML Conference 2024 Conference Paper

How to Escape Sharp Minima with Random Perturbations

  • Kwangjun Ahn
  • Ali Jadbabaie
  • Suvrit Sra

Modern machine learning applications have witnessed the remarkable success of optimization algorithms that are designed to find flat minima. Motivated by this design choice, we undertake a formal study that (i) formulates the notion of flat minima, and (ii) studies the complexity of finding them. Specifically, we adopt the trace of the Hessian of the cost function as a measure of flatness, and use it to formally define the notion of approximate flat minima. Under this notion, we then analyze algorithms that find approximate flat minima efficiently. For general cost functions, we discuss a gradient-based algorithm that finds an approximate flat local minimum efficiently. The main component of the algorithm is to use gradients computed from randomly perturbed iterates to estimate a direction that leads to flatter minima. For the setting where the cost function is an empirical risk over training data, we present a faster algorithm that is inspired by a recently proposed practical algorithm called sharpness-aware minimization, supporting its success in practice.

ICLR Conference 2024 Conference Paper

Linear attention is (maybe) all you need (to understand Transformer optimization)

  • Kwangjun Ahn
  • Xiang Cheng
  • Minhak Song
  • Chulhee Yun
  • Ali Jadbabaie
  • Suvrit Sra

Transformer training is notoriously difficult, requiring a careful design of optimizers and use of various heuristics. We make progress towards understanding the subtleties of training Transformers by carefully studying a simple yet canonical linearized *shallow* Transformer model. Specifically, we train linear Transformers to solve regression tasks, inspired by J. von Oswald et al. (ICML 2023), and K. Ahn et al. (NeurIPS 2023). Most importantly, we observe that our proposed linearized models can reproduce several prominent aspects of Transformer training dynamics. Consequently, the results obtained in this paper suggest that a simple linearized Transformer model could actually be a valuable, realistic abstraction for understanding Transformer optimization.

ICML Conference 2024 Conference Paper

Understanding Adam Optimizer via Online Learning of Updates: Adam is FTRL in Disguise

  • Kwangjun Ahn
  • Zhiyu Zhang
  • Yunbum Kook
  • Yan Dai 0002

Despite the success of the Adam optimizer in practice, the theoretical understanding of its algorithmic components still remains limited. In particular, most existing analyses of Adam show the convergence rate that can be simply achieved by non-adative algorithms like SGD. In this work, we provide a different perspective based on online learning that underscores the importance of Adam’s algorithmic components. Inspired by Cutkosky et al. (2023), we consider the framework called online learning of updates/increments, where we choose the updates/increments of an optimizer based on an online learner. With this framework, the design of a good optimizer is reduced to the design of a good online learner. Our main observation is that Adam corresponds to a principled online learning framework called Follow-the-Regularized-Leader (FTRL). Building on this observation, we study the benefits of its algorithmic components from the online learning perspective.

JMLR Journal 2023 Journal Article

A Unified Approach to Controlling Implicit Regularization via Mirror Descent

  • Haoyuan Sun
  • Khashayar Gatmiry
  • Kwangjun Ahn
  • Navid Azizan

Inspired by the remarkable success of large neural networks, there has been significant interest in understanding the generalization performance of over-parameterized models. Substantial efforts have been invested in characterizing how optimization algorithms impact generalization through their "preferred" solutions, a phenomenon commonly referred to as implicit regularization. In particular, it has been argued that gradient descent (GD) induces an implicit $\ell_2$-norm regularization in regression and classification problems. However, the implicit regularization of different algorithms are confined to either a specific geometry or a particular class of learning problems, indicating a gap in a general approach for controlling the implicit regularization. To address this, we present a unified approach using mirror descent (MD), a notable generalization of GD, to control implicit regularization in both regression and classification settings. More specifically, we show that MD with the general class of homogeneous potential functions converges in direction to a generalized maximum-margin solution for linear classification problems, thereby answering a long-standing question in the classification setting. Further, we show that MD can be implemented efficiently and enjoys fast convergence under suitable conditions. Through comprehensive experiments, we demonstrate that MD is a versatile method to produce learned models with different regularizers, which in turn have different generalization performances. [abs] [ pdf ][ bib ] &copy JMLR 2023. ( edit, beta )

NeurIPS Conference 2023 Conference Paper

Learning threshold neurons via edge of stability

  • Kwangjun Ahn
  • Sebastien Bubeck
  • Sinho Chewi
  • Yin Tat Lee
  • Felipe Suarez
  • Yi Zhang

Existing analyses of neural network training often operate under the unrealistic assumption of an extremely small learning rate. This lies in stark contrast to practical wisdom and empirical studies, such as the work of J. Cohen et al. (ICLR 2021), which exhibit startling new phenomena (the "edge of stability"' or "unstable convergence") and potential benefits for generalization in the large learning rate regime. Despite a flurry of recent works on this topic, however, the latter effect is still poorly understood. In this paper, we take a step towards understanding genuinely non-convex training dynamics with large learning rates by performing a detailed analysis of gradient descent for simplified models of two-layer neural networks. For these models, we provably establish the edge of stability phenomenon and discover a sharp phase transition for the step size below which the neural network fails to learn ``threshold-like'' neurons (i. e. , neurons with a non-zero first-layer bias). This elucidates one possible mechanism by which the edge of stability can in fact lead to better generalization, as threshold neurons are basic building blocks with useful inductive bias for many tasks.

NeurIPS Conference 2023 Conference Paper

The Crucial Role of Normalization in Sharpness-Aware Minimization

  • Yan Dai
  • Kwangjun Ahn
  • Suvrit Sra

Sharpness-Aware Minimization (SAM) is a recently proposed gradient-based optimizer (Foret et al. , ICLR 2021) that greatly improves the prediction performance of deep neural networks. Consequently, there has been a surge of interest in explaining its empirical success. We focus, in particular, on understanding the role played by normalization, a key component of the SAM updates. We theoretically and empirically study the effect of normalization in SAM for both convex and non-convex functions, revealing two key roles played by normalization: i) it helps in stabilizing the algorithm; and ii) it enables the algorithm to drift along a continuum (manifold) of minima -- a property identified by recent theoretical works that is the key to better performance. We further argue that these two properties of normalization make SAM robust against the choice of hyper-parameters, supporting the practicality of SAM. Our conclusions are backed by various experiments.

NeurIPS Conference 2023 Conference Paper

Transformers learn to implement preconditioned gradient descent for in-context learning

  • Kwangjun Ahn
  • Xiang Cheng
  • Hadi Daneshmand
  • Suvrit Sra

Several recent works demonstrate that transformers can implement algorithms like gradient descent. By a careful construction of weights, these works show that multiple layers of transformers are expressive enough to simulate iterations of gradient descent. Going beyond the question of expressivity, we ask: \emph{Can transformers learn to implement such algorithms by training over random problem instances? } To our knowledge, we make the first theoretical progress on this question via an analysis of the loss landscape for linear transformers trained over random instances of linear regression. For a single attention layer, we prove the global minimum of the training objective implements a single iteration of preconditioned gradient descent. Notably, the preconditioning matrix not only adapts to the input distribution but also to the variance induced by data inadequacy. For a transformer with $L$ attention layers, we prove certain critical points of the training objective implement $L$ iterations of preconditioned gradient descent. Our results call for future theoretical studies on learning algorithms by training transformers.

ICML Conference 2022 Conference Paper

Agnostic Learnability of Halfspaces via Logistic Loss

  • Ziwei Ji
  • Kwangjun Ahn
  • Pranjal Awasthi
  • Satyen Kale
  • Stefani Karp

We investigate approximation guarantees provided by logistic regression for the fundamental problem of agnostic learning of homogeneous halfspaces. Previously, for a certain broad class of “well-behaved” distributions on the examples, Diakonikolas et al. (2020) proved an tilde{Omega}(OPT) lower bound, while Frei et al. (2021) proved an tilde{O}(sqrt{OPT}) upper bound, where OPT denotes the best zero-one/misclassification risk of a homogeneous halfspace. In this paper, we close this gap by constructing a well-behaved distribution such that the global minimizer of the logistic risk over this distribution only achieves Omega(sqrt{OPT}) misclassification risk, matching the upper bound in (Frei et al. , 2021). On the other hand, we also show that if we impose a radial-Lipschitzness condition in addition to well-behaved-ness on the distribution, logistic regression on a ball of bounded radius reaches tilde{O}(OPT) misclassification risk. Our techniques also show for any well-behaved distribution, regardless of radial Lipschitzness, we can overcome the Omega(sqrt{OPT}) lower bound for logistic loss simply at the cost of one additional convex optimization step involving the hinge loss and attain tilde{O}(OPT) misclassification risk. This two-step convex optimization algorithm is simpler than previous methods obtaining this guarantee, all of which require solving O(log(1/OPT)) minimization problems.

NeurIPS Conference 2022 Conference Paper

Mirror Descent Maximizes Generalized Margin and Can Be Implemented Efficiently

  • Haoyuan Sun
  • Kwangjun Ahn
  • Christos Thrampoulidis
  • Navid Azizan

Driven by the empirical success and wide use of deep neural networks, understanding the generalization performance of overparameterized models has become an increasingly popular question. To this end, there has been substantial effort to characterize the implicit bias of the optimization algorithms used, such as gradient descent (GD), and the structural properties of their preferred solutions. This paper answers an open question in this literature: For the classification setting, what solution does mirror descent (MD) converge to? Specifically, motivated by its efficient implementation, we consider the family of mirror descent algorithms with potential function chosen as the $p$-th power of the $\ell_p$-norm, which is an important generalization of GD. We call this algorithm $p$-$\textsf{GD}$. For this family, we characterize the solutions it obtains and show that it converges in direction to a generalized maximum-margin solution with respect to the $\ell_p$-norm for linearly separable classification. While the MD update rule is in general expensive to compute and not suitable for deep learning, $p$-$\textsf{GD}$ is fully parallelizable in the same manner as SGD and can be used to train deep neural networks with virtually no additional computational overhead. Using comprehensive experiments with both linear and deep neural network models, we demonstrate that $p$-$\textsf{GD}$ can noticeably affect the structure and the generalization performance of the learned models.

NeurIPS Conference 2022 Conference Paper

Reproducibility in Optimization: Theoretical Framework and Limits

  • Kwangjun Ahn
  • Prateek Jain
  • Ziwei Ji
  • Satyen Kale
  • Praneeth Netrapalli
  • Gil I Shamir

We initiate a formal study of reproducibility in optimization. We define a quantitative measure of reproducibility of optimization procedures in the face of noisy or error-prone operations such as inexact or stochastic gradient computations or inexact initialization. We then analyze several convex optimization settings of interest such as smooth, non-smooth, and strongly-convex objective functions and establish tight bounds on the limits of reproducibility in each setting. Our analysis reveals a fundamental trade-off between computation and reproducibility: more computation is necessary (and sufficient) for better reproducibility.

ICML Conference 2022 Conference Paper

Understanding the unstable convergence of gradient descent

  • Kwangjun Ahn
  • Jingzhao Zhang
  • Suvrit Sra

Most existing analyses of (stochastic) gradient descent rely on the condition that for $L$-smooth costs, the step size is less than $2/L$. However, many works have observed that in machine learning applications step sizes often do not fulfill this condition, yet (stochastic) gradient descent still converges, albeit in an unstable manner. We investigate this unstable convergence phenomenon from first principles, and discuss key causes behind it. We also identify its main characteristics, and how they interrelate based on both theory and experiments, offering a principled view toward understanding the phenomenon.

NeurIPS Conference 2021 Conference Paper

Efficient constrained sampling via the mirror-Langevin algorithm

  • Kwangjun Ahn
  • Sinho Chewi

We propose a new discretization of the mirror-Langevin diffusion and give a crisp proof of its convergence. Our analysis uses relative convexity/smoothness and self-concordance, ideas which originated in convex optimization, together with a new result in optimal transport that generalizes the displacement convexity of the entropy. Unlike prior works, our result both (1) requires much weaker assumptions on the mirror map and the target distribution, and (2) has vanishing bias as the step size tends to zero. In particular, for the task of sampling from a log-concave distribution supported on a compact set, our theoretical results are significantly better than the existing guarantees.

NeurIPS Conference 2020 Conference Paper

SGD with shuffling: optimal rates without component convexity and large epoch requirements

  • Kwangjun Ahn
  • Chulhee Yun
  • Suvrit Sra

We study without-replacement SGD for solving finite-sum optimization problems. Specifically, depending on how the indices of the finite-sum are shuffled, we consider the RandomShuffle (shuffle at the beginning of each epoch) and SingleShuffle (shuffle only once) algorithms. First, we establish minimax optimal convergence rates of these algorithms up to poly-log factors. Notably, our analysis is general enough to cover gradient dominated nonconvex costs, and does not rely on the convexity of individual component functions unlike existing optimal convergence results. Secondly, assuming convexity of the individual components, we further sharpen the tight convergence results for RandomShuffle by removing the drawbacks common to all prior arts: large number of epochs required for the results to hold, and extra poly-log factor gaps to the lower bound.

NeurIPS Conference 2018 Conference Paper

Binary Rating Estimation with Graph Side Information

  • Kwangjun Ahn
  • Kangwook Lee
  • Hyunseung Cha
  • Changho Suh

Rich experimental evidences show that one can better estimate users' unknown ratings with the aid of graph side information such as social graphs. However, the gain is not theoretically quantified. In this work, we study the binary rating estimation problem to understand the fundamental value of graph side information. Considering a simple correlation model between a rating matrix and a graph, we characterize the sharp threshold on the number of observed entries required to recover the rating matrix (called the optimal sample complexity) as a function of the quality of graph side information (to be detailed). To the best of our knowledge, we are the first to reveal how much the graph side information reduces sample complexity. Further, we propose a computationally efficient algorithm that achieves the limit. Our experimental results demonstrate that the algorithm performs well even with real-world graphs.