package app import ( "context" "fmt" "net/http" "git.farahty.com/nimer/go-mongo/models" "github.com/graph-gophers/dataloader" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" ) // Loaders holds all the dataloaders used in the application type Loaders struct { TodosLoader *dataloader.Loader UsersLoader *dataloader.Loader CategoryLoader *dataloader.Loader } // NewLoaders initializes all batch loaders func NewLoaders() *Loaders { return &Loaders{ TodosLoader: dataloader.NewBatchedLoader(CreateBatch[models.Todo]("todos")), UsersLoader: dataloader.NewBatchedLoader(CreateBatch[models.User]("users")), CategoryLoader: dataloader.NewBatchedLoader(CreateBatch[models.Category]("categories")), } } // Middleware injects dataloaders into context for each HTTP request func LoaderMiddleware(loaders *Loaders, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctxWithLoaders := context.WithValue(r.Context(), LoadersKey, loaders) next.ServeHTTP(w, r.WithContext(ctxWithLoaders)) }) } // LoaderFor retrieves the dataloaders from context func LoaderFor(ctx context.Context) *Loaders { if loaders, ok := ctx.Value(LoadersKey).(*Loaders); ok { return loaders } panic("dataloader not found in context") } // CreateBatch creates a generic batched loader for a MongoDB collection func CreateBatch[T models.Identifiable](coll string) dataloader.BatchFunc { return func(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { // Convert all keys to MongoDB ObjectIDs var objectIDs []primitive.ObjectID keyOrder := make(map[primitive.ObjectID]int) // Track original key order for i, key := range keys { id, err := primitive.ObjectIDFromHex(key.String()) if err != nil { continue } objectIDs = append(objectIDs, id) keyOrder[id] = i } // Query the collection for the documents filter := bson.M{"_id": bson.M{"$in": objectIDs}} data, err := Find[T](ctx, coll, filter) if err != nil { // If the DB fails, return error for all keys results := make([]*dataloader.Result, len(keys)) for i := range keys { results[i] = &dataloader.Result{Data: nil, Error: err} } return results } // Build map of result keyed by ID objByID := make(map[primitive.ObjectID]T) for _, item := range data { objByID[(*item).GetID()] = *item // dereference pointer } // Assemble results in original key order results := make([]*dataloader.Result, len(keys)) for i, key := range keys { id, err := primitive.ObjectIDFromHex(key.String()) if err != nil { results[i] = &dataloader.Result{Data: nil, Error: fmt.Errorf("invalid object ID: %s", key.String())} continue } if val, ok := objByID[id]; ok { results[i] = &dataloader.Result{Data: &val, Error: nil} } else { results[i] = &dataloader.Result{Data: nil, Error: fmt.Errorf("object not found: %s", key.String())} } } return results } }