generalize cidr parsing and improve lua tests
This commit is contained in:
parent
2254a91866
commit
2cff9fa41d
7 changed files with 68 additions and 70 deletions
|
@ -19,7 +19,6 @@ package ratelimit
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
networking "k8s.io/api/networking/v1beta1"
|
networking "k8s.io/api/networking/v1beta1"
|
||||||
|
@ -164,7 +163,7 @@ func (a ratelimit) Parse(ing *networking.Ingress) (interface{}, error) {
|
||||||
|
|
||||||
val, _ := parser.GetStringAnnotation("limit-whitelist", ing)
|
val, _ := parser.GetStringAnnotation("limit-whitelist", ing)
|
||||||
|
|
||||||
cidrs, err := parseCIDRs(val)
|
cidrs, err := net.ParseCIDRs(val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -208,32 +207,6 @@ func (a ratelimit) Parse(ing *networking.Ingress) (interface{}, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseCIDRs(s string) ([]string, error) {
|
|
||||||
if s == "" {
|
|
||||||
return []string{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
values := strings.Split(s, ",")
|
|
||||||
|
|
||||||
ipnets, ips, err := net.ParseIPNets(values...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cidrs := []string{}
|
|
||||||
for k := range ipnets {
|
|
||||||
cidrs = append(cidrs, k)
|
|
||||||
}
|
|
||||||
|
|
||||||
for k := range ips {
|
|
||||||
cidrs = append(cidrs, k)
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Strings(cidrs)
|
|
||||||
|
|
||||||
return cidrs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func encode(s string) string {
|
func encode(s string) string {
|
||||||
str := base64.URLEncoding.EncodeToString([]byte(s))
|
str := base64.URLEncoding.EncodeToString([]byte(s))
|
||||||
return strings.Replace(str, "=", "", -1)
|
return strings.Replace(str, "=", "", -1)
|
||||||
|
|
|
@ -17,8 +17,6 @@ limitations under the License.
|
||||||
package ratelimit
|
package ratelimit
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
|
||||||
"sort"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
api "k8s.io/api/core/v1"
|
api "k8s.io/api/core/v1"
|
||||||
|
@ -85,23 +83,6 @@ func TestWithoutAnnotations(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseCIDRs(t *testing.T) {
|
|
||||||
cidr, _ := parseCIDRs("invalid.com")
|
|
||||||
if cidr != nil {
|
|
||||||
t.Errorf("expected %v but got %v", nil, cidr)
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := []string{"192.0.0.1", "192.0.1.0/24"}
|
|
||||||
cidr, err := parseCIDRs("192.0.0.1, 192.0.1.0/24")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("unexpected error %v", err)
|
|
||||||
}
|
|
||||||
sort.Strings(cidr)
|
|
||||||
if !reflect.DeepEqual(expected, cidr) {
|
|
||||||
t.Errorf("expected %v but got %v", expected, cidr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRateLimiting(t *testing.T) {
|
func TestRateLimiting(t *testing.T) {
|
||||||
ing := buildIngress()
|
ing := buildIngress()
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ package net
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -51,3 +52,30 @@ func ParseIPNets(specs ...string) (IPNet, IP, error) {
|
||||||
|
|
||||||
return ipnetset, ipset, nil
|
return ipnetset, ipset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParseCIDRs parses comma separated CIDRs into a sorted string array
|
||||||
|
func ParseCIDRs(s string) ([]string, error) {
|
||||||
|
if s == "" {
|
||||||
|
return []string{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
values := strings.Split(s, ",")
|
||||||
|
|
||||||
|
ipnets, ips, err := ParseIPNets(values...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cidrs := []string{}
|
||||||
|
for k := range ipnets {
|
||||||
|
cidrs = append(cidrs, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k := range ips {
|
||||||
|
cidrs = append(cidrs, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(cidrs)
|
||||||
|
|
||||||
|
return cidrs, nil
|
||||||
|
}
|
||||||
|
|
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||||
package net
|
package net
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -32,3 +34,20 @@ func TestNewIPSet(t *testing.T) {
|
||||||
t.Errorf("Expected len=1: %d", len(ips))
|
t.Errorf("Expected len=1: %d", len(ips))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseCIDRs(t *testing.T) {
|
||||||
|
cidr, _ := ParseCIDRs("invalid.com")
|
||||||
|
if cidr != nil {
|
||||||
|
t.Errorf("expected %v but got %v", nil, cidr)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []string{"192.0.0.1", "192.0.1.0/24"}
|
||||||
|
cidr, err := ParseCIDRs("192.0.0.1, 192.0.1.0/24")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error %v", err)
|
||||||
|
}
|
||||||
|
sort.Strings(cidr)
|
||||||
|
if !reflect.DeepEqual(expected, cidr) {
|
||||||
|
t.Errorf("expected %v but got %v", expected, cidr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,15 +1,12 @@
|
||||||
function mock_ngx(mock)
|
|
||||||
local _ngx = mock
|
|
||||||
setmetatable(_ngx, {__index = _G.ngx})
|
|
||||||
_G.ngx = _ngx
|
|
||||||
end
|
|
||||||
|
|
||||||
describe("Balancer chash", function()
|
describe("Balancer chash", function()
|
||||||
|
after_each(function()
|
||||||
|
reset_ngx()
|
||||||
|
end)
|
||||||
|
|
||||||
describe("balance()", function()
|
describe("balance()", function()
|
||||||
it("uses correct key for given backend", function()
|
it("uses correct key for given backend", function()
|
||||||
mock_ngx({var = { request_uri = "/alma/armud"}})
|
ngx.var = { request_uri = "/alma/armud"}
|
||||||
local balancer_chash = require("balancer.chash")
|
local balancer_chash = require_without_cache("balancer.chash")
|
||||||
|
|
||||||
local resty_chash = package.loaded["resty.chash"]
|
local resty_chash = package.loaded["resty.chash"]
|
||||||
resty_chash.new = function(self, nodes)
|
resty_chash.new = function(self, nodes)
|
||||||
|
|
|
@ -11,6 +11,8 @@ do
|
||||||
-- if there's more constants need to be whitelisted for test runs, add here.
|
-- if there's more constants need to be whitelisted for test runs, add here.
|
||||||
local GLOBALS_ALLOWED_IN_TEST = {
|
local GLOBALS_ALLOWED_IN_TEST = {
|
||||||
helpers = true,
|
helpers = true,
|
||||||
|
require_without_cache = true,
|
||||||
|
reset_ngx = true,
|
||||||
}
|
}
|
||||||
local newindex = function(table, key, value)
|
local newindex = function(table, key, value)
|
||||||
rawset(table, key, value)
|
rawset(table, key, value)
|
||||||
|
@ -69,6 +71,15 @@ end
|
||||||
|
|
||||||
ngx.log = function(...) end
|
ngx.log = function(...) end
|
||||||
ngx.print = function(...) end
|
ngx.print = function(...) end
|
||||||
|
local original_ngx = ngx
|
||||||
|
_G.reset_ngx = function()
|
||||||
|
ngx = original_ngx
|
||||||
|
end
|
||||||
|
|
||||||
|
_G.require_without_cache = function(module)
|
||||||
|
package.loaded[module] = nil
|
||||||
|
return require(module)
|
||||||
|
end
|
||||||
|
|
||||||
lua_ingress.init_worker()
|
lua_ingress.init_worker()
|
||||||
|
|
||||||
|
|
|
@ -1,27 +1,16 @@
|
||||||
local original_ngx = ngx
|
|
||||||
local util
|
local util
|
||||||
|
|
||||||
local function reset_ngx()
|
|
||||||
_G.ngx = original_ngx
|
|
||||||
end
|
|
||||||
|
|
||||||
local function mock_ngx(mock)
|
|
||||||
local _ngx = mock
|
|
||||||
setmetatable(_ngx, { __index = ngx })
|
|
||||||
_G.ngx = _ngx
|
|
||||||
end
|
|
||||||
|
|
||||||
describe("utility", function()
|
describe("utility", function()
|
||||||
|
before_each(function()
|
||||||
|
ngx.var = { remote_addr = "192.168.1.1", [1] = "nginx/regexp/1/group/capturing" }
|
||||||
|
util = require_without_cache("util")
|
||||||
|
end)
|
||||||
|
|
||||||
after_each(function()
|
after_each(function()
|
||||||
reset_ngx()
|
reset_ngx()
|
||||||
end)
|
end)
|
||||||
|
|
||||||
describe("ngx_complex_value", function()
|
describe("ngx_complex_value", function()
|
||||||
before_each(function()
|
|
||||||
mock_ngx({ var = { remote_addr = "192.168.1.1", [1] = "nginx/regexp/1/group/capturing" } })
|
|
||||||
util = require("util")
|
|
||||||
end)
|
|
||||||
|
|
||||||
local ngx_complex_value = function(data)
|
local ngx_complex_value = function(data)
|
||||||
local ret, err = util.parse_complex_value(data)
|
local ret, err = util.parse_complex_value(data)
|
||||||
|
|
Loading…
Reference in a new issue