112 lines
2.5 KiB
Go
112 lines
2.5 KiB
Go
![]() |
package httpexpect
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"net"
|
||
|
"net/http"
|
||
|
"net/http/httptest"
|
||
|
"sync"
|
||
|
|
||
|
"github.com/gorilla/websocket"
|
||
|
"github.com/valyala/fasthttp"
|
||
|
)
|
||
|
|
||
|
// NewWebsocketDialer produces new websocket.Dialer which dials to bound
|
||
|
// http.Handler without creating a real net.Conn.
|
||
|
func NewWebsocketDialer(handler http.Handler) *websocket.Dialer {
|
||
|
return &websocket.Dialer{
|
||
|
NetDial: func(network, addr string) (net.Conn, error) {
|
||
|
hc := newHandlerConn()
|
||
|
hc.runHandler(handler)
|
||
|
return hc, nil
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// NewFastWebsocketDialer produces new websocket.Dialer which dials to bound
|
||
|
// fasthttp.RequestHandler without creating a real net.Conn.
|
||
|
func NewFastWebsocketDialer(handler fasthttp.RequestHandler) *websocket.Dialer {
|
||
|
return &websocket.Dialer{
|
||
|
NetDial: func(network, addr string) (net.Conn, error) {
|
||
|
hc := newHandlerConn()
|
||
|
hc.runFastHandler(handler)
|
||
|
return hc, nil
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type handlerConn struct {
|
||
|
net.Conn // returned from dialer
|
||
|
backConn net.Conn // passed to the background goroutine
|
||
|
|
||
|
wg sync.WaitGroup
|
||
|
}
|
||
|
|
||
|
func newHandlerConn() *handlerConn {
|
||
|
dialConn, backConn := net.Pipe()
|
||
|
|
||
|
return &handlerConn{
|
||
|
Conn: dialConn,
|
||
|
backConn: backConn,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (hc *handlerConn) Close() error {
|
||
|
err := hc.Conn.Close()
|
||
|
hc.wg.Wait() // wait the background goroutine
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (hc *handlerConn) runHandler(handler http.Handler) {
|
||
|
hc.wg.Add(1)
|
||
|
|
||
|
go func() {
|
||
|
defer hc.wg.Done()
|
||
|
|
||
|
recorder := &hijackRecorder{conn: hc.backConn}
|
||
|
|
||
|
for {
|
||
|
req, err := http.ReadRequest(bufio.NewReader(hc.backConn))
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
handler.ServeHTTP(recorder, req)
|
||
|
}
|
||
|
}()
|
||
|
}
|
||
|
|
||
|
func (hc *handlerConn) runFastHandler(handler fasthttp.RequestHandler) {
|
||
|
hc.wg.Add(1)
|
||
|
|
||
|
go func() {
|
||
|
defer hc.wg.Done()
|
||
|
|
||
|
_ = fasthttp.ServeConn(hc.backConn, handler)
|
||
|
}()
|
||
|
}
|
||
|
|
||
|
// hijackRecorder it similar to httptest.ResponseRecorder,
|
||
|
// but with Hijack capabilities.
|
||
|
//
|
||
|
// Original idea is stolen from https://github.com/posener/wstest
|
||
|
type hijackRecorder struct {
|
||
|
httptest.ResponseRecorder
|
||
|
conn net.Conn
|
||
|
}
|
||
|
|
||
|
// Hijack the connection for caller.
|
||
|
//
|
||
|
// Implements http.Hijacker interface.
|
||
|
func (r *hijackRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||
|
rw := bufio.NewReadWriter(bufio.NewReader(r.conn), bufio.NewWriter(r.conn))
|
||
|
return r.conn, rw, nil
|
||
|
}
|
||
|
|
||
|
// WriteHeader write HTTP header to the client and closes the connection
|
||
|
//
|
||
|
// Implements http.ResponseWriter interface.
|
||
|
func (r *hijackRecorder) WriteHeader(code int) {
|
||
|
resp := http.Response{StatusCode: code, Header: r.Header()}
|
||
|
_ = resp.Write(r.conn)
|
||
|
}
|