Arrow Research search

Author name cluster

Kaifeng Lyu

Possible papers associated with this exact author name in Arrow. This page groups case-insensitive exact name matches and is not a full identity disambiguation profile.

25 papers
2 author rows

Possible papers

25

ICLR Conference 2025 Conference Paper

A Multi-Power Law for Loss Curve Prediction Across Learning Rate Schedules

  • Kairong Luo
  • Haodong Wen
  • Shengding Hu
  • Zhenbo Sun
  • Zhiyuan Liu 0001
  • Maosong Sun 0001
  • Kaifeng Lyu
  • Wenguang Chen

Training large models is both resource-intensive and time-consuming, making it crucial to understand the quantitative relationship between model performance and hyperparameters. In this paper, we derive an empirical law that predicts pretraining loss for large language models for every intermediate training step across various learning rate schedules, including constant, cosine, and step decay schedules. Our proposed law takes a multi-power form, combining a power law based on the sum of learning rates and additional power laws to account for a loss reduction effect as learning rate decays. We validate this law extensively on Llama-2 models of varying sizes and demonstrate that, after fitting on a few learning rate schedules, it accurately predicts the loss curves for unseen schedules of different shapes and horizons. Moreover, by minimizing the predicted final pretraining loss across learning rate schedules, we are able to find a schedule that outperforms the widely-used cosine learning rate schedule. Interestingly, this automatically discovered schedule bears some resemblance to the recently proposed Warmup-Stable-Decay (WSD) schedule (Hu et al, 2024) but achieves a slightly lower final loss. We believe these results could offer valuable insights for understanding the dynamics of pretraining and for designing learning rate schedules to improve efficiency.

NeurIPS Conference 2025 Conference Paper

Adam Reduces a Unique Form of Sharpness: Theoretical Insights Near the Minimizer Manifold

  • Xinghan Li
  • Haodong Wen
  • Kaifeng Lyu

Despite the popularity of Adam optimizer in practice, most theoretical analyses study SGD as a proxy and little is known about how the solutions found by Adam differ. In this paper, we show that Adam reduces a specific form of sharpness measure shaped by its adaptive updates, leading to qualitatively different solutions from SGD. When the training loss is small, Adam wanders around the manifold of minimizers and takes semi-gradients to minimize this sharpness measure in an adaptive manner, a behavior we rigorously characterize via a continuous-time approximation using stochastic differential equations. We further illustrate how this behavior differs from that of SGD in a well-studied setting: when training overparameterized models with label noise, SGD has been shown to minimize the trace of the Hessian matrix, $\text{tr}(\textbf{H})$, whereas we prove that Adam minimizes $\text{tr}(\text{diag}(\textbf{H})^{1/2})$ instead. In solving sparse linear regression with diagonal linear networks, Adam provably achieves better sparsity and generalization than SGD due to this difference. Finally, we note that our proof framework applies not only to Adam but also to a broad class of adaptive gradient methods, including but not limited to RMSProp, Adam-mini, and Adalayer. This provides a unified perspective for analyzing how adaptive optimizers reduce sharpness and may offer insights for future optimizer design.

NeurIPS Conference 2025 Conference Paper

Data Mixing Can Induce Phase Transitions in Knowledge Acquisition

  • Xinran Gu
  • Kaifeng Lyu
  • Jiazheng Li
  • Jingzhao Zhang

Large Language Models (LLMs) are typically trained on data mixtures: most data come from web scrapes, while a small portion is curated from high-quality sources with dense domain-specific knowledge. In this paper, we show that when training LLMs on such data mixtures, knowledge acquisition from knowledge-dense datasets—unlike training exclusively on knowledge-dense data—does not always follow a smooth scaling law but can exhibit phase transitions with respect to the mixing ratio and model size. Through controlled experiments on a synthetic biography dataset mixed with web-scraped data, we demonstrate that: (1) as we increase the model size to a critical value, the model suddenly transitions from memorizing very few to most of the biographies; (2) below a critical mixing ratio, the model memorizes almost nothing even with extensive training, but beyond this threshold, it rapidly memorizes more biographies. We attribute these phase transitions to a capacity allocation phenomenon: a model with bounded capacity must act like a knapsack problem solver to minimize the overall test loss, and the optimal allocation across datasets can change discontinuously as the model size or mixing ratio varies. We formalize this intuition in an information-theoretic framework and reveal that these phase transitions are predictable, with the critical mixing ratio following a power-law relationship with the model size. Our findings highlight a concrete case where a good mixing recipe for large models may not be optimal for small models, and vice versa.

ICLR Conference 2025 Conference Paper

Efficient stagewise pretraining via progressive subnetworks

  • Abhishek Panigrahi
  • Nikunj Saunshi
  • Kaifeng Lyu
  • Sobhan Miryoosefi
  • Sashank J. Reddi
  • Satyen Kale
  • Sanjiv Kumar

Recent developments in large language models have sparked interest in efficient pretraining methods. Stagewise training approaches to improve efficiency, like gradual stacking and layer dropping (Reddi et al., 2023; Zhang & He, 2020), have recently garnered attention. The prevailing view suggests that stagewise dropping strategies, such as layer dropping, are ineffective, especially when compared to stacking-based approaches. This paper challenges this notion by demonstrating that, with proper design, dropping strategies can be competitive, if not better, than stacking methods. Specifically, we develop a principled stagewise training framework, progressive subnetwork training, which only trains subnetworks within the model and progressively increases the size of subnetworks during training, until it trains the full network. We propose an instantiation of this framework — Random Part Training (RAPTR) — that selects and trains only a random subnetwork (e.g. depth-wise, width-wise) of the network at each step, progressively increasing the size in stages. We show that this approach not only generalizes prior works like layer dropping but also fixes their key issues. Furthermore, we establish a theoretical basis for such approaches and provide justification for (a) increasing complexity of subnetworks in stages, conceptually diverging from prior works on layer dropping, and (b) stability in loss across stage transitions in presence of key modern architecture components like residual connections and layer norms. Through comprehensive experiments, we demonstrate that RAPTR can significantly speed up training of standard benchmarks like BERT and UL2, up to 33% compared to standard training and, surprisingly, also shows better downstream performance on UL2, improving QA tasks and SuperGLUE by 1.5%; thereby, providing evidence of better inductive bias.

ICLR Conference 2025 Conference Paper

Feature Averaging: An Implicit Bias of Gradient Descent Leading to Non-Robustness in Neural Networks

  • Binghui Li
  • Zhixuan Pan
  • Kaifeng Lyu
  • Jian Li

In this work, we investigate a particular implicit bias in gradient descent training, which we term “Feature Averaging,” and argue that it is one of the principal factors contributing to the non-robustness of deep neural networks. We show that, even when multiple discriminative features are present in the input data, neural networks trained by gradient descent tend to rely on an average (or a certain combination) of these features for classification, rather than distinguishing and leveraging each feature individually. Specifically, we provide a detailed theoretical analysis of the training dynamics of two-layer ReLU networks on a binary classification task, where the data distribution consists of multiple clusters with mutually orthogonal centers. We rigorously prove that gradient descent biases the network towards feature averaging, where the weights of each hidden neuron represent an average of the cluster centers (each corresponding to a distinct feature), thereby making the network vulnerable to input perturbations aligned with the negative direction of the averaged features. On the positive side, we demonstrate that this vulnerability can be mitigated through more granular supervision. In particular, we prove that a two-layer ReLU network can achieve optimal robustness when trained to classify individual features rather than merely the original binary classes. Finally, we validate our theoretical findings with experiments on synthetic datasets, MNIST, and CIFAR-10, and confirm the prevalence of feature averaging and its impact on adversarial robustness. We hope these theoretical and empirical insights deepen the understanding of how gradient descent shapes feature learning and adversarial robustness, and how more detailed supervision can enhance robustness.

NeurIPS Conference 2025 Conference Paper

How Far Are We from Optimal Reasoning Efficiency?

  • Jiaxuan Gao
  • Shu Yan
  • Qixin Tan
  • Lu Yang
  • Shusheng Xu
  • Wei Fu
  • Zhiyu Mei
  • Kaifeng Lyu

Large Reasoning Models (LRMs) demonstrate remarkable problem-solving capabilities through extended Chain-of-Thought (CoT) reasoning but often produce excessively verbose and redundant reasoning traces. This inefficiency incurs high inference costs and limits practical deployment. While existing fine-tuning methods aim to improve reasoning efficiency, assessing their efficiency gains remains challenging due to inconsistent evaluations. In this work, we introduce the reasoning efficiency frontiers, empirical upper bounds derived from fine-tuning a base LRM (DeepSeek-R1-Distill-Qwen-1. 5B/7B) across diverse approaches and training configurations. Based on these frontiers, we propose the Reasoning Efficiency Gap (REG), a unified metric quantifying deviations of any fine-tuned LRMs from these frontiers. Systematic evaluation on challenging mathematical benchmarks, AMC23, AIME24, and AIME25, reveals significant gaps in current methods: they either sacrifice accuracy for short length or use excessive tokens to achieve sub-optimal accuracies despite high overall accuracy. To reduce the efficiency gap, we propose REO-RL, a Reinforcement Learning algorithm that optimizes reasoning efficiency by targeting a sparse set of token budgets. Leveraging numerical integration over strategically selected budgets, REO-RL approximates the full efficiency objective with low error using a small set of token budgets. Experiments show that, compared to vanilla RL with outcome reward, REO-RL reduces the reasoning efficiency gap by 74. 5\% and 64. 2\% in the 1. 5B and 7B settings. The 7B LRM fine-tuned with REO-RL achieves reasoning conciseness surpassing frontier LRMs like Qwen3 and Claude Sonnet 3. 7. Ablation studies confirm the efficacy of our token budget strategy and highlight REO-RL’s flexibility across design choices. This work establishes a systematic framework for evaluating and optimizing reasoning efficiency in LRMs. We will release the related code, data, and models to support future research on efficient reasoning in LRMs.

ICLR Conference 2025 Conference Paper

RNNs are not Transformers (Yet): The Key Bottleneck on In-Context Retrieval

  • Kaiyue Wen
  • Xingyu Dang
  • Kaifeng Lyu

This paper investigates the gap in representation powers of Transformers and Recurrent Neural Networks (RNNs), which are more memory efficient than Transformers. We aim to understand whether RNNs can match the performance of Transformers, particularly when enhanced with Chain-of-Thought (CoT) prompting. Our theoretical analysis reveals that CoT improves RNNs but is insufficient to close the gap with Transformers. A key bottleneck lies in the inability of RNNs to perfectly retrieve information from the context, even with CoT: for several tasks that explicitly or implicitly require this capability, such as associative recall and determining if a graph is a tree, we prove that RNNs are not expressive enough to solve the tasks while Transformers can solve them with ease. Conversely, we prove that adopting techniques to enhance the in-context retrieval capability of RNNs, including Retrieval-Augmented Generation (RAG) and adding a single Transformer layer, can elevate RNNs to be capable of solving all polynomial-time solvable problems with CoT, hence closing the representation gap with Transformers. We validate our theory on synthetic and natural language experiments.

ICLR Conference 2025 Conference Paper

Safety Alignment Should be Made More Than Just a Few Tokens Deep

  • Xiangyu Qi
  • Ashwinee Panda
  • Kaifeng Lyu
  • Xiao Ma 0010
  • Subhrajit Roy
  • Ahmad Beirami
  • Prateek Mittal
  • Peter Henderson 0002

The safety alignment of current Large Language Models (LLMs) is vulnerable. Simple attacks, or even benign fine-tuning, can jailbreak aligned models. We note that many of these vulnerabilities are related to a shared underlying issue: safety alignment can take shortcuts, wherein the alignment adapts a model's generative distribution primarily over only its very first few output tokens. We unifiedly refer to this issue as shallow safety alignment. In this paper, we present case studies to explain why shallow safety alignment can exist and show how this issue universally contributes to multiple recently discovered vulnerabilities in LLMs, including the susceptibility to adversarial suffix attacks, prefilling attacks, decoding parameter attacks, and fine-tuning attacks. The key contribution of this work is that we demonstrate how this consolidated notion of shallow safety alignment sheds light on promising research directions for mitigating these vulnerabilities. We show that deepening the safety alignment beyond the first few tokens can meaningfully improve robustness against some common exploits. We also design a regularized fine-tuning objective that makes the safety alignment more persistent against fine-tuning attacks by constraining updates on initial tokens. Overall, we advocate that future safety alignment should be made more than just a few tokens deep.

ICLR Conference 2025 Conference Paper

Towards Understanding Text Hallucination of Diffusion Models via Local Generation Bias

  • Rui Lu 0001
  • Runzhe Wang
  • Kaifeng Lyu
  • Xitai Jiang
  • Gao Huang 0001
  • Mengdi Wang 0001

Score-based diffusion models have achieved incredible performance in generating realistic images, audio, and video data. While these models produce high-quality samples with impressive details, they often introduce unrealistic artifacts, such as distorted fingers or hallucinated texts with no meaning. This paper focuses on textual hallucinations, where diffusion models correctly generate individual symbols but assemble them in a nonsensical manner. Through experimental probing, we consistently observe that such phenomenon is attributed it to the network's local generation bias. Denoising networks tend to produce outputs that rely heavily on highly correlated local regions, particularly when different dimensions of the data distribution are nearly pairwise independent. This behavior leads to a generation process that decomposes the global distribution into separate, independent distributions for each symbol, ultimately failing to capture the global structure, including underlying grammar. Intriguingly, this bias persists across various denoising network architectures including MLP and transformers which have the structure to model global dependency. These findings also provide insights into understanding other types of hallucinations, extending beyond text, as a result of implicit biases in the denoising models. Additionally, we theoretically analyze the training dynamics for a specific case involving a two-layer MLP learning parity points on a hypercube, offering an explanation of its underlying mechanism.

ICML Conference 2025 Conference Paper

Weak-to-Strong Generalization Even in Random Feature Networks, Provably

  • Marko Medvedev
  • Kaifeng Lyu
  • Dingli Yu
  • Sanjeev Arora
  • Zhiyuan Li 0005
  • Nathan Srebro

Weak-to-Strong Generalization (Burns et al. ,2024) is the phenomenon whereby a strong student, say GPT-4, learns a task from a weak teacher, say GPT-2, and ends up significantly outperforming the teacher. We show that this phenomenon does not require a complex and pretrained learner like GPT-4, can arise even in simple non-pretrained models, simply due to the size advantage of the student. But, we also show that there are inherint limits to the extent of such weak to strong generalization. We consider students and teachers that are random feature models, described by two-layer networks with a random and fixed bottom layer and trained top layer. A weak’ teacher, with a small number of units (i. e. random features), is trained on the population, and a strong’ student, with a much larger number of units (i. e. random features), is trained only on labels generated by the weak teacher. We demonstrate, prove, and understand how the student can outperform the teacher, even though trained only on data labeled by the teacher. We also explain how such weak-to-strong generalization is enabled by early stopping. We then show the quantitative limits of weak-to-strong generalization in this model, and in fact in a much broader class of models, for arbitrary teacher and student feature spaces and a broad class of learning rules, including when the student features are pre-trained or otherwise more informative. In particular, we show that in such models the student’s error can only approach zero if the teacher’s error approaches zero, and a strong student cannot “boost” a slightly-better-then-chance teacher to obtain a small error.

ICLR Conference 2024 Conference Paper

A Quadratic Synchronization Rule for Distributed Deep Learning

  • Xinran Gu
  • Kaifeng Lyu
  • Sanjeev Arora
  • Jingzhao Zhang
  • Longbo Huang

In distributed deep learning with data parallelism, synchronizing gradients at each training step can cause a huge communication overhead, especially when many nodes work together to train large models. Local gradient methods, such as Local SGD, address this issue by allowing workers to compute locally for $H$ steps without synchronizing with others, hence reducing communication frequency. While $H$ has been viewed as a hyperparameter to trade optimization efficiency for communication cost, recent research indicates that setting a proper $H$ value can lead to generalization improvement. Yet, selecting a proper $H$ is elusive. This work proposes a theory-grounded method for determining $H$, named the Quadratic Synchronization Rule (QSR), which recommends dynamically setting $H$ in proportion to $\frac{1}{\eta^2}$ as the learning rate $\eta$ decays over time. Extensive ImageNet experiments on ResNet and ViT show that local gradient methods with QSR consistently improve the test accuracy over other synchronization strategies. Compared to the standard data parallel training, QSR enables Local AdamW to cut the training time on 16 or 64 GPUs down from 26.7 to 20.2 hours or from 8.6 to 5.5 hours and, at the same time, achieves 1.16% or 0.84% higher top-1 validation accuracy.

ICLR Conference 2024 Conference Paper

Dichotomy of Early and Late Phase Implicit Biases Can Provably Induce Grokking

  • Kaifeng Lyu
  • Jikai Jin
  • Zhiyuan Li 0005
  • Simon S. Du
  • Jason D. Lee
  • Wei Hu 0014

Recent work by Power et al. (2022) highlighted a surprising "grokking" phenomenon in learning arithmetic tasks: a neural net first "memorizes" the training set, resulting in perfect training accuracy but near-random test accuracy, and after training for sufficiently longer, it suddenly transitions to perfect test accuracy. This paper studies the grokking phenomenon in theoretical setups and shows that it can be induced by a dichotomy of early and late phase implicit biases. Specifically, when training homogeneous neural nets with large initialization and small weight decay on both classification and regression tasks, we prove that the training process gets trapped at a solution corresponding to a kernel predictor for a long time, and then a very sharp transition to min-norm/max-margin predictors occurs, leading to a dramatic change in test accuracy.

ICLR Conference 2024 Conference Paper

DistillSpec: Improving Speculative Decoding via Knowledge Distillation

  • Yongchao Zhou
  • Kaifeng Lyu
  • Ankit Singh Rawat
  • Aditya Krishna Menon
  • Afshin Rostamizadeh
  • Sanjiv Kumar
  • Jean-François Kagy
  • Rishabh Agarwal

Speculative decoding~(SD) accelerates large language model inference by employing a faster {\em draft} model for generating multiple tokens, which are then verified in parallel by the larger {\em target} model, resulting in the text generated according to the target model distribution. However, identifying a compact draft model that is well-aligned with the target model is challenging. To tackle this issue, we propose {\em DistillSpec} that uses knowledge distillation to better align the draft model with the target model, before applying SD. DistillSpec makes two key design choices, which we demonstrate via systematic study to be crucial to improve the draft and target alignment: utilizing \emph{on-policy} data generation from the draft model, and \emph{tailoring the divergence function} to the task and decoding strategy. Notably, DistillSpec yields impressive $10 - 45\%$ speedups over standard SD on a range of standard benchmarks, using both greedy and non-greedy sampling. Furthermore, we combine DistillSpec with lossy SD to achieve fine-grained control over the latency vs. task performance trade-off. Finally, in practical scenarios with models of varying sizes, first using distillation to boost the performance of the target model and then applying DistillSpec to train a well-aligned draft model can reduce decoding latency by $6 - 10\times$ with minimal performance drop, compared to standard decoding without distillation.

NeurIPS Conference 2024 Conference Paper

Keeping LLMs Aligned After Fine-tuning: The Crucial Role of Prompt Templates

  • Kaifeng Lyu
  • Haoyu Zhao
  • Xinran Gu
  • Dingli Yu
  • Anirudh Goyal
  • Sanjeev Arora

Public LLMs such as the Llama 2-Chat underwent alignment training and were considered safe. Recently Qi et al. (2024) reported that even benign fine-tuning on seemingly safe datasets can give rise to unsafe behaviors in the models. The current paper is about methods and best practices to mitigate such loss of alignment. We focus on the setting where a public model is fine-tuned before serving users for specific usage, where the model should improve on the downstream task while maintaining alignment. Through extensive experiments on several chat models (Meta's Llama 2-Chat, Mistral AI's Mistral 7B Instruct v0. 2, and OpenAI's GPT-3. 5 Turbo), this paper uncovers that the prompt templates used during fine-tuning and inference play a crucial role in preserving safety alignment, and proposes the “Pure Tuning, Safe Testing” (PTST) strategy --- fine-tune models without a safety prompt, but include it at test time. This seemingly counterintuitive strategy incorporates an intended distribution shift to encourage alignment preservation. Fine-tuning experiments on GSM8K, ChatDoctor, and OpenOrca show that PTST significantly reduces the rise of unsafe behaviors.

ICLR Conference 2024 Conference Paper

The Marginal Value of Momentum for Small Learning Rate SGD

  • Runzhe Wang
  • Sadhika Malladi
  • Tianhao Wang 0017
  • Kaifeng Lyu
  • Zhiyuan Li 0005

Momentum is known to accelerate the convergence of gradient descent in strongly convex settings without stochastic gradient noise. In stochastic optimization, such as training neural networks, folklore suggests that momentum may help deep learning optimization by reducing the variance of the stochastic gradient update, but previous theoretical analyses do not find momentum to offer any provable acceleration. Theoretical results in this paper clarify the role of momentum in stochastic settings where the learning rate is small and gradient noise is the dominant source of instability, suggesting that SGD with and without momentum behave similarly in the short and long time horizons. Experiments show that momentum indeed has limited benefits for both optimization and generalization in practical training regimes where the optimal learning rate is not very large, including small- to medium-batch training from scratch on ImageNet and fine-tuning language models on downstream tasks.

ICML Conference 2023 Conference Paper

Understanding Incremental Learning of Gradient Descent: A Fine-grained Analysis of Matrix Sensing

  • Jikai Jin
  • Zhiyuan Li 0005
  • Kaifeng Lyu
  • Simon S. Du
  • Jason D. Lee

It is believed that Gradient Descent (GD) induces an implicit bias towards good generalization in training machine learning models. This paper provides a fine-grained analysis of the dynamics of GD for the matrix sensing problem, whose goal is to recover a low-rank ground-truth matrix from near-isotropic linear measurements. It is shown that GD with small initialization behaves similarly to the greedy low-rank learning heuristics and follows an incremental learning procedure: GD sequentially learns solutions with increasing ranks until it recovers the ground truth matrix. Compared to existing works which only analyze the first learning phase for rank-1 solutions, our result provides characterizations for the whole learning process. Moreover, besides the over-parameterized regime that many prior works focused on, our analysis of the incremental learning procedure also applies to the under-parameterized regime. Finally, we conduct numerical experiments to confirm our theoretical findings.

ICLR Conference 2023 Conference Paper

Why (and When) does Local SGD Generalize Better than SGD?

  • Xinran Gu
  • Kaifeng Lyu
  • Longbo Huang
  • Sanjeev Arora

Local SGD is a communication-efficient variant of SGD for large-scale training, where multiple GPUs perform SGD independently and average the model parameters periodically. It has been recently observed that Local SGD can not only achieve the design goal of reducing the communication overhead but also lead to higher test accuracy than the corresponding SGD baseline (Lin et al., 2020b), though the training regimes for this to happen are still in debate (Ortiz et al., 2021). This paper aims to understand why (and when) Local SGD generalizes better based on Stochastic Differential Equation (SDE) approximation. The main contributions of this paper include (i) the derivation of an SDE that captures the long-term behavior of Local SGD in the small learning rate regime, showing how noise drives the iterate to drift and diffuse after it has reached close to the manifold of local minima, (ii) a comparison between the SDEs of Local SGD and SGD, showing that Local SGD induces a stronger drift term that can result in a stronger effect of regularization, e.g., a faster reduction of sharpness, and (iii) empirical evidence validating that having a small learning rate and long enough training time enables the generalization improvement over SGD but removing either of the two conditions leads to no improvement.

NeurIPS Conference 2022 Conference Paper

New Definitions and Evaluations for Saliency Methods: Staying Intrinsic, Complete and Sound

  • Arushi Gupta
  • Nikunj Saunshi
  • Dingli Yu
  • Kaifeng Lyu
  • Sanjeev Arora

Saliency methods compute heat maps that highlight portions of an input that were most important for the label assigned to it by a deep net. Evaluations of saliency methods convert this heat map into a new masked input by retaining the $k$ highest-ranked pixels of the original input and replacing the rest with "uninformative" pixels, and checking if the net's output is mostly unchanged. This is usually seen as an explanation of the output, but the current paper highlights reasons why this inference of causality may be suspect. Inspired by logic concepts of completeness & soundness, it observes that the above type of evaluation focuses on completeness of the explanation, but ignores soundness. New evaluation metrics are introduced to capture both notions, while staying in an intrinsic framework---i. e. , using the dataset and the net, but no separately trained nets, human evaluations, etc. A simple saliency method is described that matches or outperforms prior methods in the evaluations. Experiments also suggest new intrinsic justifications, based on soundness, for popular heuristic tricks such as TV regularization and upsampling.

NeurIPS Conference 2022 Conference Paper

On the SDEs and Scaling Rules for Adaptive Gradient Algorithms

  • Sadhika Malladi
  • Kaifeng Lyu
  • Abhishek Panigrahi
  • Sanjeev Arora

Approximating Stochastic Gradient Descent (SGD) as a Stochastic Differential Equation (SDE) has allowed researchers to enjoy the benefits of studying a continuous optimization trajectory while carefully preserving the stochasticity of SGD. Analogous study of adaptive gradient methods, such as RMSprop and Adam, has been challenging because there were no rigorously proven SDE approximations for these methods. This paper derives the SDE approximations for RMSprop and Adam, giving theoretical guarantees of their correctness as well as experimental validation of their applicability to common large-scaling vision and language settings. A key practical result is the derivation of a square root scaling rule to adjust the optimization hyperparameters of RMSprop and Adam when changing batch size, and its empirical validation in deep learning settings.

NeurIPS Conference 2022 Conference Paper

Understanding the Generalization Benefit of Normalization Layers: Sharpness Reduction

  • Kaifeng Lyu
  • Zhiyuan Li
  • Sanjeev Arora

Normalization layers (e. g. , Batch Normalization, Layer Normalization) were introduced to help with optimization difficulties in very deep nets, but they clearly also help generalization, even in not-so-deep nets. Motivated by the long-held belief that flatter minima lead to better generalization, this paper gives mathematical analysis and supporting experiments suggesting that normalization (together with accompanying weight-decay) encourages GD to reduce the sharpness of loss surface. Here ``sharpness'' is carefully defined given that the loss is scale-invariant, a known consequence of normalization. Specifically, for a fairly broad class of neural nets with normalization, our theory explains how GD with a finite learning rate enters the so-called Edge of Stability (EoS) regime, and characterizes the trajectory of GD in this regime via a continuous sharpness-reduction flow.

NeurIPS Conference 2021 Conference Paper

Gradient Descent on Two-layer Nets: Margin Maximization and Simplicity Bias

  • Kaifeng Lyu
  • Zhiyuan Li
  • Runzhe Wang
  • Sanjeev Arora

The generalization mystery of overparametrized deep nets has motivated efforts to understand how gradient descent (GD) converges to low-loss solutions that generalize well. Real-life neural networks are initialized from small random values and trained with cross-entropy loss for classification (unlike the "lazy" or "NTK" regime of training where analysis was more successful), and a recent sequence of results (Lyu and Li, 2020; Chizat and Bach, 2020; Ji and Telgarsky, 2020) provide theoretical evidence that GD may converge to the "max-margin" solution with zero loss, which presumably generalizes well. However, the global optimality of margin is proved only in some settings where neural nets are infinitely or exponentially wide. The current paper is able to establish this global optimality for two-layer Leaky ReLU nets trained with gradient flow on linearly separable and symmetric data, regardless of the width. The analysis also gives some theoretical justification for recent empirical findings (Kalimeris et al. , 2019) on the so-called simplicity bias of GD towards linear or other "simple" classes of solutions, especially early in training. On the pessimistic side, the paper suggests that such results are fragile. A simple data manipulation can make gradient flow converge to a linear classifier with suboptimal margin.

ICLR Conference 2021 Conference Paper

Towards Resolving the Implicit Bias of Gradient Descent for Matrix Factorization: Greedy Low-Rank Learning

  • Zhiyuan Li 0005
  • Yuping Luo
  • Kaifeng Lyu

Matrix factorization is a simple and natural test-bed to investigate the implicit regularization of gradient descent. Gunasekar et al. (2017) conjectured that gradient flow with infinitesimal initialization converges to the solution that minimizes the nuclear norm, but a series of recent papers argued that the language of norm minimization is not sufficient to give a full characterization for the implicit regularization. In this work, we provide theoretical and empirical evidence that for depth-2 matrix factorization, gradient flow with infinitesimal initialization is mathematically equivalent to a simple heuristic rank minimization algorithm, Greedy Low-Rank Learning, under some reasonable assumptions. This generalizes the rank minimization view from previous works to a much broader setting and enables us to construct counter-examples to refute the conjecture from Gunasekar et al. (2017). We also extend the results to the case where depth >= 3, and we show that the benefit of being deeper is that the above convergence has a much weaker dependence over initialization magnitude so that this rank minimization is more likely to take effect for initialization with practical scale.

ICLR Conference 2020 Conference Paper

Gradient Descent Maximizes the Margin of Homogeneous Neural Networks

  • Kaifeng Lyu
  • Jian Li 0015

In this paper, we study the implicit regularization of the gradient descent algorithm in homogeneous neural networks, including fully-connected and convolutional neural networks with ReLU or LeakyReLU activations. In particular, we study the gradient descent or gradient flow (i.e., gradient descent with infinitesimal step size) optimizing the logistic loss or cross-entropy loss of any homogeneous model (possibly non-smooth), and show that if the training loss decreases below a certain threshold, then we can define a smoothed version of the normalized margin which increases over time. We also formulate a natural constrained optimization problem related to margin maximization, and prove that both the normalized margin and its smoothed version converge to the objective value at a KKT point of the optimization problem. Our results generalize the previous results for logistic regression with one-layer or multi-layer linear networks, and provide more quantitative convergence results with weaker assumptions than previous results for homogeneous smooth neural networks. We conduct several experiments to justify our theoretical finding on MNIST and CIFAR-10 datasets. Finally, as margin is closely related to robustness, we discuss potential benefits of training longer for improving the robustness of the model.

NeurIPS Conference 2020 Conference Paper

Reconciling Modern Deep Learning with Traditional Optimization Analyses: The Intrinsic Learning Rate

  • Zhiyuan Li
  • Kaifeng Lyu
  • Sanjeev Arora

Recent works (e. g. , (Li \& Arora, 2020)) suggest that the use of popular normalization schemes (including Batch Normalization) in today's deep learning can move it far from a traditional optimization viewpoint, e. g. , use of exponentially increasing learning rates. The current paper highlights other ways in which behavior of normalized nets departs from traditional viewpoints, and then initiates a formal framework for studying their mathematics via suitable adaptation of the conventional framework namely, modeling SGD-induced training trajectory via a suitable stochastic differential equation (SDE) with a noise term that captures gradient noise. This yields: (a) A new \textquotedblleft intrinsic learning rate\textquotedblright\ parameter that is the product of the normal learning rate $\eta$ and weight decay factor $\lambda$. Analysis of the SDE shows how the effective speed of learning varies and equilibrates over time under the control of intrinsic LR. (b) A challenge---via theory and experiments---to popular belief that good generalization requires large learning rates at the start of training. (c) New experiments, backed by mathematical intuition, suggesting the number of steps to equilibrium (in function space) scales as the inverse of the intrinsic learning rate, as opposed to the exponential time convergence bound implied by SDE analysis. We name it the \emph{Fast Equilibrium Conjecture} and suggest it holds the key to why Batch Normalization is effective.