add initContext() function

This commit is contained in:
yuli
2024-11-04 09:14:18 +02:00
parent f235bed0a8
commit 5752ade642
4 changed files with 9 additions and 4 deletions

View File

@@ -422,10 +422,10 @@ func (w *CustomResponseWriter) WriteHeader(statusCode int) {
var statusCounter = 0 var statusCounter = 0
var statusErrorCounter = 0 var statusErrorCounter = 0
func reqMiddleware(handler http.Handler) http.Handler { func (e mainEnv) reqMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
//log.Printf("Set host %s\n", r.Host) //log.Printf("Set host %s\n", r.Host)
autocontext.Set(r, "host", r.Host) e.initContext(r)
w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("X-Frame-Options", "SAMEORIGIN") w.Header().Set("X-Frame-Options", "SAMEORIGIN")
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload") w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")

View File

@@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/securitybunker/databunker/src/autocontext"
) )
func (e mainEnv) setupConfRouter(router *httprouter.Router) *httprouter.Router { func (e mainEnv) setupConfRouter(router *httprouter.Router) *httprouter.Router {
@@ -15,6 +16,10 @@ func (e mainEnv) setupConfRouter(router *httprouter.Router) *httprouter.Router {
return router return router
} }
func (e mainEnv) initContext(r *http.Request) {
autocontext.Set(r, "host", r.Host)
}
func (e mainEnv) cookieSettings(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { func (e mainEnv) cookieSettings(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
resultJSON, scriptsJSON, _, err := e.db.getLegalBasisCookieConf() resultJSON, scriptsJSON, _, err := e.db.getLegalBasisCookieConf()
if err != nil { if err != nil {

View File

@@ -113,7 +113,7 @@ func loadService() {
}, },
} }
listener := cfg.Server.Host + ":" + cfg.Server.Port listener := cfg.Server.Host + ":" + cfg.Server.Port
srv := &http.Server{Addr: listener, Handler: reqMiddleware(router), TLSConfig: tlsConfig} srv := &http.Server{Addr: listener, Handler: e.reqMiddleware(router), TLSConfig: tlsConfig}
stop := make(chan os.Signal, 2) stop := make(chan os.Signal, 2)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM) signal.Notify(stop, os.Interrupt, syscall.SIGTERM)

View File

@@ -94,7 +94,7 @@ func TestUtilGetJSONPost(t *testing.T) {
} }
func TestUtilSMS(t *testing.T) { func TestUtilSMS(t *testing.T) {
server := httptest.NewServer(reqMiddleware(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { server := httptest.NewServer(e.reqMiddleware(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Content-Type", "application/json") rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(200) rw.WriteHeader(200)
defer req.Body.Close() defer req.Body.Close()