104 lines
3.0 KiB
Go
104 lines
3.0 KiB
Go
package app
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"reflect"
|
|
|
|
"git.farahty.com/nimer/go-mongo/models"
|
|
"github.com/graph-gophers/dataloader"
|
|
"go.mongodb.org/mongo-driver/bson"
|
|
"go.mongodb.org/mongo-driver/bson/primitive"
|
|
)
|
|
|
|
// Loaders holds all the dataloaders used in the application
|
|
type Loaders struct {
|
|
TodosLoader *dataloader.Loader
|
|
UsersLoader *dataloader.Loader
|
|
CategoryLoader *dataloader.Loader
|
|
}
|
|
|
|
// NewLoaders initializes all batch loaders
|
|
func NewLoaders() *Loaders {
|
|
return &Loaders{
|
|
TodosLoader: dataloader.NewBatchedLoader(CreateBatch[models.Todo]("todos")),
|
|
UsersLoader: dataloader.NewBatchedLoader(CreateBatch[models.User]("users")),
|
|
CategoryLoader: dataloader.NewBatchedLoader(CreateBatch[models.Category]("categories")),
|
|
}
|
|
}
|
|
|
|
// Middleware injects dataloaders into context for each HTTP request
|
|
func LoaderMiddleware(loaders *Loaders, next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctxWithLoaders := context.WithValue(r.Context(), LoadersKey, loaders)
|
|
next.ServeHTTP(w, r.WithContext(ctxWithLoaders))
|
|
})
|
|
}
|
|
|
|
// LoaderFor retrieves the dataloaders from context
|
|
func LoaderFor(ctx context.Context) *Loaders {
|
|
if loaders, ok := ctx.Value(LoadersKey).(*Loaders); ok {
|
|
return loaders
|
|
}
|
|
panic("dataloader not found in context")
|
|
}
|
|
|
|
// CreateBatch creates a generic batched loader for a MongoDB collection
|
|
func CreateBatch[T any](coll string) dataloader.BatchFunc {
|
|
return func(ctx context.Context, keys dataloader.Keys) []*dataloader.Result {
|
|
// Convert all keys to MongoDB ObjectIDs
|
|
var objectIDs []primitive.ObjectID
|
|
keyOrder := make(map[primitive.ObjectID]int) // Track original key order
|
|
|
|
for i, key := range keys {
|
|
id, err := primitive.ObjectIDFromHex(key.String())
|
|
if err != nil {
|
|
continue
|
|
}
|
|
objectIDs = append(objectIDs, id)
|
|
keyOrder[id] = i
|
|
}
|
|
|
|
// Query the collection for the documents
|
|
filter := bson.M{"_id": bson.M{"$in": objectIDs}}
|
|
data, err := Find[T](ctx, coll, filter)
|
|
if err != nil {
|
|
// If the DB fails, return error for all keys
|
|
results := make([]*dataloader.Result, len(keys))
|
|
for i := range keys {
|
|
results[i] = &dataloader.Result{Data: nil, Error: err}
|
|
}
|
|
return results
|
|
}
|
|
|
|
// Build map of result keyed by ID
|
|
objByID := make(map[primitive.ObjectID]T)
|
|
for _, item := range data {
|
|
val := reflect.ValueOf(item).Elem()
|
|
idValue := reflect.Indirect(val).FieldByName("ID").Interface()
|
|
|
|
if id, ok := idValue.(primitive.ObjectID); ok {
|
|
objByID[id] = *item
|
|
}
|
|
}
|
|
|
|
// Assemble results in original key order
|
|
results := make([]*dataloader.Result, len(keys))
|
|
for i, key := range keys {
|
|
id, err := primitive.ObjectIDFromHex(key.String())
|
|
if err != nil {
|
|
results[i] = &dataloader.Result{Data: nil, Error: fmt.Errorf("invalid object ID: %s", key.String())}
|
|
continue
|
|
}
|
|
if val, ok := objByID[id]; ok {
|
|
results[i] = &dataloader.Result{Data: &val, Error: nil}
|
|
} else {
|
|
results[i] = &dataloader.Result{Data: nil, Error: fmt.Errorf("object not found: %s", key.String())}
|
|
}
|
|
}
|
|
|
|
return results
|
|
}
|
|
}
|