package timer import ( "container/list" "sync" "time" ) // TimeWheel 时间轮 type TimeWheel struct { interval time.Duration ticker *time.Ticker buckets []*list.List size int cursor int tasks sync.Map // key: *Task, value: bucket index stop chan struct{} } type Task struct { delay time.Duration circle int callback func() } // NewTimeWheel 创建一个时间轮 func NewTimeWheel(interval time.Duration, size int) *TimeWheel { tw := &TimeWheel{ interval: interval, size: size, buckets: make([]*list.List, size), stop: make(chan struct{}), } for i := 0; i < size; i++ { tw.buckets[i] = list.New() } return tw } // Start 启动时间轮 func (tw *TimeWheel) Start() { tw.ticker = time.NewTicker(tw.interval) go func() { for { select { case <-tw.ticker.C: tw.tick() case <-tw.stop: tw.ticker.Stop() return } } }() } // Stop 停止时间轮 func (tw *TimeWheel) Stop() { close(tw.stop) } func (tw *TimeWheel) tick() { bucket := tw.buckets[tw.cursor] tw.execute(bucket) tw.cursor = (tw.cursor + 1) % tw.size } func (tw *TimeWheel) execute(bucket *list.List) { for e := bucket.Front(); e != nil; { task := e.Value.(*Task) if task.circle > 0 { task.circle-- e = e.Next() continue } go task.callback() next := e.Next() bucket.Remove(e) tw.tasks.Delete(task) e = next } } // AfterFunc 在延迟后执行函数 func (tw *TimeWheel) AfterFunc(delay time.Duration, callback func()) *Task { if delay < tw.interval { delay = tw.interval } task := &Task{ delay: delay, callback: callback, } tw.addTask(task) return task } func (tw *TimeWheel) addTask(task *Task) { pos, circle := tw.getPositionAndCircle(task.delay) task.circle = circle tw.buckets[pos].PushBack(task) tw.tasks.Store(task, pos) } func (tw *TimeWheel) getPositionAndCircle(delay time.Duration) (pos int, circle int) { steps := int(delay / tw.interval) pos = (tw.cursor + steps) % tw.size circle = (steps - 1) / tw.size if steps <= 0 { pos = tw.cursor circle = 0 } return }