How to

Go: How to mock SCS session authentication in tests

Post - Go: How to mock SCS session authentication in tests
February 1, 2024 4 months ago Tags 6 min read Comments
Share to Linkedin
Share to X (Twitter)
Share to Reddit
Share to Bluesky
Copy link

Table of Contents

Open Table of Contents

Introduction

While working on a personal project using Go and alexedwards/scsfor handling sessions, I’ve come across an issue when creating tests: After setting up the configuration and the LoadAndSave middleware, all tests started failing with 401 (Unauthorized) response.

In this quickly post we’ll simple a small GoLang server using go-chirouter and implement a LoadAndSaveMock middleware to inject any session information into our test.

Or if you want to go direct to the LoadAndSaveMock middleware :)

Server config

First, let’s create a Server struct to hold our session manager and router

main.go
1
type Server struct {
2
session *scs.SessionManager
3
router *chi.Mux
4
}

And now, create our main function, session, router and server

main.go
1
type Server struct {
2
session *scs.SessionManager
3
router *chi.Mux
4
}
5
6
func main() {
7
session := scs.New()
8
router := chi.NewRouter()
9
server := &Server{
10
session: session,
11
router: router,
12
}
13
14
router.Use(session.LoadAndSave)
15
}

Note that we are defining the LoadAndSave middleware from the SCS session. This middleware will get the value from the cookie and save into our request context.

Server routes

We can now define two routes:

  • / - Will return 200 (Ok) with the message “Hello, World!” and insert an user_role value into our context
  • /admin - Will return 200 (Ok) with the message “Hello, <role>!”

If we request the /admin endpoint before requesting the root endpoint, we should get an 401 (Unauthorized) status code.

Let’s implement all handlers:

main.go
1
func (s *Server) handleHelloWorld(w http.ResponseWriter, r *http.Request) {
2
// Save "ADMIN" value into "user_role"
3
s.session.Put(r.Context(), "user_role", "ADMIN")
4
5
// Send the message (By default, will send a 200 status code)
6
w.Write([]byte("Hello, World!"))
7
}
8
9
func (s *Server) handleAdmin(w http.ResponseWriter, r *http.Request) {
10
// Get the "user_role" value from request context
11
role := s.session.GetString(r.Context(), "user_role")
12
13
// By default, GetString will return an empty string if no value is found
14
if role == "" {
15
w.WriteHeader(http.StatusUnauthorized)
16
return
17
}
18
19
// Send the message (By default, will send a 200 status code)
20
w.Write([]byte("Welcome, " + role + "!"))
21
}

Defining the handlers into router:

main.go
1
type Server struct {
2
session *scs.SessionManager
3
router *chi.Mux
4
}
5
6
func main() {
7
session := scs.New()
8
router := chi.NewRouter()
9
server := &Server{
10
session: session,
11
router: router,
12
}
13
14
router.Use(session.LoadAndSave)
15
16
// Routes
17
router.Get("/", server.handleHelloWorld)
18
router.Group(func(router chi.Router) {
19
router.Get("/admin", server.handleAdmin)
20
})
21
22
http.ListenAndServe(":6987", router)
23
}
24
25
func (s *Server) handleHelloWorld(w http.ResponseWriter, r *http.Request) {
26
// Save "ADMIN" value into "user_role"
27
s.session.Put(r.Context(), "user_role", "ADMIN")
28
29
// Send the message (By default, will send a 200 status code)
30
w.Write([]byte("Hello, World!"))
31
}
32
33
func (s *Server) handleAdmin(w http.ResponseWriter, r *http.Request) {
34
// Get the "user_role" value from request context
35
role := s.session.GetString(r.Context(), "user_role")
36
37
// By default, GetString will return an empty string if no value is found
38
if role == "" {
39
w.WriteHeader(http.StatusUnauthorized)
40
return
41
}
42
43
// Send the message (By default, will send a 200 status code)
44
w.Write([]byte("Welcome, " + role + "!"))
45
}

Middleware for user authorization

To complete our simple GoLang server, we need a middleware for handling user authorization

main.go
1
func (s *Server) RequireAdmin(next http.Handler) http.Handler {
2
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
3
// Get "user_role" value and validate
4
if s.session.GetString(r.Context(), "user_role") != "ADMIN" {
5
w.WriteHeader(http.StatusUnauthorized)
6
return
7
}
8
9
next.ServeHTTP(w, r)
10
})
11
}

We’ll use this middleware inside the router.Group

main.go
1
// Routes
2
router.Get("/", server.handleHelloWorld)
3
router.Group(func(router chi.Router) {
4
router.Use(server.RequireAdmin)
5
router.Get("/admin", server.handleAdmin)
6
})

The testing

Now, the testing part.

Our test will have four paths:

  • Public endpoint that will return 200 (Ok)
  • Admin endpoint without role that will return 401 (Unauthorized)
  • Admin endpoint with wrong role that will return 401 (Unauthorized)
  • Admin endpoint with correct role that will return 200 (Ok)
main_test.go
1
func Test_main(t *testing.T) {
2
tests := []struct {
3
name string
4
url string
5
role string
6
wantCode int
7
}{
8
{
9
name: "public endpoint",
10
url: "/",
11
wantCode: http.StatusOK,
12
},
13
{
14
name: "admin endpoint without role",
15
url: "/admin",
16
wantCode: http.StatusUnauthorized,
17
},
18
{
19
name: "admin endpoint with wrong role",
20
url: "/admin",
21
role: "OTHER",
22
wantCode: http.StatusUnauthorized,
23
},
24
{
25
name: "admin endpoint with correct role",
26
url: "/admin",
27
role: "ADMIN",
28
wantCode: http.StatusOK,
29
},
30
}
31
for _, tt := range tests {
32
t.Run(tt.name, func(t *testing.T) {
33
// TODO: Add test logic here
34
})
35
}
36
}

Implementing the first part of the test logic (server configuration):

main_test.go
1
// Set up
2
session := scs.New()
3
router := chi.NewRouter()
4
server := &Server{
5
session: session,
6
router: router,
7
}
8
9
// TODO: Implement session middleware
10
11
// Routes
12
router.Get("/", server.handleHelloWorld)
13
router.Group(func(router chi.Router) {
14
router.Use(server.RequireAdmin)
15
router.Get("/admin", server.handleAdmin)
16
})

Second part of the test logic (http request and status validation):

main_test.go
1
// Request
2
rr := httptest.NewRecorder()
3
req := httptest.NewRequest("GET", tt.url, nil)
4
router.ServeHTTP(rr, req)
5
6
// Result
7
res := rr.Result()
8
9
if tt.wantCode != res.StatusCode {
10
t.Errorf("Expected response code %d. Got %d\n", tt.wantCode, res.StatusCode)
11
}

LoadAndSaveMock middleware

In this section we’ll create a wrapper/middleware of the LoadAndSave middleware from the SCS package, this mock will inject any value we want into our router before recording the test request:

main_test.go
1
func LoadAndSaveMock(session *scs.SessionManager, key, value string) func(next http.Handler) http.Handler {
2
return func(next http.Handler) http.Handler {
3
return session.LoadAndSave(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
4
session.Put(r.Context(), key, value)
5
next.ServeHTTP(w, r)
6
}))
7
}
8
}

Putting all together

The final test code logic:

main_test.go
1
func LoadAndSaveMock(session *scs.SessionManager, key, value string) func(next http.Handler) http.Handler {
2
return func(next http.Handler) http.Handler {
3
return session.LoadAndSave(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
4
session.Put(r.Context(), key, value)
5
next.ServeHTTP(w, r)
6
}))
7
}
8
}
9
10
func Test_main(t *testing.T) {
11
tests := []struct {
12
name string
13
url string
14
role string
15
wantCode int
16
}{
17
{
18
name: "public endpoint",
19
url: "/",
20
wantCode: http.StatusOK,
21
},
22
{
23
name: "admin endpoint without role",
24
url: "/admin",
25
wantCode: http.StatusUnauthorized,
26
},
27
{
28
name: "admin endpoint with wrong role",
29
url: "/admin",
30
role: "OTHER",
31
wantCode: http.StatusUnauthorized,
32
},
33
{
34
name: "admin endpoint with correct role",
35
url: "/admin",
36
role: "ADMIN",
37
wantCode: http.StatusOK,
38
},
39
}
40
for _, tt := range tests {
41
t.Run(tt.name, func(t *testing.T) {
42
// Set up
43
session := scs.New()
44
router := chi.NewRouter()
45
server := &Server{
46
session: session,
47
router: router,
48
}
49
50
// Session middleware
51
router.Use(LoadAndSaveMock(session, "user_role", tt.role))
52
53
// Routes
54
router.Get("/", server.handleHelloWorld)
55
router.Group(func(router chi.Router) {
56
router.Use(server.RequireAdmin)
57
router.Get("/admin", server.handleAdmin)
58
})
59
60
// Request
61
rr := httptest.NewRecorder()
62
req := httptest.NewRequest("GET", tt.url, nil)
63
router.ServeHTTP(rr, req)
64
65
// Result
66
res := rr.Result()
67
68
if tt.wantCode != res.StatusCode {
69
t.Errorf("Expected response code %d. Got %d\n", tt.wantCode, res.StatusCode)
70
}
71
})
72
}
73
}

And.. we get this result

go test ./...
=== RUN Test_main
=== RUN Test_main/public_endpoint
--- PASS: Test_main/public_endpoint (0.00s)
=== RUN Test_main/admin_endpoint_without_role
--- PASS: Test_main/admin_endpoint_without_role (0.00s)
=== RUN Test_main/admin_endpoint_with_wrong_role
--- PASS: Test_main/admin_endpoint_with_wrong_role (0.00s)
=== RUN Test_main/admin_endpoint_with_correct_role
--- PASS: Test_main/admin_endpoint_with_correct_role (0.00s)
--- PASS: Test_main (0.00s)
PASS
ok github.com/LucJosin/go-scs-test 0.002s

Conclusion

In this post I show you a solution to a problem that I’ve encountered while implementing auth tests using the SCS session package.

Resources and References

Enjoy this post? Feel free to share!
Share to Linkedin
Share to X (Twitter)
Share to Reddit
Share to Bluesky
Copy link

Comments

© 2023 Lucas Josino All Rights Reserved.