package database import ( "context" "fmt" "time" "dd_fiber_api/config" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) // MongoDBClient MongoDB 客户端 type MongoDBClient struct { Client *mongo.Client Database *mongo.Database } // NewMongoDBClient 创建 MongoDB 客户端 func NewMongoDBClient(cfg *config.MongoDBConfig) (*MongoDBClient, error) { if cfg.URI == "" { return nil, fmt.Errorf("MongoDB URI 不能为空") } if cfg.Database == "" { return nil, fmt.Errorf("MongoDB 数据库名不能为空") } // 解析超时时间 var timeout time.Duration if cfg.Timeout != "" { var err error timeout, err = time.ParseDuration(cfg.Timeout) if err != nil { timeout = 10 * time.Second } } else { timeout = 10 * time.Second } ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() // 连接选项 clientOptions := options.Client().ApplyURI(cfg.URI) // 创建客户端 client, err := mongo.Connect(ctx, clientOptions) if err != nil { return nil, fmt.Errorf("连接 MongoDB 失败: %v", err) } // 测试连接 if err := client.Ping(ctx, nil); err != nil { return nil, fmt.Errorf("MongoDB 连接测试失败: %v", err) } database := client.Database(cfg.Database) return &MongoDBClient{ Client: client, Database: database, }, nil } // Close 关闭 MongoDB 连接 func (c *MongoDBClient) Close() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() return c.Client.Disconnect(ctx) } // Collection 获取集合 func (c *MongoDBClient) Collection(name string) *mongo.Collection { return c.Database.Collection(name) }