diff --git a/base/src/main/java/com/tinyengine/it/common/exception/GlobalExceptionAdvice.java b/base/src/main/java/com/tinyengine/it/common/exception/GlobalExceptionAdvice.java index 1b07b66d..a4e8d514 100644 --- a/base/src/main/java/com/tinyengine/it/common/exception/GlobalExceptionAdvice.java +++ b/base/src/main/java/com/tinyengine/it/common/exception/GlobalExceptionAdvice.java @@ -69,7 +69,7 @@ public Result> handleNullPointerException(HttpServletRequest * @param e the e * @return the result */ - @ResponseStatus(HttpStatus.OK) + @ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR) @ExceptionHandler(ServiceException.class) public Result> handleServiceException(ServiceException e) { // 修改为 log.error,传递异常对象以打印堆栈信息 diff --git a/base/src/main/java/com/tinyengine/it/common/utils/SM4Utils.java b/base/src/main/java/com/tinyengine/it/common/utils/SM4Utils.java new file mode 100644 index 00000000..8c613b64 --- /dev/null +++ b/base/src/main/java/com/tinyengine/it/common/utils/SM4Utils.java @@ -0,0 +1,71 @@ +package com.tinyengine.it.common.utils; + +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import javax.crypto.Cipher; +import javax.crypto.KeyGenerator; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.security.SecureRandom; +import java.security.Security; +import java.util.Base64; + +public class SM4Utils { + + static { + Security.addProvider(new BouncyCastleProvider()); + } + + private static final String ALGORITHM = "SM4"; + private static final String TRANSFORMATION_ECB = "SM4/ECB/PKCS5Padding"; + private static final int KEY_SIZE = 128; + + /** + * 生成 SM4 密钥 + */ + public static String generateKeyBase64() throws Exception { + byte[] key = generateKey(); + return Base64.getEncoder().encodeToString(key); + } + + public static byte[] generateKey() throws Exception { + KeyGenerator kg = KeyGenerator.getInstance(ALGORITHM, "BC"); + kg.init(KEY_SIZE, new SecureRandom()); + SecretKey secretKey = kg.generateKey(); + return secretKey.getEncoded(); + } + + /** + * ECB 模式加密 - 只加密API密钥值 (Base64 结果) + */ + public static String encryptECB(String apiKey, String base64Key) throws Exception { + byte[] key = Base64.getDecoder().decode(base64Key); + byte[] encrypted = encryptECB(apiKey.getBytes("UTF-8"), key); + return Base64.getEncoder().encodeToString(encrypted); + } + + /** + * ECB 模式解密 - 直接返回API密钥 + */ + public static String decryptECB(String encryptedBase64, String base64Key) throws Exception { + byte[] key = Base64.getDecoder().decode(base64Key); + byte[] encrypted = Base64.getDecoder().decode(encryptedBase64); + byte[] decrypted = decryptECB(encrypted, key); + return new String(decrypted, "UTF-8"); + } + + // ECB 模式的底层方法保持不变 + private static byte[] encryptECB(byte[] data, byte[] key) throws Exception { + SecretKeySpec secretKeySpec = new SecretKeySpec(key, ALGORITHM); + Cipher cipher = Cipher.getInstance(TRANSFORMATION_ECB, "BC"); + cipher.init(Cipher.ENCRYPT_MODE, secretKeySpec); + return cipher.doFinal(data); + } + + private static byte[] decryptECB(byte[] encryptedData, byte[] key) throws Exception { + SecretKeySpec secretKeySpec = new SecretKeySpec(key, ALGORITHM); + Cipher cipher = Cipher.getInstance(TRANSFORMATION_ECB, "BC"); + cipher.init(Cipher.DECRYPT_MODE, secretKeySpec); + return cipher.doFinal(encryptedData); + } + +} diff --git a/base/src/main/java/com/tinyengine/it/controller/AiChatController.java b/base/src/main/java/com/tinyengine/it/controller/AiChatController.java index 2b6eb3e9..e0086c78 100644 --- a/base/src/main/java/com/tinyengine/it/controller/AiChatController.java +++ b/base/src/main/java/com/tinyengine/it/controller/AiChatController.java @@ -12,7 +12,10 @@ package com.tinyengine.it.controller; +import com.tinyengine.it.common.base.Result; +import com.tinyengine.it.common.exception.ExceptionEnum; import com.tinyengine.it.common.log.SystemControllerLog; +import com.tinyengine.it.model.dto.AiToken; import com.tinyengine.it.model.dto.ChatRequest; import com.tinyengine.it.service.app.v1.AiChatV1Service; @@ -24,7 +27,6 @@ import io.swagger.v3.oas.annotations.tags.Tag; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.validation.annotation.Validated; @@ -50,6 +52,7 @@ public class AiChatController { */ @Autowired private AiChatV1Service aiChatV1Service; + /** * AI api * @@ -66,23 +69,29 @@ public class AiChatController { }) @SystemControllerLog(description = "AI chat") @PostMapping("/ai/chat") - public ResponseEntity aiChat(@RequestBody ChatRequest request) { - try { - Object response = aiChatV1Service.chatCompletion(request); + public ResponseEntity aiChat(@RequestBody ChatRequest request, + @RequestHeader(value = "Authorization", required = false) String authorization) throws Exception { - if (request.isStream()) { - return ResponseEntity.ok() - .contentType(MediaType.TEXT_EVENT_STREAM) - .body((StreamingResponseBody) response); - } else { - return ResponseEntity.ok(response); - } - } catch (Exception e) { - return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR) - .body(e.getMessage()); + if (authorization != null && authorization.startsWith("Bearer ")) { + String token = authorization.replace("Bearer ", ""); + request.setApiKey(token); } + + Object response = aiChatV1Service.chatCompletion(request); + + if (request.isStream()) { + return ResponseEntity.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .header("Cache-Control", "no-cache") + .header("X-Accel-Buffering", "no") // 禁用Nginx缓冲 + .body((StreamingResponseBody) response); + } else { + return ResponseEntity.ok(response); + } + } + /** * AI api v1 * @@ -100,24 +109,46 @@ public ResponseEntity aiChat(@RequestBody ChatRequest request) { @SystemControllerLog(description = "AI completions") @PostMapping("/chat/completions") public ResponseEntity completions(@RequestBody ChatRequest request, - @RequestHeader("Authorization") String authorization) { + @RequestHeader(value = "Authorization", required = false) String authorization) throws Exception { if (authorization != null && authorization.startsWith("Bearer ")) { String token = authorization.replace("Bearer ", ""); request.setApiKey(token); } - try { - Object response = aiChatV1Service.chatCompletion(request); - if (request.isStream()) { - return ResponseEntity.ok() - .contentType(MediaType.TEXT_EVENT_STREAM) - .body((StreamingResponseBody) response); - } else { - return ResponseEntity.ok(response); - } - } catch (Exception e) { - return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR) - .body(e.getMessage()); + Object response = aiChatV1Service.chatCompletion(request); + + if (request.isStream()) { + return ResponseEntity.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .header("Cache-Control", "no-cache") + .header("X-Accel-Buffering", "no") // 禁用Nginx缓冲 + .body((StreamingResponseBody) response); + } else { + return ResponseEntity.ok(response); + } + } + /** + * get token + * + * @param request the request + * @return ai回答信息 result + */ + @Operation(summary = "获取加密key信息", description = "获取加密key信息", + parameters = { + @Parameter(name = "request", description = "入参对象") + }, responses = { + @ApiResponse(responseCode = "200", description = "返回信息", + content = @Content(mediaType = "application/json", schema = @Schema())), + @ApiResponse(responseCode = "400", description = "请求失败") + }) + @SystemControllerLog(description = "get token") + @PostMapping("/encrypt-key") + public Result getToken(@RequestBody ChatRequest request) throws Exception { + String apiKey = request.getApiKey(); + if(apiKey == null || apiKey.isEmpty()) { + return Result.failed(ExceptionEnum.CM320); } + String token = aiChatV1Service.getToken(apiKey); + return Result.success(new AiToken(token)); } } diff --git a/base/src/main/java/com/tinyengine/it/model/dto/AiToken.java b/base/src/main/java/com/tinyengine/it/model/dto/AiToken.java new file mode 100644 index 00000000..7c335b27 --- /dev/null +++ b/base/src/main/java/com/tinyengine/it/model/dto/AiToken.java @@ -0,0 +1,33 @@ +/** + * Copyright (c) 2023 - present TinyEngine Authors. + * Copyright (c) 2023 - present Huawei Cloud Computing Technologies Co., Ltd. + * + * Use of this source code is governed by an MIT-style license. + * + * THE OPEN SOURCE SOFTWARE IN THIS PRODUCT IS DISTRIBUTED IN THE HOPE THAT IT WILL BE USEFUL, + * BUT WITHOUT ANY WARRANTY, WITHOUT EVEN THE IMPLIED WARRANTY OF MERCHANTABILITY OR FITNESS FOR + * A PARTICULAR PURPOSE. SEE THE APPLICABLE LICENSES FOR MORE DETAILS. + * + */ + +package com.tinyengine.it.model.dto; + +import lombok.Data; + +/** + * The type Ai token. + * + * @since 2025-11-27 + */ + +@Data +public class AiToken { + String token; + + public AiToken(String token) { + this.token = token; + } + public AiToken(){ + + } +} diff --git a/base/src/main/java/com/tinyengine/it/service/app/impl/v1/AiChatV1ServiceImpl.java b/base/src/main/java/com/tinyengine/it/service/app/impl/v1/AiChatV1ServiceImpl.java index 5e9c0574..c4d92827 100644 --- a/base/src/main/java/com/tinyengine/it/service/app/impl/v1/AiChatV1ServiceImpl.java +++ b/base/src/main/java/com/tinyengine/it/service/app/impl/v1/AiChatV1ServiceImpl.java @@ -13,11 +13,14 @@ package com.tinyengine.it.service.app.impl.v1; import com.fasterxml.jackson.databind.JsonNode; +import com.tinyengine.it.common.exception.ServiceException; import com.tinyengine.it.common.log.SystemServiceLog; import com.tinyengine.it.common.utils.JsonUtils; +import com.tinyengine.it.common.utils.SM4Utils; import com.tinyengine.it.config.OpenAIConfig; import com.tinyengine.it.model.dto.ChatRequest; import com.tinyengine.it.service.app.v1.AiChatV1Service; +import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody; @@ -37,6 +40,7 @@ * * @since 2025-08-06 */ +@Slf4j @Service public class AiChatV1ServiceImpl implements AiChatV1Service { private final OpenAIConfig config = new OpenAIConfig(); @@ -54,17 +58,18 @@ public class AiChatV1ServiceImpl implements AiChatV1Service { @SystemServiceLog(description = "chatCompletion") public Object chatCompletion(ChatRequest request) throws Exception { String requestBody = buildRequestBody(request); - String apiKey = request.getApiKey() != null ? request.getApiKey() : config.getApiKey(); + String encryptApiKey = request.getApiKey() != null ? request.getApiKey() : config.getApiKey(); + String apiKey = getApiKey(encryptApiKey); String baseUrl = request.getBaseUrl(); // 规范化URL处理 String normalizedUrl = normalizeApiUrl(baseUrl); HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() - .uri(URI.create(normalizedUrl)) - .header("Content-Type", "application/json") - .header("Authorization", "Bearer " + apiKey) - .POST(HttpRequest.BodyPublishers.ofString(requestBody)); + .uri(URI.create(normalizedUrl)) + .header("Content-Type", "application/json") + .header("Authorization", "Bearer " + apiKey) + .POST(HttpRequest.BodyPublishers.ofString(requestBody)); if (request.isStream()) { requestBuilder.header("Accept", "text/event-stream"); return processStreamResponse(requestBuilder); @@ -73,6 +78,19 @@ public Object chatCompletion(ChatRequest request) throws Exception { } } + /** + * get token. + * + * @param apiKey the apiKey + * @return token the token + */ + @Override + public String getToken(String apiKey) throws Exception { + String sm4Key = System.getenv("SM4KEY"); + String encrypt = SM4Utils.encryptECB(apiKey, sm4Key); + return "EKEY_"+ encrypt; + } + /** * 规范化API URL,兼容不同厂商 */ @@ -88,6 +106,9 @@ private String normalizeApiUrl(String baseUrl) { if (baseUrl.contains("v1")) { return ensureUrlProtocol(baseUrl) + "/chat/completions"; + } + if (baseUrl.endsWith("#")) { + return ensureUrlProtocol(baseUrl); } else { return ensureUrlProtocol(baseUrl) + "/v1/chat/completions"; } @@ -154,43 +175,71 @@ private String buildRequestBody(ChatRequest request) { return JsonUtils.encode(body); } - private JsonNode processStandardResponse(HttpRequest.Builder requestBuilder) - throws Exception { - HttpResponse response = httpClient.send( - requestBuilder.build(), HttpResponse.BodyHandlers.ofString()); - return JsonUtils.MAPPER.readTree(response.body()); + private JsonNode processStandardResponse(HttpRequest.Builder requestBuilder) { + HttpResponse response = null; + String code = null; + String message = null; + try { + response = httpClient.send( + requestBuilder.build(), HttpResponse.BodyHandlers.ofString()); + code = String.valueOf(response.statusCode()); + if (response.statusCode() != 200) { + String errorBody = response.body(); + + // 尝试解析错误JSON + JsonNode errorNode = JsonUtils.MAPPER.readTree(errorBody); + message = errorNode.get("error").get("message").asText(); + throw new ServiceException(code, message); + } + return JsonUtils.MAPPER.readTree(response.body()); + } catch (IOException | InterruptedException e) { + throw new ServiceException(code, message); + } + + } private StreamingResponseBody processStreamResponse(HttpRequest.Builder requestBuilder) { return outputStream -> { + HttpResponse response = null; try { - HttpClient client = HttpClient.newHttpClient(); - HttpResponse response = client.send( - requestBuilder.build(), - HttpResponse.BodyHandlers.ofInputStream() + response = httpClient.send( + requestBuilder.build(), HttpResponse.BodyHandlers.ofInputStream() ); - if (response.statusCode() != 200) { - String errorBody = new String(response.body().readAllBytes(), StandardCharsets.UTF_8); - throw new IOException("API请求失败: " + response.statusCode() + " - " + errorBody); - } - try (InputStream inputStream = response.body()) { - byte[] buffer = new byte[8192]; - int bytesRead; - while ((bytesRead = inputStream.read(buffer)) != -1) { - outputStream.write(buffer, 0, bytesRead); - outputStream.flush(); - } - } - } catch (Exception e) { - try { - String errorEvent = "data: {\"error\": \"" + e.getMessage() + "\"}\n\n"; - outputStream.write(errorEvent.getBytes(StandardCharsets.UTF_8)); + } catch (InterruptedException e) { + throw new ServiceException("500", e.getMessage()); + } + + log.info("收到AI API响应,状态码: {}", response.statusCode()); + + if (response.statusCode() != 200) { + String errorBody = new String(response.body().readAllBytes(), StandardCharsets.UTF_8); + + log.info("错误响应内容: {}", errorBody); + + JsonNode errorNode = JsonUtils.MAPPER.readTree(errorBody); + throw new ServiceException(String.valueOf(response.statusCode()), errorNode.get("error").get("message").asText()); + } + + // 正常流处理逻辑 + try (InputStream inputStream = response.body()) { + byte[] buffer = new byte[8192]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + outputStream.write(buffer, 0, bytesRead); outputStream.flush(); - } catch (IOException ioException) { - throw new IOException("API请求失败,且无法发送错误信息: " + e.getMessage() + - " (IO错误: " + ioException.getMessage() + ")", e); } } }; } + + private String getApiKey(String encryptApiKey) throws Exception { + String sm4Key = System.getenv("SM4KEY"); + + if (encryptApiKey.startsWith("EKEY_")) { + String encryptBase64ApiKey = encryptApiKey.substring(5); + return SM4Utils.decryptECB(encryptBase64ApiKey, sm4Key); + } + return encryptApiKey; + } } diff --git a/base/src/main/java/com/tinyengine/it/service/app/v1/AiChatV1Service.java b/base/src/main/java/com/tinyengine/it/service/app/v1/AiChatV1Service.java index d2bda179..6d3f2127 100644 --- a/base/src/main/java/com/tinyengine/it/service/app/v1/AiChatV1Service.java +++ b/base/src/main/java/com/tinyengine/it/service/app/v1/AiChatV1Service.java @@ -27,4 +27,12 @@ public interface AiChatV1Service { * @return Object the Object */ public Object chatCompletion(ChatRequest request) throws Exception; + + /** + * get token. + * + * @param apiKey the apiKey + * @return token the token + */ + public String getToken(String apiKey) throws Exception; } diff --git a/pom.xml b/pom.xml index 8379ac7f..c651a73c 100644 --- a/pom.xml +++ b/pom.xml @@ -34,6 +34,7 @@ 6.1.0 4.1.118.Final 1.18 + 1.79 0.12.3 1.5.0 1.5.0-beta11 @@ -56,6 +57,12 @@ spring-boot-starter-data-jpa + + org.bouncycastle + bcprov-jdk18on + ${bcprov-jdk18on.version} + + com.alibaba druid-spring-boot-starter @@ -176,12 +183,6 @@ ${mockito-inline.version} - - org.bouncycastle - bcprov-jdk18on - ${bcprov-jdk18on.version} - - io.netty netty-buffer