diff --git a/lib/AppInfo/Application.php b/lib/AppInfo/Application.php index f1731cc8..dad578b2 100644 --- a/lib/AppInfo/Application.php +++ b/lib/AppInfo/Application.php @@ -82,6 +82,10 @@ class Application extends App implements IBootstrap { public const AUDIO_TO_TEXT_LANGUAGES = [['en', 'English'], ['zh', '中文'], ['de', 'Deutsch'], ['es', 'Español'], ['ru', 'Русский'], ['ko', '한국어'], ['fr', 'Français'], ['ja', '日本語'], ['pt', 'Português'], ['tr', 'Türkçe'], ['pl', 'Polski'], ['ca', 'Català'], ['nl', 'Nederlands'], ['ar', 'العربية'], ['sv', 'Svenska'], ['it', 'Italiano'], ['id', 'Bahasa Indonesia'], ['hi', 'हिन्दी'], ['fi', 'Suomi'], ['vi', 'Tiếng Việt'], ['he', 'עברית'], ['uk', 'Українська'], ['el', 'Ελληνικά'], ['ms', 'Bahasa Melayu'], ['cs', 'Česky'], ['ro', 'Română'], ['da', 'Dansk'], ['hu', 'Magyar'], ['ta', 'தமிழ்'], ['no', 'Norsk (bokmål / riksmål)'], ['th', 'ไทย / Phasa Thai'], ['ur', 'اردو'], ['hr', 'Hrvatski'], ['bg', 'Български'], ['lt', 'Lietuvių'], ['la', 'Latina'], ['mi', 'Māori'], ['ml', 'മലയാളം'], ['cy', 'Cymraeg'], ['sk', 'Slovenčina'], ['te', 'తెలుగు'], ['fa', 'فارسی'], ['lv', 'Latviešu'], ['bn', 'বাংলা'], ['sr', 'Српски'], ['az', 'Azərbaycanca / آذربايجان'], ['sl', 'Slovenščina'], ['kn', 'ಕನ್ನಡ'], ['et', 'Eesti'], ['mk', 'Македонски'], ['br', 'Brezhoneg'], ['eu', 'Euskara'], ['is', 'Íslenska'], ['hy', 'Հայերեն'], ['ne', 'नेपाली'], ['mn', 'Монгол'], ['bs', 'Bosanski'], ['kk', 'Қазақша'], ['sq', 'Shqip'], ['sw', 'Kiswahili'], ['gl', 'Galego'], ['mr', 'मराठी'], ['pa', 'ਪੰਜਾਬੀ / पंजाबी / پنجابي'], ['si', 'සිංහල'], ['km', 'ភាសាខ្មែរ'], ['sn', 'chiShona'], ['yo', 'Yorùbá'], ['so', 'Soomaaliga'], ['af', 'Afrikaans'], ['oc', 'Occitan'], ['ka', 'ქართული'], ['be', 'Беларуская'], ['tg', 'Тоҷикӣ'], ['sd', 'सिनधि'], ['gu', 'ગુજરાતી'], ['am', 'አማርኛ'], ['yi', 'ייִדיש'], ['lo', 'ລາວ / Pha xa lao'], ['uz', 'Ўзбек'], ['fo', 'Føroyskt'], ['ht', 'Krèyol ayisyen'], ['ps', 'پښتو'], ['tk', 'Туркмен / تركمن'], ['nn', 'Norsk (nynorsk)'], ['mt', 'bil-Malti'], ['sa', 'संस्कृतम्'], ['lb', 'Lëtzebuergesch'], ['my', 'Myanmasa'], ['bo', 'བོད་ཡིག / Bod skad'], ['tl', 'Tagalog'], ['mg', 'Malagasy'], ['as', 'অসমীয়া'], ['tt', 'Tatarça'], ['haw', 'ʻŌlelo Hawaiʻi'], ['ln', 'Lingála'], ['ha', 'هَوُسَ'], ['ba', 'Башҡорт'], ['jw', 'ꦧꦱꦗꦮ'], ['su', 'Basa Sunda'], ['yue', '粤语']]; + public const SERVICE_TYPE_IMAGE = 'image'; + public const SERVICE_TYPE_STT = 'stt'; + public const SERVICE_TYPE_TTS = 'tts'; + private IAppConfig $appConfig; public function __construct(array $urlParams = []) { diff --git a/lib/Controller/ConfigController.php b/lib/Controller/ConfigController.php index dae11002..3c40d6a3 100644 --- a/lib/Controller/ConfigController.php +++ b/lib/Controller/ConfigController.php @@ -72,8 +72,11 @@ public function setSensitiveUserConfig(array $values): DataResponse { * @return DataResponse */ public function setAdminConfig(array $values): DataResponse { - if (isset($values['api_key']) || isset($values['basic_password']) || isset($values['basic_user']) || isset($values['url'])) { - return new DataResponse('', Http::STATUS_BAD_REQUEST); + $prefixes = ['', 'image_', 'tts_', 'stt_']; + foreach ($prefixes as $prefix) { + if (isset($values[$prefix . 'api_key']) || isset($values[$prefix . 'basic_password']) || isset($values[$prefix . 'basic_user']) || isset($values[$prefix . 'url'])) { + return new DataResponse('', Http::STATUS_BAD_REQUEST); + } } try { $this->openAiSettingsService->setAdminConfig($values); diff --git a/lib/Controller/OpenAiAPIController.php b/lib/Controller/OpenAiAPIController.php index df7e9114..c27fba4e 100644 --- a/lib/Controller/OpenAiAPIController.php +++ b/lib/Controller/OpenAiAPIController.php @@ -26,12 +26,13 @@ public function __construct( } /** + * @param string|null $serviceType * @return DataResponse */ #[NoAdminRequired] - public function getModels(): DataResponse { + public function getModels(?string $serviceType = null): DataResponse { try { - $response = $this->openAiAPIService->getModels($this->userId, true); + $response = $this->openAiAPIService->getModels($this->userId, true, $serviceType); return new DataResponse($response); } catch (Exception $e) { $code = $e->getCode() === 0 ? Http::STATUS_BAD_REQUEST : intval($e->getCode()); diff --git a/lib/Service/OpenAiAPIService.php b/lib/Service/OpenAiAPIService.php index a7652aec..a9d5a89b 100644 --- a/lib/Service/OpenAiAPIService.php +++ b/lib/Service/OpenAiAPIService.php @@ -38,7 +38,7 @@ */ class OpenAiAPIService { private IClient $client; - private ?array $modelsMemoryCache = null; + private array $modelsMemoryCache = []; public function __construct( private LoggerInterface $logger, @@ -66,21 +66,47 @@ public function createQuotaUsage(string $userId, int $type, int $usage) { } /** + * @param ?string $serviceType * @return bool */ - public function isUsingOpenAi(): bool { - $serviceUrl = $this->openAiSettingsService->getServiceUrl(); + public function isUsingOpenAi(?string $serviceType = null): bool { + $serviceUrl = ''; + if ($serviceType === Application::SERVICE_TYPE_IMAGE) { + $serviceUrl = $this->openAiSettingsService->getImageServiceUrl(); + } elseif ($serviceType === Application::SERVICE_TYPE_STT) { + $serviceUrl = $this->openAiSettingsService->getSttServiceUrl(); + } elseif ($serviceType === Application::SERVICE_TYPE_TTS) { + $serviceUrl = $this->openAiSettingsService->getTtsServiceUrl(); + } + if ($serviceUrl === '') { + $serviceUrl = $this->openAiSettingsService->getServiceUrl(); + } return $serviceUrl === '' || $serviceUrl === Application::OPENAI_API_BASE_URL; } /** + * @param ?string $serviceType + * * @return string */ - public function getServiceName(): string { - if ($this->isUsingOpenAi()) { + public function getServiceName(?string $serviceType = null): string { + if ($this->isUsingOpenAi($serviceType)) { + if ($serviceType === Application::SERVICE_TYPE_IMAGE) { + return $this->l10n->t('OpenAI\'s DALL-E 2'); + } + if ($serviceType === Application::SERVICE_TYPE_TTS) { + $this->l10n->t('OpenAI\'s Text to Speech'); + } return 'OpenAI'; } else { $serviceName = $this->openAiSettingsService->getServiceName(); + if ($serviceType === Application::SERVICE_TYPE_IMAGE && $this->openAiSettingsService->imageOverrideEnabled()) { + $serviceName = $this->openAiSettingsService->getImageServiceName(); + } elseif ($serviceType === Application::SERVICE_TYPE_STT && $this->openAiSettingsService->sttOverrideEnabled()) { + $serviceName = $this->openAiSettingsService->getSttServiceName(); + } elseif ($serviceType === Application::SERVICE_TYPE_TTS && $this->openAiSettingsService->ttsOverrideEnabled()) { + $serviceName = $this->openAiSettingsService->getTtsServiceName(); + } if ($serviceName === '') { return 'LocalAI'; } @@ -111,18 +137,21 @@ private function isModelListValid($models): bool { /** * @param ?string $userId * @param bool $refresh + * @param ?string $serviceType * @return array|string[] * @throws Exception */ - public function getModels(?string $userId, bool $refresh = false): array { + public function getModels(?string $userId, bool $refresh = false, ?string $serviceType = null): array { $cache = $this->cacheFactory->createDistributed(Application::APP_ID); - $userCacheKey = Application::MODELS_CACHE_KEY . '_' . ($userId ?? ''); - $adminCacheKey = Application::MODELS_CACHE_KEY . '-main'; + $userCacheKey = Application::MODELS_CACHE_KEY . '_' . ($userId ?? '') . '_' . ($serviceType ?? 'main'); + $adminCacheKey = Application::MODELS_CACHE_KEY . '-main' . '_' . ($serviceType ?? 'main'); + $dbCacheKey = $serviceType ? 'models' . '_' . $serviceType : 'models'; + $memoryCacheKey = $serviceType ?? 'default'; if (!$refresh) { - if ($this->modelsMemoryCache !== null) { + if (array_key_exists($memoryCacheKey, $this->modelsMemoryCache)) { $this->logger->debug('Getting OpenAI models from the memory cache'); - return $this->modelsMemoryCache; + return $this->modelsMemoryCache[$memoryCacheKey]; } // try to get models from the user cache first @@ -130,7 +159,7 @@ public function getModels(?string $userId, bool $refresh = false): array { $userCachedModels = $cache->get($userCacheKey); if ($userCachedModels) { $this->logger->debug('Getting OpenAI models from user cache for user ' . $userId); - $this->modelsMemoryCache = $userCachedModels; + $this->modelsMemoryCache[$memoryCacheKey] = $userCachedModels; return $userCachedModels; } } @@ -149,13 +178,13 @@ public function getModels(?string $userId, bool $refresh = false): array { // we try to get the models from the admin cache if ($adminCachedModels = $cache->get($adminCacheKey)) { $this->logger->debug('Getting OpenAI models from the main distributed cache'); - $this->modelsMemoryCache = $adminCachedModels; + $this->modelsMemoryCache[$memoryCacheKey] = $adminCachedModels; return $adminCachedModels; } } // if we don't need to refresh to model list and it's not been found in the cache, it is obtained from the DB - $modelsObjectString = $this->appConfig->getValueString(Application::APP_ID, 'models', '{"data":[],"object":"list"}'); + $modelsObjectString = $this->appConfig->getValueString(Application::APP_ID, $dbCacheKey, '{"data":[],"object":"list"}'); $fallbackModels = [ 'data' => [], 'object' => 'list', @@ -167,7 +196,7 @@ public function getModels(?string $userId, bool $refresh = false): array { $newCache = $fallbackModels; } $cache->set($userId !== null ? $userCacheKey : $adminCacheKey, $newCache, Application::MODELS_CACHE_TTL); - $this->modelsMemoryCache = $newCache; + $this->modelsMemoryCache[$memoryCacheKey] = $newCache; return $newCache; } @@ -177,7 +206,7 @@ public function getModels(?string $userId, bool $refresh = false): array { try { $this->logger->debug('Actually getting OpenAI models with a network request'); - $modelsResponse = $this->request($userId, 'models'); + $modelsResponse = $this->request($userId, 'models', serviceType: $serviceType); } catch (Exception $e) { $this->logger->warning('Error retrieving models (exc): ' . $e->getMessage()); throw $e; @@ -197,10 +226,10 @@ public function getModels(?string $userId, bool $refresh = false): array { } $cache->set($userId !== null ? $userCacheKey : $adminCacheKey, $modelsResponse, Application::MODELS_CACHE_TTL); - $this->modelsMemoryCache = $modelsResponse; + $this->modelsMemoryCache[$memoryCacheKey] = $modelsResponse; // we always store the model list after getting it $modelsObjectString = json_encode($modelsResponse); - $this->appConfig->setValueString(Application::APP_ID, 'models', $modelsObjectString); + $this->appConfig->setValueString(Application::APP_ID, $dbCacheKey, $modelsObjectString); return $modelsResponse; } @@ -223,9 +252,9 @@ private function hasOwnOpenAiApiKey(string $userId): bool { * @param string|null $userId * @return array */ - public function getModelEnumValues(?string $userId): array { + public function getModelEnumValues(?string $userId, ?string $serviceType = null): array { try { - $modelResponse = $this->getModels($userId); + $modelResponse = $this->getModels($userId, false, $serviceType); $modelEnumValues = array_map(function (array $model) { return new ShapeEnumValue($model['id'], $model['id']); }, $modelResponse['data'] ?? []); @@ -644,7 +673,8 @@ public function createChatCompletion( foreach ($response['choices'] as $choice) { // get tool calls only if this is the finish reason and it's defined and it's an array - if ($choice['finish_reason'] === 'tool_calls' + if ( + $choice['finish_reason'] === 'tool_calls' && isset($choice['message']['tool_calls']) && is_array($choice['message']['tool_calls']) ) { @@ -779,7 +809,7 @@ public function transcribe( $endpoint = $translate ? 'audio/translations' : 'audio/transcriptions'; $contentType = 'multipart/form-data'; - $response = $this->request($userId, $endpoint, $params, 'POST', $contentType); + $response = $this->request($userId, $endpoint, $params, 'POST', $contentType, serviceType: Application::SERVICE_TYPE_STT); if (!isset($response['text'])) { $this->logger->warning('Audio transcription error: ' . json_encode($response)); @@ -809,7 +839,11 @@ public function transcribe( * @throws Exception */ public function requestImageCreation( - ?string $userId, string $prompt, string $model, int $n = 1, string $size = Application::DEFAULT_DEFAULT_IMAGE_SIZE, + ?string $userId, + string $prompt, + string $model, + int $n = 1, + string $size = Application::DEFAULT_DEFAULT_IMAGE_SIZE, ): array { if ($this->isQuotaExceeded($userId, Application::QUOTA_TYPE_IMAGE)) { throw new Exception($this->l10n->t('Image generation quota exceeded'), Http::STATUS_TOO_MANY_REQUESTS); @@ -822,12 +856,11 @@ public function requestImageCreation( 'model' => $model === Application::DEFAULT_MODEL_ID ? Application::DEFAULT_IMAGE_MODEL_ID : $model, ]; - $apiResponse = $this->request($userId, 'images/generations', $params, 'POST'); + $apiResponse = $this->request($userId, 'images/generations', $params, 'POST', serviceType: Application::SERVICE_TYPE_IMAGE); if (!isset($apiResponse['data']) || !is_array($apiResponse['data'])) { $this->logger->warning('OpenAI image generation error', ['api_response' => $apiResponse]); throw new Exception($this->l10n->t('Unknown image generation error'), Http::STATUS_INTERNAL_SERVER_ERROR); - } else { try { $this->createQuotaUsage($userId ?? '', Application::QUOTA_TYPE_IMAGE, $n); @@ -877,7 +910,11 @@ public function getImageRequestOptions(?string $userId): array { * @throws Exception */ public function requestSpeechCreation( - ?string $userId, string $prompt, string $model, string $voice, float $speed = 1, + ?string $userId, + string $prompt, + string $model, + string $voice, + float $speed = 1, ): array { if ($this->isQuotaExceeded($userId, Application::QUOTA_TYPE_SPEECH)) { throw new Exception($this->l10n->t('Speech generation quota exceeded'), Http::STATUS_TOO_MANY_REQUESTS); @@ -891,7 +928,7 @@ public function requestSpeechCreation( 'speed' => $speed, ]; - $apiResponse = $this->request($userId, 'audio/speech', $params, 'POST'); + $apiResponse = $this->request($userId, 'audio/speech', $params, 'POST', serviceType: Application::SERVICE_TYPE_TTS); try { $charCount = mb_strlen($prompt); @@ -930,7 +967,7 @@ public function updateExpTextProcessingTime(int $runtime): void { * @return int */ public function getExpImgProcessingTime(): int { - return $this->isUsingOpenAi() + return $this->isUsingOpenAi(Application::SERVICE_TYPE_IMAGE) ? intval($this->appConfig->getValueString(Application::APP_ID, 'openai_image_generation_time', strval(Application::DEFAULT_OPENAI_IMAGE_GENERATION_TIME), lazy: true)) : intval($this->appConfig->getValueString(Application::APP_ID, 'localai_image_generation_time', strval(Application::DEFAULT_LOCALAI_IMAGE_GENERATION_TIME), lazy: true)); } @@ -943,7 +980,7 @@ public function updateExpImgProcessingTime(int $runtime): void { $oldTime = floatval($this->getExpImgProcessingTime()); $newTime = (1.0 - Application::EXPECTED_RUNTIME_LOWPASS_FACTOR) * $oldTime + Application::EXPECTED_RUNTIME_LOWPASS_FACTOR * floatval($runtime); - if ($this->isUsingOpenAi()) { + if ($this->isUsingOpenAi(Application::SERVICE_TYPE_IMAGE)) { $this->appConfig->setValueString(Application::APP_ID, 'openai_image_generation_time', strval(intval($newTime)), lazy: true); } else { $this->appConfig->setValueString(Application::APP_ID, 'localai_image_generation_time', strval(intval($newTime)), lazy: true); @@ -958,18 +995,53 @@ public function updateExpImgProcessingTime(int $runtime): void { * @param string $method HTTP query method * @param string|null $contentType * @param bool $logErrors if set to false error logs will be suppressed + * @param string|null $serviceType * @return array decoded request result or error * @throws Exception */ - public function request(?string $userId, string $endPoint, array $params = [], string $method = 'GET', ?string $contentType = null, bool $logErrors = true): array { + public function request(?string $userId, string $endPoint, array $params = [], string $method = 'GET', ?string $contentType = null, bool $logErrors = true, ?string $serviceType = null): array { try { - $serviceUrl = $this->openAiSettingsService->getServiceUrl(); - if ($serviceUrl === '') { - $serviceUrl = Application::OPENAI_API_BASE_URL; + $serviceUrl = ''; + $apiKey = ''; + $basicUser = ''; + $basicPassword = ''; + $useBasicAuth = false; + $timeout = 0; + + if ($serviceType === Application::SERVICE_TYPE_IMAGE && $this->openAiSettingsService->imageOverrideEnabled()) { + $serviceUrl = $this->openAiSettingsService->getImageServiceUrl(); + $apiKey = $this->openAiSettingsService->getAdminImageApiKey(); + $basicUser = $this->openAiSettingsService->getAdminImageBasicUser(); + $basicPassword = $this->openAiSettingsService->getAdminImageBasicPassword(); + $useBasicAuth = $this->openAiSettingsService->getAdminImageUseBasicAuth(); + $timeout = $this->openAiSettingsService->getImageRequestTimeout(); + } elseif ($serviceType === Application::SERVICE_TYPE_STT && $this->openAiSettingsService->sttOverrideEnabled()) { + $serviceUrl = $this->openAiSettingsService->getSttServiceUrl(); + $apiKey = $this->openAiSettingsService->getAdminSttApiKey(); + $basicUser = $this->openAiSettingsService->getAdminSttBasicUser(); + $basicPassword = $this->openAiSettingsService->getAdminSttBasicPassword(); + $useBasicAuth = $this->openAiSettingsService->getAdminSttUseBasicAuth(); + $timeout = $this->openAiSettingsService->getSttRequestTimeout(); + } elseif ($serviceType === Application::SERVICE_TYPE_TTS && $this->openAiSettingsService->ttsOverrideEnabled()) { + $serviceUrl = $this->openAiSettingsService->getTtsServiceUrl(); + $apiKey = $this->openAiSettingsService->getAdminTtsApiKey(); + $basicUser = $this->openAiSettingsService->getAdminTtsBasicUser(); + $basicPassword = $this->openAiSettingsService->getAdminTtsBasicPassword(); + $useBasicAuth = $this->openAiSettingsService->getAdminTtsUseBasicAuth(); + $timeout = $this->openAiSettingsService->getTtsRequestTimeout(); + } else { + // Currently only supporting user api keys for the default service + $serviceUrl = $this->openAiSettingsService->getServiceUrl(); + if ($serviceUrl === '') { + $serviceUrl = Application::OPENAI_API_BASE_URL; + } + $apiKey = $this->openAiSettingsService->getUserApiKey($userId, true); + $basicUser = $this->openAiSettingsService->getUserBasicUser($userId, true); + $basicPassword = $this->openAiSettingsService->getUserBasicPassword($userId, true); + $useBasicAuth = $this->openAiSettingsService->getUseBasicAuth(); + $timeout = $this->openAiSettingsService->getRequestTimeout(); } - $timeout = $this->openAiSettingsService->getRequestTimeout(); - $url = rtrim($serviceUrl, '/') . '/' . $endPoint; $options = [ 'timeout' => $timeout, @@ -978,20 +1050,11 @@ public function request(?string $userId, string $endPoint, array $params = [], s ], ]; - // an API key is mandatory when using OpenAI - $apiKey = $this->openAiSettingsService->getUserApiKey($userId, true); - - // We can also use basic authentication - $basicUser = $this->openAiSettingsService->getUserBasicUser($userId, true); - $basicPassword = $this->openAiSettingsService->getUserBasicPassword($userId, true); - if ($serviceUrl === Application::OPENAI_API_BASE_URL && $apiKey === '') { return ['error' => 'An API key is required for api.openai.com']; } - $useBasicAuth = $this->openAiSettingsService->getUseBasicAuth(); - - if ($this->isUsingOpenAi() || !$useBasicAuth) { + if ($this->isUsingOpenAi($serviceType) || !$useBasicAuth) { if ($apiKey !== '') { $options['headers']['Authorization'] = 'Bearer ' . $apiKey; } @@ -1001,7 +1064,7 @@ public function request(?string $userId, string $endPoint, array $params = [], s } } - if (!$this->isUsingOpenAi()) { + if (!$this->isUsingOpenAi($serviceType)) { $options['nextcloud']['allow_local_address'] = true; } @@ -1077,12 +1140,12 @@ public function request(?string $userId, string $endPoint, array $params = [], s throw new Exception( $this->l10n->t('API request error: ') . ( $e->getResponse()->getStatusCode() === 401 - ? $this->l10n->t('Invalid API Key/Basic Auth: ') - : '' + ? $this->l10n->t('Invalid API Key/Basic Auth: ') + : '' ) . ( isset($parsedResponseBody['error']) && isset($parsedResponseBody['error']['message']) - ? $parsedResponseBody['error']['message'] - : $e->getMessage() + ? $parsedResponseBody['error']['message'] + : $e->getMessage() ), intval($e->getCode()), ); @@ -1095,7 +1158,7 @@ public function request(?string $userId, string $endPoint, array $params = [], s * @return bool whether the T2I provider is available */ public function isT2IAvailable(): bool { - if ($this->isUsingOpenAi()) { + if ($this->openAiSettingsService->imageOverrideEnabled() || $this->isUsingOpenAi()) { return true; } try { @@ -1103,7 +1166,7 @@ public function isT2IAvailable(): bool { 'prompt' => 'a', 'model' => 'invalid-model', ]; - $this->request(null, 'images/generations', $params, 'POST', logErrors: false); + $this->request(null, 'images/generations', $params, 'POST', logErrors: false, serviceType: Application::SERVICE_TYPE_IMAGE); } catch (Exception $e) { return $e->getCode() !== Http::STATUS_NOT_FOUND && $e->getCode() !== Http::STATUS_UNAUTHORIZED; } @@ -1116,7 +1179,7 @@ public function isT2IAvailable(): bool { * @return bool whether the STT provider is available */ public function isSTTAvailable(): bool { - if ($this->isUsingOpenAi()) { + if ($this->openAiSettingsService->sttOverrideEnabled() || $this->isUsingOpenAi()) { return true; } try { @@ -1124,7 +1187,7 @@ public function isSTTAvailable(): bool { 'model' => 'invalid-model', 'file' => 'a', ]; - $this->request(null, 'audio/translations', $params, 'POST', 'multipart/form-data', logErrors: false); + $this->request(null, 'audio/translations', $params, 'POST', 'multipart/form-data', logErrors: false, serviceType: Application::SERVICE_TYPE_STT); } catch (Exception $e) { return $e->getCode() !== Http::STATUS_NOT_FOUND && $e->getCode() !== Http::STATUS_UNAUTHORIZED; } @@ -1137,7 +1200,7 @@ public function isSTTAvailable(): bool { * @return bool whether the TTS provider is available */ public function isTTSAvailable(): bool { - if ($this->isUsingOpenAi()) { + if ($this->openAiSettingsService->ttsOverrideEnabled() || $this->isUsingOpenAi()) { return true; } try { @@ -1148,7 +1211,7 @@ public function isTTSAvailable(): bool { 'response_format' => 'mp3', ]; - $this->request(null, 'audio/speech', $params, 'POST', logErrors: false); + $this->request(null, 'audio/speech', $params, 'POST', logErrors: false, serviceType: Application::SERVICE_TYPE_TTS); } catch (Exception $e) { return $e->getCode() !== Http::STATUS_NOT_FOUND && $e->getCode() !== Http::STATUS_UNAUTHORIZED; } diff --git a/lib/Service/OpenAiSettingsService.php b/lib/Service/OpenAiSettingsService.php index 1d2e1931..19003c5e 100644 --- a/lib/Service/OpenAiSettingsService.php +++ b/lib/Service/OpenAiSettingsService.php @@ -47,7 +47,31 @@ class OpenAiSettingsService { 'chat_endpoint_enabled' => 'boolean', 'basic_user' => 'string', 'basic_password' => 'string', - 'use_basic_auth' => 'boolean' + 'use_basic_auth' => 'boolean', + + 'image_url' => 'string', + 'image_service_name' => 'string', + 'image_api_key' => 'string', + 'image_basic_user' => 'string', + 'image_basic_password' => 'string', + 'image_use_basic_auth' => 'boolean', + 'image_request_timeout' => 'integer', + + 'stt_url' => 'string', + 'stt_service_name' => 'string', + 'stt_api_key' => 'string', + 'stt_basic_user' => 'string', + 'stt_basic_password' => 'string', + 'stt_use_basic_auth' => 'boolean', + 'stt_request_timeout' => 'integer', + + 'tts_url' => 'string', + 'tts_service_name' => 'string', + 'tts_api_key' => 'string', + 'tts_basic_user' => 'string', + 'tts_basic_password' => 'string', + 'tts_use_basic_auth' => 'boolean', + 'tts_request_timeout' => 'integer', ]; private const USER_CONFIG_TYPES = [ @@ -384,6 +408,153 @@ public function getUseBasicAuth(): bool { return $this->appConfig->getValueString(Application::APP_ID, 'use_basic_auth', '0', lazy: true) === '1'; } + /** + * @return string + */ + public function getImageServiceUrl(): string { + return $this->appConfig->getValueString(Application::APP_ID, 'image_url', '', lazy: true); + } + + /** + * @return string + */ + public function getImageServiceName(): string { + return $this->appConfig->getValueString(Application::APP_ID, 'image_service_name', '', lazy: true); + } + + /** + * @return string + */ + public function getAdminImageApiKey(): string { + return $this->appConfig->getValueString(Application::APP_ID, 'image_api_key', '', true); + } + + /** + * @return string + */ + public function getAdminImageBasicUser(): string { + return $this->appConfig->getValueString(Application::APP_ID, 'image_basic_user', '', lazy: true); + } + + /** + * @return string + */ + public function getAdminImageBasicPassword(): string { + return $this->appConfig->getValueString(Application::APP_ID, 'image_basic_password', '', true); + } + + /** + * @return bool + */ + public function getAdminImageUseBasicAuth(): bool { + return $this->appConfig->getValueString(Application::APP_ID, 'image_use_basic_auth', '0', lazy: true) === '1'; + } + + /** + * @return int + */ + public function getImageRequestTimeout(): int { + return intval($this->appConfig->getValueString(Application::APP_ID, 'image_request_timeout', strval(Application::OPENAI_DEFAULT_REQUEST_TIMEOUT), lazy: true)) ?: Application::OPENAI_DEFAULT_REQUEST_TIMEOUT; + } + + /** + * @return string + */ + public function getSttServiceUrl(): string { + return $this->appConfig->getValueString(Application::APP_ID, 'stt_url', '', lazy: true); + } + + /** + * @return string + */ + public function getSttServiceName(): string { + return $this->appConfig->getValueString(Application::APP_ID, 'stt_service_name', '', lazy: true); + } + + /** + * @return string + */ + public function getAdminSttApiKey(): string { + return $this->appConfig->getValueString(Application::APP_ID, 'stt_api_key', '', true); + } + + /** + * @return string + */ + public function getAdminSttBasicUser(): string { + return $this->appConfig->getValueString(Application::APP_ID, 'stt_basic_user', '', lazy: true); + } + + /** + * @return string + */ + public function getAdminSttBasicPassword(): string { + return $this->appConfig->getValueString(Application::APP_ID, 'stt_basic_password', '', true); + } + + /** + * @return bool + */ + public function getAdminSttUseBasicAuth(): bool { + return $this->appConfig->getValueString(Application::APP_ID, 'stt_use_basic_auth', '0', lazy: true) === '1'; + } + + /** + * @return int + */ + public function getSttRequestTimeout(): int { + return intval($this->appConfig->getValueString(Application::APP_ID, 'stt_request_timeout', strval(Application::OPENAI_DEFAULT_REQUEST_TIMEOUT), lazy: true)) ?: Application::OPENAI_DEFAULT_REQUEST_TIMEOUT; + } + + /** + * @return string + */ + public function getTtsServiceUrl(): string { + return $this->appConfig->getValueString(Application::APP_ID, 'tts_url', '', lazy: true); + } + + /** + * @return string + */ + public function getTtsServiceName(): string { + return $this->appConfig->getValueString(Application::APP_ID, 'tts_service_name', '', lazy: true); + } + + /** + * @return string + */ + public function getAdminTtsApiKey(): string { + return $this->appConfig->getValueString(Application::APP_ID, 'tts_api_key', '', true); + } + + /** + * @return string + */ + public function getAdminTtsBasicUser(): string { + return $this->appConfig->getValueString(Application::APP_ID, 'tts_basic_user', '', lazy: true); + } + + /** + * @return string + */ + public function getAdminTtsBasicPassword(): string { + return $this->appConfig->getValueString(Application::APP_ID, 'tts_basic_password', '', true); + } + + /** + * @return bool + */ + public function getAdminTtsUseBasicAuth(): bool { + return $this->appConfig->getValueString(Application::APP_ID, 'tts_use_basic_auth', '0', lazy: true) === '1'; + } + + /** + * @return int + */ + public function getTtsRequestTimeout(): int { + return intval($this->appConfig->getValueString(Application::APP_ID, 'tts_request_timeout', strval(Application::OPENAI_DEFAULT_REQUEST_TIMEOUT), lazy: true)) ?: Application::OPENAI_DEFAULT_REQUEST_TIMEOUT; + } + /** * Get the admin config for the settings page * @return mixed[] @@ -421,7 +592,31 @@ public function getAdminConfig(): array { 'chat_endpoint_enabled' => $this->getChatEndpointEnabled(), 'basic_user' => $this->getAdminBasicUser(), 'basic_password' => $this->getAdminBasicPassword(), - 'use_basic_auth' => $this->getUseBasicAuth() + 'use_basic_auth' => $this->getUseBasicAuth(), + // Get the service details for image, stt and tts + 'image_url' => $this->getImageServiceUrl(), + 'image_service_name' => $this->getImageServiceName(), + 'image_api_key' => $this->getAdminImageApiKey(), + 'image_basic_user' => $this->getAdminImageBasicUser(), + 'image_basic_password' => $this->getAdminImageBasicPassword(), + 'image_use_basic_auth' => $this->getAdminImageUseBasicAuth(), + 'image_request_timeout' => $this->getImageRequestTimeout(), + + 'stt_url' => $this->getSttServiceUrl(), + 'stt_service_name' => $this->getSttServiceName(), + 'stt_api_key' => $this->getAdminSttApiKey(), + 'stt_basic_user' => $this->getAdminSttBasicUser(), + 'stt_basic_password' => $this->getAdminSttBasicPassword(), + 'stt_use_basic_auth' => $this->getAdminSttUseBasicAuth(), + 'stt_request_timeout' => $this->getSttRequestTimeout(), + + 'tts_url' => $this->getTtsServiceUrl(), + 'tts_service_name' => $this->getTtsServiceName(), + 'tts_api_key' => $this->getAdminTtsApiKey(), + 'tts_basic_user' => $this->getAdminTtsBasicUser(), + 'tts_basic_password' => $this->getAdminTtsBasicPassword(), + 'tts_use_basic_auth' => $this->getAdminTtsUseBasicAuth(), + 'tts_request_timeout' => $this->getTtsRequestTimeout(), ]; } @@ -801,6 +996,191 @@ public function setAdminTtsVoices(array $voices): void { $this->invalidateModelsCache(); } + /** + * @param string $url + * @return void + * @throws Exception + */ + public function setImageServiceUrl(string $url): void { + if ($url !== '' && !filter_var($url, FILTER_VALIDATE_URL)) { + throw new Exception('Invalid image service URL'); + } + $this->appConfig->setValueString(Application::APP_ID, 'image_url', $url, lazy: true); + } + + /** + * @param string $name + * @return void + */ + public function setImageServiceName(string $name): void { + $this->appConfig->setValueString(Application::APP_ID, 'image_service_name', $name, lazy: true); + } + + /** + * @param string $apiKey + * @return void + */ + public function setAdminImageApiKey(string $apiKey): void { + $this->appConfig->setValueString(Application::APP_ID, 'image_api_key', $apiKey, true, true); + } + + /** + * @param string $user + * @return void + */ + public function setAdminImageBasicUser(string $user): void { + $this->appConfig->setValueString(Application::APP_ID, 'image_basic_user', $user, lazy: true); + } + + /** + * @param string $password + * @return void + */ + public function setAdminImageBasicPassword(string $password): void { + $this->appConfig->setValueString(Application::APP_ID, 'image_basic_password', $password, true, true); + } + + /** + * @param bool $use + * @return void + */ + public function setAdminImageUseBasicAuth(bool $use): void { + $this->appConfig->setValueString(Application::APP_ID, 'image_use_basic_auth', $use ? '1' : '0', lazy: true); + } + + /** + * @param int $requestTimeout + * @return void + */ + public function setImageRequestTimeout(int $requestTimeout): void { + // Validate input: + $requestTimeout = max(1, $requestTimeout); + $this->appConfig->setValueString(Application::APP_ID, 'image_request_timeout', strval($requestTimeout), lazy: true); + } + /** + * @param string $url + * @return void + * @throws Exception + */ + public function setSttServiceUrl(string $url): void { + if ($url !== '' && !filter_var($url, FILTER_VALIDATE_URL)) { + throw new Exception('Invalid STT service URL'); + } + $this->appConfig->setValueString(Application::APP_ID, 'stt_url', $url, lazy: true); + } + + /** + * @param string $name + * @return void + */ + public function setSttServiceName(string $name): void { + $this->appConfig->setValueString(Application::APP_ID, 'stt_service_name', $name, lazy: true); + } + + /** + * @param string $apiKey + * @return void + */ + public function setAdminSttApiKey(string $apiKey): void { + $this->appConfig->setValueString(Application::APP_ID, 'stt_api_key', $apiKey, true, true); + } + + /** + * @param string $user + * @return void + */ + public function setAdminSttBasicUser(string $user): void { + $this->appConfig->setValueString(Application::APP_ID, 'stt_basic_user', $user, lazy: true); + } + + /** + * @param string $password + * @return void + */ + public function setAdminSttBasicPassword(string $password): void { + $this->appConfig->setValueString(Application::APP_ID, 'stt_basic_password', $password, true, true); + } + + /** + * @param bool $use + * @return void + */ + public function setAdminSttUseBasicAuth(bool $use): void { + $this->appConfig->setValueString(Application::APP_ID, 'stt_use_basic_auth', $use ? '1' : '0', lazy: true); + } + + /** + * @param int $requestTimeout + * @return void + */ + public function setSttRequestTimeout(int $requestTimeout): void { + // Validate input: + $requestTimeout = max(1, $requestTimeout); + $this->appConfig->setValueString(Application::APP_ID, 'stt_request_timeout', strval($requestTimeout), lazy: true); + } + + /** + * @param string $url + * @return void + * @throws Exception + */ + public function setTtsServiceUrl(string $url): void { + if ($url !== '' && !filter_var($url, FILTER_VALIDATE_URL)) { + throw new Exception('Invalid TTS service URL'); + } + $this->appConfig->setValueString(Application::APP_ID, 'tts_url', $url, lazy: true); + } + + /** + * @param string $name + * @return void + */ + public function setTtsServiceName(string $name): void { + $this->appConfig->setValueString(Application::APP_ID, 'tts_service_name', $name, lazy: true); + } + + /** + * @param string $apiKey + * @return void + */ + public function setAdminTtsApiKey(string $apiKey): void { + $this->appConfig->setValueString(Application::APP_ID, 'tts_api_key', $apiKey, true, true); + } + + /** + * @param string $user + * @return void + */ + public function setAdminTtsBasicUser(string $user): void { + $this->appConfig->setValueString(Application::APP_ID, 'tts_basic_user', $user, lazy: true); + } + + /** + * @param string $password + * @return void + */ + public function setAdminTtsBasicPassword(string $password): void { + $this->appConfig->setValueString(Application::APP_ID, 'tts_basic_password', $password, true, true); + } + + /** + * @param bool $use + * @return void + */ + public function setAdminTtsUseBasicAuth(bool $use): void { + $this->appConfig->setValueString(Application::APP_ID, 'tts_use_basic_auth', $use ? '1' : '0', lazy: true); + } + + /** + * @param int $requestTimeout + * @return void + */ + public function setTtsRequestTimeout(int $requestTimeout): void { + // Validate input: + $requestTimeout = max(1, $requestTimeout); + $this->appConfig->setValueString(Application::APP_ID, 'tts_request_timeout', strval($requestTimeout), lazy: true); + } + /** * Set the admin config for the settings page * @param mixed[] $adminConfig @@ -823,10 +1203,7 @@ public function setAdminConfig(array $adminConfig): void { $this->setRequestTimeout($adminConfig['request_timeout']); } if (isset($adminConfig['url'])) { - if (str_ends_with($adminConfig['url'], '/')) { - $adminConfig['url'] = substr($adminConfig['url'], 0, -1) ?: $adminConfig['url']; - } - $this->setServiceUrl($adminConfig['url']); + $this->setServiceUrl(rtrim($adminConfig['url'], ' /')); } if (isset($adminConfig['service_name'])) { $this->setServiceName($adminConfig['service_name']); @@ -909,6 +1286,72 @@ public function setAdminConfig(array $adminConfig): void { if (isset($adminConfig['tts_voices'])) { $this->setAdminTtsVoices($adminConfig['tts_voices']); } + + if (isset($adminConfig['image_url'])) { + $this->setImageServiceUrl(rtrim($adminConfig['image_url'], ' /')); + } + if (isset($adminConfig['image_service_name'])) { + $this->setImageServiceName($adminConfig['image_service_name']); + } + if (isset($adminConfig['image_api_key'])) { + $this->setAdminImageApiKey($adminConfig['image_api_key']); + } + if (isset($adminConfig['image_basic_user'])) { + $this->setAdminImageBasicUser($adminConfig['image_basic_user']); + } + if (isset($adminConfig['image_basic_password'])) { + $this->setAdminImageBasicPassword($adminConfig['image_basic_password']); + } + if (isset($adminConfig['image_use_basic_auth'])) { + $this->setAdminImageUseBasicAuth($adminConfig['image_use_basic_auth']); + } + if (isset($adminConfig['image_request_timeout'])) { + $this->setImageRequestTimeout($adminConfig['image_request_timeout']); + } + + if (isset($adminConfig['stt_url'])) { + $this->setSttServiceUrl(rtrim($adminConfig['stt_url'], ' /')); + } + if (isset($adminConfig['stt_service_name'])) { + $this->setSttServiceName($adminConfig['stt_service_name']); + } + if (isset($adminConfig['stt_api_key'])) { + $this->setAdminSttApiKey($adminConfig['stt_api_key']); + } + if (isset($adminConfig['stt_basic_user'])) { + $this->setAdminSttBasicUser($adminConfig['stt_basic_user']); + } + if (isset($adminConfig['stt_basic_password'])) { + $this->setAdminSttBasicPassword($adminConfig['stt_basic_password']); + } + if (isset($adminConfig['stt_use_basic_auth'])) { + $this->setAdminSttUseBasicAuth($adminConfig['stt_use_basic_auth']); + } + if (isset($adminConfig['stt_request_timeout'])) { + $this->setSttRequestTimeout($adminConfig['stt_request_timeout']); + } + + if (isset($adminConfig['tts_url'])) { + $this->setTtsServiceUrl(rtrim($adminConfig['tts_url'], ' /')); + } + if (isset($adminConfig['tts_service_name'])) { + $this->setTtsServiceName($adminConfig['tts_service_name']); + } + if (isset($adminConfig['tts_api_key'])) { + $this->setAdminTtsApiKey($adminConfig['tts_api_key']); + } + if (isset($adminConfig['tts_basic_user'])) { + $this->setAdminTtsBasicUser($adminConfig['tts_basic_user']); + } + if (isset($adminConfig['tts_basic_password'])) { + $this->setAdminTtsBasicPassword($adminConfig['tts_basic_password']); + } + if (isset($adminConfig['tts_use_basic_auth'])) { + $this->setAdminTtsUseBasicAuth($adminConfig['tts_use_basic_auth']); + } + if (isset($adminConfig['tts_request_timeout'])) { + $this->setTtsRequestTimeout($adminConfig['tts_request_timeout']); + } } /** @@ -1010,4 +1453,25 @@ public function setAnalyzeImageProviderEnabled(bool $enabled): void { public function setChatEndpointEnabled(bool $enabled): void { $this->appConfig->setValueString(Application::APP_ID, 'chat_endpoint_enabled', $enabled ? '1' : '0', lazy: true); } + + /** + * @return bool + */ + public function imageOverrideEnabled(): bool { + return !empty($this->getImageServiceUrl()); + } + + /** + * @return bool + */ + public function sttOverrideEnabled(): bool { + return !empty($this->getSttServiceUrl()); + } + + /** + * @return bool + */ + public function ttsOverrideEnabled(): bool { + return !empty($this->getTtsServiceUrl()); + } } diff --git a/lib/Settings/Admin.php b/lib/Settings/Admin.php index 05c3442e..af3b8b2d 100644 --- a/lib/Settings/Admin.php +++ b/lib/Settings/Admin.php @@ -31,6 +31,12 @@ public function getForm(): TemplateResponse { $adminConfig = $this->openAiSettingsService->getAdminConfig(); $adminConfig['api_key'] = $adminConfig['api_key'] === '' ? '' : 'dummyApiKey'; $adminConfig['basic_password'] = $adminConfig['basic_password'] === '' ? '' : 'dummyPassword'; + $adminConfig['image_api_key'] = $adminConfig['image_api_key'] === '' ? '' : 'dummyApiKey'; + $adminConfig['image_basic_password'] = $adminConfig['image_basic_password'] === '' ? '' : 'dummyPassword'; + $adminConfig['stt_api_key'] = $adminConfig['stt_api_key'] === '' ? '' : 'dummyApiKey'; + $adminConfig['stt_basic_password'] = $adminConfig['stt_basic_password'] === '' ? '' : 'dummyPassword'; + $adminConfig['tts_api_key'] = $adminConfig['tts_api_key'] === '' ? '' : 'dummyApiKey'; + $adminConfig['tts_basic_password'] = $adminConfig['tts_basic_password'] === '' ? '' : 'dummyPassword'; $isAssistantEnabled = $this->appManager->isEnabledForUser('assistant'); $adminConfig['assistant_enabled'] = $isAssistantEnabled; $adminConfig['quota_start_date'] = $this->openAiSettingsService->getQuotaStart(); diff --git a/lib/TaskProcessing/AudioToTextProvider.php b/lib/TaskProcessing/AudioToTextProvider.php index 2f0d9fc9..c9446345 100644 --- a/lib/TaskProcessing/AudioToTextProvider.php +++ b/lib/TaskProcessing/AudioToTextProvider.php @@ -38,7 +38,7 @@ public function getId(): string { } public function getName(): string { - return $this->openAiAPIService->getServiceName(); + return $this->openAiAPIService->getServiceName(Application::SERVICE_TYPE_STT); } public function getTaskTypeId(): string { diff --git a/lib/TaskProcessing/TextToImageProvider.php b/lib/TaskProcessing/TextToImageProvider.php index 5b0e9129..44632f33 100644 --- a/lib/TaskProcessing/TextToImageProvider.php +++ b/lib/TaskProcessing/TextToImageProvider.php @@ -41,9 +41,7 @@ public function getId(): string { } public function getName(): string { - return $this->openAiAPIService->isUsingOpenAi() - ? $this->l->t('OpenAI\'s DALL-E 2') - : $this->openAiAPIService->getServiceName(); + return $this->openAiAPIService->getServiceName(Application::SERVICE_TYPE_IMAGE); } public function getTaskTypeId(): string { @@ -82,12 +80,12 @@ public function getOptionalInputShape(): array { public function getOptionalInputShapeEnumValues(): array { return [ - 'model' => $this->openAiAPIService->getModelEnumValues($this->userId), + 'model' => $this->openAiAPIService->getModelEnumValues($this->userId, serviceType: Application::SERVICE_TYPE_IMAGE), ]; } public function getOptionalInputShapeDefaults(): array { - $adminModel = $this->openAiAPIService->isUsingOpenAi() + $adminModel = $this->openAiAPIService->isUsingOpenAi(Application::SERVICE_TYPE_IMAGE) ? ($this->appConfig->getValueString(Application::APP_ID, 'default_image_model_id', Application::DEFAULT_MODEL_ID, lazy: true) ?: Application::DEFAULT_MODEL_ID) : $this->appConfig->getValueString(Application::APP_ID, 'default_image_model_id', lazy: true); return [ diff --git a/lib/TaskProcessing/TextToSpeechProvider.php b/lib/TaskProcessing/TextToSpeechProvider.php index b02d2e69..0f854ad5 100644 --- a/lib/TaskProcessing/TextToSpeechProvider.php +++ b/lib/TaskProcessing/TextToSpeechProvider.php @@ -38,9 +38,7 @@ public function getId(): string { } public function getName(): string { - return $this->openAiAPIService->isUsingOpenAi() - ? $this->l->t('OpenAI\'s Text to Speech') - : $this->openAiAPIService->getServiceName(); + return $this->openAiAPIService->getServiceName(Application::SERVICE_TYPE_TTS); } public function getTaskTypeId(): string { @@ -77,7 +75,7 @@ public function getOptionalInputShape(): array { ), 'speed' => new ShapeDescriptor( $this->l->t('Speed'), - $this->openAiAPIService->isUsingOpenAi() + $this->openAiAPIService->isUsingOpenAi(Application::SERVICE_TYPE_TTS) ? $this->l->t('Speech speed modifier (Valid values: 0.25-4)') : $this->l->t('Speech speed modifier'), EShapeType::Number @@ -89,7 +87,7 @@ public function getOptionalInputShapeEnumValues(): array { $voices = json_decode($this->appConfig->getValueString(Application::APP_ID, 'tts_voices', lazy: true)) ?: Application::DEFAULT_SPEECH_VOICES; return [ 'voice' => array_map(function ($v) { return new ShapeEnumValue($v, $v); }, $voices), - 'model' => $this->openAiAPIService->getModelEnumValues($this->userId), + 'model' => $this->openAiAPIService->getModelEnumValues($this->userId, Application::SERVICE_TYPE_TTS), ]; } @@ -143,7 +141,7 @@ public function process(?string $userId, array $input, callable $reportProgress, $speed = 1; if (isset($input['speed']) && is_numeric($input['speed'])) { $speed = $input['speed']; - if ($this->openAiAPIService->isUsingOpenAi()) { + if ($this->openAiAPIService->isUsingOpenAi(Application::SERVICE_TYPE_TTS)) { if ($speed > 4) { $speed = 4; } elseif ($speed < 0.25) { diff --git a/src/components/AdminSettings.vue b/src/components/AdminSettings.vue index da93ea1f..65d41177 100644 --- a/src/components/AdminSettings.vue +++ b/src/components/AdminSettings.vue @@ -232,7 +232,7 @@ v-model="selectedModel.text" class="model-select" :clearable="state.default_completion_model_id !== DEFAULT_MODEL_ITEM.id" - :options="formattedModels" + :options="formattedModels(models)" :input-label="t('integration_openai', 'Default completion model to use')" :no-wrap="true" input-id="openai-model-select" @@ -331,13 +331,23 @@

{{ t('integration_openai', 'Image generation') }}

-
+
{{ t('integration_openai', 'Audio transcription') }} -
+
{{ t('integration_openai', 'Text to speech') }} -
+
{ + }, + + mounted() { + if (this.configured) { + this.getAllModels(false) + } + this.loadQuotaInfo() + }, + + methods: { + formattedModels(models) { + if (models) { + return models.map(m => { return { id: m.id, value: m.id, @@ -726,16 +771,6 @@ export default { } return [] }, - }, - - mounted() { - if (this.configured) { - this.getModels(false) - } - this.loadQuotaInfo() - }, - - methods: { modelToNcSelectObject(model) { return { id: model.id, @@ -760,69 +795,85 @@ export default { console.error(error) }) }, + async getAllModels(shouldSave = true) { + const models = this.getModels() // Gets the default models. getModels returns a promise + console.debug(this.models) + const [imageModels, sttModels, ttsModels] = await Promise.all([ + this.state.image_url === '' ? models : this.getModels('image'), + this.state.stt_url === '' ? models : this.getModels('stt'), + this.state.tts_url === '' ? models : this.getModels('tts'), + ]) + this.models = await models + this.imageModels = imageModels + this.sttModels = sttModels + this.ttsModels = ttsModels - getModels(shouldSave = true) { - this.models = null - if (!this.configured) { - return - } - const url = generateUrl('/apps/integration_openai/models') - return axios.get(url) - .then((response) => { - this.models = response.data?.data ?? [] - if (this.isUsingOpenAI) { - this.models.unshift(DEFAULT_MODEL_ITEM) - } - const defaultCompletionModelId = this.state.default_completion_model_id || response.data?.default_completion_model_id - const completionModelToSelect = this.models.find(m => m.id === defaultCompletionModelId) - || this.models.find(m => m.id === 'gpt-4.1-mini') - || this.models[1] - || this.models[0] + const defaultCompletionModelId = this.state.default_completion_model_id + const completionModelToSelect = this.models.find(m => m.id === defaultCompletionModelId) + || this.models.find(m => m.id === 'gpt-4.1-mini') + || this.models[1] + || this.models[0] - const defaultImageModelId = this.state.default_image_model_id || response.data?.default_image_model_id - const imageModelToSelect = this.models.find(m => m.id === defaultImageModelId) - || this.models.find(m => m.id === 'dall-e-2') - || this.models[1] - || this.models[0] + const defaultImageModelId = this.state.default_image_model_id + const imageModelToSelect = this.imageModels.find(m => m.id === defaultImageModelId) + || this.imageModels.find(m => m.id === 'dall-e-2') + || this.imageModels[1] + || this.imageModels[0] - const defaultSttModelId = this.state.default_stt_model_id || response.data?.default_stt_model_id - const sttModelToSelect = this.models.find(m => m.id === defaultSttModelId) - || this.models.find(m => m.id.match(/whisper/i)) - || this.models[1] - || this.models[0] + const defaultSttModelId = this.state.default_stt_model_id + const sttModelToSelect = this.sttModels.find(m => m.id === defaultSttModelId) + || this.sttModels.find(m => m.id.match(/whisper/i)) + || this.sttModels[1] + || this.sttModels[0] - const defaultTtsModelId = this.state.default_tts_model_id || response.data?.default_tts_model_id - const ttsModelToSelect = this.models.find(m => m.id === defaultTtsModelId) - || this.models.find(m => m.id.match(/tts/i)) - || this.models[1] - || this.models[0] + const defaultTtsModelId = this.state.default_tts_model_id + const ttsModelToSelect = this.ttsModels.find(m => m.id === defaultTtsModelId) + || this.ttsModels.find(m => m.id.match(/tts/i)) + || this.ttsModels[1] + || this.ttsModels[0] - this.selectedModel.text = this.modelToNcSelectObject(completionModelToSelect) - this.selectedModel.image = this.modelToNcSelectObject(imageModelToSelect) - this.selectedModel.stt = this.modelToNcSelectObject(sttModelToSelect) - this.selectedModel.tts = this.modelToNcSelectObject(ttsModelToSelect) + this.selectedModel.text = this.modelToNcSelectObject(completionModelToSelect) + this.selectedModel.image = this.modelToNcSelectObject(imageModelToSelect) + this.selectedModel.stt = this.modelToNcSelectObject(sttModelToSelect) + this.selectedModel.tts = this.modelToNcSelectObject(ttsModelToSelect) - // save if url/credentials were changed OR if the values are not up-to-date in the stored settings - if (shouldSave - || this.state.default_completion_model_id !== this.selectedModel.text.id - || this.state.default_image_model_id !== this.selectedModel.image.id) { - this.saveOptions({ - default_completion_model_id: this.selectedModel.text.id, - default_image_model_id: this.selectedModel.image.id, - }, false) - } + // save if url/credentials were changed OR if the values are not up-to-date in the stored settings + if (shouldSave + || this.state.default_completion_model_id !== this.selectedModel.text.id + || this.state.default_image_model_id !== this.selectedModel.image.id + || this.state.default_stt_model_id !== this.selectedModel.stt.id + || this.state.default_tts_model_id !== this.selectedModel.tts.id) { + this.saveOptions({ + default_completion_model_id: this.selectedModel.text.id, + default_image_model_id: this.selectedModel.image.id, + default_stt_model_id: this.selectedModel.stt.id, + default_tts_model_id: this.selectedModel.tts.id, + }, false) + } - this.state.default_completion_model_id = completionModelToSelect.id - this.state.default_image_model_id = imageModelToSelect.id - }) - .catch((error) => { - showError( - t('integration_openai', 'Failed to load models') + this.state.default_completion_model_id = completionModelToSelect.id + this.state.default_image_model_id = imageModelToSelect.id + this.state.default_stt_model_id = sttModelToSelect.id + this.state.default_tts_model_id = ttsModelToSelect.id + }, + getModels(serviceType = '') { + const url = generateUrl('/apps/integration_openai/models') + return axios.get(url, { + params: { serviceType }, + }).then((response) => { + const result = response.data?.data ?? [] + if (this.isUsingOpenAI) { + result.unshift(DEFAULT_MODEL_ITEM) + } + return result + }).catch((error) => { + showError( + t('integration_openai', 'Failed to load models') + ': ' + this.reduceStars(error.response?.data?.error), - { timeout: 10000 }, - ) - console.error(error) - }) + { timeout: 10000 }, + ) + console.error(error) + }) }, onModelSelected(type, selected) { console.debug(`Selected model: ${type}: ${selected}`) @@ -878,17 +929,28 @@ export default { capitalizedWord(word) { return word.charAt(0).toUpperCase() + word.slice(1) }, + handleOverrideInput(values) { + Object.assign(this.state, values) + this.onInput() + }, + handleOverrideSensitiveInput(values, getModels = true) { + Object.assign(this.state, values) + this.onSensitiveInput(getModels) + }, async onCheckboxChanged(newValue, key, getModels = true, sensitive = false) { this.state[key] = newValue await this.saveOptions({ [key]: this.state[key] }, sensitive) if (getModels) { - this.getModels() + this.getAllModels() } }, onSensitiveInput: debounce(async function(getModels = true) { const values = { basic_user: (this.state.basic_user ?? '').trim(), url: (this.state.url ?? '').trim(), + image_url: (this.state.image_url ?? '').trim(), + stt_url: (this.state.stt_url ?? '').trim(), + tts_url: (this.state.tts_url ?? '').trim(), } if (this.state.api_key !== 'dummyApiKey') { values.api_key = (this.state.api_key ?? '').trim() @@ -896,9 +958,29 @@ export default { if (this.state.basic_password !== 'dummyPassword') { values.basic_password = (this.state.basic_password ?? '').trim() } + + if (this.state.image_api_key !== 'dummyApiKey') { + values.image_api_key = (this.state.image_api_key ?? '').trim() + } + if (this.state.image_basic_password !== 'dummyPassword') { + values.image_basic_password = (this.state.image_basic_password ?? '').trim() + } + if (this.state.stt_api_key !== 'dummyApiKey') { + values.stt_api_key = (this.state.stt_api_key ?? '').trim() + } + if (this.state.stt_basic_password !== 'dummyPassword') { + values.stt_basic_password = (this.state.stt_basic_password ?? '').trim() + } + if (this.state.tts_api_key !== 'dummyApiKey') { + values.tts_api_key = (this.state.tts_api_key ?? '').trim() + } + if (this.state.tts_basic_password !== 'dummyPassword') { + values.tts_basic_password = (this.state.tts_basic_password ?? '').trim() + } + await this.saveOptions(values, true) if (getModels) { - this.getModels() + this.getAllModels() } this.autoDetectFeatures() }, 2000), @@ -917,6 +999,15 @@ export default { tts_voices: this.state.tts_voices, default_tts_voice: this.state.default_tts_voice, usage_storage_time: this.state.usage_storage_time, + image_service_name: this.state.image_service_name, + image_request_timeout: this.state.image_request_timeout, + image_use_basic_auth: this.state.image_use_basic_auth, + stt_service_name: this.state.stt_service_name, + stt_request_timeout: this.state.stt_request_timeout, + stt_use_basic_auth: this.state.stt_use_basic_auth, + tts_service_name: this.state.tts_service_name, + tts_request_timeout: this.state.tts_request_timeout, + tts_use_basic_auth: this.state.tts_use_basic_auth, } await this.saveOptions(values, false) }, 2000), diff --git a/src/components/ServiceOverridePanel.vue b/src/components/ServiceOverridePanel.vue new file mode 100644 index 00000000..c744d00d --- /dev/null +++ b/src/components/ServiceOverridePanel.vue @@ -0,0 +1,341 @@ + + + + + + diff --git a/tests/unit/Service/ServiceOverrideTest.php b/tests/unit/Service/ServiceOverrideTest.php new file mode 100644 index 00000000..fe89c951 --- /dev/null +++ b/tests/unit/Service/ServiceOverrideTest.php @@ -0,0 +1,254 @@ +createUser(self::TEST_USER1, self::TEST_USER1); + \OCP\Server::get(\OCP\IUserManager::class)->registerBackend($backend); + } + + protected function setUp(): void { + parent::setUp(); + + $this->loginAsUser(self::TEST_USER1); + + $this->openAiSettingsService = \OCP\Server::get(OpenAiSettingsService::class); + + $this->chunkService = \OCP\Server::get(ChunkService::class); + + $this->quotaUsageMapper = \OCP\Server::get(QuotaUsageMapper::class); + + // We'll hijack the client service and subsequently iClient to return a mock response from the OpenAI API + $clientService = $this->createMock(IClientService::class); + $this->iClient = $this->createMock(IClient::class); + $clientService->method('newClient')->willReturn($this->iClient); + + $this->openAiApiService = new OpenAiAPIService( + \OCP\Server::get(\Psr\Log\LoggerInterface::class), + $this->createMock(\OCP\IL10N::class), + \OCP\Server::get(IAppConfig::class), + \OCP\Server::get(ICacheFactory::class), + \OCP\Server::get(QuotaUsageMapper::class), + $this->openAiSettingsService, + $this->createMock(\OCP\Notification\IManager::class), + \OCP\Server::get(QuotaRuleService::class), + $clientService, + ); + } + + public static function tearDownAfterClass(): void { + // Delete quota usage for test user + $quotaUsageMapper = \OCP\Server::get(QuotaUsageMapper::class); + try { + $quotaUsageMapper->deleteUserQuotaUsages(self::TEST_USER1); + } catch (\OCP\Db\Exception|\RuntimeException|\Exception|\Throwable $e) { + // Ignore + } + + $backend = new \Test\Util\User\Dummy(); + $backend->deleteUser(self::TEST_USER1); + \OCP\Server::get(\OCP\IUserManager::class)->removeBackend($backend); + + $openAiSettingsService = \OCP\Server::get(OpenAiSettingsService::class); + $openAiSettingsService->setImageServiceUrl(''); + $openAiSettingsService->setTtsServiceUrl(''); + $openAiSettingsService->setSttServiceUrl(''); + + parent::tearDownAfterClass(); + } + + public function testTextToSpeechProvider(): void { + $this->openAiSettingsService->setTtsServiceUrl(self::OVERRIDE_SPEECH_BASE); + $this->openAiSettingsService->setAdminTtsApiKey(self::APIKEY_SPEECH); + $this->openAiSettingsService->setTtsRequestTimeout(self::REQUEST_TIMEOUT_SPEECH); + + $TTSProvider = new TextToSpeechProvider( + $this->openAiApiService, + $l10n = $this->createMock(\OCP\IL10N::class), + $this->createMock(\Psr\Log\LoggerInterface::class), + \OCP\Server::get(IAppConfig::class), + self::TEST_USER1, + \OCP\Server::get(WatermarkingService::class), + ); + + $inputText = 'This is a test prompt'; + + $response = file_get_contents(__DIR__ . '/../../res/speech.mp3'); + + if (!$response) { + throw new \RuntimeException('Could not read test resourcce `speech.mp3`'); + } + + $url = self::OVERRIDE_SPEECH_BASE . 'audio/speech'; + + $options = ['timeout' => self::REQUEST_TIMEOUT_SPEECH, 'headers' => ['User-Agent' => Application::USER_AGENT, 'Authorization' => 'Bearer ' . self::APIKEY_SPEECH, 'Content-Type' => 'application/json'], 'nextcloud' => ['allow_local_address' => true]]; + $options['body'] = json_encode([ + 'input' => $inputText, + 'voice' => Application::DEFAULT_SPEECH_VOICE, + 'model' => Application::DEFAULT_SPEECH_MODEL_ID, + 'response_format' => 'mp3', + 'speed' => 1, + ]); + + $iResponse = $this->createMock(\OCP\Http\Client\IResponse::class); + $iResponse->method('getBody')->willReturn($response); + $iResponse->method('getStatusCode')->willReturn(200); + + $this->iClient->expects($this->once())->method('post')->with($url, $options)->willReturn($iResponse); + + $TTSProvider->process(self::TEST_USER1, ['input' => $inputText], fn () => null, includeWatermark: false); + } + + public function testTextToImageProvider(): void { + $this->openAiSettingsService->setImageServiceUrl(self::OVERRIDE_IMAGE_BASE); + $this->openAiSettingsService->setAdminImageApiKey(self::APIKEY_IMAGE); + $this->openAiSettingsService->setImageRequestTimeout(self::REQUEST_TIMEOUT_IMAGE); + + $TextToImageProvider = new TextToImageProvider( + $this->openAiApiService, + $this->createMock(\OCP\IL10N::class), + $this->createMock(\Psr\Log\LoggerInterface::class), + \OCP\Server::get(IClientService::class), + \OCP\Server::get(IAppConfig::class), + self::TEST_USER1, + \OCP\Server::get(WatermarkingService::class), + ); + + $inputText = 'This is a test prompt'; + + $responseImage = file_get_contents(__DIR__ . '/../../res/trees.jpg'); + + if (!$responseImage) { + throw new \RuntimeException('Could not read test resourcce `trees.jpg`'); + } + + $response = json_encode([ + 'data' => [ + [ + 'b64_json' => base64_encode($responseImage), + ] + ] + ]); + + $url = self::OVERRIDE_IMAGE_BASE . 'images/generations'; + + $options = ['timeout' => self::REQUEST_TIMEOUT_IMAGE, 'headers' => ['User-Agent' => Application::USER_AGENT, 'Authorization' => 'Bearer ' . self::APIKEY_IMAGE, 'Content-Type' => 'application/json'], 'nextcloud' => ['allow_local_address' => true]]; + $options['body'] = json_encode([ + 'prompt' => $inputText, + 'size' => '1024x1024', + 'n' => 1, + 'model' => Application::DEFAULT_IMAGE_MODEL_ID, + ]); + + $iResponse = $this->createMock(\OCP\Http\Client\IResponse::class); + $iResponse->method('getHeader')->with('Content-Type')->willReturn('application/json'); + $iResponse->method('getBody')->willReturn($response); + $iResponse->method('getStatusCode')->willReturn(200); + + $this->iClient->expects($this->once())->method('post')->with($url, $options)->willReturn($iResponse); + + $TextToImageProvider->process(self::TEST_USER1, ['input' => $inputText, 'numberOfImages' => 1], fn () => null); + } + + public function testAudioToTextProvider(): void { + $this->openAiSettingsService->setSttServiceUrl(self::OVERRIDE_TRANSCRIPTION_BASE); + $this->openAiSettingsService->setAdminSttApiKey(self::APIKEY_TRANSCRIPTION); + $this->openAiSettingsService->setSttRequestTimeout(self::REQUEST_TIMEOUT_TRANSCRIPTION); + + $audioToTextProvider = new AudioToTextProvider( + $this->openAiApiService, + $this->createMock(\Psr\Log\LoggerInterface::class), + \OCP\Server::get(IAppConfig::class), + $this->createMock(\OCP\IL10N::class), + ); + + $file = $this->createMock(\OCP\Files\File::class); + + + $inputSpeech = file_get_contents(__DIR__ . '/../../res/speech.mp3'); + + if (!$inputSpeech) { + throw new \RuntimeException('Could not read test resource `speech.mp3`'); + } + $file->method('isReadable')->willReturn(true); + $file->method('getContent')->willReturn($inputSpeech); + + $response = json_encode([ + 'text' => 'Transcribed text' + ]); + + $url = self::OVERRIDE_TRANSCRIPTION_BASE . 'audio/transcriptions'; + + $options = ['timeout' => self::REQUEST_TIMEOUT_TRANSCRIPTION, 'headers' => ['User-Agent' => Application::USER_AGENT, 'Authorization' => 'Bearer ' . self::APIKEY_TRANSCRIPTION], 'nextcloud' => ['allow_local_address' => true]]; + $options['multipart'] = [ + ['name' => 'model', 'contents' => Application::DEFAULT_TRANSCRIPTION_MODEL_ID], + ['name' => 'file', 'contents' => $inputSpeech, 'filename' => 'file.mp3'], + ['name' => 'response_format', 'contents' => 'verbose_json'], + ]; + $iResponse = $this->createMock(\OCP\Http\Client\IResponse::class); + $iResponse->method('getHeader')->with('Content-Type')->willReturn('application/json'); + $iResponse->method('getBody')->willReturn($response); + $iResponse->method('getStatusCode')->willReturn(200); + + $this->iClient->expects($this->once())->method('post')->with($url, $options)->willReturn($iResponse); + + $audioToTextProvider->process(self::TEST_USER1, ['input' => $file], fn () => null); + } + +}