Arrow Research search

Author name cluster

Christopher Ré

Possible papers associated with this exact author name in Arrow. This page groups case-insensitive exact name matches and is not a full identity disambiguation profile.

110 papers
2 author rows

Possible papers

110

ICLR Conference 2025 Conference Paper

Aioli: A Unified Optimization Framework for Language Model Data Mixing

  • Mayee F. Chen
  • Michael Y. Hu
  • Nicholas Lourie
  • Kyunghyun Cho
  • Christopher Ré

Language model performance depends on identifying the optimal mixture of data groups to train on (e.g., law, code, math). Prior work has proposed a diverse set of methods to efficiently learn mixture proportions, ranging from fitting regression models over training runs to dynamically updating proportions throughout training. Surprisingly, we find that no existing method consistently outperforms a simple stratified sampling baseline in terms of average test perplexity. To understand this inconsistency, we unify existing methods into a standard framework, showing they are equivalent to solving a common optimization problem: minimize average loss subject to a method-specific mixing law---an implicit assumption on the relationship between loss and mixture proportions. This framework suggests that measuring the fidelity of a method's mixing law can offer insights into its performance. Empirically, we find that existing methods set their mixing law parameters inaccurately, resulting in the inconsistent mixing performance we observe. Using this insight, we derive a new online method named Aioli, which directly estimates the mixing law parameters throughout training and uses them to dynamically adjust proportions. Empirically, Aioli outperforms stratified sampling on 6 out of 6 datasets by an average of 0.27 test perplexity points, whereas existing methods fail to consistently beat stratified sampling, doing up to 6.9 points worse. Moreover, in a practical setting where proportions are learned on shorter runs due to computational constraints, Aioli can dynamically adjust these proportions over the full training run, consistently improving performance over existing methods by up to 12.012 test perplexity points.

ICML Conference 2025 Conference Paper

An Architecture Search Framework for Inference-Time Techniques

  • Jon Saad-Falcon
  • Adrian Gamarra Lafuente
  • Shlok Natarajan
  • Nahum Maru
  • Hristo Todorov
  • Etash Kumar Guha
  • Estefany Kelly Buchanan
  • Mayee F. Chen

Inference-time techniques, such as repeated sampling or iterative revisions, are emerging as powerful ways to enhance large-language models (LLMs) at test time. However, best practices for developing systems that combine these techniques remain underdeveloped due to our limited understanding of the utility of each technique across models and tasks, the interactions between them, and the massive search space for combining them. To address these challenges, we introduce Archon, a modular and automated framework for optimizing the process of selecting and combining inference-time techniques and LLMs. Given a compute budget and a set of available LLMs, Archon explores a large design space to discover optimized configurations tailored to target benchmarks. It can design custom or general-purpose architectures that advance the Pareto frontier of accuracy vs. maximum token budget compared to top-performing baselines. Across instruction-following, reasoning, and coding tasks, we show that Archon can leverage additional inference compute budget to design systems that outperform frontier models such as OpenAI’s o1, GPT-4o, and Claude 3. 5 Sonnet by an average of 15. 1%.

ICLR Conference 2025 Conference Paper

Context Clues: Evaluating Long Context Models for Clinical Prediction Tasks on EHR Data

  • Michael Wornow
  • Suhana Bedi
  • Miguel Angel Fuentes Hernandez
  • Ethan Steinberg
  • Jason Alan Fries
  • Christopher Ré
  • Sanmi Koyejo
  • Nigam H. Shah

Foundation Models (FMs) trained on Electronic Health Records (EHRs) have achieved state-of-the-art results on numerous clinical prediction tasks. However, prior EHR FMs typically have context windows of $<$1k tokens, which prevents them from modeling full patient EHRs which can exceed 10k's of events. For making clinical predictions, both model performance and robustness to the unique properties of EHR data are crucial. Recent advancements in subquadratic long-context architectures (e.g. Mamba) offer a promising solution. However, their application to EHR data has not been well-studied. We address this gap by presenting the first systematic evaluation of the effect of context length on modeling EHR data. We find that longer context models improve predictive performance -- our Mamba-based model surpasses the prior state-of-the-art on 9/14 tasks on the EHRSHOT prediction benchmark. Additionally, we measure robustness to three unique, previously underexplored properties of EHR data: (1) the prevalence of ``copy-forwarded" diagnoses which create artificial token repetition in EHR sequences; (2) the irregular time intervals between EHR events which can lead to a wide range of timespans within a context window; and (3) the natural increase in disease complexity over time which makes later tokens in the EHR harder to predict than earlier ones. Stratifying our EHRSHOT results, we find that higher levels of each property correlate negatively with model performance (e.g., a 14% higher Brier loss between the least and most irregular patients), but that longer context models are more robust to more extreme levels of these properties. Our work highlights the potential for using long-context architectures to model EHR data, and offers a case study on how to identify and quantify new challenges in modeling sequential data motivated by domains outside of natural language. We release all of our model checkpoints and code.

ICML Conference 2025 Conference Paper

Cost-efficient Collaboration between On-device and Cloud Language Models

  • Avanika Narayan
  • Dan Biderman
  • Sabri Eyuboglu
  • Avner May
  • Scott W. Linderman
  • James Y. Zou
  • Christopher Ré

We investigate an emerging setup in which a small, on-device language model (LM) with access to local data collaborates with a frontier, cloud-hosted LM to solve real-world tasks involving financial, medical, and scientific reasoning over long documents. Can a local-remote collaboration reduce cloud inference costs while preserving quality? First, we consider a naïve collaboration protocol, coined MINION, where the local and remote models simply chat back and forth. Because only the local model ingests the full context, this protocol reduces cloud costs by 30. 4x, but recovers only 87% of the performance of the frontier model. We identify two key limitations of this protocol: the local model struggles to (1) follow the remote model’s multi-step instructions and (2) reason over long contexts. Motivated by these observations, we propose MINIONS, a protocol in which the remote model decomposes the task into easier subtasks over shorter chunks of the document, that are executed locally in parallel. MINIONS reduces costs by 5. 7$\times$ on average while recovering 97. 9% of the remote-only performance. Our analysis reveals several key design choices that influence the trade-off between cost and performance in local-remote systems.

ICML Conference 2025 Conference Paper

KernelBench: Can LLMs Write Efficient GPU Kernels?

  • Anne Ouyang
  • Simon Guo 0004
  • Simran Arora
  • Alex L. Zhang
  • William Hu
  • Christopher Ré
  • Azalia Mirhoseini

Efficient GPU kernels are crucial for building performant machine learning architectures, but writing them is a time-consuming challenge that requires significant expertise; therefore, we explore using language models (LMs) to automate kernel generation. We introduce KernelBench, an open-source framework for evaluating LMs’ ability to write fast and correct kernels on a suite of 250 carefully selected PyTorch ML workloads. KernelBench represents a real-world engineering environment and making progress on the introduced benchmark directly translates to faster practical kernels. We introduce a new evaluation metric $\text{fast}_p$, which measures the percentage of generated kernels that are functionally correct and offer a speedup greater than an adjustable threshold $p$ over baseline. Our experiments across various state-of-the-art models and test-time methods show that frontier reasoning models perform the best out of the box but still fall short overall, matching the PyTorch baseline in less than 20% of the cases. While we show that results can improve by leveraging execution and profiling feedback during iterative refinement, KernelBench remains a challenging benchmark, with its difficulty increasing as we raise speedup threshold $p$.

ICLR Conference 2025 Conference Paper

LoLCATs: On Low-Rank Linearizing of Large Language Models

  • Michael Zhang
  • Simran Arora
  • Rahul Chalamala
  • Benjamin Frederick Spector
  • Alan Wu
  • Krithik Ramesh
  • Aaryan Singhal
  • Christopher Ré

Recent works show we can linearize large language models (LLMs)—swapping the quadratic attentions of popular Transformer-based LLMs with subquadratic analogs, such as linear attention—avoiding the expensive pretraining costs. However, linearizing LLMs often significantly degrades model quality, still requires training over billions of tokens, and remains limited to smaller 1.3B to 7B LLMs. We thus propose Low-rank Linear Conversion via Attention Transfer (LoLCATs), a simple two-step method that improves LLM linearizing quality with orders of magnitudes less memory and compute. We base these steps on two findings. First, we can replace an LLM's softmax attentions with closely-approximating linear attentions, simply by *training* the linear attentions to match their softmax counterparts with an output MSE loss (“attention transfer”). Then, this enables adjusting for approximation errors and recovering LLM quality simply with *low-rank* adaptation (LoRA). LoLCATs significantly improves linearizing quality, training efficiency, and scalability. We significantly reduce the linearizing quality gap and produce state-of-the-art subquadratic LLMs from Llama 3 8B and Mistral 7B v0.1, leading to 20+ points of improvement on 5-shot MMLU. Furthermore, LoLCATs does so with only 0.2% of past methods' model parameters and 0.04-0.2% of their training tokens. Finally, we apply LoLCATs to create the first linearized 70B and 405B LLMs (50$\times$ that of prior work). When compared with prior approaches under the same compute budgets, LoLCATs significantly improves linearizing quality, closing the gap between linearized and original Llama 3.1 70B and 405B LLMs by 77.8\% and 78.1\% on 5-shot MMLU.

ICLR Conference 2025 Conference Paper

Restructuring Vector Quantization with the Rotation Trick

  • Christopher Fifty
  • Ronald Guenther Junkins
  • Dennis Duan
  • Aniketh Iyengar
  • Jerry Weihong Liu
  • Ehsan Amid
  • Sebastian Thrun
  • Christopher Ré

Vector Quantized Variational AutoEncoders (VQ-VAEs) are designed to compress a continuous input to a discrete latent space and reconstruct it with minimal distortion. They operate by maintaining a set of vectors---often referred to as the codebook---and quantizing each encoder output to the nearest vector in the codebook. However, as vector quantization is non-differentiable, the gradient to the encoder flows _around_ the vector quantization layer rather than _through_ it in a straight-through approximation. This approximation may be undesirable as all information from the vector quantization operation is lost. In this work, we propose a way to propagate gradients through the vector quantization layer of VQ-VAEs. We smoothly transform each encoder output into its corresponding codebook vector via a rotation and rescaling linear transformation that is treated as a constant during backpropagation. As a result, the relative magnitude and angle between encoder output and codebook vector becomes encoded into the gradient as it propagates through the vector quantization layer and back to the encoder. Across 11 different VQ-VAE training paradigms, we find this restructuring improves reconstruction metrics, codebook utilization, and quantization error.

ICLR Conference 2025 Conference Paper

Scaling Laws for Precision

  • Tanishq Kumar
  • Zachary Ankner
  • Benjamin Frederick Spector
  • Blake Bordelon
  • Niklas Muennighoff
  • Mansheej Paul
  • Cengiz Pehlevan
  • Christopher Ré

Low precision training and inference affect both the quality and cost of language models, but current scaling laws do not account for this. In this work, we devise "precision-aware" scaling laws for both training and inference. We propose that training in lower precision reduces the model's "effective parameter count," allowing us to predict the additional loss incurred from training in low precision and post-train quantization. For inference, we find that the degradation introduced by post-training quantization increases as models are trained on more data, eventually making additional pretraining data actively harmful. For training, our scaling laws allow us to predict the loss of a model with different parts in different precisions, and suggest that training larger models in lower precision can be compute optimal. We unify the scaling laws for post and pretraining quantization to arrive at a single functional form that predicts degradation from training and inference in varied precisions. We fit on over 465 pretraining runs and validate our predictions on model sizes up to 1.7B parameters trained on up to 26B tokens.

ICLR Conference 2025 Conference Paper

ThunderKittens: Simple, Fast, and Adorable Kernels

  • Benjamin Frederick Spector
  • Simran Arora
  • Aaryan Singhal
  • Arjun Parthasarathy
  • Daniel Y. Fu
  • Christopher Ré

The challenge of mapping AI architectures to GPU hardware is creating a critical bottleneck in AI progress. Despite substantial efforts, hand-written custom kernels fail to meet their theoretical performance thresholds, even on well-established operations like linear attention. The diverse capabilities of GPUs suggests we might we need a wide variety of techniques to achieve high performance. However, our work explores if a small number of key abstractions can drastically simplify the process. We present ThunderKittens (TK), a framework for writing performant AI kernels while remaining easy to use. Our abstractions map to the three levels of the GPU hierarchy: (1) at the warp-level, we provide 16x16 matrix tiles as basic data structures and PyTorch-like operations, (2) at the thread-block level, we provide templates for asynchronously overlapping operations, and (3) at the grid-level, TK helps hide block launch, tear-down, and memory costs. We show the value of TK by providing simple & diverse kernels that match or outperform prior art. We match CuBLAS and FlashAttention-3 on GEMM and attention inference performance and outperform the strongest baselines by $10-40$\% on attention backwards, $8\times$ on state space models, and $14\times$ on linear attention.

ICLR Conference 2025 Conference Paper

Towards Learning High-Precision Least Squares Algorithms with Sequence Models

  • Jerry Weihong Liu
  • Jessica Grogan
  • Owen M. Dugan
  • Ashish Rao
  • Simran Arora
  • Atri Rudra
  • Christopher Ré

This paper investigates whether sequence models can learn to perform numerical algorithms, e.g. gradient descent, on the fundamental problem of least squares. Our goal is to inherit two properties of standard algorithms from numerical analysis: (1) machine precision, i.e. we want to obtain solutions that are accurate to near floating point error, and (2) numerical generality, i.e. we want them to apply broadly across problem instances. We find that prior approaches using Transformers fail to meet these criteria, and identify limitations present in existing architectures and training procedures. First, we show that softmax Transformers struggle to perform high-precision multiplications, which prevents them from precisely learning numerical algorithms. Second, we identify an alternate class of architectures, comprised entirely of polynomials, that can efficiently represent high-precision gradient descent iterates. Finally, we investigate precision bottlenecks during training and address them via a high-precision training recipe that reduces stochastic gradient noise. Our recipe enables us to train two polynomial architectures, gated convolutions and linear attention, to perform gradient descent iterates on least squares problems. For the first time, we demonstrate the ability to train to near machine precision. Applied iteratively, our models obtain $100,000\times$ lower MSE than standard Transformers trained end-to-end and they incur a $10,000\times$ smaller generalization gap on out-of-distribution problems. We make progress towards end-to-end learning of numerical algorithms for least squares.

NeurIPS Conference 2025 Conference Paper

Weaver: Shrinking the Generation-Verification Gap by Scaling Compute for Verification

  • Jon Saad-Falcon
  • Estefany Kelly Buchanan
  • Mayee Chen
  • Tzu-Heng (Brian) Huang
  • Brendan McLaughlin
  • Tanvir Bhathal
  • Shang Zhu
  • Ben Athiwaratkun

Verifiers can improve language model (LM) capabilities by providing feedback or selecting the best response from a pool of generated candidates. Currently, high-quality verifiers are either unscalable (e. g. , humans) or limited in utility (e. g. , tools like Lean for formal proofs). While LM judges and reward models have become broadly useful as general-purpose verifiers, a significant performance gap remains between them and oracle verifiers. To help close this gap, we introduce Weaver, a framework for designing a strong verifier by combining multiple weak, imperfect verifiers. First we find that weighted ensembles of verifiers, which typically require learning from labeled data, significantly outperform unweighted combinations due to differences in the verifiers. To reduce the dependency on labeled data, Weaver leverages weak supervision to estimate each verifier’s accuracy and combines their outputs into a unified score that better reflects true response quality. However, directly applying weak supervision algorithms poses several challenges, including inconsistent verifier output formats and handling low-quality verifiers. Weaver addresses these challenges by using dataset statistics to normalize outputs and filter specific verifiers. We study the effectiveness of Weaver in repeated sampling settings, where a model generates multiple candidate responses at test time and a verifier is used to select the correct one. Our evaluations demonstrate that Weaver significantly improves the pass@1 performance across several reasoning and math tasks, achieving o3-mini level accuracy with Llama 3. 3 70B Instruct (a much cheaper non-reasoning model) as the generator, and an ensemble of smaller judge and reward models as the verifiers (86. 2% average). This gain mirrors the jump achieved between GPT-4o and o3-mini (69. 0% vs. 86. 7%), which required extensive finetuning and post-training interventions. To make Weaver more efficient, we train a compact 400M cross-encoder using Weaver's combined output scores. This distilled model retains 98. 7% of Weaver's full accuracy while reducing verification compute by up to 99. 97%.

ICML Conference 2024 Conference Paper

Benchmarking and Building Long-Context Retrieval Models with LoCo and M2-BERT

  • Jon Saad-Falcon
  • Daniel Y. Fu
  • Simran Arora
  • Neel Guha
  • Christopher Ré

Retrieval pipelines are an integral component of many machine learning systems. However, they perform poorly in domains where documents are long (e. g. , 10K tokens or more) and where identifying the relevant document requires synthesizing information across the entire text. Developing long-context retrieval encoders suitable for these domains raises three challenges: (1) how to evaluate long-context retrieval performance, (2) how to pretrain a base language model to represent both short contexts (corresponding to queries) and long contexts (corresponding to documents), and (3) how to finetune this model for retrieval under the batch size limitations imposed by GPU memory constraints. To address these challenges, we first introduce LoCoV1, a 12 task benchmark constructed to measure long-context retrieval where chunking is not possible or not effective. We next present the M2-BERT retrieval encoder, an 80M parameter state-space encoder model built from the Monarch Mixer architecture, capable of scaling to documents up to 32K tokens long. We describe a pretraining data mixture which allows this encoder to process both short and long context sequences, and a finetuning approach that adapts this base model to retrieval with only single-sample batches. Finally, we validate the M2-BERT retrieval encoder on LoCoV1, finding that it outperforms competitive Transformer-based models by at least 22. 2 points, despite containing 90× fewer parameters.

ICLR Conference 2024 Conference Paper

Context-Aware Meta-Learning

  • Christopher Fifty
  • Dennis Duan
  • Ronald Guenther Junkins
  • Ehsan Amid
  • Jure Leskovec
  • Christopher Ré
  • Sebastian Thrun

Large Language Models like ChatGPT demonstrate a remarkable capacity to learn new concepts during inference without any fine-tuning. However, visual models trained to detect new objects during inference have been unable to replicate this ability, and instead either perform poorly or require meta-training and/or fine-tuning on similar objects. In this work, we propose a meta-learning algorithm that emulates Large Language Models by learning new visual concepts during inference without fine-tuning. Our approach leverages a frozen pre-trained feature extractor, and analogous to in-context learning, recasts meta-learning as sequence modeling over datapoints with known labels and a test datapoint with an unknown label. On 8 out of 11 meta-learning benchmarks, our approach---without meta-training or fine-tuning---exceeds or matches the state-of-the-art algorithm, P>M>F, which is meta-trained on these benchmarks.

ICLR Conference 2024 Conference Paper

FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores

  • Daniel Y. Fu
  • Hermann Kumbong
  • Eric Nguyen
  • Christopher Ré

Convolution models with long filters have demonstrated state-of-the-art reasoning abilities in many long-sequence tasks but lag behind the most optimized Transformers in wall-clock time. A major bottleneck is the Fast Fourier Transform (FFT)---which allows long convolutions to run in $O(N\log N)$ time in sequence length $N$ but has poor hardware utilization. In this paper, we study how to optimize the FFT convolution. We find two key bottlenecks: the FFT does not effectively use specialized matrix multiply units, and it incurs expensive I/O between layers of the memory hierarchy. In response, we propose FlashFFTConv. FlashFFTConv uses a matrix decomposition that computes the FFT using matrix multiply units and enables kernel fusion for long sequences, reducing I/O. We also present two sparse convolution algorithms---1) partial convolutions and 2) frequency-sparse convolutions---which can be implemented simply by skipping blocks in the matrix decomposition, enabling further opportunities for memory and compute savings. FlashFFTConv speeds up exact FFT convolutions by up to 8.7$\times$ over PyTorch and achieves up to 4.4$\times$ speedup end-to-end. Given the same compute budget, FlashFFTConv allows Hyena-GPT-s to achieve 2.3 points better perplexity and M2-BERT-base to achieve 3.3 points higher GLUE score---matching models with twice the parameter count. FlashFFTConv also achieves 96.1% accuracy on Path-512, a high-resolution vision task where no model had previously achieved better than 50%. Furthermore, partial convolutions enable longer-sequence models---yielding the first DNA model that can process the longest human genes (2.3M base pairs)---and frequency-sparse convolutions speed up pretrained models while maintaining or improving model quality.

ICML Conference 2024 Conference Paper

Mechanistic Design and Scaling of Hybrid Architectures

  • Michael Poli
  • Armin W. Thomas
  • Eric Nguyen
  • Pragaash Ponnusamy
  • Björn Deiseroth
  • Kristian Kersting
  • Taiji Suzuki
  • Brian L. Hie

The development of deep learning architectures is a resource-demanding process, due to a vast design space, long prototyping times, and high compute costs associated with at-scale model training and evaluation. We set out to simplify this process by grounding it in an end-to-end mechanistic architecture design (MAD) pipeline, encompassing small-scale capability unit tests predictive of scaling laws. Through a suite of synthetic token manipulation tasks such as compression and recall, designed to probe capabilities, we identify and test new hybrid architectures constructed from a variety of computational primitives. We experimentally validate the resulting architectures via an extensive compute-optimal and a new state-optimal scaling law analysis, training over 500 language models between 70M to 7B parameters. Surprisingly, we find MAD synthetics to correlate with compute-optimal perplexity, enabling accurate evaluation of new architectures via isolated proxy tasks. The new architectures found via MAD, based on simple ideas such as hybridization and sparsity, outperform state-of-the-art Transformer, convolutional, and recurrent architectures (Transformer++, Hyena, Mamba) in scaling, both at compute-optimal budgets and in overtrained regimes. Overall, these results provide evidence that performance on curated synthetic tasks can be predictive of scaling laws, and that an optimal architecture should leverage specialized layers via a hybrid topology.

ICML Conference 2024 Conference Paper

Prospector Heads: Generalized Feature Attribution for Large Models & Data

  • Gautam Machiraju
  • Alexander Derry
  • Arjun D. Desai
  • Neel Guha
  • Amir-Hossein Karimi
  • James Y. Zou
  • Russ B. Altman
  • Christopher Ré

Feature attribution, the ability to localize regions of the input data that are relevant for classification, is an important capability for ML models in scientific and biomedical domains. Current methods for feature attribution, which rely on "explaining" the predictions of end-to-end classifiers, suffer from imprecise feature localization and are inadequate for use with small sample sizes and high-dimensional datasets due to computational challenges. We introduce prospector heads, an efficient and interpretable alternative to explanation-based attribution methods that can be applied to any encoder and any data modality. Prospector heads generalize across modalities through experiments on sequences (text), images (pathology), and graphs (protein structures), outperforming baseline attribution methods by up to 26. 3 points in mean localization AUPRC. We also demonstrate how prospector heads enable improved interpretation and discovery of class-specific patterns in input data. Through their high performance, flexibility, and generalizability, prospectors provide a framework for improving trust and transparency for ML models in complex domains.

NeurIPS Conference 2024 Conference Paper

RedPajama: an Open Dataset for Training Large Language Models

  • Maurice Weber
  • Daniel Y. Fu
  • Quentin Anthony
  • Yonatan Oren
  • Shane Adams
  • Anton Alexandrov
  • Xiaozhong Lyu
  • Huu Nguyen

Large language models are increasingly becoming a cornerstone technology in artificial intelligence, the sciences, and society as a whole, yet the optimal strategies for dataset composition and filtering remain largely elusive. Many of the top-performing models lack transparency in their dataset curation and model development processes, posing an obstacle to the development of fully open language models. In this paper, we identify three core data-related challenges that must be addressed to advance open-source language models. These include (1) transparency in model development, including the data curation process, (2) access to large quantities of high-quality data, and (3) availability of artifacts and metadata for dataset curation and analysis. To address these challenges, we release RedPajama-V1, an open reproduction of the LLaMA training dataset. In addition, we release RedPajama-V2, a massive web-only dataset consisting of raw, unfiltered text data together with quality signals and metadata. Together, the RedPajama datasets comprise over 100 trillion tokens spanning multiple domains and with their quality signals facilitate the filtering of data, aiming to inspire the development of numerous new datasets. To date, these datasets have already been used in the training of strong language models used in production, such as Snowflake Arctic, Salesforce's XGen and AI2's OLMo. To provide insight into the quality of RedPajama, we present a series of analyses and ablation studies with decoder-only language models with up to 1. 6B parameters. Our findings demonstrate how quality signals for web data can be effectively leveraged to curate high-quality subsets of the dataset, underscoring the potential of RedPajama to advance the development of transparent and high-performing language models at scale.

ICML Conference 2024 Conference Paper

Simple linear attention language models balance the recall-throughput tradeoff

  • Simran Arora
  • Sabri Eyuboglu
  • Michael Zhang
  • Aman Timalsina
  • Silas Alberti
  • James Y. Zou
  • Atri Rudra
  • Christopher Ré

Recent work has shown that attention-based language models excel at "recall", the ability to ground generations in tokens previously seen in context. However, the efficiency of attention-based models is bottle-necked during inference by the KV-cache’s aggressive memory consumption. In this work, we explore whether we can improve language model efficiency (e. g. by reducing memory consumption) without compromising on recall. By applying experiments and theory to a broad set of architectures, we identify a key tradeoff between a model’s recurrent state size and recall ability. We show that efficient alternatives to attention (e. g. H3, Mamba, RWKV) maintain a fixed-size recurrent state, but struggle at recall. We propose BASED a simple architecture combining linear and sliding window attention. By varying BASED window size and linear attention feature dimension, we can dial the state size and traverse the Pareto frontier of the recall-memory tradeoff curve, recovering the full quality of attention on one end and the small state size of attention-alternatives on the other. We train language models up to $1. 3$b parameters and show that BASED matches the strongest sub-quadratic models (e. g. Mamba) in perplexity and outperforms them on real-world recall-intensive tasks by 10. 36 accuracy points. We further develop IO-aware algorithms that enable BASED to provide 24× higher throughput on language generation than FlashAttention-2, when generating 1024 tokens using 1. 3b parameter models. Overall, BASED expands the Pareto frontier of the throughput-recall tradeoff space beyond prior architectures.

NeurIPS Conference 2024 Conference Paper

Smoothie: Label Free Language Model Routing

  • Neel Guha
  • Mayee F. Chen
  • Trevor Chow
  • Ishan S. Khare
  • Christopher Ré

Large language models (LLMs) are increasingly used in applications where LLM inputs may span many different tasks. Recent work has found that the choice of LLM is consequential, and different LLMs may be good for different input samples. Prior approaches have thus explored how engineers might select an LLM to use for each sample (i. e. routing ). While existing routing methods mostly require training auxiliary models on human-annotated data, our work explores whether it is possible to perform unsupervised routing. We propose Smoothie, a weak supervision-inspired routing approach that requires no labeled data. Given a set of outputs from different LLMs, Smoothie constructs a latent variable graphical model over embedding representations of observable LLM outputs and unknown “true” outputs. Using this graphical model, we estimate sample-dependent quality scores for each LLM, and route each sample to the LLM with the highest corresponding score. We find that Smoothie's LLM quality-scores correlate with ground-truth model quality (correctly identifying the optimal model on 9/14 tasks), and that Smoothie outperforms baselines for routing by up to 10 points accuracy.

ICML Conference 2024 Conference Paper

State-Free Inference of State-Space Models: The *Transfer Function* Approach

  • Rom N. Parnichkun
  • Stefano Massaroli
  • Alessandro Moro
  • Jimmy T. H. Smith
  • Ramin M. Hasani
  • Mathias Lechner
  • Qi An
  • Christopher Ré

We approach designing a state-space model for deep learning applications through its dual representation, the transfer function, and uncover a highly efficient sequence parallel inference algorithm that is state-free: unlike other proposed algorithms, state-free inference does not incur any significant memory or computational cost with an increase in state size. We achieve this using properties of the proposed frequency domain transfer function parametrization, which enables direct computation of its corresponding convolutional kernel’s spectrum via a single Fast Fourier Transform. Our experimental results across multiple sequence lengths and state sizes illustrates, on average, a 35% training speed improvement over S4 layers – parametrized in time-domain – on the Long Range Arena benchmark, while delivering state-of-the-art downstream performances over other attention-free approaches. Moreover, we report improved perplexity in language modeling over a long convolutional Hyena baseline, by simply introducing our transfer function parametrization. Our code is available at https: //github. com/ruke1ire/RTF.

ICLR Conference 2024 Conference Paper

The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry

  • Michael Zhang
  • Kush Bhatia
  • Hermann Kumbong
  • Christopher Ré

Linear attentions have shown promise for improving Transformer efficiency, reducing attention's quadratic complexity to linear in sequence length. This holds exciting promise for (1) training linear Transformers from scratch, (2) `inetuned-conversion of task-specific Transformers into linear versions that recover task performance, and (3) pretrained-conversion of Transformers, such as language models, into linear versions readily finetunable on downstream tasks. However, linear attentions often underperform compared to standard softmax attention. To close this performance gap, we study the behaviors of softmax and linear attentions in various train-from-scratch and finetuned-conversion settings. We find prior linear attentions lack key properties of softmax attention tied to good performance: low-entropy (or spiky) weights and dot-product monotonicity. We further observe surprisingly simple feature maps that retain these properties match softmax performance, but are inefficient to compute in linear attention. We thus propose Hedgehog, a learnable linear attention that retains the spiky and monotonic properties of softmax attention while maintaining linear complexity. Hedgehog uses simple, trainable MLPs to produce attention weights mimicking softmax attention. Experiments show Hedgehog recovers over 99\% of standard Transformer performance in train-from-scratch and finetuned-conversion settings, outperforming prior linear attentions by up to 6 perplexity points on WikiText-103 when training causal GPT models from scratch, and up to 8.7 GLUE score points when converting finetuned bidirectional BERT models. Hedgehog also enables pretrained-conversion. Converting a pretrained GPT-2 into a linear attention variant achieves state-of-the-art 16.7 perplexity on WikiText-103 for 125M subquadratic decoder models. We finally turn a pretrained Llama-2 7B into a viable linear attention Llama. With low-rank adaptation, Hedgehog-Llama-2 7B achieves 28.1 higher ROUGE-1 points over the base standard attention model, where prior linear attentions lead to 16.5 point drops.

NeurIPS Conference 2024 Conference Paper

WONDERBREAD: A Benchmark for Evaluating Multimodal Foundation Models on Business Process Management Tasks

  • Michael Wornow
  • Avanika Narayan
  • Ben Viggiano
  • Ishan S. Khare
  • Tathagat Verma
  • Tibor Thompson
  • Miguel A. Hernandez
  • Sudharsan Sundar

Existing ML benchmarks lack the depth and diversity of annotations needed for evaluating models on business process management (BPM) tasks. BPM is the practice of documenting, measuring, improving, and automating enterprise workflows. However, research has focused almost exclusively on one task -- full end-to-end automation using agents based on multimodal foundation models (FMs) like GPT-4. This focus on automation ignores the reality of how most BPM tools are applied today -- simply documenting the relevant workflow takes 60% of the time of the typical process optimization project. To address this gap we present WONDERBREAD, the first benchmark for evaluating multimodal FMs on BPM tasks beyond automation. Our contributions are: (1) a dataset containing 2928 documented workflow demonstrations; (2) 6 novel BPM tasks sourced from real-world applications ranging from workflow documentation to knowledge transfer to process improvement; and (3) an automated evaluation harness. Our benchmark shows that while state-of-the-art FMs can automatically generate documentation (e. g. recalling 88% of the steps taken in a video demonstration of a workflow), they struggle to re-apply that knowledge towards finer-grained validation of workflow completion (F1 < 0. 3). We hope WONDERBREAD encourages the development of more "human-centered" AI tooling for enterprise applications and furthers the exploration of multimodal FMs for the broader universe of BPM tasks. We publish our dataset and experiments here: https: //github. com/HazyResearch/wonderbread

ICLR Conference 2024 Conference Paper

Zoology: Measuring and Improving Recall in Efficient Language Models

  • Simran Arora
  • Sabri Eyuboglu
  • Aman Timalsina
  • Isys Johnson
  • Michael Poli
  • James Y. Zou
  • Atri Rudra
  • Christopher Ré

Attention-free language models that combine gating and convolutions are growing in popularity due to their efficiency and increasingly competitive performance. To better understand these architectures, we pretrain a suite of 17 attention and gated-convolution language models, finding that SoTA gated-convolution architectures still underperform attention by up to 2.1 perplexity points on the Pile. In fine-grained analysis, we find 82% of the gap is explained by each model's ability to recall information that is previously mentioned in-context, e.g. "Hakuna Matata means no worries Hakuna Matata it means no" -> ??. On this task, termed "associative recall", we find that attention outperforms gated-convolutions by a large margin: a 70M parameter attention model outperforms a 1.4 billion parameter gated-convolution model on associative recall. This is surprising because prior work shows gated convolutions can perfectly solve synthetic tests for AR capability. To close the gap between synthetics and real language, we develop a new formalization of the task called multi-query associative recall (MQAR) that better reflects actual language. We perform an empirical and theoretical study of MQAR that elucidates differences in the parameter-efficiency of attention and gated-convolution recall. Informed by our analysis, we evaluate simple convolution-attention hybrids and show that hybrids with input-dependent sparse attention patterns can close 97.4% of the gap to attention, while maintaining sub-quadratic scaling. Code is at: https://github.com/HazyResearch/zoology.

NeurIPS Conference 2023 Conference Paper

A case for reframing automated medical image classification as segmentation

  • Sarah Hooper
  • Mayee Chen
  • Khaled Saab
  • Kush Bhatia
  • Curtis Langlotz
  • Christopher Ré

Image classification and segmentation are common applications of deep learning to radiology. While many tasks can be framed using either classification or segmentation, classification has historically been cheaper to label and more widely used. However, recent work has drastically reduced the cost of training segmentation networks. In light of this recent work, we reexamine the choice of training classification vs. segmentation models. First, we use an information theoretic approach to analyze why segmentation vs. classification models may achieve different performance on the same dataset and overarching task. We then implement multiple methods for using segmentation models to classify medical images, which we call segmentation-for-classification, and compare these methods against traditional classification on three retrospective datasets. We use our analysis and experiments to summarize the benefits of switching from segmentation to classification, including: improved sample efficiency, enabling improved performance with fewer labeled images (up to an order of magnitude lower), on low-prevalence classes, and on certain rare subgroups (up to 161. 1\% improved recall); improved robustness to spurious correlations (up to 44. 8\% improved robust AUROC); and improved model interpretability, evaluation, and error analysis.

ICLR Conference 2023 Conference Paper

Ask Me Anything: A simple strategy for prompting language models

  • Simran Arora
  • Avanika Narayan
  • Mayee F. Chen
  • Laurel J. Orr
  • Neel Guha
  • Kush Bhatia
  • Ines Chami
  • Christopher Ré

Large language models (LLMs) transfer well to new tasks out-of-the-box simply given a natural language prompt that demonstrates how to perform the task and no additional training. Prompting is a brittle process wherein small modifications to the prompt can cause large variations in the model predictions, and therefore significant effort is dedicated towards designing a painstakingly crafted "perfect prompt" for a task. To mitigate the high degree of effort, we instead ask whether collecting multiple decent, yet imperfect, prompts and aggregating them can lead to a high quality prompting strategy. Our observations motivate our proposed method, Ask Me Anything (AMA). We first develop an understanding of the effective prompt formats, finding question-answering (QA) prompts, which encourage open-ended generation ("Who went to the park?") tend to outperform those that restrict the model outputs ("John went to the park. True or False?"). AMA recursively uses the LLM to transform task inputs to the effective QA format. AM generates multiple questions per input and applies these prompts to collect several noisy "votes" for the input's true label. We find the prompts have varying accuracies and dependencies and thus propose to use weak supervision, a procedure for combining the noisy predictions, to produce the final predictions. We evaluate AMA across open-source model families (EleutherAI, BLOOM, OPT, and T0) and sizes (125M-175B parameters), demonstrating an average performance lift of 10.2\% over the few-shot baseline. This simple strategy enables the open-source GPT-J-6B model to match and exceed the performance of few-shot GPT3-175B on 15 of 20 popular benchmarks. Averaged across these tasks, the GPT-J-6B model outperforms few-shot GPT3-175B. We release our code here: https://github.com/HazyResearch/ama_prompting.

ICML Conference 2023 Conference Paper

CocktailSGD: Fine-tuning Foundation Models over 500Mbps Networks

  • Jue Wang
  • Yucheng Lu 0003
  • Binhang Yuan
  • Beidi Chen
  • Percy Liang
  • Christopher De Sa
  • Christopher Ré
  • Ce Zhang 0001

Distributed training of foundation models, especially large language models (LLMs), is communication-intensive and so has heavily relied on centralized data centers with fast interconnects. Can we train on slow networks and unlock the potential of decentralized infrastructure for foundation models? In this paper, we propose CocktailSGD, a novel communication-efficient training framework that combines three distinct compression techniques – random sparsification, top-K sparsification, and quantization – to achieve much greater compression than each individual technique alone. We justify the benefit of such a hybrid approach through a theoretical analysis of convergence. Empirically, we show that CocktailSGD achieves up to 117$\times$ compression in fine-tuning LLMs up to 20 billion parameters without hurting convergence. On a 500Mbps network, CocktailSGD only incurs $\sim$1. 2$\times$ slowdown compared with data center networks.

ICML Conference 2023 Conference Paper

Deja Vu: Contextual Sparsity for Efficient LLMs at Inference Time

  • Zichang Liu
  • Jue Wang
  • Tri Dao
  • Tianyi Zhou 0002
  • Binhang Yuan
  • Zhao Song 0002
  • Anshumali Shrivastava
  • Ce Zhang 0001

Large language models (LLMs) with hundreds of billions of parameters have sparked a new wave of exciting AI applications. However, they are computationally expensive at inference time. Sparsity is a natural approach to reduce this cost, but existing methods either require costly retraining, have to forgo LLM’s in-context learning ability, or do not yield wall-clock time speedup on modern hardware. We hypothesize that contextual sparsity, which are small, input-dependent sets of attention heads and MLP parameters that yield approximately the same output as the dense model for a given input, can address these issues. We show that contextual sparsity exists, that it can be accurately predicted, and that we can exploit it to speed up LLM inference in wall-clock time without compromising LLM’s quality or in-context learning ability. Based on these insights, we propose DejaVu, a system that uses a low-cost algorithm to predict contextual sparsity on the fly given inputs to each layer, along with an asynchronous and hardware-aware implementation that speeds up LLM inference. We validate that DejaVu can reduce the inference latency of OPT-175B by over 2$\times$ compared to the state-of-the-art FasterTransformer, and over 6$\times$ compared to the widely used Hugging Face implementation, without compromising model quality. The code is available at https: //github. com/FMInference/DejaVu.

ICLR Conference 2023 Conference Paper

Effectively Modeling Time Series with Simple Discrete State Spaces

  • Michael Zhang
  • Khaled Kamal Saab
  • Michael Poli
  • Tri Dao
  • Karan Goel
  • Christopher Ré

Time series modeling is a well-established problem, which often requires that methods (1) expressively represent complicated dependencies, (2) forecast long horizons, and (3) efficiently train over long sequences. State-space models (SSMs) are classical models for time series, and prior works combine SSMs with deep learning layers for efficient sequence modeling. However, we find fundamental limitations with these prior approaches, proving their SSM representations cannot express autoregressive time series processes. We thus introduce SpaceTime, a new state-space time series architecture that improves all three criteria. For expressivity, we propose a new SSM parameterization based on the companion matrix---a canonical representation for discrete-time processes---which enables SpaceTime's SSM layers to learn desirable autoregressive processes. For long horizon forecasting, we introduce a "closed-loop" variation of the companion SSM, which enables SpaceTime to predict many future time-steps by generating its own layer-wise inputs. For efficient training and inference, we introduce an algorithm that reduces the memory and compute of a forward pass with the companion matrix. With sequence length $\ell$ and state-space size $d$, we go from $\tilde{O}(d \ell)$ naïvely to $\tilde{O}(d + \ell)$. In experiments, our contributions lead to state-of-the-art results on extensive and diverse benchmarks, with best or second-best AUROC on 6 / 7 ECG and speech time series classification, and best MSE on 14 / 16 Informer forecasting tasks. Furthermore, we find SpaceTime (1) fits AR($p$) processes that prior deep SSMs fail on, (2) forecasts notably more accurately on longer horizons than prior state-of-the-art, and (3) speeds up training on real-world ETTh1 data by 73% and 80% relative wall-clock time over Transformers and LSTMs.

NeurIPS Conference 2023 Conference Paper

Embroid: Unsupervised Prediction Smoothing Can Improve Few-Shot Classification

  • Neel Guha
  • Mayee Chen
  • Kush Bhatia
  • Azalia Mirhoseini
  • Frederic Sala
  • Christopher Ré

Recent work has shown that language models' (LMs) prompt-based learning capabilities make them well suited for automating data labeling in domains where manual annotation is expensive. The challenge is that while writing an initial prompt is cheap, improving a prompt is costly---practitioners often require significant labeled data in order to evaluate the impact of prompt modifications. Our work asks whether it is possible to improve prompt-based learning without additional labeled data. We approach this problem by attempting to modify the predictions of a prompt, rather than the prompt itself. Our intuition is that accurate predictions should also be consistent: samples which are similar under some feature representation should receive the same prompt prediction. We propose Embroid, a method which computes multiple representations of a dataset under different embedding functions, and uses the consistency between the LM predictions for neighboring samples to identify mispredictions. Embroid then uses these neighborhoods to create additional predictions for each sample, and combines these predictions with a simple latent variable graphical model in order to generate a final corrected prediction. In addition to providing a theoretical analysis of Embroid, we conduct a rigorous empirical evaluation across six different LMs and up to 95 different tasks. We find that (1) Embroid substantially improves performance over original prompts (e. g. , by an average of 7. 3 points on GPT-JT), (2) also realizes improvements for more sophisticated prompting strategies (e. g. , chain-of-thought), and (3) can be specialized to domains like law through the embedding functions.

ICML Conference 2023 Conference Paper

FlexGen: High-Throughput Generative Inference of Large Language Models with a Single GPU

  • Ying Sheng 0007
  • Lianmin Zheng
  • Binhang Yuan
  • Zhuohan Li 0001
  • Max Ryabinin
  • Beidi Chen
  • Percy Liang
  • Christopher Ré

The high computational and memory requirements of large language model (LLM) inference make it feasible only with multiple high-end accelerators. Motivated by the emerging demand for latency-insensitive tasks with batched processing, this paper initiates the study of high-throughput LLM inference using limited resources, such as a single commodity GPU. We present FlexGen, a high-throughput generation engine for running LLMs with limited GPU memory. FlexGen can be flexibly configured under various hardware resource constraints by aggregating memory and computation from the GPU, CPU, and disk. By solving a linear programming problem, it searches for efficient patterns to store and access tensors. FlexGen further compresses the weights and the attention cache to 4 bits with negligible accuracy loss. These techniques enable FlexGen to have a larger space of batch size choices and thus significantly increase maximum throughput. As a result, when running OPT-175B on a single 16GB GPU, FlexGen achieves significantly higher throughput compared to state-of-the-art offloading systems, reaching a generation throughput of 1 token/s for the first time with an effective batch size of 144. On the HELM benchmark, FlexGen can benchmark a 30B model with a 16GB GPU on 7 representative sub-scenarios in 21 hours. The code is available at https: //github. com/FMInference/FlexGen.

NeurIPS Conference 2023 Conference Paper

H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models

  • Zhenyu Zhang
  • Ying Sheng
  • Tianyi Zhou
  • Tianlong Chen
  • Lianmin Zheng
  • Ruisi Cai
  • Zhao Song
  • Yuandong Tian

Large Language Models (LLMs), despite their recent impressive accomplishments, are notably cost-prohibitive to deploy, particularly for applications involving long-content generation, such as dialogue systems and story writing. Often, a large amount of transient state information, referred to as the $\mathsf{KV}$ $\mathsf{cache}$, is stored in GPU memory in addition to model parameters, scaling linearly with the sequence length and batch size. In this paper, we introduce a novel approach for implementing the $\mathsf{KV}$ $\mathsf{cache}$ which significantly reduces its memory footprint. Our approach is based on the noteworthy observation that a small portion of tokens contributes most of the value when computing attention scores. We call these tokens Heavy Hitters ($\mathsf{H_2}$). Through a comprehensive investigation, we find that ($i$) the emergence of $\mathsf{H_2}$ is natural and strongly correlates with the frequent co-occurrence of tokens in the text, and ($ii$) removing them results in significant performance degradation. Based on these insights, we propose Heavy Hitter Oracle ($\mathsf{H_2O}$), a $\mathsf{KV}$ $\mathsf{cache}$ eviction policy that dynamically retains a balance of recent and $\mathsf{H_2}$ tokens. We formulate the $\mathsf{KV}$ $\mathsf{cache}$ eviction as a dynamic submodular problem and prove (under mild assumptions) a theoretical guarantee for our novel eviction algorithm which could help guide future work. We validate the accuracy of our algorithm with OPT, LLaMA, and GPT-NeoX across a wide range of tasks. Our implementation of $\mathsf{H_2O}$ with 20\% heavy hitters improves the throughput over three leading inference systems DeepSpeed Zero-Inference, Hugging Face Accelerate, and FlexGen by up to $29\times$, $29\times$, and $3\times$ on OPT-6. 7B and OPT-30B. With the same batch size, $\mathsf{H_2O}$ can reduce the latency by up to $1. 9\times$.

AAAI Conference 2023 System Paper

HAPI Explorer: Comprehension, Discovery, and Explanation on History of ML APIs

  • Lingjiao Chen
  • Zhihua Jin
  • Sabri Eyuboglu
  • Huamin Qu
  • Christopher Ré
  • Matei Zaharia
  • James Zou

Machine learning prediction APIs offered by Google, Microsoft, Amazon, and many other providers have been continuously adopted in a plethora of applications, such as visual object detection, natural language comprehension, and speech recognition. Despite the importance of a systematic study and comparison of different APIs over time, this topic is currently under-explored because of the lack of data and user-friendly exploration tools. To address this issue, we present HAPI Explorer (History of API Explorer), an interactive system that offers easy access to millions of instances of commercial API applications collected in three years, prioritize attention on user-defined instance regimes, and explain interesting patterns across different APIs, subpopulations, and time periods via visual and natural languages. HAPI Explorer can facilitate further comprehension and exploitation of ML prediction APIs.

ICLR Conference 2023 Conference Paper

How to Train your HIPPO: State Space Models with Generalized Orthogonal Basis Projections

  • Albert Gu
  • Isys Johnson
  • Aman Timalsina
  • Atri Rudra
  • Christopher Ré

Linear time-invariant state space models (SSM) are a classical model from engineering and statistics, that have recently been shown to be very promising in machine learning through the Structured State Space sequence model (S4). A core component of S4 involves initializing the SSM state matrix to a particular matrix called a HiPPO matrix, which was empirically important for S4's ability to handle long sequences. However, the specific matrix that S4 uses was actually derived in previous work for a particular *time-varying* dynamical system, and the use of this matrix as a *time-invariant* SSM had no known mathematical interpretation. Consequently, the theoretical mechanism by which S4 models long-range dependencies actually remains unexplained. We derive a more general and intuitive formulation of the HiPPO framework, which provides a simple mathematical interpretation of S4 as a decomposition onto exponentially-warped Legendre polynomials, explaining its ability to capture long dependencies. Our generalization introduces a theoretically rich class of SSMs that also lets us derive more intuitive S4 variants for other bases such as the Fourier basis, and explains other aspects of training S4, such as how to initialize the important timescale parameter. These insights improve S4's performance to 86% on the Long Range Arena benchmark, with 96% on the most difficult Path-X task.

ICLR Conference 2023 Conference Paper

Hungry Hungry Hippos: Towards Language Modeling with State Space Models

  • Daniel Y. Fu
  • Tri Dao
  • Khaled Kamal Saab
  • Armin W. Thomas
  • Atri Rudra
  • Christopher Ré

State space models (SSMs) have demonstrated state-of-the-art sequence modeling performance in some modalities, but underperform attention in language modeling. Moreover, despite scaling nearly linearly in sequence length instead of quadratically, SSMs are still slower than Transformers due to poor hardware utilization. In this paper, we make progress on understanding the expressivity gap between SSMs and attention in language modeling, and on reducing the hardware barrier between SSMs and attention. First, we use synthetic language modeling tasks to understand the gap between SSMs and attention. We find that existing SSMs struggle with two capabilities: recalling earlier tokens in the sequence and comparing tokens across the sequence. To understand the impact on language modeling, we propose a new SSM layer, H3, that is explicitly designed for these abilities. H3 matches attention on the synthetic languages and comes within 0.4 PPL of Transformers on OpenWebText. Furthermore, a hybrid 125M-parameter H3-attention model that retains two attention layers surprisingly outperforms Transformers on OpenWebText by 1.0 PPL. Next, to improve the efficiency of training SSMs on modern hardware, we propose FlashConv. FlashConv uses a fused block FFT algorithm to improve efficiency on sequences up to 8K, and introduces a novel state passing algorithm that exploits the recurrent properties of SSMs to scale to longer sequences. FlashConv yields 2$\times$ speedup on the long-range arena benchmark and allows hybrid language models to generate text 2.4$\times$ faster than Transformers. Using FlashConv, we scale hybrid H3-attention language models up to 2.7B parameters on the Pile and find promising initial results, achieving lower perplexity than Transformers and outperforming Transformers in zero- and few-shot learning on a majority of tasks in the SuperGLUE benchmark.

ICML Conference 2023 Conference Paper

Hyena Hierarchy: Towards Larger Convolutional Language Models

  • Michael Poli
  • Stefano Massaroli
  • Eric Nguyen
  • Daniel Y. Fu
  • Tri Dao
  • Stephen Baccus
  • Yoshua Bengio
  • Stefano Ermon

Recent advances in deep learning have relied heavily on the use of large Transformers due to their ability to learn at scale. However, the core building block of Transformers, the attention operator, exhibits quadratic cost in sequence length, limiting the amount of context accessible. Existing subquadratic methods based on low-rank and sparse approximations need to be combined with dense attention layers to match Transformers at scale, indicating a gap in capability. In this work, we propose Hyena, a subquadratic drop-in replacement for attention constructed by interleaving implicitly parametrized long convolutions and data-controlled gating. In challenging reasoning tasks on sequences of thousands to hundreds of thousands of tokens, Hyena improves accuracy by more than 50 points over operators relying on state-space models, transfer functions, and other implicit and explicit methods, matching attention-based models. We set a new state-of-the-art for dense-attention-free architectures on language modeling in standard datasets WikiText103 and The Pile, reaching Transformer quality with a 20% reduction in training compute required at sequence length 2k. Hyena operators are 2x faster than highly optimized attention at sequence length 8k, with speedups of 100x at 64k.

NeurIPS Conference 2023 Conference Paper

HyenaDNA: Long-Range Genomic Sequence Modeling at Single Nucleotide Resolution

  • Eric Nguyen
  • Michael Poli
  • Marjan Faizi
  • Armin Thomas
  • Michael Wornow
  • Callum Birch-Sykes
  • Stefano Massaroli
  • Aman Patel

Genomic (DNA) sequences encode an enormous amount of information for gene regulation and protein synthesis. Similar to natural language models, researchers have proposed foundation models in genomics to learn generalizable features from unlabeled genome data that can then be fine-tuned for downstream tasks such as identifying regulatory elements. Due to the quadratic scaling of attention, previous Transformer-based genomic models have used 512 to 4k tokens as context (<0. 001% of the human genome), significantly limiting the modeling of long-range interactions in DNA. In addition, these methods rely on tokenizers or fixed k-mers to aggregate meaningful DNA units, losing single nucleotide resolution (i. e. DNA "characters") where subtle genetic variations can completely alter protein function via single nucleotide polymorphisms (SNPs). Recently, Hyena, a large language model based on implicit convolutions was shown to match attention in quality while allowing longer context lengths and lower time complexity. Leveraging Hyena’s new long-range capabilities, we present HyenaDNA, a genomic foundation model pretrained on the human reference genome with context lengths of up to 1 million tokens at the single nucleotide-level – an up to 500x increase over previous dense attention-based models. HyenaDNA scales sub-quadratically in sequence length (training up to 160x faster than Transformer), uses single nucleotide tokens, and has full global context at each layer. We explore what longer context enables - including the first use of in-context learning in genomics for simple adaptation to novel tasks without updating pretrained model weights. On fine-tuned benchmarks from the Nucleotide Transformer, HyenaDNA reaches state-of-the-art (SotA) on 12 of 18 datasets using a model with orders of magnitude less parameters and pretraining data. 1 On the GenomicBenchmarks, HyenaDNA surpasses SotA on 7 of 8 datasets on average by +10 accuracy points. Code at https: //github. com/HazyResearch/hyena-dna.

NeurIPS Conference 2023 Conference Paper

Laughing Hyena Distillery: Extracting Compact Recurrences From Convolutions

  • Stefano Massaroli
  • Michael Poli
  • Dan Fu
  • Hermann Kumbong
  • Rom Parnichkun
  • David Romero
  • Aman Timalsina
  • Quinn McIntyre

Recent advances in attention-free sequence models rely on convolutions as alternatives to the attention operator at the core of Transformers. In particular, long convolution sequence models have achieved state-of-the-art performance in many domains, but incur a significant cost during auto-regressive inference workloads -- naively requiring a full pass (or caching of activations) over the input sequence for each generated token -- similarly to attention-based models. In this paper, we seek to enable $\mathcal O(1)$ compute and memory cost per token in any pre-trained long convolution architecture to reduce memory footprint and increase throughput during generation. Concretely, our methods consist in extracting low-dimensional linear state-space models from each convolution layer, building upon rational interpolation and model-order reduction techniques. We further introduce architectural improvements to convolution-based layers such as Hyena: by weight-tying the filters across channels into heads, we achieve higher pre-training quality and reduce the number of filters to be distilled. The resulting model achieves 10x higher throughput than Transformers and 1. 5x higher than Hyena at 1. 3B parameters, without any loss in quality after distillation.

NeurIPS Conference 2023 Conference Paper

LegalBench: A Collaboratively Built Benchmark for Measuring Legal Reasoning in Large Language Models

  • Neel Guha
  • Julian Nyarko
  • Daniel Ho
  • Christopher Ré
  • Adam Chilton
  • Aditya K
  • Alex Chohlas-Wood
  • Austin Peters

The advent of large language models (LLMs) and their adoption by the legal community has given rise to the question: what types of legal reasoning can LLMs perform? To enable greater study of this question, we present LegalBench: a collaboratively constructed legal reasoning benchmark consisting of 162 tasks covering six different types of legal reasoning. LegalBench was built through an interdisciplinary process, in which we collected tasks designed and hand-crafted by legal professionals. Because these subject matter experts took a leading role in construction, tasks either measure legal reasoning capabilities that are practically useful, or measure reasoning skills that lawyers find interesting. To enable cross-disciplinary conversations about LLMs in the law, we additionally show how popular legal frameworks for describing legal reasoning—which distinguish between its many forms—correspond to LegalBench tasks, thus giving lawyers and LLM developers a common vocabulary. This paper describes LegalBench, presents an empirical evaluation of 20 open-source and commercial LLMs, and illustrates the types of research explorations LegalBench enables.

NeurIPS Conference 2023 Conference Paper

Monarch Mixer: A Simple Sub-Quadratic GEMM-Based Architecture

  • Dan Fu
  • Simran Arora
  • Jessica Grogan
  • Isys Johnson
  • Evan Sabri Eyuboglu
  • Armin Thomas
  • Benjamin Spector
  • Michael Poli

Machine learning models are increasingly being scaled in both sequence length and model dimension to reach longer contexts and better performance. However, existing architectures such as Transformers scale quadratically along both these axes. We ask: are there performant architectures that can scale sub-quadratically along sequence length and model dimension? We introduce Monarch Mixer (M2), a new architecture that uses the same sub-quadratic primitive along both sequence length and model dimension: Monarch matrices, a simple class of expressive structured matrices that captures many linear transforms, achieves high hardware efficiency on GPUs, and scales sub-quadratically. As a proof of concept, we explore the performance of M2 in three domains: non-causal BERT-style language modeling, ViT-style image classification, and causal GPT-style language modeling. For non-causal BERT-style modeling, M2 matches BERT-base and BERT-large in downstream GLUE quality with up to 27% fewer parameters, and achieves up to 9. 1$\times$ higher throughput at sequence length 4K. On ImageNet, M2 outperforms ViT-b by 1% in accuracy, with only half the parameters. Causal GPT-style models introduce a technical challenge: enforcing causality via masking introduces a quadratic bottleneck. To alleviate this bottleneck, we develop a novel theoretical view of Monarch matrices based on multivariate polynomial evaluation and interpolation, which lets us parameterize M2 to be causal while remaining sub-quadratic. Using this parameterization, M2 matches GPT-style Transformers at 360M parameters in pretraining perplexity on The PILE—showing for the first time that it may be possible to match Transformer quality without attention or MLPs.

ICML Conference 2023 Conference Paper

Simple Hardware-Efficient Long Convolutions for Sequence Modeling

  • Daniel Y. Fu
  • Elliot L. Epstein
  • Eric Nguyen
  • Armin W. Thomas
  • Michael Zhang
  • Tri Dao
  • Atri Rudra
  • Christopher Ré

State space models (SSMs) have high performance on long sequence modeling but require sophisticated initialization techniques and specialized implementations for high quality and runtime performance. We study whether a simple alternative can match SSMs in performance and efficiency: directly learning long convolutions over the sequence. We find that a key requirement to achieving high performance is keeping the convolution kernels smooth. We find that simple interventions-such as squashing the kernel weights-result in smooth kernels and recover SSM performance on a range of tasks including the long range arena, image classification, language modeling, and brain data modeling. Next, we develop FlashButterfly, an IO-aware algorithm to improve the runtime performance of long convolutions. FlashButterfly appeals to classic Butterfly decompositions of the convolution to reduce GPU memory IO and increase FLOP utilization. FlashButterfly speeds up convolutions by 2. 2$\times$, and allows us to train on Path256, a challenging task with sequence length 64K, where we set state-of-the-art by 29. 1 points while training 7. 2$\times$ faster than prior work. Lastly, we introduce an extension to FlashButterfly that learns the coefficients of the Butterfly decomposition, increasing expressivity without increasing runtime. Using this extension, we outperform a Transformer on WikiText103 by 0. 2 PPL with 30% fewer parameters.

NeurIPS Conference 2023 Conference Paper

Skill-it! A data-driven skills framework for understanding and training language models

  • Mayee Chen
  • Nicholas Roberts
  • Kush Bhatia
  • Jue Wang
  • Ce Zhang
  • Frederic Sala
  • Christopher Ré

The quality of training data impacts the performance of pre-trained large language models (LMs). Given a fixed budget of tokens, we study how to best select data that leads to good downstream model performance across tasks. We develop a new framework based on a simple hypothesis: just as humans acquire interdependent skills in a deliberate order, language models also follow a natural order when learning a set of skills from their training data. If such an order exists, it can be utilized for improved understanding of LMs and for data-efficient training. Using this intuition, our framework formalizes the notion of a skill and of an ordered set of skills in terms of the associated data. First, using both synthetic and real data, we demonstrate that these ordered skill sets exist, and that their existence enables more advanced skills to be learned with less data when we train on their prerequisite skills. Second, using our proposed framework, we introduce an online data sampling algorithm, Skill-It, over mixtures of skills for both continual pre-training and fine-tuning regimes, where the objective is to efficiently learn multiple skills in the former and an individual skill in the latter. On the LEGO synthetic in the continual pre-training setting, Skill-It obtains 37. 5 points higher accuracy than random sampling. On the Natural Instructions dataset in the fine-tuning setting, Skill-It reduces the validation loss on the target skill by 13. 6% versus training on data associated with the target skill itself. We apply our skills framework on the RedPajama dataset to continually pre-train a 3B-parameter LM, achieving higher accuracy on the LM Evaluation Harness with 1B tokens than the baseline approach of sampling uniformly over data sources with 3B tokens.

NeurIPS Conference 2023 Conference Paper

TART: A plug-and-play Transformer module for task-agnostic reasoning

  • Kush Bhatia
  • Avanika Narayan
  • Christopher M. De Sa
  • Christopher Ré

Large language models (LLMs) exhibit in-context learning abilities which enable the same model to perform several tasks without any task-specific training. In contrast, traditional adaptation approaches, such as fine-tuning, modify the underlying models for each specific task. In-context learning, however, consistently underperforms task-specific tuning approaches even when presented with the same examples. While most existing approaches (e. g. , prompt engineering) focus on the LLM's learned representations to patch this performance gap, our experiments actually reveal that LLM representations contain sufficient information to make good predictions. As such, we focus on the LLM's reasoning abilities and demonstrate that this performance gap exists due to their inability to perform simple probabilistic reasoning tasks. This raises an intriguing question: Are LLMs actually capable of learning how to reason in a task-agnostic manner? We answer this in the affirmative and, as a proof of concept, propose TART which generically improves an LLM's reasoning abilities using a synthetically trained reasoning module. TART trains this Transformer-based reasoning module in a task-agnostic manner using only synthetic logistic regression tasks and composes it with an arbitrary real-world pre-trained model without any additional training. With a single inference module, TART improves performance across different model families (GPT-Neo, Pythia, Bloom), model sizes (100M - 6B), tasks (14 NLP classification tasks), and even across different modalities (audio and vision). On the RAFT Benchmark, TART improves GPT-Neo (125M)'s performance such that it outperforms Bloom (176B), and is within $4$% of GPT-3.

NeurIPS Conference 2022 Conference Paper

Contrastive Adapters for Foundation Model Group Robustness

  • Michael Zhang
  • Christopher Ré

While large pretrained foundation models (FMs) have shown remarkable zero-shot classification robustness to dataset-level distribution shifts, their robustness to subpopulation or group shifts is relatively underexplored. We study this problem, and find that foundation models such as CLIP may not be robust to various group shifts. Across 9 robustness benchmarks, zero-shot classification with their embeddings results in gaps of up to 80. 7 percentage points (pp) between average and worst-group accuracy. Unfortunately, existing methods to improve robustness require retraining, which can be prohibitively expensive on large foundation models. We also find that efficient ways to improve model inference (e. g. via adapters, lightweight networks that transform FM embeddings) do not consistently improve and can sometimes hurt group robustness compared to zero-shot. We therefore develop an adapter training strategy to effectively and efficiently improve FM group robustness. Our motivating observation is that while poor robustness results from groups in the same class being embedded far apart in the foundation model "embedding space, " standard adapter training may not actually bring these points closer together. We thus propose contrastive adapting, which contrastively trains adapters to bring sample embeddings close to both their ground-truth class embeddings and same-class sample embeddings. Across the 9 robustness benchmarks, contrastive adapting consistently improves group robustness, raising worst-group accuracy by 8. 5 to 56. 0 pp over zero-shot. Our approach is also efficient, doing so without any FM finetuning and only a fixed set of FM embeddings. On popular benchmarks such as Waterbirds and CelebA, this leads to worst-group accuracy comparable to state-of-the-art methods, while only training <1% of the model parameters.

ICML Conference 2022 Conference Paper

Correct-N-Contrast: a Contrastive Approach for Improving Robustness to Spurious Correlations

  • Michael Zhang
  • Nimit Sharad Sohoni
  • Hongyang Zhang 0005
  • Chelsea Finn
  • Christopher Ré

Spurious correlations pose a major challenge for robust machine learning. Models trained with empirical risk minimization (ERM) may learn to rely on correlations between class labels and spurious attributes, leading to poor performance on data groups without these correlations. This is challenging to address when the spurious attribute labels are unavailable. To improve worst-group performance on spuriously correlated data without training attribute labels, we propose Correct-N-Contrast (CNC), a contrastive approach to directly learn representations robust to spurious correlations. As ERM models can be good spurious attribute predictors, CNC works by (1) using a trained ERM model’s outputs to identify samples with the same class but dissimilar spurious features, and (2) training a robust model with contrastive learning to learn similar representations for these samples. To support CNC, we introduce new connections between worst-group error and a representation alignment loss that CNC aims to minimize. We empirically observe that worst-group error closely tracks with alignment loss, and prove that the alignment loss over a class helps upper-bound the class’s worst-group vs. average error gap. On popular benchmarks, CNC reduces alignment loss drastically, and achieves state-of-the-art worst-group accuracy by 3. 6% average absolute lift. CNC is also competitive with oracle methods that require group labels.

NeurIPS Conference 2022 Conference Paper

Decentralized Training of Foundation Models in Heterogeneous Environments

  • Binhang Yuan
  • Yongjun He
  • Jared Davis
  • Tianyi Zhang
  • Tri Dao
  • Beidi Chen
  • Percy S. Liang
  • Christopher Ré

Training foundation models, such as GPT-3 and PaLM, can be extremely expensive, often involving tens of thousands of GPUs running continuously for months. These models are typically trained in specialized clusters featuring fast, homogeneous interconnects and using carefully designed software systems that support both data parallelism and model/pipeline parallelism. Such dedicated clusters can be costly and difficult to obtain. Can we instead leverage the much greater amount of decentralized, heterogeneous, and lower-bandwidth interconnected compute? Previous works examining the heterogeneous, decentralized setting focus on relatively small models that can be trained in a purely data parallel manner. State-of-the-art schemes for model parallel foundation model training, such as Megatron and Deepspeed, only consider the homogeneous data center setting. In this paper, we present the first study of training large foundation models with model parallelism in a decentralized regime over a heterogeneous network. Our key technical contribution is a scheduling algorithm that allocates different computational “tasklets” in the training of foundation models to a group of decentralized GPU devices connected by a slow heterogeneous network. We provide a formal cost model and further propose an efficient evolutionary algorithm to find the optimal allocation strategy. We conduct extensive experiments that represent different scenarios for learning over geo-distributed devices simulated using real-world network measurements. In the most extreme case, across 8 different cities spanning 3 continents, our approach is 4. 8× faster than prior state-of-the-art training systems.

ICLR Conference 2022 Conference Paper

Domino: Discovering Systematic Errors with Cross-Modal Embeddings

  • Sabri Eyuboglu
  • Maya Varma
  • Khaled Kamal Saab
  • Jean-Benoit Delbrouck
  • Christopher Lee-Messer
  • Jared Dunnmon
  • James Y. Zou
  • Christopher Ré

Machine learning models that achieve high overall accuracy often make systematic errors on important subsets (or slices) of data. Identifying underperforming slices is particularly challenging when working with high-dimensional inputs (e.g. images, audio), where important slices are often unlabeled. In order to address this issue, recent studies have proposed automated slice discovery methods (SDMs), which leverage learned model representations to mine input data for slices on which a model performs poorly. To be useful to a practitioner, these methods must identify slices that are both underperforming and coherent (i.e. united by a human-understandable concept). However, no quantitative evaluation framework currently exists for rigorously assessing SDMs with respect to these criteria. Additionally, prior qualitative evaluations have shown that SDMs often identify slices that are incoherent. In this work, we address these challenges by first designing a principled evaluation framework that enables a quantitative comparison of SDMs across 1,235 slice discovery settings in three input domains (natural images, medical images, and time-series data). Then, motivated by the recent development of powerful cross-modal representation learning approaches, we present Domino, an SDM that leverages cross-modal embeddings and a novel error-aware mixture model to discover and describe coherent slices. We find that Domino accurately identifies 36% of the 1,235 slices in our framework -- a 12 percentage point improvement over prior methods. Further, Domino is the first SDM that can provide natural language descriptions of identified slices, correctly generating the exact name of the slice in 35% of settings.

ICLR Conference 2022 Conference Paper

Efficiently Modeling Long Sequences with Structured State Spaces

  • Albert Gu
  • Karan Goel
  • Christopher Ré

A central goal of sequence modeling is designing a single principled model that can address sequence data across a range of modalities and tasks, particularly on long-range dependencies. Although conventional models including RNNs, CNNs, and Transformers have specialized variants for capturing long dependencies, they still struggle to scale to very long sequences of $10000$ or more steps. A promising recent approach proposed modeling sequences by simulating the fundamental state space model (SSM) \( x'(t) = Ax(t) + Bu(t), y(t) = Cx(t) + Du(t) \), and showed that for appropriate choices of the state matrix \( A \), this system could handle long-range dependencies mathematically and empirically. However, this method has prohibitive computation and memory requirements, rendering it infeasible as a general sequence modeling solution. We propose the Structured State Space sequence model (S4) based on a new parameterization for the SSM, and show that it can be computed much more efficiently than prior approaches while preserving their theoretical strengths. Our technique involves conditioning \( A \) with a low-rank correction, allowing it to be diagonalized stably and reducing the SSM to the well-studied computation of a Cauchy kernel. S4 achieves strong empirical results across a diverse range of established benchmarks, including (i) 91\% accuracy on sequential CIFAR-10 with no data augmentation or auxiliary losses, on par with a larger 2-D ResNet, (ii) substantially closing the gap to Transformers on image and language modeling tasks, while performing generation $60\times$ faster (iii) SoTA on every task from the Long Range Arena benchmark, including solving the challenging Path-X task of length 16k that all prior work fails on, while being as efficient as all competitors.

NeurIPS Conference 2022 Conference Paper

Fine-tuning Language Models over Slow Networks using Activation Quantization with Guarantees

  • Jue Wang
  • Binhang Yuan
  • Luka Rimanic
  • Yongjun He
  • Tri Dao
  • Beidi Chen
  • Christopher Ré
  • Ce Zhang

Communication compression is a crucial technique for modern distributed learning systems to alleviate their communication bottlenecks over slower networks. Despite recent intensive studies of gradient compression for data parallel-style training, compressing the activations for models trained with pipeline parallelism is still an open problem. In this paper, we propose AQ-SGD, a novel activation compression algorithm for communication-efficient pipeline parallelism training over slow networks. Different from previous efforts in activation compression, instead of compressing activation values directly, AQ-SGD compresses the changes of the activations. This allows us to show, to the best of our knowledge for the first time, that one can still achieve $O(1/\sqrt{T})$ convergence rate for non-convex objectives under activation compression, without making assumptions on gradient unbiasedness that do not hold for deep learning models with non-linear activation functions. We then show that AQ-SGD can be optimized and implemented efficiently, without additional end-to-end runtime overhead. We evaluated AQ-SGD to fine-tune language models with up to 1. 5 billion parameters, compressing activation to 2-4 bits. AQ-SGD provides up to $4. 3\times$ end-to-end speed-up in slower networks, without sacrificing model quality. Moreover, we also show that AQ-SGD can be combined with state-of-the-art gradient compression algorithms to enable end-to-end communication compression: All communications between machines, including model gradients, forward activations, and backward gradients are compressed into lower precision. This provides up to $4. 9\times$ end-to-end speed-up, without sacrificing model quality.

NeurIPS Conference 2022 Conference Paper

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

  • Tri Dao
  • Dan Fu
  • Stefano Ermon
  • Atri Rudra
  • Christopher Ré

Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length. Approximate attention methods have attempted to address this problem by trading off model quality to reduce the compute complexity, but often do not achieve wall-clock speedup. We argue that a missing principle is making attention algorithms IO-aware---accounting for reads and writes between levels of GPU memory. We propose FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM. We analyze the IO complexity of FlashAttention, showing that it requires fewer HBM accesses than standard attention, and is optimal for a range of SRAM sizes. We also extend FlashAttention, yielding an approximate attention algorithm that is faster than any existing approximate attention method. FlashAttention, 3x speedup on GPT-2 (seq. length 1K), and 2. 4x speedup on long-range arena (seq. length 1K-4K). FlashAttention, yielding higher quality models (0. 7 better perplexity on GPT-2 and 6. 4 points of lift on long-document classification) and entirely new capabilities: the first Transformers to achieve better-than-chance performance on the Path-X challenge (seq. length 16K, 61. 4% accuracy) and Path-256 (seq. length 64K, 63. 1% accuracy).

NeurIPS Conference 2022 Conference Paper

HAPI: A Large-scale Longitudinal Dataset of Commercial ML API Predictions

  • Lingjiao Chen
  • Zhihua Jin
  • Evan Sabri Eyuboglu
  • Christopher Ré
  • Matei Zaharia
  • James Y. Zou

Commercial ML APIs offered by providers such as Google, Amazon and Microsoft have dramatically simplified ML adoptions in many applications. Numerous companies and academics pay to use ML APIs for tasks such as object detection, OCR and sentiment analysis. Different ML APIs tackling the same task can have very heterogeneous performances. Moreover, the ML models underlying the APIs also evolve over time. As ML APIs rapidly become a valuable marketplace and an integral part of analytics, it is critical to systematically study and compare different APIs with each other and to characterize how individual APIs change over time. However, this practically important topic is currently underexplored due to the lack of data. In this paper, we present HAPI (History of APIs), a longitudinal dataset of 1, 761, 417 instances of commercial ML API applications (involving APIs from Amazon, Google, IBM, Microsoft and other providers) across diverse tasks including image tagging, speech recognition, and text mining from 2020 to 2022. Each instance consists of a query input for an API (e. g. , an image or text) along with the API’s output prediction/annotation and confidence scores. HAPI is the first large-scale dataset of ML API usages and is a unique resource for studying ML as-a-service (MLaaS). As examples of the types of analyses that HAPI enables, we show that ML APIs’ performance changes substantially over time—several APIs’ accuracies dropped on specific benchmark datasets. Even when the API’s aggregate performance stays steady, its error modes can shift across different subtypes of data between 2020 and 2022. Such changes can substantially impact the entire analytics pipelines that use some ML API as a component. We further use HAPI to study commercial APIs’ performance disparities across demographic subgroups over time. HAPI can stimulate more research in the growing field of MLaaS.

ICML Conference 2022 Conference Paper

It's Raw! Audio Generation with State-Space Models

  • Karan Goel
  • Albert Gu
  • Chris Donahue
  • Christopher Ré

Developing architectures suitable for modeling raw audio is a challenging problem due to the high sampling rates of audio waveforms. Standard sequence modeling approaches like RNNs and CNNs have previously been tailored to fit the demands of audio, but the resultant architectures make undesirable computational tradeoffs and struggle to model waveforms effectively. We propose SaShiMi, a new multi-scale architecture for waveform modeling built around the recently introduced S4 model for long sequence modeling. We identify that S4 can be unstable during autoregressive generation, and provide a simple improvement to its parameterization by drawing connections to Hurwitz matrices. SaShiMi yields state-of-the-art performance for unconditional waveform generation in the autoregressive setting. Additionally, SaShiMi improves non-autoregressive generation performance when used as the backbone architecture for a diffusion model. Compared to prior architectures in the autoregressive generation setting, SaShiMi generates piano and speech waveforms which humans find more musical and coherent respectively, e. g. 2{\texttimes} better mean opinion scores than WaveNet on an unconditional speech generation task. On a music generation task, SaShiMi outperforms WaveNet on density estimation and speed at both training and inference even when using 3{\texttimes} fewer parameters

JMLR Journal 2022 Journal Article

Machine Learning on Graphs: A Model and Comprehensive Taxonomy

  • Ines Chami
  • Sami Abu-El-Haija
  • Bryan Perozzi
  • Christopher Ré
  • Kevin Murphy

There has been a surge of recent interest in graph representation learning (GRL). GRL methods have generally fallen into three main categories, based on the availability of labeled data. The first, network embedding, focuses on learning unsupervised representations of relational structure. The second, graph regularized neural networks, leverages graphs to augment neural network losses with a regularization objective for semi-supervised learning. The third, graph neural networks, aims to learn differentiable functions over discrete topologies with arbitrary structure. However, despite the popularity of these areas there has been surprisingly little work on unifying the three paradigms. Here, we aim to bridge the gap between network embedding, graph regularization and graph neural networks. We propose a comprehensive taxonomy of GRL methods, aiming to unify several disparate bodies of work. Specifically, we propose the GraphEDM framework, which generalizes popular algorithms for semi-supervised learning (e.g. GraphSage, GCN, GAT), and unsupervised learning (e.g. DeepWalk, node2vec) of graph representations into a single consistent approach. To illustrate the generality of GraphEDM, we fit over thirty existing methods into this framework. We believe that this unifying view both provides a solid foundation for understanding the intuition behind these methods, and enables future research in the area. [abs] [ pdf ][ bib ] &copy JMLR 2022. ( edit, beta )

ICML Conference 2022 Conference Paper

Monarch: Expressive Structured Matrices for Efficient and Accurate Training

  • Tri Dao
  • Beidi Chen
  • Nimit Sharad Sohoni
  • Arjun D. Desai
  • Michael Poli
  • Jessica Grogan
  • Alexander Liu
  • Aniruddh Rao

Large neural networks excel in many domains, but they are expensive to train and fine-tune. A popular approach to reduce their compute or memory requirements is to replace dense weight matrices with structured ones (e. g. , sparse, low-rank, Fourier transform). These methods have not seen widespread adoption (1) in end-to-end training due to unfavorable efficiency–quality tradeoffs, and (2) in dense-to-sparse fine-tuning due to lack of tractable algorithms to approximate a given dense weight matrix. To address these issues, we propose a class of matrices (Monarch) that is hardware-efficient (they are parameterized as products of two block-diagonal matrices for better hardware utilization) and expressive (they can represent many commonly used transforms). Surprisingly, the problem of approximating a dense weight matrix with a Monarch matrix, though nonconvex, has an analytical optimal solution. These properties of Monarch matrices unlock new ways to train and fine-tune sparse and dense models. We empirically validate that Monarch can achieve favorable accuracy-efficiency tradeoffs in several end-to-end sparse training applications: speeding up ViT and GPT-2 training on ImageNet classification and Wikitext-103 language modeling by 2x with comparable model quality, and reducing the error on PDE solving and MRI reconstruction tasks by 40%. In sparse-to-dense training, with a simple technique called "reverse sparsification, " Monarch matrices serve as a useful intermediate representation to speed up GPT-2 pretraining on OpenWebText by 2x without quality drop. The same technique brings 23% faster BERT pretraining than even the very optimized implementation from Nvidia that set the MLPerf 1. 1 record. In dense-to-sparse fine-tuning, as a proof-of-concept, our Monarch approximation algorithm speeds up BERT fine-tuning on GLUE by 1. 7x with comparable accuracy.

NeurIPS Conference 2022 Conference Paper

On the Parameterization and Initialization of Diagonal State Space Models

  • Albert Gu
  • Karan Goel
  • Ankit Gupta
  • Christopher Ré

State space models (SSM) have recently been shown to be very effective as a deep learning layer as a promising alternative to sequence models such as RNNs, CNNs, or Transformers. The first version to show this potential was the S4 model, which is particularly effective on tasks involving long-range dependencies by using a prescribed state matrix called the HiPPO matrix. While this has an interpretable mathematical mechanism for modeling long dependencies, it also requires a custom representation and algorithm that makes the model difficult to understand and implement. On the other hand, a recent variant of S4 called DSS showed that restricting the state matrix to be fully diagonal can still preserve the performance of the original model when using a specific initialization based on approximating S4's matrix. This work seeks to systematically understand how to parameterize and initialize diagonal state space models. While it follows from classical results that almost all SSMs have an equivalent diagonal form, we show that the initialization is critical for performance. First, we explain why DSS works mathematically, as the diagonal approximation to S4 surprisingly recovers the same dynamics in the limit of infinite state dimension. We then systematically describe various design choices in parameterizing and computing diagonal SSMs, and perform a controlled empirical study ablating the effects of these choices. Our final model S4D is a simple diagonal version of S4 whose kernel computation requires just 3 lines of code and performs comparably to S4 in almost all settings, with state-of-the-art results in image, audio, and medical time-series domains, and 85\% average on the Long Range Arena benchmark.

ICML Conference 2022 Conference Paper

Perfectly Balanced: Improving Transfer and Robustness of Supervised Contrastive Learning

  • Mayee F. Chen
  • Daniel Y. Fu
  • Avanika Narayan
  • Michael Zhang
  • Zhao Song 0002
  • Kayvon Fatahalian
  • Christopher Ré

An ideal learned representation should display transferability and robustness. Supervised contrastive learning (SupCon) is a promising method for training accurate models, but produces representations that do not capture these properties due to class collapse—when all points in a class map to the same representation. Recent work suggests that "spreading out" these representations improves them, but the precise mechanism is poorly understood. We argue that creating spread alone is insufficient for better representations, since spread is invariant to permutations within classes. Instead, both the correct degree of spread and a mechanism for breaking this invariance are necessary. We first prove that adding a weighted class-conditional InfoNCE loss to SupCon controls the degree of spread. Next, we study three mechanisms to break permutation invariance: using a constrained encoder, adding a class-conditional autoencoder, and using data augmentation. We show that the latter two encourage clustering of latent subclasses under more realistic conditions than the former. Using these insights, we show that adding a properly-weighted class-conditional InfoNCE loss and a class-conditional autoencoder to SupCon achieves 11. 1 points of lift on coarse-to-fine transfer across 5 standard datasets and 4. 7 points on worst-group robustness on 3 datasets, setting state-of-the-art on CelebA by 11. 5 points.

ICLR Conference 2022 Conference Paper

Pixelated Butterfly: Simple and Efficient Sparse training for Neural Network Models

  • Beidi Chen
  • Tri Dao
  • Kaizhao Liang
  • Jiaming Yang
  • Zhao Song 0002
  • Atri Rudra
  • Christopher Ré

Overparameterized neural networks generalize well but are expensive to train. Ideally one would like to reduce their computational cost while retaining their generalization benefits. Sparse model training is a simple and promising approach to achieve this, but there remain challenges as existing methods struggle with accuracy loss, slow training runtime, or difficulty in sparsifying all model components. The core problem is that searching for a sparsity mask over a discrete set of sparse matrices is difficult and expensive. To address this, our main insight is to optimize over a continuous superset of sparse matrices with a fixed structure known as products of butterfly matrices. As butterfly matrices are not hardware efficient, we propose simple variants of butterfly (block and flat) to take advantage of modern hardware. Our method (Pixelated Butterfly) uses a simple fixed sparsity pattern based on flat block butterfly and low-rank matrices to sparsify most network layers (e.g., attention, MLP). We empirically validate that Pixelated Butterfly is $3\times$ faster than Butterfly and speeds up training to achieve favorable accuracy--efficiency tradeoffs. On the ImageNet classification and WikiText-103 language modeling tasks, our sparse models train up to 2.3$\times$ faster than the dense MLP-Mixer, Vision Transformer, and GPT-2 small with no drop in accuracy.

NeurIPS Conference 2022 Conference Paper

S4ND: Modeling Images and Videos as Multidimensional Signals with State Spaces

  • Eric Nguyen
  • Karan Goel
  • Albert Gu
  • Gordon Downs
  • Preey Shah
  • Tri Dao
  • Stephen Baccus
  • Christopher Ré

Visual data such as images and videos are typically modeled as discretizations of inherently continuous, multidimensional signals. Existing continuous-signal models attempt to exploit this fact by modeling the underlying signals of visual (e. g. , image) data directly. However, these models have not yet been able to achieve competitive performance on practical vision tasks such as large-scale image and video classification. Building on a recent line of work on deep state space models (SSMs), we propose \method, a new multidimensional SSM layer that extends the continuous-signal modeling ability of SSMs to multidimensional data including images and videos. We show that S4ND can model large-scale visual data in $1$D, $2$D, and $3$D as continuous multidimensional signals and demonstrates strong performance by simply swapping Conv2D and self-attention layers with \method\ layers in existing state-of-the-art models. On ImageNet-1k, \method\ exceeds the performance of a Vision Transformer baseline by $1. 5\%$ when training with a $1$D sequence of patches, and matches ConvNeXt when modeling images in $2$D. For videos, S4ND improves on an inflated $3$D ConvNeXt in activity classification on HMDB-51 by $4\%$. S4ND implicitly learns global, continuous convolutional kernels that are resolution invariant by construction, providing an inductive bias that enables generalization across multiple resolutions. By developing a simple bandlimiting modification to S4 to overcome aliasing, S4ND achieves strong zero-shot (unseen at training time) resolution performance, outperforming a baseline Conv2D by $40\%$ on CIFAR-10 when trained on $8 \times 8$ and tested on $32 \times 32$ images. When trained with progressive resizing, S4ND comes within $\sim 1\%$ of a high-resolution model while training $22\%$ faster.

NeurIPS Conference 2022 Conference Paper

Self-Supervised Learning of Brain Dynamics from Broad Neuroimaging Data

  • Armin Thomas
  • Christopher Ré
  • Russell Poldrack

Self-supervised learning techniques are celebrating immense success in natural language processing (NLP) by enabling models to learn from broad language data at unprecedented scales. Here, we aim to leverage the success of these techniques for mental state decoding, where researchers aim to identify specific mental states (e. g. , the experience of anger or joy) from brain activity. To this end, we devise a set of novel self-supervised learning frameworks for neuroimaging data inspired by prominent learning frameworks in NLP. At their core, these frameworks learn the dynamics of brain activity by modeling sequences of activity akin to how sequences of text are modeled in NLP. We evaluate the frameworks by pre-training models on a broad neuroimaging dataset spanning functional Magnetic Resonance Imaging data from 11, 980 experimental runs of 1, 726 individuals across 34 datasets, and subsequently adapting the pre-trained models to benchmark mental state decoding datasets. The pre-trained models transfer well, generally outperforming baseline models trained from scratch, while models trained in a learning framework based on causal language modeling clearly outperform the others.

UAI Conference 2022 Conference Paper

Shoring up the foundations: fusing model embeddings and weak supervision

  • Mayee F. Chen
  • Daniel Y. Fu
  • Dyah Adila
  • Michael Zhang
  • Frederic Sala
  • Kayvon Fatahalian
  • Christopher Ré

Foundation models offer an exciting new paradigm for constructing models with out-of-the-box embeddings and a few labeled examples. However, it is not clear how to best apply foundation models without labeled data. A potential approach is to fuse foundation models with weak supervision frameworks, which use weak label sources—pre-trained models, heuristics, crowd-workers—to construct pseudolabels. The challenge is building a combination that best exploits the signal available in both foundation models and weak sources. We propose LIGER, a combination that uses foundation model embeddings to improve two crucial elements of existing weak supervision techniques. First, we produce finer estimates of weak source quality by partitioning the embedding space and learning per-part source accuracies. Second, we improve source coverage by extending source votes in embedding space. Despite the black-box nature of foundation models, we prove results characterizing how our approach improves performance and show that lift scales with the smoothness of label distributions in embedding space. On six benchmark NLP and video tasks, LIGER outperforms vanilla weak supervision by 14. 1 points, weakly-supervised kNN and adapters by 11. 8 points, and kNN and adapters supervised by traditional hand labels by 7. 2 points.

NeurIPS Conference 2022 Conference Paper

Transform Once: Efficient Operator Learning in Frequency Domain

  • Michael Poli
  • Stefano Massaroli
  • Federico Berto
  • Jinkyoo Park
  • Tri Dao
  • Christopher Ré
  • Stefano Ermon

Spectral analysis provides one of the most effective paradigms for information-preserving dimensionality reduction, as simple descriptions of naturally occurring signals are often obtained via few terms of periodic basis functions. In this work, we study deep neural networks designed to harness the structure in frequency domain for efficient learning of long-range correlations in space or time: frequency-domain models (FDMs). Existing FDMs are based on complex-valued transforms i. e. Fourier Transforms (FT), and layers that perform computation on the spectrum and input data separately. This design introduces considerable computational overhead: for each layer, a forward and inverse FT. Instead, this work introduces a blueprint for frequency domain learning through a single transform: transform once (T1). To enable efficient, direct learning in the frequency domain we derive a variance preserving weight initialization scheme and investigate methods for frequency selection in reduced-order FDMs. Our results noticeably streamline the design process of FDMs, pruning redundant transforms, and leading to speedups of 3x to 10x that increase with data resolution and model size. We perform extensive experiments on learning the solution operator of spatio-temporal dynamics, including incompressible Navier-Stokes, turbulent flows around airfoils and high-resolution video of smoke. T1 models improve on the test performance of FDMs while requiring significantly less computation (5 hours instead of 32 for our large-scale experiment), with over 20% reduction in predictive error across tasks.

ICML Conference 2021 Conference Paper

Catformer: Designing Stable Transformers via Sensitivity Analysis

  • Jared Quincy Davis
  • Albert Gu
  • Krzysztof Choromanski
  • Tri Dao
  • Christopher Ré
  • Chelsea Finn
  • Percy Liang

Transformer architectures are widely used, but training them is non-trivial, requiring custom learning rate schedules, scaling terms, residual connections, careful placement of submodules such as normalization, and so on. In this paper, we improve upon recent analysis of Transformers and formalize a notion of sensitivity to capture the difficulty of training. Sensitivity characterizes how the variance of activation and gradient norms change in expectation when parameters are randomly perturbed. We analyze the sensitivity of previous Transformer architectures and design a new architecture, the Catformer, which replaces residual connections or RNN-based gating mechanisms with concatenation. We prove that Catformers are less sensitive than other Transformer variants and demonstrate that this leads to more stable training. On DMLab30, a suite of high-dimension reinforcement tasks, Catformer outperforms other transformers, including Gated Transformer-XL—the state-of-the-art architecture designed to address stability—by 13%.

NeurIPS Conference 2021 Conference Paper

Combining Recurrent, Convolutional, and Continuous-time Models with Linear State Space Layers

  • Albert Gu
  • Isys Johnson
  • Karan Goel
  • Khaled Saab
  • Tri Dao
  • Atri Rudra
  • Christopher Ré

Recurrent neural networks (RNNs), temporal convolutions, and neural differential equations (NDEs) are popular families of deep learning models for time-series data, each with unique strengths and tradeoffs in modeling power and computational efficiency. We introduce a simple sequence model inspired by control systems that generalizes these approaches while addressing their shortcomings. The Linear State-Space Layer (LSSL) maps a sequence $u \mapsto y$ by simply simulating a linear continuous-time state-space representation $\dot{x} = Ax + Bu, y = Cx + Du$. Theoretically, we show that LSSL models are closely related to the three aforementioned families of models and inherit their strengths. For example, they generalize convolutions to continuous-time, explain common RNN heuristics, and share features of NDEs such as time-scale adaptation. We then incorporate and generalize recent theory on continuous-time memorization to introduce a trainable subset of structured matrices $A$ that endow LSSLs with long-range memory. Empirically, stacking LSSL layers into a simple deep neural network obtains state-of-the-art results across time series benchmarks for long dependencies in sequential image classification, real-world healthcare regression tasks, and speech. On a difficult speech classification task with length-16000 sequences, LSSL outperforms prior approaches by 24 accuracy points, and even outperforms baselines that use hand-crafted features on 100x shorter sequences.

ICLR Conference 2021 Conference Paper

Cut out the annotator, keep the cutout: better segmentation with weak supervision

  • Sarah M. Hooper
  • Michael Wornow
  • Ying Hang Seah
  • Peter Kellman
  • Hui Xue 0006
  • Frederic Sala
  • Curtis Langlotz
  • Christopher Ré

Constructing large, labeled training datasets for segmentation models is an expensive and labor-intensive process. This is a common challenge in machine learning, addressed by methods that require few or no labeled data points such as few-shot learning (FSL) and weakly-supervised learning (WS). Such techniques, however, have limitations when applied to image segmentation---FSL methods often produce noisy results and are strongly dependent on which few datapoints are labeled, while WS models struggle to fully exploit rich image information. We propose a framework that fuses FSL and WS for segmentation tasks, enabling users to train high-performing segmentation networks with very few hand-labeled training points. We use FSL models as weak sources in a WS framework, requiring a very small set of reference labeled images, and introduce a new WS model that focuses on key areas---areas with contention among noisy labels---of the image to fuse these weak sources. Empirically, we evaluate our proposed approach over seven well-motivated segmentation tasks. We show that our methods can achieve within 1.4 Dice points compared to fully supervised networks while only requiring five hand-labeled training points. Compared to existing FSL methods, our approach improves performance by a mean 3.6 Dice points over the next-best method.

ICML Conference 2021 Conference Paper

HoroPCA: Hyperbolic Dimensionality Reduction via Horospherical Projections

  • Ines Chami
  • Albert Gu
  • Dat Nguyen
  • Christopher Ré

This paper studies Principal Component Analysis (PCA) for data lying in hyperbolic spaces. Given directions, PCA relies on: (1) a parameterization of subspaces spanned by these directions, (2) a method of projection onto subspaces that preserves information in these directions, and (3) an objective to optimize, namely the variance explained by projections. We generalize each of these concepts to the hyperbolic space and propose HoroPCA, a method for hyperbolic dimensionality reduction. By focusing on the core problem of extracting principal directions, HoroPCA theoretically better preserves information in the original data such as distances, compared to previous generalizations of PCA. Empirically, we validate that HoroPCA outperforms existing dimensionality reduction methods, significantly reducing error in distance preservation. As a data whitening method, it improves downstream classification by up to 3. 9% compared to methods that don’t use whitening. Finally, we show that HoroPCA can be used to visualize hyperbolic data in two dimensions.

ICML Conference 2021 Conference Paper

Mandoline: Model Evaluation under Distribution Shift

  • Mayee F. Chen
  • Karan Goel
  • Nimit Sharad Sohoni
  • Fait Poms
  • Kayvon Fatahalian
  • Christopher Ré

Machine learning models are often deployed in different settings than they were trained and validated on, posing a challenge to practitioners who wish to predict how well the deployed model will perform on a target distribution. If an unlabeled sample from the target distribution is available, along with a labeled sample from a possibly different source distribution, standard approaches such as importance weighting can be applied to estimate performance on the target. However, importance weighting struggles when the source and target distributions have non-overlapping support or are high-dimensional. Taking inspiration from fields such as epidemiology and polling, we develop Mandoline, a new evaluation framework that mitigates these issues. Our key insight is that practitioners may have prior knowledge about the ways in which the distribution shifts, which we can use to better guide the importance weighting procedure. Specifically, users write simple "slicing functions" {–} noisy, potentially correlated binary functions intended to capture possible axes of distribution shift {–} to compute reweighted performance estimates. We further describe a density ratio estimation framework for the slices and show how its estimation error scales with slice quality and dataset size. Empirical validation on NLP and vision tasks shows that Mandoline can estimate performance on the target distribution up to 3x more accurately compared to standard baselines.

ICLR Conference 2021 Conference Paper

Model Patching: Closing the Subgroup Performance Gap with Data Augmentation

  • Karan Goel
  • Albert Gu
  • Yixuan Li 0001
  • Christopher Ré

Classifiers in machine learning are often brittle when deployed. Particularly concerning are models with inconsistent performance on specific subgroups of a class, e.g., exhibiting disparities in skin cancer classification in the presence or absence of a spurious bandage. To mitigate these performance differences, we introduce model patching, a two-stage framework for improving robustness that encourages the model to be invariant to subgroup differences, and focus on class information shared by subgroups. Model patching first models subgroup features within a class and learns semantic transformations between them, and then trains a classifier with data augmentations that deliberately manipulate subgroup features. We instantiate model patching with CAMEL, which (1) uses a CycleGAN to learn the intra-class, inter-subgroup augmentations, and (2) balances subgroup performance using a theoretically-motivated subgroup consistency regularizer, accompanied by a new robust objective. We demonstrate CAMEL’s effectiveness on 3 benchmark datasets, with reductions in robust error of up to 33% relative to the best baseline. Lastly, CAMEL successfully patches a model that fails due to spurious features on a real-world skin cancer dataset.

ICLR Conference 2021 Conference Paper

MONGOOSE: A Learnable LSH Framework for Efficient Neural Network Training

  • Beidi Chen
  • Zichang Liu
  • Binghui Peng
  • Zhaozhuo Xu
  • Jonathan Lingjie Li
  • Tri Dao
  • Zhao Song 0002
  • Anshumali Shrivastava

Recent advances by practitioners in the deep learning community have breathed new life into Locality Sensitive Hashing (LSH), using it to reduce memory and time bottlenecks in neural network (NN) training. However, while LSH has sub-linear guarantees for approximate near-neighbor search in theory, it is known to have inefficient query time in practice due to its use of random hash functions. Moreover, when model parameters are changing, LSH suffers from update overhead. This work is motivated by an observation that model parameters evolve slowly, such that the changes do not always require an LSH update to maintain performance. This phenomenon points to the potential for a reduction in update time and allows for a modified learnable version of data-dependent LSH to improve query time at a low cost. We use the above insights to build MONGOOSE, an end-to-end LSH framework for efficient NN training. In particular, MONGOOSE is equipped with a scheduling algorithm to adaptively perform LSH updates with provable guarantees and learnable hash functions to improve query efficiency. Empirically, we validate MONGOOSE on large-scale deep learning models for recommendation systems and language modeling. We find that it achieves up to 8% better accuracy compared to previous LSH approaches, with $6.5 \times$ speed-up and $6\times$ reduction in memory usage.

NeurIPS Conference 2021 Conference Paper

Personalized Benchmarking with the Ludwig Benchmarking Toolkit

  • Avanika Narayan
  • Piero Molino
  • Karan Goel
  • Willie Neiswanger
  • Christopher Ré

The rapid proliferation of machine learning models across domains and deployment settings has given rise to various communities (e. g. industry practitioners) which seek to benchmark models across tasks and objectives of personal value. Unfortunately, these users cannot use standard benchmark results to perform such value-driven comparisons as traditional benchmarks evaluate models on a single objective (e. g. average accuracy) and fail to facilitate a standardized training framework that controls for confounding variables (e. g. computational budget), making fair comparisons difficult. To address these challenges, we introduce the open-source Ludwig Benchmarking Toolkit (LBT), a personalized benchmarking toolkit for running end-to-end benchmark studies (from hyperparameter optimization to evaluation) across an easily extensible set of tasks, deep learning models, datasets and evaluation metrics. LBT provides a configurable interface for controlling training and customizing evaluation, a standardized training framework for eliminating confounding variables, and support for multi-objective evaluation. We demonstrate how LBT can be used to create personalized benchmark studies with a large-scale comparative analysis for text classification across 7 models and 9 datasets. We explore the trade-offs between inference latency and performance, relationships between dataset attributes and performance, and the effects of pretraining on convergence and robustness, showing how LBT can be used to satisfy various benchmarking objectives.

NeurIPS Conference 2021 Conference Paper

Rethinking Neural Operations for Diverse Tasks

  • Nicholas Roberts
  • Mikhail Khodak
  • Tri Dao
  • Liam Li
  • Christopher Ré
  • Ameet Talwalkar

An important goal of AutoML is to automate-away the design of neural networks on new tasks in under-explored domains. Motivated by this goal, we study the problem of enabling users to discover the right neural operations given data from their specific domain. We introduce a search space of operations called XD-Operations that mimic the inductive bias of standard multi-channel convolutions while being much more expressive: we prove that it includes many named operations across multiple application areas. Starting with any standard backbone such as ResNet, we show how to transform it into a search space over XD-operations and how to traverse the space using a simple weight sharing scheme. On a diverse set of tasks—solving PDEs, distance prediction for protein folding, and music modeling—our approach consistently yields models with lower error than baseline networks and often even lower error than expert-designed domain-specific approaches.

NeurIPS Conference 2021 Conference Paper

Scatterbrain: Unifying Sparse and Low-rank Attention

  • Beidi Chen
  • Tri Dao
  • Eric Winsor
  • Zhao Song
  • Atri Rudra
  • Christopher Ré

Recent advances in efficient Transformers have exploited either the sparsity or low-rank properties of attention matrices to reduce the computational and memory bottlenecks of modeling long sequences. However, it is still challenging to balance the trade-off between model quality and efficiency to perform a one-size-fits-all approximation for different tasks. To better understand this trade-off, we observe that sparse and low-rank approximations excel in different regimes, determined by the softmax temperature in attention, and sparse + low-rank can outperform each individually. Inspired by the classical robust-PCA algorithm for sparse and low-rank decomposition, we propose Scatterbrain, a novel way to unify sparse (via locality sensitive hashing) and low-rank (via kernel feature map) attention for accurate and efficient approximation. The estimation is unbiased with provably low error. We empirically show that Scatterbrain can achieve $2. 1 \times$ lower error than baselines when serving as a drop-in replacement in BigGAN image generation and pre-trained T2T-ViT. On a pre-trained T2T Vision transformer, even without fine-tuning, Scatterbrain can reduce $98\%$ of attention memory at the cost of only $1\%$ drop in accuracy. We demonstrate Scatterbrain for end-to-end training with up to $4$ points better perplexity and 5 points better average accuracy than sparse or low-rank efficient transformers on language modeling and long-range-arena tasks.

NeurIPS Conference 2021 Conference Paper

SKM-TEA: A Dataset for Accelerated MRI Reconstruction with Dense Image Labels for Quantitative Clinical Evaluation

  • Arjun Desai
  • Andrew Schmidt
  • Elka Rubin
  • Christopher Sandino
  • Marianne Black
  • Valentina Mazzoli
  • Kathryn Stevens
  • Robert Boutin

Magnetic resonance imaging (MRI) is a cornerstone of modern medical imaging. However, long image acquisition times, the need for qualitative expert analysis, and the lack of (and difficulty extracting) quantitative indicators that are sensitive to tissue health have curtailed widespread clinical and research studies. While recent machine learning methods for MRI reconstruction and analysis have shown promise for reducing this burden, these techniques are primarily validated with imperfect image quality metrics, which are discordant with clinically-relevant measures that ultimately hamper clinical deployment and clinician trust. To mitigate this challenge, we present the Stanford Knee MRI with Multi-Task Evaluation (SKM-TEA) dataset, a collection of quantitative knee MRI (qMRI) scans that enables end-to-end, clinically-relevant evaluation of MRI reconstruction and analysis tools. This 1. 6TB dataset consists of raw-data measurements of ~25, 000 slices (155 patients) of anonymized patient MRI scans, the corresponding scanner-generated DICOM images, manual segmentations of four tissues, and bounding box annotations for sixteen clinically relevant pathologies. We provide a framework for using qMRI parameter maps, along with image reconstructions and dense image labels, for measuring the quality of qMRI biomarker estimates extracted from MRI reconstruction, segmentation, and detection techniques. Finally, we use this framework to benchmark state-of-the-art baselines on this dataset. We hope our SKM-TEA dataset and code can enable a broad spectrum of research for modular image reconstruction and image analysis in a clinically informed manner. Dataset access, code, and benchmarks are available at https: //github. com/StanfordMIMI/skm-tea.

ICML Conference 2020 Conference Paper

Fast and Three-rious: Speeding Up Weak Supervision with Triplet Methods

  • Daniel Y. Fu
  • Mayee F. Chen
  • Frederic Sala
  • Sarah M. Hooper
  • Kayvon Fatahalian
  • Christopher Ré

Weak supervision is a popular method for building machine learning models without relying on ground truth annotations. Instead, it generates probabilistic training labels by estimating the accuracies of multiple noisy labeling sources (e. g. , heuristics, crowd workers). Existing approaches use latent variable estimation to model the noisy sources, but these methods can be computationally expensive, scaling superlinearly in the data. In this work, we show that, for a class of latent variable models highly applicable to weak supervision, we can find a closed-form solution to model parameters, obviating the need for iterative solutions like stochastic gradient descent (SGD). We use this insight to build FlyingSquid, a weak supervision framework that runs orders of magnitude faster than previous weak supervision approaches and requires fewer assumptions. In particular, we prove bounds on generalization error without assuming that the latent variable model can exactly parameterize the underlying data distribution. Empirically, we validate FlyingSquid on benchmark weak supervision datasets and find that it achieves the same or higher quality compared to previous approaches without the need to tune an SGD procedure, recovers model parameters 170 times faster on average, and enables new video analysis and online learning applications.

NeurIPS Conference 2020 Conference Paper

From Trees to Continuous Embeddings and Back: Hyperbolic Hierarchical Clustering

  • Ines Chami
  • Albert Gu
  • Vaggos Chatziafratis
  • Christopher Ré

Similarity-based Hierarchical Clustering (HC) is a classical unsupervised machine learning algorithm that has traditionally been solved with heuristic algorithms like Average-Linkage. Recently, Dasgupta reframed HC as a discrete optimization problem by introducing a global cost function measuring the quality of a given tree. In this work, we provide the first continuous relaxation of Dasgupta's discrete optimization problem with provable quality guarantees. The key idea of our method, HypHC, is showing a direct correspondence from discrete trees to continuous representations (via the hyperbolic embeddings of their leaf nodes) and back (via a decoding algorithm that maps leaf embeddings to a dendrogram), allowing us to search the space of discrete binary trees with continuous optimization. Building on analogies between trees and hyperbolic space, we derive a continuous analogue for the notion of lowest common ancestor, which leads to a continuous relaxation of Dasgupta's discrete objective. We can show that after decoding, the global minimizer of our continuous relaxation yields a discrete tree with a (1+eps)-factor approximation for Dasgupta's optimal tree, where eps can be made arbitrarily small and controls optimization challenges. We experimentally evaluate HypHC on a variety of HC benchmarks and find that even approximate solutions found with gradient descent have superior clustering quality than agglomerative heuristics or other gradient based algorithms. Finally, we highlight the flexibility of HypHC using end-to-end training in a downstream classification task.

NeurIPS Conference 2020 Conference Paper

HiPPO: Recurrent Memory with Optimal Polynomial Projections

  • Albert Gu
  • Tri Dao
  • Stefano Ermon
  • Atri Rudra
  • Christopher Ré

A central problem in learning from sequential data is representing cumulative history in an incremental fashion as more data is processed. We introduce a general framework (HiPPO) for the online compression of continuous signals and discrete time series by projection onto polynomial bases. Given a measure that specifies the importance of each time step in the past, HiPPO produces an optimal solution to a natural online function approximation problem. As special cases, our framework yields a short derivation of the recent Legendre Memory Unit (LMU) from first principles, and generalizes the ubiquitous gating mechanism of recurrent neural networks such as GRUs. This formal framework yields a new memory update mechanism (HiPPO-LegS) that scales through time to remember all history, avoiding priors on the timescale. HiPPO-LegS enjoys the theoretical benefits of timescale robustness, fast updates, and bounded gradients. By incorporating the memory dynamics into recurrent neural networks, HiPPO RNNs can empirically capture complex temporal dependencies. On the benchmark permuted MNIST dataset, HiPPO-LegS sets a new state-of-the-art accuracy of 98. 3%. Finally, on a novel trajectory classification task testing robustness to out-of-distribution timescales and missing data, HiPPO-LegS outperforms RNN and neural ODE baselines by 25-40% accuracy.

ICLR Conference 2020 Conference Paper

Kaleidoscope: An Efficient, Learnable Representation For All Structured Linear Maps

  • Tri Dao
  • Nimit Sharad Sohoni
  • Albert Gu
  • Matthew Eichhorn
  • Amit Blonder
  • Megan Leszczynski
  • Atri Rudra
  • Christopher Ré

Modern neural network architectures use structured linear transformations, such as low-rank matrices, sparse matrices, permutations, and the Fourier transform, to improve inference speed and reduce memory usage compared to general linear maps. However, choosing which of the myriad structured transformations to use (and its associated parameterization) is a laborious task that requires trading off speed, space, and accuracy. We consider a different approach: we introduce a family of matrices called kaleidoscope matrices (K-matrices) that provably capture any structured matrix with near-optimal space (parameter) and time (arithmetic operation) complexity. We empirically validate that K-matrices can be automatically learned within end-to-end pipelines to replace hand-crafted procedures, in order to improve model quality. For example, replacing channel shuffles in ShuffleNet improves classification accuracy on ImageNet by up to 5%. K-matrices can also simplify hand-engineered pipelines---we replace filter bank feature computation in speech data preprocessing with a learnable kaleidoscope layer, resulting in only 0.4% loss in accuracy on the TIMIT speech recognition task. In addition, K-matrices can capture latent structure in models: for a challenging permuted image classification task, adding a K-matrix to a standard convolutional architecture can enable learning the latent permutation and improve accuracy by over 8 points. We provide a practically efficient implementation of our approach, and use K-matrices in a Transformer network to attain 36% faster end-to-end inference speed on a language translation task.

NeurIPS Conference 2020 Conference Paper

No Subclass Left Behind: Fine-Grained Robustness in Coarse-Grained Classification Problems

  • Nimit Sohoni
  • Jared Dunnmon
  • Geoffrey Angus
  • Albert Gu
  • Christopher Ré

In real-world classification tasks, each class often comprises multiple finer-grained "subclasses. " As the subclass labels are frequently unavailable, models trained using only the coarser-grained class labels often exhibit highly variable performance across different subclasses. This phenomenon, known as hidden stratification, has important consequences for models deployed in safety-critical applications such as medicine. We propose GEORGE, a method to both measure and mitigate hidden stratification even when subclass labels are unknown. We first observe that unlabeled subclasses are often separable in the feature space of deep models, and exploit this fact to estimate subclass labels for the training data via clustering techniques. We then use these approximate subclass labels as a form of noisy supervision in a distributionally robust optimization objective. We theoretically characterize the performance of GEORGE in terms of the worst-case generalization error across any subclass. We empirically validate GEORGE on a mix of real-world and benchmark image classification datasets, and show that our approach boosts worst-case subclass accuracy by up to 15 percentage points compared to standard training techniques, without requiring any information about the subclasses.

ICML Conference 2020 Conference Paper

On the Generalization Effects of Linear Transformations in Data Augmentation

  • Sen Wu 0002
  • Hongyang Zhang 0005
  • Gregory Valiant
  • Christopher Ré

Data augmentation is a powerful technique to improve performance in applications such as image and text classification tasks. Yet, there is little rigorous understanding of why and how various augmentations work. In this work, we consider a family of linear transformations and study their effects on the ridge estimator in an over-parametrized linear regression setting. First, we show that transformations which preserve the labels of the data can improve estimation by enlarging the span of the training data. Second, we show that transformations which mix data can improve estimation by playing a regularization effect. Finally, we validate our theoretical insights on MNIST. Based on the insights, we propose an augmentation scheme that searches over the space of transformations by how \emph{uncertain} the model is about the transformed data. We validate our proposed scheme on image and text datasets. For example, our method outperforms RandAugment by 1. 24% on CIFAR-100 using Wide-ResNet-28-10. Furthermore, we achieve comparable accuracy to the SoTA Adversarial AutoAugment on CIFAR datasets.

ICLR Conference 2020 Conference Paper

Understanding and Improving Information Transfer in Multi-Task Learning

  • Sen Wu 0002
  • Hongyang Zhang 0005
  • Christopher Ré

We investigate multi-task learning approaches that use a shared feature representation for all tasks. To better understand the transfer of task information, we study an architecture with a shared module for all tasks and a separate output module for each task. We study the theory of this setting on linear and ReLU-activated models. Our key observation is that whether or not tasks' data are well-aligned can significantly affect the performance of multi-task learning. We show that misalignment between task data can cause negative transfer (or hurt performance) and provide sufficient conditions for positive transfer. Inspired by the theoretical insights, we show that aligning tasks' embedding layers leads to performance gains for multi-task training and transfer learning on the GLUE benchmark and sentiment analysis tasks; for example, we obtained a 2.35% GLUE score average improvement on 5 GLUE tasks over BERT LARGE using our alignment method. We also design an SVD-based task re-weighting scheme and show that it improves the robustness of multi-task training on a multi-label image dataset.

ICML Conference 2019 Conference Paper

A Kernel Theory of Modern Data Augmentation

  • Tri Dao
  • Albert Gu
  • Alexander Ratner
  • Virginia Smith
  • Christopher De Sa
  • Christopher Ré

Data augmentation, a technique in which a training set is expanded with class-preserving transformations, is ubiquitous in modern machine learning pipelines. In this paper, we seek to establish a theoretical framework for understanding data augmentation. We approach this from two directions: First, we provide a general model of augmentation as a Markov process, and show that kernels appear naturally with respect to this model, even when we do not employ kernel classification. Next, we analyze more directly the effect of augmentation on kernel classifiers, showing that data augmentation can be approximated by first-order feature averaging and second-order variance regularization components. These frameworks both serve to illustrate the ways in which data augmentation affects the downstream learning model, and the resulting analyses provide novel connections between prior work in invariant kernels, tangent propagation, and robust optimization. Finally, we provide several proof-of-concept applications showing that our theory can be useful for accelerating machine learning workflows, such as reducing the amount of computation needed to train using augmented data, and predicting the utility of a transformation prior to training.

NeurIPS Conference 2019 Conference Paper

Hyperbolic Graph Convolutional Neural Networks

  • Ines Chami
  • Zhitao Ying
  • Christopher Ré
  • Jure Leskovec

Graph convolutional neural networks (GCNs) embed nodes in a graph into Euclidean space, which has been shown to incur a large distortion when embedding real-world graphs with scale-free or hierarchical structure. Hyperbolic geometry offers an exciting alternative, as it enables embeddings with much smaller distortion. However, extending GCNs to hyperbolic geometry presents several unique challenges because it is not clear how to define neural network operations, such as feature transformation and aggregation, in hyperbolic space. Furthermore, since input features are often Euclidean, it is unclear how to transform the features into hyperbolic embeddings with the right amount of curvature. Here we propose Hyperbolic Graph Convolutional Neural Network (HGCN), the first inductive hyperbolic GCN that leverages both the expressiveness of GCNs and hyperbolic geometry to learn inductive node representations for hierarchical and scale-free graphs. We derive GCNs operations in the hyperboloid model of hyperbolic space and map Euclidean input features to embeddings in hyperbolic spaces with different trainable curvature at each layer. Experiments demonstrate that HGCN learns embeddings that preserve hierarchical structure, and leads to improved performance when compared to Euclidean analogs, even with very low dimensional embeddings: compared to state-of-the-art GCNs, HGCN achieves an error reduction of up to 63. 1% in ROC AUC for link prediction and of up to 47. 5% in F1 score for node classification, also improving state-of-the art on the Pubmed dataset.

ICML Conference 2019 Conference Paper

Learning Dependency Structures for Weak Supervision Models

  • Paroma Varma
  • Frederic Sala
  • Ann He
  • Alexander Ratner
  • Christopher Ré

Labeling training data is a key bottleneck in the modern machine learning pipeline. Recent weak supervision approaches combine labels from multiple noisy sources by estimating their accuracies without access to ground truth labels; however, estimating the dependencies among these sources is a critical challenge. We focus on a robust PCA-based algorithm for learning these dependency structures, establish improved theoretical recovery rates, and outperform existing methods on various real-world tasks. Under certain conditions, we show that the amount of unlabeled data needed can scale sublinearly or even logarithmically with the number of sources m, improving over previous efforts that ignore the sparsity pattern in the dependency structure and scale linearly in m. We provide an information-theoretic lower bound on the minimum sample complexity of the weak supervision setting. Our method outperforms weak supervision approaches that assume conditionally-independent sources by up to 4. 64 F1 points and previous structure learning approaches by up to 4. 41 F1 points on real-world relation extraction and image classification tasks.

ICML Conference 2019 Conference Paper

Learning Fast Algorithms for Linear Transforms Using Butterfly Factorizations

  • Tri Dao
  • Albert Gu
  • Matthew Eichhorn
  • Atri Rudra
  • Christopher Ré

Fast linear transforms are ubiquitous in machine learning, including the discrete Fourier transform, discrete cosine transform, and other structured transformations such as convolutions. All of these transforms can be represented by dense matrix-vector multiplication, yet each has a specialized and highly efficient (subquadratic) algorithm. We ask to what extent hand-crafting these algorithms and implementations is necessary, what structural prior they encode, and how much knowledge is required to automatically learn a fast algorithm for a provided structured transform. Motivated by a characterization of fast matrix-vector multiplication as products of sparse matrices, we introduce a parameterization of divide-and-conquer methods that is capable of representing a large class of transforms. This generic formulation can automatically learn an efficient algorithm for many important transforms; for example, it recovers the $O(N \log N)$ Cooley-Tukey FFT algorithm to machine precision, for dimensions $N$ up to $1024$. Furthermore, our method can be incorporated as a lightweight replacement of generic matrices in machine learning pipelines to learn efficient and compressible transformations. On a standard task of compressing a single hidden-layer network, our method exceeds the classification accuracy of unconstrained matrices on CIFAR-10 by 3. 9 points—the first time a structured approach has done so—with 4X faster inference speed and 40X fewer parameters.

NeurIPS Conference 2019 Conference Paper

Multi-Resolution Weak Supervision for Sequential Data

  • Paroma Varma
  • Frederic Sala
  • Shiori Sagawa
  • Jason Fries
  • Daniel Fu
  • Saelig Khattar
  • Ashwini Ramamoorthy
  • Ke Xiao

Since manually labeling training data is slow and expensive, recent industrial and scientific research efforts have turned to weaker or noisier forms of supervision sources. However, existing weak supervision approaches fail to model multi-resolution sources for sequential data, like video, that can assign labels to individual elements or collections of elements in a sequence. A key challenge in weak supervision is estimating the unknown accuracies and correlations of these sources without using labeled data. Multi-resolution sources exacerbate this challenge due to complex correlations and sample complexity that scales in the length of the sequence. We propose Dugong, the first framework to model multi-resolution weak supervision sources with complex correlations to assign probabilistic labels to training data. Theoretically, we prove that Dugong, under mild conditions, can uniquely recover the unobserved accuracy and correlation parameters and use parameter sharing to improve sample complexity. Our method assigns clinician-validated labels to population-scale biomedical video repositories, helping outperform traditional supervision by 36. 8 F1 points and addressing a key use case where machine learning has been severely limited by the lack of expert labeled data. On average, Dugong improves over traditional supervision by 16. 0 F1 points and existing weak supervision approaches by 24. 2 F1 points across several video and sensor classification tasks.

NeurIPS Conference 2019 Conference Paper

On the Downstream Performance of Compressed Word Embeddings

  • Avner May
  • Jian Zhang
  • Tri Dao
  • Christopher Ré

Compressing word embeddings is important for deploying NLP models in memory-constrained settings. However, understanding what makes compressed embeddings perform well on downstream tasks is challenging---existing measures of compression quality often fail to distinguish between embeddings that perform well and those that do not. We thus propose the eigenspace overlap score as a new measure. We relate the eigenspace overlap score to downstream performance by developing generalization bounds for the compressed embeddings in terms of this score, in the context of linear and logistic regression. We then show that we can lower bound the eigenspace overlap score for a simple uniform quantization compression method, helping to explain the strong empirical performance of this method. Finally, we show that by using the eigenspace overlap score as a selection criterion between embeddings drawn from a representative set we compressed, we can efficiently identify the better performing embedding with up to 2x lower selection error rates than the next best measure of compression quality, and avoid the cost of training a separate model for each task of interest.

NeurIPS Conference 2019 Conference Paper

Slice-based Learning: A Programming Model for Residual Learning in Critical Data Slices

  • Vincent Chen
  • Sen Wu
  • Alexander Ratner
  • Jen Weng
  • Christopher Ré

In real-world machine learning applications, data subsets correspond to especially critical outcomes: vulnerable cyclist detections are safety-critical in an autonomous driving task, and "question" sentences might be important to a dialogue agent's language understanding for product purposes. While machine learning models can achieve quality performance on coarse-grained metrics like F1-score and overall accuracy, they may underperform on these critical subsets---we define these as slices, the key abstraction in our approach. To address slice-level performance, practitioners often train separate "expert" models on slice subsets or use multi-task hard parameter sharing. We propose Slice-based Learning, a new programming model in which the slicing function (SF), a programmer abstraction, is used to specify additional model capacity for each slice. Any model can leverage SFs to learn slice-specific representations, which are combined with an attention mechanism to make slice-aware predictions. We show that our approach improves over baselines in terms of computational complexity and slice-specific performance by up to 19. 0 points, and overall performance by up to 4. 6 F1 points on applications spanning natural language understanding and computer vision benchmarks as well as production-scale industrial systems.

AAAI Conference 2019 Conference Paper

Training Complex Models with Multi-Task Weak Supervision

  • Alexander Ratner
  • Braden Hancock
  • Jared Dunnmon
  • Frederic Sala
  • Shreyash Pandey
  • Christopher Ré

As machine learning models continue to increase in complexity, collecting large hand-labeled training sets has become one of the biggest roadblocks in practice. Instead, weaker forms of supervision that provide noisier but cheaper labels are often used. However, these weak supervision sources have diverse and unknown accuracies, may output correlated labels, and may label different tasks or apply at different levels of granularity. We propose a framework for integrating and modeling such weak supervision sources by viewing them as labeling different related sub-tasks of a problem, which we refer to as the multi-task weak supervision setting. We show that by solving a matrix completion-style problem, we can recover the accuracies of these multi-task sources given their dependency structure, but without any labeled data, leading to higher-quality supervision for training an end model. Theoretically, we show that the generalization error of models trained with this approach improves with the number of unlabeled data points, and characterize the scaling with respect to the task and dependency structures. On three fine-grained classification problems, we show that our approach leads to average gains of 20. 2 points in accuracy over a traditional supervised approach, 6. 8 points over a majority vote baseline, and 4. 1 points over a previously proposed weak supervision method that models tasks separately.

SODA Conference 2018 Conference Paper

A Two-pronged Progress in Structured Dense Matrix Vector Multiplication

  • Christopher De Sa
  • Albert Gu
  • Rohan Puttagunta
  • Christopher Ré
  • Atri Rudra

Matrix-vector multiplication is one of the most fundamental computing primitives. Given a matrix and a vector, it is known that in the worst case Θ( N 2 ) operations over F are needed to compute Ab. Many types of structured matrices do admit faster multiplication. However, even given a matrix A that is known to have this property, it is hard in general to recover a representation of A exposing the actual fast multiplication algorithm. Additionally, it is not known in general whether the inverses of such structured matrices can be computed or multiplied quickly. A broad question is thus to identify classes of structured dense matrices that can be represented with O ( N ) parameters, and for which matrix-vector multiplication (and ideally other operations such as solvers) can be performed in a sub-quadratic number of operations. One such class of structured matrices that admit near-linear matrix-vector multiplication are the orthogonal polynomial transforms whose rows correspond to a family of orthogonal polynomials. Other well known classes include the Toeplitz, Hankel, Vandermonde, Cauchy matrices and their extensions (e. g. confluent Cauchy-like matrices) that are all special cases of a low displacement rank property. In this paper, we make progress on two fronts: 1. We introduce the notion of recurrence width of matrices. For matrices A with constant recurrence width, we design algorithms to compute both Ab and A T b with a near-linear number of operations. This notion of width is finer than all the above classes of structured matrices and thus we can compute near-linear matrix-vector multiplication for all of them using the same core algorithm. Furthermore, we show that it is possible to solve the harder problems of recovering the structured parameterization of a matrix with low recurrence width, and computing matrix-vector product with its inverse in near-linear time. 2. We additionally adapt our algorithm to a matrix-vector multiplication algorithm for a much more general class of matrices with displacement structure: those with low displacement rank with respect to quasiseparable matrices. This result is a novel connection between matrices with displacement structure and those with rank structure, two large but previously separate classes of structured matrices. This class includes Toeplitz-plus-Hankel-like matrices, the Discrete Trigonometric Transforms, and more, and captures all previously known matrices with displacement structure under a unified parameterization and algorithm. Our work unifies, generalizes, and simplifies existing state-of-the-art results in structured matrix-vector multiplication. Finally, we show how applications in areas such as multipoint evaluations of multivariate polynomials can be reduced to problems involving low recurrence width matrices.

NeurIPS Conference 2018 Conference Paper

Learning Compressed Transforms with Low Displacement Rank

  • Anna Thomas
  • Albert Gu
  • Tri Dao
  • Atri Rudra
  • Christopher Ré

The low displacement rank (LDR) framework for structured matrices represents a matrix through two displacement operators and a low-rank residual. Existing use of LDR matrices in deep learning has applied fixed displacement operators encoding forms of shift invariance akin to convolutions. We introduce a rich class of LDR matrices with more general displacement operators, and explicitly learn over both the operators and the low-rank component. This class generalizes several previous constructions while preserving compression and efficient computation. We prove bounds on the VC dimension of multi-layer neural networks with structured weight matrices and show empirically that our compact parameterization can reduce the sample complexity of learning. When replacing weight layers in fully-connected, convolutional, and recurrent neural networks for image classification and language modeling tasks, our new classes exceed the accuracy of existing compression approaches, and on some tasks even outperform general unstructured layers while using more than 20x fewer parameters.

ICML Conference 2018 Conference Paper

Representation Tradeoffs for Hyperbolic Embeddings

  • Frederic Sala
  • Christopher De Sa
  • Albert Gu
  • Christopher Ré

Hyperbolic embeddings offer excellent quality with few dimensions when embedding hierarchical data structures. We give a combinatorial construction that embeds trees into hyperbolic space with arbitrarily low distortion without optimization. On WordNet, this algorithm obtains a mean-average-precision of 0. 989 with only two dimensions, outperforming existing work by 0. 11 points. We provide bounds characterizing the precision-dimensionality tradeoff inherent in any hyperbolic embedding. To embed general metric spaces, we propose a hyperbolic generalization of multidimensional scaling (h-MDS). We show how to perform exact recovery of hyperbolic points from distances, provide a perturbation analysis, and give a recovery result that enables us to reduce dimensionality. Finally, we extract lessons from the algorithms and theory above to design a scalable PyTorch-based implementation that can handle incomplete information.

JMLR Journal 2018 Journal Article

Weighted SGD for $\ell_p$ Regression with Randomized Preconditioning

  • Jiyan Yang
  • Yin-Lam Chow
  • Christopher Ré
  • Michael W. Mahoney

In recent years, stochastic gradient descent (SGD) methods and randomized linear algebra (RLA) algorithms have been applied to many large-scale problems in machine learning and data analysis. SGD methods are easy to implement and applicable to a wide range of convex optimization problems. In contrast, RLA algorithms provide much stronger performance guarantees but are applicable to a narrower class of problems. We aim to bridge the gap between these two methods in solving constrained overdetermined linear regression problems---e.g., $\ell_2$ and $\ell_1$ regression problems. We propose a hybrid algorithm named PWSGD that uses RLA techniques for preconditioning and constructing an importance sampling distribution, and then performs an SGD-like iterative process with weighted sampling on the preconditioned system. By rewriting a deterministic $\ell_p$ regression problem as a stochastic optimization problem, we connect PWSGD to several existing $\ell_p$ solvers including RLA methods with algorithmic leveraging (RLA for short). We prove that PWSGD inherits faster convergence rates that only depend on the lower dimension of the linear system, while maintaining low computation complexity. Such SGD convergence rates are superior to other related SGD algorithm such as the weighted randomized Kaczmarz algorithm. Particularly, when solving $\ell_1$ regression with size $n$ by $d$, PWSGD returns an approximate solution with $\epsilon$ relative error in the objective value in $\mathcal{O}(\log n \cdot \text{nnz}(A) + \text{poly}(d)/\epsilon^2)$ time. This complexity is uniformly better than that of RLA methods in terms of both $\epsilon$ and $d$ when the problem is unconstrained. In the presence of constraints, PWSGD only has to solve a sequence of much simpler and smaller optimization problem over the same constraints. In general this is more efficient than solving the constrained subproblem required in RLA. For $\ell_2$ regression, PWSGD returns an approximate solution with $\epsilon$ relative error in the objective value and the solution vector measured in prediction norm in $\mathcal{O}(\log n \cdot \text{nnz}(A) + \text{poly}(d) \log(1/\epsilon) /\epsilon)$ time. We show that for unconstrained $\ell_2$ regression, this complexity is comparable to that of RLA and is asymptotically better over several state-of-the-art solvers in the regime where the desired accuracy $\epsilon$, high dimension $n$ and low dimension $d$ satisfy $d\geq 1/\epsilon$ and $n \geq d^2/\epsilon$. We also provide lower bounds on the coreset complexity for more general regression problems, indicating that still new ideas will be needed to extend similar RLA preconditioning ideas to weighted SGD algorithms for more general regression problems. Finally, the effectiveness of such algorithms is illustrated numerically on both synthetic and real datasets, and the results are consistent with our theoretical findings and demonstrate that PWSGD converges to a medium- precision solution, e.g., $\epsilon=10^{-3}$, more quickly. [abs] [ pdf ][ bib ] &copy JMLR 2018. ( edit, beta )

IJCAI Conference 2017 Conference Paper

Ensuring Rapid Mixing and Low Bias for Asynchronous Gibbs Sampling

  • Christopher De Sa
  • Kunle Olukotun
  • Christopher Ré

Gibbs sampling is a Markov chain Monte Carlo technique commonly used for estimating marginal distributions. To speed up Gibbs sampling, there has recently been interest in parallelizing it by executing asynchronously. While empirical results suggest that many models can be efficiently sampled asynchronously, traditional Markov chain analysis does not apply to the asynchronous case, and thus asynchronous Gibbs sampling is poorly understood. In this paper, we derive a better understanding of the two main challenges of asynchronous Gibbs: bias and mixing time. We show experimentally that our theoretical results match practical outcomes.

NeurIPS Conference 2017 Conference Paper

Gaussian Quadrature for Kernel Features

  • Tri Dao
  • Christopher De Sa
  • Christopher Ré

Kernel methods have recently attracted resurgent interest, showing performance competitive with deep neural networks in tasks such as speech recognition. The random Fourier features map is a technique commonly used to scale up kernel machines, but employing the randomized feature map means that $O(\epsilon^{-2})$ samples are required to achieve an approximation error of at most $\epsilon$. We investigate some alternative schemes for constructing feature maps that are deterministic, rather than random, by approximating the kernel in the frequency domain using Gaussian quadrature. We show that deterministic feature maps can be constructed, for any $\gamma > 0$, to achieve error $\epsilon$ with $O(e^{e^\gamma} + \epsilon^{-1/\gamma})$ samples as $\epsilon$ goes to 0. Our method works particularly well with sparse ANOVA kernels, which are inspired by the convolutional layer of CNNs. We validate our methods on datasets in different domains, such as MNIST and TIMIT, showing that deterministic features are faster to generate and achieve accuracy comparable to the state-of-the-art kernel methods based on random Fourier features.

NeurIPS Conference 2017 Conference Paper

Inferring Generative Model Structure with Static Analysis

  • Paroma Varma
  • Bryan He
  • Payal Bajaj
  • Nishith Khandwala
  • Imon Banerjee
  • Daniel Rubin
  • Christopher Ré

Obtaining enough labeled data to robustly train complex discriminative models is a major bottleneck in the machine learning pipeline. A popular solution is combining multiple sources of weak supervision using generative models. The structure of these models affects the quality of the training labels, but is difficult to learn without any ground truth labels. We instead rely on weak supervision sources having some structure by virtue of being encoded programmatically. We present Coral, a paradigm that infers generative model structure by statically analyzing the code for these heuristics, thus significantly reducing the amount of data required to learn structure. We prove that Coral's sample complexity scales quasilinearly with the number of heuristics and number of relations identified, improving over the standard sample complexity, which is exponential in n for learning n-th degree relations. Empirically, Coral matches or outperforms traditional structure learning approaches by up to 3. 81 F1 points. Using Coral to model dependencies instead of assuming independence results in better performance than a fully supervised model by 3. 07 accuracy points when heuristics are used to label radiology data without ground truth labels.

ICML Conference 2017 Conference Paper

Learning the Structure of Generative Models without Labeled Data

  • Stephen H. Bach
  • Bryan Dawei He
  • Alexander Ratner
  • Christopher Ré

Curating labeled training data has become the primary bottleneck in machine learning. Recent frameworks address this bottleneck with generative models to synthesize labels at scale from weak supervision sources. The generative model’s dependency structure directly affects the quality of the estimated labels, but selecting a structure automatically without any labeled data is a distinct challenge. We propose a structure estimation method that maximizes the l1-regularized marginal pseudolikelihood of the observed data. Our analysis shows that the amount of unlabeled data required to identify the true structure scales sublinearly in the number of possible dependencies for a broad class of models. Simulations show that our method is 100x faster than a maximum likelihood approach and selects 1/4 as many extraneous dependencies. We also show that our method provides an average of 1. 5 F1 points of improvement over existing, user-developed information extraction applications on real-world data such as PubMed journal abstracts.

NeurIPS Conference 2017 Conference Paper

Learning to Compose Domain-Specific Transformations for Data Augmentation

  • Alexander Ratner
  • Henry Ehrenberg
  • Zeshan Hussain
  • Jared Dunnmon
  • Christopher Ré

Data augmentation is a ubiquitous technique for increasing the size of labeled training sets by leveraging task-specific data transformations that preserve class labels. While it is often easy for domain experts to specify individual transformations, constructing and tuning the more sophisticated compositions typically needed to achieve state-of-the-art results is a time-consuming manual task in practice. We propose a method for automating this process by learning a generative sequence model over user-specified transformation functions using a generative adversarial approach. Our method can make use of arbitrary, non-deterministic transformation functions, is robust to misspecified user input, and is trained on unlabeled data. The learned transformation model can then be used to perform data augmentation for any end discriminative model. In our experiments, we show the efficacy of our approach on both image and text datasets, achieving improvements of 4. 0 accuracy points on CIFAR-10, 1. 4 F1 points on the ACE relation extraction task, and 3. 4 accuracy points when using domain-specific transformation operations on a medical imaging dataset as compared to standard heuristic augmentation approaches.

NeurIPS Conference 2016 Conference Paper

Cyclades: Conflict-free Asynchronous Machine Learning

  • Xinghao Pan
  • Maximilian Lam
  • Stephen Tu
  • Dimitris Papailiopoulos
  • Ce Zhang
  • Michael Jordan
  • Kannan Ramchandran
  • Christopher Ré

We present Cyclades, a general framework for parallelizing stochastic optimization algorithms in a shared memory setting. Cyclades is asynchronous during model updates, and requires no memory locking mechanisms, similar to Hogwild! -type algorithms. Unlike Hogwild! , Cyclades introduces no conflicts during parallel execution, and offers a black-box analysis for provable speedups across a large family of algorithms. Due to its inherent cache locality and conflict-free nature, our multi-core implementation of Cyclades consistently outperforms Hogwild! -type algorithms on sufficiently sparse datasets, leading to up to 40% speedup gains compared to Hogwild! , and up to 5\times gains over asynchronous implementations of variance reduction algorithms.

NeurIPS Conference 2016 Conference Paper

Data Programming: Creating Large Training Sets, Quickly

  • Alexander Ratner
  • Christopher De Sa
  • Sen Wu
  • Daniel Selsam
  • Christopher Ré

Large labeled training sets are the critical building blocks of supervised learning methods and are key enablers of deep learning techniques. For some applications, creating labeled training sets is the most time-consuming and expensive part of applying machine learning. We therefore propose a paradigm for the programmatic creation of training sets called data programming in which users provide a set of labeling functions, which are programs that heuristically label subsets of the data, but that are noisy and may conflict. By viewing these labeling functions as implicitly describing a generative model for this noise, we show that we can recover the parameters of this model to "denoise" the generated training set, and establish theoretically that we can recover the parameters of these generative models in a handful of settings. We then show how to modify a discriminative loss function to make it noise-aware, and demonstrate our method over a range of discriminative models including logistic regression and LSTMs. Experimentally, on the 2014 TAC-KBP Slot Filling challenge, we show that data programming would have led to a new winning score, and also show that applying data programming to an LSTM model leads to a TAC-KBP score almost 6 F1 points over a state-of-the-art LSTM baseline (and into second place in the competition). Additionally, in initial user studies we observed that data programming may be an easier way for non-experts to create machine learning models when training data is limited or unavailable.

ICML Conference 2016 Conference Paper

Ensuring Rapid Mixing and Low Bias for Asynchronous Gibbs Sampling

  • Christopher De Sa
  • Christopher Ré
  • Kunle Olukotun

Gibbs sampling is a Markov chain Monte Carlo technique commonly used for estimating marginal distributions. To speed up Gibbs sampling, there has recently been interest in parallelizing it by executing asynchronously. While empirical results suggest that many models can be efficiently sampled asynchronously, traditional Markov chain analysis does not apply to the asynchronous case, and thus asynchronous Gibbs sampling is poorly understood. In this paper, we derive a better understanding of the two main challenges of asynchronous Gibbs: bias and mixing time. We show experimentally that our theoretical results match practical outcomes.

NeurIPS Conference 2016 Conference Paper

Scan Order in Gibbs Sampling: Models in Which it Matters and Bounds on How Much

  • Bryan He
  • Christopher De Sa
  • Ioannis Mitliagkas
  • Christopher Ré

Gibbs sampling is a Markov Chain Monte Carlo sampling technique that iteratively samples variables from their conditional distributions. There are two common scan orders for the variables: random scan and systematic scan. Due to the benefits of locality in hardware, systematic scan is commonly used, even though most statistical guarantees are only for random scan. While it has been conjectured that the mixing times of random scan and systematic scan do not differ by more than a logarithmic factor, we show by counterexample that this is not the case, and we prove that that the mixing times do not differ by more than a polynomial factor under mild conditions. To prove these relative bounds, we introduce a method of augmenting the state space to study systematic scan using conductance.

NeurIPS Conference 2016 Conference Paper

Sub-sampled Newton Methods with Non-uniform Sampling

  • Peng Xu
  • Jiyan Yang
  • Fred Roosta
  • Christopher Ré
  • Michael Mahoney

We consider the problem of finding the minimizer of a convex function $F: \mathbb R^d \rightarrow \mathbb R$ of the form $F(w) \defeq \sum_{i=1}^n f_i(w) + R(w)$ where a low-rank factorization of $\nabla^2 f_i(w)$ is readily available. We consider the regime where $n \gg d$. We propose randomized Newton-type algorithms that exploit \textit{non-uniform} sub-sampling of $\{\nabla^2 f_i(w)\}_{i=1}^{n}$, as well as inexact updates, as means to reduce the computational complexity, and are applicable to a wide range of problems in machine learning. Two non-uniform sampling distributions based on {\it block norm squares} and {\it block partial leverage scores} are considered. Under certain assumptions, we show that our algorithms inherit a linear-quadratic convergence rate in $w$ and achieve a lower computational complexity compared to similar existing methods. In addition, we show that our algorithms exhibit more robustness and better dependence on problem specific quantities, such as the condition number. We numerically demonstrate the advantages of our algorithms on several real datasets.

JMLR Journal 2015 Journal Article

An Asynchronous Parallel Stochastic Coordinate Descent Algorithm

  • Ji Liu
  • Stephen J. Wright
  • Christopher Ré
  • Victor Bittorf
  • Srikrishna Sridhar

We describe an asynchronous parallel stochastic coordinate descent algorithm for minimizing smooth unconstrained or separably constrained functions. The method achieves a linear convergence rate on functions that satisfy an essential strong convexity property and a sublinear rate ($1/K$) on general convex functions. Near-linear speedup on a multicore system can be expected if the number of processors is $O(n^{1/2})$ in unconstrained optimization and $O(n^{1/4})$ in the separable- constrained case, where $n$ is the number of variables. We describe results from implementation on 40-core processors. [abs] [ pdf ][ bib ] &copy JMLR 2015. ( edit, beta )

NeurIPS Conference 2015 Conference Paper

Asynchronous stochastic convex optimization: the noise is in the noise and SGD don't care

  • Sorathan Chaturapruek
  • John Duchi
  • Christopher Ré

We show that asymptotically, completely asynchronous stochastic gradient procedures achieve optimal (even to constant factors) convergence rates for the solution of convex optimization problems under nearly the same conditions required for asymptotic optimality of standard stochastic gradient procedures. Roughly, the noise inherent to the stochastic approximation scheme dominates any noise from asynchrony. We also give empirical evidence demonstrating the strong performance of asynchronous, parallel stochastic optimization schemes, demonstrating that the robustness inherent to stochastic approximation problems allows substantially faster parallel and asynchronous solution methods. In short, we show that for many stochastic approximation problems, as Freddie Mercury sings in Queen's \emph{Bohemian Rhapsody}, ``Nothing really matters. ''

ICML Conference 2015 Conference Paper

Global Convergence of Stochastic Gradient Descent for Some Non-convex Matrix Problems

  • Christopher De Sa
  • Christopher Ré
  • Kunle Olukotun

Stochastic gradient descent (SGD) on a low-rank factorization is commonly employed to speed up matrix problems including matrix completion, subspace tracking, and SDP relaxation. In this paper, we exhibit a step size scheme for SGD on a low-rank least-squares problem, and we prove that, under broad sampling conditions, our method converges globally from a random starting point within O(ε^-1 n \log n) steps with constant probability for constant-rank problems. Our modification of SGD relates it to stochastic power iteration. We also show some experiments to illustrate the runtime and convergence of the algorithm.

NeurIPS Conference 2015 Conference Paper

Rapidly Mixing Gibbs Sampling for a Class of Factor Graphs Using Hierarchy Width

  • Christopher De Sa
  • Ce Zhang
  • Kunle Olukotun
  • Christopher Ré

Gibbs sampling on factor graphs is a widely used inference technique, which often produces good empirical results. Theoretical guarantees for its performance are weak: even for tree structured graphs, the mixing time of Gibbs may be exponential in the number of variables. To help understand the behavior of Gibbs sampling, we introduce a new (hyper)graph property, called hierarchy width. We show that under suitable conditions on the weights, bounded hierarchy width ensures polynomial mixing time. Our study of hierarchy width is in part motivated by a class of factor graph templates, hierarchical templates, which have bounded hierarchy width—regardless of the data used to instantiate them. We demonstrate a rich application from natural language processing in which Gibbs sampling provably mixes rapidly and achieves accuracy that exceeds human volunteers.

NeurIPS Conference 2015 Conference Paper

Taming the Wild: A Unified Analysis of Hogwild-Style Algorithms

  • Christopher De Sa
  • Ce Zhang
  • Kunle Olukotun
  • Christopher Ré

Stochastic gradient descent (SGD) is a ubiquitous algorithm for a variety of machine learning problems. Researchers and industry have developed several techniques to optimize SGD's runtime performance, including asynchronous execution and reduced precision. Our main result is a martingale-based analysis that enables us to capture the rich noise models that may arise from such techniques. Specifically, we useour new analysis in three ways: (1) we derive convergence rates for the convex case (Hogwild) with relaxed assumptions on the sparsity of the problem; (2) we analyze asynchronous SGD algorithms for non-convex matrix problems including matrix completion; and (3) we design and analyze an asynchronous SGD algorithm, called Buckwild, that uses lower-precision arithmetic. We show experimentally that our algorithms run efficiently for a variety of problems on modern hardware.

ICML Conference 2014 Conference Paper

An Asynchronous Parallel Stochastic Coordinate Descent Algorithm

  • Ji Liu 0002
  • Stephen J. Wright 0001
  • Christopher Ré
  • Victor Bittorf
  • Srikrishna Sridhar

We describe an asynchronous parallel stochastic coordinate descent algorithm for minimizing smooth unconstrained or separably constrained functions. The method achieves a linear convergence rate on functions that satisfy an essential strong convexity property and a sublinear rate (1/K) on general convex functions. Near-linear speedup on a multicore system can be expected if the number of processors is O(n^1/2) in unconstrained optimization and O(n^1/4) in the separable-constrained case, where n is the number of variables. We describe results from implementation on 40-core processors.

NeurIPS Conference 2014 Conference Paper

Parallel Feature Selection Inspired by Group Testing

  • Yingbo Zhou
  • Utkarsh Porwal
  • Ce Zhang
  • Hung Ngo
  • XuanLong Nguyen
  • Christopher Ré
  • Venu Govindaraju

This paper presents a parallel feature selection method for classification that scales up to very high dimensions and large data sizes. Our original method is inspired by group testing theory, under which the feature selection procedure consists of a collection of randomized tests to be performed in parallel. Each test corresponds to a subset of features, for which a scoring function may be applied to measure the relevance of the features in a classification task. We develop a general theory providing sufficient conditions under which true features are guaranteed to be correctly identified. Superior performance of our method is demonstrated on a challenging relation extraction task from a very large data set that have both redundant features and sample size in the order of millions. We present comprehensive comparisons with state-of-the-art feature selection methods on a range of data sets, for which our method exhibits competitive performance in terms of running time and accuracy. Moreover, it also yields substantial speedup when used as a pre-processing step for most other existing methods.