Skip to content
Open
2 changes: 2 additions & 0 deletions api/adscert.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ message RequestInfo {
bytes url_hash = 2;
bytes body_hash = 3;
repeated SignatureInfo signature_info = 4;
// useful if 1 signatory is managing multiple origin domains such as in resellers case.
string origin_domain = 5;
}

// SignatureInfo captures the signature generated for the signing request. It
Expand Down
34 changes: 26 additions & 8 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"flag"
"strings"
"time"

"github.com/IABTechLab/adscert/internal/server"
Expand All @@ -17,22 +18,40 @@ var (
origin = flag.String("origin", utils.GetEnvVarString("ORIGIN", ""), "ads.cert Call Sign domain name for this party's Signatory service deployment")
domainCheckInterval = flag.Duration("domain_check_interval", time.Duration(utils.GetEnvVarInt("DOMAIN_CHECK_INTERVAL", 30))*time.Second, "interval for checking domain records")
domainRenewalInterval = flag.Duration("domain_renewal_interval", time.Duration(utils.GetEnvVarInt("DOMAIN_RENEWAL_INTERVAL", 300))*time.Second, "interval before considering domain records for renewal")
privateKey = flag.String("private_key", utils.GetEnvVarString("PRIVATE_KEY", ""), "base-64 encoded private key")
)

type privateKeyFlags []string

func (i *privateKeyFlags) String() string {
return strings.Join(*i, ",")
}

func (i *privateKeyFlags) Set(value string) error {
if value != "" {
for _, v := range strings.Split(value, ",") {
*i = append(*i, v)
}
}
return nil
}

func main() {
var privateKeys privateKeyFlags
flag.Var(&privateKeys, "private_key", "base-64 encoded private key")

if value := utils.GetEnvVarString("PRIVATE_KEY", ""); value != "" {
for _, k := range strings.Split(value, ",") {
privateKeys = append(privateKeys, k)
}
}

flag.Parse()

parsedLogLevel := logger.GetLevelFromString(*logLevel)
logger.SetLevel(parsedLogLevel)
logger.Infof("Log Level: %s, parsed as iota %v", *logLevel, parsedLogLevel)

if *origin == "" {
logger.Fatalf("Origin ads.cert Call Sign domain name is required")
}

if *privateKey == "" {
if len(privateKeys) == 0 {
logger.Fatalf("Private key is required")
}

Expand All @@ -45,11 +64,10 @@ func main() {
}()

logger.Infof("Starting AdsCert API server")
logger.Infof("Origin ads.cert Call Sign domain: %v", *origin)
logger.Infof("Port: %v", *serverPort)

grpcServer := grpc.NewServer()
server.SetUpAdsCertSignatoryServer(grpcServer, *origin, *domainCheckInterval, *domainRenewalInterval, []string{*privateKey})
server.SetUpAdsCertSignatoryServer(grpcServer, *origin, *domainCheckInterval, *domainRenewalInterval, privateKeys)
if err := server.StartServingRequests(grpcServer, *serverPort); err != nil {
logger.Fatalf("gRPC server failure: %v", err)
}
Expand Down
4 changes: 4 additions & 0 deletions examples/signer-client/signer-client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

var (
serverAddress = flag.String("server_address", "localhost:3000", "address of grpc server")
originDomain = flag.String("origin_domain", "", "Origin domain")
destinationURL = flag.String("url", "https://google.com/gen_204", "URL to invoke")
body = flag.String("body", "", "POST request body")
signingTimeout = flag.Duration("signing_timeout", 5*time.Millisecond, "Specifies how long this client will wait for signing to finish before abandoning.")
Expand Down Expand Up @@ -49,6 +50,9 @@ func main() {
// destination URL and body, setting these value on the RequestInfo message.
reqInfo := &api.RequestInfo{}
signatory.SetRequestInfo(reqInfo, *destinationURL, []byte(*body))
if originDomain != nil {
reqInfo.OriginDomain = *originDomain
}

// Request the signature.
logger.Infof("signing request for url: %v", *destinationURL)
Expand Down
4 changes: 4 additions & 0 deletions internal/server/server_reference_implementation.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import (
"fmt"
"net"
"net/http"
"strings"
"time"

"github.com/IABTechLab/adscert/pkg/adscert/api"
"github.com/IABTechLab/adscert/pkg/adscert/discovery"
"github.com/IABTechLab/adscert/pkg/adscert/logger"
"github.com/IABTechLab/adscert/pkg/adscert/metrics"
"github.com/IABTechLab/adscert/pkg/adscert/server"
"github.com/IABTechLab/adscert/pkg/adscert/signatory"
Expand All @@ -31,6 +33,8 @@ func SetUpAdsCertSignatoryServer(grpcServer *grpc.Server, adscertCallSign string
domainRenewalInterval,
privateKeys)

logger.Debugf("Origin ads.cert Call Sign domains: %v", strings.Join(signatoryApi.GetOriginCallsigns(), ","))

handler := &server.AdsCertSignatoryServer{
SignatoryAPI: signatoryApi,
}
Expand Down
305 changes: 158 additions & 147 deletions pkg/adscert/api/adscert.pb.go

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions pkg/adscert/api/adscert_grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pkg/adscert/discovery/domain_indexer_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ import "time"
type DomainIndexer interface {
LookupIdentitiesForDomain(domain string) ([]DomainInfo, error)
GetLastRun() time.Time
GetOriginCallsigns() []string
}
50 changes: 31 additions & 19 deletions pkg/adscert/discovery/domain_indexer_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func NewDefaultDomainIndexer(dnsResolver DNSResolver, domainStore DomainStore, d
domainRenewalInterval: domainRenewalInterval,
dnsResolver: dnsResolver,
domainStore: domainStore,
currentPrivateKey: make(map[string]keyAlias),
}

myPrivateKeys, err := privateKeysToKeyMap(base64PrivateKeys)
Expand All @@ -39,12 +40,14 @@ func NewDefaultDomainIndexer(dnsResolver DNSResolver, domainStore DomainStore, d
}
di.myPrivateKeys = myPrivateKeys

for _, privateKey := range di.myPrivateKeys {
// since iterating over a map is non-deterministic, we can make sure to set the key
// either if it is not already set or it is alphabetically less than current key at the index when
// iterating over the private keys map.
if di.currentPrivateKey == "" || di.currentPrivateKey < privateKey.alias {
di.currentPrivateKey = privateKey.alias
for originCallsign := range di.myPrivateKeys {
for _, privateKey := range di.myPrivateKeys[originCallsign] {
// since iterating over a map is non-deterministic, we can make sure to set the key
// either if it is not already set or it is alphabetically less than current key at the index when
// iterating over the private keys map.
if di.currentPrivateKey[originCallsign] == "" || di.currentPrivateKey[originCallsign] < privateKey.alias {
di.currentPrivateKey[originCallsign] = privateKey.alias
}
}
}

Expand All @@ -62,13 +65,21 @@ type defaultDomainIndexer struct {
lastRun time.Time
lastRunLock sync.RWMutex

myPrivateKeys keyMap
currentPrivateKey keyAlias
myPrivateKeys map[string]keyMap
currentPrivateKey map[string]keyAlias

dnsResolver DNSResolver
domainStore DomainStore
}

func (di *defaultDomainIndexer) GetOriginCallsigns() []string {
var originCallsigns []string
for oc := range di.myPrivateKeys {
originCallsigns = append(originCallsigns, oc)
}
return originCallsigns
}

func (di *defaultDomainIndexer) GetLastRun() time.Time {
di.lastRunLock.RLock()
t := di.lastRun
Expand Down Expand Up @@ -227,20 +238,21 @@ func (di *defaultDomainIndexer) checkDomainForKeyRecords(ctx context.Context, cu
}

// create shared secrets for each private key + public key combination
for _, myKey := range di.myPrivateKeys {
for _, theirKey := range currentDomainInfo.allPublicKeys {
keyPairAlias := newKeyPairAlias(myKey.alias, theirKey.alias)
if currentDomainInfo.allSharedSecrets[keyPairAlias] == nil {
currentDomainInfo.allSharedSecrets[keyPairAlias], err = calculateSharedSecret(myKey, theirKey)
if err != nil {
logger.Warningf("error calculating shared secret for record %s: %v", currentDomainInfo.Domain, err)
currentDomainInfo.domainStatus = DomainStatusErrorOnSharedSecretCalculation
for originCallsign := range di.myPrivateKeys {
for _, myKey := range di.myPrivateKeys[originCallsign] {
for _, theirKey := range currentDomainInfo.allPublicKeys {
keyPairAlias := newKeyPairAlias(myKey.alias, theirKey.alias)
if currentDomainInfo.allSharedSecrets[keyPairAlias] == nil {
currentDomainInfo.allSharedSecrets[keyPairAlias], err = calculateSharedSecret(myKey, theirKey)
if err != nil {
logger.Warningf("error calculating shared secret for record %s: %v", currentDomainInfo.Domain, err)
currentDomainInfo.domainStatus = DomainStatusErrorOnSharedSecretCalculation
}
}
}
}
currentDomainInfo.currentSharedSecretId[originCallsign] = newKeyPairAlias(di.currentPrivateKey[originCallsign], currentDomainInfo.currentPublicKeyId)
}

currentDomainInfo.currentSharedSecretId = newKeyPairAlias(di.currentPrivateKey, currentDomainInfo.currentPublicKeyId)
currentDomainInfo.lastUpdateTime = time.Now()
}

Expand Down Expand Up @@ -312,7 +324,7 @@ func initializeDomainInfo(domain string) DomainInfo {
Domain: domain,
IdentityDomains: []string{},
currentPublicKeyId: "",
currentSharedSecretId: keyPairAlias{},
currentSharedSecretId: map[string]keyPairAlias{},
allPublicKeys: map[keyAlias]*x25519Key{},
allSharedSecrets: keyPairMap{},
domainStatus: DomainStatusNotYetChecked,
Expand Down
6 changes: 3 additions & 3 deletions pkg/adscert/discovery/domain_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ type DomainInfo struct {
Domain string // root domain for this record, can be invoking or identity domain
IdentityDomains []string // used to map from invoking domain to parent identity domains
currentPublicKeyId keyAlias
currentSharedSecretId keyPairAlias
currentSharedSecretId map[string]keyPairAlias
allPublicKeys keyMap
allSharedSecrets keyPairMap

Expand All @@ -30,7 +30,7 @@ func (c *DomainInfo) GetStatus() DomainStatus {
return c.domainStatus
}

func (c *DomainInfo) GetSharedSecret() (SharedSecret, bool) {
sharedSecret, ok := c.allSharedSecrets[c.currentSharedSecretId]
func (c *DomainInfo) GetSharedSecret(originDomain string) (SharedSecret, bool) {
sharedSecret, ok := c.allSharedSecrets[c.currentSharedSecretId[originDomain]]
return sharedSecret, ok
}
23 changes: 17 additions & 6 deletions pkg/adscert/discovery/internal_base_key.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package discovery

import (
"errors"
"fmt"
"strings"

"github.com/IABTechLab/adscert/internal/formats"
"github.com/IABTechLab/adscert/pkg/adscert/logger"
Expand Down Expand Up @@ -76,11 +78,14 @@ func calculateSharedSecret(originPrivateKey *x25519Key, remotePublicKey *x25519K
return result, err
}

func privateKeysToKeyMap(privateKeys []string) (keyMap, error) {
result := keyMap{}

func privateKeysToKeyMap(privateKeys []string) (map[string]keyMap, error) {
results := map[string]keyMap{}
for _, privateKeyBase64 := range privateKeys {
privateKey, err := parseKeyFromString(privateKeyBase64)
sp := strings.SplitN(privateKeyBase64, "=", 2)
if len(sp) < 2 {
return nil, errors.New("missing origin callsign")
}
privateKey, err := parseKeyFromString(sp[1])
if err != nil {
return nil, err
}
Expand All @@ -90,10 +95,16 @@ func privateKeysToKeyMap(privateKeys []string) (keyMap, error) {

keyAlias := keyAlias(formats.ExtractKeyAliasFromPublicKeyBase64(formats.EncodeKeyBase64(publicBytes[:])))
privateKey.alias = keyAlias
result[keyAlias] = privateKey

km := results[sp[0]]
if km == nil {
km = keyMap{}
}
km[keyAlias] = privateKey
results[sp[0]] = km
}

return result, nil
return results, nil
}

func parseKeyFromString(base64EncodedKey string) (*x25519Key, error) {
Expand Down
Loading