81 lines
2.4 KiB
Go
81 lines
2.4 KiB
Go
package app
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"git.farahty.com/nimer/go-mongo/models"
|
|
"github.com/graph-gophers/dataloader"
|
|
"go.mongodb.org/mongo-driver/bson"
|
|
"go.mongodb.org/mongo-driver/bson/primitive"
|
|
)
|
|
|
|
// Loaders holds all the dataloaders used in the application
|
|
type Loaders struct {
|
|
TodosLoader *dataloader.Loader
|
|
UsersLoader *dataloader.Loader
|
|
CategoryLoader *dataloader.Loader
|
|
}
|
|
|
|
// NewLoaders initializes all batch loaders
|
|
func NewLoaders() *Loaders {
|
|
return &Loaders{
|
|
TodosLoader: dataloader.NewBatchedLoader(CreateBatch[models.Todo]("todos")),
|
|
UsersLoader: dataloader.NewBatchedLoader(CreateBatch[models.User]("users")),
|
|
CategoryLoader: dataloader.NewBatchedLoader(CreateBatch[models.Category]("categories")),
|
|
}
|
|
}
|
|
|
|
// CreateBatch creates a generic batched loader for a MongoDB collection
|
|
func CreateBatch[T models.Identifiable](coll string) dataloader.BatchFunc {
|
|
return func(ctx context.Context, keys dataloader.Keys) []*dataloader.Result {
|
|
// Convert all keys to MongoDB ObjectIDs
|
|
var objectIDs []primitive.ObjectID
|
|
keyOrder := make(map[primitive.ObjectID]int) // Track original key order
|
|
|
|
for i, key := range keys {
|
|
id, err := primitive.ObjectIDFromHex(key.String())
|
|
if err != nil {
|
|
continue
|
|
}
|
|
objectIDs = append(objectIDs, id)
|
|
keyOrder[id] = i
|
|
}
|
|
|
|
// Query the collection for the documents
|
|
filter := bson.M{"_id": bson.M{"$in": objectIDs}}
|
|
data, err := Find[T](ctx, coll, filter)
|
|
if err != nil {
|
|
// If the DB fails, return error for all keys
|
|
results := make([]*dataloader.Result, len(keys))
|
|
for i := range keys {
|
|
results[i] = &dataloader.Result{Data: nil, Error: err}
|
|
}
|
|
return results
|
|
}
|
|
|
|
// Build map of result keyed by ID
|
|
objByID := make(map[primitive.ObjectID]T)
|
|
for _, item := range data {
|
|
objByID[(*item).GetID()] = *item // dereference pointer
|
|
}
|
|
|
|
// Assemble results in original key order
|
|
results := make([]*dataloader.Result, len(keys))
|
|
for i, key := range keys {
|
|
id, err := primitive.ObjectIDFromHex(key.String())
|
|
if err != nil {
|
|
results[i] = &dataloader.Result{Data: nil, Error: fmt.Errorf("invalid object ID: %s", key.String())}
|
|
continue
|
|
}
|
|
if val, ok := objByID[id]; ok {
|
|
results[i] = &dataloader.Result{Data: &val, Error: nil}
|
|
} else {
|
|
results[i] = &dataloader.Result{Data: nil, Error: fmt.Errorf("object not found: %s", key.String())}
|
|
}
|
|
}
|
|
|
|
return results
|
|
}
|
|
}
|