307 lines
6.6 KiB
Go
307 lines
6.6 KiB
Go
package scheduler
|
||
|
||
import (
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"net/http"
|
||
"sync"
|
||
"time"
|
||
|
||
"dd_fiber_api/pkg/snowflake"
|
||
"dd_fiber_api/pkg/timewheel"
|
||
)
|
||
|
||
// Service 调度器服务
|
||
type Service struct {
|
||
timeWheel *timewheel.TimeWheel
|
||
taskInfos map[string]*TaskInfo // 存储任务信息(内存)
|
||
mu sync.RWMutex
|
||
}
|
||
|
||
// NewService 创建调度器服务
|
||
func NewService(tickInterval time.Duration, slotNum int) *Service {
|
||
service := &Service{
|
||
timeWheel: timewheel.New(tickInterval, slotNum),
|
||
taskInfos: make(map[string]*TaskInfo),
|
||
}
|
||
service.timeWheel.Start()
|
||
return service
|
||
}
|
||
|
||
// Stop 停止服务
|
||
func (s *Service) Stop() {
|
||
s.timeWheel.Stop()
|
||
}
|
||
|
||
// AddTask 添加任务
|
||
func (s *Service) AddTask(req *AddTaskRequest) (*AddTaskResponse, error) {
|
||
// 自动生成 task_id(如果未提供)
|
||
taskID := req.TaskID
|
||
if taskID == "" {
|
||
taskID = snowflake.GenerateID()
|
||
}
|
||
|
||
s.mu.Lock()
|
||
// 检查任务是否已存在
|
||
if _, exists := s.taskInfos[taskID]; exists {
|
||
s.mu.Unlock()
|
||
return &AddTaskResponse{
|
||
Success: false,
|
||
Message: "任务ID已存在",
|
||
TaskID: taskID,
|
||
BusinessKey: req.BusinessKey,
|
||
}, nil
|
||
}
|
||
s.mu.Unlock()
|
||
|
||
// 创建任务回调函数
|
||
callback := func() {
|
||
s.executeTask(req, taskID)
|
||
}
|
||
|
||
var err error
|
||
switch req.TaskType {
|
||
case TaskTypeOnce:
|
||
// 一次性任务
|
||
delay := time.Duration(req.DelayMs) * time.Millisecond
|
||
err = s.timeWheel.AddTask(taskID, delay, callback)
|
||
case TaskTypeCyclic:
|
||
// 循环任务
|
||
interval := time.Duration(req.IntervalMs) * time.Millisecond
|
||
err = s.timeWheel.AddCyclicTask(taskID, interval, callback)
|
||
default:
|
||
return &AddTaskResponse{
|
||
Success: false,
|
||
Message: "不支持的任务类型",
|
||
}, nil
|
||
}
|
||
|
||
if err != nil {
|
||
return &AddTaskResponse{
|
||
Success: false,
|
||
Message: fmt.Sprintf("添加任务失败: %v", err),
|
||
}, nil
|
||
}
|
||
|
||
// 保存任务信息到内存
|
||
s.mu.Lock()
|
||
s.taskInfos[taskID] = &TaskInfo{
|
||
TaskID: taskID,
|
||
TaskType: req.TaskType,
|
||
Status: TaskStatusPending,
|
||
DelayMs: req.DelayMs,
|
||
IntervalMs: req.IntervalMs,
|
||
Metadata: req.Metadata,
|
||
}
|
||
s.mu.Unlock()
|
||
|
||
return &AddTaskResponse{
|
||
Success: true,
|
||
Message: "任务添加成功",
|
||
TaskID: taskID,
|
||
BusinessKey: req.BusinessKey,
|
||
}, nil
|
||
}
|
||
|
||
// RemoveTask 删除任务
|
||
func (s *Service) RemoveTask(taskID string) (*RemoveTaskResponse, error) {
|
||
if taskID == "" {
|
||
return &RemoveTaskResponse{
|
||
Success: false,
|
||
Message: "task_id不能为空",
|
||
}, nil
|
||
}
|
||
|
||
s.mu.Lock()
|
||
_, exists := s.taskInfos[taskID]
|
||
if exists {
|
||
delete(s.taskInfos, taskID)
|
||
}
|
||
s.mu.Unlock()
|
||
|
||
if !exists {
|
||
return &RemoveTaskResponse{
|
||
Success: false,
|
||
Message: "任务不存在",
|
||
}, nil
|
||
}
|
||
|
||
// 从时间轮中删除
|
||
s.timeWheel.RemoveTask(taskID)
|
||
|
||
return &RemoveTaskResponse{
|
||
Success: true,
|
||
Message: "任务删除成功",
|
||
}, nil
|
||
}
|
||
|
||
// GetTaskStatus 查询任务状态
|
||
func (s *Service) GetTaskStatus(taskID string) (*GetTaskStatusResponse, error) {
|
||
if taskID == "" {
|
||
return &GetTaskStatusResponse{
|
||
Exists: false,
|
||
Message: "task_id不能为空",
|
||
}, nil
|
||
}
|
||
|
||
status, exists := s.timeWheel.GetTaskStatus(taskID)
|
||
if !exists {
|
||
return &GetTaskStatusResponse{
|
||
Exists: false,
|
||
Status: TaskStatusCompleted,
|
||
Message: "任务不存在或已完成",
|
||
}, nil
|
||
}
|
||
|
||
// 转换状态
|
||
var taskStatus TaskStatus
|
||
switch status {
|
||
case timewheel.TaskStatusPending:
|
||
taskStatus = TaskStatusPending
|
||
case timewheel.TaskStatusRunning:
|
||
taskStatus = TaskStatusRunning
|
||
case timewheel.TaskStatusCompleted:
|
||
taskStatus = TaskStatusCompleted
|
||
case timewheel.TaskStatusCancelled:
|
||
taskStatus = TaskStatusCancelled
|
||
default:
|
||
taskStatus = TaskStatusPending
|
||
}
|
||
|
||
return &GetTaskStatusResponse{
|
||
Exists: true,
|
||
Status: taskStatus,
|
||
Message: "查询成功",
|
||
}, nil
|
||
}
|
||
|
||
// GetTaskCount 获取任务数量
|
||
func (s *Service) GetTaskCount() *GetTaskCountResponse {
|
||
count := s.timeWheel.TaskCount()
|
||
return &GetTaskCountResponse{
|
||
Count: count,
|
||
}
|
||
}
|
||
|
||
// ListTasks 列出所有任务
|
||
func (s *Service) ListTasks(page, pageSize int) *ListTasksResponse {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
|
||
// 收集所有任务
|
||
allTasks := make([]TaskInfo, 0, len(s.taskInfos))
|
||
for _, taskInfo := range s.taskInfos {
|
||
// 更新状态
|
||
status, exists := s.timeWheel.GetTaskStatus(taskInfo.TaskID)
|
||
if exists {
|
||
switch status {
|
||
case timewheel.TaskStatusPending:
|
||
taskInfo.Status = TaskStatusPending
|
||
case timewheel.TaskStatusRunning:
|
||
taskInfo.Status = TaskStatusRunning
|
||
case timewheel.TaskStatusCompleted:
|
||
taskInfo.Status = TaskStatusCompleted
|
||
case timewheel.TaskStatusCancelled:
|
||
taskInfo.Status = TaskStatusCancelled
|
||
}
|
||
}
|
||
allTasks = append(allTasks, *taskInfo)
|
||
}
|
||
|
||
total := len(allTasks)
|
||
|
||
// 分页
|
||
if page < 1 {
|
||
page = 1
|
||
}
|
||
if pageSize < 1 {
|
||
pageSize = 10
|
||
}
|
||
if pageSize > 100 {
|
||
pageSize = 100
|
||
}
|
||
|
||
start := (page - 1) * pageSize
|
||
end := start + pageSize
|
||
|
||
if start >= total {
|
||
return &ListTasksResponse{
|
||
Tasks: []TaskInfo{},
|
||
Total: total,
|
||
Page: page,
|
||
PageSize: pageSize,
|
||
}
|
||
}
|
||
|
||
if end > total {
|
||
end = total
|
||
}
|
||
|
||
tasks := allTasks[start:end]
|
||
|
||
return &ListTasksResponse{
|
||
Tasks: tasks,
|
||
Total: total,
|
||
Page: page,
|
||
PageSize: pageSize,
|
||
}
|
||
}
|
||
|
||
// executeTask 执行任务
|
||
func (s *Service) executeTask(req *AddTaskRequest, taskID string) {
|
||
startTime := time.Now()
|
||
|
||
// 如果有回调URL,发送HTTP请求
|
||
if req.CallbackURL != "" {
|
||
s.sendHttpCallback(req, taskID, startTime)
|
||
} else {
|
||
log.Printf("✅ 任务执行完成(无回调URL): %s", taskID)
|
||
}
|
||
|
||
// 更新任务状态
|
||
s.mu.Lock()
|
||
if taskInfo, exists := s.taskInfos[taskID]; exists {
|
||
// 如果是一次性任务,执行后删除
|
||
if req.TaskType == TaskTypeOnce {
|
||
taskInfo.Status = TaskStatusCompleted
|
||
delete(s.taskInfos, taskID)
|
||
} else {
|
||
// 循环任务,保持运行状态
|
||
taskInfo.Status = TaskStatusRunning
|
||
}
|
||
}
|
||
s.mu.Unlock()
|
||
}
|
||
|
||
// sendHttpCallback 发送HTTP回调请求
|
||
func (s *Service) sendHttpCallback(req *AddTaskRequest, taskID string, startTime time.Time) {
|
||
log.Printf("🔔 开始HTTP回调: %s", req.CallbackURL)
|
||
|
||
// 创建HTTP客户端(设置超时)
|
||
client := &http.Client{
|
||
Timeout: 10 * time.Second,
|
||
}
|
||
|
||
// 发送GET请求
|
||
resp, err := client.Get(req.CallbackURL)
|
||
if err != nil {
|
||
log.Printf("❌ HTTP回调失败: %s, 错误: %v", req.CallbackURL, err)
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 读取响应内容
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
log.Printf("❌ 读取响应失败: %v", err)
|
||
return
|
||
}
|
||
|
||
// 打印响应结果
|
||
log.Printf("✅ HTTP回调成功!")
|
||
log.Printf(" URL: %s", req.CallbackURL)
|
||
log.Printf(" 状态码: %d", resp.StatusCode)
|
||
log.Printf(" 响应数据: %s", string(body))
|
||
}
|