package timewheel import ( "sync" "time" ) // TaskStatus 任务状态 type TaskStatus int const ( TaskStatusPending TaskStatus = iota // 等待中 TaskStatusRunning // 运行中 TaskStatusCompleted // 已完成 TaskStatusCancelled // 已取消 ) // Task 任务结构 type Task struct { ID string Delay time.Duration Interval time.Duration // 循环任务间隔 Callback func() IsCyclic bool Status TaskStatus mu sync.RWMutex } // TimeWheel 时间轮 type TimeWheel struct { tickInterval time.Duration // 时间轮刻度间隔 slotNum int // 槽位数量 slots []*slot // 槽位数组 currentSlot int // 当前槽位 ticker *time.Ticker stopCh chan struct{} tasks map[string]*Task // 任务映射 mu sync.RWMutex } // slot 槽位 type slot struct { tasks map[string]*Task mu sync.RWMutex } // New 创建时间轮 func New(tickInterval time.Duration, slotNum int) *TimeWheel { tw := &TimeWheel{ tickInterval: tickInterval, slotNum: slotNum, slots: make([]*slot, slotNum), currentSlot: 0, stopCh: make(chan struct{}), tasks: make(map[string]*Task), } // 初始化槽位 for i := 0; i < slotNum; i++ { tw.slots[i] = &slot{ tasks: make(map[string]*Task), } } return tw } // Start 启动时间轮 func (tw *TimeWheel) Start() { tw.ticker = time.NewTicker(tw.tickInterval) go tw.run() } // Stop 停止时间轮 func (tw *TimeWheel) Stop() { close(tw.stopCh) if tw.ticker != nil { tw.ticker.Stop() } } // run 运行时间轮 func (tw *TimeWheel) run() { for { select { case <-tw.ticker.C: tw.tick() case <-tw.stopCh: return } } } // tick 时间轮转动 func (tw *TimeWheel) tick() { tw.mu.Lock() slot := tw.slots[tw.currentSlot] tw.currentSlot = (tw.currentSlot + 1) % tw.slotNum tw.mu.Unlock() // 执行当前槽位的任务 slot.mu.Lock() tasksToExecute := make([]*Task, 0, len(slot.tasks)) for _, task := range slot.tasks { tasksToExecute = append(tasksToExecute, task) } slot.mu.Unlock() // 执行任务 for _, task := range tasksToExecute { task.mu.Lock() if task.Status == TaskStatusPending || task.Status == TaskStatusRunning { task.Status = TaskStatusRunning task.mu.Unlock() // 执行回调 if task.Callback != nil { go task.Callback() } // 如果是循环任务,重新添加 if task.IsCyclic { task.mu.Lock() task.Status = TaskStatusPending task.mu.Unlock() // 重新添加到时间轮 tw.addTaskToSlot(task, task.Interval) } else { // 一次性任务,标记为完成 task.mu.Lock() task.Status = TaskStatusCompleted task.mu.Unlock() // 从槽位中移除 slot.mu.Lock() delete(slot.tasks, task.ID) slot.mu.Unlock() // 从任务映射中移除 tw.mu.Lock() delete(tw.tasks, task.ID) tw.mu.Unlock() } } else { task.mu.Unlock() } } } // AddTask 添加一次性任务 func (tw *TimeWheel) AddTask(taskID string, delay time.Duration, callback func()) error { if delay < 0 { delay = 0 } task := &Task{ ID: taskID, Delay: delay, Callback: callback, IsCyclic: false, Status: TaskStatusPending, } tw.mu.Lock() tw.tasks[taskID] = task tw.mu.Unlock() tw.addTaskToSlot(task, delay) return nil } // AddCyclicTask 添加循环任务 func (tw *TimeWheel) AddCyclicTask(taskID string, interval time.Duration, callback func()) error { if interval <= 0 { interval = tw.tickInterval } task := &Task{ ID: taskID, Interval: interval, Callback: callback, IsCyclic: true, Status: TaskStatusPending, } tw.mu.Lock() tw.tasks[taskID] = task tw.mu.Unlock() tw.addTaskToSlot(task, interval) return nil } // addTaskToSlot 将任务添加到对应槽位 func (tw *TimeWheel) addTaskToSlot(task *Task, delay time.Duration) { // 计算需要等待的tick数 ticks := int(delay / tw.tickInterval) if ticks < 1 { ticks = 1 } // 计算目标槽位 tw.mu.RLock() targetSlot := (tw.currentSlot + ticks) % tw.slotNum tw.mu.RUnlock() // 添加到目标槽位 slot := tw.slots[targetSlot] slot.mu.Lock() slot.tasks[task.ID] = task slot.mu.Unlock() } // RemoveTask 移除任务 func (tw *TimeWheel) RemoveTask(taskID string) { tw.mu.Lock() task, exists := tw.tasks[taskID] if exists { delete(tw.tasks, taskID) } tw.mu.Unlock() if !exists { return } // 从所有槽位中移除 for _, slot := range tw.slots { slot.mu.Lock() delete(slot.tasks, taskID) slot.mu.Unlock() } // 标记为已取消 task.mu.Lock() task.Status = TaskStatusCancelled task.mu.Unlock() } // GetTaskStatus 获取任务状态 func (tw *TimeWheel) GetTaskStatus(taskID string) (TaskStatus, bool) { tw.mu.RLock() task, exists := tw.tasks[taskID] tw.mu.RUnlock() if !exists { return TaskStatusCompleted, false } task.mu.RLock() status := task.Status task.mu.RUnlock() return status, true } // TaskCount 获取任务数量 func (tw *TimeWheel) TaskCount() int { tw.mu.RLock() defer tw.mu.RUnlock() return len(tw.tasks) }