236 lines
7.1 KiB
Go
236 lines
7.1 KiB
Go
|
|
package auth_test
|
||
|
|
|
||
|
|
import (
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"testing"
|
||
|
|
|
||
|
|
"github.com/gin-gonic/gin"
|
||
|
|
|
||
|
|
"dragonrelay/internal/auth"
|
||
|
|
"dragonrelay/internal/db"
|
||
|
|
)
|
||
|
|
|
||
|
|
// makeHandler builds a *gin.Engine that:
|
||
|
|
// 1. Injects the provided Claims into the Gin context (simulating what
|
||
|
|
// auth.Middleware would do after JWT verification).
|
||
|
|
// 2. Runs RequireHostAccess with the given store and minLevel.
|
||
|
|
// 3. Returns 200 {"ok":true} if all middleware passes.
|
||
|
|
func makeHandler(store *db.Store, claims *auth.Claims, minLevel string) *gin.Engine {
|
||
|
|
r := gin.New()
|
||
|
|
r.Use(func(c *gin.Context) {
|
||
|
|
if claims != nil {
|
||
|
|
c.Set("claims", claims)
|
||
|
|
}
|
||
|
|
c.Next()
|
||
|
|
})
|
||
|
|
r.GET("/hosts/:ip/action", auth.RequireHostAccess(store, minLevel), func(c *gin.Context) {
|
||
|
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||
|
|
})
|
||
|
|
return r
|
||
|
|
}
|
||
|
|
|
||
|
|
// doRequest fires a GET against /hosts/<ip>/action and returns the status code.
|
||
|
|
func doRequest(r *gin.Engine, ip string) int {
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
req, _ := http.NewRequest("GET", "/hosts/"+ip+"/action", nil)
|
||
|
|
r.ServeHTTP(w, req)
|
||
|
|
return w.Code
|
||
|
|
}
|
||
|
|
|
||
|
|
// ---------------------------------------------------------------------------
|
||
|
|
// TestLevelRank
|
||
|
|
// ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
func TestLevelRank(t *testing.T) {
|
||
|
|
cases := []struct {
|
||
|
|
level string
|
||
|
|
want int
|
||
|
|
}{
|
||
|
|
{"view", 1},
|
||
|
|
{"connect", 2},
|
||
|
|
{"manage", 3},
|
||
|
|
{"unknown", 0},
|
||
|
|
{"", 0},
|
||
|
|
{"ADMIN", 0},
|
||
|
|
}
|
||
|
|
for _, tc := range cases {
|
||
|
|
got := auth.LevelRank(tc.level)
|
||
|
|
if got != tc.want {
|
||
|
|
t.Errorf("LevelRank(%q) = %d, want %d", tc.level, got, tc.want)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Ordinal assertions: view < connect < manage
|
||
|
|
if !(auth.LevelRank("view") < auth.LevelRank("connect")) {
|
||
|
|
t.Error("expected LevelRank(view) < LevelRank(connect)")
|
||
|
|
}
|
||
|
|
if !(auth.LevelRank("connect") < auth.LevelRank("manage")) {
|
||
|
|
t.Error("expected LevelRank(connect) < LevelRank(manage)")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ---------------------------------------------------------------------------
|
||
|
|
// TestRequireHostAccess_AdminBypasses
|
||
|
|
//
|
||
|
|
// An admin with no ACL row for the host and a high minLevel should still get
|
||
|
|
// through, because admin bypasses all per-host checks.
|
||
|
|
// ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
func TestRequireHostAccess_AdminBypasses(t *testing.T) {
|
||
|
|
gin.SetMode(gin.TestMode)
|
||
|
|
|
||
|
|
store, err := db.Open(":memory:")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("db.Open: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
claims := &auth.Claims{Username: "alice", Role: "admin"}
|
||
|
|
r := makeHandler(store, claims, "manage")
|
||
|
|
|
||
|
|
if code := doRequest(r, "192.168.1.1"); code != http.StatusOK {
|
||
|
|
t.Errorf("admin bypass: got %d, want 200", code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ---------------------------------------------------------------------------
|
||
|
|
// TestRequireHostAccess_MemberWithSufficientGrant
|
||
|
|
//
|
||
|
|
// A regular member whose ACL row grants "connect" access should pass when
|
||
|
|
// minLevel is "connect".
|
||
|
|
// ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
func TestRequireHostAccess_MemberWithSufficientGrant(t *testing.T) {
|
||
|
|
gin.SetMode(gin.TestMode)
|
||
|
|
|
||
|
|
store, err := db.Open(":memory:")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("db.Open: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Seed: user "bob" with a "connect" grant on 192.168.1.2.
|
||
|
|
if err := seedUserWithGrant(store, "bob", "member", "192.168.1.2", "connect"); err != nil {
|
||
|
|
t.Fatalf("seed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
claims := &auth.Claims{Username: "bob", Role: "member"}
|
||
|
|
r := makeHandler(store, claims, "connect")
|
||
|
|
|
||
|
|
if code := doRequest(r, "192.168.1.2"); code != http.StatusOK {
|
||
|
|
t.Errorf("sufficient grant: got %d, want 200", code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ---------------------------------------------------------------------------
|
||
|
|
// TestRequireHostAccess_MemberWithInsufficientGrant
|
||
|
|
//
|
||
|
|
// A member whose ACL row grants only "view" should be rejected when
|
||
|
|
// minLevel is "connect".
|
||
|
|
// ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
func TestRequireHostAccess_MemberWithInsufficientGrant(t *testing.T) {
|
||
|
|
gin.SetMode(gin.TestMode)
|
||
|
|
|
||
|
|
store, err := db.Open(":memory:")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("db.Open: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := seedUserWithGrant(store, "carol", "member", "192.168.1.3", "view"); err != nil {
|
||
|
|
t.Fatalf("seed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
claims := &auth.Claims{Username: "carol", Role: "member"}
|
||
|
|
r := makeHandler(store, claims, "connect")
|
||
|
|
|
||
|
|
if code := doRequest(r, "192.168.1.3"); code != http.StatusForbidden {
|
||
|
|
t.Errorf("insufficient grant: got %d, want 403", code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ---------------------------------------------------------------------------
|
||
|
|
// TestRequireHostAccess_MemberNoGrant
|
||
|
|
//
|
||
|
|
// A member with no ACL row at all should be rejected even for minLevel "view".
|
||
|
|
// ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
func TestRequireHostAccess_MemberNoGrant(t *testing.T) {
|
||
|
|
gin.SetMode(gin.TestMode)
|
||
|
|
|
||
|
|
store, err := db.Open(":memory:")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("db.Open: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Create the user but add no ACL row.
|
||
|
|
if err := seedUser(store, "dan", "member"); err != nil {
|
||
|
|
t.Fatalf("seed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
claims := &auth.Claims{Username: "dan", Role: "member"}
|
||
|
|
r := makeHandler(store, claims, "view")
|
||
|
|
|
||
|
|
if code := doRequest(r, "10.0.0.1"); code != http.StatusForbidden {
|
||
|
|
t.Errorf("no grant: got %d, want 403", code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ---------------------------------------------------------------------------
|
||
|
|
// TestRequireHostAccess_ViewerCappedAtView
|
||
|
|
//
|
||
|
|
// Viewers can never exceed the "view" level, regardless of what their ACL row
|
||
|
|
// says. Two sub-cases:
|
||
|
|
// - minLevel "connect" → 403 (capped before ACL lookup)
|
||
|
|
// - minLevel "view" → 200 (viewer passes when the bar is low enough)
|
||
|
|
// ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
func TestRequireHostAccess_ViewerCappedAtView(t *testing.T) {
|
||
|
|
gin.SetMode(gin.TestMode)
|
||
|
|
|
||
|
|
store, err := db.Open(":memory:")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("db.Open: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Give eve a "connect" ACL row — the role cap should still block her at
|
||
|
|
// minLevel "connect".
|
||
|
|
if err := seedUserWithGrant(store, "eve", "viewer", "192.168.1.5", "connect"); err != nil {
|
||
|
|
t.Fatalf("seed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
claims := &auth.Claims{Username: "eve", Role: "viewer"}
|
||
|
|
|
||
|
|
t.Run("minLevel_connect_blocked", func(t *testing.T) {
|
||
|
|
r := makeHandler(store, claims, "connect")
|
||
|
|
if code := doRequest(r, "192.168.1.5"); code != http.StatusForbidden {
|
||
|
|
t.Errorf("viewer cap (connect): got %d, want 403", code)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
|
||
|
|
t.Run("minLevel_view_allowed", func(t *testing.T) {
|
||
|
|
r := makeHandler(store, claims, "view")
|
||
|
|
if code := doRequest(r, "192.168.1.5"); code != http.StatusOK {
|
||
|
|
t.Errorf("viewer at view level: got %d, want 200", code)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
// ---------------------------------------------------------------------------
|
||
|
|
// Seed helpers
|
||
|
|
// ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
// seedUser creates a user record in the in-memory store and returns any error.
|
||
|
|
func seedUser(store *db.Store, username, role string) error {
|
||
|
|
_, err := store.CreateUser(username, role)
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
// seedUserWithGrant creates a user and an associated host_acl row.
|
||
|
|
func seedUserWithGrant(store *db.Store, username, role, hostIP, level string) error {
|
||
|
|
user, err := store.CreateUser(username, role)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
return store.SetHostAccess(user.ID, hostIP, level)
|
||
|
|
}
|