Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Self-Extend to the gemma.cpp #60

Open
namtranase opened this issue Feb 28, 2024 · 19 comments
Open

Add Self-Extend to the gemma.cpp #60

namtranase opened this issue Feb 28, 2024 · 19 comments
Assignees
Labels
Feature New feature or request good first issue Good for newcomers stat:awaiting response Status - Awaiting response from author

Comments

@namtranase
Copy link

Hi team, I checked the locallama and found that gemma can work well with the Self-Extend method. It would be awesome if this technique could be added to the gemma.cpp.
References:

@austinvhuang austinvhuang added the Feature New feature or request label Feb 28, 2024
@austinvhuang
Copy link
Collaborator

This seems interesting and quite doable. I'll need to have a closer look at the paper and revisit tomorrow.

On the tactical side, we'll want to tidy up the APIs + dispatch mechanisms multiple alternative inference graphs. The dispatch mechanisms are ok for the limited set of 7B/2B x IT/PT but could use a refactor before we add more combinations of inference paths.

@ahxt
Copy link

ahxt commented Feb 28, 2024

Glad to see that our method works well with Gemma!! Our python implementation is here https://github.com/datamllab/LongLM/blob/master/gemma_self_extend_patch.py and the llama.cpp implementation is here https://github.com/ggerganov/llama.cpp/blob/cb49e0f8c906e5da49e9f6d64a57742a9a241c6a/examples/main/main.cpp#L569

We are glad to help!!!

@Mooler0410
Copy link

Mooler0410 commented Feb 28, 2024

Author here, glad to answer any questions about details for our work.

@austinvhuang
Copy link
Collaborator

If someone wants to take a stab at this as a flag, happy to have a look at the PR / provide suggestions (add yourself as the assignee for this issue).

There's an enhancement that i think would improve the usefulness of this is %save %load commands for KV cache state. Using the blob store headers, I think this wouldn't be that hard to implement. Might be a good first issue for someone who's comfortable with the codebase. I think this would lead to a lot of use cases that would otherwise be impractical.

@jan-wassenberg
Copy link
Member

+1, we'd welcome a pull request for this, also happy to discuss.

@jan-wassenberg jan-wassenberg added the good first issue Good for newcomers label Jul 15, 2024
@KumarGitesh2024 KumarGitesh2024 added the stat:awaiting response Status - Awaiting response from author label Jul 17, 2024
@jonpsy
Copy link

jonpsy commented Aug 20, 2024

@austinvhuang @jan-wassenberg I'd like to take a stab at this, if you nobody has objections?

My background: I've been trying to break into this field, and I've had the pleasure of collaborating with the Google Team in the past for TFLite Support repository.

@jan-wassenberg
Copy link
Member

Nice, sounds great, we'd be happy to collaborate with you, discuss and review :)

FYI the KVCache internals will likely change a bit to use RowVectorBatch at some point, but no big deal.

Is there anything in the current code that you think will cause difficulties?

InferenceArgs is probably a good place to add the flag.

@jonpsy
Copy link

jonpsy commented Aug 22, 2024

Perfect, sorry for the delay, I can spin something up over the weekend. Please allow some time to read the codebase and get back with a proposal

@jonpsy
Copy link

jonpsy commented Aug 22, 2024

Had a first pass through the paper, the paper has proven its ability only on RoPE position encodings, and the theory is supported only for relative position encodings. i.e. there's no proof of it working if we were training via sinusoidal positional encoding.

Shouldn't we have some kind of check for this?

cc: @Mooler0410 @ahxt
image

@jan-wassenberg
Copy link
Member

Hi, https://arxiv.org/pdf/2401.01325 mentions an experiment with non-RoPE also working. We mostly do use RoPE, though.
We can certainly mention in the flag description that this depends on the positional encoding.

@KumarGitesh2024 KumarGitesh2024 added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting response Status - Awaiting response from author labels Aug 29, 2024
@jonpsy
Copy link

jonpsy commented Sep 6, 2024

@jan-wassenberg @austinvhuang @Mooler0410

Okay, it took me a long time to understand how transformers work and even longer to understand how this repository is implementing that. So please forgive some mistakes below and do comment if I'm making any mistakes

The paper argues that O.O.D occurs due to it not having seen the positions outside of its trained context window. So it uses grouped attention by grouping position into gbatch_size/group_size_1 s.t position ids is shared within a group. It's applied on position ids > nb_window_size/group_size_2.

For reference here is the code snippet from LongLM

image

We already have RoPe implemented as PostQK.

In Gemma, we have

  • activations: A cache to store all the data related to one tbatch_size (token batch size) amount of data.
// gemma/activations.h

class Activations {

 // Same as before, need space for s_g_pos states
 RowVectorBatch<float> grp_q; (size: (batch_size, kHeads * QKVDim) // only Q head is flattened
}

We use do fwd pass twice once for PreFill and once for GenerateT

I think q matrix has the following shape

# rows = token batch size
row 1 [ ----- head 1 (QKV flattened) ------, ---- head 2 (QKV flattened) ----, --- head 3 (QKV flattened) ---- ]

row 2 [ ---- head 1 (QKV flattened) -----, --- head 2 ---...... (Same as before)
  1. Would need a flag inside PostQK so that when applying positional encodings for K, we apply positional encoding grp_k as well.
  // Apply positional encodings for K (and copy KV to cache if MHA).
  pool.Run( 
....
           if constexpr (kIsMHA) {
          // For MHA, copy KV into the KV cache from scratch space (see above).
          const float* HWY_RESTRICT q =
              activations.q.Batch(interleaved_idx) + head * kQStride;
           
          const grp_k = kv_caches[query_idx].g_k_cache() // Another float pointer to store grouped key cache? 
          
          // Apply Rope on Grouped keys
          const size_t ngb_size = TConfig::self_extend_ngb_size;
          const size_t grp_size = TConfig::self_extend_grp_size;
          
          // First, group the key positional embedding
          const grp_k_pos = pos / grp_size; 
         
          // Now apply RoPE based on this, will come in handy later
          RoPE(grp_k, qkvDim, grp_k_pos);
  1. We also need to compute grp_q as we stored in Activations and calculate score
// This should be done during our fwd pass
  // For each head (token, query), compute Q.K, softmax, and weighted V.
  pool.Run(
      0, kHeads * num_interleaved,
      [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
        ...
        KVCache& kv_cache = kv_caches[query_idx];
        float* HWY_RESTRICT q =
            activations.q.Batch(interleaved_idx) + head * kQStride;

        // Apply rope and scaling to Q.
        const size_t pos = batch_start + batch_idx;
        PostQK<TConfig>(q, pos, layer);
        MulByConst(kQueryScale, q, kQKVDim);


       // Now let's find g_q based on new pos
       if (pos > ngb_size && TConfig::kEnableSelfExtend) {
          g_q = // Same as activations.q which we got during MatMul stored in a different memory block
          const size_t s_g_pos = ngb_size  + (pos - grp_size) / ngb_size;
          PostQK<TConfig>(g_q, s_g_pos, layer);
       }

       // Score using g_q and g_k not query and K
       if (pos > ngb_size && TConfig::kEnableSelfExtend) 
       const score = Dot(g_q, grp_k, QKVDim);

@jonpsy
Copy link

jonpsy commented Sep 6, 2024

If you're interested in mentoring me to help merge this, I can convert the above into a Tech doc in Google Docs and you can point out things there.

Awaiting your positive response

@jan-wassenberg
Copy link
Member

Hi, great that you've mapped this to our code :) This sounds correct to me.
Note that activations.q shape depends on QStride. For MHA it is indeed Q,K,V, otherwise only Q.

Would be happy to advise, but I will be without internet access from Sep11-17. A doc sounds fine, or we can also proceed directly to a pull request if you prefer?

@KumarGitesh2024 KumarGitesh2024 added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting response Status - Awaiting response from author labels Sep 10, 2024
@jonpsy
Copy link

jonpsy commented Sep 11, 2024

I'd love that. I'm aware of values being stored as [Q, Q,...., K, K ...., V, V..] in non MHA cases. I was under the assumption that we'll be restricting it to purely MHA case, guess that's wrong.

I was thinking of a doc so that we could have a unified tracker, objectives and next-plans where you could easily comment/edit. For reference, I've done something similar with TensorFlow team when I collaborated with them.

Great! I can definitely fast-track a pull request for this! I understand you wouldn't be available from Sept11-17, given I'm also handling a full-time job on the side. I can try and spin a MVP while you're unavailable.

Let me know if this arrangement works for you.

@jonpsy
Copy link

jonpsy commented Sep 17, 2024

I have holidays from 23rd - 27th, where I can dedicate good time to this and do a bulk amount of work. I'd love go back-n-forth with the feedbacks and suggestions so we can quickly integrate this amazing feature!

Just FYI, we'd also need to plan how we'd test if its working and benchmark to confirm its practicality. I can make an .md file for the report.

@jan-wassenberg
Copy link
Member

:) Yes, the various Gemma models can be any of MHA, GQA or MQA, so we have to handle both the MHA and non-MHA case.

Sure, a doc sounds useful for commenting. Thanks for suggesting it :)

Yes indeed, I'm back and catching up. Working on this together next week sounds great!

plan how we'd test if its working and benchmark to confirm its practicality

I am not very familiar with evals in this space, but understand that summarization and needle-in-haystack tests might work. Do you or others have any particular suggestions?

@jonpsy
Copy link

jonpsy commented Sep 22, 2024

@jan-wassenberg yes! The original code proposed needle-in-a-haystack problem and they also have a dataset for this. See for example

They've used Gemma-7B-it for one of their benchmark as is evident here which we do have present with us.

@jan-wassenberg
Copy link
Member

Got it, needle in a haystack sounds like a good eval then :)

@jonpsy
Copy link

jonpsy commented Oct 18, 2024

Updated with the code, could you please check @jan-wassenberg. Will create a tech doc and eval strat soon. Been busy for a while this took longer than expected, but now I'm ready

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature New feature or request good first issue Good for newcomers stat:awaiting response Status - Awaiting response from author
Projects
None yet
Development

No branches or pull requests

7 participants