duidui_fiber/pkg/timewheel/timewheel.go
2026-03-27 10:34:03 +08:00

261 lines
5.1 KiB
Go

package timewheel
import (
"sync"
"time"
)
// TaskStatus 任务状态
type TaskStatus int
const (
TaskStatusPending TaskStatus = iota // 等待中
TaskStatusRunning // 运行中
TaskStatusCompleted // 已完成
TaskStatusCancelled // 已取消
)
// Task 任务结构
type Task struct {
ID string
Delay time.Duration
Interval time.Duration // 循环任务间隔
Callback func()
IsCyclic bool
Status TaskStatus
mu sync.RWMutex
}
// TimeWheel 时间轮
type TimeWheel struct {
tickInterval time.Duration // 时间轮刻度间隔
slotNum int // 槽位数量
slots []*slot // 槽位数组
currentSlot int // 当前槽位
ticker *time.Ticker
stopCh chan struct{}
tasks map[string]*Task // 任务映射
mu sync.RWMutex
}
// slot 槽位
type slot struct {
tasks map[string]*Task
mu sync.RWMutex
}
// New 创建时间轮
func New(tickInterval time.Duration, slotNum int) *TimeWheel {
tw := &TimeWheel{
tickInterval: tickInterval,
slotNum: slotNum,
slots: make([]*slot, slotNum),
currentSlot: 0,
stopCh: make(chan struct{}),
tasks: make(map[string]*Task),
}
// 初始化槽位
for i := 0; i < slotNum; i++ {
tw.slots[i] = &slot{
tasks: make(map[string]*Task),
}
}
return tw
}
// Start 启动时间轮
func (tw *TimeWheel) Start() {
tw.ticker = time.NewTicker(tw.tickInterval)
go tw.run()
}
// Stop 停止时间轮
func (tw *TimeWheel) Stop() {
close(tw.stopCh)
if tw.ticker != nil {
tw.ticker.Stop()
}
}
// run 运行时间轮
func (tw *TimeWheel) run() {
for {
select {
case <-tw.ticker.C:
tw.tick()
case <-tw.stopCh:
return
}
}
}
// tick 时间轮转动
func (tw *TimeWheel) tick() {
tw.mu.Lock()
slot := tw.slots[tw.currentSlot]
tw.currentSlot = (tw.currentSlot + 1) % tw.slotNum
tw.mu.Unlock()
// 执行当前槽位的任务
slot.mu.Lock()
tasksToExecute := make([]*Task, 0, len(slot.tasks))
for _, task := range slot.tasks {
tasksToExecute = append(tasksToExecute, task)
}
slot.mu.Unlock()
// 执行任务
for _, task := range tasksToExecute {
task.mu.Lock()
if task.Status == TaskStatusPending || task.Status == TaskStatusRunning {
task.Status = TaskStatusRunning
task.mu.Unlock()
// 执行回调
if task.Callback != nil {
go task.Callback()
}
// 如果是循环任务,重新添加
if task.IsCyclic {
task.mu.Lock()
task.Status = TaskStatusPending
task.mu.Unlock()
// 重新添加到时间轮
tw.addTaskToSlot(task, task.Interval)
} else {
// 一次性任务,标记为完成
task.mu.Lock()
task.Status = TaskStatusCompleted
task.mu.Unlock()
// 从槽位中移除
slot.mu.Lock()
delete(slot.tasks, task.ID)
slot.mu.Unlock()
// 从任务映射中移除
tw.mu.Lock()
delete(tw.tasks, task.ID)
tw.mu.Unlock()
}
} else {
task.mu.Unlock()
}
}
}
// AddTask 添加一次性任务
func (tw *TimeWheel) AddTask(taskID string, delay time.Duration, callback func()) error {
if delay < 0 {
delay = 0
}
task := &Task{
ID: taskID,
Delay: delay,
Callback: callback,
IsCyclic: false,
Status: TaskStatusPending,
}
tw.mu.Lock()
tw.tasks[taskID] = task
tw.mu.Unlock()
tw.addTaskToSlot(task, delay)
return nil
}
// AddCyclicTask 添加循环任务
func (tw *TimeWheel) AddCyclicTask(taskID string, interval time.Duration, callback func()) error {
if interval <= 0 {
interval = tw.tickInterval
}
task := &Task{
ID: taskID,
Interval: interval,
Callback: callback,
IsCyclic: true,
Status: TaskStatusPending,
}
tw.mu.Lock()
tw.tasks[taskID] = task
tw.mu.Unlock()
tw.addTaskToSlot(task, interval)
return nil
}
// addTaskToSlot 将任务添加到对应槽位
func (tw *TimeWheel) addTaskToSlot(task *Task, delay time.Duration) {
// 计算需要等待的tick数
ticks := int(delay / tw.tickInterval)
if ticks < 1 {
ticks = 1
}
// 计算目标槽位
tw.mu.RLock()
targetSlot := (tw.currentSlot + ticks) % tw.slotNum
tw.mu.RUnlock()
// 添加到目标槽位
slot := tw.slots[targetSlot]
slot.mu.Lock()
slot.tasks[task.ID] = task
slot.mu.Unlock()
}
// RemoveTask 移除任务
func (tw *TimeWheel) RemoveTask(taskID string) {
tw.mu.Lock()
task, exists := tw.tasks[taskID]
if exists {
delete(tw.tasks, taskID)
}
tw.mu.Unlock()
if !exists {
return
}
// 从所有槽位中移除
for _, slot := range tw.slots {
slot.mu.Lock()
delete(slot.tasks, taskID)
slot.mu.Unlock()
}
// 标记为已取消
task.mu.Lock()
task.Status = TaskStatusCancelled
task.mu.Unlock()
}
// GetTaskStatus 获取任务状态
func (tw *TimeWheel) GetTaskStatus(taskID string) (TaskStatus, bool) {
tw.mu.RLock()
task, exists := tw.tasks[taskID]
tw.mu.RUnlock()
if !exists {
return TaskStatusCompleted, false
}
task.mu.RLock()
status := task.Status
task.mu.RUnlock()
return status, true
}
// TaskCount 获取任务数量
func (tw *TimeWheel) TaskCount() int {
tw.mu.RLock()
defer tw.mu.RUnlock()
return len(tw.tasks)
}