package leaf import ( "context" "fmt" "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) type RepositoryInterface[T Document] interface { Collection() *mongo.Collection Find(ctx context.Context, filter interface{}, opts ...*options.FindOptions) ([]T, error) FindOne(ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) (T, error) FindByID(ctx context.Context, id string, opts ...*options.FindOneOptions) (T, error) Create(ctx context.Context, data T, opts ...*options.InsertOneOptions) (T, error) Update(ctx context.Context, data T, opts ...*options.UpdateOptions) (T, error) UpdateOne(ctx context.Context, filter interface{}, data T, opts ...*options.UpdateOptions) (T, error) Delete(ctx context.Context, data T, opts ...*options.DeleteOptions) error DeleteOne(ctx context.Context, filter interface{}, opts ...*options.DeleteOptions) error CountDocuments(ctx context.Context, filter interface{}, opts ...*options.CountOptions) (int, error) } type Repository[T Document] struct { collection *mongo.Collection indexes []mongo.IndexModel } func NewRepository[T Document](collection *mongo.Collection, indexes []mongo.IndexModel) *Repository[T] { return &Repository[T]{collection: collection, indexes: withTimestampIndexes(indexes)} } func (s *Repository[T]) Collection() *mongo.Collection { return s.collection } func (s *Repository[T]) Find(ctx context.Context, filter interface{}, opts ...*options.FindOptions) ([]T, error) { cursor, err := s.collection.Find(ctx, filter, opts...) if err != nil { return nil, repositoryError(s.collection, "Find", err) } defer cursor.Close(ctx) var data []T if err := cursor.All(ctx, &data); err != nil { return nil, repositoryError(s.collection, "Find", err) } return data, nil } func (s *Repository[T]) FindOne(ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) (T, error) { var data T if err := s.collection.FindOne(ctx, filter, opts...).Decode(&data); err != nil { return data, repositoryError(s.collection, "FindOne", err) } return data, nil } func (s *Repository[T]) FindByID(ctx context.Context, id string, opts ...*options.FindOneOptions) (T, error) { return s.FindOne(ctx, bson.M{"_id": ObjectIDFromHex(id)}, opts...) } func (s *Repository[T]) Create(ctx context.Context, data T, opts ...*options.InsertOneOptions) (T, error) { data.SetCreatedAt(time.Now().UTC()) data.SetUpdatedAt(time.Now().UTC()) res, err := s.collection.InsertOne(ctx, data, opts...) if err != nil { return data, repositoryError(s.collection, "Create", err) } data.SetID(res.InsertedID.(primitive.ObjectID)) return data, nil } func (s *Repository[T]) Update(ctx context.Context, data T, opts ...*options.UpdateOptions) (T, error) { return s.UpdateOne(ctx, bson.M{"_id": data.ID()}, data, opts...) } func (s *Repository[T]) UpdateOne(ctx context.Context, filter interface{}, data T, opts ...*options.UpdateOptions) (T, error) { data.SetUpdatedAt(time.Now().UTC()) if _, err := s.collection.UpdateOne(ctx, filter, bson.M{"$set": data}, opts...); err != nil { return data, repositoryError(s.collection, "UpdateOne", err) } return data, nil } func (s *Repository[T]) Delete(ctx context.Context, data T, opts ...*options.DeleteOptions) error { return s.DeleteOne(ctx, bson.M{"_id": data.ID()}, opts...) } func (s *Repository[T]) DeleteOne(ctx context.Context, filter interface{}, opts ...*options.DeleteOptions) error { if _, err := s.collection.DeleteOne(ctx, filter, opts...); err != nil { return repositoryError(s.collection, "DeleteOne", err) } return nil } func (s *Repository[T]) CountDocuments(ctx context.Context, filter interface{}, opts ...*options.CountOptions) (int, error) { count, err := s.collection.CountDocuments(ctx, filter, opts...) return int(count), err } func (s *Repository[T]) EnsureIndexes(ctx context.Context, createIndexes bool) ([]IndexMessage, error) { return ensureIndexes(ctx, s.collection, s.indexes, createIndexes) } func repositoryError(collection *mongo.Collection, op string, err error) error { return fmt.Errorf("%s.%s: %v", collection.Name(), op, err) }