129 lines
3.1 KiB
Go
129 lines
3.1 KiB
Go
package config
|
||
|
||
import (
|
||
"errors"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"gopkg.in/yaml.v3"
|
||
)
|
||
|
||
// Config 是 YAML 配置结构(只用 YAML,不依赖环境变量)。
|
||
type Config struct {
|
||
Server ServerConfig `yaml:"server"`
|
||
JWT JWTConfig `yaml:"jwt"`
|
||
DB DBConfig `yaml:"db"`
|
||
}
|
||
|
||
type ServerConfig struct {
|
||
Addr string `yaml:"addr"`
|
||
GinMode string `yaml:"gin_mode"`
|
||
GracefulTimeoutSeconds int `yaml:"graceful_timeout_seconds"`
|
||
ReadHeaderTimeoutSeconds int `yaml:"read_header_timeout_seconds"`
|
||
}
|
||
|
||
type JWTConfig struct {
|
||
Secret string `yaml:"secret"`
|
||
ExpirySeconds int `yaml:"expiry_seconds"`
|
||
}
|
||
|
||
type DBConfig struct {
|
||
Driver string `yaml:"driver"`
|
||
DSN string `yaml:"dsn"`
|
||
MaxOpenConns int `yaml:"max_open_conns"`
|
||
MaxIdleConns int `yaml:"max_idle_conns"`
|
||
ConnMaxLifetimeSeconds int `yaml:"conn_max_lifetime_seconds"`
|
||
}
|
||
|
||
// Load 读取 config/config.yaml。
|
||
func Load() Config {
|
||
path := filepath.Join("config", "config.yaml")
|
||
raw, err := os.ReadFile(path)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
var cfg Config
|
||
if err := yaml.Unmarshal(raw, &cfg); err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
// defaults
|
||
if cfg.Server.Addr == "" {
|
||
cfg.Server.Addr = ":8080"
|
||
}
|
||
if cfg.Server.GinMode == "" {
|
||
cfg.Server.GinMode = gin.ReleaseMode
|
||
}
|
||
|
||
ginModeNorm := strings.ToLower(strings.TrimSpace(cfg.Server.GinMode))
|
||
switch ginModeNorm {
|
||
case "debug":
|
||
cfg.Server.GinMode = gin.DebugMode
|
||
case "test":
|
||
cfg.Server.GinMode = gin.TestMode
|
||
case "release":
|
||
cfg.Server.GinMode = gin.ReleaseMode
|
||
default:
|
||
// 若未知,按 release 兜底。
|
||
cfg.Server.GinMode = gin.ReleaseMode
|
||
}
|
||
|
||
if cfg.Server.GracefulTimeoutSeconds <= 0 {
|
||
cfg.Server.GracefulTimeoutSeconds = 5
|
||
}
|
||
if cfg.Server.ReadHeaderTimeoutSeconds <= 0 {
|
||
cfg.Server.ReadHeaderTimeoutSeconds = 5
|
||
}
|
||
|
||
if cfg.JWT.Secret == "" {
|
||
cfg.JWT.Secret = "dev-secret-change-me"
|
||
}
|
||
if cfg.JWT.ExpirySeconds <= 0 {
|
||
cfg.JWT.ExpirySeconds = 3600
|
||
}
|
||
|
||
if cfg.DB.Driver == "" {
|
||
cfg.DB.Driver = "sqlite"
|
||
}
|
||
if cfg.DB.DSN == "" {
|
||
return cfg // 交给 model 层校验并给出更明确错误
|
||
}
|
||
|
||
// 连接池 defaults(可选)
|
||
if cfg.DB.MaxOpenConns <= 0 {
|
||
cfg.DB.MaxOpenConns = 10
|
||
}
|
||
if cfg.DB.MaxIdleConns <= 0 {
|
||
cfg.DB.MaxIdleConns = 5
|
||
}
|
||
if cfg.DB.ConnMaxLifetimeSeconds <= 0 {
|
||
cfg.DB.ConnMaxLifetimeSeconds = 3600
|
||
}
|
||
|
||
return cfg
|
||
}
|
||
|
||
// ServerDuration 便于下层直接使用 time.Duration。
|
||
func (c Config) ServerDuration() (gracefulTimeout time.Duration, readHeaderTimeout time.Duration) {
|
||
if c.Server.GracefulTimeoutSeconds <= 0 {
|
||
return 5 * time.Second, 5 * time.Second
|
||
}
|
||
if c.Server.ReadHeaderTimeoutSeconds <= 0 {
|
||
return time.Duration(c.Server.GracefulTimeoutSeconds) * time.Second, 5 * time.Second
|
||
}
|
||
return time.Duration(c.Server.GracefulTimeoutSeconds) * time.Second, time.Duration(c.Server.ReadHeaderTimeoutSeconds) * time.Second
|
||
}
|
||
|
||
// Validate 校验配置的必填项(用于启动时报错)。
|
||
func (c Config) Validate() error {
|
||
if c.DB.DSN == "" {
|
||
return errors.New("db.dsn is empty in config/config.yaml")
|
||
}
|
||
return nil
|
||
}
|
||
|