diff --git a/context.go b/context.go index fb7960d..ce299fb 100644 --- a/context.go +++ b/context.go @@ -18,3 +18,7 @@ type MutationCtx struct { AuthCtx *AuthCtx Params map[string]interface{} } + +type Auth interface { + GetUserID(token string) (string, error) +} diff --git a/engine.go b/engine.go index 324d7da..480674f 100644 --- a/engine.go +++ b/engine.go @@ -20,15 +20,20 @@ type Engine struct { hashMu sync.RWMutex queryHashes map[string]uint64 tracker *reactivity.Tracker + auth Auth } +type defaultAuth struct{} + +func (defaultAuth) GetUserID(_ string) (string, error) { return "", nil } + func NewEngine(db *gorm.DB, dbType string) *Engine { slog.SetLogLoggerLevel(slog.LevelDebug) tracker := reactivity.NewTracker() if dbType != "sqlite" && dbType != "postgres" { panic("Invalid database type") } - e := &Engine{db: db, dbType: dbType, mutations: make(map[string]func(ctx *MutationCtx) interface{}), queries: make(map[string]func(ctx *QueryCtx) interface{}), dependencies: make(map[string][]string), queryHashes: make(map[string]uint64), tracker: tracker} + e := &Engine{db: db, dbType: dbType, mutations: make(map[string]func(ctx *MutationCtx) interface{}), queries: make(map[string]func(ctx *QueryCtx) interface{}), dependencies: make(map[string][]string), queryHashes: make(map[string]uint64), tracker: tracker, auth: defaultAuth{}} db.Callback().Create().After("gorm:create").Register("tether:after_create", func(tx *gorm.DB) { if dbType == "postgres" { return @@ -38,6 +43,10 @@ func NewEngine(db *gorm.DB, dbType string) *Engine { return e } +func (e *Engine) SetAuth(auth Auth) { + e.auth = auth +} + func (e *Engine) RegisterMutation(name string, mutation func(ctx *MutationCtx) interface{}) { e.mutations[name] = mutation // stores the mutation in the list of valid mutations slog.Debug("Registered mutation", "name", name) @@ -200,6 +209,13 @@ func (e *Engine) OnReceiveMessage(clientID string, msg map[string]interface{}) e e.ExecuteQuery(msg["location"].(string), msg["params"].(map[string]interface{}), clientID, true) case "mutation": e.ExecuteMutation(msg["location"].(string), msg["params"].(map[string]interface{}), clientID, msg["mutation_id"].(string)) + case "auth": + userID, err := e.auth.GetUserID(msg["token"].(string)) + if err != nil { + slog.Error("Failed to get user ID", "error", err) + return err + } + e.tracker.SetAuthID(clientID, userID) } return nil }