add hijacker

This commit is contained in:
Mohamad Tahir 2023-11-22 13:55:51 +03:00
parent 30e10ba8ec
commit 5b93b79349
Signed by: MohamadTahir
GPG Key ID: 116FAB02D35512FA

72
main.go
View File

@ -1,11 +1,56 @@
package traefik_response_header_forward_plugin package traefik_response_header_forward_plugin
import ( import (
"bufio"
"bytes"
"context" "context"
"fmt" "fmt"
"io"
"net"
"net/http" "net/http"
) )
var (
_ interface {
http.ResponseWriter
http.Hijacker
} = &wrappedResponseWriter{}
)
type wrappedResponseWriter struct {
rw http.ResponseWriter
buf *bytes.Buffer
code int
}
func (w *wrappedResponseWriter) Header() http.Header {
return w.rw.Header()
}
func (w *wrappedResponseWriter) Write(b []byte) (int, error) {
return w.buf.Write(b)
}
func (w *wrappedResponseWriter) WriteHeader(code int) {
w.code = code
}
func (w *wrappedResponseWriter) Flush() {
w.rw.WriteHeader(w.code)
io.Copy(w.rw, w.buf)
}
func (w *wrappedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := w.rw.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("%T is not an http.Hijacker", w.rw)
}
return hijacker.Hijack()
}
// ========================================
type RequestHeader struct { type RequestHeader struct {
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
} }
@ -23,7 +68,7 @@ func CreateConfig() *Config {
type ResponseHeaderForward struct { type ResponseHeaderForward struct {
next http.Handler next http.Handler
name string name string
requestHeaders []RequestHeader config *Config
} }
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) { func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
@ -40,19 +85,26 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
return &ResponseHeaderForward{ return &ResponseHeaderForward{
next: next, next: next,
name: name, name: name,
requestHeaders: config.RequestHeaders, config: config,
}, nil }, nil
} }
func (a *ResponseHeaderForward) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (a *ResponseHeaderForward) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
a.next.ServeHTTP(rw, req) resp := &wrappedResponseWriter{
rw: rw,
buf: &bytes.Buffer{},
}
// for _, requestHeader := range a.requestHeaders { defer resp.Flush()
// headerValue := req.Header.Get(requestHeader.Name)
// if headerValue == "" {
// continue
// }
// rw.Header().Set(requestHeader.Name, headerValue) a.next.ServeHTTP(resp, req)
// }
for _, requestHeader := range a.config.RequestHeaders {
headerValue := req.Header.Get(requestHeader.Name)
if headerValue == "" {
continue
}
resp.Header().Set(requestHeader.Name, headerValue)
}
} }