Compare commits

...

6 Commits

33 changed files with 831 additions and 338 deletions

View File

@ -3,6 +3,7 @@ package app
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"git.farahty.com/nimer/go-mongo/models" "git.farahty.com/nimer/go-mongo/models"
) )
@ -15,9 +16,33 @@ var (
UserKey = &contextKey{"user"} UserKey = &contextKey{"user"}
StatusKey = &contextKey{"status"} StatusKey = &contextKey{"status"}
ExpiryKey = &contextKey{"expiry"} ExpiryKey = &contextKey{"expiry"}
LoadersKey = &contextKey{"dataloaders"} LoadersKey = &contextKey{"dataLoaders"}
WriterKye = &contextKey{"writer"}
RequestKey = &contextKey{"request"}
) )
// LoaderFor retrieves the dataLoaders from context
func LoaderFor(ctx context.Context) *Loaders {
if loaders, ok := ctx.Value(LoadersKey).(*Loaders); ok {
return loaders
}
panic("dataloader not found in context")
}
func WriterFor(ctx context.Context) *http.ResponseWriter {
if writer, ok := ctx.Value(WriterKye).(*http.ResponseWriter); ok {
return writer
}
panic("no writer found in context")
}
func RequestFor(ctx context.Context) *http.Request {
if req, ok := ctx.Value(RequestKey).(*http.Request); ok {
return req
}
panic("no request found in context")
}
// Retrieves the current user from the context // Retrieves the current user from the context
func CurrentUser(ctx context.Context) (*models.UserJWT, error) { func CurrentUser(ctx context.Context) (*models.UserJWT, error) {
user, _ := ctx.Value(UserKey).(*models.UserJWT) user, _ := ctx.Value(UserKey).(*models.UserJWT)
@ -32,6 +57,7 @@ func CurrentUser(ctx context.Context) (*models.UserJWT, error) {
// Check if the token was marked as expired // Check if the token was marked as expired
func IsTokenExpired(ctx context.Context) bool { func IsTokenExpired(ctx context.Context) bool {
if expired, ok := ctx.Value(ExpiryKey).(bool); ok { if expired, ok := ctx.Value(ExpiryKey).(bool); ok {
return expired return expired
} }

View File

@ -5,9 +5,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"log" "log"
"os"
"git.farahty.com/nimer/go-mongo/models" "git.farahty.com/nimer/go-mongo/models"
"git.farahty.com/nimer/go-mongo/utils"
"github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/handler/transport" "github.com/99designs/gqlgen/graphql/handler/transport"
"github.com/casbin/casbin/v2" "github.com/casbin/casbin/v2"
@ -22,8 +22,7 @@ var (
) )
func LoadAuthorizer(ctx context.Context) error { func LoadAuthorizer(ctx context.Context) error {
a, err := mongodbadapter.NewAdapterWithClientOption(options.Client().ApplyURI(os.Getenv("MONGO_URI")), os.Getenv("MONGO_DB")) a, err := mongodbadapter.NewAdapterWithClientOption(options.Client().ApplyURI(Config.MongoURI), Config.MongoDB)
if err != nil { if err != nil {
return err return err
} }
@ -32,18 +31,15 @@ func LoadAuthorizer(ctx context.Context) error {
return err return err
} else { } else {
Authorizer = enforcer Authorizer = enforcer
} }
a.AddPolicy("mongodb", "p", []string{"*", "login", "mutation"}) a.AddPolicy("mongodb", "p", []string{"*", "login", "mutation"})
a.AddPolicy("mongodb", "p", []string{"todos-admin", "createTodo", "mutation"}) a.AddPolicy("mongodb", "p", []string{"todos-admin", "createTodo", "mutation"})
a.AddPolicy("mongodb", "p", []string{"todos-admin", "todos", "query"}) a.AddPolicy("mongodb", "p", []string{"todos-admin", "todos", "query"})
a.AddPolicy("mongodb", "p", []string{"todos-admin", "todo", "query"}) a.AddPolicy("mongodb", "p", []string{"todos-admin", "todo", "query"})
a.AddPolicy("mongodb", "p", []string{"category-admin", "createCategory", "mutation"}) a.AddPolicy("mongodb", "p", []string{"category-admin", "createCategory", "mutation"})
a.AddPolicy("mongodb", "p", []string{"category-admin", "categories", "query"}) a.AddPolicy("mongodb", "p", []string{"category-admin", "categories", "query"})
a.AddPolicy("mongodb", "p", []string{"category-admin", "category", "query"}) a.AddPolicy("mongodb", "p", []string{"category-admin", "category", "query"})
a.AddPolicy("mongodb", "p", []string{"users-admin", "users", "query"}) a.AddPolicy("mongodb", "p", []string{"users-admin", "users", "query"})
a.AddPolicy("mongodb", "p", []string{"users-admin", "createUser", "mutation"}) a.AddPolicy("mongodb", "p", []string{"users-admin", "createUser", "mutation"})
@ -51,8 +47,8 @@ func LoadAuthorizer(ctx context.Context) error {
Authorizer.AddGroupingPolicy("admin", "todos-admin") Authorizer.AddGroupingPolicy("admin", "todos-admin")
Authorizer.AddGroupingPolicy("admin", "category-admin") Authorizer.AddGroupingPolicy("admin", "category-admin")
email := os.Getenv("ADMIN_EMAIL") email := Config.AdminEmail
password, _ := models.MakeHash(os.Getenv("ADMIN_PASSWORD")) password, _ := models.MakeHash(Config.AdminPass)
if _, err := FindOne[models.User](ctx, "users", bson.M{"email": email}); err != nil { if _, err := FindOne[models.User](ctx, "users", bson.M{"email": email}); err != nil {
log.Println("Creating admin user") log.Println("Creating admin user")
@ -67,7 +63,6 @@ func LoadAuthorizer(ctx context.Context) error {
} }
Authorizer.AddRoleForUser(admin.ID.Hex(), "admin") Authorizer.AddRoleForUser(admin.ID.Hex(), "admin")
} }
return nil return nil
@ -82,7 +77,7 @@ func AuthorizeWebSocket(ctx context.Context, initPayload transport.InitPayload)
return ctx, &initPayload, nil return ctx, &initPayload, nil
} }
user, err := getUserFromToken(token) user, err := utils.GetUserFromToken[models.UserJWT](token, Config.AccessSecret)
if err != nil { if err != nil {
ctx = SetStatus(ctx, err.Error()) ctx = SetStatus(ctx, err.Error())
@ -119,6 +114,10 @@ func AuthorizeOperation(ctx context.Context) error {
return nil return nil
} }
if IsTokenExpired(ctx) && object != "login" {
return fmt.Errorf("token expired")
}
if obj, err := CurrentUser(ctx); err == nil { if obj, err := CurrentUser(ctx); err == nil {
user = string(obj.ID) user = string(obj.ID)
} }

90
app/config.go Normal file
View File

@ -0,0 +1,90 @@
package app
import (
"fmt"
"log"
"os"
"github.com/joho/godotenv"
)
var (
Config *ConfigStruct
)
type ConfigStruct struct {
Port string
MongoURI string
MongoDB string
RedisHost string
RedisPort string
RedisPass string
AdminEmail string
AdminPass string
RefreshSecret string
RefreshExpiry string
AccessSecret string
AccessExpiry string
Env string
}
// LoadConfig loads environment variables into Config struct and validates required fields.
func LoadConfig() error {
if _, exists := os.LookupEnv("MONGO_URI"); !exists {
if err := godotenv.Load(); err != nil {
log.Fatal("🔴 Failed to load .env file: ", err)
}
}
Config = &ConfigStruct{
Port: getEnv("PORT", "8080"),
MongoURI: os.Getenv("MONGO_URI"),
MongoDB: getEnv("MONGO_DB", "test"),
RedisHost: os.Getenv("REDIS_HOST"),
RedisPort: os.Getenv("REDIS_PORT"),
RedisPass: os.Getenv("REDIS_PASSWORD"),
AdminEmail: os.Getenv("ADMIN_EMAIL"),
AdminPass: os.Getenv("ADMIN_PASSWORD"),
RefreshSecret: os.Getenv("REFRESH_SECRET"),
RefreshExpiry: os.Getenv("REFRESH_EXPIRY"),
AccessSecret: os.Getenv("ACCESS_SECRET"),
AccessExpiry: os.Getenv("ACCESS_EXPIRY"),
Env: os.Getenv("ENV"),
}
missing := ""
if Config.Port == "" {
missing += "PORT "
}
if Config.MongoURI == "" {
missing += "MONGO_URI "
}
if Config.RedisHost == "" {
missing += "REDIS_HOST "
}
if Config.RedisPort == "" {
missing += "REDIS_PORT "
}
if Config.AccessSecret == "" {
missing += "ACCESS_SECRET "
}
if Config.RefreshSecret == "" {
missing += "REFRESH_SECRET "
}
if missing != "" {
return fmt.Errorf("missing required environment variables: %s", missing)
}
return nil
}
// GetConfig returns the global config
func GetConfig() *ConfigStruct {
return Config
}
func getEnv(key, def string) string {
if v := os.Getenv(key); v != "" {
return v
}
return def
}

View File

@ -3,7 +3,6 @@ package app
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"time" "time"
"github.com/fatih/color" "github.com/fatih/color"
@ -22,7 +21,7 @@ func Connect() (context.CancelFunc, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
var err error var err error
Mongo, err = mongo.Connect(ctx, options.Client().ApplyURI(os.Getenv("MONGO_URI"))) Mongo, err = mongo.Connect(ctx, options.Client().ApplyURI(Config.MongoURI))
if err != nil { if err != nil {
color.Red("❌ Database connection failed\n" + err.Error()) color.Red("❌ Database connection failed\n" + err.Error())
return cancel, err return cancel, err
@ -45,7 +44,7 @@ func Disconnect(ctx context.Context) error {
// Collection returns a MongoDB collection // Collection returns a MongoDB collection
func Collection(name string) *mongo.Collection { func Collection(name string) *mongo.Collection {
return Mongo.Database(os.Getenv("MONGO_DB")).Collection(name) return Mongo.Database(Config.MongoDB).Collection(name)
} }
// Find returns all matching documents from a collection // Find returns all matching documents from a collection

View File

@ -3,7 +3,6 @@ package app
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"git.farahty.com/nimer/go-mongo/models" "git.farahty.com/nimer/go-mongo/models"
"github.com/graph-gophers/dataloader" "github.com/graph-gophers/dataloader"
@ -27,22 +26,6 @@ func NewLoaders() *Loaders {
} }
} }
// Middleware injects dataloaders into context for each HTTP request
func LoaderMiddleware(loaders *Loaders, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxWithLoaders := context.WithValue(r.Context(), LoadersKey, loaders)
next.ServeHTTP(w, r.WithContext(ctxWithLoaders))
})
}
// LoaderFor retrieves the dataloaders from context
func LoaderFor(ctx context.Context) *Loaders {
if loaders, ok := ctx.Value(LoadersKey).(*Loaders); ok {
return loaders
}
panic("dataloader not found in context")
}
// CreateBatch creates a generic batched loader for a MongoDB collection // CreateBatch creates a generic batched loader for a MongoDB collection
func CreateBatch[T models.Identifiable](coll string) dataloader.BatchFunc { func CreateBatch[T models.Identifiable](coll string) dataloader.BatchFunc {
return func(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { return func(ctx context.Context, keys dataloader.Keys) []*dataloader.Result {

View File

@ -3,32 +3,67 @@ package app
import ( import (
"context" "context"
"errors" "errors"
"net/http" "net/http"
"git.farahty.com/nimer/go-mongo/models"
"git.farahty.com/nimer/go-mongo/utils"
"github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
) )
// Middleware injects dataLoaders into context for each HTTP request
func LoaderMiddleware(loaders *Loaders, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxWithLoaders := context.WithValue(r.Context(), LoadersKey, loaders)
next.ServeHTTP(w, r.WithContext(ctxWithLoaders))
})
}
// ExpiryMiddleware checks for expired tokens in GraphQL resolvers // ExpiryMiddleware checks for expired tokens in GraphQL resolvers
func ExpiryMiddleware(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { func ExpiryMiddleware(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
if IsTokenExpired(ctx) { if IsTokenExpired(ctx) {
return graphql.ErrorResponse(ctx, "token expired") return graphql.ErrorResponse(ctx, "token expired")
} }
return next(ctx) return next(ctx)
} }
// add response writer to context for GraphQL resolvers
func WriterMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), WriterKye, &rw)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
// add request to context for GraphQL resolvers
func RequestMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), RequestKey, r)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
// AuthMiddleware parses JWT token and injects user context for HTTP requests // AuthMiddleware parses JWT token and injects user context for HTTP requests
func AuthMiddleware(next http.Handler) http.Handler { func AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
tokenStr, err := getTokenFromHeader(r) headerToken, headerErr := utils.GetTokenFromHeader(r)
if err != nil { cookieToken, cookieErr := utils.GetTokenFromCookie(r)
ctx := SetStatus(r.Context(), err.Error())
if headerErr != nil && cookieErr != nil {
ctx := SetStatus(r.Context(), headerErr.Error())
next.ServeHTTP(rw, r.WithContext(ctx)) next.ServeHTTP(rw, r.WithContext(ctx))
return return
} }
user, err := getUserFromToken(tokenStr) token := headerToken
if token == "" {
token = cookieToken
}
user, err := utils.GetUserFromToken[models.UserJWT](token, Config.AccessSecret)
ctx := r.Context() ctx := r.Context()
if err != nil { if err != nil {

24
app/redis-client.go Normal file
View File

@ -0,0 +1,24 @@
package app
import (
"context"
"github.com/redis/go-redis/v9"
)
var (
RedisClient redis.UniversalClient
)
// InitRedis initializes the Redis client and assigns it to RedisClient
func InitRedis(ctx context.Context) error {
RedisClient = redis.NewClient(&redis.Options{
Addr: Config.RedisHost + ":" + Config.RedisPort,
Password: Config.RedisPass,
DB: 0,
})
if _, err := RedisClient.Ping(ctx).Result(); err != nil {
return err
}
return nil
}

View File

@ -8,7 +8,7 @@ import (
"github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql"
) )
func Auth(ctx context.Context, obj interface{}, next graphql.Resolver) (res interface{}, err error) { func Auth(ctx context.Context, obj any, next graphql.Resolver) (res any, err error) {
if _, err := app.CurrentUser(ctx); err != nil { if _, err := app.CurrentUser(ctx); err != nil {
return nil, fmt.Errorf("access denied, %s", err.Error()) return nil, fmt.Errorf("access denied, %s", err.Error())

20
directives/has-role.go Normal file
View File

@ -0,0 +1,20 @@
package directives
import (
"context"
"fmt"
"git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/models"
"github.com/99designs/gqlgen/graphql"
)
func HasRole(ctx context.Context, obj any, next graphql.Resolver, role models.Role) (res any, err error) {
if _, err := app.CurrentUser(ctx); err != nil {
return nil, fmt.Errorf("access denied, %s", err.Error())
}
return next(ctx)
}

View File

@ -7,6 +7,7 @@ import (
"os" "os"
"strings" "strings"
"git.farahty.com/nimer/go-mongo/utils"
"github.com/99designs/gqlgen/api" "github.com/99designs/gqlgen/api"
"github.com/99designs/gqlgen/codegen/config" "github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/plugin/modelgen" "github.com/99designs/gqlgen/plugin/modelgen"
@ -29,7 +30,7 @@ func mutateHook(b *modelgen.ModelBuild) *modelgen.ModelBuild {
field.Tag = strings.TrimSuffix(field.Tag, `"`) + `,omitempty" bson:"-"` field.Tag = strings.TrimSuffix(field.Tag, `"`) + `,omitempty" bson:"-"`
} }
} else { } else {
field.Tag = strings.TrimSuffix(field.Tag, `"`) + `" bson:"` + field.Name + `,omitempty"` field.Tag = strings.TrimSuffix(field.Tag, `"`) + `" bson:"` + utils.Camelize(field.Name) + `,omitempty"`
} }
} }
} }

View File

@ -13,8 +13,8 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/models" "git.farahty.com/nimer/go-mongo/models"
"git.farahty.com/nimer/go-mongo/utils"
"github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/introspection" "github.com/99designs/gqlgen/graphql/introspection"
gqlparser "github.com/vektah/gqlparser/v2" gqlparser "github.com/vektah/gqlparser/v2"
@ -52,11 +52,13 @@ type ResolverRoot interface {
type DirectiveRoot struct { type DirectiveRoot struct {
Auth func(ctx context.Context, obj any, next graphql.Resolver) (res any, err error) Auth func(ctx context.Context, obj any, next graphql.Resolver) (res any, err error)
HasRole func(ctx context.Context, obj any, next graphql.Resolver, role models.Role) (res any, err error)
} }
type ComplexityRoot struct { type ComplexityRoot struct {
Category struct { Category struct {
Body func(childComplexity int) int Body func(childComplexity int) int
Children func(childComplexity int) int
CreatedAt func(childComplexity int) int CreatedAt func(childComplexity int) int
CreatedBy func(childComplexity int) int CreatedBy func(childComplexity int) int
CreatedByID func(childComplexity int) int CreatedByID func(childComplexity int) int
@ -119,10 +121,8 @@ type ComplexityRoot struct {
User struct { User struct {
Email func(childComplexity int) int Email func(childComplexity int) int
ID func(childComplexity int) int ID func(childComplexity int) int
Password func(childComplexity int) int
Phone func(childComplexity int) int Phone func(childComplexity int) int
Status func(childComplexity int) int Status func(childComplexity int) int
Token func(childComplexity int) int
Type func(childComplexity int) int Type func(childComplexity int) int
Verified func(childComplexity int) int Verified func(childComplexity int) int
} }
@ -131,6 +131,8 @@ type ComplexityRoot struct {
type CategoryResolver interface { type CategoryResolver interface {
Parent(ctx context.Context, obj *models.Category) (*models.Category, error) Parent(ctx context.Context, obj *models.Category) (*models.Category, error)
Children(ctx context.Context, obj *models.Category) ([]*models.Category, error)
CreatedBy(ctx context.Context, obj *models.Category) (*models.User, error) CreatedBy(ctx context.Context, obj *models.Category) (*models.User, error)
UpdatedBy(ctx context.Context, obj *models.Category) (*models.User, error) UpdatedBy(ctx context.Context, obj *models.Category) (*models.User, error)
@ -190,6 +192,13 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin
return e.complexity.Category.Body(childComplexity), true return e.complexity.Category.Body(childComplexity), true
case "Category.children":
if e.complexity.Category.Children == nil {
break
}
return e.complexity.Category.Children(childComplexity), true
case "Category.createdAt": case "Category.createdAt":
if e.complexity.Category.CreatedAt == nil { if e.complexity.Category.CreatedAt == nil {
break break
@ -507,13 +516,6 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin
return e.complexity.User.ID(childComplexity), true return e.complexity.User.ID(childComplexity), true
case "User.password":
if e.complexity.User.Password == nil {
break
}
return e.complexity.User.Password(childComplexity), true
case "User.phone": case "User.phone":
if e.complexity.User.Phone == nil { if e.complexity.User.Phone == nil {
break break
@ -528,13 +530,6 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin
return e.complexity.User.Status(childComplexity), true return e.complexity.User.Status(childComplexity), true
case "User.token":
if e.complexity.User.Token == nil {
break
}
return e.complexity.User.Token(childComplexity), true
case "User.type": case "User.type":
if e.complexity.User.Type == nil { if e.complexity.User.Type == nil {
break break
@ -678,6 +673,13 @@ func (ec *executionContext) introspectType(name string) (*introspection.Type, er
var sources = []*ast.Source{ var sources = []*ast.Source{
{Name: "../gql/auth.gql", Input: `directive @auth on FIELD_DEFINITION {Name: "../gql/auth.gql", Input: `directive @auth on FIELD_DEFINITION
directive @hasRole(role: Role!) on FIELD_DEFINITION
enum Role {
ADMIN
USER
}
type LoginResponse { type LoginResponse {
user: User! user: User!
accessToken: String! accessToken: String!
@ -693,15 +695,31 @@ extend type Mutation {
login(input: LoginInput!): LoginResponse! login(input: LoginInput!): LoginResponse!
} }
`, BuiltIn: false}, `, BuiltIn: false},
{Name: "../gql/base.gql", Input: `directive @goField( {Name: "../gql/base.gql", Input: `directive @goModel(
model: String
models: [String!]
forceGenerate: Boolean
) on OBJECT | INPUT_OBJECT | SCALAR | ENUM | INTERFACE | UNION
directive @goField(
forceResolver: Boolean forceResolver: Boolean
name: String name: String
omittable: Boolean
type: String
) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION ) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION
directive @goTag( directive @goTag(
key: String! key: String!
value: String value: String
) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION ) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION
directive @goExtraField(
name: String
type: String!
overrideTags: String
description: String
) repeatable on OBJECT | INPUT_OBJECT
scalar Time scalar Time
interface Base { interface Base {
@ -744,6 +762,9 @@ input TranslatedInput {
parent: Category @goField(forceResolver: true) parent: Category @goField(forceResolver: true)
parentId: ID parentId: ID
"#bson:ignore"
children: [Category] @goField(forceResolver: true)
createdAt: Time! createdAt: Time!
updatedAt: Time! updatedAt: Time!
@ -813,15 +834,15 @@ extend type Subscription {
onTodo: Todo onTodo: Todo
} }
`, BuiltIn: false}, `, BuiltIn: false},
{Name: "../gql/user.gql", Input: `type User { {Name: "../gql/user.gql", Input: `type User
@goExtraField(name: "Password", type: "*string")
@goExtraField(name: "Token", type: "*string") {
id: ID! id: ID!
phone: String phone: String
email: String email: String
type: String type: String
status: String status: String
verified: Boolean @goField(forceResolver: true) verified: Boolean @goField(forceResolver: true)
password: String
token: String
} }
input CreateUserInput { input CreateUserInput {
@ -846,6 +867,34 @@ var parsedSchema = gqlparser.MustLoadSchema(sources...)
// region ***************************** args.gotpl ***************************** // region ***************************** args.gotpl *****************************
func (ec *executionContext) dir_hasRole_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
var err error
args := map[string]any{}
arg0, err := ec.dir_hasRole_argsRole(ctx, rawArgs)
if err != nil {
return nil, err
}
args["role"] = arg0
return args, nil
}
func (ec *executionContext) dir_hasRole_argsRole(
ctx context.Context,
rawArgs map[string]any,
) (models.Role, error) {
if _, ok := rawArgs["role"]; !ok {
var zeroVal models.Role
return zeroVal, nil
}
ctx = graphql.WithPathContext(ctx, graphql.NewPathWithField("role"))
if tmp, ok := rawArgs["role"]; ok {
return ec.unmarshalNRole2gitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐRole(ctx, tmp)
}
var zeroVal models.Role
return zeroVal, nil
}
func (ec *executionContext) field_Mutation_createCategory_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) { func (ec *executionContext) field_Mutation_createCategory_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
var err error var err error
args := map[string]any{} args := map[string]any{}
@ -1353,6 +1402,8 @@ func (ec *executionContext) fieldContext_Category_parent(_ context.Context, fiel
return ec.fieldContext_Category_parent(ctx, field) return ec.fieldContext_Category_parent(ctx, field)
case "parentId": case "parentId":
return ec.fieldContext_Category_parentId(ctx, field) return ec.fieldContext_Category_parentId(ctx, field)
case "children":
return ec.fieldContext_Category_children(ctx, field)
case "createdAt": case "createdAt":
return ec.fieldContext_Category_createdAt(ctx, field) return ec.fieldContext_Category_createdAt(ctx, field)
case "updatedAt": case "updatedAt":
@ -1417,6 +1468,77 @@ func (ec *executionContext) fieldContext_Category_parentId(_ context.Context, fi
return fc, nil return fc, nil
} }
func (ec *executionContext) _Category_children(ctx context.Context, field graphql.CollectedField, obj *models.Category) (ret graphql.Marshaler) {
fc, err := ec.fieldContext_Category_children(ctx, field)
if err != nil {
return graphql.Null
}
ctx = graphql.WithFieldContext(ctx, fc)
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (any, error) {
ctx = rctx // use context from middleware stack in children
return ec.resolvers.Category().Children(rctx, obj)
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
return graphql.Null
}
res := resTmp.([]*models.Category)
fc.Result = res
return ec.marshalOCategory2ᚕᚖgitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐCategory(ctx, field.Selections, res)
}
func (ec *executionContext) fieldContext_Category_children(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "Category",
Field: field,
IsMethod: true,
IsResolver: true,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
switch field.Name {
case "id":
return ec.fieldContext_Category_id(ctx, field)
case "title":
return ec.fieldContext_Category_title(ctx, field)
case "body":
return ec.fieldContext_Category_body(ctx, field)
case "parent":
return ec.fieldContext_Category_parent(ctx, field)
case "parentId":
return ec.fieldContext_Category_parentId(ctx, field)
case "children":
return ec.fieldContext_Category_children(ctx, field)
case "createdAt":
return ec.fieldContext_Category_createdAt(ctx, field)
case "updatedAt":
return ec.fieldContext_Category_updatedAt(ctx, field)
case "createdBy":
return ec.fieldContext_Category_createdBy(ctx, field)
case "createdById":
return ec.fieldContext_Category_createdById(ctx, field)
case "updatedBy":
return ec.fieldContext_Category_updatedBy(ctx, field)
case "updatedById":
return ec.fieldContext_Category_updatedById(ctx, field)
case "owner":
return ec.fieldContext_Category_owner(ctx, field)
case "ownerId":
return ec.fieldContext_Category_ownerId(ctx, field)
}
return nil, fmt.Errorf("no field named %q was found under type Category", field.Name)
},
}
return fc, nil
}
func (ec *executionContext) _Category_createdAt(ctx context.Context, field graphql.CollectedField, obj *models.Category) (ret graphql.Marshaler) { func (ec *executionContext) _Category_createdAt(ctx context.Context, field graphql.CollectedField, obj *models.Category) (ret graphql.Marshaler) {
fc, err := ec.fieldContext_Category_createdAt(ctx, field) fc, err := ec.fieldContext_Category_createdAt(ctx, field)
if err != nil { if err != nil {
@ -1556,10 +1678,6 @@ func (ec *executionContext) fieldContext_Category_createdBy(_ context.Context, f
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -1662,10 +1780,6 @@ func (ec *executionContext) fieldContext_Category_updatedBy(_ context.Context, f
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -1765,10 +1879,6 @@ func (ec *executionContext) fieldContext_Category_owner(_ context.Context, field
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -1871,10 +1981,6 @@ func (ec *executionContext) fieldContext_LoginResponse_user(_ context.Context, f
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -2082,6 +2188,8 @@ func (ec *executionContext) fieldContext_Mutation_createCategory(ctx context.Con
return ec.fieldContext_Category_parent(ctx, field) return ec.fieldContext_Category_parent(ctx, field)
case "parentId": case "parentId":
return ec.fieldContext_Category_parentId(ctx, field) return ec.fieldContext_Category_parentId(ctx, field)
case "children":
return ec.fieldContext_Category_children(ctx, field)
case "createdAt": case "createdAt":
return ec.fieldContext_Category_createdAt(ctx, field) return ec.fieldContext_Category_createdAt(ctx, field)
case "updatedAt": case "updatedAt":
@ -2240,10 +2348,6 @@ func (ec *executionContext) fieldContext_Mutation_createUser(ctx context.Context
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -2311,6 +2415,8 @@ func (ec *executionContext) fieldContext_Query_categories(_ context.Context, fie
return ec.fieldContext_Category_parent(ctx, field) return ec.fieldContext_Category_parent(ctx, field)
case "parentId": case "parentId":
return ec.fieldContext_Category_parentId(ctx, field) return ec.fieldContext_Category_parentId(ctx, field)
case "children":
return ec.fieldContext_Category_children(ctx, field)
case "createdAt": case "createdAt":
return ec.fieldContext_Category_createdAt(ctx, field) return ec.fieldContext_Category_createdAt(ctx, field)
case "updatedAt": case "updatedAt":
@ -2383,6 +2489,8 @@ func (ec *executionContext) fieldContext_Query_category(ctx context.Context, fie
return ec.fieldContext_Category_parent(ctx, field) return ec.fieldContext_Category_parent(ctx, field)
case "parentId": case "parentId":
return ec.fieldContext_Category_parentId(ctx, field) return ec.fieldContext_Category_parentId(ctx, field)
case "children":
return ec.fieldContext_Category_children(ctx, field)
case "createdAt": case "createdAt":
return ec.fieldContext_Category_createdAt(ctx, field) return ec.fieldContext_Category_createdAt(ctx, field)
case "updatedAt": case "updatedAt":
@ -2606,10 +2714,6 @@ func (ec *executionContext) fieldContext_Query_users(_ context.Context, field gr
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -3092,10 +3196,6 @@ func (ec *executionContext) fieldContext_Todo_createdBy(_ context.Context, field
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -3198,10 +3298,6 @@ func (ec *executionContext) fieldContext_Todo_updatedBy(_ context.Context, field
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -3301,10 +3397,6 @@ func (ec *executionContext) fieldContext_Todo_owner(_ context.Context, field gra
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -3737,88 +3829,6 @@ func (ec *executionContext) fieldContext_User_verified(_ context.Context, field
return fc, nil return fc, nil
} }
func (ec *executionContext) _User_password(ctx context.Context, field graphql.CollectedField, obj *models.User) (ret graphql.Marshaler) {
fc, err := ec.fieldContext_User_password(ctx, field)
if err != nil {
return graphql.Null
}
ctx = graphql.WithFieldContext(ctx, fc)
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (any, error) {
ctx = rctx // use context from middleware stack in children
return obj.Password, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
return graphql.Null
}
res := resTmp.(*string)
fc.Result = res
return ec.marshalOString2ᚖstring(ctx, field.Selections, res)
}
func (ec *executionContext) fieldContext_User_password(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "User",
Field: field,
IsMethod: false,
IsResolver: false,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
return nil, errors.New("field of type String does not have child fields")
},
}
return fc, nil
}
func (ec *executionContext) _User_token(ctx context.Context, field graphql.CollectedField, obj *models.User) (ret graphql.Marshaler) {
fc, err := ec.fieldContext_User_token(ctx, field)
if err != nil {
return graphql.Null
}
ctx = graphql.WithFieldContext(ctx, fc)
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (any, error) {
ctx = rctx // use context from middleware stack in children
return obj.Token, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
return graphql.Null
}
res := resTmp.(*string)
fc.Result = res
return ec.marshalOString2ᚖstring(ctx, field.Selections, res)
}
func (ec *executionContext) fieldContext_User_token(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "User",
Field: field,
IsMethod: false,
IsResolver: false,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
return nil, errors.New("field of type String does not have child fields")
},
}
return fc, nil
}
func (ec *executionContext) ___Directive_name(ctx context.Context, field graphql.CollectedField, obj *introspection.Directive) (ret graphql.Marshaler) { func (ec *executionContext) ___Directive_name(ctx context.Context, field graphql.CollectedField, obj *introspection.Directive) (ret graphql.Marshaler) {
fc, err := ec.fieldContext___Directive_name(ctx, field) fc, err := ec.fieldContext___Directive_name(ctx, field)
if err != nil { if err != nil {
@ -6061,6 +6071,39 @@ func (ec *executionContext) _Category(ctx context.Context, sel ast.SelectionSet,
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
case "parentId": case "parentId":
out.Values[i] = ec._Category_parentId(ctx, field, obj) out.Values[i] = ec._Category_parentId(ctx, field, obj)
case "children":
field := field
innerFunc := func(ctx context.Context, _ *graphql.FieldSet) (res graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
}
}()
res = ec._Category_children(ctx, field, obj)
return res
}
if field.Deferrable != nil {
dfs, ok := deferred[field.Deferrable.Label]
di := 0
if ok {
dfs.AddField(field)
di = len(dfs.Values) - 1
} else {
dfs = graphql.NewFieldSet([]graphql.CollectedField{field})
deferred[field.Deferrable.Label] = dfs
}
dfs.Concurrently(di, func(ctx context.Context) graphql.Marshaler {
return innerFunc(ctx, dfs)
})
// don't run the out.Concurrently() call below
out.Values[i] = graphql.Null
continue
}
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
case "createdAt": case "createdAt":
out.Values[i] = ec._Category_createdAt(ctx, field, obj) out.Values[i] = ec._Category_createdAt(ctx, field, obj)
if out.Values[i] == graphql.Null { if out.Values[i] == graphql.Null {
@ -6777,10 +6820,6 @@ func (ec *executionContext) _User(ctx context.Context, sel ast.SelectionSet, obj
} }
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
case "password":
out.Values[i] = ec._User_password(ctx, field, obj)
case "token":
out.Values[i] = ec._User_token(ctx, field, obj)
default: default:
panic("unknown field " + strconv.Quote(field.Name)) panic("unknown field " + strconv.Quote(field.Name))
} }
@ -7223,13 +7262,13 @@ func (ec *executionContext) unmarshalNCreateUserInput2gitᚗfarahtyᚗcomᚋnime
} }
func (ec *executionContext) unmarshalNID2goᚗmongodbᚗorgᚋmongoᚑdriverᚋbsonᚋprimitiveᚐObjectID(ctx context.Context, v any) (primitive.ObjectID, error) { func (ec *executionContext) unmarshalNID2goᚗmongodbᚗorgᚋmongoᚑdriverᚋbsonᚋprimitiveᚐObjectID(ctx context.Context, v any) (primitive.ObjectID, error) {
res, err := app.UnmarshalObjectID(v) res, err := utils.UnmarshalObjectID(v)
return res, graphql.ErrorOnPath(ctx, err) return res, graphql.ErrorOnPath(ctx, err)
} }
func (ec *executionContext) marshalNID2goᚗmongodbᚗorgᚋmongoᚑdriverᚋbsonᚋprimitiveᚐObjectID(ctx context.Context, sel ast.SelectionSet, v primitive.ObjectID) graphql.Marshaler { func (ec *executionContext) marshalNID2goᚗmongodbᚗorgᚋmongoᚑdriverᚋbsonᚋprimitiveᚐObjectID(ctx context.Context, sel ast.SelectionSet, v primitive.ObjectID) graphql.Marshaler {
_ = sel _ = sel
res := app.MarshalObjectID(v) res := utils.MarshalObjectID(v)
if res == graphql.Null { if res == graphql.Null {
if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) {
ec.Errorf(ctx, "the requested element is null which the schema does not allow") ec.Errorf(ctx, "the requested element is null which the schema does not allow")
@ -7257,6 +7296,16 @@ func (ec *executionContext) marshalNLoginResponse2ᚖgitᚗfarahtyᚗcomᚋnimer
return ec._LoginResponse(ctx, sel, v) return ec._LoginResponse(ctx, sel, v)
} }
func (ec *executionContext) unmarshalNRole2gitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐRole(ctx context.Context, v any) (models.Role, error) {
var res models.Role
err := res.UnmarshalGQL(v)
return res, graphql.ErrorOnPath(ctx, err)
}
func (ec *executionContext) marshalNRole2gitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐRole(ctx context.Context, sel ast.SelectionSet, v models.Role) graphql.Marshaler {
return v
}
func (ec *executionContext) unmarshalNString2string(ctx context.Context, v any) (string, error) { func (ec *executionContext) unmarshalNString2string(ctx context.Context, v any) (string, error) {
res, err := graphql.UnmarshalString(v) res, err := graphql.UnmarshalString(v)
return res, graphql.ErrorOnPath(ctx, err) return res, graphql.ErrorOnPath(ctx, err)
@ -7660,6 +7709,47 @@ func (ec *executionContext) marshalOBoolean2ᚖbool(ctx context.Context, sel ast
return res return res
} }
func (ec *executionContext) marshalOCategory2ᚕᚖgitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐCategory(ctx context.Context, sel ast.SelectionSet, v []*models.Category) graphql.Marshaler {
if v == nil {
return graphql.Null
}
ret := make(graphql.Array, len(v))
var wg sync.WaitGroup
isLen1 := len(v) == 1
if !isLen1 {
wg.Add(len(v))
}
for i := range v {
i := i
fc := &graphql.FieldContext{
Index: &i,
Result: &v[i],
}
ctx := graphql.WithFieldContext(ctx, fc)
f := func(i int) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = nil
}
}()
if !isLen1 {
defer wg.Done()
}
ret[i] = ec.marshalOCategory2ᚖgitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐCategory(ctx, sel, v[i])
}
if isLen1 {
f(i)
} else {
go f(i)
}
}
wg.Wait()
return ret
}
func (ec *executionContext) marshalOCategory2ᚖgitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐCategory(ctx context.Context, sel ast.SelectionSet, v *models.Category) graphql.Marshaler { func (ec *executionContext) marshalOCategory2ᚖgitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐCategory(ctx context.Context, sel ast.SelectionSet, v *models.Category) graphql.Marshaler {
if v == nil { if v == nil {
return graphql.Null return graphql.Null
@ -7671,7 +7761,7 @@ func (ec *executionContext) unmarshalOID2ᚖgoᚗmongodbᚗorgᚋmongoᚑdriver
if v == nil { if v == nil {
return nil, nil return nil, nil
} }
res, err := app.UnmarshalObjectID(v) res, err := utils.UnmarshalObjectID(v)
return &res, graphql.ErrorOnPath(ctx, err) return &res, graphql.ErrorOnPath(ctx, err)
} }
@ -7681,7 +7771,7 @@ func (ec *executionContext) marshalOID2ᚖgoᚗmongodbᚗorgᚋmongoᚑdriverᚋ
} }
_ = sel _ = sel
_ = ctx _ = ctx
res := app.MarshalObjectID(*v) res := utils.MarshalObjectID(*v)
return res return res
} }
@ -7701,6 +7791,42 @@ func (ec *executionContext) marshalOStatus2ᚖgitᚗfarahtyᚗcomᚋnimerᚋgo
return v return v
} }
func (ec *executionContext) unmarshalOString2ᚕstringᚄ(ctx context.Context, v any) ([]string, error) {
if v == nil {
return nil, nil
}
var vSlice []any
vSlice = graphql.CoerceList(v)
var err error
res := make([]string, len(vSlice))
for i := range vSlice {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i))
res[i], err = ec.unmarshalNString2string(ctx, vSlice[i])
if err != nil {
return nil, err
}
}
return res, nil
}
func (ec *executionContext) marshalOString2ᚕstringᚄ(ctx context.Context, sel ast.SelectionSet, v []string) graphql.Marshaler {
if v == nil {
return graphql.Null
}
ret := make(graphql.Array, len(v))
for i := range v {
ret[i] = ec.marshalNString2string(ctx, sel, v[i])
}
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
return ret
}
func (ec *executionContext) unmarshalOString2ᚖstring(ctx context.Context, v any) (*string, error) { func (ec *executionContext) unmarshalOString2ᚖstring(ctx context.Context, v any) (*string, error) {
if v == nil { if v == nil {
return nil, nil return nil, nil

View File

@ -1,5 +1,12 @@
directive @auth on FIELD_DEFINITION directive @auth on FIELD_DEFINITION
directive @hasRole(role: Role!) on FIELD_DEFINITION
enum Role {
ADMIN
USER
}
type LoginResponse { type LoginResponse {
user: User! user: User!
accessToken: String! accessToken: String!

View File

@ -1,12 +1,28 @@
directive @goModel(
model: String
models: [String!]
forceGenerate: Boolean
) on OBJECT | INPUT_OBJECT | SCALAR | ENUM | INTERFACE | UNION
directive @goField( directive @goField(
forceResolver: Boolean forceResolver: Boolean
name: String name: String
omittable: Boolean
type: String
) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION ) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION
directive @goTag( directive @goTag(
key: String! key: String!
value: String value: String
) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION ) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION
directive @goExtraField(
name: String
type: String!
overrideTags: String
description: String
) repeatable on OBJECT | INPUT_OBJECT
scalar Time scalar Time
interface Base { interface Base {

View File

@ -8,6 +8,9 @@ type Category implements Base {
parent: Category @goField(forceResolver: true) parent: Category @goField(forceResolver: true)
parentId: ID parentId: ID
"#bson:ignore"
children: [Category] @goField(forceResolver: true)
createdAt: Time! createdAt: Time!
updatedAt: Time! updatedAt: Time!

View File

@ -1,12 +1,12 @@
type User { type User
@goExtraField(name: "Password", type: "*string")
@goExtraField(name: "Token", type: "*string") {
id: ID! id: ID!
phone: String phone: String
email: String email: String
type: String type: String
status: String status: String
verified: Boolean @goField(forceResolver: true) verified: Boolean @goField(forceResolver: true)
password: String
token: String
} }
input CreateUserInput { input CreateUserInput {

View File

@ -1,12 +1,10 @@
schema: schema:
- gql/*.gql - gql/*.gql
exec: exec:
filename: generated/generated.go filename: generated/generated.go
package: generated package: generated
model: model:
filename: models/models_gen.go filename: models/models_gen.go
package: models package: models
@ -27,7 +25,7 @@ autobind:
models: models:
ID: ID:
model: model:
- git.farahty.com/nimer/go-mongo/app.ObjectID - git.farahty.com/nimer/go-mongo/utils.ObjectID
- github.com/99designs/gqlgen/graphql.ID - github.com/99designs/gqlgen/graphql.ID
Int: Int:
model: model:

117
main.go
View File

@ -2,98 +2,101 @@ package main
import ( import (
"context" "context"
"log"
"net/http"
"os" "os"
"os/signal" "os/signal"
"time" "time"
"log"
"net/http"
"git.farahty.com/nimer/go-mongo/app" "git.farahty.com/nimer/go-mongo/app"
"github.com/fatih/color" "github.com/fatih/color"
"github.com/joho/godotenv"
"github.com/redis/go-redis/v9"
) )
func main() { func main() {
// Setup cancelable root context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
color.Yellow("Starting server ...\n") // Panic recovery
defer func() {
if r := recover(); r != nil {
color.Red("🔴 Panic occurred: %v\n", r)
os.Exit(1)
}
}()
if _, exists := os.LookupEnv("MONGO_URI"); !exists { color.Yellow("🚀 Starting server ...\n")
err := godotenv.Load()
// Load and validate config
err := app.LoadConfig()
if err != nil { if err != nil {
log.Fatal("Error loading .env file\n") log.Fatal("🔴 Config error: ", err)
} }
color.Green("✅ .env loaded\n") // Connect to Mongo
dbCancel, err := app.Connect()
if err != nil {
log.Fatalf("🔴 MongoDB connection error: %v", err)
} }
port := os.Getenv("PORT")
if cancel, err := app.Connect(); err != nil {
cancel()
log.Fatal(err)
} else {
defer func() { defer func() {
color.Red("❌ Database Connection Closed\n") color.Red("❌ Closing MongoDB connection...\n")
cancel() dbCancel()
}() if err := app.Mongo.Disconnect(ctx); err != nil {
} log.Fatal("🔴 MongoDB disconnection error: ", err)
defer func() {
if err := app.Mongo.Disconnect(context.Background()); err != nil {
log.Fatal("MogoDB Errors" + err.Error())
} }
}() }()
color.Green("✅ Connected to MongoDB successfully\n")
color.Green("✅ Connected to Database successfully\n") // Load authorization policies using root context
if err := app.LoadAuthorizer(context.Background()); err != nil { if err := app.LoadAuthorizer(ctx); err != nil {
log.Fatal("Authorizer Errors : " + err.Error()) log.Fatal("🔴 Authorizer error: ", err)
} }
color.Green("✅ Authorization policies loaded successfully\n") color.Green("✅ Authorization policies loaded successfully\n")
redisClient := redis.NewClient(&redis.Options{ // Redis
Addr: os.Getenv("REDIS_HOST") + ":" + os.Getenv("REDIS_PORT"), if err := app.InitRedis(ctx); err != nil {
Password: os.Getenv("REDIS_PASSWORD"), // no password set log.Fatal("🔴 Redis connection error: ", err)
})
if _, err := redisClient.Ping(context.Background()).Result(); err != nil {
log.Fatal("Redis Error : " + err.Error())
} }
defer func() {
defer redisClient.Close() color.Red("❌ Closing Redis connection...\n")
_ = app.RedisClient.Close()
}()
color.Green("✅ Connected to Redis cache successfully\n") color.Green("✅ Connected to Redis cache successfully\n")
graphqlServer := createGraphqlServer(redisClient) // Create GraphQL server
graphqlServer := createGraphqlServer()
color.Green("🚀 Server Started at http://localhost:" + port + "\n")
//http.ListenAndServe(":"+port, createRouter(graphqlServer))
// Start HTTP server
server := &http.Server{ server := &http.Server{
Addr: ":" + port, Addr: ":" + app.Config.Port,
WriteTimeout: time.Second * 30,
ReadTimeout: time.Second * 30,
IdleTimeout: time.Second * 30,
Handler: createRouter(graphqlServer), Handler: createRouter(graphqlServer),
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 30 * time.Second,
} }
go server.ListenAndServe() go func() {
color.Green("🌐 Server listening at http://localhost:%s\n", app.Config.Port)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("🔴 Server failed: %v", err)
}
}()
// Wait for interrupt signal to gracefully shut down the server with // Graceful shutdown
// a timeout of 15 seconds.
quit := make(chan os.Signal, 1) quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt) signal.Notify(quit, os.Interrupt)
<-quit <-quit
color.Yellow(" 🎬 Start Shutdown Signal ... ")
ctx, cancelShutdown := context.WithTimeout(context.Background(), 15*time.Second) color.Yellow("🟡 Shutdown signal received, initiating cleanup...")
defer cancelShutdown()
if err := server.Shutdown(ctx); err != nil { // Cancel root context and wait for graceful shutdown
log.Fatal("Server Shutdown:", err) shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 15*time.Second)
defer shutdownCancel()
if err := server.Shutdown(shutdownCtx); err != nil {
log.Fatalf("🔴 Server forced to shutdown: %v", err)
} }
color.Red("❌ Server Exiting")
color.Green("✅ Server shutdown completed gracefully")
} }

View File

@ -29,6 +29,8 @@ type Category struct {
// #bson:ignore // #bson:ignore
Parent *Category `json:"parent,omitempty,omitempty" bson:"-"` Parent *Category `json:"parent,omitempty,omitempty" bson:"-"`
ParentID *primitive.ObjectID `json:"parentId,omitempty" bson:"parentId,omitempty"` ParentID *primitive.ObjectID `json:"parentId,omitempty" bson:"parentId,omitempty"`
// #bson:ignore
Children []*Category `json:"children,omitempty,omitempty" bson:"-"`
CreatedAt time.Time `json:"createdAt" bson:"createdAt,omitempty"` CreatedAt time.Time `json:"createdAt" bson:"createdAt,omitempty"`
UpdatedAt time.Time `json:"updatedAt" bson:"updatedAt,omitempty"` UpdatedAt time.Time `json:"updatedAt" bson:"updatedAt,omitempty"`
// #bson:ignore // #bson:ignore
@ -132,8 +134,63 @@ type User struct {
Type *string `json:"type,omitempty" bson:"type,omitempty"` Type *string `json:"type,omitempty" bson:"type,omitempty"`
Status *string `json:"status,omitempty" bson:"status,omitempty"` Status *string `json:"status,omitempty" bson:"status,omitempty"`
Verified *bool `json:"verified,omitempty" bson:"verified,omitempty"` Verified *bool `json:"verified,omitempty" bson:"verified,omitempty"`
Password *string `json:"password,omitempty" bson:"password,omitempty"` Password *string `json:"-" bson:"Password,omitempty"`
Token *string `json:"token,omitempty" bson:"token,omitempty"` Token *string `json:"-" bson:"Token,omitempty"`
}
type Role string
const (
RoleAdmin Role = "ADMIN"
RoleUser Role = "USER"
)
var AllRole = []Role{
RoleAdmin,
RoleUser,
}
func (e Role) IsValid() bool {
switch e {
case RoleAdmin, RoleUser:
return true
}
return false
}
func (e Role) String() string {
return string(e)
}
func (e *Role) UnmarshalGQL(v any) error {
str, ok := v.(string)
if !ok {
return fmt.Errorf("enums must be strings")
}
*e = Role(str)
if !e.IsValid() {
return fmt.Errorf("%s is not a valid Role", str)
}
return nil
}
func (e Role) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(e.String()))
}
func (e *Role) UnmarshalJSON(b []byte) error {
s, err := strconv.Unquote(string(b))
if err != nil {
return err
}
return e.UnmarshalGQL(s)
}
func (e Role) MarshalJSON() ([]byte, error) {
var buf bytes.Buffer
e.MarshalGQL(&buf)
return buf.Bytes(), nil
} }
type Status string type Status string

View File

@ -11,6 +11,10 @@ func MakeHash(password string) (string, error) {
} }
func (u *User) CheckPassword(password string) bool { func (u *User) CheckPassword(password string) bool {
if u.Password == nil {
return false
}
err := bcrypt.CompareHashAndPassword([]byte(*u.Password), []byte(password)) err := bcrypt.CompareHashAndPassword([]byte(*u.Password), []byte(password))
return err == nil return err == nil
} }

View File

@ -7,6 +7,7 @@ package resolvers
import ( import (
"context" "context"
"git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/generated" "git.farahty.com/nimer/go-mongo/generated"
"git.farahty.com/nimer/go-mongo/models" "git.farahty.com/nimer/go-mongo/models"
authService "git.farahty.com/nimer/go-mongo/services/auth" authService "git.farahty.com/nimer/go-mongo/services/auth"
@ -14,7 +15,7 @@ import (
// Login is the resolver for the login field. // Login is the resolver for the login field.
func (r *mutationResolver) Login(ctx context.Context, input models.LoginInput) (*models.LoginResponse, error) { func (r *mutationResolver) Login(ctx context.Context, input models.LoginInput) (*models.LoginResponse, error) {
return authService.Login(ctx, &input) return authService.Login(ctx, &input, app.RedisClient)
} }
// Mutation returns generated.MutationResolver implementation. // Mutation returns generated.MutationResolver implementation.

View File

@ -22,6 +22,11 @@ func (r *categoryResolver) Parent(ctx context.Context, obj *models.Category) (*m
return categoryService.FindByID(ctx, *obj.ParentID) return categoryService.FindByID(ctx, *obj.ParentID)
} }
// Children is the resolver for the children field.
func (r *categoryResolver) Children(ctx context.Context, obj *models.Category) ([]*models.Category, error) {
return categoryService.FindChildren(ctx, obj.ID)
}
// CreatedBy is the resolver for the createdBy field. // CreatedBy is the resolver for the createdBy field.
func (r *categoryResolver) CreatedBy(ctx context.Context, obj *models.Category) (*models.User, error) { func (r *categoryResolver) CreatedBy(ctx context.Context, obj *models.Category) (*models.User, error) {
return userService.FindById(ctx, obj.CreatedByID) return userService.FindById(ctx, obj.CreatedByID)

View File

@ -15,15 +15,19 @@ import (
// It serves as dependency injection for your app, add any dependencies you require here. // It serves as dependency injection for your app, add any dependencies you require here.
type Resolver struct { type Resolver struct {
Redis *redis.Client Redis redis.UniversalClient
} }
func Subscribe[T any](ctx context.Context, redis *redis.Client, event string) (<-chan *T, error) { func Subscribe[T any](ctx context.Context, redisClient redis.UniversalClient, event string) (<-chan *T, error) {
client, ok := redisClient.(*redis.Client)
if !ok {
return nil, nil // or handle error
}
clientChannel := make(chan *T, 1) clientChannel := make(chan *T, 1)
go func() { go func() {
sub := redis.Subscribe(ctx, event) sub := client.Subscribe(ctx, event)
if _, err := sub.Receive(ctx); err != nil { if _, err := sub.Receive(ctx); err != nil {
return return
@ -52,5 +56,4 @@ func Subscribe[T any](ctx context.Context, redis *redis.Client, event string) (<
}() }()
return clientChannel, nil return clientChannel, nil
} }

View File

@ -12,6 +12,7 @@ import (
"git.farahty.com/nimer/go-mongo/models" "git.farahty.com/nimer/go-mongo/models"
todoService "git.farahty.com/nimer/go-mongo/services/todo" todoService "git.farahty.com/nimer/go-mongo/services/todo"
userService "git.farahty.com/nimer/go-mongo/services/user" userService "git.farahty.com/nimer/go-mongo/services/user"
redis "github.com/redis/go-redis/v9"
"go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/bson/primitive"
) )
@ -24,7 +25,9 @@ func (r *mutationResolver) CreateTodo(ctx context.Context, input models.CreateTo
} }
if objJson, err := json.Marshal(obj); err == nil { if objJson, err := json.Marshal(obj); err == nil {
r.Redis.Publish(ctx, "NEW_TODO_EVENT", objJson) if client, ok := r.Redis.(*redis.Client); ok {
client.Publish(ctx, "NEW_TODO_EVENT", objJson)
}
} }
return obj, nil return obj, nil
@ -42,7 +45,10 @@ func (r *queryResolver) Todo(ctx context.Context, id primitive.ObjectID) (*model
// OnTodo is the resolver for the onTodo field. // OnTodo is the resolver for the onTodo field.
func (r *subscriptionResolver) OnTodo(ctx context.Context) (<-chan *models.Todo, error) { func (r *subscriptionResolver) OnTodo(ctx context.Context) (<-chan *models.Todo, error) {
return Subscribe[models.Todo](ctx, r.Redis, "NEW_TODO_EVENT") if client, ok := r.Redis.(*redis.Client); ok {
return Subscribe[models.Todo](ctx, client, "NEW_TODO_EVENT")
}
return nil, nil
} }
// CreatedBy is the resolver for the createdBy field. // CreatedBy is the resolver for the createdBy field.

View File

@ -3,14 +3,13 @@ package main
import ( import (
"log" "log"
"net/http" "net/http"
"os"
"time" "time"
"git.farahty.com/nimer/go-mongo/app" "git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/controllers" "git.farahty.com/nimer/go-mongo/controllers"
"git.farahty.com/nimer/go-mongo/directives" "git.farahty.com/nimer/go-mongo/directives"
"git.farahty.com/nimer/go-mongo/generated" "git.farahty.com/nimer/go-mongo/generated"
"git.farahty.com/nimer/go-mongo/helpers" "git.farahty.com/nimer/go-mongo/utils"
"github.com/99designs/gqlgen/graphql/handler" "github.com/99designs/gqlgen/graphql/handler"
"github.com/99designs/gqlgen/graphql/handler/extension" "github.com/99designs/gqlgen/graphql/handler/extension"
"github.com/99designs/gqlgen/graphql/handler/transport" "github.com/99designs/gqlgen/graphql/handler/transport"
@ -22,14 +21,13 @@ import (
"github.com/go-chi/httprate" "github.com/go-chi/httprate"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/redis/go-redis/v9"
) )
func createRouter(graphqlServer http.Handler) chi.Router { func createRouter(graphqlServer http.Handler) chi.Router {
router := chi.NewRouter() router := chi.NewRouter()
// Apply middleware // Apply middleware
if os.Getenv("ENV") == "production" { if app.Config.Env == "production" {
router.Use(httprate.LimitByIP(100, 1*time.Minute)) // 100 requests/minute/IP router.Use(httprate.LimitByIP(100, 1*time.Minute)) // 100 requests/minute/IP
} }
@ -41,6 +39,9 @@ func createRouter(graphqlServer http.Handler) chi.Router {
// Custom middleware for Auth // Custom middleware for Auth
router.Use(app.AuthMiddleware) router.Use(app.AuthMiddleware)
router.Use(app.WriterMiddleware)
router.Use(app.RequestMiddleware)
// REST routes // REST routes
router.Mount("/users", controllers.UserRouter()) router.Mount("/users", controllers.UserRouter())
@ -51,10 +52,8 @@ func createRouter(graphqlServer http.Handler) chi.Router {
return router return router
} }
func createGraphqlServer(redisClient *redis.Client) http.Handler { func createGraphqlServer() http.Handler {
cache, err := utils.NewCache(app.RedisClient, 24*time.Hour, "apq:")
cache, err := helpers.NewCache(redisClient, 24*time.Hour, "apq:")
if err != nil { if err != nil {
log.Fatalf("cannot create APQ redis cache: %v", err) log.Fatalf("cannot create APQ redis cache: %v", err)
} }
@ -62,7 +61,7 @@ func createGraphqlServer(redisClient *redis.Client) http.Handler {
// Setup gqlgen with resolvers and Redis client // Setup gqlgen with resolvers and Redis client
schema := generated.Config{ schema := generated.Config{
Resolvers: &resolvers.Resolver{ Resolvers: &resolvers.Resolver{
Redis: redisClient, Redis: app.RedisClient,
}, },
} }
@ -88,10 +87,11 @@ func createGraphqlServer(redisClient *redis.Client) http.Handler {
srv.AddTransport(transport.MultipartForm{}) srv.AddTransport(transport.MultipartForm{})
srv.Use(extension.AutomaticPersistedQuery{Cache: cache}) srv.Use(extension.AutomaticPersistedQuery{Cache: cache})
srv.Use(extension.Introspection{})
// Apply global middleware // Apply global middleware
srv.AroundRootFields(app.RootFieldsAuthorizer) // Check for @auth at root fields srv.AroundRootFields(app.RootFieldsAuthorizer) // Check for @auth at root fields
srv.AroundResponses(app.ExpiryMiddleware) // Token expiry validation //srv.AroundResponses(app.ExpiryMiddleware) // Token expiry validation
// Inject DataLoaders into request context // Inject DataLoaders into request context
return app.LoaderMiddleware(app.NewLoaders(), srv) return app.LoaderMiddleware(app.NewLoaders(), srv)
@ -99,4 +99,6 @@ func createGraphqlServer(redisClient *redis.Client) http.Handler {
func mapDirectives(config *generated.Config) { func mapDirectives(config *generated.Config) {
config.Directives.Auth = directives.Auth config.Directives.Auth = directives.Auth
config.Directives.HasRole = directives.HasRole
} }

View File

@ -3,15 +3,23 @@ package authService
import ( import (
"context" "context"
"fmt" "fmt"
"time"
"git.farahty.com/nimer/go-mongo/app" "git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/models" "git.farahty.com/nimer/go-mongo/models"
"github.com/redis/go-redis/v9"
"go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson"
) )
func Login(ctx context.Context, loginInput *models.LoginInput) (*models.LoginResponse, error) { const maxAttempts = 5
const lockoutDuration = 15 * time.Minute
func Login(ctx context.Context, loginInput *models.LoginInput, redisClient redis.UniversalClient) (*models.LoginResponse, error) {
if locked, _ := isLockedOut(loginInput.Identity, redisClient); locked {
return nil, fmt.Errorf("too many failed attempts, please try again later")
}
// todo : fix the security threats here
filter := bson.D{ filter := bson.D{
{ {
Key: "$or", Key: "$or",
@ -25,12 +33,37 @@ func Login(ctx context.Context, loginInput *models.LoginInput) (*models.LoginRes
user, err := app.FindOne[models.User](ctx, "users", filter) user, err := app.FindOne[models.User](ctx, "users", filter)
if err != nil { if err != nil {
logFailedLoginAttempt(ctx, loginInput.Identity, redisClient) // optional
return nil, err return nil, err
} }
if !user.CheckPassword(loginInput.Password) { if !user.CheckPassword(loginInput.Password) {
return nil, fmt.Errorf("incorrect password") logFailedLoginAttempt(ctx, loginInput.Identity, redisClient) // optional
return nil, fmt.Errorf("invalid identity or password")
} }
clearFailedAttempts(ctx, loginInput.Identity, redisClient)
return successLogin(ctx, user) return successLogin(ctx, user)
} }
func isLockedOut(identity string, redisClient redis.UniversalClient) (bool, error) {
key := fmt.Sprintf("login:fail:%s", identity)
attempts, err := redisClient.Get(context.Background(), key).Int()
if err != nil && err != redis.Nil {
return false, err
}
return attempts >= maxAttempts, nil
}
func logFailedLoginAttempt(ctx context.Context, identity string, redisClient redis.UniversalClient) {
key := fmt.Sprintf("login:fail:%s", identity)
redisClient.Incr(ctx, key)
redisClient.Expire(ctx, key, lockoutDuration)
}
func clearFailedAttempts(ctx context.Context, identity string, redisClient redis.UniversalClient) {
key := fmt.Sprintf("login:fail:%s", identity)
redisClient.Del(ctx, key)
}

View File

@ -1,29 +0,0 @@
package authService
import (
"time"
"github.com/golang-jwt/jwt/v4"
)
func createToken(sub, secret, expiry string, payload interface{}) (*string, error) {
duration, err := time.ParseDuration(expiry)
if err != nil {
return nil, err
}
token := jwt.New(jwt.SigningMethodHS256)
claims := token.Claims.(jwt.MapClaims)
claims["sub"] = sub
claims["exp"] = time.Now().Add(duration).Unix()
claims["data"] = payload
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
return nil, err
}
return &signedToken, nil
}

View File

@ -3,35 +3,41 @@ package authService
import ( import (
"context" "context"
"encoding/hex" "encoding/hex"
"os" "fmt"
"net/http"
"git.farahty.com/nimer/go-mongo/app" "git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/models" "git.farahty.com/nimer/go-mongo/models"
"git.farahty.com/nimer/go-mongo/utils"
"github.com/google/uuid" "github.com/google/uuid"
"go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson"
) )
func successLogin(ctx context.Context, user *models.User) (*models.LoginResponse, error) { func successLogin(ctx context.Context, user *models.User) (*models.LoginResponse, error) {
refresh_secret := os.Getenv("REFRESH_SECRET") refresh_secret := app.Config.RefreshSecret
refresh_expiry := os.Getenv("REFRESH_EXPIRY") refresh_expiry := app.Config.RefreshExpiry
access_secret := os.Getenv("ACCESS_SECRET") access_secret := app.Config.AccessSecret
access_expiry := os.Getenv("ACCESS_EXPIRY") access_expiry := app.Config.AccessExpiry
identity := user.Email var identity string
if identity == nil { if user.Email != nil {
identity = user.Phone identity = *user.Email
} else if user.Phone != nil {
identity = *user.Phone
} else {
return nil, fmt.Errorf("user identity not found")
} }
refreshHandle := hex.EncodeToString([]byte(uuid.NewString())) refreshHandle := hex.EncodeToString([]byte(uuid.NewString()))
refreshToken, err := createToken( refreshToken, err := utils.CreateToken(
refreshHandle, refreshHandle,
refresh_secret, refresh_secret,
refresh_expiry, refresh_expiry,
models.UserJWT{ models.UserJWT{
ID: user.ID.Hex(), ID: user.ID.Hex(),
Identity: *identity, Identity: identity,
}, },
) )
@ -39,13 +45,13 @@ func successLogin(ctx context.Context, user *models.User) (*models.LoginResponse
return nil, err return nil, err
} }
accessToken, err := createToken( accessToken, err := utils.CreateToken(
user.ID.Hex(), user.ID.Hex(),
access_secret, access_secret,
access_expiry, access_expiry,
models.UserJWT{ models.UserJWT{
ID: user.ID.Hex(), ID: user.ID.Hex(),
Identity: *identity, Identity: identity,
}, },
) )
@ -53,14 +59,27 @@ func successLogin(ctx context.Context, user *models.User) (*models.LoginResponse
return nil, err return nil, err
} }
user.Token = &refreshHandle _, err = app.Collection("users").UpdateByID(ctx, user.ID, bson.D{
{Key: "$set", Value: bson.D{
_, err = app.Collection("users").UpdateByID(ctx, user.ID, bson.D{{Key: "$set", Value: user}}) {Key: "token", Value: refreshHandle},
}},
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
w := app.WriterFor(ctx)
http.SetCookie(*w, &http.Cookie{
Name: "access_token",
Value: *accessToken,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
return &models.LoginResponse{ return &models.LoginResponse{
AccessToken: *accessToken, AccessToken: *accessToken,
RefreshToken: *refreshToken, RefreshToken: *refreshToken,

View File

@ -16,6 +16,10 @@ func Find(ctx context.Context) ([]*models.Category, error) {
return app.Find[models.Category](ctx, coll, bson.D{}) return app.Find[models.Category](ctx, coll, bson.D{})
} }
func FindChildren(ctx context.Context, id primitive.ObjectID) ([]*models.Category, error) {
return app.Find[models.Category](ctx, coll, bson.M{"parentId": id})
}
func Create(ctx context.Context, input models.CreateCategoryInput) (*models.Category, error) { func Create(ctx context.Context, input models.CreateCategoryInput) (*models.Category, error) {
return app.InsertOne[models.Category](ctx, coll, input) return app.InsertOne[models.Category](ctx, coll, input)
} }

View File

@ -1,4 +1,4 @@
package helpers package utils
import ( import (
"context" "context"

37
utils/camelize.go Normal file
View File

@ -0,0 +1,37 @@
package utils
import (
"strings"
"unicode"
)
func Camelize(s string) string {
// If input has no delimiters, assume it's already in camelCase or PascalCase
if !strings.ContainsAny(s, "_- ") {
return s[:1] + s[1:] // No change (optionally lowercase first letter: strings.ToLower(s[:1]) + s[1:])
}
// Otherwise, split and camelCase it
parts := strings.FieldsFunc(s, func(r rune) bool {
return r == '_' || r == '-' || unicode.IsSpace(r)
})
if len(parts) == 0 {
return ""
}
var b strings.Builder
b.WriteString(strings.ToLower(parts[0]))
for _, part := range parts[1:] {
if len(part) == 0 {
continue
}
b.WriteString(strings.ToUpper(part[:1]))
if len(part) > 1 {
b.WriteString(part[1:])
}
}
return b.String()
}

View File

@ -1,4 +1,4 @@
package helpers package utils
import ( import (
"crypto/aes" "crypto/aes"

View File

@ -1,4 +1,4 @@
package app package utils
import ( import (
"errors" "errors"
@ -18,7 +18,7 @@ func MarshalObjectID(value primitive.ObjectID) graphql.Marshaler {
} }
func UnmarshalObjectID(value interface{}) (primitive.ObjectID, error) { func UnmarshalObjectID(value any) (primitive.ObjectID, error) {
if str, ok := value.(string); ok { if str, ok := value.(string); ok {

View File

@ -1,17 +1,16 @@
package app package utils
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"os"
"strings" "strings"
"time"
"git.farahty.com/nimer/go-mongo/models"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
) )
func getTokenFromHeader(r *http.Request) (string, error) { func GetTokenFromHeader(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization") authHeader := r.Header.Get("Authorization")
@ -28,9 +27,9 @@ func getTokenFromHeader(r *http.Request) (string, error) {
return strings.TrimSpace(authSlice[1]), nil return strings.TrimSpace(authSlice[1]), nil
} }
func getTokenFromCookie(r *http.Request) (string, error) { func GetTokenFromCookie(r *http.Request) (string, error) {
cookie, err := r.Cookie("app-access-token") cookie, err := r.Cookie("access_token")
if err != nil { if err != nil {
return "", fmt.Errorf("there is no authorization cookie provided") return "", fmt.Errorf("there is no authorization cookie provided")
@ -39,13 +38,13 @@ func getTokenFromCookie(r *http.Request) (string, error) {
return strings.TrimSpace(cookie.Value), nil return strings.TrimSpace(cookie.Value), nil
} }
func getUserFromToken(tokenString string) (*models.UserJWT, error) { func GetUserFromToken[T any](tokenString string, secret string) (*T, error) {
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) { token, err := jwt.Parse(tokenString, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("wrong token format ") return nil, fmt.Errorf("wrong token format ")
} }
return []byte(os.Getenv("ACCESS_SECRET")), nil return []byte(secret), nil
}) })
if err != nil { if err != nil {
@ -56,7 +55,7 @@ func getUserFromToken(tokenString string) (*models.UserJWT, error) {
return nil, fmt.Errorf("token is not valid") return nil, fmt.Errorf("token is not valid")
} }
var user *models.UserJWT var user *T
claims := token.Claims.(jwt.MapClaims) claims := token.Claims.(jwt.MapClaims)
if err := mapstructure.Decode(claims["data"], &user); err != nil { if err := mapstructure.Decode(claims["data"], &user); err != nil {
@ -65,3 +64,25 @@ func getUserFromToken(tokenString string) (*models.UserJWT, error) {
return user, nil return user, nil
} }
func CreateToken(sub, secret, expiry string, payload any) (*string, error) {
duration, err := time.ParseDuration(expiry)
if err != nil {
return nil, err
}
token := jwt.New(jwt.SigningMethodHS256)
claims := token.Claims.(jwt.MapClaims)
claims["sub"] = sub
claims["exp"] = time.Now().Add(duration).Unix()
claims["data"] = payload
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
return nil, err
}
return &signedToken, nil
}