go-mongo/app/loaders.go
2025-05-31 18:38:25 +03:00

104 lines
3.0 KiB
Go

package app
import (
"context"
"fmt"
"net/http"
"reflect"
"git.farahty.com/nimer/go-mongo/models"
"github.com/graph-gophers/dataloader"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// Loaders holds all the dataloaders used in the application
type Loaders struct {
TodosLoader *dataloader.Loader
UsersLoader *dataloader.Loader
CategoryLoader *dataloader.Loader
}
// NewLoaders initializes all batch loaders
func NewLoaders() *Loaders {
return &Loaders{
TodosLoader: dataloader.NewBatchedLoader(CreateBatch[models.Todo]("todos")),
UsersLoader: dataloader.NewBatchedLoader(CreateBatch[models.User]("users")),
CategoryLoader: dataloader.NewBatchedLoader(CreateBatch[models.Category]("categories")),
}
}
// Middleware injects dataloaders into context for each HTTP request
func LoaderMiddleware(loaders *Loaders, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxWithLoaders := context.WithValue(r.Context(), LoadersKey, loaders)
next.ServeHTTP(w, r.WithContext(ctxWithLoaders))
})
}
// LoaderFor retrieves the dataloaders from context
func LoaderFor(ctx context.Context) *Loaders {
if loaders, ok := ctx.Value(LoadersKey).(*Loaders); ok {
return loaders
}
panic("dataloader not found in context")
}
// CreateBatch creates a generic batched loader for a MongoDB collection
func CreateBatch[T any](coll string) dataloader.BatchFunc {
return func(ctx context.Context, keys dataloader.Keys) []*dataloader.Result {
// Convert all keys to MongoDB ObjectIDs
var objectIDs []primitive.ObjectID
keyOrder := make(map[primitive.ObjectID]int) // Track original key order
for i, key := range keys {
id, err := primitive.ObjectIDFromHex(key.String())
if err != nil {
continue
}
objectIDs = append(objectIDs, id)
keyOrder[id] = i
}
// Query the collection for the documents
filter := bson.M{"_id": bson.M{"$in": objectIDs}}
data, err := Find[T](ctx, coll, filter)
if err != nil {
// If the DB fails, return error for all keys
results := make([]*dataloader.Result, len(keys))
for i := range keys {
results[i] = &dataloader.Result{Data: nil, Error: err}
}
return results
}
// Build map of result keyed by ID
objByID := make(map[primitive.ObjectID]T)
for _, item := range data {
val := reflect.ValueOf(item).Elem()
idValue := reflect.Indirect(val).FieldByName("ID").Interface()
if id, ok := idValue.(primitive.ObjectID); ok {
objByID[id] = *item
}
}
// Assemble results in original key order
results := make([]*dataloader.Result, len(keys))
for i, key := range keys {
id, err := primitive.ObjectIDFromHex(key.String())
if err != nil {
results[i] = &dataloader.Result{Data: nil, Error: fmt.Errorf("invalid object ID: %s", key.String())}
continue
}
if val, ok := objByID[id]; ok {
results[i] = &dataloader.Result{Data: &val, Error: nil}
} else {
results[i] = &dataloader.Result{Data: nil, Error: fmt.Errorf("object not found: %s", key.String())}
}
}
return results
}
}