duidui_fiber/internal/scheduler/service.go
2026-03-27 10:34:03 +08:00

307 lines
6.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package scheduler
import (
"fmt"
"io"
"log"
"net/http"
"sync"
"time"
"dd_fiber_api/pkg/snowflake"
"dd_fiber_api/pkg/timewheel"
)
// Service 调度器服务
type Service struct {
timeWheel *timewheel.TimeWheel
taskInfos map[string]*TaskInfo // 存储任务信息(内存)
mu sync.RWMutex
}
// NewService 创建调度器服务
func NewService(tickInterval time.Duration, slotNum int) *Service {
service := &Service{
timeWheel: timewheel.New(tickInterval, slotNum),
taskInfos: make(map[string]*TaskInfo),
}
service.timeWheel.Start()
return service
}
// Stop 停止服务
func (s *Service) Stop() {
s.timeWheel.Stop()
}
// AddTask 添加任务
func (s *Service) AddTask(req *AddTaskRequest) (*AddTaskResponse, error) {
// 自动生成 task_id如果未提供
taskID := req.TaskID
if taskID == "" {
taskID = snowflake.GenerateID()
}
s.mu.Lock()
// 检查任务是否已存在
if _, exists := s.taskInfos[taskID]; exists {
s.mu.Unlock()
return &AddTaskResponse{
Success: false,
Message: "任务ID已存在",
TaskID: taskID,
BusinessKey: req.BusinessKey,
}, nil
}
s.mu.Unlock()
// 创建任务回调函数
callback := func() {
s.executeTask(req, taskID)
}
var err error
switch req.TaskType {
case TaskTypeOnce:
// 一次性任务
delay := time.Duration(req.DelayMs) * time.Millisecond
err = s.timeWheel.AddTask(taskID, delay, callback)
case TaskTypeCyclic:
// 循环任务
interval := time.Duration(req.IntervalMs) * time.Millisecond
err = s.timeWheel.AddCyclicTask(taskID, interval, callback)
default:
return &AddTaskResponse{
Success: false,
Message: "不支持的任务类型",
}, nil
}
if err != nil {
return &AddTaskResponse{
Success: false,
Message: fmt.Sprintf("添加任务失败: %v", err),
}, nil
}
// 保存任务信息到内存
s.mu.Lock()
s.taskInfos[taskID] = &TaskInfo{
TaskID: taskID,
TaskType: req.TaskType,
Status: TaskStatusPending,
DelayMs: req.DelayMs,
IntervalMs: req.IntervalMs,
Metadata: req.Metadata,
}
s.mu.Unlock()
return &AddTaskResponse{
Success: true,
Message: "任务添加成功",
TaskID: taskID,
BusinessKey: req.BusinessKey,
}, nil
}
// RemoveTask 删除任务
func (s *Service) RemoveTask(taskID string) (*RemoveTaskResponse, error) {
if taskID == "" {
return &RemoveTaskResponse{
Success: false,
Message: "task_id不能为空",
}, nil
}
s.mu.Lock()
_, exists := s.taskInfos[taskID]
if exists {
delete(s.taskInfos, taskID)
}
s.mu.Unlock()
if !exists {
return &RemoveTaskResponse{
Success: false,
Message: "任务不存在",
}, nil
}
// 从时间轮中删除
s.timeWheel.RemoveTask(taskID)
return &RemoveTaskResponse{
Success: true,
Message: "任务删除成功",
}, nil
}
// GetTaskStatus 查询任务状态
func (s *Service) GetTaskStatus(taskID string) (*GetTaskStatusResponse, error) {
if taskID == "" {
return &GetTaskStatusResponse{
Exists: false,
Message: "task_id不能为空",
}, nil
}
status, exists := s.timeWheel.GetTaskStatus(taskID)
if !exists {
return &GetTaskStatusResponse{
Exists: false,
Status: TaskStatusCompleted,
Message: "任务不存在或已完成",
}, nil
}
// 转换状态
var taskStatus TaskStatus
switch status {
case timewheel.TaskStatusPending:
taskStatus = TaskStatusPending
case timewheel.TaskStatusRunning:
taskStatus = TaskStatusRunning
case timewheel.TaskStatusCompleted:
taskStatus = TaskStatusCompleted
case timewheel.TaskStatusCancelled:
taskStatus = TaskStatusCancelled
default:
taskStatus = TaskStatusPending
}
return &GetTaskStatusResponse{
Exists: true,
Status: taskStatus,
Message: "查询成功",
}, nil
}
// GetTaskCount 获取任务数量
func (s *Service) GetTaskCount() *GetTaskCountResponse {
count := s.timeWheel.TaskCount()
return &GetTaskCountResponse{
Count: count,
}
}
// ListTasks 列出所有任务
func (s *Service) ListTasks(page, pageSize int) *ListTasksResponse {
s.mu.RLock()
defer s.mu.RUnlock()
// 收集所有任务
allTasks := make([]TaskInfo, 0, len(s.taskInfos))
for _, taskInfo := range s.taskInfos {
// 更新状态
status, exists := s.timeWheel.GetTaskStatus(taskInfo.TaskID)
if exists {
switch status {
case timewheel.TaskStatusPending:
taskInfo.Status = TaskStatusPending
case timewheel.TaskStatusRunning:
taskInfo.Status = TaskStatusRunning
case timewheel.TaskStatusCompleted:
taskInfo.Status = TaskStatusCompleted
case timewheel.TaskStatusCancelled:
taskInfo.Status = TaskStatusCancelled
}
}
allTasks = append(allTasks, *taskInfo)
}
total := len(allTasks)
// 分页
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 10
}
if pageSize > 100 {
pageSize = 100
}
start := (page - 1) * pageSize
end := start + pageSize
if start >= total {
return &ListTasksResponse{
Tasks: []TaskInfo{},
Total: total,
Page: page,
PageSize: pageSize,
}
}
if end > total {
end = total
}
tasks := allTasks[start:end]
return &ListTasksResponse{
Tasks: tasks,
Total: total,
Page: page,
PageSize: pageSize,
}
}
// executeTask 执行任务
func (s *Service) executeTask(req *AddTaskRequest, taskID string) {
startTime := time.Now()
// 如果有回调URL发送HTTP请求
if req.CallbackURL != "" {
s.sendHttpCallback(req, taskID, startTime)
} else {
log.Printf("✅ 任务执行完成无回调URL: %s", taskID)
}
// 更新任务状态
s.mu.Lock()
if taskInfo, exists := s.taskInfos[taskID]; exists {
// 如果是一次性任务,执行后删除
if req.TaskType == TaskTypeOnce {
taskInfo.Status = TaskStatusCompleted
delete(s.taskInfos, taskID)
} else {
// 循环任务,保持运行状态
taskInfo.Status = TaskStatusRunning
}
}
s.mu.Unlock()
}
// sendHttpCallback 发送HTTP回调请求
func (s *Service) sendHttpCallback(req *AddTaskRequest, taskID string, startTime time.Time) {
log.Printf("🔔 开始HTTP回调: %s", req.CallbackURL)
// 创建HTTP客户端设置超时
client := &http.Client{
Timeout: 10 * time.Second,
}
// 发送GET请求
resp, err := client.Get(req.CallbackURL)
if err != nil {
log.Printf("❌ HTTP回调失败: %s, 错误: %v", req.CallbackURL, err)
return
}
defer resp.Body.Close()
// 读取响应内容
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("❌ 读取响应失败: %v", err)
return
}
// 打印响应结果
log.Printf("✅ HTTP回调成功!")
log.Printf(" URL: %s", req.CallbackURL)
log.Printf(" 状态码: %d", resp.StatusCode)
log.Printf(" 响应数据: %s", string(body))
}