-
Notifications
You must be signed in to change notification settings - Fork 509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Self-Extend to the gemma.cpp #60
Comments
This seems interesting and quite doable. I'll need to have a closer look at the paper and revisit tomorrow. On the tactical side, we'll want to tidy up the APIs + dispatch mechanisms multiple alternative inference graphs. The dispatch mechanisms are ok for the limited set of 7B/2B x IT/PT but could use a refactor before we add more combinations of inference paths. |
Glad to see that our method works well with Gemma!! Our python implementation is here https://github.com/datamllab/LongLM/blob/master/gemma_self_extend_patch.py and the llama.cpp implementation is here https://github.com/ggerganov/llama.cpp/blob/cb49e0f8c906e5da49e9f6d64a57742a9a241c6a/examples/main/main.cpp#L569 We are glad to help!!! |
Author here, glad to answer any questions about details for our work. |
If someone wants to take a stab at this as a flag, happy to have a look at the PR / provide suggestions (add yourself as the assignee for this issue). There's an enhancement that i think would improve the usefulness of this is %save %load commands for KV cache state. Using the blob store headers, I think this wouldn't be that hard to implement. Might be a good first issue for someone who's comfortable with the codebase. I think this would lead to a lot of use cases that would otherwise be impractical. |
+1, we'd welcome a pull request for this, also happy to discuss. |
@austinvhuang @jan-wassenberg I'd like to take a stab at this, if you nobody has objections? My background: I've been trying to break into this field, and I've had the pleasure of collaborating with the Google Team in the past for TFLite Support repository. |
Nice, sounds great, we'd be happy to collaborate with you, discuss and review :) FYI the KVCache internals will likely change a bit to use RowVectorBatch at some point, but no big deal. Is there anything in the current code that you think will cause difficulties?
|
Perfect, sorry for the delay, I can spin something up over the weekend. Please allow some time to read the codebase and get back with a proposal |
Had a first pass through the paper, the paper has proven its ability only on RoPE position encodings, and the theory is supported only for relative position encodings. i.e. there's no proof of it working if we were training via sinusoidal positional encoding. Shouldn't we have some kind of check for this? cc: @Mooler0410 @ahxt |
Hi, https://arxiv.org/pdf/2401.01325 mentions an experiment with non-RoPE also working. We mostly do use RoPE, though. |
@jan-wassenberg @austinvhuang @Mooler0410 Okay, it took me a long time to understand how transformers work and even longer to understand how this repository is implementing that. So please forgive some mistakes below and do comment if I'm making any mistakes The paper argues that O.O.D occurs due to it not having seen the positions outside of its trained context window. So it uses grouped attention by grouping position into For reference here is the code snippet from LongLM We already have RoPe implemented as In Gemma, we have
// gemma/activations.h
class Activations {
// Same as before, need space for s_g_pos states
RowVectorBatch<float> grp_q; (size: (batch_size, kHeads * QKVDim) // only Q head is flattened
} We use do fwd pass twice once for I think
// Apply positional encodings for K (and copy KV to cache if MHA).
pool.Run(
....
if constexpr (kIsMHA) {
// For MHA, copy KV into the KV cache from scratch space (see above).
const float* HWY_RESTRICT q =
activations.q.Batch(interleaved_idx) + head * kQStride;
const grp_k = kv_caches[query_idx].g_k_cache() // Another float pointer to store grouped key cache?
// Apply Rope on Grouped keys
const size_t ngb_size = TConfig::self_extend_ngb_size;
const size_t grp_size = TConfig::self_extend_grp_size;
// First, group the key positional embedding
const grp_k_pos = pos / grp_size;
// Now apply RoPE based on this, will come in handy later
RoPE(grp_k, qkvDim, grp_k_pos);
// This should be done during our fwd pass
// For each head (token, query), compute Q.K, softmax, and weighted V.
pool.Run(
0, kHeads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
...
KVCache& kv_cache = kv_caches[query_idx];
float* HWY_RESTRICT q =
activations.q.Batch(interleaved_idx) + head * kQStride;
// Apply rope and scaling to Q.
const size_t pos = batch_start + batch_idx;
PostQK<TConfig>(q, pos, layer);
MulByConst(kQueryScale, q, kQKVDim);
// Now let's find g_q based on new pos
if (pos > ngb_size && TConfig::kEnableSelfExtend) {
g_q = // Same as activations.q which we got during MatMul stored in a different memory block
const size_t s_g_pos = ngb_size + (pos - grp_size) / ngb_size;
PostQK<TConfig>(g_q, s_g_pos, layer);
}
// Score using g_q and g_k not query and K
if (pos > ngb_size && TConfig::kEnableSelfExtend)
const score = Dot(g_q, grp_k, QKVDim);
|
If you're interested in mentoring me to help merge this, I can convert the above into a Tech doc in Google Docs and you can point out things there. Awaiting your positive response |
Hi, great that you've mapped this to our code :) This sounds correct to me. Would be happy to advise, but I will be without internet access from Sep11-17. A doc sounds fine, or we can also proceed directly to a pull request if you prefer? |
I'd love that. I'm aware of values being stored as [Q, Q,...., K, K ...., V, V..] in non MHA cases. I was under the assumption that we'll be restricting it to purely MHA case, guess that's wrong. I was thinking of a doc so that we could have a unified tracker, objectives and next-plans where you could easily comment/edit. For reference, I've done something similar with TensorFlow team when I collaborated with them. Great! I can definitely fast-track a pull request for this! I understand you wouldn't be available from Sept11-17, given I'm also handling a full-time job on the side. I can try and spin a MVP while you're unavailable. Let me know if this arrangement works for you. |
I have holidays from 23rd - 27th, where I can dedicate good time to this and do a bulk amount of work. I'd love go back-n-forth with the feedbacks and suggestions so we can quickly integrate this amazing feature! Just FYI, we'd also need to plan how we'd test if its working and benchmark to confirm its practicality. I can make an .md file for the report. |
:) Yes, the various Gemma models can be any of MHA, GQA or MQA, so we have to handle both the MHA and non-MHA case. Sure, a doc sounds useful for commenting. Thanks for suggesting it :) Yes indeed, I'm back and catching up. Working on this together next week sounds great!
I am not very familiar with evals in this space, but understand that summarization and needle-in-haystack tests might work. Do you or others have any particular suggestions? |
@jan-wassenberg yes! The original code proposed needle-in-a-haystack problem and they also have a dataset for this. See for example They've used Gemma-7B-it for one of their benchmark as is evident here which we do have present with us. |
Got it, needle in a haystack sounds like a good eval then :) |
Updated with the code, could you please check @jan-wassenberg. Will create a tech doc and eval strat soon. Been busy for a while this took longer than expected, but now I'm ready |
Hi team, I checked the locallama and found that gemma can work well with the Self-Extend method. It would be awesome if this technique could be added to the gemma.cpp.
References:
The text was updated successfully, but these errors were encountered: