diff --git a/internal/ingress/annotations/ratelimit/main.go b/internal/ingress/annotations/ratelimit/main.go index 3decd59d3..7b7d6f4db 100644 --- a/internal/ingress/annotations/ratelimit/main.go +++ b/internal/ingress/annotations/ratelimit/main.go @@ -19,7 +19,6 @@ package ratelimit import ( "encoding/base64" "fmt" - "sort" "strings" networking "k8s.io/api/networking/v1beta1" @@ -164,7 +163,7 @@ func (a ratelimit) Parse(ing *networking.Ingress) (interface{}, error) { val, _ := parser.GetStringAnnotation("limit-whitelist", ing) - cidrs, err := parseCIDRs(val) + cidrs, err := net.ParseCIDRs(val) if err != nil { return nil, err } @@ -208,32 +207,6 @@ func (a ratelimit) Parse(ing *networking.Ingress) (interface{}, error) { }, nil } -func parseCIDRs(s string) ([]string, error) { - if s == "" { - return []string{}, nil - } - - values := strings.Split(s, ",") - - ipnets, ips, err := net.ParseIPNets(values...) - if err != nil { - return nil, err - } - - cidrs := []string{} - for k := range ipnets { - cidrs = append(cidrs, k) - } - - for k := range ips { - cidrs = append(cidrs, k) - } - - sort.Strings(cidrs) - - return cidrs, nil -} - func encode(s string) string { str := base64.URLEncoding.EncodeToString([]byte(s)) return strings.Replace(str, "=", "", -1) diff --git a/internal/ingress/annotations/ratelimit/main_test.go b/internal/ingress/annotations/ratelimit/main_test.go index 1975ddb96..7ffbac3ff 100644 --- a/internal/ingress/annotations/ratelimit/main_test.go +++ b/internal/ingress/annotations/ratelimit/main_test.go @@ -17,8 +17,6 @@ limitations under the License. package ratelimit import ( - "reflect" - "sort" "testing" api "k8s.io/api/core/v1" @@ -85,23 +83,6 @@ func TestWithoutAnnotations(t *testing.T) { } } -func TestParseCIDRs(t *testing.T) { - cidr, _ := parseCIDRs("invalid.com") - if cidr != nil { - t.Errorf("expected %v but got %v", nil, cidr) - } - - expected := []string{"192.0.0.1", "192.0.1.0/24"} - cidr, err := parseCIDRs("192.0.0.1, 192.0.1.0/24") - if err != nil { - t.Errorf("unexpected error %v", err) - } - sort.Strings(cidr) - if !reflect.DeepEqual(expected, cidr) { - t.Errorf("expected %v but got %v", expected, cidr) - } -} - func TestRateLimiting(t *testing.T) { ing := buildIngress() diff --git a/internal/net/ipnet.go b/internal/net/ipnet.go index b48d13db8..1cd8832ca 100644 --- a/internal/net/ipnet.go +++ b/internal/net/ipnet.go @@ -18,6 +18,7 @@ package net import ( "net" + "sort" "strings" ) @@ -51,3 +52,30 @@ func ParseIPNets(specs ...string) (IPNet, IP, error) { return ipnetset, ipset, nil } + +// ParseCIDRs parses comma separated CIDRs into a sorted string array +func ParseCIDRs(s string) ([]string, error) { + if s == "" { + return []string{}, nil + } + + values := strings.Split(s, ",") + + ipnets, ips, err := ParseIPNets(values...) + if err != nil { + return nil, err + } + + cidrs := []string{} + for k := range ipnets { + cidrs = append(cidrs, k) + } + + for k := range ips { + cidrs = append(cidrs, k) + } + + sort.Strings(cidrs) + + return cidrs, nil +} diff --git a/internal/net/ipnet_test.go b/internal/net/ipnet_test.go index 3ce1345c6..95e6b9c32 100644 --- a/internal/net/ipnet_test.go +++ b/internal/net/ipnet_test.go @@ -17,6 +17,8 @@ limitations under the License. package net import ( + "reflect" + "sort" "testing" ) @@ -32,3 +34,20 @@ func TestNewIPSet(t *testing.T) { t.Errorf("Expected len=1: %d", len(ips)) } } + +func TestParseCIDRs(t *testing.T) { + cidr, _ := ParseCIDRs("invalid.com") + if cidr != nil { + t.Errorf("expected %v but got %v", nil, cidr) + } + + expected := []string{"192.0.0.1", "192.0.1.0/24"} + cidr, err := ParseCIDRs("192.0.0.1, 192.0.1.0/24") + if err != nil { + t.Errorf("unexpected error %v", err) + } + sort.Strings(cidr) + if !reflect.DeepEqual(expected, cidr) { + t.Errorf("expected %v but got %v", expected, cidr) + } +} diff --git a/rootfs/etc/nginx/lua/test/balancer/chash_test.lua b/rootfs/etc/nginx/lua/test/balancer/chash_test.lua index 94379e7dd..dda3f848c 100644 --- a/rootfs/etc/nginx/lua/test/balancer/chash_test.lua +++ b/rootfs/etc/nginx/lua/test/balancer/chash_test.lua @@ -1,15 +1,12 @@ -function mock_ngx(mock) - local _ngx = mock - setmetatable(_ngx, {__index = _G.ngx}) - _G.ngx = _ngx -end - describe("Balancer chash", function() + after_each(function() + reset_ngx() + end) describe("balance()", function() it("uses correct key for given backend", function() - mock_ngx({var = { request_uri = "/alma/armud"}}) - local balancer_chash = require("balancer.chash") + ngx.var = { request_uri = "/alma/armud"} + local balancer_chash = require_without_cache("balancer.chash") local resty_chash = package.loaded["resty.chash"] resty_chash.new = function(self, nodes) diff --git a/rootfs/etc/nginx/lua/test/run.lua b/rootfs/etc/nginx/lua/test/run.lua index 8e20ea3f5..d43ce8612 100644 --- a/rootfs/etc/nginx/lua/test/run.lua +++ b/rootfs/etc/nginx/lua/test/run.lua @@ -11,6 +11,8 @@ do -- if there's more constants need to be whitelisted for test runs, add here. local GLOBALS_ALLOWED_IN_TEST = { helpers = true, + require_without_cache = true, + reset_ngx = true, } local newindex = function(table, key, value) rawset(table, key, value) @@ -69,6 +71,15 @@ end ngx.log = function(...) end ngx.print = function(...) end +local original_ngx = ngx +_G.reset_ngx = function() + ngx = original_ngx +end + +_G.require_without_cache = function(module) + package.loaded[module] = nil + return require(module) +end lua_ingress.init_worker() diff --git a/rootfs/etc/nginx/lua/test/util_test.lua b/rootfs/etc/nginx/lua/test/util_test.lua index 0b8d48ae4..1aca67fa1 100644 --- a/rootfs/etc/nginx/lua/test/util_test.lua +++ b/rootfs/etc/nginx/lua/test/util_test.lua @@ -1,27 +1,16 @@ -local original_ngx = ngx local util -local function reset_ngx() - _G.ngx = original_ngx -end - -local function mock_ngx(mock) - local _ngx = mock - setmetatable(_ngx, { __index = ngx }) - _G.ngx = _ngx -end - describe("utility", function() + before_each(function() + ngx.var = { remote_addr = "192.168.1.1", [1] = "nginx/regexp/1/group/capturing" } + util = require_without_cache("util") + end) after_each(function() reset_ngx() end) describe("ngx_complex_value", function() - before_each(function() - mock_ngx({ var = { remote_addr = "192.168.1.1", [1] = "nginx/regexp/1/group/capturing" } }) - util = require("util") - end) local ngx_complex_value = function(data) local ret, err = util.parse_complex_value(data)