Arrow Research search

Author name cluster

Albert Gu

Possible papers associated with this exact author name in Arrow. This page groups case-insensitive exact name matches and is not a full identity disambiguation profile.

32 papers
2 author rows

Possible papers

32

TMLR Journal 2025 Journal Article

Chimera: State Space Models Beyond Sequences

  • Aakash Lahoti
  • Tanya Marwah
  • Ratish Puduppully
  • Albert Gu

Transformer-based deep learning methods have emerged as the standard approach to model diverse data such as sequences, images, and graphs. These methods rely on self-attention, which treats data as an unordered set of elements. This ignores the neighborhood structure or graph topology of the data and requires the use of inductive biases, such as position embeddings in sequences and images, and random walks in graphs, to incorporate topology. However, developing bespoke inductive biases for each task requires significant effort and can also introduce side-effects hindering generalization. In this work, we introduce Chimera, a unified model that directly incorporates the data topology in a principled way, obviating the need for domain-specific biases. Central to Chimera is the observation that state-space models---which naturally do not require position embeddings---can be generalized to capture any general graph topology. Our experiments demonstrate the versatility of our approach---Chimera achieves strong performance across the domains of language, vision, and graphs, outperforming BERT on GLUE by 0.7 points, ViT on ImageNet-1k by 2.6%, and all the baselines on the Long Range Graph Benchmark. Our results validate Chimera's principled methodological contributions and affirm the long-held belief that data topology is a powerful inductive bias across modalities. We further propose algorithmic optimizations to improve Chimera's efficiency while maintaining performance: 1) For the subclass of Directed Acyclic Graphs we show that Chimera can be implemented as a linear time recurrence. 2) For general graphs, we relax the method with a simple mathematical approximation, achieving Transformer's quadratic complexity without relying on domain-specific biases.

ICLR Conference 2025 Conference Paper

On the Benefits of Memory for Modeling Time-Dependent PDEs

  • Ricardo Buitrago Ruiz
  • Tanya Marwah
  • Albert Gu
  • Andrej Risteski

Data-driven techniques have emerged as a promising alternative to traditional numerical methods for solving PDEs. For time-dependent PDEs, many approaches are Markovian---the evolution of the trained system only depends on the current state, and not the past states. In this work, we investigate the benefits of using memory for modeling time-dependent PDEs: that is, when past states are explicitly used to predict the future. Motivated by the Mori-Zwanzig theory of model reduction, we theoretically exhibit examples of simple (even linear) PDEs, in which a solution that uses memory is arbitrarily better than a Markovian solution. Additionally, we introduce Memory Neural Operator (MemNO), a neural operator architecture that combines recent state space models (specifically, S4) and Fourier Neural Operators (FNOs) to effectively model memory. We empirically demonstrate that when the PDEs are supplied in low resolution or contain observation noise at train and test time, MemNO significantly outperforms the baselines without memory---with up to $6 \times$ reduction in test error. Furthermore, we show that this benefit is particularly pronounced when the PDE solutions have significant high-frequency Fourier modes (e.g., low-viscosity fluid dynamics) and we construct a challenging benchmark dataset consisting of such PDEs.

ICML Conference 2025 Conference Paper

Understanding and Improving Length Generalization in Recurrent Models

  • Ricardo Buitrago Ruiz
  • Albert Gu

Recently, recurrent models such as state space models and linear attention have become popular due to their linear complexity in the sequence length. Thanks to their recurrent nature, in principle they can process arbitrarily long sequences, but their performance sometimes drops considerably beyond their training context lengths—i. e. they fail to length generalize. In this work, we provide comprehensive empirical and theoretical analysis to support the unexplored states hypothesis, which posits that models fail to length generalize when during training they are only exposed to a limited subset of the distribution of all attainable states (i. e. states that would be attained if the recurrence was applied to long sequences). Furthermore, we investigate simple training interventions that aim to increase the coverage of the states that the model is trained on, e. g. by initializing the state with Gaussian noise or with the final state of a different input sequence. With only 500 post-training steps ($\sim 0. 1%$ of the pre-training budget), these interventions enable length generalization for sequences that are orders of magnitude longer than the training context (e. g. $2k\longrightarrow 128k$) and show improved performance in long context tasks, thus presenting a simple and efficient way to enable robust length generalization in general recurrent models.

ICML Conference 2025 Conference Paper

Understanding the Skill Gap in Recurrent Language Models: The Role of the Gather-and-Aggregate Mechanism

  • Aviv Bick
  • Eric P. Xing
  • Albert Gu

State-space models (SSMs) offer efficient alternatives to Transformers for long sequences, but their fixed-size recurrent state limits capability on algorithmic tasks, such as retrieving past context. In this work, we examine how in-context retrieval operates in Transformer- and SSM-based language models and find that both rely on a Gather-and-Aggregate (G&A) mechanism: a Gather Head extracts relevant information from context, which an Aggregate Head integrates into representation. In both architectures, G&A concentrates in a few heads, forming bottlenecks even for simple retrieval. For example, disabling a single Gather or Aggregate Head in a pruned Llama-3. 1-8B impairs retrieving the correct answer letter in MMLU, reducing its accuracy from 66% to 25%. Moreover, this retrieval bottleneck can obscure knowledge demands of tasks as the pruned model succeeds on MMLU with functioning G&A heads yet fails on other knowledge benchmarks. The bottleneck similarly extends to tasks where SSMs typically underperform, like GSM8K, BBH, and dialogue. We show that SSMs’ retrieval challenges manifest in these heads, creating smoother attention patterns instead of the sharp transitions effective G&A requires. Thus, the Transformer-SSM retrieval gap exists in just a few heads, rather than the entire language model. % Result 3: Analyzing Hybrid models This suggests a unified explanation for Transformer vs. SSM performance gap while showing how to merge their strengths. We find that pretrained hybrid models, where SSMs are combined with attention layers, delegate the role of Aggregate Heads to attention. Similarly, replacing a single G&A head in a pretrained SSM with an attention variant boosts retrieval and benchmark scores.

ICML Conference 2024 Conference Paper

Caduceus: Bi-Directional Equivariant Long-Range DNA Sequence Modeling

  • Yair Schiff
  • Chia-Hsiang Kao
  • Aaron Gokaslan
  • Tri Dao
  • Albert Gu
  • Volodymyr Kuleshov

Large-scale sequence modeling has sparked rapid advances that now extend into biology and genomics. However, modeling genomic sequences introduces challenges such as the need to model long-range token interactions, the effects of upstream and downstream regions of the genome, and the reverse complementarity (RC) of DNA. Here, we propose an architecture motivated by these challenges that builds off the long-range Mamba block, and extends it to a BiMamba component that supports bi-directionality, and to a MambaDNA block that additionally supports RC equivariance. We use MambaDNA as the basis of Caduceus, the first family of RC equivariant bi-directional long-range DNA language models, and we introduce pre-training and fine-tuning strategies that yield Caduceus DNA foundation models. Caduceus outperforms previous long-range models on downstream benchmarks; on a challenging long-range variant effect prediction task, Caduceus exceeds the performance of 10x larger models that do not leverage bi-directionality or equivariance. Code to reproduce our experiments is available here: https: //github. com/kuleshov-group/caduceus.

NeurIPS Conference 2024 Conference Paper

Hydra: Bidirectional State Space Models Through Generalized Matrix Mixers

  • Sukjun Hwang
  • Aakash Lahoti
  • Ratish Puduppully
  • Tri Dao
  • Albert Gu

A wide array of sequence models are built on a framework modeled after Transformers, comprising alternating sequence mixer and channel mixer layers. This paper studies a unifying matrix mixer view of sequence mixers that can be conceptualized as a linear map on the input sequence. This framework encompasses a broad range of well-known sequence models, including the self-attention of Transformers as well as recent strong alternatives such as structured state space models (SSMs), and allows understanding downstream characteristics such as efficiency and expressivity through properties of their structured matrix class. We identify a key axis of matrix parameterizations termed sequence alignment, which increases the flexibility and performance of matrix mixers, providing insights into the strong performance of Transformers and recent SSMs such as Mamba. Furthermore, the matrix mixer framework offers a systematic approach to developing sequence mixers with desired properties, allowing us to develop several new sub-quadratic sequence models. In particular, we propose a natural bidirectional extension of the Mamba model ( Hydra ), parameterized as a quasiseparable matrix mixer, which demonstrates superior performance over other sequence models including Transformers on non-causal tasks. As a drop-in replacement for attention layers, \name outperforms BERT by 0. 8 points on the GLUE benchmark and ViT by 2% Top-1 accuracy on ImageNet.

ICML Conference 2024 Conference Paper

Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality

  • Tri Dao
  • Albert Gu

While Transformers have been the main architecture behind deep learning’s success in language modeling, state-space models (SSMs) such as Mamba have recently been shown to match or outperform Transformers at small to medium scale. We show that these families of models are actually quite closely related, and develop a rich framework of theoretical connections between SSMs and variants of attention, connected through various decompositions of a well-studied class of structured semiseparable matrices. Our state space duality (SSD) framework allows us to design a new architecture ( Mamba-2 ) whose core layer is an a refinement of Mamba’s selective SSM that is 2-8$\times$ faster, while continuing to be competitive with Transformers on language modeling.

NeurIPS Conference 2024 Conference Paper

Transformers to SSMs: Distilling Quadratic Knowledge to Subquadratic Models

  • Aviv Bick
  • Kevin Y. Li
  • Eric P. Xing
  • J. Z. Kolter
  • Albert Gu

Transformer architectures have become a dominant paradigm for domains like language modeling but suffer in many inference settings due to their quadratic-time self-attention. Recently proposed subquadratic architectures, such as Mamba, have shown promise, but have been pretrained with substantially less computational resources than the strongest Transformer models. In this work, we present a method that is able to distill a pretrained Transformer architecture into alternative architectures such as state space models (SSMs). The key idea to our approach is that we can view both Transformers and SSMs as applying different forms of mixing matrices over the token sequences. We can thus progressively distill the Transformer architecture by matching different degrees of granularity in the SSM: first matching the mixing matrices themselves, then the hidden units at each block, and finally the end-to-end predictions. Our method, called MOHAWK, is able to distill a Mamba-2 variant based on the Phi-1. 5 architecture (Phi-Mamba) using only 3B tokens. Despite using less than 1% of the training data typically used to train models from scratch, Phi-Mamba boasts substantially stronger performance compared to all past open-source non-Transformer models. MOHAWK allows models like SSMs to leverage computational resources invested in training Transformer-based architectures, highlighting a new avenue for building such models.

ICLR Conference 2023 Conference Paper

How to Train your HIPPO: State Space Models with Generalized Orthogonal Basis Projections

  • Albert Gu
  • Isys Johnson
  • Aman Timalsina
  • Atri Rudra
  • Christopher Ré

Linear time-invariant state space models (SSM) are a classical model from engineering and statistics, that have recently been shown to be very promising in machine learning through the Structured State Space sequence model (S4). A core component of S4 involves initializing the SSM state matrix to a particular matrix called a HiPPO matrix, which was empirically important for S4's ability to handle long sequences. However, the specific matrix that S4 uses was actually derived in previous work for a particular *time-varying* dynamical system, and the use of this matrix as a *time-invariant* SSM had no known mathematical interpretation. Consequently, the theoretical mechanism by which S4 models long-range dependencies actually remains unexplained. We derive a more general and intuitive formulation of the HiPPO framework, which provides a simple mathematical interpretation of S4 as a decomposition onto exponentially-warped Legendre polynomials, explaining its ability to capture long dependencies. Our generalization introduces a theoretically rich class of SSMs that also lets us derive more intuitive S4 variants for other bases such as the Fourier basis, and explains other aspects of training S4, such as how to initialize the important timescale parameter. These insights improve S4's performance to 86% on the Long Range Arena benchmark, with 96% on the most difficult Path-X task.

ICLR Conference 2023 Conference Paper

Modelling Long Range Dependencies in $N$D: From Task-Specific to a General Purpose CNN

  • David M. Knigge
  • David W. Romero
  • Albert Gu
  • Efstratios Gavves
  • Erik J. Bekkers
  • Jakub M. Tomczak
  • Mark Hoogendoorn
  • Jan-Jakob Sonke

Performant Convolutional Neural Network (CNN) architectures must be tailored to specific tasks in order to consider the length, resolution, and dimensionality of the input data. In this work, we tackle the need for problem-specific CNN architectures. We present the Continuous Convolutional Neural Network (CCNN): a single CNN able to process data of arbitrary resolution, dimensionality and length without any structural changes. Its key component are its continuous convolutional kernels which model long-range dependencies at every layer, and thus remove the need of current CNN architectures for task-dependent downsampling and depths. We showcase the generality of our method by using the same architecture for tasks on sequential ($1{\rm D}$), visual ($2{\rm D}$) and point-cloud ($3{\rm D}$) data. Our CCNN matches and often outperforms the current state-of-the-art across all tasks considered.

ICML Conference 2023 Conference Paper

Resurrecting Recurrent Neural Networks for Long Sequences

  • Antonio Orvieto
  • Samuel L. Smith
  • Albert Gu
  • Anushan Fernando
  • Çaglar Gülçehre
  • Razvan Pascanu
  • Soham De

Recurrent Neural Networks (RNNs) offer fast inference on long sequences but are hard to optimize and slow to train. Deep state-space models (SSMs) have recently been shown to perform remarkably well on long sequence modeling tasks, and have the added benefits of fast parallelizable training and RNN-like fast inference. However, while SSMs are superficially similar to RNNs, there are important differences that make it unclear where their performance boost over RNNs comes from. We show that careful design of deep RNNs using standard signal propagation arguments can recover the impressive performance of deep SSMs on long-range reasoning tasks, while matching their training speed. To achieve this, we analyze and ablate a series of changes to standard RNNs including linearizing and diagonalizing the recurrence, using better parameterizations and initializations, and ensuring careful normalization of the forward pass. Our results provide new insights on the origins of the impressive performance of deep SSMs, and introduce an RNN block called the Linear Recurrent Unit (or LRU) that matches both their performance on the Long Range Arena benchmark and their computational efficiency.

NeurIPS Conference 2023 Conference Paper

Structured State Space Models for In-Context Reinforcement Learning

  • Chris Lu
  • Yannick Schroecker
  • Albert Gu
  • Emilio Parisotto
  • Jakob Foerster
  • Satinder Singh
  • Feryal Behbahani

Structured state space sequence (S4) models have recently achieved state-of-the-art performance on long-range sequence modeling tasks. These models also have fast inference speeds and parallelisable training, making them potentially useful in many reinforcement learning settings. We propose a modification to a variant of S4 that enables us to initialise and reset the hidden state in parallel, allowing us to tackle reinforcement learning tasks. We show that our modified architecture runs asymptotically faster than Transformers in sequence length and performs better than RNN's on a simple memory-based task. We evaluate our modified architecture on a set of partially-observable environments and find that, in practice, our model outperforms RNN's while also running over five times faster. Then, by leveraging the model’s ability to handle long-range sequences, we achieve strong performance on a challenging meta-learning task in which the agent is given a randomly-sampled continuous control environment, combined with a randomly-sampled linear projection of the environment's observations and actions. Furthermore, we show the resulting model can adapt to out-of-distribution held-out tasks. Overall, the results presented in this paper show that structured state space models are fast and performant for in-context reinforcement learning tasks. We provide code at https: //github. com/luchris429/s5rl.

NeurIPS Conference 2022 Conference Paper

Diagonal State Spaces are as Effective as Structured State Spaces

  • Ankit Gupta
  • Albert Gu
  • Jonathan Berant

Modeling long range dependencies in sequential data is a fundamental step towards attaining human-level performance in many modalities such as text, vision, audio and video. While attention-based models are a popular and effective choice in modeling short-range interactions, their performance on tasks requiring long range reasoning has been largely inadequate. In an exciting result, Gu et al. (ICLR 2022) proposed the $\textit{Structured State Space}$ (S4) architecture delivering large gains over state-of-the-art models on several long-range tasks across various modalities. The core proposition of S4 is the parameterization of state matrices via a diagonal plus low rank structure, allowing efficient computation. In this work, we show that one can match the performance of S4 even without the low rank correction and thus assuming the state matrices to be diagonal. Our $\textit{Diagonal State Space}$ (DSS) model matches the performance of S4 on Long Range Arena tasks, speech classification on Speech Commands dataset, while being conceptually simpler and straightforward to implement.

ICLR Conference 2022 Conference Paper

Efficiently Modeling Long Sequences with Structured State Spaces

  • Albert Gu
  • Karan Goel
  • Christopher Ré

A central goal of sequence modeling is designing a single principled model that can address sequence data across a range of modalities and tasks, particularly on long-range dependencies. Although conventional models including RNNs, CNNs, and Transformers have specialized variants for capturing long dependencies, they still struggle to scale to very long sequences of $10000$ or more steps. A promising recent approach proposed modeling sequences by simulating the fundamental state space model (SSM) \( x'(t) = Ax(t) + Bu(t), y(t) = Cx(t) + Du(t) \), and showed that for appropriate choices of the state matrix \( A \), this system could handle long-range dependencies mathematically and empirically. However, this method has prohibitive computation and memory requirements, rendering it infeasible as a general sequence modeling solution. We propose the Structured State Space sequence model (S4) based on a new parameterization for the SSM, and show that it can be computed much more efficiently than prior approaches while preserving their theoretical strengths. Our technique involves conditioning \( A \) with a low-rank correction, allowing it to be diagonalized stably and reducing the SSM to the well-studied computation of a Cauchy kernel. S4 achieves strong empirical results across a diverse range of established benchmarks, including (i) 91\% accuracy on sequential CIFAR-10 with no data augmentation or auxiliary losses, on par with a larger 2-D ResNet, (ii) substantially closing the gap to Transformers on image and language modeling tasks, while performing generation $60\times$ faster (iii) SoTA on every task from the Long Range Arena benchmark, including solving the challenging Path-X task of length 16k that all prior work fails on, while being as efficient as all competitors.

ICML Conference 2022 Conference Paper

It's Raw! Audio Generation with State-Space Models

  • Karan Goel
  • Albert Gu
  • Chris Donahue
  • Christopher Ré

Developing architectures suitable for modeling raw audio is a challenging problem due to the high sampling rates of audio waveforms. Standard sequence modeling approaches like RNNs and CNNs have previously been tailored to fit the demands of audio, but the resultant architectures make undesirable computational tradeoffs and struggle to model waveforms effectively. We propose SaShiMi, a new multi-scale architecture for waveform modeling built around the recently introduced S4 model for long sequence modeling. We identify that S4 can be unstable during autoregressive generation, and provide a simple improvement to its parameterization by drawing connections to Hurwitz matrices. SaShiMi yields state-of-the-art performance for unconditional waveform generation in the autoregressive setting. Additionally, SaShiMi improves non-autoregressive generation performance when used as the backbone architecture for a diffusion model. Compared to prior architectures in the autoregressive generation setting, SaShiMi generates piano and speech waveforms which humans find more musical and coherent respectively, e. g. 2{\texttimes} better mean opinion scores than WaveNet on an unconditional speech generation task. On a music generation task, SaShiMi outperforms WaveNet on density estimation and speed at both training and inference even when using 3{\texttimes} fewer parameters

NeurIPS Conference 2022 Conference Paper

On the Parameterization and Initialization of Diagonal State Space Models

  • Albert Gu
  • Karan Goel
  • Ankit Gupta
  • Christopher Ré

State space models (SSM) have recently been shown to be very effective as a deep learning layer as a promising alternative to sequence models such as RNNs, CNNs, or Transformers. The first version to show this potential was the S4 model, which is particularly effective on tasks involving long-range dependencies by using a prescribed state matrix called the HiPPO matrix. While this has an interpretable mathematical mechanism for modeling long dependencies, it also requires a custom representation and algorithm that makes the model difficult to understand and implement. On the other hand, a recent variant of S4 called DSS showed that restricting the state matrix to be fully diagonal can still preserve the performance of the original model when using a specific initialization based on approximating S4's matrix. This work seeks to systematically understand how to parameterize and initialize diagonal state space models. While it follows from classical results that almost all SSMs have an equivalent diagonal form, we show that the initialization is critical for performance. First, we explain why DSS works mathematically, as the diagonal approximation to S4 surprisingly recovers the same dynamics in the limit of infinite state dimension. We then systematically describe various design choices in parameterizing and computing diagonal SSMs, and perform a controlled empirical study ablating the effects of these choices. Our final model S4D is a simple diagonal version of S4 whose kernel computation requires just 3 lines of code and performs comparably to S4 in almost all settings, with state-of-the-art results in image, audio, and medical time-series domains, and 85\% average on the Long Range Arena benchmark.

NeurIPS Conference 2022 Conference Paper

S4ND: Modeling Images and Videos as Multidimensional Signals with State Spaces

  • Eric Nguyen
  • Karan Goel
  • Albert Gu
  • Gordon Downs
  • Preey Shah
  • Tri Dao
  • Stephen Baccus
  • Christopher Ré

Visual data such as images and videos are typically modeled as discretizations of inherently continuous, multidimensional signals. Existing continuous-signal models attempt to exploit this fact by modeling the underlying signals of visual (e. g. , image) data directly. However, these models have not yet been able to achieve competitive performance on practical vision tasks such as large-scale image and video classification. Building on a recent line of work on deep state space models (SSMs), we propose \method, a new multidimensional SSM layer that extends the continuous-signal modeling ability of SSMs to multidimensional data including images and videos. We show that S4ND can model large-scale visual data in $1$D, $2$D, and $3$D as continuous multidimensional signals and demonstrates strong performance by simply swapping Conv2D and self-attention layers with \method\ layers in existing state-of-the-art models. On ImageNet-1k, \method\ exceeds the performance of a Vision Transformer baseline by $1. 5\%$ when training with a $1$D sequence of patches, and matches ConvNeXt when modeling images in $2$D. For videos, S4ND improves on an inflated $3$D ConvNeXt in activity classification on HMDB-51 by $4\%$. S4ND implicitly learns global, continuous convolutional kernels that are resolution invariant by construction, providing an inductive bias that enables generalization across multiple resolutions. By developing a simple bandlimiting modification to S4 to overcome aliasing, S4ND achieves strong zero-shot (unseen at training time) resolution performance, outperforming a baseline Conv2D by $40\%$ on CIFAR-10 when trained on $8 \times 8$ and tested on $32 \times 32$ images. When trained with progressive resizing, S4ND comes within $\sim 1\%$ of a high-resolution model while training $22\%$ faster.

ICML Conference 2021 Conference Paper

Catformer: Designing Stable Transformers via Sensitivity Analysis

  • Jared Quincy Davis
  • Albert Gu
  • Krzysztof Choromanski
  • Tri Dao
  • Christopher Ré
  • Chelsea Finn
  • Percy Liang

Transformer architectures are widely used, but training them is non-trivial, requiring custom learning rate schedules, scaling terms, residual connections, careful placement of submodules such as normalization, and so on. In this paper, we improve upon recent analysis of Transformers and formalize a notion of sensitivity to capture the difficulty of training. Sensitivity characterizes how the variance of activation and gradient norms change in expectation when parameters are randomly perturbed. We analyze the sensitivity of previous Transformer architectures and design a new architecture, the Catformer, which replaces residual connections or RNN-based gating mechanisms with concatenation. We prove that Catformers are less sensitive than other Transformer variants and demonstrate that this leads to more stable training. On DMLab30, a suite of high-dimension reinforcement tasks, Catformer outperforms other transformers, including Gated Transformer-XL—the state-of-the-art architecture designed to address stability—by 13%.

NeurIPS Conference 2021 Conference Paper

Combining Recurrent, Convolutional, and Continuous-time Models with Linear State Space Layers

  • Albert Gu
  • Isys Johnson
  • Karan Goel
  • Khaled Saab
  • Tri Dao
  • Atri Rudra
  • Christopher Ré

Recurrent neural networks (RNNs), temporal convolutions, and neural differential equations (NDEs) are popular families of deep learning models for time-series data, each with unique strengths and tradeoffs in modeling power and computational efficiency. We introduce a simple sequence model inspired by control systems that generalizes these approaches while addressing their shortcomings. The Linear State-Space Layer (LSSL) maps a sequence $u \mapsto y$ by simply simulating a linear continuous-time state-space representation $\dot{x} = Ax + Bu, y = Cx + Du$. Theoretically, we show that LSSL models are closely related to the three aforementioned families of models and inherit their strengths. For example, they generalize convolutions to continuous-time, explain common RNN heuristics, and share features of NDEs such as time-scale adaptation. We then incorporate and generalize recent theory on continuous-time memorization to introduce a trainable subset of structured matrices $A$ that endow LSSLs with long-range memory. Empirically, stacking LSSL layers into a simple deep neural network obtains state-of-the-art results across time series benchmarks for long dependencies in sequential image classification, real-world healthcare regression tasks, and speech. On a difficult speech classification task with length-16000 sequences, LSSL outperforms prior approaches by 24 accuracy points, and even outperforms baselines that use hand-crafted features on 100x shorter sequences.

ICML Conference 2021 Conference Paper

HoroPCA: Hyperbolic Dimensionality Reduction via Horospherical Projections

  • Ines Chami
  • Albert Gu
  • Dat Nguyen
  • Christopher Ré

This paper studies Principal Component Analysis (PCA) for data lying in hyperbolic spaces. Given directions, PCA relies on: (1) a parameterization of subspaces spanned by these directions, (2) a method of projection onto subspaces that preserves information in these directions, and (3) an objective to optimize, namely the variance explained by projections. We generalize each of these concepts to the hyperbolic space and propose HoroPCA, a method for hyperbolic dimensionality reduction. By focusing on the core problem of extracting principal directions, HoroPCA theoretically better preserves information in the original data such as distances, compared to previous generalizations of PCA. Empirically, we validate that HoroPCA outperforms existing dimensionality reduction methods, significantly reducing error in distance preservation. As a data whitening method, it improves downstream classification by up to 3. 9% compared to methods that don’t use whitening. Finally, we show that HoroPCA can be used to visualize hyperbolic data in two dimensions.

ICLR Conference 2021 Conference Paper

Model Patching: Closing the Subgroup Performance Gap with Data Augmentation

  • Karan Goel
  • Albert Gu
  • Yixuan Li 0001
  • Christopher Ré

Classifiers in machine learning are often brittle when deployed. Particularly concerning are models with inconsistent performance on specific subgroups of a class, e.g., exhibiting disparities in skin cancer classification in the presence or absence of a spurious bandage. To mitigate these performance differences, we introduce model patching, a two-stage framework for improving robustness that encourages the model to be invariant to subgroup differences, and focus on class information shared by subgroups. Model patching first models subgroup features within a class and learns semantic transformations between them, and then trains a classifier with data augmentations that deliberately manipulate subgroup features. We instantiate model patching with CAMEL, which (1) uses a CycleGAN to learn the intra-class, inter-subgroup augmentations, and (2) balances subgroup performance using a theoretically-motivated subgroup consistency regularizer, accompanied by a new robust objective. We demonstrate CAMEL’s effectiveness on 3 benchmark datasets, with reductions in robust error of up to 33% relative to the best baseline. Lastly, CAMEL successfully patches a model that fails due to spurious features on a real-world skin cancer dataset.

NeurIPS Conference 2020 Conference Paper

From Trees to Continuous Embeddings and Back: Hyperbolic Hierarchical Clustering

  • Ines Chami
  • Albert Gu
  • Vaggos Chatziafratis
  • Christopher Ré

Similarity-based Hierarchical Clustering (HC) is a classical unsupervised machine learning algorithm that has traditionally been solved with heuristic algorithms like Average-Linkage. Recently, Dasgupta reframed HC as a discrete optimization problem by introducing a global cost function measuring the quality of a given tree. In this work, we provide the first continuous relaxation of Dasgupta's discrete optimization problem with provable quality guarantees. The key idea of our method, HypHC, is showing a direct correspondence from discrete trees to continuous representations (via the hyperbolic embeddings of their leaf nodes) and back (via a decoding algorithm that maps leaf embeddings to a dendrogram), allowing us to search the space of discrete binary trees with continuous optimization. Building on analogies between trees and hyperbolic space, we derive a continuous analogue for the notion of lowest common ancestor, which leads to a continuous relaxation of Dasgupta's discrete objective. We can show that after decoding, the global minimizer of our continuous relaxation yields a discrete tree with a (1+eps)-factor approximation for Dasgupta's optimal tree, where eps can be made arbitrarily small and controls optimization challenges. We experimentally evaluate HypHC on a variety of HC benchmarks and find that even approximate solutions found with gradient descent have superior clustering quality than agglomerative heuristics or other gradient based algorithms. Finally, we highlight the flexibility of HypHC using end-to-end training in a downstream classification task.

NeurIPS Conference 2020 Conference Paper

HiPPO: Recurrent Memory with Optimal Polynomial Projections

  • Albert Gu
  • Tri Dao
  • Stefano Ermon
  • Atri Rudra
  • Christopher Ré

A central problem in learning from sequential data is representing cumulative history in an incremental fashion as more data is processed. We introduce a general framework (HiPPO) for the online compression of continuous signals and discrete time series by projection onto polynomial bases. Given a measure that specifies the importance of each time step in the past, HiPPO produces an optimal solution to a natural online function approximation problem. As special cases, our framework yields a short derivation of the recent Legendre Memory Unit (LMU) from first principles, and generalizes the ubiquitous gating mechanism of recurrent neural networks such as GRUs. This formal framework yields a new memory update mechanism (HiPPO-LegS) that scales through time to remember all history, avoiding priors on the timescale. HiPPO-LegS enjoys the theoretical benefits of timescale robustness, fast updates, and bounded gradients. By incorporating the memory dynamics into recurrent neural networks, HiPPO RNNs can empirically capture complex temporal dependencies. On the benchmark permuted MNIST dataset, HiPPO-LegS sets a new state-of-the-art accuracy of 98. 3%. Finally, on a novel trajectory classification task testing robustness to out-of-distribution timescales and missing data, HiPPO-LegS outperforms RNN and neural ODE baselines by 25-40% accuracy.

ICML Conference 2020 Conference Paper

Improving the Gating Mechanism of Recurrent Neural Networks

  • Albert Gu
  • Çaglar Gülçehre
  • Thomas Paine
  • Matthew Hoffman 0002
  • Razvan Pascanu

Gating mechanisms are widely used in neural network models, where they allow gradients to backpropagate easily through depth or time. However, their saturation property introduces problems of its own. For example, in recurrent models these gates need to have outputs near 1 to propagate information over long time-delays, which requires them to operate in their saturation regime and hinders gradient-based learning of the gate mechanism. We address this problem by deriving two synergistic modifications to the standard gating mechanism that are easy to implement, introduce no additional hyperparameters, and improve learnability of the gates when they are close to saturation. We show how these changes are related to and improve on alternative recently proposed gating mechanisms such as chrono-initialization and Ordered Neurons. Empirically, our simple gating mechanisms robustly improve the performance of recurrent models on a range of applications, including synthetic memorization tasks, sequential image classification, language modeling, and reinforcement learning, particularly when long-term dependencies are involved.

ICLR Conference 2020 Conference Paper

Kaleidoscope: An Efficient, Learnable Representation For All Structured Linear Maps

  • Tri Dao
  • Nimit Sharad Sohoni
  • Albert Gu
  • Matthew Eichhorn
  • Amit Blonder
  • Megan Leszczynski
  • Atri Rudra
  • Christopher Ré

Modern neural network architectures use structured linear transformations, such as low-rank matrices, sparse matrices, permutations, and the Fourier transform, to improve inference speed and reduce memory usage compared to general linear maps. However, choosing which of the myriad structured transformations to use (and its associated parameterization) is a laborious task that requires trading off speed, space, and accuracy. We consider a different approach: we introduce a family of matrices called kaleidoscope matrices (K-matrices) that provably capture any structured matrix with near-optimal space (parameter) and time (arithmetic operation) complexity. We empirically validate that K-matrices can be automatically learned within end-to-end pipelines to replace hand-crafted procedures, in order to improve model quality. For example, replacing channel shuffles in ShuffleNet improves classification accuracy on ImageNet by up to 5%. K-matrices can also simplify hand-engineered pipelines---we replace filter bank feature computation in speech data preprocessing with a learnable kaleidoscope layer, resulting in only 0.4% loss in accuracy on the TIMIT speech recognition task. In addition, K-matrices can capture latent structure in models: for a challenging permuted image classification task, adding a K-matrix to a standard convolutional architecture can enable learning the latent permutation and improve accuracy by over 8 points. We provide a practically efficient implementation of our approach, and use K-matrices in a Transformer network to attain 36% faster end-to-end inference speed on a language translation task.

NeurIPS Conference 2020 Conference Paper

No Subclass Left Behind: Fine-Grained Robustness in Coarse-Grained Classification Problems

  • Nimit Sohoni
  • Jared Dunnmon
  • Geoffrey Angus
  • Albert Gu
  • Christopher Ré

In real-world classification tasks, each class often comprises multiple finer-grained "subclasses. " As the subclass labels are frequently unavailable, models trained using only the coarser-grained class labels often exhibit highly variable performance across different subclasses. This phenomenon, known as hidden stratification, has important consequences for models deployed in safety-critical applications such as medicine. We propose GEORGE, a method to both measure and mitigate hidden stratification even when subclass labels are unknown. We first observe that unlabeled subclasses are often separable in the feature space of deep models, and exploit this fact to estimate subclass labels for the training data via clustering techniques. We then use these approximate subclass labels as a form of noisy supervision in a distributionally robust optimization objective. We theoretically characterize the performance of GEORGE in terms of the worst-case generalization error across any subclass. We empirically validate GEORGE on a mix of real-world and benchmark image classification datasets, and show that our approach boosts worst-case subclass accuracy by up to 15 percentage points compared to standard training techniques, without requiring any information about the subclasses.

ICML Conference 2019 Conference Paper

A Kernel Theory of Modern Data Augmentation

  • Tri Dao
  • Albert Gu
  • Alexander Ratner
  • Virginia Smith
  • Christopher De Sa
  • Christopher Ré

Data augmentation, a technique in which a training set is expanded with class-preserving transformations, is ubiquitous in modern machine learning pipelines. In this paper, we seek to establish a theoretical framework for understanding data augmentation. We approach this from two directions: First, we provide a general model of augmentation as a Markov process, and show that kernels appear naturally with respect to this model, even when we do not employ kernel classification. Next, we analyze more directly the effect of augmentation on kernel classifiers, showing that data augmentation can be approximated by first-order feature averaging and second-order variance regularization components. These frameworks both serve to illustrate the ways in which data augmentation affects the downstream learning model, and the resulting analyses provide novel connections between prior work in invariant kernels, tangent propagation, and robust optimization. Finally, we provide several proof-of-concept applications showing that our theory can be useful for accelerating machine learning workflows, such as reducing the amount of computation needed to train using augmented data, and predicting the utility of a transformation prior to training.

ICML Conference 2019 Conference Paper

Learning Fast Algorithms for Linear Transforms Using Butterfly Factorizations

  • Tri Dao
  • Albert Gu
  • Matthew Eichhorn
  • Atri Rudra
  • Christopher Ré

Fast linear transforms are ubiquitous in machine learning, including the discrete Fourier transform, discrete cosine transform, and other structured transformations such as convolutions. All of these transforms can be represented by dense matrix-vector multiplication, yet each has a specialized and highly efficient (subquadratic) algorithm. We ask to what extent hand-crafting these algorithms and implementations is necessary, what structural prior they encode, and how much knowledge is required to automatically learn a fast algorithm for a provided structured transform. Motivated by a characterization of fast matrix-vector multiplication as products of sparse matrices, we introduce a parameterization of divide-and-conquer methods that is capable of representing a large class of transforms. This generic formulation can automatically learn an efficient algorithm for many important transforms; for example, it recovers the $O(N \log N)$ Cooley-Tukey FFT algorithm to machine precision, for dimensions $N$ up to $1024$. Furthermore, our method can be incorporated as a lightweight replacement of generic matrices in machine learning pipelines to learn efficient and compressible transformations. On a standard task of compressing a single hidden-layer network, our method exceeds the classification accuracy of unconstrained matrices on CIFAR-10 by 3. 9 points—the first time a structured approach has done so—with 4X faster inference speed and 40X fewer parameters.

SODA Conference 2018 Conference Paper

A Two-pronged Progress in Structured Dense Matrix Vector Multiplication

  • Christopher De Sa
  • Albert Gu
  • Rohan Puttagunta
  • Christopher Ré
  • Atri Rudra

Matrix-vector multiplication is one of the most fundamental computing primitives. Given a matrix and a vector, it is known that in the worst case Θ( N 2 ) operations over F are needed to compute Ab. Many types of structured matrices do admit faster multiplication. However, even given a matrix A that is known to have this property, it is hard in general to recover a representation of A exposing the actual fast multiplication algorithm. Additionally, it is not known in general whether the inverses of such structured matrices can be computed or multiplied quickly. A broad question is thus to identify classes of structured dense matrices that can be represented with O ( N ) parameters, and for which matrix-vector multiplication (and ideally other operations such as solvers) can be performed in a sub-quadratic number of operations. One such class of structured matrices that admit near-linear matrix-vector multiplication are the orthogonal polynomial transforms whose rows correspond to a family of orthogonal polynomials. Other well known classes include the Toeplitz, Hankel, Vandermonde, Cauchy matrices and their extensions (e. g. confluent Cauchy-like matrices) that are all special cases of a low displacement rank property. In this paper, we make progress on two fronts: 1. We introduce the notion of recurrence width of matrices. For matrices A with constant recurrence width, we design algorithms to compute both Ab and A T b with a near-linear number of operations. This notion of width is finer than all the above classes of structured matrices and thus we can compute near-linear matrix-vector multiplication for all of them using the same core algorithm. Furthermore, we show that it is possible to solve the harder problems of recovering the structured parameterization of a matrix with low recurrence width, and computing matrix-vector product with its inverse in near-linear time. 2. We additionally adapt our algorithm to a matrix-vector multiplication algorithm for a much more general class of matrices with displacement structure: those with low displacement rank with respect to quasiseparable matrices. This result is a novel connection between matrices with displacement structure and those with rank structure, two large but previously separate classes of structured matrices. This class includes Toeplitz-plus-Hankel-like matrices, the Discrete Trigonometric Transforms, and more, and captures all previously known matrices with displacement structure under a unified parameterization and algorithm. Our work unifies, generalizes, and simplifies existing state-of-the-art results in structured matrix-vector multiplication. Finally, we show how applications in areas such as multipoint evaluations of multivariate polynomials can be reduced to problems involving low recurrence width matrices.

NeurIPS Conference 2018 Conference Paper

Learning Compressed Transforms with Low Displacement Rank

  • Anna Thomas
  • Albert Gu
  • Tri Dao
  • Atri Rudra
  • Christopher Ré

The low displacement rank (LDR) framework for structured matrices represents a matrix through two displacement operators and a low-rank residual. Existing use of LDR matrices in deep learning has applied fixed displacement operators encoding forms of shift invariance akin to convolutions. We introduce a rich class of LDR matrices with more general displacement operators, and explicitly learn over both the operators and the low-rank component. This class generalizes several previous constructions while preserving compression and efficient computation. We prove bounds on the VC dimension of multi-layer neural networks with structured weight matrices and show empirically that our compact parameterization can reduce the sample complexity of learning. When replacing weight layers in fully-connected, convolutional, and recurrent neural networks for image classification and language modeling tasks, our new classes exceed the accuracy of existing compression approaches, and on some tasks even outperform general unstructured layers while using more than 20x fewer parameters.

ICML Conference 2018 Conference Paper

Representation Tradeoffs for Hyperbolic Embeddings

  • Frederic Sala
  • Christopher De Sa
  • Albert Gu
  • Christopher Ré

Hyperbolic embeddings offer excellent quality with few dimensions when embedding hierarchical data structures. We give a combinatorial construction that embeds trees into hyperbolic space with arbitrarily low distortion without optimization. On WordNet, this algorithm obtains a mean-average-precision of 0. 989 with only two dimensions, outperforming existing work by 0. 11 points. We provide bounds characterizing the precision-dimensionality tradeoff inherent in any hyperbolic embedding. To embed general metric spaces, we propose a hyperbolic generalization of multidimensional scaling (h-MDS). We show how to perform exact recovery of hyperbolic points from distances, provide a perturbation analysis, and give a recovery result that enables us to reduce dimensionality. Finally, we extract lessons from the algorithms and theory above to design a scalable PyTorch-based implementation that can handle incomplete information.