From 844f3b59b73f1ea176f1b771baa8a0011a2ad4d0 Mon Sep 17 00:00:00 2001 From: Kais NEFFATI Date: Sat, 20 Apr 2024 09:22:02 +0200 Subject: [PATCH] [FEATURE] Support OpenAI Audio API --- README.md | 53 +++++++++ .../ai4j/openai4j/BytesConverterFactory.java | 48 ++++++++ .../ai4j/openai4j/DefaultOpenAiClient.java | 9 ++ .../java/dev/ai4j/openai4j/FilePersistor.java | 7 +- .../java/dev/ai4j/openai4j/OpenAiApi.java | 9 ++ .../java/dev/ai4j/openai4j/OpenAiClient.java | 4 + .../openai4j/PersistorConverterFactory.java | 12 ++ .../openai4j/audio/GenerateSpeechRequest.java | 109 ++++++++++++++++++ .../audio/GenerateSpeechResponse.java | 86 ++++++++++++++ .../dev/ai4j/openai4j/audio/SpeechModel.java | 68 +++++++++++ .../dev.ai4j/openai4j/reflect-config.json | 18 +++ .../dev/ai4j/openai4j/FilePersistorTest.java | 13 +++ .../openai4j/audio/SpeechGenerationTest.java | 65 +++++++++++ 13 files changed, 499 insertions(+), 2 deletions(-) create mode 100644 src/main/java/dev/ai4j/openai4j/BytesConverterFactory.java create mode 100644 src/main/java/dev/ai4j/openai4j/audio/GenerateSpeechRequest.java create mode 100644 src/main/java/dev/ai4j/openai4j/audio/GenerateSpeechResponse.java create mode 100644 src/main/java/dev/ai4j/openai4j/audio/SpeechModel.java create mode 100644 src/test/java/dev/ai4j/openai4j/audio/SpeechGenerationTest.java diff --git a/README.md b/README.md index e46f93c..9fde0ac 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,8 @@ This is an unofficial Java client library that helps to connect your Java applic - [synchronous](https://github.com/ai-for-java/openai4j#synchronously-3) - [asynchronous](https://github.com/ai-for-java/openai4j#asynchronously-3) - [Functions](https://github.com/ai-for-java/openai4j/blob/main/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionTest.java) +- [Audio](https://platform.openai.com/docs/api-reference/audio) + - [Speech](https://platform.openai.com/docs/api-reference/audio/createSpeech) ## Coming soon: @@ -380,6 +382,57 @@ Customizable way: String localImage = response.data().get(0).url(); ``` + +## Audio Generations +### Create speech + +Simple way: + +``` + OpenAiClient client = OpenAiClient + .builder() + .openAiApiKey(System.getenv("OPENAI_API_KEY")) + .build(); + + SpeechRequest request = SpeechRequest + .builder() + .model(TTS_1) + .input("The quick brown fox jumped over the lazy dog.") + .voice(SpeechModel.Voice.ALLOY) + .build(); + + SpeechResponse response = client.speechGenerations(request).execute(); + + // Byte array audio speech generated + String speechData = response.data(); +``` + +Customizable way: + +``` + OpenAiClient client = OpenAiClient + .builder() + .openAiApiKey(System.getenv("OPENAI_API_KEY")) + .logRequests() + .logResponses() + .withPersisting() + .build(); + + SpeechRequest request = SpeechRequest + .builder() + .model(TTS_1) + .input("The quick brown fox jumped over the lazy dog.") + .voice(SpeechModel.Voice.ALLOY) + .responseFormat(SpeechModel.ResponseFormat.WAV) + .speed(2) + .build(); + + AudioResponse response = client.speechGenerations(request).execute(); + + // your generated audio speech is here locally: + String speechUrl = response.url(); +``` + # Useful materials - How to get best results form AI: https://www.deeplearning.ai/short-courses/chatgpt-prompt-engineering-for-developers/ diff --git a/src/main/java/dev/ai4j/openai4j/BytesConverterFactory.java b/src/main/java/dev/ai4j/openai4j/BytesConverterFactory.java new file mode 100644 index 0000000..d452fc2 --- /dev/null +++ b/src/main/java/dev/ai4j/openai4j/BytesConverterFactory.java @@ -0,0 +1,48 @@ +package dev.ai4j.openai4j; + +import dev.ai4j.openai4j.audio.GenerateSpeechResponse; +import okhttp3.ResponseBody; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import retrofit2.Converter; +import retrofit2.Retrofit; + +import java.io.IOException; +import java.lang.annotation.Annotation; +import java.lang.reflect.Type; + +/** + * A converter factory to handle the conversion of Retrofit response bodies into + * ByteArrayObjectWrapper instances. + */ +public class BytesConverterFactory extends Converter.Factory { + + private static final Logger logger = LoggerFactory.getLogger(BytesConverterFactory.class); + + public BytesConverterFactory() { + // Constructor can be utilized for initializing if needed + } + + @Override + public Converter responseBodyConverter(Type type, Annotation[] annotations, Retrofit retrofit) { + + logger.debug("Requesting conversion for type: {}", type.getTypeName()); + if (GenerateSpeechResponse.class.equals(type)) { + return responseBody -> { + try { + logger.debug("Converting response body to GenerateSpeechResponse"); + return GenerateSpeechResponse.builder() + .data(responseBody.bytes()) + .build(); + } catch (IOException e) { + logger.error("Failed to read bytes from response body", e); + throw new RuntimeException("Error reading response body", e); + } finally { + responseBody.close(); + } + }; + } + logger.debug("No converter found for type: {}", type.getTypeName()); + return null; + } +} diff --git a/src/main/java/dev/ai4j/openai4j/DefaultOpenAiClient.java b/src/main/java/dev/ai4j/openai4j/DefaultOpenAiClient.java index 3585090..2849c85 100644 --- a/src/main/java/dev/ai4j/openai4j/DefaultOpenAiClient.java +++ b/src/main/java/dev/ai4j/openai4j/DefaultOpenAiClient.java @@ -1,5 +1,7 @@ package dev.ai4j.openai4j; +import dev.ai4j.openai4j.audio.GenerateSpeechRequest; +import dev.ai4j.openai4j.audio.GenerateSpeechResponse; import dev.ai4j.openai4j.chat.ChatCompletionRequest; import dev.ai4j.openai4j.chat.ChatCompletionResponse; import dev.ai4j.openai4j.completion.CompletionRequest; @@ -96,6 +98,8 @@ private DefaultOpenAiClient(Builder serviceBuilder) { retrofitBuilder.addConverterFactory(new PersistorConverterFactory(serviceBuilder.persistTo)); } + retrofitBuilder.addConverterFactory(new BytesConverterFactory()); + retrofitBuilder.addConverterFactory(GsonConverterFactory.create(GSON)); this.openAiApi = retrofitBuilder.build().create(OpenAiApi.class); @@ -224,6 +228,11 @@ public SyncOrAsync imagesGeneration(GenerateImagesReques return new RequestExecutor<>(openAiApi.imagesGenerations(request, apiVersion), r -> r); } + @Override + public SyncOrAsync speechGeneration(GenerateSpeechRequest request) { + return new RequestExecutor<>(openAiApi.speechGenerations(request, apiVersion), r -> r); + } + private String formatUrl(String endpoint) { return baseUrl + endpoint + apiVersionQueryParam(); } diff --git a/src/main/java/dev/ai4j/openai4j/FilePersistor.java b/src/main/java/dev/ai4j/openai4j/FilePersistor.java index c354976..0dbba50 100644 --- a/src/main/java/dev/ai4j/openai4j/FilePersistor.java +++ b/src/main/java/dev/ai4j/openai4j/FilePersistor.java @@ -29,9 +29,12 @@ static Path persistFromUri(URI uri, Path destinationFolder) { public static Path persistFromBase64String(String base64EncodedString, Path destinationFolder) throws IOException { byte[] decodedBytes = Base64.getDecoder().decode(base64EncodedString); - Path destinationFile = destinationFolder.resolve(randomFileName()); + return persistFromByteArray(decodedBytes, destinationFolder); + } - Files.write(destinationFile, decodedBytes, StandardOpenOption.CREATE); + public static Path persistFromByteArray(byte[] bytes, Path destinationFolder) throws IOException { + Path destinationFile = destinationFolder.resolve(randomFileName()); + Files.write(destinationFile, bytes, StandardOpenOption.CREATE); return destinationFile; } diff --git a/src/main/java/dev/ai4j/openai4j/OpenAiApi.java b/src/main/java/dev/ai4j/openai4j/OpenAiApi.java index 950610e..d091e60 100644 --- a/src/main/java/dev/ai4j/openai4j/OpenAiApi.java +++ b/src/main/java/dev/ai4j/openai4j/OpenAiApi.java @@ -1,5 +1,7 @@ package dev.ai4j.openai4j; +import dev.ai4j.openai4j.audio.GenerateSpeechRequest; +import dev.ai4j.openai4j.audio.GenerateSpeechResponse; import dev.ai4j.openai4j.chat.ChatCompletionRequest; import dev.ai4j.openai4j.chat.ChatCompletionResponse; import dev.ai4j.openai4j.completion.CompletionRequest; @@ -42,4 +44,11 @@ Call imagesGenerations( @Body GenerateImagesRequest request, @Query("api-version") String apiVersion ); + + @POST("audio/speech") + @Headers({ "Content-Type: application/json" }) + Call speechGenerations( + @Body GenerateSpeechRequest request, + @Query("api-version") String apiVersion + ); } diff --git a/src/main/java/dev/ai4j/openai4j/OpenAiClient.java b/src/main/java/dev/ai4j/openai4j/OpenAiClient.java index 6033627..c72c527 100644 --- a/src/main/java/dev/ai4j/openai4j/OpenAiClient.java +++ b/src/main/java/dev/ai4j/openai4j/OpenAiClient.java @@ -7,6 +7,8 @@ import java.time.Duration; import java.util.List; +import dev.ai4j.openai4j.audio.GenerateSpeechRequest; +import dev.ai4j.openai4j.audio.GenerateSpeechResponse; import dev.ai4j.openai4j.chat.ChatCompletionRequest; import dev.ai4j.openai4j.chat.ChatCompletionResponse; import dev.ai4j.openai4j.completion.CompletionRequest; @@ -44,6 +46,8 @@ public abstract class OpenAiClient { public abstract SyncOrAsync imagesGeneration(GenerateImagesRequest request); + public abstract SyncOrAsync speechGeneration(GenerateSpeechRequest request); + public abstract void shutdown(); @SuppressWarnings("rawtypes") diff --git a/src/main/java/dev/ai4j/openai4j/PersistorConverterFactory.java b/src/main/java/dev/ai4j/openai4j/PersistorConverterFactory.java index abc4b92..4037246 100644 --- a/src/main/java/dev/ai4j/openai4j/PersistorConverterFactory.java +++ b/src/main/java/dev/ai4j/openai4j/PersistorConverterFactory.java @@ -1,5 +1,6 @@ package dev.ai4j.openai4j; +import dev.ai4j.openai4j.audio.GenerateSpeechResponse; import dev.ai4j.openai4j.image.GenerateImagesResponse; import java.io.IOException; import java.lang.annotation.Annotation; @@ -49,6 +50,17 @@ public T convert(ResponseBody value) throws IOException { }); } + if (response instanceof GenerateSpeechResponse) { + try { + GenerateSpeechResponse generateSpeechResponse = (GenerateSpeechResponse) response; + generateSpeechResponse.url( + FilePersistor.persistFromByteArray(generateSpeechResponse.data(), persistTo).toUri() + ); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + return response; } } diff --git a/src/main/java/dev/ai4j/openai4j/audio/GenerateSpeechRequest.java b/src/main/java/dev/ai4j/openai4j/audio/GenerateSpeechRequest.java new file mode 100644 index 0000000..8e6d92f --- /dev/null +++ b/src/main/java/dev/ai4j/openai4j/audio/GenerateSpeechRequest.java @@ -0,0 +1,109 @@ +package dev.ai4j.openai4j.audio; + +import java.util.Objects; + +/** + * This class represents a request to generate audio speech using specific parameters. + */ +public class GenerateSpeechRequest { + private final String model; + private final String input; + private final String voice; + private final String responseFormat; + private final double speed; + + private GenerateSpeechRequest(Builder builder) { + this.model = Objects.requireNonNull(builder.model, "Model cannot be null"); + this.input = Objects.requireNonNull(builder.input, "Input cannot be null"); + this.voice = Objects.requireNonNull(builder.voice, "Voice cannot be null"); + this.responseFormat = builder.responseFormat; + this.speed = builder.speed; + } + + // Implementing equals method to ensure correct behavior in collections and other use cases. + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + GenerateSpeechRequest that = (GenerateSpeechRequest) o; + return Double.compare(that.speed, speed) == 0 && + Objects.equals(model, that.model) && + Objects.equals(input, that.input) && + Objects.equals(voice, that.voice) && + Objects.equals(responseFormat, that.responseFormat); + } + + @Override + public int hashCode() { + return Objects.hash(model, input, voice, responseFormat, speed); + } + + @Override + public String toString() { + return String.format( + "GenerateAudioRequest{model='%s', input='%s', voice='%s', responseFormat='%s', speed=%.1f}", + model, input, voice, responseFormat, speed + ); + } + + // Static factory method for the builder, improving code readability. + public static Builder builder() { + return new Builder(); + } + + // Builder class for GenerateAudioRequest. + public static class Builder { + private String model; + private String input; + private String voice; + private String responseFormat = SpeechModel.ResponseFormat.MP3.toString(); // Default response format + private double speed = 1.0; // Default speed + + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder model(SpeechModel model) { + this.model = model.toString(); + return this; + } + + public Builder input(String input) { + this.input = input; + return this; + } + + public Builder voice(String voice) { + this.voice = voice; + return this; + } + + public Builder voice(SpeechModel.Voice voice) { + this.voice = voice.toString(); + return this; + } + + public Builder responseFormat(String responseFormat) { + this.responseFormat = responseFormat; + return this; + } + + public Builder responseFormat(SpeechModel.ResponseFormat responseFormat) { + this.responseFormat = responseFormat.toString(); + return this; + } + + public Builder speed(double speed) { + if (speed <= 0) { + throw new IllegalArgumentException("Speed must be positive"); + } + this.speed = speed; + return this; + } + + public GenerateSpeechRequest build() { + return new GenerateSpeechRequest(this); + } + } +} diff --git a/src/main/java/dev/ai4j/openai4j/audio/GenerateSpeechResponse.java b/src/main/java/dev/ai4j/openai4j/audio/GenerateSpeechResponse.java new file mode 100644 index 0000000..06c0cdd --- /dev/null +++ b/src/main/java/dev/ai4j/openai4j/audio/GenerateSpeechResponse.java @@ -0,0 +1,86 @@ +package dev.ai4j.openai4j.audio; + +import java.net.URI; +import java.util.Arrays; +import java.util.Objects; + +/** + * Represents the response from the OpenAI Speech API when generating audio. + * For a detailed description of parameters, see OpenAI Speech Audio API. + */ +public class GenerateSpeechResponse { + + private final byte[] data; + + private URI url; + + public GenerateSpeechResponse(Builder builder) { + this.data = builder.data; + this.url = builder.url; + } + + public static Builder builder() { + return new Builder(); + } + + public byte[] data() { + return data; + } + + public URI url() { + return url; + } + + public void url(URI url) { + this.url = url; + } + + @Override + public String toString() { + return ( + "GenerateSpeechResponse{" + + "url='" + + url + + '\'' + + ", data='" + + Arrays.toString(data) + + '}' + ); + } + + @Override + public boolean equals(Object another) { + if (this == another) return true; + if (another == null || getClass() != another.getClass()) return false; + GenerateSpeechResponse anotherGenerateSpeechResponse = (GenerateSpeechResponse) another; + return Arrays.equals(data, anotherGenerateSpeechResponse.data) + && Objects.equals(url, anotherGenerateSpeechResponse.url); + } + + @Override + public int hashCode() { + int result = Objects.hash(url); + result = 31 * result + Arrays.hashCode(data); + return result; + } + + public static class Builder { + + private byte[] data; + private URI url; + + public Builder data(byte[] data) { + this.data = data; + return this; + } + + public Builder url(URI url) { + this.url = url; + return this; + } + + public GenerateSpeechResponse build() { + return new GenerateSpeechResponse(this); + } + } +} diff --git a/src/main/java/dev/ai4j/openai4j/audio/SpeechModel.java b/src/main/java/dev/ai4j/openai4j/audio/SpeechModel.java new file mode 100644 index 0000000..aa96eb9 --- /dev/null +++ b/src/main/java/dev/ai4j/openai4j/audio/SpeechModel.java @@ -0,0 +1,68 @@ +package dev.ai4j.openai4j.audio; + +/** + * Enum representing different OpenAI audio speech models. + */ +public enum SpeechModel { + TTS_1("tts-1"), + TTS_1_HD("tts-1-hd"); + + private final String modelId; + + SpeechModel(String modelId) { + this.modelId = modelId; + } + + @Override + public String toString() { + return modelId; + } + + + /** + * Enum representing different voice types for TTS (Text To Speech). + */ + public enum Voice { + ALLOY("alloy"), + ECHO("echo"), + FABLE("fable"), + ONYX("onyx"), + NOVA("nova"), + SHIMMER("shimmer"); + + private final String voiceType; + + Voice(String voiceType) { + this.voiceType = voiceType; + } + + @Override + public String toString() { + return voiceType; + } + } + + /** + * Enum representing different audio response formats. + */ + public enum ResponseFormat { + MP3("mp3"), + OPUS("opus"), + AAC("aac"), + FLAC("flac"), + WAV("wav"), + PCM("pcm"); + + private final String format; + + ResponseFormat(String format) { + this.format = format; + } + + @Override + public String toString() { + return format; + } + } +} + diff --git a/src/main/resources/META-INF/native-image/dev.ai4j/openai4j/reflect-config.json b/src/main/resources/META-INF/native-image/dev.ai4j/openai4j/reflect-config.json index 278cfcf..b98a2f0 100644 --- a/src/main/resources/META-INF/native-image/dev.ai4j/openai4j/reflect-config.json +++ b/src/main/resources/META-INF/native-image/dev.ai4j/openai4j/reflect-config.json @@ -305,6 +305,24 @@ "allDeclaredFields": true, "allPublicFields": true }, + { + "name": "dev.ai4j.openai4j.audio.GenerateSpeechRequest", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.ai4j.openai4j.audio.GenerateSpeechResponse", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, { "name": "dev.ai4j.openai4j.moderation.Categories", "allDeclaredConstructors": true, diff --git a/src/test/java/dev/ai4j/openai4j/FilePersistorTest.java b/src/test/java/dev/ai4j/openai4j/FilePersistorTest.java index f67cca8..3dc2cf9 100644 --- a/src/test/java/dev/ai4j/openai4j/FilePersistorTest.java +++ b/src/test/java/dev/ai4j/openai4j/FilePersistorTest.java @@ -6,6 +6,7 @@ import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; +import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import org.junit.jupiter.api.Test; @@ -52,4 +53,16 @@ void shouldPersistFromBase64String() throws IOException { assertThat(filePath).startsWith(TEMP_DIR); assertThat(filePath).exists().hasContent("Hello world!"); } + + @Test + void shouldPersistFromByteArray() throws IOException { + byte[] simpleByteArray = new byte[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; // Sample byte array + + Path filePath = FilePersistor.persistFromByteArray(simpleByteArray, TEMP_DIR); + + assertThat(filePath).isNotNull(); + assertThat(filePath).startsWith(TEMP_DIR); + assertThat(Files.exists(filePath)).isTrue(); + assertThat(Files.readAllBytes(filePath)).isEqualTo(simpleByteArray); + } } diff --git a/src/test/java/dev/ai4j/openai4j/audio/SpeechGenerationTest.java b/src/test/java/dev/ai4j/openai4j/audio/SpeechGenerationTest.java new file mode 100644 index 0000000..60382e5 --- /dev/null +++ b/src/test/java/dev/ai4j/openai4j/audio/SpeechGenerationTest.java @@ -0,0 +1,65 @@ +package dev.ai4j.openai4j.audio; + +import dev.ai4j.openai4j.OpenAiClient; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.net.URI; +import java.util.Arrays; + +import static dev.ai4j.openai4j.audio.SpeechModel.TTS_1; +import static org.assertj.core.api.Assertions.assertThat; + +public class SpeechGenerationTest { + + @Test + void generationShouldWork() { + OpenAiClient client = OpenAiClient + .builder() + .openAiApiKey(System.getenv("OPENAI_API_KEY")) + .logRequests() + .logResponses() + .build(); + + GenerateSpeechRequest request = GenerateSpeechRequest + .builder() + .model(TTS_1) + .voice(SpeechModel.Voice.ALLOY) + .input("Beautiful house on country side") + .build(); + + GenerateSpeechResponse response = client.speechGeneration(request).execute(); + + byte[] audioSpeechData = response.data(); + + System.out.println("Your audio speech is here: " + Arrays.toString(audioSpeechData)); + + assertThat(response.data()).isNotEmpty(); + } + + @Test + void generationWithDownloadShouldWork() { + OpenAiClient client = OpenAiClient + .builder() + .openAiApiKey(System.getenv("OPENAI_API_KEY")) + .logRequests() + .logResponses() + .withPersisting() + .build(); + + GenerateSpeechRequest request = GenerateSpeechRequest + .builder() + .model(TTS_1) + .voice(SpeechModel.Voice.ALLOY) + .input("Bird flying in the sky") + .build(); + + GenerateSpeechResponse response = client.speechGeneration(request).execute(); + + URI speechUrl = response.url(); + + System.out.println("Your audio speech url is here: " + speechUrl); + + assertThat(new File(speechUrl)).exists(); + } +}