limit_ip.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. package middleware
  2. import (
  3. "context"
  4. "errors"
  5. "github.com/sirupsen/logrus"
  6. "net/http"
  7. "time"
  8. "go.uber.org/zap"
  9. "github.com/gin-gonic/gin"
  10. "lc-base-frame/global"
  11. "lc-base-frame/model/common/response"
  12. )
  13. type LimitConfig struct {
  14. // GenerationKey 根据业务生成key 下面CheckOrMark查询生成
  15. GenerationKey func(c *gin.Context) string
  16. // 检查函数,用户可修改具体逻辑,更加灵活
  17. CheckOrMark func(key string, expire int, limit int) error
  18. // Expire key 过期时间
  19. Expire int
  20. // Limit 周期时间
  21. Limit int
  22. }
  23. func (l LimitConfig) LimitWithTime() gin.HandlerFunc {
  24. return func(c *gin.Context) {
  25. if err := l.CheckOrMark(l.GenerationKey(c), l.Expire, l.Limit); err != nil {
  26. c.JSON(http.StatusOK, gin.H{"code": response.ERROR, "msg": err})
  27. c.Abort()
  28. return
  29. } else {
  30. c.Next()
  31. }
  32. }
  33. }
  34. // DefaultGenerationKey 默认生成key
  35. func DefaultGenerationKey(c *gin.Context) string {
  36. return "GVA_Limit" + c.ClientIP()
  37. }
  38. func DefaultCheckOrMark(key string, expire int, limit int) (err error) {
  39. // 判断是否开启redis
  40. if global.GVA_REDIS == nil {
  41. return err
  42. }
  43. if err = SetLimitWithTime(key, limit, time.Duration(expire)*time.Second); err != nil {
  44. logrus.Error("limit", zap.Error(err))
  45. }
  46. return err
  47. }
  48. func DefaultLimit() gin.HandlerFunc {
  49. return LimitConfig{
  50. GenerationKey: DefaultGenerationKey,
  51. CheckOrMark: DefaultCheckOrMark,
  52. Expire: global.Config.System.LimitTimeIP,
  53. Limit: global.Config.System.LimitCountIP,
  54. }.LimitWithTime()
  55. }
  56. // SetLimitWithTime 设置访问次数
  57. func SetLimitWithTime(key string, limit int, expiration time.Duration) error {
  58. count, err := global.GVA_REDIS.Exists(context.Background(), key).Result()
  59. if err != nil {
  60. return err
  61. }
  62. if count == 0 {
  63. pipe := global.GVA_REDIS.TxPipeline()
  64. pipe.Incr(context.Background(), key)
  65. pipe.Expire(context.Background(), key, expiration)
  66. _, err = pipe.Exec(context.Background())
  67. return err
  68. } else {
  69. // 次数
  70. if times, err := global.GVA_REDIS.Get(context.Background(), key).Int(); err != nil {
  71. return err
  72. } else {
  73. if times >= limit {
  74. if t, err := global.GVA_REDIS.PTTL(context.Background(), key).Result(); err != nil {
  75. return errors.New("请求太过频繁,请稍后再试")
  76. } else {
  77. return errors.New("请求太过频繁, 请 " + t.String() + " 秒后尝试")
  78. }
  79. } else {
  80. return global.GVA_REDIS.Incr(context.Background(), key).Err()
  81. }
  82. }
  83. }
  84. }