leaf/repository.go
2024-01-12 12:20:15 +03:00

118 lines
4.0 KiB
Go

package leaf
import (
"context"
"fmt"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type RepositoryInterface[T Document] interface {
Find(ctx context.Context, filter interface{}, opts ...*options.FindOptions) ([]T, error)
FindOne(ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) (T, error)
FindByID(ctx context.Context, id string, opts ...*options.FindOneOptions) (T, error)
Create(ctx context.Context, data T, opts ...*options.InsertOneOptions) (T, error)
Update(ctx context.Context, data T, opts ...*options.UpdateOptions) (T, error)
UpdateOne(ctx context.Context, filter interface{}, data T, opts ...*options.UpdateOptions) (T, error)
Delete(ctx context.Context, data T, opts ...*options.DeleteOptions) error
DeleteOne(ctx context.Context, filter interface{}, opts ...*options.DeleteOptions) error
CountDocuments(ctx context.Context, filter interface{}, opts ...*options.CountOptions) (int, error)
}
type Repository[T Document] struct {
collection *mongo.Collection
indexes []mongo.IndexModel
}
func NewRepository[T Document](collection *mongo.Collection, indexes []mongo.IndexModel) *Repository[T] {
return &Repository[T]{collection: collection, indexes: withTimestampIndexes(indexes)}
}
func (s *Repository[T]) Find(ctx context.Context, filter interface{}, opts ...*options.FindOptions) ([]T, error) {
cursor, err := s.collection.Find(ctx, filter, opts...)
if err != nil {
return nil, repositoryError(s.collection, "Find", err)
}
defer cursor.Close(ctx)
var data []T
if err := cursor.All(ctx, &data); err != nil {
return nil, repositoryError(s.collection, "Find", err)
}
return data, nil
}
func (s *Repository[T]) FindOne(ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) (T, error) {
var data T
if err := s.collection.FindOne(ctx, filter, opts...).Decode(&data); err != nil {
return data, repositoryError(s.collection, "FindOne", err)
}
return data, nil
}
func (s *Repository[T]) FindByID(ctx context.Context, id string, opts ...*options.FindOneOptions) (T, error) {
return s.FindOne(ctx, bson.M{"_id": ObjectIDFromHex(id)}, opts...)
}
func (s *Repository[T]) Create(ctx context.Context, data T, opts ...*options.InsertOneOptions) (T, error) {
data.SetCreatedAt(time.Now().UTC())
data.SetUpdatedAt(time.Now().UTC())
res, err := s.collection.InsertOne(ctx, data, opts...)
if err != nil {
return data, repositoryError(s.collection, "Create", err)
}
data.SetID(res.InsertedID.(primitive.ObjectID))
return data, nil
}
func (s *Repository[T]) Update(ctx context.Context, data T, opts ...*options.UpdateOptions) (T, error) {
return s.UpdateOne(ctx, bson.M{"_id": data.ID()}, data, opts...)
}
func (s *Repository[T]) UpdateOne(ctx context.Context, filter interface{}, data T, opts ...*options.UpdateOptions) (T, error) {
data.SetUpdatedAt(time.Now().UTC())
if _, err := s.collection.UpdateOne(ctx, filter, bson.M{"$set": data}, opts...); err != nil {
return data, repositoryError(s.collection, "UpdateOne", err)
}
return data, nil
}
func (s *Repository[T]) Delete(ctx context.Context, data T, opts ...*options.DeleteOptions) error {
return s.DeleteOne(ctx, bson.M{"_id": data.ID()}, opts...)
}
func (s *Repository[T]) DeleteOne(ctx context.Context, filter interface{}, opts ...*options.DeleteOptions) error {
if _, err := s.collection.DeleteOne(ctx, filter, opts...); err != nil {
return repositoryError(s.collection, "DeleteOne", err)
}
return nil
}
func (s *Repository[T]) CountDocuments(ctx context.Context, filter interface{}, opts ...*options.CountOptions) (int, error) {
count, err := s.collection.CountDocuments(ctx, filter, opts...)
return int(count), err
}
func (s *Repository[T]) EnsureIndexes(ctx context.Context, createIndexes bool) ([]IndexMessage, error) {
return ensureIndexes(ctx, s.collection, s.indexes, createIndexes)
}
func repositoryError(collection *mongo.Collection, op string, err error) error {
return fmt.Errorf("%s.%s: %v", collection.Name(), op, err)
}