diff --git a/pkg/collector/clock.go b/pkg/collector/clock.go new file mode 100644 index 0000000..60f00ae --- /dev/null +++ b/pkg/collector/clock.go @@ -0,0 +1,156 @@ +// Copyright 2024 VMware, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package collector + +import ( + "sync" + "time" +) + +// timer allows for injecting fake or real timers into code that needs to do arbitrary things based +// on time. We do not include the C() method, as we only support timers created with AfterFunc. +type timer interface { + Stop() bool + Reset(d time.Duration) bool +} + +// clock allows for injecting fake or real clocks into code that needs to do arbitrary things based +// on time. We only support a very limited interface at the moment, with only the methods required +// by CollectingProcess. +type clock interface { + Now() time.Time + AfterFunc(d time.Duration, f func()) timer +} + +// realClock implements the clock interface using functions from the time package. +type realClock struct{} + +func (realClock) Now() time.Time { + return time.Now() +} + +func (realClock) AfterFunc(d time.Duration, f func()) timer { + return time.AfterFunc(d, f) +} + +type fakeTimer struct { + targetTime time.Time + f func() + clock *fakeClock +} + +func (t *fakeTimer) Stop() bool { + clock := t.clock + clock.m.Lock() + defer clock.m.Unlock() + newTimers := make([]*fakeTimer, 0, len(clock.timers)) + fired := true + for i := range clock.timers { + if clock.timers[i] != t { + newTimers = append(newTimers, t) + continue + } + // timer is found so it hasn't been fired yet + fired = false + } + clock.timers = newTimers + return !fired +} + +func (t *fakeTimer) Reset(d time.Duration) bool { + clock := t.clock + clock.m.Lock() + defer clock.m.Unlock() + fired := true + for i := range clock.timers { + if clock.timers[i] != t { + continue + } + // timer is found so it hasn't been fired yet + fired = false + t.targetTime = clock.now.Add(d) + } + return !fired +} + +// fakeClock implements the clock interface as a virtual clock meant to be used in tests. Time can +// be advanced arbitrarily, but does not change on its own. +type fakeClock struct { + m sync.RWMutex + isAdvancing bool + now time.Time + timers []*fakeTimer +} + +func newFakeClock(t time.Time) *fakeClock { + return &fakeClock{ + now: t, + } +} + +func (c *fakeClock) Now() time.Time { + c.m.RLock() + defer c.m.RUnlock() + return c.now +} + +func (c *fakeClock) AfterFunc(d time.Duration, f func()) timer { + if d <= 0 { + panic("negative duration not supported") + } + c.m.Lock() + defer c.m.Unlock() + t := &fakeTimer{ + targetTime: c.now.Add(d), + f: f, + clock: c, + } + c.timers = append(c.timers, t) + return t +} + +func (c *fakeClock) Step(d time.Duration) { + if d < 0 { + panic("invalid duration") + } + timerFuncs := []func(){} + func() { + c.m.Lock() + defer c.m.Unlock() + if c.isAdvancing { + panic("concurrent calls to Step() not allowed") + } + c.isAdvancing = true + c.now = c.now.Add(d) + // Collect timer functions to run and remove them from list. + newTimers := make([]*fakeTimer, 0, len(c.timers)) + for _, t := range c.timers { + if !t.targetTime.After(c.now) { + timerFuncs = append(timerFuncs, t.f) + } else { + newTimers = append(newTimers, t) + } + } + c.timers = newTimers + }() + // Run the timer functions, without holding a lock. This allows these functions to call + // clock.Now(), but also timer.Stop(). + for _, f := range timerFuncs { + f() + } + c.m.Lock() + defer c.m.Unlock() + c.isAdvancing = false +} diff --git a/pkg/collector/process.go b/pkg/collector/process.go index 48fe208..0b2fdab 100644 --- a/pkg/collector/process.go +++ b/pkg/collector/process.go @@ -50,9 +50,15 @@ const ( DecodingModeLenientDropUnknown DecodingMode = "LenientDropUnknown" ) +type template struct { + ies []*entities.InfoElement + expiryTime time.Time + expiryTimer timer +} + type CollectingProcess struct { // for each obsDomainID, there is a map of templates - templatesMap map[uint32]map[uint16][]*entities.InfoElement + templatesMap map[uint32]map[uint16]*template // mutex allows multiple readers or one writer at the same time mutex sync.RWMutex // template lifetime @@ -85,6 +91,8 @@ type CollectingProcess struct { serverKey []byte wg sync.WaitGroup numOfRecordsReceived uint64 + // clock implementation: enables injecting a fake clock for testing + clock clock } type CollectorInput struct { @@ -113,7 +121,7 @@ type clientHandler struct { closeClientChan chan struct{} } -func InitCollectingProcess(input CollectorInput) (*CollectingProcess, error) { +func initCollectingProcess(input CollectorInput, clock clock) (*CollectingProcess, error) { templateTTLSeconds := input.TemplateTTL if input.Protocol == "udp" && templateTTLSeconds == 0 { templateTTLSeconds = entities.TemplateTTL @@ -128,8 +136,8 @@ func InitCollectingProcess(input CollectorInput) (*CollectingProcess, error) { "encrypted", input.IsEncrypted, "address", input.Address, "protocol", input.Protocol, "maxBufferSize", input.MaxBufferSize, "templateTTL", templateTTL, "numExtraElements", input.NumExtraElements, "decodingMode", decodingMode, ) - collectProc := &CollectingProcess{ - templatesMap: make(map[uint32]map[uint16][]*entities.InfoElement), + cp := &CollectingProcess{ + templatesMap: make(map[uint32]map[uint16]*template), mutex: sync.RWMutex{}, templateTTL: templateTTL, address: input.Address, @@ -144,8 +152,13 @@ func InitCollectingProcess(input CollectorInput) (*CollectingProcess, error) { serverKey: input.ServerKey, numExtraElements: input.NumExtraElements, decodingMode: decodingMode, + clock: clock, } - return collectProc, nil + return cp, nil +} + +func InitCollectingProcess(input CollectorInput) (*CollectingProcess, error) { + return initCollectingProcess(input, realClock{}) } func (cp *CollectingProcess) Start() { @@ -321,7 +334,7 @@ func (cp *CollectingProcess) decodeTemplateSet(templateBuffer *bytes.Buffer, obs func (cp *CollectingProcess) decodeDataSet(dataBuffer *bytes.Buffer, obsDomainID uint32, templateID uint16) (entities.Set, error) { // make sure template exists - template, err := cp.getTemplate(obsDomainID, templateID) + template, err := cp.getTemplateIEs(obsDomainID, templateID) if err != nil { return nil, fmt.Errorf("template %d with obsDomainID %d does not exist", templateID, obsDomainID) } @@ -361,47 +374,95 @@ func (cp *CollectingProcess) decodeDataSet(dataBuffer *bytes.Buffer, obsDomainID func (cp *CollectingProcess) addTemplate(obsDomainID uint32, templateID uint16, elementsWithValue []entities.InfoElementWithValue) { cp.mutex.Lock() defer cp.mutex.Unlock() - if _, exists := cp.templatesMap[obsDomainID]; !exists { - cp.templatesMap[obsDomainID] = make(map[uint16][]*entities.InfoElement) + if _, ok := cp.templatesMap[obsDomainID]; !ok { + cp.templatesMap[obsDomainID] = make(map[uint16]*template) } elements := make([]*entities.InfoElement, 0) for _, elementWithValue := range elementsWithValue { elements = append(elements, elementWithValue.GetInfoElement()) } - cp.templatesMap[obsDomainID][templateID] = elements - // template lifetime management + tpl, ok := cp.templatesMap[obsDomainID][templateID] + if !ok { + tpl = &template{} + cp.templatesMap[obsDomainID][templateID] = tpl + } + tpl.ies = elements + klog.V(4).InfoS("Added template to template map", "obsDomainID", obsDomainID, "templateID", templateID) + // Template lifetime management for UDP. if cp.protocol != "udp" { return } - // Handle udp template expiration - go func() { - ticker := time.NewTicker(cp.templateTTL) - defer ticker.Stop() - select { - case <-ticker.C: + tpl.expiryTime = cp.clock.Now().Add(cp.templateTTL) + if tpl.expiryTimer == nil { + tpl.expiryTimer = cp.clock.AfterFunc(cp.templateTTL, func() { klog.Infof("Template with id %d, and obsDomainID %d is expired.", templateID, obsDomainID) - cp.deleteTemplate(obsDomainID, templateID) - break + now := cp.clock.Now() + // From the Go documentation: + // For a func-based timer created with AfterFunc(d, f), Reset either + // reschedules when f will run, in which case Reset returns true, or + // schedules f to run again, in which case it returns false. When Reset + // returns false, Reset neither waits for the prior f to complete before + // returning nor does it guarantee that the subsequent goroutine running f + // does not run concurrently with the prior one. If the caller needs to + // know whether the prior execution of f is completed, it must coordinate + // with f explicitly. + // In our case, when f executes, we have to verify that the record is indeed + // scheduled for deletion by checking expiryTime. We cannot just + // automatically delete the template. + cp.deleteTemplateWithConds(obsDomainID, templateID, func(tpl *template) bool { + // lock will be held when this executes + return !tpl.expiryTime.After(now) + }) + }) + } else { + tpl.expiryTimer.Reset(cp.templateTTL) + } +} + +// deleteTemplate returns true iff a template was actually deleted. +func (cp *CollectingProcess) deleteTemplate(obsDomainID uint32, templateID uint16) bool { + return cp.deleteTemplateWithConds(obsDomainID, templateID) +} + +// deleteTemplateWithConds returns true iff a template was actually deleted. +func (cp *CollectingProcess) deleteTemplateWithConds(obsDomainID uint32, templateID uint16, condFns ...func(*template) bool) bool { + cp.mutex.Lock() + defer cp.mutex.Unlock() + template, ok := cp.templatesMap[obsDomainID][templateID] + if !ok { + return false + } + for _, condFn := range condFns { + if !condFn(template) { + return false } - }() + } + // expiryTimer will be nil when the protocol is UDP. + if template.expiryTimer != nil { + // expiryTimer may have been stopped already (if the timer + // expired and is the reason why the template is being deleted), + // but it is safe to call Stop() on an expired timer. + template.expiryTimer.Stop() + } + delete(cp.templatesMap[obsDomainID], templateID) + klog.V(4).InfoS("Deleted template from template map", "obsDomainID", obsDomainID, "templateID", templateID) + if len(cp.templatesMap[obsDomainID]) == 0 { + delete(cp.templatesMap, obsDomainID) + klog.V(4).InfoS("No more templates for observation domain", "obsDomainID", obsDomainID) + } + return true } -func (cp *CollectingProcess) getTemplate(obsDomainID uint32, templateID uint16) ([]*entities.InfoElement, error) { +func (cp *CollectingProcess) getTemplateIEs(obsDomainID uint32, templateID uint16) ([]*entities.InfoElement, error) { cp.mutex.RLock() defer cp.mutex.RUnlock() - if elements, exists := cp.templatesMap[obsDomainID][templateID]; exists { - return elements, nil + if template, ok := cp.templatesMap[obsDomainID][templateID]; ok { + return template.ies, nil } else { return nil, fmt.Errorf("template %d with obsDomainID %d does not exist", templateID, obsDomainID) } } -func (cp *CollectingProcess) deleteTemplate(obsDomainID uint32, templateID uint16) { - cp.mutex.Lock() - defer cp.mutex.Unlock() - delete(cp.templatesMap[obsDomainID], templateID) -} - func (cp *CollectingProcess) updateAddress(address net.Addr) { cp.mutex.Lock() defer cp.mutex.Unlock() diff --git a/pkg/collector/process_test.go b/pkg/collector/process_test.go index c7ba19e..360cf31 100644 --- a/pkg/collector/process_test.go +++ b/pkg/collector/process_test.go @@ -80,7 +80,7 @@ func TestTCPCollectingProcess_ReceiveTemplateRecord(t *testing.T) { }() <-cp.GetMsgChan() cp.Stop() - template, _ := cp.getTemplate(1, 256) + template, _ := cp.getTemplateIEs(1, 256) assert.NotNil(t, template, "TCP Collecting Process should receive and store the received template.") assert.Equal(t, int64(1), cp.GetNumRecordsReceived()) } @@ -110,7 +110,7 @@ func TestUDPCollectingProcess_ReceiveTemplateRecord(t *testing.T) { }() <-cp.GetMsgChan() cp.Stop() - template, _ := cp.getTemplate(1, 256) + template, _ := cp.getTemplateIEs(1, 256) assert.NotNil(t, template, "UDP Collecting Process should receive and store the received template.") assert.Equal(t, int64(1), cp.GetNumRecordsReceived()) } @@ -350,7 +350,7 @@ func TestUDPCollectingProcess_DecodePacketError(t *testing.T) { func TestCollectingProcess_DecodeTemplateRecord(t *testing.T) { cp := CollectingProcess{} - cp.templatesMap = make(map[uint32]map[uint16][]*entities.InfoElement) + cp.templatesMap = make(map[uint32]map[uint16]*template) cp.mutex = sync.RWMutex{} address, err := net.ResolveTCPAddr(tcpTransport, hostPortIPv4) if err != nil { @@ -381,7 +381,7 @@ func TestCollectingProcess_DecodeTemplateRecord(t *testing.T) { assert.NotNil(t, err, "Error should be logged for invalid version") // Malformed record templateRecord = []byte{0, 10, 0, 40, 95, 40, 211, 236, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 24, 1, 0, 0, 3, 0, 8, 0, 4, 0, 12, 0, 4, 128, 105, 255, 255, 0, 0} - cp.templatesMap = make(map[uint32]map[uint16][]*entities.InfoElement) + cp.templatesMap = make(map[uint32]map[uint16]*template) _, err = cp.decodePacket(bytes.NewBuffer(templateRecord), address.String()) assert.NotNil(t, err, "Error should be logged for malformed template record") if _, exist := cp.templatesMap[uint32(1)]; exist { @@ -391,7 +391,7 @@ func TestCollectingProcess_DecodeTemplateRecord(t *testing.T) { func TestCollectingProcess_DecodeDataRecord(t *testing.T) { cp := CollectingProcess{} - cp.templatesMap = make(map[uint32]map[uint16][]*entities.InfoElement) + cp.templatesMap = make(map[uint32]map[uint16]*template) cp.mutex = sync.RWMutex{} address, err := net.ResolveTCPAddr(tcpTransport, hostPortIPv4) if err != nil { @@ -426,19 +426,15 @@ func TestCollectingProcess_DecodeDataRecord(t *testing.T) { } func TestUDPCollectingProcess_TemplateExpire(t *testing.T) { + clock := newFakeClock(time.Now()) input := CollectorInput{ Address: hostPortIPv4, Protocol: udpTransport, MaxBufferSize: 1024, TemplateTTL: 1, - IsEncrypted: false, - ServerCert: nil, - ServerKey: nil, - } - cp, err := InitCollectingProcess(input) - if err != nil { - t.Fatalf("UDP Collecting Process does not start correctly: %v", err) } + cp, err := initCollectingProcess(input, clock) + require.NoError(t, err) go cp.Start() // wait until collector is ready waitForCollectorReady(t, cp) @@ -460,13 +456,109 @@ func TestUDPCollectingProcess_TemplateExpire(t *testing.T) { }() <-cp.GetMsgChan() cp.Stop() - template, err := cp.getTemplate(1, 256) + template, err := cp.getTemplateIEs(1, 256) assert.NotNil(t, template, "Template should be stored in the template map.") assert.Nil(t, err, "Template should be stored in the template map.") - time.Sleep(2 * time.Second) - template, err = cp.getTemplate(1, 256) - assert.Nil(t, template, "Template should be deleted after 5 seconds.") - assert.NotNil(t, err, "Template should be deleted after 5 seconds.") + clock.Step(time.Duration(input.TemplateTTL) * time.Second) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + _, err := cp.getTemplateIEs(1, 256) + assert.ErrorContains(t, err, "does not exist", "template should be deleted after timeout") + }, 1*time.Second, 50*time.Millisecond) +} + +func TestUDPCollectingProcess_TemplateAddAndDelete(t *testing.T) { + const ( + templateID = 100 + obsDomainID = 0xabcd + ) + clock := newFakeClock(time.Now()) + input := CollectorInput{ + Address: hostPortIPv4, + Protocol: udpTransport, + MaxBufferSize: 1024, + TemplateTTL: 1, + } + cp, err := initCollectingProcess(input, clock) + require.NoError(t, err) + cp.addTemplate(obsDomainID, templateID, elementsWithValueIPv4) + // Get a copy of the stored template + tpl := func() template { + cp.mutex.RLock() + defer cp.mutex.RUnlock() + return *cp.templatesMap[obsDomainID][templateID] + }() + require.NotNil(t, tpl.expiryTimer) + require.True(t, cp.deleteTemplate(obsDomainID, templateID)) + // Stop returns false if the timer has already been stopped, which + // should be done by the call to deleteTemplate. + assert.False(t, tpl.expiryTimer.Stop()) + // Deleting the template a second time should return false + assert.False(t, cp.deleteTemplate(obsDomainID, templateID)) +} + +// TestUDPCollectingProcess_TemplateUpdate checks the behavior of addTemplate +// when a template is refreshed. +func TestUDPCollectingProcess_TemplateUpdate(t *testing.T) { + const ( + templateID = 100 + obsDomainID = 0xabcd + ) + now := time.Now() + clock := newFakeClock(now) + input := CollectorInput{ + Address: hostPortIPv4, + Protocol: udpTransport, + MaxBufferSize: 1024, + TemplateTTL: 1, + } + cp, err := initCollectingProcess(input, clock) + require.NoError(t, err) + cp.addTemplate(obsDomainID, templateID, elementsWithValueIPv4) + // Get a copy of the stored template + getTemplate := func() template { + cp.mutex.RLock() + defer cp.mutex.RUnlock() + return *cp.templatesMap[obsDomainID][templateID] + } + tpl := getTemplate() + require.NotNil(t, tpl.expiryTimer) + assert.Equal(t, now.Add(time.Duration(input.TemplateTTL)*time.Second), tpl.expiryTime) + // Advance the clock by half the TTL + clock.Step(500 * time.Millisecond) + // Template should still be present in map + _, err = cp.getTemplateIEs(obsDomainID, templateID) + require.NoError(t, err) + // "Update" the template (template is being refreshed) + cp.addTemplate(obsDomainID, templateID, elementsWithValueIPv4) + tpl = getTemplate() + assert.Equal(t, clock.Now().Add(time.Duration(input.TemplateTTL)*time.Second), tpl.expiryTime) + // Advance the clock by half the TTL again, template should still be present + clock.Step(500 * time.Millisecond) + _, err = cp.getTemplateIEs(obsDomainID, templateID) + require.NoError(t, err) + // Advance the clock by half the TTL again, template should be expired + clock.Step(500 * time.Millisecond) + _, err = cp.getTemplateIEs(obsDomainID, templateID) + assert.Error(t, err) +} + +func BenchmarkAddTemplateUDP(b *testing.B) { + input := CollectorInput{ + Address: hostPortIPv4, + Protocol: udpTransport, + MaxBufferSize: 1024, + IsEncrypted: false, + ServerCert: nil, + ServerKey: nil, + } + cp, err := initCollectingProcess(input, newFakeClock(time.Now())) + require.NoError(b, err) + obsDomainID := uint32(1) + b.ResetTimer() + for range b.N { + cp.addTemplate(obsDomainID, 256, elementsWithValueIPv4) + obsDomainID = (obsDomainID + 1) % 1000 + } } func TestTLSCollectingProcess(t *testing.T) { @@ -573,7 +665,7 @@ func TestTCPCollectingProcessIPv6(t *testing.T) { <-cp.GetMsgChan() message := <-cp.GetMsgChan() cp.Stop() - template, _ := cp.getTemplate(1, 256) + template, _ := cp.getTemplateIEs(1, 256) assert.NotNil(t, template) ie, _, exist := message.GetSet().GetRecords()[0].GetInfoElementWithValue("sourceIPv6Address") assert.True(t, exist) @@ -602,7 +694,7 @@ func TestUDPCollectingProcessIPv6(t *testing.T) { <-cp.GetMsgChan() message := <-cp.GetMsgChan() cp.Stop() - template, _ := cp.getTemplate(1, 256) + template, _ := cp.getTemplateIEs(1, 256) assert.NotNil(t, template) ie, _, exist := message.GetSet().GetRecords()[0].GetInfoElementWithValue("sourceIPv6Address") assert.True(t, exist)