go-mongo/app/loaders.go

81 lines
2.4 KiB
Go

package app
import (
"context"
"fmt"
"git.farahty.com/nimer/go-mongo/models"
"github.com/graph-gophers/dataloader"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// Loaders holds all the dataloaders used in the application
type Loaders struct {
TodosLoader *dataloader.Loader
UsersLoader *dataloader.Loader
CategoryLoader *dataloader.Loader
}
// NewLoaders initializes all batch loaders
func NewLoaders() *Loaders {
return &Loaders{
TodosLoader: dataloader.NewBatchedLoader(CreateBatch[models.Todo]("todos")),
UsersLoader: dataloader.NewBatchedLoader(CreateBatch[models.User]("users")),
CategoryLoader: dataloader.NewBatchedLoader(CreateBatch[models.Category]("categories")),
}
}
// CreateBatch creates a generic batched loader for a MongoDB collection
func CreateBatch[T models.Identifiable](coll string) dataloader.BatchFunc {
return func(ctx context.Context, keys dataloader.Keys) []*dataloader.Result {
// Convert all keys to MongoDB ObjectIDs
var objectIDs []primitive.ObjectID
keyOrder := make(map[primitive.ObjectID]int) // Track original key order
for i, key := range keys {
id, err := primitive.ObjectIDFromHex(key.String())
if err != nil {
continue
}
objectIDs = append(objectIDs, id)
keyOrder[id] = i
}
// Query the collection for the documents
filter := bson.M{"_id": bson.M{"$in": objectIDs}}
data, err := Find[T](ctx, coll, filter)
if err != nil {
// If the DB fails, return error for all keys
results := make([]*dataloader.Result, len(keys))
for i := range keys {
results[i] = &dataloader.Result{Data: nil, Error: err}
}
return results
}
// Build map of result keyed by ID
objByID := make(map[primitive.ObjectID]T)
for _, item := range data {
objByID[(*item).GetID()] = *item // dereference pointer
}
// Assemble results in original key order
results := make([]*dataloader.Result, len(keys))
for i, key := range keys {
id, err := primitive.ObjectIDFromHex(key.String())
if err != nil {
results[i] = &dataloader.Result{Data: nil, Error: fmt.Errorf("invalid object ID: %s", key.String())}
continue
}
if val, ok := objByID[id]; ok {
results[i] = &dataloader.Result{Data: &val, Error: nil}
} else {
results[i] = &dataloader.Result{Data: nil, Error: fmt.Errorf("object not found: %s", key.String())}
}
}
return results
}
}