diff --git a/controllers/gce/networkendpointgroup/fakes.go b/controllers/gce/networkendpointgroup/fakes.go new file mode 100644 index 000000000..df9735e35 --- /dev/null +++ b/controllers/gce/networkendpointgroup/fakes.go @@ -0,0 +1,194 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package networkendpointgroup + +import ( + "fmt" + computealpha "google.golang.org/api/compute/v0.alpha" + "k8s.io/apimachinery/pkg/util/sets" + "reflect" + "sync" +) + +const ( + TestZone1 = "zone1" + TestZone2 = "zone2" + TestInstance1 = "instance1" + TestInstance2 = "instance2" + TestInstance3 = "instance3" + TestInstance4 = "instance4" +) + +type fakeZoneGetter struct { + zoneInstanceMap map[string]sets.String +} + +func NewFakeZoneGetter() *fakeZoneGetter { + return &fakeZoneGetter{ + zoneInstanceMap: map[string]sets.String{ + TestZone1: sets.NewString(TestInstance1, TestInstance2), + TestZone2: sets.NewString(TestInstance3, TestInstance4), + }, + } +} + +func (f *fakeZoneGetter) ListZones() ([]string, error) { + ret := []string{} + for key := range f.zoneInstanceMap { + ret = append(ret, key) + } + return ret, nil +} +func (f *fakeZoneGetter) GetZoneForNode(name string) (string, error) { + for zone, instances := range f.zoneInstanceMap { + if instances.Has(name) { + return zone, nil + } + } + return "", NotFoundError +} + +type FakeNetworkEndpointGroupCloud struct { + NetworkEndpointGroups map[string][]*computealpha.NetworkEndpointGroup + NetworkEndpoints map[string][]*computealpha.NetworkEndpoint + Subnetwork string + Network string + mu sync.Mutex +} + +func NewFakeNetworkEndpointGroupCloud(subnetwork, network string) NetworkEndpointGroupCloud { + return &FakeNetworkEndpointGroupCloud{ + Subnetwork: subnetwork, + Network: network, + NetworkEndpointGroups: map[string][]*computealpha.NetworkEndpointGroup{}, + NetworkEndpoints: map[string][]*computealpha.NetworkEndpoint{}, + } +} + +var NotFoundError = fmt.Errorf("Not Found") + +func (cloud *FakeNetworkEndpointGroupCloud) GetNetworkEndpointGroup(name string, zone string) (*computealpha.NetworkEndpointGroup, error) { + cloud.mu.Lock() + defer cloud.mu.Unlock() + negs, ok := cloud.NetworkEndpointGroups[zone] + if ok { + for _, neg := range negs { + if neg.Name == name { + return neg, nil + } + } + } + return nil, NotFoundError +} + +func networkEndpointKey(name, zone string) string { + return fmt.Sprintf("%s-%s", zone, name) +} + +func (cloud *FakeNetworkEndpointGroupCloud) ListNetworkEndpointGroup(zone string) ([]*computealpha.NetworkEndpointGroup, error) { + cloud.mu.Lock() + defer cloud.mu.Unlock() + return cloud.NetworkEndpointGroups[zone], nil +} + +func (cloud *FakeNetworkEndpointGroupCloud) AggregatedListNetworkEndpointGroup() (map[string][]*computealpha.NetworkEndpointGroup, error) { + cloud.mu.Lock() + defer cloud.mu.Unlock() + return cloud.NetworkEndpointGroups, nil +} + +func (cloud *FakeNetworkEndpointGroupCloud) CreateNetworkEndpointGroup(neg *computealpha.NetworkEndpointGroup, zone string) error { + cloud.mu.Lock() + defer cloud.mu.Unlock() + if _, ok := cloud.NetworkEndpointGroups[zone]; !ok { + cloud.NetworkEndpointGroups[zone] = []*computealpha.NetworkEndpointGroup{} + } + cloud.NetworkEndpointGroups[zone] = append(cloud.NetworkEndpointGroups[zone], neg) + cloud.NetworkEndpoints[networkEndpointKey(neg.Name, zone)] = []*computealpha.NetworkEndpoint{} + return nil +} + +func (cloud *FakeNetworkEndpointGroupCloud) DeleteNetworkEndpointGroup(name string, zone string) error { + cloud.mu.Lock() + defer cloud.mu.Unlock() + delete(cloud.NetworkEndpoints, networkEndpointKey(name, zone)) + negs := cloud.NetworkEndpointGroups[zone] + newList := []*computealpha.NetworkEndpointGroup{} + found := false + for _, neg := range negs { + if neg.Name == name { + found = true + continue + } + newList = append(newList, neg) + } + if !found { + return NotFoundError + } + cloud.NetworkEndpointGroups[zone] = newList + return nil +} + +func (cloud *FakeNetworkEndpointGroupCloud) AttachNetworkEndpoints(name, zone string, endpoints []*computealpha.NetworkEndpoint) error { + cloud.mu.Lock() + defer cloud.mu.Unlock() + cloud.NetworkEndpoints[networkEndpointKey(name, zone)] = append(cloud.NetworkEndpoints[networkEndpointKey(name, zone)], endpoints...) + return nil +} + +func (cloud *FakeNetworkEndpointGroupCloud) DetachNetworkEndpoints(name, zone string, endpoints []*computealpha.NetworkEndpoint) error { + cloud.mu.Lock() + defer cloud.mu.Unlock() + newList := []*computealpha.NetworkEndpoint{} + for _, ne := range cloud.NetworkEndpoints[networkEndpointKey(name, zone)] { + found := false + for _, remove := range endpoints { + if reflect.DeepEqual(*ne, *remove) { + found = true + break + } + } + if found { + continue + } + newList = append(newList, ne) + } + cloud.NetworkEndpoints[networkEndpointKey(name, zone)] = newList + return nil +} + +func (cloud *FakeNetworkEndpointGroupCloud) ListNetworkEndpoints(name, zone string, showHealthStatus bool) ([]*computealpha.NetworkEndpointWithHealthStatus, error) { + cloud.mu.Lock() + defer cloud.mu.Unlock() + ret := []*computealpha.NetworkEndpointWithHealthStatus{} + nes, ok := cloud.NetworkEndpoints[networkEndpointKey(name, zone)] + if !ok { + return nil, NotFoundError + } + for _, ne := range nes { + ret = append(ret, &computealpha.NetworkEndpointWithHealthStatus{NetworkEndpoint: ne}) + } + return ret, nil +} + +func (cloud *FakeNetworkEndpointGroupCloud) NetworkURL() string { + return cloud.Network +} + +func (cloud *FakeNetworkEndpointGroupCloud) SubnetworkURL() string { + return cloud.Subnetwork +} diff --git a/controllers/gce/networkendpointgroup/interfaces.go b/controllers/gce/networkendpointgroup/interfaces.go new file mode 100644 index 000000000..4e7334697 --- /dev/null +++ b/controllers/gce/networkendpointgroup/interfaces.go @@ -0,0 +1,66 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package networkendpointgroup + +import ( + computealpha "google.golang.org/api/compute/v0.alpha" + "k8s.io/apimachinery/pkg/util/sets" +) + +// NetworkEndpointGroupCloud is an interface for managing gce network endpoint group. +type NetworkEndpointGroupCloud interface { + GetNetworkEndpointGroup(name string, zone string) (*computealpha.NetworkEndpointGroup, error) + ListNetworkEndpointGroup(zone string) ([]*computealpha.NetworkEndpointGroup, error) + AggregatedListNetworkEndpointGroup() (map[string][]*computealpha.NetworkEndpointGroup, error) + CreateNetworkEndpointGroup(neg *computealpha.NetworkEndpointGroup, zone string) error + DeleteNetworkEndpointGroup(name string, zone string) error + AttachNetworkEndpoints(name, zone string, endpoints []*computealpha.NetworkEndpoint) error + DetachNetworkEndpoints(name, zone string, endpoints []*computealpha.NetworkEndpoint) error + ListNetworkEndpoints(name, zone string, showHealthStatus bool) ([]*computealpha.NetworkEndpointWithHealthStatus, error) + NetworkURL() string + SubnetworkURL() string +} + +// NetworkEndpointGroupNamer is an interface for generating network endpoint group name. +type NetworkEndpointGroupNamer interface { + NEGName(namespace, name, port string) string + NEGPrefix() string +} + +// ZoneGetter is an interface for retrieve zone related information +type ZoneGetter interface { + ListZones() ([]string, error) + GetZoneForNode(name string) (string, error) +} + +// Syncer is an interface to interact with syncer +type Syncer interface { + Start() error + Stop() + Sync() bool + IsStopped() bool + IsShuttingDown() bool +} + +// SyncerManager is an interface for controllers to manage Syncers +type SyncerManager interface { + EnsureSyncer(namespace, name string, targetPorts sets.String) error + StopSyncer(namespace, name string) + Sync(namespace, name string) + GC() error + ShutDown() +} diff --git a/controllers/gce/networkendpointgroup/manager.go b/controllers/gce/networkendpointgroup/manager.go new file mode 100644 index 000000000..9a3693405 --- /dev/null +++ b/controllers/gce/networkendpointgroup/manager.go @@ -0,0 +1,256 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package networkendpointgroup + +import ( + "fmt" + "strings" + "sync" + + "github.com/golang/glog" + utilerrors "k8s.io/apimachinery/pkg/util/errors" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/client-go/tools/cache" + "k8s.io/client-go/tools/record" +) + +// syncerManager exposes a few interfaces to manage syncer and ensures thread safety. +type syncerManager struct { + namer NetworkEndpointGroupNamer + recorder record.EventRecorder + cloud NetworkEndpointGroupCloud + zoneGetter ZoneGetter + + serviceLister cache.Indexer + endpointLister cache.Indexer + + // TODO: lock per service instead of global lock + mu sync.Mutex + // svcPortMap is the canonical indicator for whether a service needs NEG + // key is service namespace/name, value is the list of target port that requires NEG + svcPortMap map[string]sets.String + // syncerMap stores the NEG syncer + // key is service namespace/name/targetPort. Value is the corresponding syncer + syncerMap map[string]Syncer +} + +func newSyncerManager(namer NetworkEndpointGroupNamer, recorder record.EventRecorder, cloud NetworkEndpointGroupCloud, zoneGetter ZoneGetter, serviceLister cache.Indexer, endpointLister cache.Indexer) *syncerManager { + return &syncerManager{ + namer: namer, + recorder: recorder, + cloud: cloud, + zoneGetter: zoneGetter, + serviceLister: serviceLister, + endpointLister: endpointLister, + svcPortMap: make(map[string]sets.String), + syncerMap: make(map[string]Syncer), + } +} + +// EnsureSyncer starts and stops syncers based on the input service ports. +func (manager *syncerManager) EnsureSyncer(namespace, name string, targetPorts sets.String) error { + manager.mu.Lock() + defer manager.mu.Unlock() + key := serviceKeyFunc(namespace, name) + currentPorts, ok := manager.svcPortMap[key] + if !ok { + currentPorts = sets.NewString() + } + + removes := currentPorts.Difference(targetPorts).List() + adds := targetPorts.Difference(currentPorts).List() + manager.svcPortMap[key] = targetPorts + + // Stop syncer for removed ports + for _, port := range removes { + syncer, ok := manager.syncerMap[encodeSyncerKey(namespace, name, port)] + if ok { + syncer.Stop() + } + } + + errList := []error{} + // Start syncer for added ports + for _, port := range adds { + syncer, ok := manager.syncerMap[encodeSyncerKey(namespace, name, port)] + if !ok { + syncer = newSyncer( + servicePort{ + namespace: namespace, + name: name, + targetPort: port, + }, + manager.namer.NEGName(namespace, name, port), + manager.recorder, + manager.cloud, + manager.zoneGetter, + manager.serviceLister, + manager.endpointLister, + ) + manager.syncerMap[encodeSyncerKey(namespace, name, port)] = syncer + } + + if syncer.IsStopped() { + if err := syncer.Start(); err != nil { + errList = append(errList, err) + } + } + } + return utilerrors.NewAggregate(errList) +} + +// StopSyncer stops all syncers for the input service. +func (manager *syncerManager) StopSyncer(namespace, name string) { + manager.mu.Lock() + defer manager.mu.Unlock() + key := serviceKeyFunc(namespace, name) + if ports, ok := manager.svcPortMap[key]; ok { + glog.V(2).Infof("Stopping NEG syncer for service %q", key) + for _, port := range ports.List() { + syncer, ok := manager.syncerMap[encodeSyncerKey(namespace, name, port)] + if ok { + syncer.Stop() + } + } + delete(manager.svcPortMap, key) + } + return +} + +// Sync signals all syncers related to the service to sync. +func (manager *syncerManager) Sync(namespace, name string) { + manager.mu.Lock() + defer manager.mu.Unlock() + key := serviceKeyFunc(namespace, name) + if portList, ok := manager.svcPortMap[key]; ok { + for _, port := range portList.List() { + if syncer, ok := manager.syncerMap[encodeSyncerKey(namespace, name, port)]; ok { + if !syncer.IsStopped() { + syncer.Sync() + } + } + } + } +} + +// ShutDown signals all syncers to stop +func (manager *syncerManager) ShutDown() { + manager.mu.Lock() + defer manager.mu.Unlock() + for _, s := range manager.syncerMap { + s.Stop() + } +} + +// GC garbage collects syncers and NEGs. +func (manager *syncerManager) GC() error { + glog.V(2).Infof("Start NEG garbage collection.") + defer glog.V(2).Infof("NEG garbage collection finished.") + // Garbage collect syncer + for _, key := range manager.getAllStoppedSyncerKeys().List() { + manager.garbageCollectSyncer(key) + } + + // Garbage collect NEGs + if err := manager.garbageCollectNEG(); err != nil { + return fmt.Errorf("Failed to garbage collect negs: %v", err) + } + return nil +} + +func (manager *syncerManager) garbageCollectSyncer(key string) { + manager.mu.Lock() + defer manager.mu.Unlock() + if manager.syncerMap[key].IsStopped() && !manager.syncerMap[key].IsShuttingDown() { + delete(manager.syncerMap, key) + } +} + +func (manager *syncerManager) getAllStoppedSyncerKeys() sets.String { + manager.mu.Lock() + defer manager.mu.Unlock() + ret := sets.NewString() + for key, syncer := range manager.syncerMap { + if syncer.IsStopped() { + ret.Insert(key) + } + } + return ret +} + +func (manager *syncerManager) garbageCollectNEG() error { + // Retrieve aggregated NEG list from cloud + // Compare against svcPortMap and Remove unintended NEGs by best effort + zoneNEGList, err := manager.cloud.AggregatedListNetworkEndpointGroup() + if err != nil { + return fmt.Errorf("failed to retrieve aggregated NEG list: %v", err) + } + + negNames := sets.String{} + for _, list := range zoneNEGList { + for _, neg := range list { + if strings.HasPrefix(neg.Name, manager.namer.NEGPrefix()) { + negNames.Insert(neg.Name) + } + } + } + + func() { + manager.mu.Lock() + defer manager.mu.Unlock() + for key, ports := range manager.svcPortMap { + namespace, name, err := cache.SplitMetaNamespaceKey(key) + if err != nil { + glog.Errorf("Failed to parse service key %q: %v", key, err) + continue + } + for _, port := range ports.List() { + name := manager.namer.NEGName(namespace, name, port) + negNames.Delete(name) + } + } + }() + + // This section includes a potential race condition between deleting neg here and users adds the neg annotation. + // The worst outcome of the race condition is that neg is deleted in the end but user actually specifies a neg. + // This would be resolved (sync neg) when the next endpoint update or resync arrives. + // TODO: avoid race condition here + for zone := range zoneNEGList { + for _, name := range negNames.List() { + if err := manager.ensureDeleteNetworkEndpointGroup(name, zone); err != nil { + return fmt.Errorf("failed to delete NEG %q in %q: %v", name, zone, err) + } + } + } + return nil +} + +// ensureDeleteNetworkEndpointGroup ensures neg is delete from zone +func (manager *syncerManager) ensureDeleteNetworkEndpointGroup(name, zone string) error { + _, err := manager.cloud.GetNetworkEndpointGroup(name, zone) + if err != nil { + // Assume error is caused by not existing + return nil + } + glog.V(2).Infof("Deleting NEG %q in %q.", name, zone) + return manager.cloud.DeleteNetworkEndpointGroup(name, zone) +} + +// encodeSyncerKey encodes a service namespace, name and targetPort into a string key +func encodeSyncerKey(namespace, name, port string) string { + return fmt.Sprintf("%s||%s||%s", namespace, name, port) +} diff --git a/controllers/gce/networkendpointgroup/manager_test.go b/controllers/gce/networkendpointgroup/manager_test.go new file mode 100644 index 000000000..cd9a0b39e --- /dev/null +++ b/controllers/gce/networkendpointgroup/manager_test.go @@ -0,0 +1,187 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package networkendpointgroup + +import ( + compute "google.golang.org/api/compute/v0.alpha" + apiv1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/kubernetes/fake" + "k8s.io/client-go/tools/record" + "k8s.io/ingress/controllers/gce/utils" + "testing" + "time" +) + +const ( + CluseterID = "clusterid" +) + +func NewTestSyncerManager(kubeClient kubernetes.Interface) *syncerManager { + context := utils.NewControllerContext(kubeClient, apiv1.NamespaceAll, 1*time.Second, true) + manager := newSyncerManager( + utils.NewNamer(CluseterID, ""), + record.NewFakeRecorder(100), + NewFakeNetworkEndpointGroupCloud("test-subnetwork", "test-network"), + NewFakeZoneGetter(), + context.ServiceInformer.GetIndexer(), + context.EndpointInformer.GetIndexer(), + ) + return manager +} + +func TestEnsureAndStopSyncer(t *testing.T) { + testCases := []struct { + namespace string + name string + ports sets.String + stop bool + expect sets.String // keys of running syncers + }{ + { + "ns1", + "n1", + sets.NewString("80", "443"), + false, + sets.NewString( + encodeSyncerKey("ns1", "n1", "80"), + encodeSyncerKey("ns1", "n1", "443"), + ), + }, + { + "ns1", + "n1", + sets.NewString("80", "namedport"), + false, + sets.NewString( + encodeSyncerKey("ns1", "n1", "80"), + encodeSyncerKey("ns1", "n1", "namedport"), + ), + }, + { + "ns2", + "n1", + sets.NewString("80"), + false, + sets.NewString( + encodeSyncerKey("ns1", "n1", "80"), + encodeSyncerKey("ns1", "n1", "namedport"), + encodeSyncerKey("ns2", "n1", "80"), + ), + }, + { + "ns1", + "n1", + sets.NewString(), + true, + sets.NewString( + encodeSyncerKey("ns2", "n1", "80"), + ), + }, + } + + manager := NewTestSyncerManager(fake.NewSimpleClientset()) + for _, tc := range testCases { + if tc.stop { + manager.StopSyncer(tc.namespace, tc.name) + } else { + if err := manager.EnsureSyncer(tc.namespace, tc.name, tc.ports); err != nil { + t.Errorf("Failed to ensure syncer %s/%s-%v: %v", tc.namespace, tc.name, tc.ports, err) + } + } + + for _, key := range tc.expect.List() { + syncer, ok := manager.syncerMap[key] + if !ok { + t.Errorf("Expect syncer key %q to be present.", key) + continue + } + if syncer.IsStopped() || syncer.IsShuttingDown() { + t.Errorf("Expect syncer %q to be running.", key) + } + } + for key, syncer := range manager.syncerMap { + if tc.expect.Has(key) { + continue + } + if !syncer.IsStopped() { + t.Errorf("Expect syncer %q to be stopped.", key) + } + } + } + + // make sure there is no leaking go routine + manager.StopSyncer("ns1", "n1") + manager.StopSyncer("ns2", "n1") +} + +func TestGarbageCollectionSyncer(t *testing.T) { + manager := NewTestSyncerManager(fake.NewSimpleClientset()) + if err := manager.EnsureSyncer("ns1", "n1", sets.NewString("80", "namedport")); err != nil { + t.Fatalf("Failed to ensure syncer: %v", err) + } + manager.StopSyncer("ns1", "n1") + + syncer1 := manager.syncerMap[encodeSyncerKey("ns1", "n1", "80")] + syncer2 := manager.syncerMap[encodeSyncerKey("ns1", "n1", "namedport")] + + if err := wait.PollImmediate(time.Second, 30*time.Second, func() (bool, error) { + return !syncer1.IsShuttingDown() && syncer1.IsStopped() && !syncer2.IsShuttingDown() && syncer2.IsStopped(), nil + }); err != nil { + t.Fatalf("Syncer failed to shutdown: %v", err) + } + + if err := manager.GC(); err != nil { + t.Fatalf("Failed to GC: %v", err) + } + + if len(manager.syncerMap) != 0 { + t.Fatalf("Expect 0 syncers left, but got %v", len(manager.syncerMap)) + } +} + +func TestGarbageCollectionNEG(t *testing.T) { + kubeClient := fake.NewSimpleClientset() + if _, err := kubeClient.Core().Endpoints(ServiceNamespace).Create(getDefaultEndpoint()); err != nil { + t.Fatalf("Failed to create endpoint: %v", err) + } + manager := NewTestSyncerManager(kubeClient) + if err := manager.EnsureSyncer(ServiceNamespace, ServiceName, sets.NewString("80")); err != nil { + t.Fatalf("Failed to ensure syncer: %v", err) + } + + negName := manager.namer.NEGName("test", "test", "80") + manager.cloud.CreateNetworkEndpointGroup(&compute.NetworkEndpointGroup{ + Name: negName, + }, TestZone1) + + if err := manager.GC(); err != nil { + t.Fatalf("Failed to GC: %v", err) + } + + negs, _ := manager.cloud.ListNetworkEndpointGroup(TestZone1) + for _, neg := range negs { + if neg.Name == negName { + t.Errorf("Expect NEG %q to be GCed.", negName) + } + } + + // make sure there is no leaking go routine + manager.StopSyncer(ServiceNamespace, ServiceName) +} diff --git a/controllers/gce/networkendpointgroup/syncer.go b/controllers/gce/networkendpointgroup/syncer.go new file mode 100644 index 000000000..8932890ba --- /dev/null +++ b/controllers/gce/networkendpointgroup/syncer.go @@ -0,0 +1,524 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package networkendpointgroup + +import ( + "fmt" + "math" + "strconv" + "strings" + "sync" + "time" + + "github.com/golang/glog" + compute "google.golang.org/api/compute/v0.alpha" + apiv1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/clock" + utilerrors "k8s.io/apimachinery/pkg/util/errors" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/client-go/tools/cache" + "k8s.io/client-go/tools/record" + "k8s.io/kubernetes/pkg/cloudprovider/providers/gce" +) + +const ( + MAX_NETWORK_ENDPOINTS_PER_BATCH = 500 + minRetryDelay = 5 * time.Second + maxRetryDelay = 300 * time.Second +) + +// servicePort includes information to uniquely identify a NEG +type servicePort struct { + namespace string + name string + // Serivice target port + targetPort string +} + +// syncer handles synchorizing NEGs for one service port. It handles sync, resync and retry on error. +type syncer struct { + servicePort + negName string + + serviceLister cache.Indexer + endpointLister cache.Indexer + + recorder record.EventRecorder + cloud NetworkEndpointGroupCloud + zoneGetter ZoneGetter + + stateLock sync.Mutex + stopped bool + shuttingDown bool + + clock clock.Clock + syncCh chan interface{} + lastRetryDelay time.Duration + retryCount int +} + +func newSyncer(svcPort servicePort, networkEndpointGroupName string, recorder record.EventRecorder, cloud NetworkEndpointGroupCloud, zoneGetter ZoneGetter, serviceLister cache.Indexer, endpointLister cache.Indexer) *syncer { + glog.V(2).Infof("New syncer for service %s/%s port %s NEG %q", svcPort.namespace, svcPort.name, svcPort.targetPort, networkEndpointGroupName) + return &syncer{ + servicePort: svcPort, + negName: networkEndpointGroupName, + recorder: recorder, + serviceLister: serviceLister, + cloud: cloud, + endpointLister: endpointLister, + zoneGetter: zoneGetter, + stopped: true, + shuttingDown: false, + clock: clock.RealClock{}, + lastRetryDelay: time.Duration(0), + retryCount: 0, + } +} + +func (s *syncer) init() { + s.stateLock.Lock() + defer s.stateLock.Unlock() + s.stopped = false + s.syncCh = make(chan interface{}, 1) +} + +// Start starts the syncer go routine if it has not been started. +func (s *syncer) Start() error { + if !s.IsStopped() { + return fmt.Errorf("NEG syncer for %s/%s-%s is already running.", s.namespace, s.name, s.targetPort) + } + if s.IsShuttingDown() { + return fmt.Errorf("NEG syncer for %s/%s-%s is shutting down. ", s.namespace, s.name, s.targetPort) + } + + glog.V(2).Infof("Starting NEG syncer for service port %s/%s-%s", s.namespace, s.name, s.targetPort) + s.init() + go func() { + for { + // equivalent to never retry + retryCh := make(<-chan time.Time) + err := s.sync() + if err != nil { + retryMesg := "" + if s.retryCount > maxRetries { + retryMesg = "(will not retry)" + } else { + retryCh = s.clock.After(s.nextRetryDelay()) + retryMesg = "(will retry)" + } + + if svc := getService(s.serviceLister, s.namespace, s.name); svc != nil { + s.recorder.Eventf(svc, apiv1.EventTypeWarning, "SyncNetworkEndpiontGroupFailed", "Failed to sync NEG %q %s: %v", s.negName, retryMesg, err) + } + } else { + s.resetRetryDelay() + } + + select { + case _, open := <-s.syncCh: + if !open { + s.stateLock.Lock() + s.shuttingDown = false + s.stateLock.Unlock() + glog.V(2).Infof("Stopping NEG syncer for %s/%s-%s", s.namespace, s.name, s.targetPort) + return + } + case <-retryCh: + // continue to sync + } + } + }() + return nil +} + +// Stop stops syncer and return only when syncer shutdown completes. +func (s *syncer) Stop() { + s.stateLock.Lock() + defer s.stateLock.Unlock() + if !s.stopped { + s.stopped = true + s.shuttingDown = true + close(s.syncCh) + } +} + +// Sync informs syncer to run sync loop as soon as possible. +func (s *syncer) Sync() bool { + if s.IsStopped() { + glog.Warningf("NEG syncer for %s/%s-%s is already stopped.", s.namespace, s.name, s.targetPort) + return false + } + glog.V(2).Infof("=======Sync %s/%s-%s", s.namespace, s.name, s.targetPort) + select { + case s.syncCh <- struct{}{}: + return true + default: + return false + } +} + +func (s *syncer) IsStopped() bool { + s.stateLock.Lock() + defer s.stateLock.Unlock() + return s.stopped +} + +func (s *syncer) IsShuttingDown() bool { + s.stateLock.Lock() + defer s.stateLock.Unlock() + return s.shuttingDown +} + +func (s *syncer) sync() error { + if s.IsStopped() || s.IsShuttingDown() { + glog.V(4).Infof("Skip syncing NEG %q for %s/%s-%s.", s.negName, s.namespace, s.name, s.targetPort) + return nil + } + + glog.V(2).Infof("Sync NEG %q for %s/%s-%s", s.negName, s.namespace, s.name, s.targetPort) + ep, exists, err := s.endpointLister.Get( + &apiv1.Endpoints{ + ObjectMeta: metav1.ObjectMeta{ + Name: s.name, + Namespace: s.namespace, + }, + }, + ) + if err != nil { + return err + } + + if !exists { + glog.Warningf("Endpoint %s/%s does not exists. Skipping NEG sync") + return nil + } + + err = s.ensureNetworkEndpointGroups() + if err != nil { + return err + } + + targetMap, err := s.toZoneNetworkEndpointMap(ep.(*apiv1.Endpoints)) + if err != nil { + return err + } + + currentMap, err := s.retrieveExistingZoneNetworkEndpointMap() + if err != nil { + return err + } + + addEndpoints, removeEndpoints := calculateDifference(targetMap, currentMap) + if len(addEndpoints) == 0 && len(removeEndpoints) == 0 { + glog.V(4).Infof("No endpoint change for %s/%s, skip syncing NEG. ", s.namespace, s.name) + return nil + } + + return s.syncNetworkEndpoints(addEndpoints, removeEndpoints) +} + +// ensureNetworkEndpointGroups ensures negs are created in the related zones. +func (s *syncer) ensureNetworkEndpointGroups() error { + var err error + zones, err := s.zoneGetter.ListZones() + if err != nil { + return err + } + + var errList []error + for _, zone := range zones { + // Assume error is caused by not existing + neg, err := s.cloud.GetNetworkEndpointGroup(s.negName, zone) + if err != nil { + // Most likely to be caused by non-existed NEG + glog.V(4).Infof("Error while retriving %q in zone %q: %v", s.negName, zone, err) + } + + needToCreate := false + if neg == nil { + needToCreate = true + } else if retrieveName(neg.LoadBalancer.Network) != retrieveName(s.cloud.NetworkURL()) || + retrieveName(neg.LoadBalancer.Subnetwork) != retrieveName(s.cloud.SubnetworkURL()) { + // Only compare network and subnetwork names to avoid api endpoint differences that cause deleting NEG accidentally. + // TODO: change to compare network/subnetwork url instead of name when NEG API reach GA. + needToCreate = true + glog.V(2).Infof("NEG %q in %q does not match network and subnetwork of the cluster. Deleting NEG.", s.negName, zone) + err = s.cloud.DeleteNetworkEndpointGroup(s.negName, zone) + if err != nil { + errList = append(errList, err) + } else { + if svc := getService(s.serviceLister, s.namespace, s.name); svc != nil { + s.recorder.Eventf(svc, apiv1.EventTypeNormal, "Delete", "Deleted NEG %q for %s/%s-%s in %q.", s.negName, s.namespace, s.name, s.targetPort, zone) + } + } + } + + if needToCreate { + glog.V(2).Infof("Creating NEG %q for %s/%s in %q.", s.negName, s.namespace, s.name, zone) + err = s.cloud.CreateNetworkEndpointGroup(&compute.NetworkEndpointGroup{ + Name: s.negName, + Type: gce.NEGLoadBalancerType, + NetworkEndpointType: gce.NEGIPPortNetworkEndpointType, + LoadBalancer: &compute.NetworkEndpointGroupLbNetworkEndpointGroup{ + Network: s.cloud.NetworkURL(), + Subnetwork: s.cloud.SubnetworkURL(), + }, + }, zone) + if err != nil { + errList = append(errList, err) + } else { + if svc := getService(s.serviceLister, s.namespace, s.name); svc != nil { + s.recorder.Eventf(svc, apiv1.EventTypeNormal, "Create", "Created NEG %q for %s/%s-%s in %q.", s.negName, s.namespace, s.name, s.targetPort, zone) + } + } + } + } + return utilerrors.NewAggregate(errList) +} + +// toZoneNetworkEndpointMap translates addresses in endpoints object into zone and endpoints map +func (s *syncer) toZoneNetworkEndpointMap(endpoints *apiv1.Endpoints) (map[string]sets.String, error) { + zoneNetworkEndpointMap := map[string]sets.String{} + targetPort, _ := strconv.Atoi(s.targetPort) + for _, subset := range endpoints.Subsets { + matchPort := "" + // service spec allows target port to be a named port. + // support both explicit port and named port. + for _, port := range subset.Ports { + if targetPort != 0 { + // targetPort is int + if int(port.Port) == targetPort { + matchPort = s.targetPort + } + } else { + // targetPort is string + if port.Name == s.targetPort { + matchPort = strconv.Itoa(int(port.Port)) + } + } + if len(matchPort) > 0 { + break + } + } + + // subset does not contain target port + if len(matchPort) == 0 { + continue + } + for _, address := range subset.Addresses { + zone, err := s.zoneGetter.GetZoneForNode(*address.NodeName) + if err != nil { + return nil, err + } + if zoneNetworkEndpointMap[zone] == nil { + zoneNetworkEndpointMap[zone] = sets.String{} + } + zoneNetworkEndpointMap[zone].Insert(encodeEndpoint(address.IP, *address.NodeName, matchPort)) + } + } + return zoneNetworkEndpointMap, nil +} + +// retrieveExistingZoneNetworkEndpointMap lists existing network endpoints in the neg and return the zone and endpoints map +func (s *syncer) retrieveExistingZoneNetworkEndpointMap() (map[string]sets.String, error) { + zones, err := s.zoneGetter.ListZones() + if err != nil { + return nil, err + } + + zoneNetworkEndpointMap := map[string]sets.String{} + for _, zone := range zones { + zoneNetworkEndpointMap[zone] = sets.String{} + networkEndpointsWithHealthStatus, err := s.cloud.ListNetworkEndpoints(s.negName, zone, false) + if err != nil { + return nil, err + } + for _, ne := range networkEndpointsWithHealthStatus { + zoneNetworkEndpointMap[zone].Insert(encodeEndpoint(ne.NetworkEndpoint.IpAddress, ne.NetworkEndpoint.Instance, strconv.FormatInt(ne.NetworkEndpoint.Port, 10))) + } + } + return zoneNetworkEndpointMap, nil +} + +type ErrorList struct { + errList []error + lock sync.Mutex +} + +func (e *ErrorList) Add(err error) { + e.lock.Lock() + defer e.lock.Unlock() + e.errList = append(e.errList, err) +} + +func (e *ErrorList) List() []error { + e.lock.Lock() + defer e.lock.Unlock() + return e.errList +} + +// syncNetworkEndpoints adds and removes endpoints for negs +func (s *syncer) syncNetworkEndpoints(addEndpoints, removeEndpoints map[string]sets.String) error { + var wg sync.WaitGroup + errList := &ErrorList{} + + // Detach Endpoints + for zone, endpointSet := range removeEndpoints { + for { + if endpointSet.Len() == 0 { + break + } + networkEndpoints, err := s.toNetworkEndpointBatch(endpointSet) + if err != nil { + return err + } + s.detachNetworkEndpoints(&wg, zone, networkEndpoints, errList) + } + } + + // Attach Endpoints + for zone, endpointSet := range addEndpoints { + for { + if endpointSet.Len() == 0 { + break + } + networkEndpoints, err := s.toNetworkEndpointBatch(endpointSet) + if err != nil { + return err + } + s.attachNetworkEndpoints(&wg, zone, networkEndpoints, errList) + } + } + wg.Wait() + return utilerrors.NewAggregate(errList.List()) +} + +// translate a endpoints set to a batch of network endpoints object +func (s *syncer) toNetworkEndpointBatch(endpoints sets.String) ([]*compute.NetworkEndpoint, error) { + var ok bool + list := make([]string, int(math.Min(float64(endpoints.Len()), float64(MAX_NETWORK_ENDPOINTS_PER_BATCH)))) + for i := range list { + list[i], ok = endpoints.PopAny() + if !ok { + break + } + } + networkEndpointList := make([]*compute.NetworkEndpoint, len(list)) + for i, enc := range list { + ip, instance, port := decodeEndpoint(enc) + portNum, err := strconv.Atoi(port) + if err != nil { + return nil, fmt.Errorf("Failed to decode endpoint %q: %v", enc, err) + } + networkEndpointList[i] = &compute.NetworkEndpoint{ + Instance: instance, + IpAddress: ip, + Port: int64(portNum), + } + } + return networkEndpointList, nil +} + +func (s *syncer) attachNetworkEndpoints(wg *sync.WaitGroup, zone string, networkEndpoints []*compute.NetworkEndpoint, errList *ErrorList) { + wg.Add(1) + go s.operationInternal(wg, zone, networkEndpoints, errList, s.cloud.AttachNetworkEndpoints, "Attach") +} + +func (s *syncer) detachNetworkEndpoints(wg *sync.WaitGroup, zone string, networkEndpoints []*compute.NetworkEndpoint, errList *ErrorList) { + wg.Add(1) + go s.operationInternal(wg, zone, networkEndpoints, errList, s.cloud.DetachNetworkEndpoints, "Detach") +} + +func (s *syncer) operationInternal(wg *sync.WaitGroup, zone string, networkEndpoints []*compute.NetworkEndpoint, errList *ErrorList, syncFunc func(name, zone string, endpoints []*compute.NetworkEndpoint) error, operationName string) { + defer wg.Done() + err := syncFunc(s.negName, zone, networkEndpoints) + if err != nil { + errList.Add(err) + } + if svc := getService(s.serviceLister, s.namespace, s.name); svc != nil { + if err == nil { + s.recorder.Eventf(svc, apiv1.EventTypeNormal, operationName, "%s %d network endpoints to NEG %q in %q.", operationName, len(networkEndpoints), s.negName, zone) + } else { + s.recorder.Eventf(svc, apiv1.EventTypeWarning, operationName+"Failed", "Failed to %s %d network endpoints to NEG %q in %q: %v", operationName, len(networkEndpoints), s.negName, zone, err) + } + } +} + +func (s *syncer) nextRetryDelay() time.Duration { + s.retryCount += 1 + s.lastRetryDelay *= 2 + if s.lastRetryDelay < minRetryDelay { + s.lastRetryDelay = minRetryDelay + } else if s.lastRetryDelay > maxRetryDelay { + s.lastRetryDelay = maxRetryDelay + } + return s.lastRetryDelay +} + +func (s *syncer) resetRetryDelay() { + s.retryCount = 0 + s.lastRetryDelay = time.Duration(0) +} + +// encodeEndpoint encodes ip and instance into a single string +func encodeEndpoint(ip, instance, port string) string { + return fmt.Sprintf("%s||%s||%s", ip, instance, port) +} + +// decodeEndpoint decodes ip and instance from an encoded string +func decodeEndpoint(str string) (string, string, string) { + strs := strings.Split(str, "||") + return strs[0], strs[1], strs[2] +} + +// calculateDifference determines what endpoints needs to be added and removed in order to move current state to target state. +func calculateDifference(targetMap, currentMap map[string]sets.String) (map[string]sets.String, map[string]sets.String) { + addSet := map[string]sets.String{} + removeSet := map[string]sets.String{} + for zone, endpointSet := range targetMap { + diff := endpointSet.Difference(currentMap[zone]) + if len(diff) > 0 { + addSet[zone] = diff + } + } + + for zone, endpointSet := range currentMap { + diff := endpointSet.Difference(targetMap[zone]) + if len(diff) > 0 { + removeSet[zone] = diff + } + } + return addSet, removeSet +} + +func retrieveName(url string) string { + strs := strings.Split(url, "/") + return strs[len(strs)-1] +} + +// getService retrieves service object from serviceLister based on the input namespace and name +func getService(serviceLister cache.Indexer, namespace, name string) *apiv1.Service { + service, exists, err := serviceLister.GetByKey(serviceKeyFunc(namespace, name)) + if exists && err == nil { + return service.(*apiv1.Service) + } + if err != nil { + glog.Errorf("Failed to retrieve service %s/%s from store: %v", namespace, name, err) + } + return nil +} diff --git a/controllers/gce/networkendpointgroup/syncer_test.go b/controllers/gce/networkendpointgroup/syncer_test.go new file mode 100644 index 000000000..0abe0df37 --- /dev/null +++ b/controllers/gce/networkendpointgroup/syncer_test.go @@ -0,0 +1,439 @@ +package networkendpointgroup + +import ( + apiv1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/kubernetes/fake" + "k8s.io/client-go/tools/record" + "k8s.io/ingress/controllers/gce/utils" + "reflect" + "testing" + "time" +) + +const ( + NegName = "test-neg-name" + ServiceNamespace = "test-ns" + ServiceName = "test-name" + NamedPort = "named-port" +) + +func NewTestSyncer() *syncer { + kubeClient := fake.NewSimpleClientset() + context := utils.NewControllerContext(kubeClient, apiv1.NamespaceAll, 1*time.Second, true) + svcPort := servicePort{ + namespace: ServiceNamespace, + name: ServiceName, + targetPort: "80", + } + + return newSyncer(svcPort, + NegName, + record.NewFakeRecorder(100), + NewFakeNetworkEndpointGroupCloud("test-subnetwork", "test-newtork"), + NewFakeZoneGetter(), + context.ServiceInformer.GetIndexer(), + context.EndpointInformer.GetIndexer()) +} + +func TestStartAndStopSyncer(t *testing.T) { + syncer := NewTestSyncer() + if !syncer.IsStopped() { + t.Fatalf("Syncer is not stopped after creation.") + } + if syncer.IsShuttingDown() { + t.Fatalf("Syncer is shutting down after creation.") + } + + if err := syncer.Start(); err != nil { + t.Fatalf("Failed to start syncer: %v", err) + } + if syncer.IsStopped() { + t.Fatalf("Syncer is stopped after Start.") + } + if syncer.IsShuttingDown() { + t.Fatalf("Syncer is shutting down after Start.") + } + + syncer.Stop() + if !syncer.IsStopped() { + t.Fatalf("Syncer is not stopped after Stop.") + } + + if err := wait.PollImmediate(time.Second, 30*time.Second, func() (bool, error) { + return !syncer.IsShuttingDown() && syncer.IsStopped(), nil + }); err != nil { + t.Fatalf("Syncer failed to shutdown: %v", err) + } + + if err := syncer.Start(); err != nil { + t.Fatalf("Failed to restart syncer: %v", err) + } + if syncer.IsStopped() { + t.Fatalf("Syncer is stopped after restart.") + } + if syncer.IsShuttingDown() { + t.Fatalf("Syncer is shutting down after restart.") + } + + syncer.Stop() + if !syncer.IsStopped() { + t.Fatalf("Syncer is not stopped after Stop.") + } +} + +func TestEnsureNetworkEndpointGroups(t *testing.T) { + syncer := NewTestSyncer() + if err := syncer.ensureNetworkEndpointGroups(); err != nil { + t.Errorf("Failed to ensure NEGs: %v", err) + } + + ret, _ := syncer.cloud.AggregatedListNetworkEndpointGroup() + expectZones := []string{TestZone1, TestZone2} + for _, zone := range expectZones { + negs, ok := ret[zone] + if !ok { + t.Errorf("Failed to find zone %q from ret %v", zone, ret) + continue + } + + if len(negs) != 1 { + t.Errorf("Unexpected negs %v", negs) + } else { + if negs[0].Name != NegName { + t.Errorf("Unexpected neg %q", negs[0].Name) + } + } + } +} + +func TestToZoneNetworkEndpointMap(t *testing.T) { + syncer := NewTestSyncer() + testCases := []struct { + targetPort string + expect map[string]sets.String + }{ + { + targetPort: "80", + expect: map[string]sets.String{ + TestZone1: sets.NewString("10.100.1.1||instance1||80", "10.100.1.2||instance1||80", "10.100.2.1||instance2||80"), + TestZone2: sets.NewString("10.100.3.1||instance3||80"), + }, + }, + { + targetPort: NamedPort, + expect: map[string]sets.String{ + TestZone1: sets.NewString("10.100.2.2||instance2||81"), + TestZone2: sets.NewString("10.100.4.1||instance4||81", "10.100.3.2||instance3||8081", "10.100.4.2||instance4||8081"), + }, + }, + } + + for _, tc := range testCases { + syncer.targetPort = tc.targetPort + res, _ := syncer.toZoneNetworkEndpointMap(getDefaultEndpoint()) + + if !reflect.DeepEqual(res, tc.expect) { + t.Errorf("Expect %v, but got %v.", tc.expect, res) + } + } +} + +func TestEncodeDecodeEndpoint(t *testing.T) { + ip := "10.0.0.10" + instance := "somehost" + port := "8080" + + retIp, retInstance, retPort := decodeEndpoint(encodeEndpoint(ip, instance, port)) + + if ip != retIp || instance != retInstance || retPort != port { + t.Fatalf("Encode and decode endpoint failed. Expect %q, %q, %q but got %q, %q, %q.", ip, instance, port, retIp, retInstance, retPort) + } +} + +func TestCalculateDifference(t *testing.T) { + testCases := []struct { + targetSet map[string]sets.String + currentSet map[string]sets.String + addSet map[string]sets.String + removeSet map[string]sets.String + }{ + // unchanged + { + targetSet: map[string]sets.String{ + TestZone1: sets.NewString("a", "b", "c"), + }, + currentSet: map[string]sets.String{ + TestZone1: sets.NewString("a", "b", "c"), + }, + addSet: map[string]sets.String{}, + removeSet: map[string]sets.String{}, + }, + // unchanged + { + targetSet: map[string]sets.String{}, + currentSet: map[string]sets.String{}, + addSet: map[string]sets.String{}, + removeSet: map[string]sets.String{}, + }, + // add in one zone + { + targetSet: map[string]sets.String{ + TestZone1: sets.NewString("a", "b", "c"), + }, + currentSet: map[string]sets.String{}, + addSet: map[string]sets.String{ + TestZone1: sets.NewString("a", "b", "c"), + }, + removeSet: map[string]sets.String{}, + }, + // add in 2 zones + { + targetSet: map[string]sets.String{ + TestZone1: sets.NewString("a", "b", "c"), + TestZone2: sets.NewString("e", "f", "g"), + }, + currentSet: map[string]sets.String{}, + addSet: map[string]sets.String{ + TestZone1: sets.NewString("a", "b", "c"), + TestZone2: sets.NewString("e", "f", "g"), + }, + removeSet: map[string]sets.String{}, + }, + // remove in one zone + { + targetSet: map[string]sets.String{}, + currentSet: map[string]sets.String{ + TestZone1: sets.NewString("a", "b", "c"), + }, + addSet: map[string]sets.String{}, + removeSet: map[string]sets.String{ + TestZone1: sets.NewString("a", "b", "c"), + }, + }, + // remove in 2 zones + { + targetSet: map[string]sets.String{}, + currentSet: map[string]sets.String{ + TestZone1: sets.NewString("a", "b", "c"), + TestZone2: sets.NewString("e", "f", "g"), + }, + addSet: map[string]sets.String{}, + removeSet: map[string]sets.String{ + TestZone1: sets.NewString("a", "b", "c"), + TestZone2: sets.NewString("e", "f", "g"), + }, + }, + // add and delete in one zone + { + targetSet: map[string]sets.String{ + TestZone1: sets.NewString("a", "b", "c"), + }, + currentSet: map[string]sets.String{ + TestZone1: sets.NewString("b", "c", "d"), + }, + addSet: map[string]sets.String{ + TestZone1: sets.NewString("a"), + }, + removeSet: map[string]sets.String{ + TestZone1: sets.NewString("d"), + }, + }, + // add and delete in 2 zones + { + targetSet: map[string]sets.String{ + TestZone1: sets.NewString("a", "b", "c"), + TestZone2: sets.NewString("a", "b", "c"), + }, + currentSet: map[string]sets.String{ + TestZone1: sets.NewString("b", "c", "d"), + TestZone2: sets.NewString("b", "c", "d"), + }, + addSet: map[string]sets.String{ + TestZone1: sets.NewString("a"), + TestZone2: sets.NewString("a"), + }, + removeSet: map[string]sets.String{ + TestZone1: sets.NewString("d"), + TestZone2: sets.NewString("d"), + }, + }, + } + + for _, tc := range testCases { + addSet, removeSet := calculateDifference(tc.targetSet, tc.currentSet) + + if !reflect.DeepEqual(addSet, tc.addSet) { + t.Errorf("Failed to calculate difference for add, expecting %v, but got %v", tc.addSet, addSet) + } + + if !reflect.DeepEqual(removeSet, tc.removeSet) { + t.Errorf("Failed to calculate difference for remove, expecting %v, but got %v", tc.removeSet, removeSet) + } + } +} + +func TestSyncNetworkEndpoints(t *testing.T) { + syncer := NewTestSyncer() + if err := syncer.ensureNetworkEndpointGroups(); err != nil { + t.Fatalf("Failed to ensure NEG: %v", err) + } + + testCases := []struct { + expectSet map[string]sets.String + addSet map[string]sets.String + removeSet map[string]sets.String + }{ + { + expectSet: map[string]sets.String{ + TestZone1: sets.NewString("10.100.1.1||instance1||80", "10.100.2.1||instance2||80"), + TestZone2: sets.NewString("10.100.3.1||instance3||80", "10.100.4.1||instance4||80"), + }, + addSet: map[string]sets.String{ + TestZone1: sets.NewString("10.100.1.1||instance1||80", "10.100.2.1||instance2||80"), + TestZone2: sets.NewString("10.100.3.1||instance3||80", "10.100.4.1||instance4||80"), + }, + removeSet: map[string]sets.String{}, + }, + { + expectSet: map[string]sets.String{ + TestZone1: sets.NewString("10.100.1.2||instance1||80"), + TestZone2: sets.NewString(), + }, + addSet: map[string]sets.String{ + TestZone1: sets.NewString("10.100.1.2||instance1||80"), + }, + removeSet: map[string]sets.String{ + TestZone1: sets.NewString("10.100.1.1||instance1||80", "10.100.2.1||instance2||80"), + TestZone2: sets.NewString("10.100.3.1||instance3||80", "10.100.4.1||instance4||80"), + }, + }, + { + expectSet: map[string]sets.String{ + TestZone1: sets.NewString("10.100.1.2||instance1||80"), + TestZone2: sets.NewString("10.100.3.2||instance3||80"), + }, + addSet: map[string]sets.String{ + TestZone2: sets.NewString("10.100.3.2||instance3||80"), + }, + removeSet: map[string]sets.String{}, + }, + } + + for _, tc := range testCases { + if err := syncer.syncNetworkEndpoints(tc.addSet, tc.removeSet); err != nil { + t.Fatalf("Failed to sync network endpoints: %v", err) + } + examineNetworkEndpoints(tc.expectSet, syncer, t) + } +} + +func examineNetworkEndpoints(expectSet map[string]sets.String, syncer *syncer, t *testing.T) { + for zone, endpoints := range expectSet { + expectEndpoints, err := syncer.toNetworkEndpointBatch(endpoints) + if err != nil { + t.Fatalf("Failed to convert endpoints to network endpoints: %v", err) + } + if cloudEndpoints, err := syncer.cloud.ListNetworkEndpoints(syncer.negName, zone, false); err == nil { + if len(expectEndpoints) != len(cloudEndpoints) { + t.Errorf("Expect number of endpoints to be %v, but got %v.", len(expectEndpoints), len(cloudEndpoints)) + } + for _, expectEp := range expectEndpoints { + found := false + for _, cloudEp := range cloudEndpoints { + if reflect.DeepEqual(*expectEp, *cloudEp.NetworkEndpoint) { + found = true + break + } + } + if !found { + t.Errorf("Endpoint %v not found.", expectEp) + } + } + } else { + t.Errorf("Failed to list network endpoints in zone %q: %v.", zone, err) + } + } +} + +func getDefaultEndpoint() *apiv1.Endpoints { + instance1 := TestInstance1 + instance2 := TestInstance2 + instance3 := TestInstance3 + instance4 := TestInstance4 + return &apiv1.Endpoints{ + ObjectMeta: metav1.ObjectMeta{ + Name: ServiceName, + Namespace: ServiceNamespace, + }, + Subsets: []apiv1.EndpointSubset{ + { + Addresses: []apiv1.EndpointAddress{ + { + IP: "10.100.1.1", + NodeName: &instance1, + }, + { + IP: "10.100.1.2", + NodeName: &instance1, + }, + { + IP: "10.100.2.1", + NodeName: &instance2, + }, + { + IP: "10.100.3.1", + NodeName: &instance3, + }, + }, + Ports: []apiv1.EndpointPort{ + { + Name: "", + Port: int32(80), + Protocol: apiv1.ProtocolTCP, + }, + }, + }, + { + Addresses: []apiv1.EndpointAddress{ + { + IP: "10.100.2.2", + NodeName: &instance2, + }, + { + IP: "10.100.4.1", + NodeName: &instance4, + }, + }, + Ports: []apiv1.EndpointPort{ + { + Name: NamedPort, + Port: int32(81), + Protocol: apiv1.ProtocolTCP, + }, + }, + }, + { + Addresses: []apiv1.EndpointAddress{ + { + IP: "10.100.3.2", + NodeName: &instance3, + }, + { + IP: "10.100.4.2", + NodeName: &instance4, + }, + }, + Ports: []apiv1.EndpointPort{ + { + Name: NamedPort, + Port: int32(8081), + Protocol: apiv1.ProtocolTCP, + }, + }, + }, + }, + } +}