From 4859645234854c3ddfd176332e2d265706036b0b Mon Sep 17 00:00:00 2001 From: helei Date: Tue, 20 Jan 2026 15:09:22 +0800 Subject: [PATCH 1/9] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E5=AD=98=E5=82=A8=20FileSaver=20interface?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: helei --- db/fileSaver/local/local.go | 6 +- file/filesaver/ftp/ftp.go | 187 ++++++++++++++++++++++++++++++ file/filesaver/ftp/types.go | 85 ++++++++++++++ file/filesaver/hdfs/hdfs.go | 142 +++++++++++++++++++++++ file/filesaver/hdfs/types.go | 41 +++++++ file/filesaver/localfile/local.go | 113 ++++++++++++++++++ file/filesaver/minio/minio.go | 116 ++++++++++++++++++ file/filesaver/minio/types.go | 46 ++++++++ file/filesaver/saver.go | 29 +++++ file/filesaver/sftp/sftp.go | 180 ++++++++++++++++++++++++++++ file/filesaver/sftp/types.go | 85 ++++++++++++++ file/filesaver/types.go | 43 +++++++ 12 files changed, 1071 insertions(+), 2 deletions(-) create mode 100644 file/filesaver/ftp/ftp.go create mode 100644 file/filesaver/ftp/types.go create mode 100644 file/filesaver/hdfs/hdfs.go create mode 100644 file/filesaver/hdfs/types.go create mode 100644 file/filesaver/localfile/local.go create mode 100644 file/filesaver/minio/minio.go create mode 100644 file/filesaver/minio/types.go create mode 100644 file/filesaver/saver.go create mode 100644 file/filesaver/sftp/sftp.go create mode 100644 file/filesaver/sftp/types.go create mode 100644 file/filesaver/types.go diff --git a/db/fileSaver/local/local.go b/db/fileSaver/local/local.go index 2362773..4ab276d 100644 --- a/db/fileSaver/local/local.go +++ b/db/fileSaver/local/local.go @@ -2,11 +2,12 @@ package local import ( "fmt" - "github.com/helays/utils/v2/close/vclose" - "github.com/helays/utils/v2/tools" "io" "os" "path" + + "github.com/helays/utils/v2/close/vclose" + "github.com/helays/utils/v2/tools" ) type Local struct{} @@ -29,6 +30,7 @@ func (this Local) Write(p string, src io.Reader, existIgnores ...bool) (int64, e return 0, fmt.Errorf("创建目录%s失败: %s", dir, err.Error()) } file, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + defer vclose.Close(file) if err != nil { return 0, fmt.Errorf("打开文件%s失败: %s", filePath, err.Error()) } diff --git a/file/filesaver/ftp/ftp.go b/file/filesaver/ftp/ftp.go new file mode 100644 index 0000000..3c7ea34 --- /dev/null +++ b/file/filesaver/ftp/ftp.go @@ -0,0 +1,187 @@ +package ftp + +import ( + "fmt" + "io" + + "path/filepath" + "strings" + + "github.com/helays/utils/v2/close/ftpClose" + "github.com/helays/utils/v2/dataType/customWriter" + "github.com/jlaffaye/ftp" +) + +type Saver struct { + opt *Config + client *ftp.ServerConn +} + +func New(cfg *Config) (*Saver, error) { + s := &Saver{opt: cfg} + if err := cfg.Valid(); err != nil { + return nil, err + } + if err := s.login(); err != nil { + return nil, err + } + return s, nil +} + +// Write 写入文件 +func (s *Saver) Write(p string, src io.Reader, existIgnores ...bool) (int64, error) { + path, err := s.setPath(p) + if err != nil { + return 0, err + } + if len(existIgnores) > 0 && existIgnores[0] { + if exist, err := s.exist(path); err != nil { + return 0, err + } else if exist { + return 0, nil + } + } + dir := filepath.Dir(path) + if err = s.mkdir(dir); err != nil { + return 0, err + } + counter := &customWriter.SizeCounter{} + teeReader := io.TeeReader(src, counter) + if err = s.client.Stor(path, teeReader); err != nil { + return 0, err + } + return counter.TotalSize, nil +} + +// Read 读取文件 +func (s *Saver) Read(p string) (io.ReadCloser, error) { + path, err := s.setPath(p) + if err != nil { + return nil, err + } + return s.client.Retr(path) +} + +func (s *Saver) ListFiles(dirPath string) ([]string, error) { + path, err := s.setPath(dirPath) + if err != nil { + return nil, err + } + entries, err := s.client.List(path) + if err != nil { + return nil, err + } + var fileNames []string + for _, entry := range entries { + if entry.Type == ftp.EntryTypeFile { + fileNames = append(fileNames, entry.Name) + } + } + return fileNames, nil +} + +func (s *Saver) Delete(p string) error { + path, err := s.setPath(p) + if err != nil { + return err + } + return s.client.Delete(path) +} + +func (s *Saver) DeleteAll(p string) error { + path, err := s.setPath(p) + if err != nil { + return err + } + return s.client.RemoveDirRecur(path) +} + +func (s *Saver) Close() error { + if s.client == nil { + return nil + } + ftpClose.CloseFtpClient(s.client) + s.client = nil + return nil +} + +// 设置当前文件全路径 +func (s *Saver) setPath(p string) (string, error) { + //if path.IsAbs(p) { + // return p, nil + //} + // 获取当前目录 + current, err := s.client.CurrentDir() + if err != nil { + return "", fmt.Errorf("获取当前目录失败:%s", err.Error()) + } + return filepath.Join(current, p), nil +} + +// 判断文件是否存在 +func (s *Saver) exist(p string) (bool, error) { + remotePath := filepath.Dir(p) + lst, err := s.client.List(remotePath) + if err != nil { + return false, fmt.Errorf("判断文件%s是否存在失败:%s", p, err.Error()) + } + remoteFileName := filepath.Base(p) + for _, v := range lst { + if v.Name == remoteFileName { + return true, nil + } + } + return false, nil +} + +func (s *Saver) mkdir(p string) error { + if ok, err := s.exist(p); err != nil { + return err + } else if ok { + return nil + } + return s.mkdirALL(p) +} + +func (s *Saver) mkdirALL(p string) error { + currentDir, err := s.client.CurrentDir() + if err != nil { + return fmt.Errorf("获取当前目录失败:%s", err.Error()) + } + path := strings.TrimPrefix(p, currentDir) + var currentPath string + for _, part := range strings.Split(path, "/") { + if part == "" { + continue // Skip empty parts which can happen with leading/trailing slashes or double slashes. + } + currentPath = fmt.Sprintf("%s/%s", currentPath, part) + err = s.client.ChangeDir(currentPath) + if err != nil { + // Directory does not exist, so create it. + if err = s.client.MakeDir(part); err != nil { + return err + } + // Change to the newly created directory. + if err = s.client.ChangeDir(part); err != nil { + return err + } + } + } + return s.client.ChangeDir(currentDir) +} + +// 登录 +func (s *Saver) login() error { + if s.client != nil { + return nil + } + var err error + s.client, err = ftp.Dial(s.opt.Host, ftp.DialWithDisabledEPSV(s.opt.Epsv == EpsvActive)) + if err != nil { + return fmt.Errorf("ftp连接失败:%s", err.Error()) + } + if err = s.client.Login(s.opt.User, s.opt.Pwd); err != nil { + return fmt.Errorf("ftp登录失败:%s", err.Error()) + } + return nil +} diff --git a/file/filesaver/ftp/types.go b/file/filesaver/ftp/types.go new file mode 100644 index 0000000..9eb4ad5 --- /dev/null +++ b/file/filesaver/ftp/types.go @@ -0,0 +1,85 @@ +package ftp + +import ( + "database/sql/driver" + "fmt" + + "github.com/helays/utils/v2/config" + "github.com/helays/utils/v2/dataType" + "github.com/helays/utils/v2/net/checkIp" + "gorm.io/gorm" + "gorm.io/gorm/schema" +) + +// Config ftp 配置 +// noinspection all +type Config struct { + Host string `json:"host" yaml:"host" ini:"host"` + User string `json:"user" yaml:"user" ini:"user"` + Pwd string `json:"pwd" yaml:"pwd" ini:"pwd"` + Epsv Epsv `ini:"epsv" yaml:"epsv" json:"epsv,omitempty"` // ftp连接模式 +} + +// Epsv ftp连接模式 +// noinspection all +type Epsv int + +// 0 被动模式 1 主动模式 +// noinspection all +const ( + EpsvPassive Epsv = 0 + EpsvActive Epsv = 1 +) + +// noinspection all +func (c Config) Value() (driver.Value, error) { + return dataType.DriverValueWithJson(c) +} + +// noinspection all +func (c *Config) Scan(val any) error { + return dataType.DriverScanWithJson(val, c) +} + +// noinspection all +func (c Config) GormDataType() string { + return "json" +} + +// noinspection all +func (Config) GormDBDataType(db *gorm.DB, field *schema.Field) string { + return dataType.JsonDbDataType(db, field) +} + +// noinspection all +func (c *Config) RemovePasswd() { + c.Pwd = "" +} + +// noinspection all +func (c *Config) Valid() error { + if _, port, err := checkIp.ParseIPAndPort(c.Host); err != nil { + return err + } else if port < 1 { + return fmt.Errorf("缺失端口号") + } + if c.Epsv != EpsvPassive && c.Epsv != EpsvActive { + return fmt.Errorf("无效的连接模式") + } + return nil +} + +// noinspection all +func (c *Config) SetInfo(args ...any) { + if len(args) != 2 { + return + } + switch args[0].(string) { + case config.ClientInfoHost: + c.Host = args[1].(string) + case config.ClientInfoUser: + c.User = args[1].(string) + case config.ClientInfoPasswd: + c.Pwd = args[1].(string) + } +} diff --git a/file/filesaver/hdfs/hdfs.go b/file/filesaver/hdfs/hdfs.go new file mode 100644 index 0000000..7782192 --- /dev/null +++ b/file/filesaver/hdfs/hdfs.go @@ -0,0 +1,142 @@ +package hdfs + +import ( + "fmt" + "io" + "os" + "path/filepath" + + "github.com/colinmarc/hdfs/v2" + "github.com/helays/utils/v2/close/vclose" +) + +type Saver struct { + opt *Config + client *hdfs.Client +} + +func New(cfg *Config) (*Saver, error) { + s := &Saver{opt: cfg} + if err := cfg.Valid(); err != nil { + return nil, err + } + if err := s.login(); err != nil { + return nil, err + } + return s, nil +} + +func (s *Saver) Close() error { + if s.client == nil { + return nil + } + vclose.Close(s.client) + s.client = nil + return nil +} + +func (s *Saver) Write(p string, src io.Reader, existIgnores ...bool) (int64, error) { + if !filepath.IsAbs(p) { + p = filepath.Join("/", p) + } + if ok, err := s.exist(p); ok { + if len(existIgnores) > 0 && existIgnores[0] { + return 0, nil + } + // 删除文件,重写 + if err = s.client.Remove(p); err != nil { + return 0, fmt.Errorf("删除文件%s失败: %s", p, err.Error()) + } + } else if err != nil { + return 0, err + } + dir := filepath.Dir(p) + if err := s.client.MkdirAll(dir, 0755); err != nil { + return 0, fmt.Errorf("创建目录%s失败: %s", dir, err.Error()) + } + remoteFile, err := s.client.Create(p) + defer vclose.Close(remoteFile) + if err != nil { + return 0, fmt.Errorf("创建文件%s失败: %s", p, err.Error()) + } + return io.Copy(remoteFile, src) +} + +func (s *Saver) Read(p string) (io.ReadCloser, error) { + if !filepath.IsAbs(p) { + p = filepath.Join("/", p) + } + remoteFile, err := s.client.Open(p) + if err != nil { + return nil, err + } + return remoteFile, nil +} + +// ListFiles 列出目录下的文件 +func (s *Saver) ListFiles(p string) ([]string, error) { + if !filepath.IsAbs(p) { + p = filepath.Join("/", p) + } + entries, err := s.client.ReadDir(p) + if err != nil { + return nil, err + } + var files []string + for _, entry := range entries { + if !entry.IsDir() { + files = append(files, entry.Name()) + } + } + return files, nil +} + +func (s *Saver) Delete(p string) error { + if !filepath.IsAbs(p) { + p = filepath.Join("/", p) + } + if ok, err := s.exist(p); !ok { + if err != nil { + return err + } + return nil + } + return s.client.Remove(p) +} + +func (s *Saver) DeleteAll(p string) error { + if !filepath.IsAbs(p) { + p = filepath.Join("/", p) + } + return s.client.RemoveAll(p) +} + +func (s *Saver) login() error { + if s.client != nil { + return nil + } + var err error + s.client, err = hdfs.NewClient(hdfs.ClientOptions{ + Addresses: s.opt.Addresses, // 指定要连接的 NameNode 地址列表。 + User: s.opt.User, // 指定客户端以哪个 HDFS 用户身份进行操作 + UseDatanodeHostname: s.opt.UseDatanodeHostname, // 指定客户端是否通过主机名(而不是 IP 地址)连接 DataNode。 + NamenodeDialFunc: nil, // 自定义连接 NameNode 的拨号函数。 + DatanodeDialFunc: nil, // 自定义连接 DataNode 的拨号函数。 + KerberosClient: nil, // 于连接启用了 Kerberos 认证的 HDFS 集群。 + KerberosServicePrincipleName: s.opt.KerberosServicePrincipleName, // 指定 NameNode 的 Kerberos 服务主体名称(SPN)。格式为 /,例如 nn/_HOST。 + DataTransferProtection: s.opt.DataTransferProtection, // 指定与 DataNode 通信时的数据保护级别。 + }) + if err != nil { + return fmt.Errorf("hdfs连接失败 %v", err) + } + return nil +} + +func (s *Saver) exist(p string) (bool, error) { + if _, err := s.client.Stat(p); err == nil { + return true, nil + } else if !os.IsNotExist(err) { + return false, err + } + return false, nil +} diff --git a/file/filesaver/hdfs/types.go b/file/filesaver/hdfs/types.go new file mode 100644 index 0000000..2fe46cc --- /dev/null +++ b/file/filesaver/hdfs/types.go @@ -0,0 +1,41 @@ +package hdfs + +import ( + "fmt" + + "github.com/helays/utils/v2/config" +) + +type Config struct { + Addresses []string `json:"addresses" yaml:"addresses" ini:"addresses,omitempty"` // 路径 + User string `json:"user" yaml:"user" ini:"user"` + // 指定客户端是否通过主机名(而不是 IP 地址)连接 DataNode。 + UseDatanodeHostname bool `json:"use_datanode_hostname" yaml:"use_datanode_hostname" ini:"use_datanode_hostname"` + // 指定 NameNode 的 Kerberos 服务主体名称(SPN)。格式为 /,例如 nn/_HOST。 + KerberosServicePrincipleName string `json:"kerberos_service_principle_name" yaml:"kerberos_service_principle_name" ini:"kerberos_service_principle_name"` + // 指定与 DataNode 通信时的数据保护级别。 + // authentication:仅认证; + // integrity: 认证 + 数据完整性校验 + // integrity+privacy: 认证 + 数据完整性校验 + 数据加密 + DataTransferProtection string `json:"data_transfer_protection" yaml:"data_transfer_protection" ini:"data_transfer_protection"` +} + +func (c *Config) Valid() error { + if len(c.Addresses) < 1 { + return fmt.Errorf("缺失地址") + } + + return nil +} + +func (c *Config) SetInfo(args ...any) { + if len(args) != 2 { + return + } + switch args[0].(string) { + case config.ClientInfoHost: + c.Addresses = args[1].([]string) + case config.ClientInfoUser: + c.User = args[1].(string) + } +} diff --git a/file/filesaver/localfile/local.go b/file/filesaver/localfile/local.go new file mode 100644 index 0000000..d577cf1 --- /dev/null +++ b/file/filesaver/localfile/local.go @@ -0,0 +1,113 @@ +package localfile + +import ( + "fmt" + "io" + "os" + "path/filepath" + + "github.com/helays/utils/v2/close/vclose" + "github.com/helays/utils/v2/tools" +) + +type Config struct { + Root string `json:"root" yaml:"root" ini:"root"` +} + +type Saver struct { + opt *Config +} + +func New(cfg *Config) (*Saver, error) { + return &Saver{opt: cfg}, nil +} + +func (s *Saver) Write(p string, src io.Reader, existIgnores ...bool) (int64, error) { + path, err := s.realPath(p) + if err != nil { + return 0, err + } + if len(existIgnores) > 0 && existIgnores[0] { + // 如果启用 文件存在就忽略,首先判断文件是否存在, + // 如果文件存在,就中断处理 + // 如果err有问题,判断是否因为文件不存在导致的。 + if _, err = os.Stat(path); err == nil { + return 0, nil + } else if !os.IsNotExist(err) { + return 0, err + } + } + dir := filepath.Dir(path) + if err = tools.Mkdir(dir); err != nil { + return 0, fmt.Errorf("创建目录[%s]失败:%s", dir, err) + } + file, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + defer vclose.Close(file) + if err != nil { + return 0, fmt.Errorf("创建文件[%s]失败:%s", path, err) + } + var n int64 + n, err = io.Copy(file, src) + if err != nil { + return n, fmt.Errorf("写入文件[%s]失败:%s", path, err) + } + return n, nil +} + +func (s *Saver) Read(p string) (io.ReadCloser, error) { + path, err := s.realPath(p) + if err != nil { + return nil, err + } + file, err := os.Open(path) + defer vclose.Close(file) + if err != nil { + return nil, fmt.Errorf("打开文件[%s]失败:%s", path, err) + } + return file, nil +} + +func (s *Saver) ListFiles(p string) ([]string, error) { + path, err := s.realPath(p) + if err != nil { + return nil, err + } + entries, err := os.ReadDir(path) + if err != nil { + return nil, fmt.Errorf("读取目录[%s]失败:%s", path, err) + } + var filePaths []string + for _, entry := range entries { + if !entry.IsDir() { + filePaths = append(filePaths, entry.Name()) + } + } + return filePaths, nil +} + +func (s *Saver) Delete(p string) error { + path, err := s.realPath(p) + if err != nil { + return err + } + return os.Remove(path) +} + +func (s *Saver) DeleteAll(p string) error { + path, err := s.realPath(p) + if err != nil { + return err + } + return os.RemoveAll(path) +} + +func (s *Saver) Close() error { + return nil +} + +func (s *Saver) realPath(p string) (string, error) { + if tools.ContainsDotDot(p) { + return "", fmt.Errorf("路径[%s]包含 '..'", p) + } + return tools.Fileabs(filepath.Join(s.opt.Root, p)), nil +} diff --git a/file/filesaver/minio/minio.go b/file/filesaver/minio/minio.go new file mode 100644 index 0000000..5591977 --- /dev/null +++ b/file/filesaver/minio/minio.go @@ -0,0 +1,116 @@ +package minio + +import ( + "context" + "fmt" + "io" + "path/filepath" + "strings" + + "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" +) + +type Saver struct { + opt *Config + client *minio.Client + + ctx context.Context + cancel context.CancelFunc +} + +func New(cfg *Config) (*Saver, error) { + s := &Saver{opt: cfg} + if err := s.opt.Valid(); err != nil { + return nil, err + } + if err := s.login(); err != nil { + return nil, err + } + return s, nil +} + +func (s *Saver) Write(p string, src io.Reader, existIgnores ...bool) (int64, error) { + if len(existIgnores) > 0 && existIgnores[0] { + if _, err := s.client.StatObject(s.ctx, s.opt.Options.Bucket, p, minio.StatObjectOptions{}); err == nil { + return 0, err + } else if _err := err.Error(); !strings.Contains(_err, "key does not exist") { + return 0, fmt.Errorf("文件已存在: %s", _err) + } + } + info, err := s.client.PutObject(s.ctx, s.opt.Options.Bucket, p, src, -1, minio.PutObjectOptions{}) + if err != nil { + return 0, err + } + return info.Size, nil +} + +func (s *Saver) Read(p string) (io.ReadCloser, error) { + return s.client.GetObject(s.ctx, s.opt.Options.Bucket, p, minio.GetObjectOptions{}) +} + +func (s *Saver) ListFiles(p string) ([]string, error) { + var files []string + opts := minio.ListObjectsOptions{ + Prefix: p, + Recursive: true, + } + for obj := range s.client.ListObjects(s.ctx, s.opt.Options.Bucket, opts) { + if obj.Err != nil { + return nil, obj.Err + } + files = append(files, filepath.Base(obj.Key)) + } + return files, nil +} + +func (s *Saver) Delete(p string) error { + return s.client.RemoveObject(s.ctx, s.opt.Options.Bucket, p, minio.RemoveObjectOptions{}) +} + +func (s *Saver) DeleteAll(p string) error { + return s.client.RemoveObject(s.ctx, s.opt.Options.Bucket, p, minio.RemoveObjectOptions{}) +} + +func (s *Saver) Close() error { + if s.client == nil { + return nil + } + s.cancel() + s.client = nil + return nil +} + +func (s *Saver) login() error { + if s.client != nil { + return nil + } + options := &minio.Options{ + Creds: credentials.NewStaticV4(s.opt.AccessKeyID, s.opt.SecretAccessKey, ""), + Secure: s.opt.UseSSL, + } + var err error + if s.client, err = minio.New(s.opt.Endpoint, options); err != nil { + return fmt.Errorf("连接MinIO节点失败: %s", err.Error()) + } + + s.ctx, s.cancel = context.WithCancel(context.Background()) + return s.createBucket() +} + +// 创建 bucket +func (s *Saver) createBucket() error { + if ok, err := s.client.BucketExists(s.ctx, s.opt.Options.Bucket); ok { + return nil + } else if err != nil { + return fmt.Errorf("查询bucket %s失败: %s", s.opt.Options.Bucket, err.Error()) + } + err := s.client.MakeBucket(s.ctx, s.opt.Options.Bucket, minio.MakeBucketOptions{ + Region: s.opt.Options.Region, + ObjectLocking: s.opt.Options.ObjectLocking, + }) + if err != nil { + return fmt.Errorf("创建bucket %s失败: %s", s.opt.Options.Bucket, err.Error()) + } + return nil +} diff --git a/file/filesaver/minio/types.go b/file/filesaver/minio/types.go new file mode 100644 index 0000000..fd99445 --- /dev/null +++ b/file/filesaver/minio/types.go @@ -0,0 +1,46 @@ +package minio + +import ( + "fmt" + + "github.com/helays/utils/v2/config" +) + +type Config struct { + Endpoint string `json:"endpoint" yaml:"endpoint" ini:"endpoint"` // MinIO 节点地址(单点或集群) + AccessKeyID string `json:"access_key_id" yaml:"access_key_id" ini:"access_key_id"` // 访问密钥 + SecretAccessKey string `json:"secret_access_key" yaml:"secret_access_key" ini:"secret_access_key"` // 秘密密钥 + UseSSL bool `json:"use_ssl" yaml:"use_ssl" ini:"use_ssl"` // 是否使用 HTTPS + Options Options `json:"options" yaml:"options" ini:"options"` // 配置项 +} + +type Options struct { + Bucket string `json:"bucket" yaml:"bucket" ini:"bucket"` // 存储桶名称 + Region string `json:"region" yaml:"region" ini:"region"` //指定 Bucket 所在的区域(Region)。MinIO 默认使用 us-east-1 作为区域 + ObjectLocking bool `json:"object_locking" yaml:"object_locking" ini:"object_locking"` //是否启用对象锁定(Object Locking)功能 +} + +func (c *Config) Valid() error { + if c.Endpoint == "" { + return fmt.Errorf("缺失地址") + } + return nil +} + +func (c *Config) RemovePasswd() { + c.SecretAccessKey = "" +} + +func (c *Config) SetInfo(args ...any) { + if len(args) != 2 { + return + } + switch args[0].(string) { + case config.ClientInfoHost: + c.Endpoint = args[1].(string) + case config.ClientInfoUser: + c.AccessKeyID = args[1].(string) + case config.ClientInfoPasswd: + c.SecretAccessKey = args[1].(string) + } +} diff --git a/file/filesaver/saver.go b/file/filesaver/saver.go new file mode 100644 index 0000000..075fca2 --- /dev/null +++ b/file/filesaver/saver.go @@ -0,0 +1,29 @@ +package filesaver + +import ( + "fmt" + + "github.com/helays/utils/v2/file/filesaver/ftp" + "github.com/helays/utils/v2/file/filesaver/hdfs" + "github.com/helays/utils/v2/file/filesaver/localfile" + "github.com/helays/utils/v2/file/filesaver/minio" + "github.com/helays/utils/v2/file/filesaver/sftp" +) + +func New(cfg *Config) (FileSaver, error) { + switch cfg.Driver { + case DriverLocal: + return localfile.New(&cfg.Local) + case DriverSftp: + return sftp.New(&cfg.SFTP) + case DriverFtp: + return ftp.New(&cfg.FTP) + case DriverHdfs: + return hdfs.New(&cfg.HDFS) + case DriverMinio: + return minio.New(&cfg.Minio) + //case DriverCeph: + default: + panic(fmt.Errorf("不支持的文件系统驱动 %s", cfg.Driver)) + } +} diff --git a/file/filesaver/sftp/sftp.go b/file/filesaver/sftp/sftp.go new file mode 100644 index 0000000..83f09e0 --- /dev/null +++ b/file/filesaver/sftp/sftp.go @@ -0,0 +1,180 @@ +package sftp + +import ( + "fmt" + "io" + "os" + "path/filepath" + + "github.com/helays/utils/v2/close/vclose" + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +type Saver struct { + opt *Config + sshClient *ssh.Client + sftpClient *sftp.Client +} + +func New(cfg *Config) (*Saver, error) { + s := &Saver{opt: cfg} + if err := cfg.Valid(); err != nil { + return nil, err + } + if err := s.login(); err != nil { + return nil, err + } + return s, nil +} + +func (s *Saver) Write(p string, src io.Reader, existIgnores ...bool) (int64, error) { + path, err := s.setPath(p) + if err != nil { + return 0, err + } + if len(existIgnores) > 0 && existIgnores[0] { + if ok, err := s.exist(path); err != nil { + return 0, err + } else if ok { + return 0, nil + } + } + dir := filepath.Dir(path) + if err = s.mkdir(dir); err != nil { + return 0, err + } + file, err := s.sftpClient.Create(path) + defer vclose.Close(file) + if err != nil { + return 0, fmt.Errorf("创建文件%s失败:%s", path, err.Error()) + } + return io.Copy(file, src) +} + +func (s *Saver) Read(p string) (io.ReadCloser, error) { + path, err := s.setPath(p) + if err != nil { + return nil, err + } + file, err := s.sftpClient.Open(path) + if err != nil { + return nil, fmt.Errorf("打开文件%s失败:%s", path, err.Error()) + } + return file, nil +} + +func (s *Saver) ListFiles(dirPath string) ([]string, error) { + path, err := s.setPath(dirPath) + if err != nil { + return nil, err + } + entries, err := s.sftpClient.ReadDir(path) + if err != nil { + return nil, fmt.Errorf("获取目录%s失败:%s", path, err.Error()) + } + var files []string + for _, entry := range entries { + if !entry.IsDir() { + files = append(files, entry.Name()) + } + } + return files, nil +} + +func (s *Saver) Delete(p string) error { + path, err := s.setPath(p) + if err != nil { + return err + } + return s.sftpClient.Remove(path) +} + +func (s *Saver) DeleteAll(p string) error { + path, err := s.setPath(p) + if err != nil { + return err + } + return s.sftpClient.RemoveAll(path) +} + +func (s *Saver) Close() error { + vclose.Close(s.sftpClient) + s.sftpClient = nil + vclose.Close(s.sshClient) + s.sshClient = nil + return nil +} + +func (s *Saver) login() error { + if err := s.loginSSH(); err != nil { + return err + } + if s.sftpClient == nil { + var err error + s.sftpClient, err = sftp.NewClient(s.sshClient) + if err != nil { + return fmt.Errorf("sftp连接失败:%s", err.Error()) + } + } + return nil +} + +func (s *Saver) loginSSH() error { + if s.sshClient != nil { + return nil + } + cfg := &ssh.ClientConfig{ + User: s.opt.User, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + var auth ssh.AuthMethod + if s.opt.Authentication == Password { + auth = ssh.Password(s.opt.Pwd) + } else { + signer, err := ssh.ParsePrivateKey([]byte(s.opt.Pwd)) + if err != nil { + return fmt.Errorf("ssh密钥解析失败:%s", err.Error()) + } + auth = ssh.PublicKeys(signer) + } + cfg.Auth = []ssh.AuthMethod{auth} + var err error + s.sshClient, err = ssh.Dial("tcp", s.opt.Host, cfg) + if err != nil { + return fmt.Errorf("ssh连接失败:%s", err.Error()) + } + return nil +} + +// setPath 设置当前 文件全路径 +// p 如果是绝对路径,那么直接返回p +// p 如果是相对路径,会跟上当前目录 +func (s *Saver) setPath(p string) (string, error) { + current, err := s.sftpClient.Getwd() + if err != nil { + return "", err + } + return s.sftpClient.Join(current, p), nil +} + +func (s *Saver) exist(p string) (bool, error) { + if _, err := s.sftpClient.Stat(p); err == nil { + return true, nil + } else if !os.IsNotExist(err) { + return false, err + } + return false, nil +} + +func (s *Saver) mkdir(p string) error { + if ok, err := s.exist(p); err != nil { + return err + } else if ok { + return nil + } + if err := s.sftpClient.MkdirAll(p); err != nil { + return fmt.Errorf("创建文件夹%s失败:%s", p, err.Error()) + } + return nil +} diff --git a/file/filesaver/sftp/types.go b/file/filesaver/sftp/types.go new file mode 100644 index 0000000..3d91b00 --- /dev/null +++ b/file/filesaver/sftp/types.go @@ -0,0 +1,85 @@ +package sftp + +import ( + "database/sql/driver" + "fmt" + + "github.com/helays/utils/v2/config" + "github.com/helays/utils/v2/dataType" + "github.com/helays/utils/v2/net/checkIp" + "gorm.io/gorm" + "gorm.io/gorm/schema" +) + +type Config struct { + Host string `json:"host" yaml:"host" ini:"host"` + User string `json:"user" yaml:"user" ini:"user"` + Pwd string `json:"pwd" yaml:"pwd" ini:"pwd"` + Authentication Authentication `json:"authentication" yaml:"authentication" ini:"authentication"` +} + +type Authentication string + +const ( + Password Authentication = "password" + PublicKey Authentication = "public_key" +) + +// noinspection all +func (c *Config) RemovePasswd() { + c.Pwd = "" +} + +// noinspection all +func (c *Config) Valid() error { + if _, port, err := checkIp.ParseIPAndPort(c.Host); err != nil { + return err + } else if port < 1 { + return fmt.Errorf("缺失端口号") + } + if c.User == "" { + return fmt.Errorf("缺失账号") + } + if c.Pwd == "" { + return fmt.Errorf("缺失密码") + } + if c.Authentication == "" { + c.Authentication = Password + } else if c.Authentication != Password && c.Authentication != PublicKey { + return fmt.Errorf("无效的认证方式") + } + return nil +} + +// noinspection all +func (c *Config) SetInfo(args ...any) { + if len(args) != 2 { + return + } + switch args[0].(string) { + case config.ClientInfoHost: + c.Host = args[1].(string) + case config.ClientInfoUser: + c.User = args[1].(string) + case config.ClientInfoPasswd: + c.Pwd = args[1].(string) + } +} + +// noinspection all +func (c Config) Value() (driver.Value, error) { + return dataType.DriverValueWithJson(c) +} + +// noinspection all +func (c *Config) Scan(val interface{}) error { + return dataType.DriverScanWithJson(val, c) +} + +func (c Config) GormDataType() string { + return "json" +} + +func (Config) GormDBDataType(db *gorm.DB, field *schema.Field) string { + return dataType.JsonDbDataType(db, field) +} diff --git a/file/filesaver/types.go b/file/filesaver/types.go new file mode 100644 index 0000000..099c901 --- /dev/null +++ b/file/filesaver/types.go @@ -0,0 +1,43 @@ +package filesaver + +import ( + "io" + + "github.com/helays/utils/v2/file/filesaver/ftp" + "github.com/helays/utils/v2/file/filesaver/hdfs" + "github.com/helays/utils/v2/file/filesaver/localfile" + "github.com/helays/utils/v2/file/filesaver/minio" + "github.com/helays/utils/v2/file/filesaver/sftp" +) + +type FileSaver interface { + Write(p string, src io.Reader, existIgnores ...bool) (int64, error) // 写入文件 + Read(p string) (io.ReadCloser, error) // 读取文件 + Delete(p string) error // 删除指定文件 + DeleteAll(p string) error // 删除文件夹 + ListFiles(p string) ([]string, error) // 列出指定目录下的所有文件 + Close() error // 关闭资源 +} + +type Driver string + +// noinspection all +const ( + DriverLocal Driver = "local" + DriverSftp Driver = "sftp" + DriverFtp Driver = "ftp" + DriverHdfs Driver = "hdfs" + DriverMinio Driver = "minio" + DriverCeph Driver = "ceph" +) + +// noinspection all +type Config struct { + Driver Driver `json:"driver" yaml:"driver" ini:"driver"` + + Local localfile.Config `json:"local" yaml:"local" ini:"local"` // 本地文件系统 + FTP ftp.Config `json:"ftp" yaml:"ftp" ini:"ftp"` // ftp + SFTP sftp.Config `json:"sftp" yaml:"sftp" ini:"sftp"` // sftp + HDFS hdfs.Config `json:"hdfs" yaml:"hdfs" ini:"hdfs"` // hdfs + Minio minio.Config `json:"minio" yaml:"minio" ini:"minio"` // minio +} From 0a9f2f0d35e2369797c71eeeb0d9c906da6aab05 Mon Sep 17 00:00:00 2001 From: helei Date: Thu, 22 Jan 2026 21:00:46 +0800 Subject: [PATCH 2/9] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E4=B8=80=E4=B8=AA?= =?UTF-8?q?=E9=80=9A=E7=94=A8=E7=9A=84=E7=BC=93=E5=AD=98=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=20=E7=9B=AE=E5=89=8D=E5=9F=BA=E4=BA=8E=E5=86=85=E5=AD=98?= =?UTF-8?q?=E7=9A=84=E7=BC=93=E5=AD=98=E5=AE=8C=E6=88=90=E4=BA=86=E3=80=82?= =?UTF-8?q?=20=E5=90=8E=E9=9D=A2=E5=9F=BA=E4=BA=8Eredis=E3=80=81=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E3=80=81rdbms=E7=AD=89=E9=A9=B1=E5=8A=A8=E7=9A=84?= =?UTF-8?q?=E5=BE=85=E5=AE=9E=E7=8E=B0=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: helei --- dataType/int.go | 61 ++++++++++++++++ dataType/sessionval.go | 58 +++++++++++++++ net/http/session/manager.go | 14 ++-- net/http/session/standard_manager.go | 2 +- .../session/storage/carrier_memory/memory.go | 2 +- net/http/session/types.go | 72 +------------------ net/ipmatch/build.go | 4 +- safe/cachemgr/rdbms/types.go | 38 ++++++++++ safe/cachemgr/types.go | 70 ++++++++++++++++++ safe/map.go | 15 ++-- safe/types.go | 12 +++- security/lockpolicy/cache.go | 4 +- 12 files changed, 258 insertions(+), 94 deletions(-) create mode 100644 dataType/sessionval.go create mode 100644 safe/cachemgr/rdbms/types.go create mode 100644 safe/cachemgr/types.go diff --git a/dataType/int.go b/dataType/int.go index 0170d9f..11a290b 100644 --- a/dataType/int.go +++ b/dataType/int.go @@ -53,3 +53,64 @@ func (Byte) GormDBDataType(db *gorm.DB, field *schema.Field) string { } return "int" } + +type Uint64 struct { + uint64 +} + +func NewUint64(v uint64) Uint64 { + return Uint64{uint64: v} +} + +func (u *Uint64) GetValue() uint64 { + return u.uint64 +} + +func (u *Uint64) SetValue(v uint64) { + u.uint64 = v +} + +func (u *Uint64) Equals(other Uint64) bool { + return u.uint64 == other.uint64 +} + +func (u *Uint64) EqualsInt(other int) bool { + return u.uint64 == uint64(other) +} + +func (u *Uint64) EqualsUint64(other uint64) bool { + return u.uint64 == other +} + +// noinspection all +func (u Uint64) Value() (driver.Value, error) { + return u.uint64, nil +} + +// noinspection all +func (u *Uint64) Scan(value any) error { + if value == nil { + return nil + } + v, err := tools.Any2Int[uint64](value) + if err != nil { + return err + } + u.uint64 = v + return nil +} + +// noinspection all +func (u Uint64) GormDBDataType(db *gorm.DB, field *schema.Field) string { + switch db.Dialector.Name() { + case config.DbTypeSqlite: + return "integer" + case config.DbTypeMysql: + return "BIGINT UNSIGNED" + case config.DbTypePostgres: + return "BIGINT" + case config.DbTypeSqlserver: + return "BIGINT" + } + return "bigint" +} diff --git a/dataType/sessionval.go b/dataType/sessionval.go new file mode 100644 index 0000000..762b16f --- /dev/null +++ b/dataType/sessionval.go @@ -0,0 +1,58 @@ +package dataType + +import ( + "bytes" + "database/sql/driver" + "encoding/gob" + + "github.com/helays/utils/v2/tools" + "gorm.io/gorm" + "gorm.io/gorm/schema" +) + +func NewSessionValue(val any) SessionValue { + return SessionValue{Val: val} +} + +// noinspection all +type SessionValue struct { + Val any +} + +// Value return blob value, implement driver.Valuer interface +// noinspection all +func (s SessionValue) Value() (driver.Value, error) { + var buf bytes.Buffer + err := gob.NewEncoder(&buf).Encode(s) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// noinspection all +func (s *SessionValue) Scan(val any) error { + if val == nil { + *s = SessionValue{} + return nil + } + + b, err := tools.Any2bytes(val) + if err != nil { + return err + } + + err = gob.NewDecoder(bytes.NewReader(b)).Decode(s) + + return err +} + +// GormDBDataType gorm db data type +// noinspection all +func (SessionValue) GormDBDataType(db *gorm.DB, field *schema.Field) string { + return BlobDbDataType(db, field) +} + +func (s SessionValue) GormDataType() string { + return "blob" +} diff --git a/net/http/session/manager.go b/net/http/session/manager.go index dd14e59..18717ee 100644 --- a/net/http/session/manager.go +++ b/net/http/session/manager.go @@ -19,12 +19,12 @@ const ( // Session session 数据结构 type Session struct { - Id string `json:"id" gorm:"primaryKey;autoIncrement:false;type:varchar(64);not null;index;comment:Session ID"` - Name string `json:"name" gorm:"primaryKey;autoIncrement:false;type:varchar(128);not null;index;comment:Session的名字"` - Values SessionValue `json:"values" gorm:"comment:session数据"` - CreateTime dataType.CustomTime `json:"create_time" gorm:"comment:session 创建时间"` - ExpireTime dataType.CustomTime `json:"expire_time" gorm:"not null;index;comment:session 过期时间"` - Duration time.Duration `json:"duration" gorm:"comment:session有效期"` + Id string `json:"id" gorm:"primaryKey;autoIncrement:false;type:varchar(64);not null;index;comment:Session ID"` + Name string `json:"name" gorm:"primaryKey;autoIncrement:false;type:varchar(128);not null;index;comment:Session的名字"` + Values dataType.SessionValue `json:"values" gorm:"comment:session数据"` + CreateTime dataType.CustomTime `json:"create_time" gorm:"comment:session 创建时间"` + ExpireTime dataType.CustomTime `json:"expire_time" gorm:"not null;index;comment:session 过期时间"` + Duration time.Duration `json:"duration" gorm:"comment:session有效期"` } type Manager struct { @@ -60,7 +60,7 @@ func New(ctx context.Context, storage StorageDriver, opt ...*Options) *Manager { options: options, storage: storage, } - gob.Register(SessionValue{}) + gob.Register(dataType.SessionValue{}) if !options.DisableGc { go manager.startGC(ctx) } diff --git a/net/http/session/standard_manager.go b/net/http/session/standard_manager.go index c754eae..2f502fc 100644 --- a/net/http/session/standard_manager.go +++ b/net/http/session/standard_manager.go @@ -303,7 +303,7 @@ func (m *Manager) SetVal(value *Value) error { sv := Session{ Id: value.SessionID, Name: value.Field, - Values: NewSessionValue(value.Value), + Values: dataType.NewSessionValue(value.Value), CreateTime: dataType.NewCustomTime(now), Duration: tools.Ternary(value.TTL > 0, value.TTL, ExpireTime), } diff --git a/net/http/session/storage/carrier_memory/memory.go b/net/http/session/storage/carrier_memory/memory.go index 77e228a..f1c7a42 100644 --- a/net/http/session/storage/carrier_memory/memory.go +++ b/net/http/session/storage/carrier_memory/memory.go @@ -15,7 +15,7 @@ type Instance struct { func New(ctx context.Context) *Instance { i := &Instance{} - i.storage = safe.NewMap[string, *session.Session](ctx, safe.StringHasher{}, safe.MapConfig{ + i.storage = safe.NewMap[string, *session.Session](ctx, safe.StringHasher{}, safe.CacheConfig{ EnableCleanup: true, ClearInterval: time.Minute / 2, TTL: time.Minute, diff --git a/net/http/session/types.go b/net/http/session/types.go index aaea9ca..a42f674 100644 --- a/net/http/session/types.go +++ b/net/http/session/types.go @@ -1,65 +1,12 @@ package session import ( - "bytes" - "database/sql/driver" - "encoding/gob" "errors" "time" "github.com/helays/utils/v2/dataType" - "github.com/helays/utils/v2/tools" - "gorm.io/gorm" - "gorm.io/gorm/schema" ) -func NewSessionValue(val any) SessionValue { - return SessionValue{Val: val} -} - -// noinspection all -type SessionValue struct { - Val any -} - -// Value return blob value, implement driver.Valuer interface -// noinspection all -func (s SessionValue) Value() (driver.Value, error) { - var buf bytes.Buffer - err := gob.NewEncoder(&buf).Encode(s) - if err != nil { - return nil, err - } - return buf.Bytes(), nil -} - -// noinspection all -func (s *SessionValue) Scan(val any) error { - if val == nil { - *s = SessionValue{} - return nil - } - - b, err := tools.Any2bytes(val) - if err != nil { - return err - } - - err = gob.NewDecoder(bytes.NewReader(b)).Decode(s) - - return err -} - -// GormDBDataType gorm db data type -// noinspection all -func (SessionValue) GormDBDataType(db *gorm.DB, field *schema.Field) string { - return dataType.BlobDbDataType(db, field) -} - -func (s SessionValue) GormDataType() string { - return "blob" -} - // 这个需要移除 上级Session 已经实现了二进制序列化 //func (s SessionValue) GobEncode() ([]byte, error) { // return msgpack.Marshal(s.val) @@ -81,22 +28,9 @@ const ( ) var ( - ErrUnSupport = errors.New("不支持的session载体") - ErrNotFound = errors.New("session不存在") - ErrNotPointer = errors.New("session变量目标必须是指针") -) - -type Engine string - -func (e Engine) String() string { - return string(e) -} - -const ( - EngineRedis Engine = "redis" - EngineRdbms Engine = "rdbms" - EngineMemory Engine = "memory" - EngineFile Engine = "file" + ErrUnSupport = errors.New("不支持的 session 载体") + ErrNotFound = errors.New("session 不存在") + ErrNotPointer = errors.New("session 变量目标必须是指针") ) // noinspection all diff --git a/net/ipmatch/build.go b/net/ipmatch/build.go index 50aa88d..d5184b3 100644 --- a/net/ipmatch/build.go +++ b/net/ipmatch/build.go @@ -44,12 +44,12 @@ func NewIPMatcher(ctx context.Context, config *Config) (*IPMatcher, error) { } ipv4CacheTTL := tools.AutoTimeDuration(m.config.IPv4CacheTTL, time.Second, 30*time.Second) ipv6CacheTTL := tools.AutoTimeDuration(m.config.IPv6CacheTTL, time.Second, 10*time.Second) - m.ipv4Cache = safe.NewMap[uint32, struct{}](ctx, safe.IntegerHasher[uint32]{}, safe.MapConfig{ + m.ipv4Cache = safe.NewMap[uint32, struct{}](ctx, safe.IntegerHasher[uint32]{}, safe.CacheConfig{ EnableCleanup: true, ClearInterval: ipv4CacheTTL / 2, TTL: ipv4CacheTTL, }) - m.ipv6Cache = safe.NewMap[[16]byte, struct{}](ctx, safe.Array16Hasher{}, safe.MapConfig{ + m.ipv6Cache = safe.NewMap[[16]byte, struct{}](ctx, safe.Array16Hasher{}, safe.CacheConfig{ EnableCleanup: true, ClearInterval: ipv6CacheTTL / 2, TTL: ipv6CacheTTL, diff --git a/safe/cachemgr/rdbms/types.go b/safe/cachemgr/rdbms/types.go new file mode 100644 index 0000000..b22780e --- /dev/null +++ b/safe/cachemgr/rdbms/types.go @@ -0,0 +1,38 @@ +package rdbms + +import ( + "github.com/helays/utils/v2/dataType" + "github.com/helays/utils/v2/db" +) + +// CacheFast 采用双hash模式,缓存数据结构 +// 效率足够高,但是在大规模数据下 可能存在hash冲突。 +type CacheFast struct { + InstanceHash dataType.Uint64 `json:"instance_hash" gorm:"primaryKey;autoIncrement:false;index;comment:缓存实例hash"` + KeyHash dataType.Uint64 `json:"key_hash" gorm:"primaryKey;autoIncrement:false;comment:缓存key hash"` + InstanceID string `json:"instance_id" gorm:"type:varchar(32);not null;comment:缓存实例标识"` + CacheKey string `json:"key" gorm:"type:varchar(512);not null;comment:缓存key"` + Value dataType.SessionValue `json:"value" gorm:"comment:缓存数据"` + ExpiresTime *dataType.CustomTime `json:"expires_time" gorm:"index;comment:过期时间"` + db.TableDefaultTimeField +} + +func (CacheFast) TableName() string { + return "cache_fast" +} + +// CacheSafe 采用单hash模式,缓存数据结构 +// 缓存实例标识,数量很少,采用 xxhash 足够安全 +// 缓存key,数量很多,原样保存 +type CacheSafe struct { + InstanceHash dataType.Uint64 `json:"instance_hash" gorm:"primaryKey;autoIncrement:false;index;comment:缓存实例hash"` + CacheKey string `json:"key" gorm:"primaryKey;type:varchar(512);not null;comment:缓存key"` + InstanceID string `json:"instance_id" gorm:"type:varchar(32);not null;comment:缓存实例标识"` + Value dataType.SessionValue `json:"value" gorm:"comment:缓存数据"` + ExpiresTime *dataType.CustomTime `json:"expires_time" gorm:"index;comment:过期时间"` + db.TableDefaultTimeField +} + +func (CacheSafe) TableName() string { + return "cache_safe" +} diff --git a/safe/cachemgr/types.go b/safe/cachemgr/types.go new file mode 100644 index 0000000..3a841f2 --- /dev/null +++ b/safe/cachemgr/types.go @@ -0,0 +1,70 @@ +package cachemgr + +import ( + "context" + "time" + + "github.com/helays/utils/v2/safe" +) + +// Cache 缓存接口 +type Cache[K comparable, V any] interface { + SetOnExpired(onExpired safe.OnExpired[K]) // 设置过期回调函数。 + Load(key K) (V, bool) // 获取键的值。 + LoadOrStore(key K, val V, duration ...time.Duration) (actual V, loaded bool) // 获取键的值,如果键不存在则存储键值对。 + LoadOrStoreFunc(key K, valueFunc func(k K) (V, error), duration ...time.Duration) (V, bool, error) // 获取键的值,如果键不存在则存储键值对。 + LoadAndDelete(key K) (V, bool) // 获取键的值并删除键。 + LoadAndDeleteIf(key K, condition func(value V) bool) (V, bool) // 获取键的值并删除键,如果满足条件。 + LoadAndRefresh(key K, duration ...time.Duration) (V, bool) // 获取键的值并刷新键的过期时间。 + LoadWithExpiry(key K) (V, time.Time, bool) // 获取键的值和过期时间。 + Refresh(key K, duration ...time.Duration) bool // 刷新键的过期时间。 + GetTTL(key K) (time.Duration, bool) // 获取键的剩余过期时间。 + IsExpired(key K) bool // 判断键是否已过期。 + GetHeartbeat(key K) (time.Time, bool) // 获取键的更新时间。 + + Store(key K, val V, duration ...time.Duration) // 存储键值对。 + Delete(key K) // 删除键。 + DeleteAndGetCount(keys ...K) int // 删除多个键并返回删除的键数量。 + DeleteAll() // 删除所有键。 + Range(f func(key K, value V) bool) // 遍历缓存中的键值对。 + DeletePrefix(prefix string) // 删除以指定前缀开头的键。 + DeleteSuffix(suffix string) // 删除以指定后缀结尾的键。 +} + +type Driver string + +// noinspection all +const ( + // 内存缓存 + DriverMemory Driver = "memory" + // 关系数据库缓存 + // 缓存数据都在一张表中,但是由于有多个缓存实例,所以需要定义一个缓存标识来区分。 + // 这是一个联合索引。identity+key + DriverRdbms Driver = "rdbms" + DriverFile Driver = "file" // 文件缓存 + DriverRedis Driver = "redis" // redis 缓存 +) + +type Config struct { + Driver Driver `json:"driver" yaml:"driver" ini:"driver"` // 缓存驱动 + Identity string `json:"identity" yaml:"identity" ini:"identity"` // 缓存标识,在非内存存储下,非常有用。用于隔离多个缓存实例里面的数据。 + safe.CacheConfig `json:"memory" yaml:"memory" ini:"memory"` +} + +func New[K comparable, V any](ctx context.Context, hasher safe.Hasher[K], cfg Config) Cache[K, V] { + var ( + driver Cache[K, V] + ) + switch cfg.Driver { + case DriverMemory: + driver = safe.NewMap[K, V](ctx, hasher, cfg.CacheConfig) + default: + return nil + } + + return NewWithDriver[K, V](driver) +} + +func NewWithDriver[K comparable, V any](driver Cache[K, V]) Cache[K, V] { + return driver +} diff --git a/safe/map.go b/safe/map.go index f738e5a..b7b0680 100644 --- a/safe/map.go +++ b/safe/map.go @@ -31,14 +31,6 @@ import ( //适合:仅限高性能计算场景 type ( - MapConfig struct { - EnableCleanup bool // 是否启用自动清理功能 - ClearInterval time.Duration // 清理间隔,推荐值是设置成 ttl/2 或者 ttl/3 - TTL time.Duration // 默认TTL为0,表示不过期 - ShardSize uint64 // 分片数量,默认为2的8次方 - UseKey bool // 是否使用 key 作为 hash 值 - } - value[K comparable, V any] struct { key K val V @@ -67,11 +59,11 @@ type ( // 清理间隔,推荐值是设置成 ttl/2 或者 ttl/3 // 当前启用自动清理后,clearInterval必须设置 clearInterval time.Duration - onExpired onExpired[K] // 再过期时刻触发的回调操作。 + onExpired OnExpired[K] // 再过期时刻触发的回调操作。 } ) -func NewMap[K comparable, V any](ctx context.Context, hasher Hasher[K], configs ...MapConfig) *Map[K, V] { +func NewMap[K comparable, V any](ctx context.Context, hasher Hasher[K], configs ...CacheConfig) *Map[K, V] { m := &Map[K, V]{ ctx: ctx, hasher: hasher, @@ -111,7 +103,8 @@ func NewMap[K comparable, V any](ctx context.Context, hasher Hasher[K], configs return m } -func (m *Map[K, V]) SetOnExpired(onExpired onExpired[K]) { +// SetOnExpired 设置过期回调函数。 +func (m *Map[K, V]) SetOnExpired(onExpired OnExpired[K]) { m.onExpired = onExpired } diff --git a/safe/types.go b/safe/types.go index f3be68f..34a002f 100644 --- a/safe/types.go +++ b/safe/types.go @@ -1,6 +1,7 @@ package safe import ( + "time" "unsafe" "github.com/cespare/xxhash/v2" @@ -12,7 +13,16 @@ const ( defaultCapacity = 1 << 6 // 默认缓存大小 ) -type onExpired[K comparable] func(key []K) // 过期回调 +type CacheConfig struct { + EnableCleanup bool `json:"enable_cleanup" yaml:"enable_cleanup" ini:"enable_cleanup"` // 是否启用自动清理功能 + ClearInterval time.Duration `json:"clear_interval" yaml:"clear_interval" ini:"clear_interval"` // 清理间隔,推荐值是设置成 ttl/2 或者 ttl/3 + TTL time.Duration `json:"ttl" yaml:"ttl" ini:"ttl"` // 默认TTL为0,表示不过期 + ShardSize uint64 `json:"shard_size" yaml:"shard_size" ini:"shard_size"` // 分片数量,默认为2的8次方 + UseKey bool `json:"use_key" yaml:"use_key" ini:"use_key"` // 是否使用 key 作为 hash 值 +} + +// OnExpired 过期回调 +type OnExpired[K comparable] func(key []K) // 过期回调 // Hasher 编译时确定的哈希函数 type Hasher[K comparable] interface { diff --git a/security/lockpolicy/cache.go b/security/lockpolicy/cache.go index 3c89c5c..7555cda 100644 --- a/security/lockpolicy/cache.go +++ b/security/lockpolicy/cache.go @@ -19,12 +19,12 @@ type targetCache struct { func newTargetCache(ctx context.Context, policy *Policy) *targetCache { return &targetCache{ policy: policy, - triggerCount: safe.NewMap[string, *safe.ResourceRWMutex[int]](ctx, safe.StringHasher{}, safe.MapConfig{ + triggerCount: safe.NewMap[string, *safe.ResourceRWMutex[int]](ctx, safe.StringHasher{}, safe.CacheConfig{ EnableCleanup: true, ClearInterval: time.Minute / 2, TTL: time.Minute, }), - isLocked: safe.NewMap[string, *safe.ResourceRWMutex[bool]](ctx, safe.StringHasher{}, safe.MapConfig{ + isLocked: safe.NewMap[string, *safe.ResourceRWMutex[bool]](ctx, safe.StringHasher{}, safe.CacheConfig{ EnableCleanup: true, ClearInterval: time.Minute / 2, TTL: time.Minute, From f5f988c332cc58bba86531e6a86697ce242e1ede Mon Sep 17 00:00:00 2001 From: vsclub <1217179982@qq.com> Date: Fri, 23 Jan 2026 23:00:16 +0800 Subject: [PATCH 3/9] =?UTF-8?q?=E8=87=AA=E5=AE=9A=E4=B9=89int=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=EF=BC=8C=E5=9C=A8=E5=86=99=E8=A1=A8=E6=97=B6=EF=BC=8C?= =?UTF-8?q?=E5=8F=AA=E6=94=AF=E6=8C=81int64=E7=B1=BB=E5=9E=8B=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dataType/int.go | 1 + 1 file changed, 1 insertion(+) diff --git a/dataType/int.go b/dataType/int.go index 11a290b..9b20415 100644 --- a/dataType/int.go +++ b/dataType/int.go @@ -13,6 +13,7 @@ type Byte byte // noinspection all func (b Byte) Value() (driver.Value, error) { + // 注意,这里只能接受int64类型 return int64(b), nil } From 6b6de6926e94ce18a0ac4115d2e5156a4f6f58cf Mon Sep 17 00:00:00 2001 From: vsclub <1217179982@qq.com> Date: Sat, 24 Jan 2026 14:12:13 +0800 Subject: [PATCH 4/9] =?UTF-8?q?=E8=87=AA=E5=AE=9A=E4=B9=89int=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=EF=BC=8C=E5=9C=A8=E5=86=99=E8=A1=A8=E6=97=B6=EF=BC=8C?= =?UTF-8?q?=E5=8F=AA=E6=94=AF=E6=8C=81int64=E7=B1=BB=E5=9E=8B=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dataType/int.go | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/dataType/int.go b/dataType/int.go index 9b20415..e8155fe 100644 --- a/dataType/int.go +++ b/dataType/int.go @@ -2,6 +2,7 @@ package dataType import ( "database/sql/driver" + "fmt" "github.com/helays/utils/v2/config" "github.com/helays/utils/v2/tools" @@ -23,20 +24,11 @@ func (b *Byte) Scan(value any) error { *b = 0 return nil } - switch t := value.(type) { - case byte: - *b = Byte(t) - case int8: - *b = Byte(t) - case int: - *b = Byte(t) - default: - v, err := tools.Any2Int[Byte](value) - if err != nil { - return err - } - *b = Byte(v) + v, err := tools.Any2Int[byte](value) + if err != nil { + return fmt.Errorf("Byte.Scan: unknown type %T", value) } + *b = Byte(v) return nil } From 4c1de387bd84ea602a71d1adc1fdc9097ea39dec Mon Sep 17 00:00:00 2001 From: vsclub <1217179982@qq.com> Date: Sun, 25 Jan 2026 13:40:04 +0800 Subject: [PATCH 5/9] =?UTF-8?q?=E9=80=9A=E7=94=A8list=20to=20tree=20?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E5=BC=BA=E5=8C=96=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=A4=9A=E9=A1=B6=E7=BA=A7=E8=8A=82=E7=82=B9=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/tree/treeconv/listotree.go | 99 ++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 44 deletions(-) diff --git a/tools/tree/treeconv/listotree.go b/tools/tree/treeconv/listotree.go index e64c9a5..843d7b7 100644 --- a/tools/tree/treeconv/listotree.go +++ b/tools/tree/treeconv/listotree.go @@ -37,62 +37,73 @@ func ListToTree[L ListItem[T, S, L], T comparable, S TreeNode[L, S]](src []L) [] children[parentID] = append(children[parentID], id) idToSrc[id] = item } - var zeroID T + var ( + zeroIDs []T + zeroMap = make(map[T]struct{}) + ) + + // 筛选所有顶级节点。 for parentID, _ := range children { if _, ok := idToSrc[parentID]; !ok { - zeroID = parentID - break + if _, ok = zeroMap[parentID]; !ok { + zeroIDs = append(zeroIDs, parentID) + zeroMap[parentID] = struct{}{} + } } } var tempNode S - //对每个根节点分别构建数 - for _, rootId := range children[zeroID] { - level := 1 - rootData := idToSrc[rootId] - rootNode := tempNode.ToTreeNode(rootData, level) - //rootNode := rootData.ToTreeNode(level) - // 使用栈进行深度优先遍历 - type stackItem struct { - node S - id T - level int - } - stack := []stackItem{{node: rootNode, id: rootId, level: level}} - for len(stack) > 0 { - // 弹出栈顶 - top := stack[len(stack)-1] - stack = stack[:len(stack)-1] - // 获取子节点 ID - childIds := children[top.id] - if len(childIds) == 0 { - continue + // 分别对每个顶级节点处理 + for _, zeroID := range zeroIDs { + //对每个根节点分别构建数 + for _, rootId := range children[zeroID] { + level := 1 + rootData := idToSrc[rootId] + rootNode := tempNode.ToTreeNode(rootData, level) + //rootNode := rootData.ToTreeNode(level) + // 使用栈进行深度优先遍历 + type stackItem struct { + node S + id T + level int } - // 初始化子节点切片 - top.node.PrepareChildren(len(childIds)) - // 遍历子节点 - for _, childId := range childIds { - childData := idToSrc[childId] - // 创建子节点 - childNode := tempNode.ToTreeNode(childData, top.level+1, top.node) - //childNode := childData.ToTreeNode(top.level+1, top.node) - // 添加父节点 - top.node.AddChild(childNode) - // 由于上面的SetChildren 里面会解引用,所以这里不能直接用childNode - lastChild, ok := top.node.GetLastChild() - // 确保 GetLastChild() 不为 nil - if !ok { + stack := []stackItem{{node: rootNode, id: rootId, level: level}} + for len(stack) > 0 { + // 弹出栈顶 + top := stack[len(stack)-1] + stack = stack[:len(stack)-1] + // 获取子节点 ID + childIds := children[top.id] + if len(childIds) == 0 { continue } - stack = append(stack, stackItem{ - node: lastChild, - id: childId, - level: top.level + 1, - }) + // 初始化子节点切片 + top.node.PrepareChildren(len(childIds)) + // 遍历子节点 + for _, childId := range childIds { + childData := idToSrc[childId] + // 创建子节点 + childNode := tempNode.ToTreeNode(childData, top.level+1, top.node) + //childNode := childData.ToTreeNode(top.level+1, top.node) + // 添加父节点 + top.node.AddChild(childNode) + // 由于上面的SetChildren 里面会解引用,所以这里不能直接用childNode + lastChild, ok := top.node.GetLastChild() + // 确保 GetLastChild() 不为 nil + if !ok { + continue + } + stack = append(stack, stackItem{ + node: lastChild, + id: childId, + level: top.level + 1, + }) + } } + + result = append(result, rootNode) } - result = append(result, rootNode) } return result From d196024bca669d205456e478bdead1593f00d249 Mon Sep 17 00:00:00 2001 From: helei Date: Mon, 26 Jan 2026 19:37:16 +0800 Subject: [PATCH 6/9] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E4=B8=80=E4=B8=AA?= =?UTF-8?q?=E9=80=9A=E7=94=A8=E7=9A=84=E7=BC=93=E5=AD=98=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=20=E7=9B=AE=E5=89=8D=E5=9F=BA=E4=BA=8E=E5=86=85=E5=AD=98?= =?UTF-8?q?=E7=9A=84=E7=BC=93=E5=AD=98=E5=AE=8C=E6=88=90=E4=BA=86=E3=80=82?= =?UTF-8?q?=20=E5=90=8E=E9=9D=A2=E5=9F=BA=E4=BA=8Eredis=E3=80=81=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E3=80=81rdbms=E7=AD=89=E9=A9=B1=E5=8A=A8=E7=9A=84?= =?UTF-8?q?=E5=BE=85=E5=AE=9E=E7=8E=B0=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: helei --- net/tlsconfig/tls.go | 78 +++++++++++++++++++++++++++++--------------- 1 file changed, 51 insertions(+), 27 deletions(-) diff --git a/net/tlsconfig/tls.go b/net/tlsconfig/tls.go index 8afaf69..19afb83 100644 --- a/net/tlsconfig/tls.go +++ b/net/tlsconfig/tls.go @@ -10,6 +10,14 @@ import ( "github.com/helays/utils/v2/tools" ) +var CurvePreferencesMap = map[string]tls.CurveID{ + "CurveP256": tls.CurveP256, // 标准椭圆曲线 (传统加密) + "CurveP384": tls.CurveP384, // 标准椭圆曲线 (高安全性) + "CurveP521": tls.CurveP521, // 标准椭圆曲线 (高安全性) + "X25519": tls.X25519, // 密码学曲线 最快,比P256快3-4倍 + "X25519MLKEM768": tls.X25519MLKEM768, // 后量子混合曲线 +} + // CipherSuiteMapping 密码套件映射 var CipherSuiteMapping = map[string]uint16{ // TLS 1.3 密码套件 @@ -42,38 +50,42 @@ var CipherSuiteMapping = map[string]uint16{ // 预定义的密码套件组合 var ( ModernCipherSuites = []string{ - // TLS 1.3 密码套件 - "TLS_AES_128_GCM_SHA256", - "TLS_AES_256_GCM_SHA384", - "TLS_CHACHA20_POLY1305_SHA256", - - // 安全的 TLS 1.2 密码套件 - "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", - "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", - "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", - "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", - "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", - "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305", + // TLS 1.2 密码套件(按性能排序,从最快到最慢) + + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", // 1. ECDSA证书 + ChaCha20-Poly1305(移动设备/无AES-NI时最快) + "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305", // 2. RSA证书 + ChaCha20-Poly1305(兼容移动设备) + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", // 3. ECDSA证书 + AES-128-GCM(服务器有AES-NI时最快) + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", // 4. RSA证书 + AES-128-GCM(广泛兼容且快) + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", // 5. ECDSA证书 + AES-256-GCM(更高安全性) + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", // 6. RSA证书 + AES-256-GCM(高安全性兼容) + + // TLS 1.3 密码套件(Go会自动处理,放最后) + "TLS_AES_128_GCM_SHA256", // TLS 1.3最快 + "TLS_CHACHA20_POLY1305_SHA256", // TLS 1.3移动友好 + "TLS_AES_256_GCM_SHA384", // TLS 1.3高安全 } // 现代密码套件 (TLS 1.3 + 安全的 TLS 1.2) CompatibleCipherSuites = []string{ - // TLS 1.3 + // TLS 1.2 密码套件(按性能排序) + // 第1梯队:现代快速套件 + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", // 最快(ECDSA+ChaCha20) + "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305", // 快(RSA+ChaCha20) + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", // 快(ECDSA+AES-128) + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", // 广泛兼容且快 + + // 第2梯队:高安全性套件 + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", // 高安全 ECDSA + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", // 高安全 RSA + + // 第3梯队:兼容性套件(性能较差,必要时使用) + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", // CBC模式,有AES-NI时还行 + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", // CBC+ECDSA + "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", // CBC+RSA + + // TLS 1.3 密码套件(放最后,Go会忽略) "TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384", "TLS_CHACHA20_POLY1305_SHA256", - - // TLS 1.2 - "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", - "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", - "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", - "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", - "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", - "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305", - - // 更多兼容性套件 - "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", - "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", - "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", } // 兼容性密码套件 (包含更多旧版本支持) ) @@ -167,6 +179,7 @@ type TLSConfig struct { ClientCAFile string `json:"client_ca_file" yaml:"client_ca_file"` // 客户端CA文件 InsecureSkipVerify bool `json:"insecure_skip_verify" yaml:"insecure_skip_verify"` // 跳过验证 CipherSuites []string `json:"cipher_suites" yaml:"cipher_suites"` // 密钥套件 + CurvePreferences []string `json:"curve_preferences" yaml:"curve_preferences"` // 曲线偏好 SessionTicketsDisabled bool `json:"session_tickets_disabled" yaml:"session_tickets_disabled"` // 禁用会话密钥 MinVersion string `json:"min_version" yaml:"min_version"` // 最低TLS版本 MaxVersion string `json:"max_version" yaml:"max_version"` // 最高TLS版本 @@ -194,6 +207,17 @@ func (t *TLSConfig) ToTLSConfig() (*tls.Config, error) { Renegotiation: t.Renegotiation, } + // 自定义曲线 + if len(t.CurvePreferences) > 0 { + for _, curveName := range t.CurvePreferences { + curveID, exists := CurvePreferencesMap[curveName] + if !exists { + return nil, fmt.Errorf("未知的曲线: %s", curveName) + } + config.CurvePreferences = append(config.CurvePreferences, curveID) + } + } + // 解析 TLS 版本 if t.MinVersion != "" { minVersion, err := ParseTLSVersion(t.MinVersion) @@ -224,7 +248,7 @@ func (t *TLSConfig) ToTLSConfig() (*tls.Config, error) { config.CipherSuites = cipherSuites } else { // 默认使用现代密码套件 - cipherSuites, err = ParseCipherSuites(ModernCipherSuites) + cipherSuites, err = ParseCipherSuites(CompatibleCipherSuites) if err != nil { return nil, fmt.Errorf("解析默认密码套件失败: %v", err) } From 25348d6c30b8026c100389146020191dbef7e0e4 Mon Sep 17 00:00:00 2001 From: helei Date: Tue, 27 Jan 2026 09:51:46 +0800 Subject: [PATCH 7/9] =?UTF-8?q?tls=E9=83=A8=E5=88=86=E4=BC=98=E5=8C=96?= =?UTF-8?q?=EF=BC=8C=E5=AF=B9=E4=BA=8Eca=E8=AF=81=E4=B9=A6=EF=BC=8C?= =?UTF-8?q?=E4=B9=9F=E8=83=BD=E5=90=8C=E6=97=B6=E8=AE=BE=E7=BD=AE=E5=A4=9A?= =?UTF-8?q?=E5=A5=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: helei --- net/tlsconfig/tls.go | 122 +++++++++++++++++++++++++------------------ 1 file changed, 71 insertions(+), 51 deletions(-) diff --git a/net/tlsconfig/tls.go b/net/tlsconfig/tls.go index 19afb83..d121c8f 100644 --- a/net/tlsconfig/tls.go +++ b/net/tlsconfig/tls.go @@ -166,7 +166,7 @@ var ( type TLSConfig struct { Enable bool `json:"enable" yaml:"enable"` Certificates []Certificate `json:"certificates" yaml:"certificates"` // 证书 - RootCAFile string `json:"root_ca_file" yaml:"root_ca_file"` // 根CA文件 + RootCAFile []string `json:"root_ca_file" yaml:"root_ca_file"` // 根CA文件 NextProtos []string `json:"next_protos" yaml:"next_protos"` // 支持的协议 ServerName string `json:"server_name" yaml:"server_name"` // 服务器名称 @@ -176,7 +176,7 @@ type TLSConfig struct { // 3 如果提供客户端证书则验证 // 4 要求并验证客户端证书 ClientAuth tls.ClientAuthType `json:"client_auth" yaml:"client_auth"` // 客户端验证 - ClientCAFile string `json:"client_ca_file" yaml:"client_ca_file"` // 客户端CA文件 + ClientCAFile []string `json:"client_ca_file" yaml:"client_ca_file"` // 客户端CA文件 InsecureSkipVerify bool `json:"insecure_skip_verify" yaml:"insecure_skip_verify"` // 跳过验证 CipherSuites []string `json:"cipher_suites" yaml:"cipher_suites"` // 密钥套件 CurvePreferences []string `json:"curve_preferences" yaml:"curve_preferences"` // 曲线偏好 @@ -205,6 +205,7 @@ func (t *TLSConfig) ToTLSConfig() (*tls.Config, error) { InsecureSkipVerify: t.InsecureSkipVerify, DynamicRecordSizingDisabled: t.DynamicRecordSizingDisabled, Renegotiation: t.Renegotiation, + ClientAuth: t.ClientAuth, } // 自定义曲线 @@ -219,40 +220,44 @@ func (t *TLSConfig) ToTLSConfig() (*tls.Config, error) { } // 解析 TLS 版本 - if t.MinVersion != "" { - minVersion, err := ParseTLSVersion(t.MinVersion) - if err != nil { - return nil, fmt.Errorf("解析最低TLS版本失败: %v", err) + { + if t.MinVersion != "" { + minVersion, err := ParseTLSVersion(t.MinVersion) + if err != nil { + return nil, fmt.Errorf("解析最低TLS版本失败: %v", err) + } + config.MinVersion = minVersion + } else { + // 默认使用 TLS 1.2 + config.MinVersion = tls.VersionTLS12 } - config.MinVersion = minVersion - } else { - // 默认使用 TLS 1.2 - config.MinVersion = tls.VersionTLS12 - } - if t.MaxVersion != "" { - maxVersion, err := ParseTLSVersion(t.MaxVersion) - if err != nil { - return nil, fmt.Errorf("解析最高TLS版本失败: %v", err) + if t.MaxVersion != "" { + maxVersion, err := ParseTLSVersion(t.MaxVersion) + if err != nil { + return nil, fmt.Errorf("解析最高TLS版本失败: %v", err) + } + config.MaxVersion = maxVersion } - config.MaxVersion = maxVersion } // 解析密码套件 - var cipherSuites []uint16 - var err error - if len(t.CipherSuites) > 0 { - cipherSuites, err = ParseCipherSuites(t.CipherSuites) - if err != nil { - return nil, fmt.Errorf("解析密码套件失败: %v", err) - } - config.CipherSuites = cipherSuites - } else { - // 默认使用现代密码套件 - cipherSuites, err = ParseCipherSuites(CompatibleCipherSuites) - if err != nil { - return nil, fmt.Errorf("解析默认密码套件失败: %v", err) + { + var cipherSuites []uint16 + var err error + if len(t.CipherSuites) > 0 { + cipherSuites, err = ParseCipherSuites(t.CipherSuites) + if err != nil { + return nil, fmt.Errorf("解析密码套件失败: %v", err) + } + config.CipherSuites = cipherSuites + } else { + // 默认使用现代密码套件 + cipherSuites, err = ParseCipherSuites(CompatibleCipherSuites) + if err != nil { + return nil, fmt.Errorf("解析默认密码套件失败: %v", err) + } + config.CipherSuites = cipherSuites } - config.CipherSuites = cipherSuites } // 加载服务器证书 @@ -276,30 +281,20 @@ func (t *TLSConfig) ToTLSConfig() (*tls.Config, error) { } // 加载根 CA 证书 - if t.RootCAFile != "" { - rootCAFile := tools.Fileabs(t.RootCAFile) - rootCACert, err := os.ReadFile(rootCAFile) - if err != nil { - return nil, fmt.Errorf("读取根证书文件失败: %v", err) - } - rootCertPool := x509.NewCertPool() - if !rootCertPool.AppendCertsFromPEM(rootCACert) { - return nil, fmt.Errorf("解析根证书失败") - } - config.RootCAs = rootCertPool + // noinspection all + if pool, err := t.loadCaKit(t.RootCAFile...); err != nil { + return nil, fmt.Errorf("服务端证书添加失败 %v", err) + } else { + config.RootCAs = pool } - // 加载客户端 CA 证书 - if t.ClientCAFile != "" { - clientCACert, err := os.ReadFile(tools.Fileabs(t.ClientCAFile)) + + // 加载客户端证书 + if len(t.ClientCAFile) > 0 { + pool, err := t.loadCaKit(t.ClientCAFile...) if err != nil { - return nil, fmt.Errorf("读取客户端CA文件失败: %v", err) - } - clientCertPool := x509.NewCertPool() - if !clientCertPool.AppendCertsFromPEM(clientCACert) { - return nil, fmt.Errorf("解析客户端CA证书失败") + return nil, fmt.Errorf("客户端证书添加失败 %v", err) } - config.ClientCAs = clientCertPool - config.ClientAuth = t.ClientAuth + config.ClientCAs = pool } else if t.ClientAuth != tls.NoClientCert { // 如果设置了客户端认证但没有提供 CA 文件,返回错误 return nil, fmt.Errorf("客户端认证需要提供 client_ca_file") @@ -307,6 +302,31 @@ func (t *TLSConfig) ToTLSConfig() (*tls.Config, error) { return config, nil } +// 载入 ca 证书工具方法 +// 支持pem和der两种格式的证书。 +func (t *TLSConfig) loadCaKit(cas ...string) (*x509.CertPool, error) { + if len(cas) < 1 { + return nil, nil + } + pool := x509.NewCertPool() + for _, ca := range cas { + caFile := tools.Fileabs(ca) + caCert, err := os.ReadFile(caFile) + if err != nil { + return nil, fmt.Errorf("读取证书[%s]内容失败:%v", ca, err) + } + if pool.AppendCertsFromPEM(caCert) { + continue + } + cert, err := x509.ParseCertificate(caCert) + if err != nil { + return nil, fmt.Errorf("解析证书[%s]失败:%v(既不是PEM也不是DER格式)", ca, err) + } + pool.AddCert(cert) + } + return pool, nil +} + // DefaultTLSConfig 返回一个安全的默认 TLS 配置 func DefaultTLSConfig() *TLSConfig { return &TLSConfig{ From dc5c45027e8562bafcf877a975f095eb21c2cf47 Mon Sep 17 00:00:00 2001 From: helei Date: Tue, 27 Jan 2026 11:50:18 +0800 Subject: [PATCH 8/9] =?UTF-8?q?=E6=97=A5=E5=BF=97=E6=96=B0=E5=A2=9E=20trac?= =?UTF-8?q?e=E7=BA=A7=E5=88=AB=20http=E6=97=A5=E5=BF=97=EF=BC=8C=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=E5=B0=86=E6=AD=A3=E5=B8=B8=E6=97=A5=E5=BF=97=E8=BE=93?= =?UTF-8?q?=E5=87=BA=E5=88=B0=20debug=E7=BA=A7=E5=88=AB=E4=B8=8B=E7=9A=84?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E5=99=A8=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: helei --- config/parseCmd/parse.go | 3 + logger/ulogs/log.go | 104 ++++++++++++++-------------- net/http/route/middleware/logger.go | 48 +++++++++++-- 3 files changed, 99 insertions(+), 56 deletions(-) diff --git a/config/parseCmd/parse.go b/config/parseCmd/parse.go index 8c36a8a..7cb0e5c 100644 --- a/config/parseCmd/parse.go +++ b/config/parseCmd/parse.go @@ -43,6 +43,8 @@ func Parseparams(f ...func()) { } // 控制日志等级 switch logLevel { + case "trace": + ulogs.Level = ulogs.LogLevelTrace case "debug": ulogs.Level = ulogs.LogLevelDebug case "info": @@ -55,6 +57,7 @@ func Parseparams(f ...func()) { ulogs.Level = ulogs.LogLevelFatal } + // noinspection all if config.EnableParseParamsLog { fmt.Println("日志级别", logLevel, ulogs.Level) ulogs.Log("运行参数解析完成...") diff --git a/logger/ulogs/log.go b/logger/ulogs/log.go index 54c17dd..ac73010 100644 --- a/logger/ulogs/log.go +++ b/logger/ulogs/log.go @@ -6,14 +6,24 @@ import ( ) const ( - LogLevelDebug = iota + LogLevelTrace = iota + LogLevelDebug LogLevelInfo LogLevelWarn LogLevelError LogLevelFatal ) -var Level = 1 +var ( + traceLogger = log.New(os.Stdout, "【TRACE】", log.LstdFlags) + debugLogger = log.New(os.Stdout, "【DEBUG】", log.LstdFlags) + infoLogger = log.New(os.Stdout, "【INFO】", log.LstdFlags) + warnLogger = log.New(os.Stdout, "【WARN】", log.LstdFlags) + errorLogger = log.New(os.Stderr, "【ERROR", log.LstdFlags) + fatalLogger = log.New(os.Stderr, "【FATAL】", log.LstdFlags) +) + +var Level = LogLevelInfo // Recover 捕获系统异常 func Recover() { @@ -23,103 +33,93 @@ func Recover() { } // Log 打印正确日志,Info的别名 +// Deprecated: 弃用,请使用 Info func Log(i ...interface{}) { Info(i...) } +func Trace(i ...any) { + if Level <= LogLevelTrace { + traceLogger.Println(i...) + } +} + +// noinspection all +func Tracef(format string, a ...any) { + if Level <= LogLevelTrace { + traceLogger.Printf(format, a...) + } +} + // Debug 用于记录调试信息 func Debug(i ...any) { - if Level > LogLevelDebug { - return + if Level <= LogLevelDebug { + debugLogger.Println(i...) } - log.SetPrefix("【DEBUG】") - log.SetOutput(os.Stdout) - log.Println(i...) } +// Debugf +// noinspection all func Debugf(format string, a ...any) { - if Level > LogLevelDebug { - return + if Level <= LogLevelDebug { + debugLogger.Printf(format, a...) } - log.SetPrefix("【DEBUG】") - log.SetOutput(os.Stdout) - log.Printf(format, a...) } // Info 用于记录信息 func Info(i ...interface{}) { - if Level > LogLevelInfo { - return + if Level <= LogLevelInfo { + infoLogger.Println(i...) } - log.SetPrefix("【INFO】") - log.SetOutput(os.Stdout) - log.Println(i...) } +// noinspection all func Infof(format string, a ...any) { - if Level > LogLevelInfo { - return + if Level <= LogLevelInfo { + infoLogger.Printf(format, a...) } - log.SetPrefix("【INFO】") - log.SetOutput(os.Stdout) - log.Printf(format, a...) + } // Warn 用于记录警告信息 func Warn(i ...interface{}) { - if Level > LogLevelWarn { - return + if Level <= LogLevelWarn { + warnLogger.Println(i...) } - log.SetPrefix("【WARN】") - log.SetOutput(os.Stdout) - log.Println(i...) } +// noinspection all func Warnf(format string, a ...any) { - if Level > LogLevelWarn { - return + if Level <= LogLevelWarn { + warnLogger.Printf(format, a...) } - log.SetPrefix("【WARN】") - log.SetOutput(os.Stdout) - log.Printf(format, a...) } // Error 用于记录错误信息 func Error(i ...interface{}) { - if Level > LogLevelError { - return + if Level <= LogLevelError { + errorLogger.Println(i...) } - log.SetPrefix("【ERROR】") - log.SetOutput(os.Stderr) - log.Println(i...) } func Errorf(format string, a ...any) { - if Level > LogLevelError { - return + if Level <= LogLevelError { + errorLogger.Printf(format, a...) } - log.SetPrefix("【ERROR】") - log.SetOutput(os.Stderr) - log.Printf(format, a...) } // Fatal 用于记录致命错误信息 func Fatal(i ...interface{}) { - if Level > LogLevelFatal { - return + if Level <= LogLevelFatal { + fatalLogger.Println(i...) } - log.SetPrefix("【FATAL】") - log.SetOutput(os.Stderr) - log.Println(i...) } +// noinspection all func Fatalf(format string, a ...any) { - if Level > LogLevelFatal { - return + if Level <= LogLevelFatal { + fatalLogger.Printf(format, a...) } - log.SetPrefix("【FATAL】") - log.SetOutput(os.Stderr) - log.Printf(format, a...) } // Checkerr 检查错误 diff --git a/net/http/route/middleware/logger.go b/net/http/route/middleware/logger.go index 8ffe1c6..6d4ccd5 100644 --- a/net/http/route/middleware/logger.go +++ b/net/http/route/middleware/logger.go @@ -30,8 +30,13 @@ func (c *ResponseProcessor) metrics(w *writer) { } -// StdLogger 日志标准输出器 +// ================= StdLogger 日志标准输出器 ================ +// StdLogger 日志标准输出器 +// http code >= 500 输出到Stderr,日志等级Fatal +// http code >= 400 输出的Stderr,日志等级Error +// http code >= 300 输出到Stdout,日志等级Warn +// 否则输出到 Stdout,日志等级 Info type StdLogger struct{} func NewStdLogger() *StdLogger { @@ -39,13 +44,44 @@ func NewStdLogger() *StdLogger { } func (s *StdLogger) Write(l *Logs) { - if l.Status >= http.StatusBadRequest { + if l.Status >= http.StatusInternalServerError { + ulogs.Fatalf("[%s] %s %s %d %d %s [%s]", l.Ip, l.Method, l.Uri, l.Status, l.ContentSize, l.UserAgent, l.Elapsed) + } else if l.Status >= http.StatusBadRequest { ulogs.Errorf("[%s] %s %s %d %d %s [%s]", l.Ip, l.Method, l.Uri, l.Status, l.ContentSize, l.UserAgent, l.Elapsed) + } else if l.Status >= http.StatusMultipleChoices { + ulogs.Warnf("[%s] %s %s %d %d %s [%s]", l.Ip, l.Method, l.Uri, l.Status, l.ContentSize, l.UserAgent, l.Elapsed) } else { ulogs.Infof("[%s] %s %s %d %d %s [%s]", l.Ip, l.Method, l.Uri, l.Status, l.ContentSize, l.UserAgent, l.Elapsed) } } +// ================= DebugStdLogger 日志标准输出器 ================ + +// DebugStdLogger 调试日志输出器(将正常日志输出到debug级别) +// http code >= 500 输出到Stderr,日志等级Fatal +// http code >= 400 输出的Stderr,日志等级Error +// http code >= 300 输出到Stdout,日志等级Warn +// 否则输出到 Stdout,日志等级 Debug(正常日志使用debug级别) +type DebugStdLogger struct{} + +func NewDebugStdLogger() *DebugStdLogger { + return &DebugStdLogger{} +} + +func (d *DebugStdLogger) Write(l *Logs) { + if l.Status >= http.StatusInternalServerError { + ulogs.Fatalf("[%s] %s %s %d %d %s [%s]", l.Ip, l.Method, l.Uri, l.Status, l.ContentSize, l.UserAgent, l.Elapsed) + } else if l.Status >= http.StatusBadRequest { + ulogs.Errorf("[%s] %s %s %d %d %s [%s]", l.Ip, l.Method, l.Uri, l.Status, l.ContentSize, l.UserAgent, l.Elapsed) + } else if l.Status >= http.StatusMultipleChoices { + ulogs.Warnf("[%s] %s %s %d %d %s [%s]", l.Ip, l.Method, l.Uri, l.Status, l.ContentSize, l.UserAgent, l.Elapsed) + } else { + ulogs.Debugf("[%s] %s %s %d %d %s [%s]", l.Ip, l.Method, l.Uri, l.Status, l.ContentSize, l.UserAgent, l.Elapsed) + } +} + +// ================= ZapLogger 日志输出器 ================ + type ZapLogger struct { logger *zaploger.Logger } @@ -66,9 +102,13 @@ func (z *ZapLogger) Write(l *Logs) { zap.String("http_user_agent", l.UserAgent), zap.Any("elapsed", l.Elapsed), } + if l.Status >= http.StatusBadRequest { z.logger.Error(context.Background(), l.Ip, msg...) - return + } else if l.Status >= http.StatusMultipleChoices { + z.logger.Warn(context.Background(), l.Ip, msg...) + } else { + z.logger.Debug(context.Background(), l.Ip, msg...) } - z.logger.Debug(context.Background(), l.Ip, msg...) + } From 61233f669cf030ec9252492cea73c3c0bb30b2ec Mon Sep 17 00:00:00 2001 From: helei Date: Tue, 27 Jan 2026 13:48:47 +0800 Subject: [PATCH 9/9] =?UTF-8?q?=E6=97=A5=E5=BF=97=E6=96=B0=E5=A2=9E=20trac?= =?UTF-8?q?e=E7=BA=A7=E5=88=AB=20http=E6=97=A5=E5=BF=97=EF=BC=8C=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=E5=B0=86=E6=AD=A3=E5=B8=B8=E6=97=A5=E5=BF=97=E8=BE=93?= =?UTF-8?q?=E5=87=BA=E5=88=B0=20debug=E7=BA=A7=E5=88=AB=E4=B8=8B=E7=9A=84?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E5=99=A8=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: helei --- config/parseCmd/parse.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/config/parseCmd/parse.go b/config/parseCmd/parse.go index 7cb0e5c..b01acc4 100644 --- a/config/parseCmd/parse.go +++ b/config/parseCmd/parse.go @@ -22,11 +22,9 @@ func Parseparams(f ...func()) { flag.BoolVar(&config.Dbg, "debug", false, "Debug 模式") flag.StringVar(&logLevel, "log-level", "info", "日志级别:\ndebug info warn error fatal") flag.BoolVar(&vers, "version", false, "查看版本") - if len(f) > 0 { - for _, v := range f { - if v != nil { - v() - } + for _, v := range f { + if v != nil { + v() } } flag.Parse()