Inference-Pleasant Fashions with MixAttention | Databricks Weblog

Transformer fashions, the spine of recent language AI, depend on the eye mechanism to course of context when producing output. Throughout inference, the eye mechanism works by computing the important thing and worth vectors for every token seen thus far, and utilizing these vectors to replace the inner illustration of the following token which can be output. As a result of the identical key and worth vectors of the previous tokens get reused each time the mannequin outputs a brand new token, it’s customary apply to cache it in a knowledge construction known as the Key-Worth (KV) cache. For the reason that KV cache grows proportionally to the variety of tokens seen thus far, KV cache dimension is a significant factor in figuring out each the utmost context size (i.e., the utmost variety of tokens) and the utmost variety of concurrent requests that may be supported for inference on trendy language fashions. Notably for lengthy inputs, LLM inference can be dominated by the I/O value of transferring the KV cache from Excessive Bandwidth Reminiscence (HBM) to the GPU’s shared reminiscence. Due to this fact, lowering the KV cache dimension has the potential to be a robust methodology to hurry up and scale back the price of inference on trendy language fashions. On this put up, we discover concepts just lately proposed by Character.AI for decreasing KV cache dimension by changing many of the layers within the community with sliding window consideration (a type of native consideration that solely makes use of the important thing and worth vectors of a small variety of most up-to-date tokens) and sharing the KV cache amongst layers. We name this structure MixAttention; our experiments with totally different variants of this structure have demonstrated that it maintains each quick and lengthy context mannequin high quality whereas enhancing the inference pace and reminiscence footprint.

 

Figure 2 bar charts
Fig. 1: Velocity and accuracy of MixAttention mannequin variants. (Mannequin variants proven in Fig. 2). High: We see that MixAttention fashions are sooner and use much less reminiscence throughout inference at 32K context size. Backside: MixAttention fashions preserve high quality – they match the usual consideration mannequin on most evals. The fashions are all Combination of Consultants with 2B energetic and 5B complete parameters.

We discovered that KV cache sharing between layers and including sliding window layers can pace up inference and scale back inference reminiscence utilization whereas sustaining mannequin high quality, though some eval metrics present some degradation. As well as, our ablation experiments confirmed the next:

 

  • Having a number of customary consideration layers is essential for the mannequin’s lengthy context talents. Specifically, having the usual KV cache computed within the deeper layers is extra vital for lengthy context talents than the usual KV cache of the primary few layers.
  • KV cache of normal consideration layers could be shared between non-consecutive layers with none noticed degradation in lengthy context talents.
  • Growing the KV-cache sharing between sliding window layers an excessive amount of additionally hurts lengthy context talents. 

We have now supplied a information to configuring and coaching MixAttention fashions utilizing LLM Foundry within the appendix of this weblog put up.

 

Image 2 Mix Attention Blog
Determine 2: (Left) An ordinary transformer mannequin the place all layers are customary consideration layers. (Center) Inference-friendly fashions with MixAttention. Inexperienced bars characterize sliding window consideration and the traces connecting bars characterize KV cache sharing. (Proper) A mannequin the place all layers are sliding window consideration.

MixAttention Structure Overview

Commonplace transformer fashions use international consideration in every layer. To create inference-friendly mannequin architectures, we used a mix of sliding window consideration layers, customary consideration, and KV cache reuse layers. Beneath is a quick dialogue of every element:

 

  • Sliding Window Consideration Layers: In Sliding Window Consideration (or Native Consideration) with window dimension s, the question solely pays consideration to the final s keys as a substitute of all of the keys previous it. Which means that throughout inference, the KV cache dimension must solely retailer the KV tensors for the previous s tokens as a substitute of storing the KV tensors for all of the previous tokens. In our experiments, we set a window dimension of s=1024 tokens.
  • Commonplace Consideration Layers: We discovered that although Commonplace Consideration Layers result in larger KV caches and slower consideration computation in comparison with Sliding Window Consideration, having a number of Commonplace Consideration Layers is essential for the mannequin’s lengthy context talents.
  • KV cache reuse: This refers to a layer within the transformer community that’s reusing the KV cache computed by a earlier layer. Therefore, if each l layers share KV tensors, then the dimensions of KV cache is lowered by issue of 1/l.

 

We experimented with totally different mixtures of the parts above to ablate the results of every of them. (Extra mixtures are described within the appendices.) We discovered that not solely do every of the above parts play vital roles in lengthy context talents and inference pace and reminiscence consumption, but additionally their relative positions and counts have important results on these metrics.

 

The fashions we educated are 24-layer Combination of Consultants (MoE) fashions with 1.64B energetic and 5.21B complete parameters. We used RoPE positional embeddings and elevated the RoPE base theta as we elevated the context size throughout coaching. We used Grouped Question Consideration with 12 consideration heads and three KV heads.

 

Coaching

We used LLM Foundry to coach MixAttention fashions. Much like prior work on coaching lengthy context fashions, we adopted a multi-stage coaching process to impart lengthy context talents to the fashions.

 

  1. We pretrained the fashions with a RoPE theta of 0.5M on 101B tokens, the place every sequence has been truncated to 4k token size.
  2. To extend the context size, we then educated the mannequin on 9B tokens from a mixture of pure language and code information, the place the sequences have been truncated to 32k tokens. We elevated the RoPE theta to 8M for this stage. When coaching at 32k context size, we educated solely the eye weights and froze the remainder of the community. We discovered that this delivered higher outcomes than full community coaching.
  3. Lastly, we educated the mannequin on a 32k-length, artificial, long-context QA dataset.
    • To create the dataset, we took pure language paperwork and chunked them into 1k-token chunks. Every chunk was then fed to a pretrained instruction mannequin and the mannequin was prompted to generate a question-answer pair based mostly on the chunk. Then, we concatenated chunks from totally different paperwork collectively to function the “lengthy context.” On the finish of this lengthy context, the question-answer pairs for every of the chunks had been added. The loss gradients had been computed solely on the reply elements of those sequences. 
    • This part of coaching was carried out on 500M tokens (this quantity contains the tokens from the context, questions, and solutions). The RoPE theta was stored at 8M for this stage.

Analysis

The fashions had been evaluated on the Mosaic Analysis Gauntlet to measure mannequin high quality throughout varied metrics together with studying comprehension, commonsense reasoning, world information, symbolic downside fixing, and language understanding. To guage the fashions’ lengthy context talents, we used RULER at a context size of 32000 tokens. RULER is a composite benchmark consisting of 13 particular person evals of the next varieties:

 

  • Needle-in-a-haystack (NIAH): These kind of evals conceal a single or a number of keys and values in a protracted textual content, and the mannequin is evaluated on its capacity to retrieve the proper worth(s) from the lengthy context for a given key(s).
  • Variable Monitoring (VT): This eval supplies the mannequin with a protracted context containing variable task statements, and the mannequin is tasked to determine which variables have a selected worth by the top of all of the variable assignments.
  • Widespread and Frequent Phrase Extraction (CWE and FWE): These duties ask the mannequin to extract the commonest or frequent phrases from the textual content.
  • Query Answering (QA): Given a protracted context, the mannequin is requested a query from someplace within the context and is evaluated on whether or not it may appropriately reply that query.

 

We used SGLang to deploy our fashions on 1 NVIDIA H100 GPU to run RULER and get inference pace and reminiscence consumption metrics.

Outcomes

Place and Depend of Commonplace Consideration KV Caches

To measure the impact of the place and depend of the usual consideration KV caches, we tried 4 variants. All of the configurations are variants of the configuration proposed in Character.AI’s weblog put up.

MixAttention Image 3
Determine 3: KV Cache place and counts. To measure the impact of the place and depend of the usual consideration KV caches on MixAttention’s lengthy context talents, we educated and evaluated the 4 fashions proven above.
  1. MA: This variant has a single customary consideration KV cache, which is the KV cache of the primary layer. All the opposite customary consideration layers share this KV cache.
  2. MA-EndSlide: This variant is similar as MA, however the final layer is a sliding window consideration layer. This was performed to measure how a lot having customary consideration within the final layer impacts long-context talents.
  3. MA-Offset: This variant is much like MA, however the first customary consideration layer is offset to a later layer to permit the mannequin to course of the native context for a number of layers earlier than the usual consideration layer is used to have a look at longer contexts.
  4. MA-Pairs: This variant computes two customary consideration KV caches (on the first and thirteenth layers), that are then shared with one other customary consideration layer every.

We in contrast these fashions to a transformer mannequin with Commonplace Consideration and a transformer mannequin with Sliding Window Consideration in all layers.

MixAttention image 4

MixAttention_image5
Fig. 4 and 5: Impact of Commonplace Consideration Layers. (High) Loss curves of the fashions when high quality tuning on lengthy context QA dataset. (Backside) RULER evals for the fashions. MA and MA-EndSlide carry out poorly on lengthy context duties whereas MA-Offset and MA-Pairs carry out properly. This means that having a regular consideration KV cache which is computed in later layers is vital for lengthy context talents. We additionally discovered that the loss on lengthy context QA dataset correlates properly with the mannequin’s lengthy context talents.

Whereas the loss curves in Levels 1 and a pair of of Coaching had been shut for all of the fashions, we discovered that in Stage 3 (coaching on lengthy context QA dataset), there was a transparent bifurcation within the loss curves. Specifically, we see that configurations MA and MA-EndSlide present a lot worse loss than the others. These outcomes are in line with the lengthy context RULER evals, the place we discovered that MA and MA-EndSlide carried out a lot worse than others. Their efficiency was much like the efficiency of the community with solely sliding window consideration in all layers. We predict the loss in Stage 3 correlates properly with RULER evals as a result of not like Levels 1 and a pair of, which had been next-word prediction duties the place native context was adequate to foretell the following phrase more often than not, in Stage 3 the mannequin wanted to retrieve the proper info from doubtlessly long-distance context to reply the questions. 

 

As we see from the RULER evals, MA-Offset and MA-Pairs have higher long-context talents than MA and MA-EndSlide throughout all of the classes. Each MA and MA-EndSlide have just one customary consideration KV cache, which is computed within the first layer, whereas each MA-Offset and MA-Pairs have a minimum of one customary consideration KV cache which is computed in deeper layers.  Therefore, this means that having a minimum of one customary consideration KV cache computed within the deeper layers of a transformer mannequin is critical for good long-context talents.

KV cache sharing in sliding window layers

MixAttention Image 6
Fig. 6: Growing KV cache sharing in sliding window layers. To measure the impact of KV cache sharing within the sliding window layers, we in contrast the architectures proven within the determine above.

Mix Attention Image 7

Mix Attention Image 8
Fig. 7 and eight: Impact of accelerating KV cache sharing in sliding window layers. (High) Loss curves of the fashions when high quality tuning on lengthy context QA dataset. (Backside) RULER evals for the fashions. We discovered that growing the KV cache sharing in sliding window layers worsened lengthy context talents of MixAttention Fashions.

We discovered that growing the sharing between sliding window layers degraded the mannequin’s lengthy context efficiency: MA-Offset-slide-share was worse than MA-Offset and MA-Pairs-SlideShare was worse than MA-Pairs. This exhibits that the KV cache sharing sample amongst the sliding window layers can also be vital for lengthy context talents.

 

We have now supplied the outcomes of some extra ablation experiments within the appendices.

Gauntlet Evals

Utilizing the Mosaic Eval Gauntlet v0.3.0, we additionally measured the efficiency of MixAttention fashions on customary duties like MMLU, HellaSwag, and many others. to confirm that they maintain good shorter context talents. The entire duties on this eval suite have context lengths of lower than a number of thousand tokens.

MixAttention Figure 9
Fig. 9: Efficiency of MixAttention fashions on the Eval Gauntlet. We discovered that MixAttention fashions have related eval metrics to the baseline mannequin on commonsense reasoning, language understanding, and world information. Nonetheless, we see that they carry out worse on studying comprehension.

We discovered that MixAttention fashions have related eval metrics to the baseline mannequin on commonsense reasoning, language understanding, and world information; nevertheless, they carried out worse on studying comprehension. An attention-grabbing open query is that if studying comprehension talents may very well be improved with a special MixAttention configuration or by coaching MixAttention fashions longer.

Inference Velocity and Reminiscence Consumption

Mix Attention Image 10

MixAttention Image 11
Fig. 10 and 11: (High) MixAttention fashions have considerably sooner inference than customary transformers. (Backside) MixAttention fashions can assist extra tokens, and thus bigger batch sizes, throughout inference.

We benchmarked the inference pace and reminiscence consumption of MixAttention fashions by deploying them on a single NVIDIA H100 GPU utilizing SGLang and querying them with 300 prompts, with an enter size of 31000 and output size of 1000. Within the determine, we present that the inference pace of MixAttention fashions is far sooner than customary consideration fashions. We additionally present that with MixAttention, we are able to assist a a lot bigger inference batch dimension when it comes to the full variety of tokens. 

 

We discovered that the present implementation of Sliding Window Consideration in SGLang doesn’t optimize the reminiscence consumption for sliding window consideration; therefore, sliding window consideration has the identical most variety of tokens as the usual consideration Mannequin. Optimizing the reminiscence consumption for sliding window consideration ought to additional enhance the utmost variety of tokens that MixAttention can assist throughout inference.

Conclusion

We discovered that MixAttention fashions are aggressive with customary consideration fashions on each long- and short-context talents whereas being sooner throughout inference and supporting bigger batch sizes. We additionally noticed that on some lengthy context duties like Variable Monitoring and Widespread Phrase Extraction, neither MixAttention nor customary consideration fashions carried out properly. We imagine this was as a result of our fashions weren’t educated lengthy sufficient or the fashions want a special form of lengthy context information to be educated for such duties. Extra analysis must be performed to measure the impression of MixAttention architectures on these metrics.

 

We encourage others to discover extra MixAttention architectures to be taught extra about them. Beneath are a number of observations to assist with additional analysis:

 

  • Including a regular consideration layer within the preliminary layers by itself doesn’t appear to assist lengthy context talents (for instance, see MA-NoShare-1 within the appendix), even when the KV cache from that layer is reused in layers deeper into the community (MA and MA-EndSlide). Therefore we suggest inserting the primary customary consideration layer deeper within the community (like MA-Offset) or having a number of customary consideration layers, a minimum of considered one of which is computed at a deeper layer (like MA-Pairs).
  • Sliding window layers additionally contribute to the mannequin’s lengthy context talents. Growing the KV cache sharing amongst the sliding window layers worsened lengthy context talents (MA-Offset-SlideShare and MA-Pairs-SlideShare). For that cause, we expect that the 2-3 sharing sample in sliding window layers appears to strike a very good stability.
  • Sharing full consideration KV caches between consecutive layers gave combined outcomes, with barely worse accuracy on lengthy context QA duties (see the appendix). 
  • In our experiments, MA-Offset and MA-Pair confirmed nice speedup and reminiscence financial savings throughout inference, whereas additionally sustaining lengthy and quick context talents. Therefore, MA-Offset and MA-Pairs is perhaps good configurations for additional analysis.
  • MixAttention fashions could be educated with LLM Foundry. Please see the appendix for pointers.

 

Normally, there’s a massive hyperparameter house to discover, and we look ahead to seeing a wide range of new methods for decreasing the price of inference by way of mixtures of sliding window consideration and KV cache reuse.

Appendix: Utilizing LLM Foundry to coach MixAttention fashions

The way in which to configure MixAttention fashions with LLM Foundry is to make use of the block_overrides characteristic. The block_overrides definition consists of two sections: order and overrides. The order key defines the ordering and the names of the layers within the community, whereas the overrides key accommodates the customized configuration of every named layer. 

 

For instance, to create a 5 layer community with the primary two layers being the usual consideration layers, the following two being the sliding window layers, and the final one being a regular consideration layer, we use the next YAML:

CodeSnippet1

Right here, the order part conveys that the primary two layers are of sort ‘default’, the following two are of sort ‘sliding_window_layer’, and the final is of sort ‘default’ once more. The definitions of every of those varieties are contained within the overrides part utilizing the names outlined within the order part. It says that the ‘sliding_window_layer ought to have a sliding_window_size of 1024. Be aware that ‘default’ is a particular sort, which doesn’t want a definition within the overrides part as a result of it simply refers back to the default layer (on this case, a regular consideration layer). Additionally, word that ‘sliding_window_layer‘ is only a customized identify and could be changed with some other arbitrary identify so long as that identify is correspondingly additionally outlined within the overrides part.

 

The mannequin configuration is printed within the logs, which can be utilized to substantiate that the mannequin is configured appropriately. For instance, the above YAML will consequence within the following being printed within the logs:

CodeSnippet2

We will additionally configure the 2 sliding window layers to have totally different sliding window sizes as follows:

CodeSnippet3

The above will consequence within the third layer having a sliding window dimension of 1024, and the fourth layer having a sliding window dimension of 512. Be aware that the repeat key phrase defaults to 1. So, the above YAML can be written as:

CodeSnippet4

The repeat key phrase can also be relevant to the order key phrase. So, if we need to create a 4 layer community with alternating customary and sliding window consideration layers like the next,

MixAttention Appendix 1

then we are able to use the next YAML:

CodeSnippet5

To make a layer reuse the KV cache of a earlier layer, we use reuse_kv_layer_idx within the attn_config within the override definition. The important thing reuse_kv_layer_idx accommodates the relative layer index whose KV cache we wish this layer to reuse. To make a two layered community the place the second layer reuses the primary layer’s KV cache, we are able to use the next YAML:

CodeSnippet6

The worth -1 signifies that the layer named kv_reuse_layer reuses the KV cache of the layer that’s one layer earlier than it. To create a 5 layer community with the next configuration

Mix Attention Appendix Image 2

we are able to use the next YAML:

CodeSnippet7

Be aware that within the above configuration, layer #4 reuses the KV cache of layer #3, which in flip reuses the KV cache of layer #2. Therefore, layer #4 finally ends up reusing the KV cache of layer #2.

 

Lastly, word that order could be outlined recursively; that’s, the order can comprise one other order sub-block. For instance, MA-Offset-SlideShare

Appendix 3 image

could be outlined as follows:

CodeSnippet8

Appendix: Different Ablation Experiments

Sharing Commonplace Consideration KV Caches between Consecutive Layers

For the reason that transformer layers progressively replace the latent illustration of a token because it progresses via the layers, the Question, Key, and Worth tensors may need considerably totally different representations for layers which might be far aside. Therefore, it would make extra sense to share KV caches between consecutive layers. To check this, we in contrast 4 such configurations: MA-Successive-1, MA-Successive-2, MA-Successive-3, and MA-Successive-4 towards MA-Pairs. These configurations differ the positions of the usual KV consideration layers and the space between the consecutive pairs of normal KV consideration layers.

MixAttention image 4
KV cache sharing between consecutive layers: To measure the impact of KV cache sharing between consecutive layers, we tried the 4 configurations above.

 

For the reason that transformer layers progressively replace the latent illustration of a token because it progresses via the layers, the Question, Key, and Worth tensors may need considerably totally different representations for layers which might be far aside. Therefore, it would make extra sense to share KV caches between consecutive layers. To check this, we in contrast 4 such configurations: MA-Successive-1, MA-Successive-2, MA-Successive-3, and MA-Successive-4 towards MA-Pairs. These configurations differ the positions of the usual KV consideration layers and the space between the consecutive pairs of normal KV consideration layers.

MixAttention appendix 5

MixAttention appendix 6
Impact of KV cache sharing between consecutive layers: (High) Loss curves of the fashions when high quality tuning on lengthy context QA dataset. (Backside) RULER evals for the fashions. We discovered that KV cache sharing between consecutive layers doesn’t constantly enhance lengthy context talents throughout all evals. Nonetheless, for duties like  SQuAD QA and Hotpot QA, which could be indicative of lengthy context RAG talents, the efficiency was barely worse when sharing KV cache between consecutive layers.

We decided that each one the fashions have related loss curves and related efficiency on NIAH single 1, 2, and three duties, which we think about to be the simplest lengthy context duties. Nonetheless, we didn’t see a constant sample throughout the opposite NIAH duties. For lengthy context QA duties, we discovered that MA-Pairs was barely higher than the others. These outcomes point out that sharing customary consideration KV cache between layers which might be additional aside doesn’t result in any important degradation in lengthy context talents as in comparison with sharing customary consideration KV cache between consecutive layers.

Impact of Sharing Commonplace Consideration KV Cache

MixAttention appendix 7
No customary consideration KV-cache sharing: To measure the impact of KV cache sharing between customary consideration layers we evaluate the architectures proven within the determine above.

MixAttention appendix 8

MixAttention appendix 9
Impact of no customary consideration KV-cache sharing: (High) Loss curves of the fashions when high quality tuning on lengthy context QA dataset. (Backside) RULER evals for the fashions. We discovered that each MA-NoShare-2 and MA-NoShare-3 had been comparable with MA-Offset.

 

To check the impact of sharing the KV cache between customary consideration layers, we tried out three configurations: MA-NoShare-1, MA-NoShare-2, and MA-NoShare-3. We discovered that MA-NoShare-1 carried out very badly on RULER, indicating its lack of lengthy context talents. Nonetheless, MA-NoShare-2 and MA-NoShare-3 had been akin to MA-Offset on lengthy context duties. Therefore, we expect that additional analysis is required to establish the results of sharing customary consideration KV cache.

Leave a Reply

Your email address will not be published. Required fields are marked *