tcp.go 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. package tcp
  2. import (
  3. "errors"
  4. "net"
  5. "server/utils/logger"
  6. "strings"
  7. "sync"
  8. "time"
  9. )
  10. func ListenTcp() {
  11. // 监听当前的tcp连接
  12. listen, err := net.Listen("tcp", "0.0.0.0:9200")
  13. if err != nil {
  14. logger.Logger.Errorf("listen failed, err:%v", err)
  15. return
  16. }
  17. tracker := NewConnectionTracker() //创建连接检测器
  18. for {
  19. conn, err := listen.Accept()
  20. if err != nil {
  21. logger.Logger.Errorf("Accept failed, err:%v", err)
  22. continue
  23. }
  24. err = CheckConn(conn, tracker)
  25. if err != nil {
  26. conn.Close() // 如果是恶意连接,则关闭连接
  27. continue
  28. }
  29. }
  30. }
  31. func CheckConn(conn net.Conn, tracker *ConnectionTracker) error {
  32. logger.Logger.Debugf("StartDevice addr:%s", conn.RemoteAddr().String())
  33. arr := strings.Split(conn.RemoteAddr().String(), ":")
  34. ip := arr[0]
  35. // 记录连接
  36. tracker.recordConnection(ip)
  37. // 检查是否为恶意连接
  38. if tracker.isMalicious(ip) {
  39. logger.Logger.Debugf("恶意连接检测到 ip: %s\n", ip)
  40. return errors.New("connection is Malicious")
  41. }
  42. device := Device{}
  43. device.Start(conn)
  44. return nil
  45. }
  46. type ConnectionTracker struct {
  47. mu sync.Mutex
  48. connections map[string][]time.Time // 存储每个 IP 的连接时间戳
  49. }
  50. func NewConnectionTracker() *ConnectionTracker {
  51. return &ConnectionTracker{
  52. connections: make(map[string][]time.Time),
  53. }
  54. }
  55. func (ct *ConnectionTracker) recordConnection(ip string) {
  56. ct.mu.Lock()
  57. defer ct.mu.Unlock()
  58. now := time.Now()
  59. ct.connections[ip] = append(ct.connections[ip], now)
  60. // 清理过期的连接记录
  61. ct.cleanUpExpired(ip, now)
  62. }
  63. func (ct *ConnectionTracker) cleanUpExpired(ip string, now time.Time) {
  64. threshold := now.Add(-3 * time.Minute)
  65. if timestamps, exists := ct.connections[ip]; exists {
  66. var filtered []time.Time
  67. for _, t := range timestamps {
  68. if t.After(threshold) { // 检查时间戳是否在三分钟内
  69. filtered = append(filtered, t) // 如果在范围内,保存到 filtered 列表
  70. }
  71. }
  72. ct.connections[ip] = filtered
  73. }
  74. }
  75. // 判断是否是恶意连接
  76. func (ct *ConnectionTracker) isMalicious(ip string) bool {
  77. ct.mu.Lock()
  78. defer ct.mu.Unlock()
  79. if timestamps, exists := ct.connections[ip]; exists {
  80. return len(timestamps) >= 10 // 定义恶意连接的阈值
  81. }
  82. return false
  83. }