diff --git a/engine.go b/engine.go index 7b1b148..1cdbbfd 100644 --- a/engine.go +++ b/engine.go @@ -1,6 +1,7 @@ package tether import ( + "encoding/json" "log/slog" "net/http" @@ -9,11 +10,12 @@ import ( ) type Engine struct { - db *gorm.DB - dbType string // sqlite or postgres - mutations map[string]func(ctx *MutationCtx) error - queries map[string]func(ctx *QueryCtx) error - tracker *reactivity.Tracker + db *gorm.DB + dbType string // sqlite or postgres + mutations map[string]func(ctx *MutationCtx) error + queries map[string]func(ctx *QueryCtx) error + dependencies map[string][]string + tracker *reactivity.Tracker } func NewEngine(db *gorm.DB, dbType string) *Engine { @@ -22,7 +24,14 @@ func NewEngine(db *gorm.DB, dbType string) *Engine { if dbType != "sqlite" && dbType != "postgres" { panic("Invalid database type") } - return &Engine{db: db, dbType: dbType, mutations: make(map[string]func(ctx *MutationCtx) error), queries: make(map[string]func(ctx *QueryCtx) error), tracker: tracker} + e := &Engine{db: db, dbType: dbType, mutations: make(map[string]func(ctx *MutationCtx) error), queries: make(map[string]func(ctx *QueryCtx) error), dependencies: make(map[string][]string), tracker: tracker} + db.Callback().Create().After("gorm:create").Register("tether:after_create", func(tx *gorm.DB) { + if dbType == "postgres" { + return + } + e.InvalidateTable(tx.Statement.Table) + }) + return e } func (e *Engine) RegisterMutation(name string, mutation func(ctx *MutationCtx) error) { @@ -32,6 +41,9 @@ func (e *Engine) RegisterMutation(name string, mutation func(ctx *MutationCtx) e func (e *Engine) RegisterQuery(name string, query func(ctx *QueryCtx) error, dependencies []string) { e.queries[name] = query // stores the query in the list of valid queries + for _, dependency := range dependencies { + e.dependencies[dependency] = append(e.dependencies[dependency], name) + } slog.Debug("Registered query", "name", name) } @@ -56,13 +68,36 @@ func (e *Engine) OnDisconnect(clientID string) error { return nil } -func (e *Engine) ExecuteQuery(query string) (interface{}, error) { +func (e *Engine) GetDependentQueries(tableName string) []string { + return e.dependencies[tableName] +} + +func (e *Engine) InvalidateTable(tableName string) { + slog.Debug("Invalidating table", "table", tableName) + dependentQueries := e.GetDependentQueries(tableName) + for _, query := range dependentQueries { + slog.Debug("Invalidating query", "query", query) + subscriptions := e.tracker.GetQuerySubscriptions(query) + for _, subscription := range subscriptions { + slog.Debug("Invalidating subscription", "subscription", subscription["clientID"]) + params := map[string]interface{}{} + json.Unmarshal([]byte(subscription["params"]), ¶ms) + _, err := e.ExecuteQuery(query, params) + if err != nil { + slog.Error("Failed to execute query", "error", err) + continue + } + } + } +} + +func (e *Engine) ExecuteQuery(query string, params map[string]interface{}) (interface{}, error) { /* TODO: implement the logic to execute the query Steps needed: - 1. Check which tables updated - 2. Get the queries that rely on the tables - 3. Get the subscriptions that need updating + 1. Check which tables updated ✅ + 2. Get the queries that rely on the tables ✅ + 3. Get the subscriptions that need updating ✅ 4. Calculate hash for every query 5. Send the updated queries if hash changed */ @@ -78,8 +113,7 @@ func (e *Engine) OnReceiveMessage(clientID string, msg map[string]interface{}) e slog.Debug("Received message", "from", clientID, "message", msg) switch msg["type"] { case "query": - query := msg["location"].(string) + "?" + msg["params"].(string) - e.tracker.SubscribeToQuery(clientID, query) + e.tracker.SubscribeToQuery(clientID, msg["location"].(string), msg["params"].(map[string]string)) case "mutation": e.ExecuteMutation(msg["location"].(string), msg["params"].(map[string]interface{})) } diff --git a/reactivity/tracker.go b/reactivity/tracker.go index e5907d0..658fcca 100644 --- a/reactivity/tracker.go +++ b/reactivity/tracker.go @@ -1,6 +1,7 @@ package reactivity import ( + "encoding/json" "log/slog" "sync" ) @@ -13,11 +14,11 @@ type Tracker struct { clients map[string]*Client // Maps a Query Hash (e.g. "getUser?id=1") to a Set of Client IDs - subscriptions map[string]map[string]bool + subscriptions map[string][]map[string]string } func NewTracker() *Tracker { - return &Tracker{clients: make(map[string]*Client), subscriptions: make(map[string]map[string]bool)} + return &Tracker{clients: make(map[string]*Client), subscriptions: make(map[string][]map[string]string)} } func (t *Tracker) Track(c *Client) { @@ -32,30 +33,37 @@ func (t *Tracker) Untrack(c *Client) { delete(t.clients, c.ID) } -func (t *Tracker) SubscribeToQuery(clientID string, query string) { +func (t *Tracker) SubscribeToQuery(clientID string, query string, params map[string]string) { t.mu.Lock() defer t.mu.Unlock() if t.subscriptions[query] == nil { - t.subscriptions[query] = make(map[string]bool) + t.subscriptions[query] = make([]map[string]string, 0) } - t.subscriptions[query][clientID] = true + // set t.subscriptions[query] to a map of client IDs and their params + paramsJSON, err := json.Marshal(params) + if err != nil { + slog.Error("Tracker: Failed to marshal params", "error", err) + return + } + t.subscriptions[query] = append(t.subscriptions[query], map[string]string{"clientID": clientID, "params": string(paramsJSON)}) } func (t *Tracker) UnsubscribeFromQuery(clientID string, query string) { t.mu.Lock() defer t.mu.Unlock() - delete(t.subscriptions[query], clientID) + for i, subscription := range t.subscriptions[query] { + if subscription["clientID"] == clientID { + t.subscriptions[query] = append(t.subscriptions[query][:i], t.subscriptions[query][i+1:]...) + break + } + } } -func (t *Tracker) GetQuerySubscriptions(query string) []string { +func (t *Tracker) GetQuerySubscriptions(query string) []map[string]string { t.mu.RLock() defer t.mu.RUnlock() subscriptions := t.subscriptions[query] - subscriptionIDs := make([]string, 0, len(subscriptions)) - for clientID := range subscriptions { - subscriptionIDs = append(subscriptionIDs, clientID) - } - return subscriptionIDs + return subscriptions } func (t *Tracker) SendMessage(clientID string, message []byte) {