some work on the query/invalidation engine

This commit is contained in:
2026-04-23 02:35:23 -05:00
parent c1c8950f24
commit e623ed1c52
2 changed files with 66 additions and 24 deletions
+41 -7
View File
@@ -1,6 +1,7 @@
package tether package tether
import ( import (
"encoding/json"
"log/slog" "log/slog"
"net/http" "net/http"
@@ -13,6 +14,7 @@ type Engine struct {
dbType string // sqlite or postgres dbType string // sqlite or postgres
mutations map[string]func(ctx *MutationCtx) error mutations map[string]func(ctx *MutationCtx) error
queries map[string]func(ctx *QueryCtx) error queries map[string]func(ctx *QueryCtx) error
dependencies map[string][]string
tracker *reactivity.Tracker tracker *reactivity.Tracker
} }
@@ -22,7 +24,14 @@ func NewEngine(db *gorm.DB, dbType string) *Engine {
if dbType != "sqlite" && dbType != "postgres" { if dbType != "sqlite" && dbType != "postgres" {
panic("Invalid database type") panic("Invalid database type")
} }
return &Engine{db: db, dbType: dbType, mutations: make(map[string]func(ctx *MutationCtx) error), queries: make(map[string]func(ctx *QueryCtx) error), tracker: tracker} e := &Engine{db: db, dbType: dbType, mutations: make(map[string]func(ctx *MutationCtx) error), queries: make(map[string]func(ctx *QueryCtx) error), dependencies: make(map[string][]string), tracker: tracker}
db.Callback().Create().After("gorm:create").Register("tether:after_create", func(tx *gorm.DB) {
if dbType == "postgres" {
return
}
e.InvalidateTable(tx.Statement.Table)
})
return e
} }
func (e *Engine) RegisterMutation(name string, mutation func(ctx *MutationCtx) error) { func (e *Engine) RegisterMutation(name string, mutation func(ctx *MutationCtx) error) {
@@ -32,6 +41,9 @@ func (e *Engine) RegisterMutation(name string, mutation func(ctx *MutationCtx) e
func (e *Engine) RegisterQuery(name string, query func(ctx *QueryCtx) error, dependencies []string) { func (e *Engine) RegisterQuery(name string, query func(ctx *QueryCtx) error, dependencies []string) {
e.queries[name] = query // stores the query in the list of valid queries e.queries[name] = query // stores the query in the list of valid queries
for _, dependency := range dependencies {
e.dependencies[dependency] = append(e.dependencies[dependency], name)
}
slog.Debug("Registered query", "name", name) slog.Debug("Registered query", "name", name)
} }
@@ -56,13 +68,36 @@ func (e *Engine) OnDisconnect(clientID string) error {
return nil return nil
} }
func (e *Engine) ExecuteQuery(query string) (interface{}, error) { func (e *Engine) GetDependentQueries(tableName string) []string {
return e.dependencies[tableName]
}
func (e *Engine) InvalidateTable(tableName string) {
slog.Debug("Invalidating table", "table", tableName)
dependentQueries := e.GetDependentQueries(tableName)
for _, query := range dependentQueries {
slog.Debug("Invalidating query", "query", query)
subscriptions := e.tracker.GetQuerySubscriptions(query)
for _, subscription := range subscriptions {
slog.Debug("Invalidating subscription", "subscription", subscription["clientID"])
params := map[string]interface{}{}
json.Unmarshal([]byte(subscription["params"]), &params)
_, err := e.ExecuteQuery(query, params)
if err != nil {
slog.Error("Failed to execute query", "error", err)
continue
}
}
}
}
func (e *Engine) ExecuteQuery(query string, params map[string]interface{}) (interface{}, error) {
/* /*
TODO: implement the logic to execute the query TODO: implement the logic to execute the query
Steps needed: Steps needed:
1. Check which tables updated 1. Check which tables updated
2. Get the queries that rely on the tables 2. Get the queries that rely on the tables
3. Get the subscriptions that need updating 3. Get the subscriptions that need updating
4. Calculate hash for every query 4. Calculate hash for every query
5. Send the updated queries if hash changed 5. Send the updated queries if hash changed
*/ */
@@ -78,8 +113,7 @@ func (e *Engine) OnReceiveMessage(clientID string, msg map[string]interface{}) e
slog.Debug("Received message", "from", clientID, "message", msg) slog.Debug("Received message", "from", clientID, "message", msg)
switch msg["type"] { switch msg["type"] {
case "query": case "query":
query := msg["location"].(string) + "?" + msg["params"].(string) e.tracker.SubscribeToQuery(clientID, msg["location"].(string), msg["params"].(map[string]string))
e.tracker.SubscribeToQuery(clientID, query)
case "mutation": case "mutation":
e.ExecuteMutation(msg["location"].(string), msg["params"].(map[string]interface{})) e.ExecuteMutation(msg["location"].(string), msg["params"].(map[string]interface{}))
} }
+20 -12
View File
@@ -1,6 +1,7 @@
package reactivity package reactivity
import ( import (
"encoding/json"
"log/slog" "log/slog"
"sync" "sync"
) )
@@ -13,11 +14,11 @@ type Tracker struct {
clients map[string]*Client clients map[string]*Client
// Maps a Query Hash (e.g. "getUser?id=1") to a Set of Client IDs // Maps a Query Hash (e.g. "getUser?id=1") to a Set of Client IDs
subscriptions map[string]map[string]bool subscriptions map[string][]map[string]string
} }
func NewTracker() *Tracker { func NewTracker() *Tracker {
return &Tracker{clients: make(map[string]*Client), subscriptions: make(map[string]map[string]bool)} return &Tracker{clients: make(map[string]*Client), subscriptions: make(map[string][]map[string]string)}
} }
func (t *Tracker) Track(c *Client) { func (t *Tracker) Track(c *Client) {
@@ -32,30 +33,37 @@ func (t *Tracker) Untrack(c *Client) {
delete(t.clients, c.ID) delete(t.clients, c.ID)
} }
func (t *Tracker) SubscribeToQuery(clientID string, query string) { func (t *Tracker) SubscribeToQuery(clientID string, query string, params map[string]string) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
if t.subscriptions[query] == nil { if t.subscriptions[query] == nil {
t.subscriptions[query] = make(map[string]bool) t.subscriptions[query] = make([]map[string]string, 0)
} }
t.subscriptions[query][clientID] = true // set t.subscriptions[query] to a map of client IDs and their params
paramsJSON, err := json.Marshal(params)
if err != nil {
slog.Error("Tracker: Failed to marshal params", "error", err)
return
}
t.subscriptions[query] = append(t.subscriptions[query], map[string]string{"clientID": clientID, "params": string(paramsJSON)})
} }
func (t *Tracker) UnsubscribeFromQuery(clientID string, query string) { func (t *Tracker) UnsubscribeFromQuery(clientID string, query string) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
delete(t.subscriptions[query], clientID) for i, subscription := range t.subscriptions[query] {
if subscription["clientID"] == clientID {
t.subscriptions[query] = append(t.subscriptions[query][:i], t.subscriptions[query][i+1:]...)
break
}
}
} }
func (t *Tracker) GetQuerySubscriptions(query string) []string { func (t *Tracker) GetQuerySubscriptions(query string) []map[string]string {
t.mu.RLock() t.mu.RLock()
defer t.mu.RUnlock() defer t.mu.RUnlock()
subscriptions := t.subscriptions[query] subscriptions := t.subscriptions[query]
subscriptionIDs := make([]string, 0, len(subscriptions)) return subscriptions
for clientID := range subscriptions {
subscriptionIDs = append(subscriptionIDs, clientID)
}
return subscriptionIDs
} }
func (t *Tracker) SendMessage(clientID string, message []byte) { func (t *Tracker) SendMessage(clientID string, message []byte) {