From 3e752be25ab83f9b0b07d7eab265c3de2e5a4c52 Mon Sep 17 00:00:00 2001 From: Agustin Groh Date: Tue, 6 Jan 2026 11:17:19 -0300 Subject: [PATCH] chore:SP-3888 Implement multithreading support for the local vulnerability use case --- CHANGELOG.md | 2 + pkg/config/server_config.go | 4 +- pkg/models/versions.go | 19 ++-- pkg/models/versions_test.go | 25 ++--- pkg/models/vulns_purl.go | 21 ++--- pkg/models/vulns_purl_test.go | 28 ++---- pkg/usecase/OSV_use_case.go | 8 +- pkg/usecase/OSV_use_case_test.go | 2 +- pkg/usecase/local_use_case.go | 130 ++++++++++++++++---------- pkg/usecase/local_use_case_test.go | 20 ++-- pkg/usecase/vulnerability_use_case.go | 35 +++++-- 11 files changed, 166 insertions(+), 128 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 771d312..68e6b8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,8 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.8.0] - 2026/01/07 ### Added - Included Exploit Prediction Scoring System (EPSS) to vulnerability response +- Added configurable worker pool for local vulnerability processing (`VULN_SCANOSS_WORKERS`) ### Changed - Refactored OSV use case +- Refactored local vulnerability use case with multithreading support and context cancellation handling - Upgraded `scanoss/papi` to v0.28.0 ## [0.7.0] - 2025/11/13 diff --git a/pkg/config/server_config.go b/pkg/config/server_config.go index 15bbcf6..c3aa0f5 100644 --- a/pkg/config/server_config.go +++ b/pkg/config/server_config.go @@ -79,7 +79,8 @@ type ServerConfig struct { APIWorkers int `env:"VULN_OSV_API_WORKERS"` } SCANOSS struct { - Enabled bool `env:"VULN_SCANOSS_SOURCE_ENABLED"` + Enabled bool `env:"VULN_SCANOSS_SOURCE_ENABLED"` + MaxWorkers int `env:"VULN_SCANOSS_WORKERS"` } } } @@ -124,6 +125,7 @@ func setServerConfigDefaults(cfg *ServerConfig) { cfg.Source.OSV.Enabled = true cfg.Source.OSV.APIWorkers = 5 cfg.Source.SCANOSS.Enabled = true + cfg.Source.SCANOSS.MaxWorkers = 5 } func IsValidConfig(cfg *ServerConfig) error { diff --git a/pkg/models/versions.go b/pkg/models/versions.go index 7585e82..1795943 100644 --- a/pkg/models/versions.go +++ b/pkg/models/versions.go @@ -30,8 +30,7 @@ import ( ) type VersionModel struct { - ctx context.Context - conn *sqlx.Conn + db *sqlx.DB } type Version struct { @@ -48,18 +47,18 @@ type PurlVersion struct { // TODO add cache for versions already searched for? // NewVersionModel creates a new instance of the Version Model. -func NewVersionModel(ctx context.Context, conn *sqlx.Conn) *VersionModel { - return &VersionModel{ctx: ctx, conn: conn} +func NewVersionModel(db *sqlx.DB) *VersionModel { + return &VersionModel{db: db} } // GetVersionByName gets the given version from the versions table. -func (m *VersionModel) GetVersionByName(name string, create bool) (Version, error) { +func (m *VersionModel) GetVersionByName(ctx context.Context, name string, create bool) (Version, error) { if len(name) == 0 { zlog.S.Error("Please specify a valid Version Name to query") return Version{}, errors.New("please specify a valid Version Name to query") } var version Version - err := m.conn.QueryRowxContext(m.ctx, + err := m.db.QueryRowxContext(ctx, "SELECT id, version_name, semver FROM versions"+ " WHERE version_name = $1", name).StructScan(&version) @@ -68,28 +67,28 @@ func (m *VersionModel) GetVersionByName(name string, create bool) (Version, erro return Version{}, fmt.Errorf("failed to query the versions table: %v", err) } if create && len(version.VersionName) == 0 { // No version found and requested to create an entry - return m.saveVersion(name) + return m.saveVersion(ctx, name) } return version, nil } // saveVersion writes the given version name to the versions table. -func (m *VersionModel) saveVersion(name string) (Version, error) { +func (m *VersionModel) saveVersion(ctx context.Context, name string) (Version, error) { if len(name) == 0 { zlog.S.Error("Please specify a valid version Name to save") return Version{}, errors.New("please specify a valid Version Name to save") } zlog.S.Debugf("Attempting to save '%v' to the versions table...", name) var version Version - err := m.conn.QueryRowxContext(m.ctx, + err := m.db.QueryRowxContext(ctx, "INSERT INTO versions (version_name, semver) VALUES($1, $2)"+ " RETURNING id, version_name, semver", name, "", false, false, ).StructScan(&version) if err != nil { zlog.S.Errorf("Error: Failed to insert new version name into versions table for %v: %v", name, err) - return m.GetVersionByName(name, false) // Search one more time for it, just in case someone else added it + return m.GetVersionByName(ctx, name, false) // Search one more time for it, just in case someone else added it } return version, nil } diff --git a/pkg/models/versions_test.go b/pkg/models/versions_test.go index c710a26..6cdcab7 100644 --- a/pkg/models/versions_test.go +++ b/pkg/models/versions_test.go @@ -38,19 +38,14 @@ func TestVersionsSearch(t *testing.T) { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer CloseDB(db) - conn, err := db.Connx(ctx) // Get a connection from the pool - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - defer CloseConn(conn) - err = loadTestSQLDataFiles(db, ctx, conn, []string{"../models/tests/versions.sql"}) + err = loadTestSQLDataFiles(db, ctx, nil, []string{"../models/tests/versions.sql"}) if err != nil { t.Fatalf("failed to load SQL test data: %v", err) } - versionModel := NewVersionModel(ctx, conn) + versionModel := NewVersionModel(db) var name = "1.0.0" fmt.Printf("Searching for version: %v\n", name) - version, err := versionModel.GetVersionByName(name, false) + version, err := versionModel.GetVersionByName(ctx, name, false) if err != nil { t.Errorf("versions.GetVersionByName() error = %v", err) } @@ -61,7 +56,7 @@ func TestVersionsSearch(t *testing.T) { name = "" fmt.Printf("Searching for license: %v\n", name) - _, err = versionModel.GetVersionByName(name, false) + _, err = versionModel.GetVersionByName(ctx, name, false) if err == nil { t.Errorf("versions.GetVersionByName() error = did not get an error") } else { @@ -69,7 +64,7 @@ func TestVersionsSearch(t *testing.T) { } name = "" fmt.Printf("Saving for license: %v\n", name) - _, err = versionModel.saveVersion(name) + _, err = versionModel.saveVersion(ctx, name) if err == nil { t.Errorf("versions.saveVersion() error = did not get an error") } else { @@ -78,7 +73,7 @@ func TestVersionsSearch(t *testing.T) { name = "22.22.22" fmt.Printf("Searching for version: %v\n", name) - version, err = versionModel.GetVersionByName(name, true) + version, err = versionModel.GetVersionByName(ctx, name, true) if err != nil { t.Errorf("versions.GetVersionByName() error = %v", err) } @@ -89,7 +84,7 @@ func TestVersionsSearch(t *testing.T) { name = "22.22.22" fmt.Printf("Searching for version: %v\n", name) - version, err = versionModel.saveVersion(name) + version, err = versionModel.saveVersion(ctx, name) if err != nil { t.Errorf("versions.GetVersionByName() error = %v", err) } @@ -116,14 +111,14 @@ func TestVersionsSearchBadSql(t *testing.T) { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer CloseConn(conn) - versionModel := NewVersionModel(ctx, conn) - _, err = versionModel.GetVersionByName("rubbish", false) + versionModel := NewVersionModel(db) + _, err = versionModel.GetVersionByName(ctx, "rubbish", false) if err == nil { t.Errorf("versions.GetVersionByName() error = did not get an error") } else { fmt.Printf("Got expected error = %v\n", err) } - _, err = versionModel.saveVersion("rubbish") + _, err = versionModel.saveVersion(ctx, "rubbish") if err == nil { t.Errorf("versions.saveVersion() error = did not get an error") } else { diff --git a/pkg/models/vulns_purl.go b/pkg/models/vulns_purl.go index e13515b..3344685 100644 --- a/pkg/models/vulns_purl.go +++ b/pkg/models/vulns_purl.go @@ -30,8 +30,7 @@ import ( ) type VulnsForPurlModel struct { - ctx context.Context - conn *sqlx.Conn + db *sqlx.DB } type VulnsForPurl struct { @@ -48,12 +47,12 @@ type OnlyPurl struct { } // NewVulnsForPurlModel creates a new instance of the CPE Purl Model. -func NewVulnsForPurlModel(ctx context.Context, conn *sqlx.Conn) *VulnsForPurlModel { - return &VulnsForPurlModel{ctx: ctx, conn: conn} +func NewVulnsForPurlModel(db *sqlx.DB) *VulnsForPurlModel { + return &VulnsForPurlModel{db: db} } // GetVulnsByPurl gets vulnerabilities by purl. -func (m *VulnsForPurlModel) GetVulnsByPurl(purl string, version string) ([]VulnsForPurl, error) { +func (m *VulnsForPurlModel) GetVulnsByPurl(ctx context.Context, purl string, version string) ([]VulnsForPurl, error) { if len(purl) == 0 { zlog.S.Errorf("Please specify a valid Purl String to query") return []VulnsForPurl{}, errors.New("please specify a valid Purl String to query") @@ -68,13 +67,13 @@ func (m *VulnsForPurlModel) GetVulnsByPurl(purl string, version string) ([]Vulns purlName := utils.PurlRemoveFromVersionComponent(purl) // Remove everything after the component name if len(version) > 0 { - return m.GetVulnsByPurlVersion(purlName, version) + return m.GetVulnsByPurlVersion(ctx, purlName, version) } - return m.GetVulnsByPurlName(purlName) + return m.GetVulnsByPurlName(ctx, purlName) } // GetVulnsByPurlName searches for component details of the specified Purl Name/Type (and optional requirement). -func (m *VulnsForPurlModel) GetVulnsByPurlName(purlName string) ([]VulnsForPurl, error) { +func (m *VulnsForPurlModel) GetVulnsByPurlName(ctx context.Context, purlName string) ([]VulnsForPurl, error) { if len(purlName) == 0 { zlog.S.Errorf("Please specify a valid Purl Name to query") return []VulnsForPurl{}, errors.New("please specify a valid Purl Name to query") @@ -82,7 +81,7 @@ func (m *VulnsForPurlModel) GetVulnsByPurlName(purlName string) ([]VulnsForPurl, var vulns []VulnsForPurl purlName = strings.TrimSpace(purlName) - err := m.conn.SelectContext(m.ctx, &vulns, + err := m.db.SelectContext(ctx, &vulns, "SELECT c2.cve, c2.severity, c2.published, c2.modified, c2.summary "+ "FROM short_cpe_purl scp "+ "INNER JOIN cpes c ON scp.cpe_id = c.id "+ @@ -101,7 +100,7 @@ func (m *VulnsForPurlModel) GetVulnsByPurlName(purlName string) ([]VulnsForPurl, return vulns, nil } -func (m *VulnsForPurlModel) GetVulnsByPurlVersion(purlName string, purlVersion string) ([]VulnsForPurl, error) { +func (m *VulnsForPurlModel) GetVulnsByPurlVersion(ctx context.Context, purlName string, purlVersion string) ([]VulnsForPurl, error) { if len(purlName) == 0 { zlog.S.Errorf("Please specify a valid Purl Name to query") return []VulnsForPurl{}, errors.New("please specify a valid Purl Name to query") @@ -142,7 +141,7 @@ func (m *VulnsForPurlModel) GetVulnsByPurlVersion(purlName string, purlVersion s WHERE c2.match_criteria_ids && mc.criteria_ids ORDER BY c2.cve, c2.severity, c2.published, c2.modified, c2.summary;` - err := m.conn.SelectContext(m.ctx, &vulns, query, purlName, purlVersion) + err := m.db.SelectContext(ctx, &vulns, query, purlName, purlVersion) if err != nil { zlog.S.Errorf("Failed to query short_cpe for %s: %v", purlName, err) diff --git a/pkg/models/vulns_purl_test.go b/pkg/models/vulns_purl_test.go index d85a160..2fa548a 100644 --- a/pkg/models/vulns_purl_test.go +++ b/pkg/models/vulns_purl_test.go @@ -50,7 +50,7 @@ func TestGetVulnsByPurl(t *testing.T) { t.Fatalf("failed to load SQL test data: %v", err) } - cpeModel := NewVulnsForPurlModel(ctx, conn) + cpeModel := NewVulnsForPurlModel(db) type inputGetVulnsForPurl struct { purl string @@ -75,7 +75,7 @@ func TestGetVulnsByPurl(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := cpeModel.GetVulnsByPurl(tt.input.purl, tt.input.requirement) + got, err := cpeModel.GetVulnsByPurl(ctx, tt.input.purl, tt.input.requirement) if (err != nil) != tt.wantErr { t.Errorf("cpeModel.GetCpeByPurl() error = %v, wantErr %v", err, tt.wantErr) return @@ -102,26 +102,14 @@ func TestGetVulnsByPurlName(t *testing.T) { } db.SetMaxOpenConns(1) defer CloseDB(db) - - conn, err := db.Connx(ctx) // Get a connection from the pool - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - defer CloseConn(conn) - err = LoadTestSQLData(db, ctx, conn) + err = LoadTestSQLData(db, ctx, nil) if err != nil { t.Fatalf("failed to load SQL test data: %v", err) } - cpeModel := NewVulnsForPurlModel(ctx, conn) - - _, err = cpeModel.GetVulnsByPurlName("") - if err == nil { - t.Errorf("Error was expected because purl is empty in cpeModel.GetVulnsByPurlName()") - } + cpeModel := NewVulnsForPurlModel(db) - CloseConn(conn) - _, err = cpeModel.GetVulnsByPurlName("pkg:github/hapijs/call") + _, err = cpeModel.GetVulnsByPurlName(ctx, "") if err == nil { t.Errorf("Error was expected because purl is empty in cpeModel.GetVulnsByPurlName()") } @@ -151,15 +139,15 @@ func TestGetVulnsByPurlVersion(t *testing.T) { t.Fatalf("failed to load SQL test data: %v", err) } - cpeModel := NewVulnsForPurlModel(ctx, conn) + cpeModel := NewVulnsForPurlModel(db) - _, err = cpeModel.GetVulnsByPurlVersion("", "") + _, err = cpeModel.GetVulnsByPurlVersion(ctx, "", "") if err == nil { t.Errorf("Error was expected because purl is empty in cpeModel.GetVulnsByPurlVersion()") } CloseConn(conn) - _, err = cpeModel.GetVulnsByPurlVersion("pkg:github/hapijs/call", "1.0.0") + _, err = cpeModel.GetVulnsByPurlVersion(ctx, "pkg:github/hapijs/call", "1.0.0") if err == nil { t.Errorf("Error was expected because purl is empty in cpeModel.GetVulnsByPurlVersion()") } diff --git a/pkg/usecase/OSV_use_case.go b/pkg/usecase/OSV_use_case.go index 4cc90a7..187af40 100644 --- a/pkg/usecase/OSV_use_case.go +++ b/pkg/usecase/OSV_use_case.go @@ -81,15 +81,15 @@ func (us OSVUseCase) getOSVRequestsFromDTO(dto []dtos.ComponentDTO) []OSVRequest return osvRequests } -func (us OSVUseCase) Execute(dto []dtos.ComponentDTO) dtos.VulnerabilityOutput { +func (us OSVUseCase) Execute(ctx context.Context, dto []dtos.ComponentDTO) dtos.VulnerabilityOutput { osvRequests := us.getOSVRequestsFromDTO(dto) - return us.processRequests(osvRequests) + return us.processRequests(ctx, osvRequests) } -func (us OSVUseCase) processRequests(requests []OSVRequest) dtos.VulnerabilityOutput { +func (us OSVUseCase) processRequests(ctx context.Context, requests []OSVRequest) dtos.VulnerabilityOutput { numJobs := len(requests) jobs := make(chan OSVRequest, numJobs) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + ctx, cancel := context.WithTimeout(ctx, 3*time.Minute) defer cancel() results := make(chan dtos.VulnerabilityComponentOutput, numJobs) workers := min(us.MaxAPIWorkers, numJobs) diff --git a/pkg/usecase/OSV_use_case_test.go b/pkg/usecase/OSV_use_case_test.go index fde298d..0629ac7 100644 --- a/pkg/usecase/OSV_use_case_test.go +++ b/pkg/usecase/OSV_use_case_test.go @@ -61,7 +61,7 @@ func TestOSVUseCase(t *testing.T) { OSVUseCase := NewOSVUseCase(s, serverConfig) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - r := OSVUseCase.Execute(tc.input) + r := OSVUseCase.Execute(ctx, tc.input) if len(r.Components) == 0 { t.Errorf("Expected Purls to have elements, got empty slice") } diff --git a/pkg/usecase/local_use_case.go b/pkg/usecase/local_use_case.go index 58830f8..800ad29 100644 --- a/pkg/usecase/local_use_case.go +++ b/pkg/usecase/local_use_case.go @@ -18,76 +18,108 @@ package usecase import ( "context" - "errors" "fmt" + "scanoss.com/vulnerabilities/pkg/config" + + "go.uber.org/zap" + "github.com/jmoiron/sqlx" - myconfig "scanoss.com/vulnerabilities/pkg/config" "scanoss.com/vulnerabilities/pkg/dtos" "scanoss.com/vulnerabilities/pkg/models" - - zlog "github.com/scanoss/zap-logging-helper/pkg/logger" ) +// LocalVulnerabilityUseCase handles vulnerability lookups against a local database. +// It provides concurrent processing of component vulnerability queries using a worker pool pattern. type LocalVulnerabilityUseCase struct { ctx context.Context - conn *sqlx.Conn vulnsPurl *models.VulnsForPurlModel versionMod *models.VersionModel + s *zap.SugaredLogger + config *config.ServerConfig } // NewLocalVulnerabilitiesUseCase creates a new instance of the vulnerability Use Case. -func NewLocalVulnerabilitiesUseCase(ctx context.Context, conn *sqlx.Conn, config *myconfig.ServerConfig) *LocalVulnerabilityUseCase { - return &LocalVulnerabilityUseCase{ctx: ctx, conn: conn, - vulnsPurl: models.NewVulnsForPurlModel(ctx, conn), - versionMod: models.NewVersionModel(ctx, conn), +func NewLocalVulnerabilitiesUseCase(ctx context.Context, s *zap.SugaredLogger, config *config.ServerConfig, db *sqlx.DB) *LocalVulnerabilityUseCase { + return &LocalVulnerabilityUseCase{ + ctx: ctx, + vulnsPurl: models.NewVulnsForPurlModel(db), + versionMod: models.NewVersionModel(db), + s: s, + config: config, } } -func (d LocalVulnerabilityUseCase) GetVulnerabilities(components []dtos.ComponentDTO) (dtos.VulnerabilityOutput, error) { - var vulnOutputs []dtos.VulnerabilityComponentOutput - var problems = false - for _, c := range components { - if len(c.Purl) == 0 { - zlog.S.Infof("Empty Purl string supplied for: %v. Skipping", c) - continue - } - - // VulnerabilitiesOutput - var item dtos.VulnerabilityComponentOutput - item.Purl = c.Purl - item.Requirement = c.Requirement - item.Version = c.Version - - vulnPurls, err := d.vulnsPurl.GetVulnsByPurl(c.Purl, c.Version) +// vulnerabilityWorker is a worker goroutine that processes component vulnerability lookups. +// It reads components from the jobs channel, queries the local database for vulnerabilities, +// and sends the results to the results channel. The worker terminates when the jobs channel is closed. +func (d *LocalVulnerabilityUseCase) vulnerabilityWorker(ctx context.Context, jobs chan dtos.ComponentDTO, results chan dtos.VulnerabilityComponentOutput) { + for { + select { + case <-ctx.Done(): + d.s.Debugf("Vulnerability worker cancelled: %v", ctx.Err()) + return + case c, ok := <-jobs: + if !ok { + d.s.Debugf("Vulnerability worker channel closed. Exiting.") + return + } + if len(c.Purl) == 0 { + d.s.Debugf("Empty Purl string supplied for: %v. Skipping", c) + results <- dtos.VulnerabilityComponentOutput{} + continue + } + // VulnerabilitiesOutput + var item dtos.VulnerabilityComponentOutput + item.Purl = c.Purl + item.Requirement = c.Requirement + item.Version = c.Version + vulnPurls, err := d.vulnsPurl.GetVulnsByPurl(ctx, c.Purl, c.Version) + if err != nil { + d.s.Errorf("Problem encountered extracting vulnerabilities for: %v - %v.", c, err) + results <- item + continue + } + for _, cve := range vulnPurls { + var vulnerabilitiesForThisPurl dtos.VulnerabilitiesOutput + vulnerabilitiesForThisPurl.ID = cve.Cve + vulnerabilitiesForThisPurl.Cve = cve.Cve + vulnerabilitiesForThisPurl.Severity = cve.Severity + vulnerabilitiesForThisPurl.Modified = cve.Modified + vulnerabilitiesForThisPurl.Published = cve.Published + vulnerabilitiesForThisPurl.Summary = cve.Summary + vulnerabilitiesForThisPurl.URL = fmt.Sprintf("https://nvd.nist.gov/vuln/detail/%s", cve.Cve) - if err != nil { - zlog.S.Errorf("Problem encountered extracting vulnerabilities for: %v - %v.", c, err) - problems = true - continue + vulnerabilitiesForThisPurl.Source = "NVD" + item.Vulnerabilities = append(item.Vulnerabilities, vulnerabilitiesForThisPurl) + } + results <- item } - - for _, cve := range vulnPurls { - var vulnerabilitiesForThisPurl dtos.VulnerabilitiesOutput - vulnerabilitiesForThisPurl.ID = cve.Cve - vulnerabilitiesForThisPurl.Cve = cve.Cve - vulnerabilitiesForThisPurl.Severity = cve.Severity - vulnerabilitiesForThisPurl.Modified = cve.Modified - vulnerabilitiesForThisPurl.Published = cve.Published - vulnerabilitiesForThisPurl.Summary = cve.Summary - vulnerabilitiesForThisPurl.URL = fmt.Sprintf("https://nvd.nist.gov/vuln/detail/%s", cve.Cve) - - vulnerabilitiesForThisPurl.Source = "NVD" - item.Vulnerabilities = append(item.Vulnerabilities, vulnerabilitiesForThisPurl) - } - - vulnOutputs = append(vulnOutputs, item) } +} - if problems { - zlog.S.Errorf("Encountered issues while processing vulnerabilities: %v", components) - return dtos.VulnerabilityOutput{}, errors.New("encountered issues while processing vulnerabilities") +// GetVulnerabilities retrieves vulnerabilities for a list of components from the local database. +// It spawns a pool of workers (up to MaxWorkers) to process requests concurrently and returns +// aggregated vulnerability information for all components. +func (d *LocalVulnerabilityUseCase) GetVulnerabilities(ctx context.Context, components []dtos.ComponentDTO) (dtos.VulnerabilityOutput, error) { + numJobs := len(components) + jobs := make(chan dtos.ComponentDTO, numJobs) + results := make(chan dtos.VulnerabilityComponentOutput, numJobs) + numWorkers := min(d.config.Source.SCANOSS.MaxWorkers, numJobs) + for i := 0; i < numWorkers; i++ { + go d.vulnerabilityWorker(ctx, jobs, results) + } + for _, component := range components { + jobs <- component + } + close(jobs) + var vulnOutputs = make([]dtos.VulnerabilityComponentOutput, numJobs) + for i := 0; i < numJobs; i++ { + select { + case <-ctx.Done(): + return dtos.VulnerabilityOutput{Components: vulnOutputs}, ctx.Err() + case vulnOutputs[i] = <-results: + } } - return dtos.VulnerabilityOutput{Components: vulnOutputs}, nil } diff --git a/pkg/usecase/local_use_case_test.go b/pkg/usecase/local_use_case_test.go index f714612..165be79 100644 --- a/pkg/usecase/local_use_case_test.go +++ b/pkg/usecase/local_use_case_test.go @@ -21,8 +21,11 @@ import ( "fmt" "testing" + "scanoss.com/vulnerabilities/pkg/config" + + "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" + zlog "github.com/scanoss/zap-logging-helper/pkg/logger" - myconfig "scanoss.com/vulnerabilities/pkg/config" "scanoss.com/vulnerabilities/pkg/dtos" "github.com/jmoiron/sqlx" @@ -31,12 +34,13 @@ import ( ) func TestGetVulnerabilityUseCase(t *testing.T) { - ctx := context.Background() err := zlog.NewSugaredDevLogger() if err != nil { t.Fatalf("an error '%s' was not expected when opening a sugared logger", err) } defer zlog.SyncZap() + ctx := ctxzap.ToContext(context.Background(), zlog.L) + s := ctxzap.Extract(ctx).Sugar() db, err := sqlx.Connect("sqlite3", ":memory:") if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) @@ -51,6 +55,10 @@ func TestGetVulnerabilityUseCase(t *testing.T) { if err != nil { t.Fatalf("an error '%v' was not expected when loading test data", err) } + serverConfig, err := config.NewServerConfig(nil) + if err != nil { + t.Fatalf("failed to load Config: %v", err) + } components := []dtos.ComponentDTO{ { Purl: "pkg:github/tseliot/screen-resolution-extra", @@ -62,12 +70,8 @@ func TestGetVulnerabilityUseCase(t *testing.T) { Purl: "pkg:github/candlepin/candlepin", }, } - myConfig, err := myconfig.NewServerConfig(nil) - if err != nil { - t.Fatalf("failed to load Config: %v", err) - } - vulnUc := NewLocalVulnerabilitiesUseCase(ctx, conn, myConfig) - vulns, err := vulnUc.GetVulnerabilities(components) + vulnUc := NewLocalVulnerabilitiesUseCase(ctx, s, serverConfig, db) + vulns, err := vulnUc.GetVulnerabilities(ctx, components) if err != nil { t.Fatalf("an error '%s' was not expected when getting vulnerabilities", err) } diff --git a/pkg/usecase/vulnerability_use_case.go b/pkg/usecase/vulnerability_use_case.go index e06cf73..ed61386 100644 --- a/pkg/usecase/vulnerability_use_case.go +++ b/pkg/usecase/vulnerability_use_case.go @@ -19,6 +19,7 @@ package usecase import ( "context" "errors" + "sync" "github.com/jmoiron/sqlx" "github.com/scanoss/go-models/pkg/scanoss" @@ -78,24 +79,40 @@ func (us VulnerabilityUseCase) Execute(ctx context.Context, components []dtos.Co } } + wg := sync.WaitGroup{} // Gets OSV vulnerabilities only if enabled var osvVulnerabilities = dtos.VulnerabilityOutput{} if us.config.Source.OSV.Enabled { - us.s.Debugf("vulnerabilities: OSV enabled") - osvUseCase := NewOSVUseCase(us.s, us.config) - osvVulnerabilities = osvUseCase.Execute(components) + wg.Add(1) + go func() { + defer wg.Done() + us.s.Debugf("vulnerabilities: OSV enabled") + osvUseCase := NewOSVUseCase(us.s, us.config) + osvVulnerabilities = osvUseCase.Execute(ctx, components) + }() } // ************* OSV Use case end *************** / // Search the KB for information about each Vulnerability var localVulnerabilities = dtos.VulnerabilityOutput{} + var localErr error if us.config.Source.SCANOSS.Enabled { - localVulUc := NewLocalVulnerabilitiesUseCase(ctx, conn, us.config) - localVulnerabilities, err = localVulUc.GetVulnerabilities(components) - if err != nil { - us.s.Errorf("Failed to get Vulnerabilities: %v", err) - return dtos.VulnerabilityOutput{}, errors.New("problems encountered extracting vulnerability data") - } + wg.Add(1) + go func() { + defer wg.Done() + localVulUc := NewLocalVulnerabilitiesUseCase(ctx, us.s, us.config, us.db) + localVulnerabilities, err = localVulUc.GetVulnerabilities(ctx, components) + if err != nil { + us.s.Errorf("Failed to get Vulnerabilities: %v", err) + localErr = errors.New("problems encountered extracting vulnerability data") + return + } + }() } + wg.Wait() + if localErr != nil { + return dtos.VulnerabilityOutput{}, localErr + } + // Merge OSV and local vulnerabilities in one response. Avoids duplicated vulnerabilities := helpers.MergeOSVAndLocalVulnerabilities(localVulnerabilities, osvVulnerabilities) // Add EPSS data