218 lines
4.9 KiB
Go
218 lines
4.9 KiB
Go
package httpexpect
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
// Binder implements networkless http.RoundTripper attached directly to
|
|
// http.Handler.
|
|
//
|
|
// Binder emulates network communication by invoking given http.Handler
|
|
// directly. It passes httptest.ResponseRecorder as http.ResponseWriter
|
|
// to the handler, and then constructs http.Response from recorded data.
|
|
type Binder struct {
|
|
// HTTP handler invoked for every request.
|
|
Handler http.Handler
|
|
// TLS connection state used for https:// requests.
|
|
TLS *tls.ConnectionState
|
|
}
|
|
|
|
// NewBinder returns a new Binder given a http.Handler.
|
|
//
|
|
// Example:
|
|
// client := &http.Client{
|
|
// Transport: NewBinder(handler),
|
|
// }
|
|
func NewBinder(handler http.Handler) Binder {
|
|
return Binder{Handler: handler}
|
|
}
|
|
|
|
// RoundTrip implements http.RoundTripper.RoundTrip.
|
|
func (binder Binder) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
if req.Proto == "" {
|
|
req.Proto = fmt.Sprintf("HTTP/%d.%d", req.ProtoMajor, req.ProtoMinor)
|
|
}
|
|
|
|
if req.Body != nil {
|
|
if req.ContentLength == -1 {
|
|
req.TransferEncoding = []string{"chunked"}
|
|
}
|
|
} else {
|
|
req.Body = ioutil.NopCloser(bytes.NewReader(nil))
|
|
}
|
|
|
|
if req.URL != nil && req.URL.Scheme == "https" && binder.TLS != nil {
|
|
req.TLS = binder.TLS
|
|
}
|
|
|
|
if req.RequestURI == "" {
|
|
req.RequestURI = req.URL.RequestURI()
|
|
}
|
|
|
|
recorder := httptest.NewRecorder()
|
|
|
|
binder.Handler.ServeHTTP(recorder, req)
|
|
|
|
resp := http.Response{
|
|
Request: req,
|
|
StatusCode: recorder.Code,
|
|
Status: http.StatusText(recorder.Code),
|
|
Header: recorder.Result().Header,
|
|
}
|
|
|
|
if recorder.Flushed {
|
|
resp.TransferEncoding = []string{"chunked"}
|
|
}
|
|
|
|
if recorder.Body != nil {
|
|
resp.Body = ioutil.NopCloser(recorder.Body)
|
|
}
|
|
|
|
return &resp, nil
|
|
}
|
|
|
|
// FastBinder implements networkless http.RoundTripper attached directly
|
|
// to fasthttp.RequestHandler.
|
|
//
|
|
// FastBinder emulates network communication by invoking given fasthttp.RequestHandler
|
|
// directly. It converts http.Request to fasthttp.Request, invokes handler, and then
|
|
// converts fasthttp.Response to http.Response.
|
|
type FastBinder struct {
|
|
// FastHTTP handler invoked for every request.
|
|
Handler fasthttp.RequestHandler
|
|
// TLS connection state used for https:// requests.
|
|
TLS *tls.ConnectionState
|
|
}
|
|
|
|
// NewFastBinder returns a new FastBinder given a fasthttp.RequestHandler.
|
|
//
|
|
// Example:
|
|
// client := &http.Client{
|
|
// Transport: NewFastBinder(fasthandler),
|
|
// }
|
|
func NewFastBinder(handler fasthttp.RequestHandler) FastBinder {
|
|
return FastBinder{Handler: handler}
|
|
}
|
|
|
|
// RoundTrip implements http.RoundTripper.RoundTrip.
|
|
func (binder FastBinder) RoundTrip(stdreq *http.Request) (*http.Response, error) {
|
|
fastreq := std2fast(stdreq)
|
|
|
|
var conn net.Conn
|
|
if stdreq.URL != nil && stdreq.URL.Scheme == "https" && binder.TLS != nil {
|
|
conn = connTLS{state: binder.TLS}
|
|
} else {
|
|
conn = connNonTLS{}
|
|
}
|
|
|
|
ctx := fasthttp.RequestCtx{}
|
|
ctx.Init2(conn, fastLogger{}, true)
|
|
fastreq.CopyTo(&ctx.Request)
|
|
|
|
if stdreq.ContentLength >= 0 {
|
|
ctx.Request.Header.SetContentLength(int(stdreq.ContentLength))
|
|
} else {
|
|
ctx.Request.Header.Add("Transfer-Encoding", "chunked")
|
|
}
|
|
|
|
if stdreq.Body != nil {
|
|
b, err := ioutil.ReadAll(stdreq.Body)
|
|
if err == nil {
|
|
ctx.Request.SetBody(b)
|
|
}
|
|
}
|
|
|
|
binder.Handler(&ctx)
|
|
|
|
return fast2std(stdreq, &ctx.Response), nil
|
|
}
|
|
|
|
func std2fast(stdreq *http.Request) *fasthttp.Request {
|
|
fastreq := &fasthttp.Request{}
|
|
fastreq.SetRequestURI(stdreq.URL.String())
|
|
|
|
fastreq.Header.SetMethod(stdreq.Method)
|
|
|
|
for k, a := range stdreq.Header {
|
|
for n, v := range a {
|
|
if n == 0 {
|
|
fastreq.Header.Set(k, v)
|
|
} else {
|
|
fastreq.Header.Add(k, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
return fastreq
|
|
}
|
|
|
|
func fast2std(stdreq *http.Request, fastresp *fasthttp.Response) *http.Response {
|
|
status := fastresp.Header.StatusCode()
|
|
body := fastresp.Body()
|
|
|
|
stdresp := &http.Response{
|
|
Request: stdreq,
|
|
StatusCode: status,
|
|
Status: http.StatusText(status),
|
|
}
|
|
|
|
fastresp.Header.VisitAll(func(k, v []byte) {
|
|
sk := string(k)
|
|
sv := string(v)
|
|
if stdresp.Header == nil {
|
|
stdresp.Header = make(http.Header)
|
|
}
|
|
stdresp.Header.Add(sk, sv)
|
|
})
|
|
|
|
if fastresp.Header.ContentLength() == -1 {
|
|
stdresp.TransferEncoding = []string{"chunked"}
|
|
}
|
|
|
|
if body != nil {
|
|
stdresp.Body = ioutil.NopCloser(bytes.NewReader(body))
|
|
} else {
|
|
stdresp.Body = ioutil.NopCloser(bytes.NewReader(nil))
|
|
}
|
|
|
|
return stdresp
|
|
}
|
|
|
|
type fastLogger struct{}
|
|
|
|
func (fastLogger) Printf(format string, args ...interface{}) {
|
|
_, _ = format, args
|
|
}
|
|
|
|
type connNonTLS struct {
|
|
net.Conn
|
|
}
|
|
|
|
func (connNonTLS) RemoteAddr() net.Addr {
|
|
return &net.TCPAddr{IP: net.IPv4zero}
|
|
}
|
|
|
|
func (connNonTLS) LocalAddr() net.Addr {
|
|
return &net.TCPAddr{IP: net.IPv4zero}
|
|
}
|
|
|
|
type connTLS struct {
|
|
connNonTLS
|
|
state *tls.ConnectionState
|
|
}
|
|
|
|
func (c connTLS) Handshake() error {
|
|
return nil
|
|
}
|
|
|
|
func (c connTLS) ConnectionState() tls.ConnectionState {
|
|
return *c.state
|
|
}
|