From 14279df0d0c60ee7121d93b253a1d3e8d80111aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20Sch=C3=BCnemann?= Date: Fri, 14 Jan 2022 14:36:06 +0100 Subject: [PATCH] feat(dns_sd): allow multiple services of the same service type e.g. "_http._tcp" for one host --- dns_sd/dns_sd.c | 192 +++++++++++++++++++----------------------------- dns_sd/dns_sd.h | 11 ++- 2 files changed, 82 insertions(+), 121 deletions(-) diff --git a/dns_sd/dns_sd.c b/dns_sd/dns_sd.c index 5cec1992..db9fe6c7 100644 --- a/dns_sd/dns_sd.c +++ b/dns_sd/dns_sd.c @@ -203,54 +203,6 @@ MdnsState dnsSdGetState(DnsSdContext *context) } -/** - * @brief Set service instance name - * @param[in] context Pointer to the DNS-SD context - * @param[in] instanceName NULL-terminated string that contains the service - * instance name - * @return Error code - **/ - -error_t dnsSdSetInstanceName(DnsSdContext *context, const char_t *instanceName) -{ - NetInterface *interface; - - //Check parameters - if(context == NULL || instanceName == NULL) - return ERROR_INVALID_PARAMETER; - - //Get exclusive access - osAcquireMutex(&netMutex); - - //Point to the underlying network interface - interface = context->settings.interface; - - //Any registered services? - if(dnsSdGetNumServices(context) > 0) - { - //Check whether the link is up - if(interface->linkState) - { - //Send a goodbye packet - dnsSdSendGoodbye(context, NULL); - } - } - - //Set instance name - strSafeCopy(context->instanceName, instanceName, - DNS_SD_MAX_INSTANCE_NAME_LEN); - - //Restart probing process - dnsSdStartProbing(context); - - //Release exclusive access - osReleaseMutex(&netMutex); - - //Successful processing - return NO_ERROR; -} - - /** * @brief Register a DNS-SD service * @param[in] context Pointer to the DNS-SD context @@ -265,6 +217,7 @@ error_t dnsSdSetInstanceName(DnsSdContext *context, const char_t *instanceName) **/ error_t dnsSdRegisterService(DnsSdContext *context, const char_t *serviceName, + const char_t *serviceType, uint16_t priority, uint16_t weight, uint16_t port, const char_t *metadata) { error_t error; @@ -276,7 +229,7 @@ error_t dnsSdRegisterService(DnsSdContext *context, const char_t *serviceName, DnsSdService *firstFreeEntry; //Check parameters - if(context == NULL || serviceName == NULL || metadata == NULL) + if(context == NULL || serviceName == NULL || serviceType == NULL || metadata == NULL) return ERROR_INVALID_PARAMETER; //Get exclusive access @@ -292,10 +245,11 @@ error_t dnsSdRegisterService(DnsSdContext *context, const char_t *serviceName, entry = &context->serviceList[i]; //Check if the entry is currently in use - if(entry->name[0] != '\0') + if(entry->serviceName[0] != '\0') { //Check whether the specified service is already registered - if(!osStrcasecmp(entry->name, serviceName)) + if(!osStrcasecmp(entry->serviceName, serviceName) && + !osStrcasecmp(entry->serviceType, serviceType)) break; } else @@ -315,7 +269,8 @@ error_t dnsSdRegisterService(DnsSdContext *context, const char_t *serviceName, if(entry != NULL) { //Service name - strSafeCopy(entry->name, serviceName, DNS_SD_MAX_SERVICE_NAME_LEN); + strSafeCopy(entry->serviceName, serviceName, DNS_SD_MAX_INSTANCE_NAME_LEN); + strSafeCopy(entry->serviceType, serviceType, DNS_SD_MAX_SERVICE_NAME_LEN); //Priority field entry->priority = priority; @@ -405,13 +360,14 @@ error_t dnsSdRegisterService(DnsSdContext *context, const char_t *serviceName, * @return Error code **/ -error_t dnsSdUnregisterService(DnsSdContext *context, const char_t *serviceName) +error_t dnsSdUnregisterService(DnsSdContext *context, const char_t *serviceName, + const char_t *serviceType) { uint_t i; DnsSdService *entry; //Check parameters - if(context == NULL || serviceName == NULL) + if(context == NULL || serviceName == NULL || serviceType == NULL) return ERROR_INVALID_PARAMETER; //Get exclusive access @@ -424,12 +380,13 @@ error_t dnsSdUnregisterService(DnsSdContext *context, const char_t *serviceName) entry = &context->serviceList[i]; //Service name found? - if(!osStrcasecmp(entry->name, serviceName)) + if(!osStrcasecmp(entry->serviceName, serviceName) && + !osStrcasecmp(entry->serviceType, serviceType)) { //Send a goodbye packet dnsSdSendGoodbye(context, entry); //Remove the service from the list - entry->name[0] = '\0'; + entry->serviceName[0] = '\0'; } } @@ -458,17 +415,13 @@ uint_t dnsSdGetNumServices(DnsSdContext *context) //Check parameter if(context != NULL) { - //Valid instance name? - if(context->instanceName[0] != '\0') - { //Loop through the list of registered services for(i = 0; i < DNS_SD_SERVICE_LIST_SIZE; i++) { //Check if the entry is currently in use - if(context->serviceList[i].name[0] != '\0') + if(context->serviceList[i].serviceName[0] != '\0') n++; } - } } //Return the number of registered services @@ -547,8 +500,17 @@ void dnsSdTick(DnsSdContext *context) //Probing failed? if(context->conflict && context->retransmitCount > 0) { - //Programmatically change the service instance name - dnsSdChangeInstanceName(context); + //Programmatically change the service names + uint_t i; + for(i = 0; i < DNS_SD_SERVICE_LIST_SIZE; i++) + { + //Valid service? + if(context->serviceList[i].serviceName[0] != '\0') + { + dnsSdChangeServiceName(&(context->serviceList[i])); + } + } + //Probe again, and repeat as necessary until a unique name is found dnsSdChangeState(context, MDNS_STATE_PROBING, 0); } @@ -720,7 +682,7 @@ void dnsSdChangeState(DnsSdContext *context, * @param[in] context Pointer to the DNS-SD context **/ -void dnsSdChangeInstanceName(DnsSdContext *context) +void dnsSdChangeServiceName(DnsSdService *service) { size_t i; size_t m; @@ -729,7 +691,7 @@ void dnsSdChangeInstanceName(DnsSdContext *context) char_t s[16]; //Retrieve the length of the string - n = osStrlen(context->instanceName); + n = osStrlen(service->serviceName); //Parse the string backwards for(i = n; i > 0; i--) @@ -738,22 +700,22 @@ void dnsSdChangeInstanceName(DnsSdContext *context) if(i == n) { //Check whether the last character is a bracket - if(context->instanceName[i - 1] != ')') + if(service->serviceName[i - 1] != ')') break; } else { //Check whether the current character is a digit - if(!osIsdigit(context->instanceName[i - 1])) + if(!osIsdigit(service->serviceName[i - 1])) break; } } //Any number following the service instance name? - if(context->instanceName[i] != '\0') + if(service->serviceName[i] != '\0') { //Retrieve the number at the end of the name - index = atoi(context->instanceName + i); + index = atoi(service->serviceName + i); //Increment the value index++; @@ -761,15 +723,15 @@ void dnsSdChangeInstanceName(DnsSdContext *context) if(i >= 2) { //Discard any space and bracket that may precede the number - if(context->instanceName[i - 2] == ' ' && - context->instanceName[i - 1] == '(') + if(service->serviceName[i - 2] == ' ' && + service->serviceName[i - 1] == '(') { i -= 2; } } //Strip the digits - context->instanceName[i] = '\0'; + service->serviceName[i] = '\0'; } else { @@ -784,7 +746,7 @@ void dnsSdChangeInstanceName(DnsSdContext *context) if((i + m) <= DNS_SD_MAX_INSTANCE_NAME_LEN) { //Programmatically change the service instance name - osStrcat(context->instanceName, s); + osStrcat(service->serviceName, s); } } @@ -828,10 +790,10 @@ error_t dnsSdSendProbe(DnsSdContext *context) service = &context->serviceList[i]; //Valid service? - if(service->name[0] != '\0') + if(service->serviceName[0] != '\0') { //Encode the service name using DNS notation - message.length += mdnsEncodeName(context->instanceName, service->name, + message.length += mdnsEncodeName(service->serviceName, service->serviceType, ".local", (uint8_t *) message.dnsHeader + message.length); //Point to the corresponding question structure @@ -863,7 +825,7 @@ error_t dnsSdSendProbe(DnsSdContext *context) service = &context->serviceList[i]; //Valid service? - if(service->name[0] != '\0') + if(service->serviceName[0] != '\0') { //Format SRV resource record error = dnsSdAddSrvRecord(interface, &message, @@ -942,7 +904,7 @@ error_t dnsSdSendAnnouncement(DnsSdContext *context) service = &context->serviceList[i]; //Valid service? - if(service->name[0] != '\0') + if(service->serviceName[0] != '\0') { //Format PTR resource record (service type enumeration) error = dnsSdAddServiceEnumPtrRecord(interface, @@ -1024,7 +986,7 @@ error_t dnsSdSendGoodbye(DnsSdContext *context, const DnsSdService *service) entry = &context->serviceList[i]; //Valid service? - if(entry->name[0] != '\0') + if(entry->serviceName[0] != '\0') { if(service == entry || service == NULL) { @@ -1144,7 +1106,7 @@ error_t dnsSdParseQuestion(NetInterface *interface, const MdnsMessage *query, service = &context->serviceList[i]; //Valid service? - if(service->name[0] != '\0') + if(service->serviceName[0] != '\0') { //Check the class of the query if(qclass == DNS_RR_CLASS_IN || qclass == DNS_RR_CLASS_ANY) @@ -1168,7 +1130,7 @@ error_t dnsSdParseQuestion(NetInterface *interface, const MdnsMessage *query, } } else if(!mdnsCompareName(query->dnsHeader, query->length, - offset, "", service->name, ".local", 0)) + offset, "", service->serviceType, ".local", 0)) { //PTR query? if(qtype == DNS_RR_TYPE_PTR || qtype == DNS_RR_TYPE_ANY) @@ -1185,7 +1147,7 @@ error_t dnsSdParseQuestion(NetInterface *interface, const MdnsMessage *query, } } else if(!mdnsCompareName(query->dnsHeader, query->length, offset, - context->instanceName, service->name, ".local", 0)) + service->serviceName, service->serviceType, ".local", 0)) { //SRV query? if(qtype == DNS_RR_TYPE_SRV || qtype == DNS_RR_TYPE_ANY) @@ -1264,11 +1226,11 @@ void dnsSdParseNsRecord(NetInterface *interface, const MdnsMessage *query, service = &context->serviceList[i]; //Valid service? - if(service->name[0] != '\0') + if(service->serviceName[0] != '\0') { //Apply tie-breaking rules if(!mdnsCompareName(query->dnsHeader, query->length, offset, - context->instanceName, service->name, ".local", 0)) + service->serviceName, service->serviceType, ".local", 0)) { //Convert the class to host byte order rclass = ntohs(record->rclass); @@ -1309,7 +1271,7 @@ void dnsSdParseNsRecord(NetInterface *interface, const MdnsMessage *query, offset = srvRecord->target - (uint8_t *) query->dnsHeader; if(mdnsCompareName(query->dnsHeader, query->length, offset, - context->instanceName, "", ".local", 0) > 0) + service->serviceName, "", ".local", 0) > 0) { //The host has lost the tie-break context->tieBreakLost = TRUE; @@ -1358,11 +1320,11 @@ void dnsSdParseAnRecord(NetInterface *interface, const MdnsMessage *response, service = &context->serviceList[i]; //Valid service? - if(service->name[0] != '\0') + if(service->serviceName[0] != '\0') { //Check for conflicts if(!mdnsCompareName(response->dnsHeader, response->length, offset, - context->instanceName, service->name, ".local", 0)) + service->serviceName, service->serviceType, ".local", 0)) { //Convert the class to host byte order rclass = ntohs(record->rclass); @@ -1383,7 +1345,7 @@ void dnsSdParseAnRecord(NetInterface *interface, const MdnsMessage *response, //response message containing a record with the same name, rrtype //and rrclass, but inconsistent rdata if(mdnsCompareName(response->dnsHeader, response->length, offset, - context->instanceName, "", ".local", 0)) + service->serviceName, "", ".local", 0)) { //The service instance name is already in use by some other host context->conflict = TRUE; @@ -1492,7 +1454,7 @@ void dnsSdGenerateAdditionalRecords(NetInterface *interface, service = &context->serviceList[j]; //Valid service? - if(service->name[0] != '\0') + if(service->serviceName[0] != '\0') { //Check the class of the resource record if(rclass == DNS_RR_CLASS_IN) @@ -1502,7 +1464,7 @@ void dnsSdGenerateAdditionalRecords(NetInterface *interface, { //Compare service name if(!mdnsCompareName(response->dnsHeader, response->length, - offset, "", service->name, ".local", 0)) + offset, "", service->serviceType, ".local", 0)) { //Format SRV resource record error = dnsSdAddSrvRecord(interface, @@ -1524,7 +1486,7 @@ void dnsSdGenerateAdditionalRecords(NetInterface *interface, { //Compare service name if(!mdnsCompareName(response->dnsHeader, response->length, - offset, context->instanceName, service->name, ".local", 0)) + offset, service->serviceName, service->serviceType, ".local", 0)) { //Format TXT resource record error = dnsSdAddTxtRecord(interface, @@ -1595,14 +1557,14 @@ error_t dnsSdAddServiceEnumPtrRecord(NetInterface *interface, offset += sizeof(DnsResourceRecord); //The first pass calculates the length of the DNS encoded service name - n = mdnsEncodeName("", service->name, ".local", NULL); + n = mdnsEncodeName("", service->serviceType, ".local", NULL); //Check the length of the resulting mDNS message if((offset + n) > MDNS_MESSAGE_MAX_SIZE) return ERROR_MESSAGE_TOO_LONG; //The second pass encodes the service name using DNS notation - n = mdnsEncodeName("", service->name, + n = mdnsEncodeName("", service->serviceType, ".local", record->rdata); //Convert length field to network byte order @@ -1641,8 +1603,8 @@ error_t dnsSdAddPtrRecord(NetInterface *interface, //Check whether the resource record is already present in the Answer //Section of the message - duplicate = mdnsCheckDuplicateRecord(message, "", - service->name, ".local", DNS_RR_TYPE_PTR); + duplicate = mdnsCheckDuplicateRecord(message, service->serviceName, + service->serviceType, ".local", DNS_RR_TYPE_PTR); //The duplicates should be suppressed and the resource record should //appear only once in the list @@ -1652,14 +1614,14 @@ error_t dnsSdAddPtrRecord(NetInterface *interface, offset = message->length; //The first pass calculates the length of the DNS encoded service name - n = mdnsEncodeName("", service->name, ".local", NULL); + n = mdnsEncodeName("", service->serviceType, ".local", NULL); //Check the length of the resulting mDNS message if((offset + n) > MDNS_MESSAGE_MAX_SIZE) return ERROR_MESSAGE_TOO_LONG; //Encode the service name using the DNS name notation - offset += mdnsEncodeName("", service->name, + offset += mdnsEncodeName("", service->serviceType, ".local", (uint8_t *) message->dnsHeader + offset); //Consider the length of the resource record itself @@ -1678,15 +1640,15 @@ error_t dnsSdAddPtrRecord(NetInterface *interface, offset += sizeof(DnsResourceRecord); //The first pass calculates the length of the DNS encoded instance name - n = mdnsEncodeName(context->instanceName, service->name, ".local", NULL); + n = mdnsEncodeName(service->serviceName, service->serviceType, ".local", NULL); //Check the length of the resulting mDNS message if((offset + n) > MDNS_MESSAGE_MAX_SIZE) return ERROR_MESSAGE_TOO_LONG; //The second pass encodes the instance name using DNS notation - n = mdnsEncodeName(context->instanceName, - service->name, ".local", record->rdata); + n = mdnsEncodeName(service->serviceName, + service->serviceType, ".local", record->rdata); //Convert length field to network byte order record->rdlength = htons(n); @@ -1729,8 +1691,8 @@ error_t dnsSdAddSrvRecord(NetInterface *interface, MdnsMessage *message, //Check whether the resource record is already present in the Answer //Section of the message - duplicate = mdnsCheckDuplicateRecord(message, dnsSdContext->instanceName, - service->name, ".local", DNS_RR_TYPE_SRV); + duplicate = mdnsCheckDuplicateRecord(message, service->serviceName, + service->serviceType, ".local", DNS_RR_TYPE_SRV); //The duplicates should be suppressed and the resource record should //appear only once in the list @@ -1740,16 +1702,16 @@ error_t dnsSdAddSrvRecord(NetInterface *interface, MdnsMessage *message, offset = message->length; //The first pass calculates the length of the DNS encoded instance name - n = mdnsEncodeName(dnsSdContext->instanceName, - service->name, ".local", NULL); + n = mdnsEncodeName(service->serviceName, + service->serviceType, ".local", NULL); //Check the length of the resulting mDNS message if((offset + n) > MDNS_MESSAGE_MAX_SIZE) return ERROR_MESSAGE_TOO_LONG; //The second pass encodes the instance name using DNS notation - offset += mdnsEncodeName(dnsSdContext->instanceName, - service->name, ".local", (uint8_t *) message->dnsHeader + offset); + offset += mdnsEncodeName(service->serviceName, + service->serviceType, ".local", (uint8_t *) message->dnsHeader + offset); //Consider the length of the resource record itself if((offset + sizeof(DnsSrvResourceRecord)) > MDNS_MESSAGE_MAX_SIZE) @@ -1824,8 +1786,8 @@ error_t dnsSdAddTxtRecord(NetInterface *interface, MdnsMessage *message, //Check whether the resource record is already present in the Answer //Section of the message - duplicate = mdnsCheckDuplicateRecord(message, context->instanceName, - service->name, ".local", DNS_RR_TYPE_TXT); + duplicate = mdnsCheckDuplicateRecord(message, service->serviceName, + service->serviceType, ".local", DNS_RR_TYPE_TXT); //The duplicates should be suppressed and the resource record should //appear only once in the list @@ -1835,15 +1797,15 @@ error_t dnsSdAddTxtRecord(NetInterface *interface, MdnsMessage *message, offset = message->length; //The first pass calculates the length of the DNS encoded instance name - n = mdnsEncodeName(context->instanceName, service->name, ".local", NULL); + n = mdnsEncodeName(service->serviceName, service->serviceType, ".local", NULL); //Check the length of the resulting mDNS message if((offset + n) > MDNS_MESSAGE_MAX_SIZE) return ERROR_MESSAGE_TOO_LONG; //The second pass encodes the instance name using DNS notation - offset += mdnsEncodeName(context->instanceName, - service->name, ".local", (uint8_t *) message->dnsHeader + offset); + offset += mdnsEncodeName(service->serviceName, + service->serviceType, ".local", (uint8_t *) message->dnsHeader + offset); //Consider the length of the resource record itself if((offset + sizeof(DnsResourceRecord)) > MDNS_MESSAGE_MAX_SIZE) @@ -1909,8 +1871,8 @@ error_t dnsSdAddNsecRecord(NetInterface *interface, MdnsMessage *message, //Check whether the resource record is already present in the Answer //Section of the message - duplicate = mdnsCheckDuplicateRecord(message, context->instanceName, - service->name, ".local", DNS_RR_TYPE_NSEC); + duplicate = mdnsCheckDuplicateRecord(message, service->serviceName, + service->serviceType, ".local", DNS_RR_TYPE_NSEC); //The duplicates should be suppressed and the resource record should //appear only once in the list @@ -1936,14 +1898,14 @@ error_t dnsSdAddNsecRecord(NetInterface *interface, MdnsMessage *message, offset = message->length; //The first pass calculates the length of the DNS encoded instance name - n = mdnsEncodeName(context->instanceName, service->name, ".local", NULL); + n = mdnsEncodeName(service->serviceName, service->serviceType, ".local", NULL); //Check the length of the resulting mDNS message if((offset + n) > MDNS_MESSAGE_MAX_SIZE) return ERROR_MESSAGE_TOO_LONG; //The second pass encodes the instance name using the DNS name notation - offset += mdnsEncodeName(context->instanceName, service->name, + offset += mdnsEncodeName(service->serviceName, service->serviceType, ".local", (uint8_t *) message->dnsHeader + offset); //Consider the length of the resource record itself @@ -1970,7 +1932,7 @@ error_t dnsSdAddNsecRecord(NetInterface *interface, MdnsMessage *message, return ERROR_MESSAGE_TOO_LONG; //The Next Domain Name field contains the record's own name - mdnsEncodeName(context->instanceName, service->name, + mdnsEncodeName(service->serviceName, service->serviceType, ".local", record->rdata); //DNS NSEC record is limited to Window Block number zero diff --git a/dns_sd/dns_sd.h b/dns_sd/dns_sd.h index 94ac318f..9e2c4d0e 100644 --- a/dns_sd/dns_sd.h +++ b/dns_sd/dns_sd.h @@ -122,7 +122,8 @@ typedef struct typedef struct { - char_t name[DNS_SD_MAX_SERVICE_NAME_LEN + 1]; ///