diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs index 81b7da3..7e91be9 100644 --- a/src/mcp/mod.rs +++ b/src/mcp/mod.rs @@ -3,7 +3,9 @@ use crate::core::models::SearchOptions; use crate::core::patent_search::{PatentSearch, PatentSearcher}; use rmcp::{ handler::server::{tool::ToolRouter, wrapper::Parameters}, - model::{Implementation, ProtocolVersion, ServerCapabilities, ServerInfo, ToolsCapability}, + model::{ + ErrorCode, Implementation, ProtocolVersion, ServerCapabilities, ServerInfo, ToolsCapability, + }, schemars::{self, JsonSchema}, service::{NotificationContext, RequestContext}, tool, tool_handler, tool_router, ErrorData, RoleServer, ServerHandler, ServiceExt, @@ -69,7 +71,7 @@ impl PatentHandler { pub async fn search_patents( &self, Parameters(request): Parameters, - ) -> String { + ) -> Result { let options = SearchOptions { query: request.query, assignee: request.assignee, @@ -81,10 +83,13 @@ impl PatentHandler { language: request.language, }; - match self.searcher.search(&options).await { - Ok(results) => serde_json::to_string_pretty(&results).unwrap_or_default(), - Err(e) => format!("Search failed: {}", e), - } + self.searcher + .search(&options) + .await + .map(|results| serde_json::to_string_pretty(&results).unwrap_or_default()) + .map_err(|e| { + ErrorData::new(ErrorCode::INTERNAL_ERROR, format!("Search failed: {}", e), None) + }) } /// Fetch details of a specific patent by ID @@ -92,13 +97,18 @@ impl PatentHandler { pub async fn fetch_patent( &self, Parameters(request): Parameters, - ) -> String { + ) -> Result { if request.raw { - match self.searcher.get_raw_html(&request.patent_id, request.language.as_deref()).await - { - Ok(html) => html, - Err(e) => format!("Failed to fetch raw HTML: {}", e), - } + self.searcher + .get_raw_html(&request.patent_id, request.language.as_deref()) + .await + .map_err(|e| { + ErrorData::new( + ErrorCode::INTERNAL_ERROR, + format!("Failed to fetch raw HTML: {}", e), + None, + ) + }) } else { let options = SearchOptions { query: None, @@ -112,10 +122,20 @@ impl PatentHandler { }; match self.searcher.search(&options).await { Ok(mut results) => results.patents.pop().map_or_else( - || format!("No patent found with ID: {}", request.patent_id), - |patent| serde_json::to_string_pretty(&patent).unwrap_or_default(), + || { + Err(ErrorData::new( + ErrorCode::INVALID_PARAMS, + format!("No patent found with ID: {}", request.patent_id), + None, + )) + }, + |patent| Ok(serde_json::to_string_pretty(&patent).unwrap_or_default()), ), - Err(e) => format!("Fetch failed: {}", e), + Err(e) => Err(ErrorData::new( + ErrorCode::INTERNAL_ERROR, + format!("Fetch failed: {}", e), + None, + )), } } } @@ -240,8 +260,10 @@ mod tests { language: None, }; let result = handler.search_patents(Parameters(request)).await; - assert!(result.contains("SEARCH1")); - assert!(result.contains("Search Result")); + assert!(result.is_ok()); + let result_str = result.unwrap(); + assert!(result_str.contains("SEARCH1")); + assert!(result_str.contains("Search Result")); } #[tokio::test] @@ -252,24 +274,30 @@ mod tests { let request = FetchPatentRequest { patent_id: "US123".to_string(), raw: false, language: None }; let result = handler.fetch_patent(Parameters(request)).await; - assert!(result.contains("US123")); + assert!(result.is_ok()); + assert!(result.unwrap().contains("US123")); // Raw HTML case let request = FetchPatentRequest { patent_id: "US123".to_string(), raw: true, language: None }; let result = handler.fetch_patent(Parameters(request)).await; - assert_eq!(result, "US123"); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "US123"); // Not found case let request = FetchPatentRequest { patent_id: "NONE".to_string(), raw: false, language: None }; let result = handler.fetch_patent(Parameters(request)).await; - assert!(result.contains("No patent found")); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.message.contains("No patent found")); // Error case let request = FetchPatentRequest { patent_id: "FAIL".to_string(), raw: false, language: None }; let result = handler.fetch_patent(Parameters(request)).await; - assert!(result.contains("Fetch failed")); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.message.contains("Fetch failed")); } }