Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,15 @@ internal class DefaultModelDownloaderClient(

downloadZipFile(url, zipFile).getOrElse { return@flow emit(FileDownloadState.OnFailed(it)) }

println("[AEROEDGE] downloaded ${zipFile.path}")

unzipModelFile(
zipFile,
destinationFile
).getOrElse { return@flow emit(FileDownloadState.OnFailed(it)) }

println("[AEROEDGE] unzipped into ${destinationFile.path}")

emit(FileDownloadState.OnSuccess)
}

Expand Down Expand Up @@ -78,6 +82,7 @@ internal class DefaultModelDownloaderClient(
zipFile: File,
destinationFile: File,
): Result<Unit> = runCatching<Unit> {
var found = false
ZipFile(zipFile).use { zip ->
zip.entries().asSequence().forEach { entry ->
if (entry.name != destinationFile.name) return@forEach
Expand All @@ -90,10 +95,15 @@ internal class DefaultModelDownloaderClient(
bos.write(bytesIn, 0, read)
}
bos.close()
found = true
}
}
}
zipFile.delete()

if(!found) {
throw IllegalStateException("No entry found with the name ${destinationFile.name} into ${zipFile.name}")
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ internal class ModelDownloadWorker(
return Result.failure()
}

println("[AEROEDGE] POSTING ${destinationFile.path}")

return Result.success(workDataOf(WORKER_LOCAL_MODEL_FILE_PATH_KEY to destinationFile.path))
}

Expand Down
14 changes: 13 additions & 1 deletion app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,33 @@ android {
}

dependencies {
implementation(fileTree(mapOf("dir" to "libs", "include" to listOf("*.aar"))))

implementation(project(":aeroedge"))

implementation("androidx.core:core-ktx:1.9.0")
implementation("androidx.core:core-ktx:1.12.0")
implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.6.2")
implementation("androidx.activity:activity-compose:1.8.0")
implementation("androidx.lifecycle:lifecycle-runtime-compose:2.6.2")
implementation(platform("androidx.compose:compose-bom:2023.03.00"))
implementation("androidx.compose.ui:ui")
implementation("androidx.compose.ui:ui-graphics")
implementation("androidx.compose.ui:ui-tooling-preview")
implementation("androidx.compose.material3:material3")
implementation("androidx.compose.material:material-icons-extended")
implementation("com.google.accompanist:accompanist-systemuicontroller:0.30.0")
implementation("org.tensorflow:tensorflow-lite:2.14.0")
implementation("io.insert-koin:koin-core:3.5.0")
implementation("io.insert-koin:koin-android:3.5.0")
implementation("io.insert-koin:koin-androidx-compose:3.5.0")

testImplementation("junit:junit:4.13.2")

androidTestImplementation("androidx.test.ext:junit:1.1.5")
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
androidTestImplementation(platform("androidx.compose:compose-bom:2023.03.00"))
androidTestImplementation("androidx.compose.ui:ui-test-junit4")

debugImplementation("androidx.compose.ui:ui-tooling")
debugImplementation("androidx.compose.ui:ui-test-manifest")

Expand Down
Binary file added app/libs/tensorflow-lite-select-tf-ops.aar
Binary file not shown.
9 changes: 5 additions & 4 deletions app/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools">

<uses-permission android:name="android.permission.INTERNET" />

<application
android:name=".DemoApplication"
android:allowBackup="true"
android:dataExtractionRules="@xml/data_extraction_rules"
android:fullBackupContent="@xml/backup_rules"
Expand All @@ -13,10 +16,8 @@
android:theme="@style/Theme.AeroEdge"
tools:targetApi="31">
<activity
android:name=".MainActivity"
android:exported="true"
android:label="@string/app_name"
android:theme="@style/Theme.AeroEdge">
android:name=".ui.MainActivity"
android:exported="true">
<intent-filter>
<action android:name="android.intent.action.MAIN" />

Expand Down
21 changes: 21 additions & 0 deletions app/src/main/java/app/sarama/aeroedge/DemoApplication.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package app.sarama.aeroedge

import android.app.Application
import app.sarama.aeroedge.di.appModule
import app.sarama.aeroedge.di.viewModelModule
import org.koin.android.ext.koin.androidContext
import org.koin.android.ext.koin.androidLogger
import org.koin.core.context.startKoin

class DemoApplication: Application() {

override fun onCreate() {
super.onCreate()
startKoin {
androidLogger()
androidContext(this@DemoApplication)

modules(appModule, viewModelModule)
}
}
}
46 changes: 0 additions & 46 deletions app/src/main/java/app/sarama/aeroedge/MainActivity.kt

This file was deleted.

34 changes: 34 additions & 0 deletions app/src/main/java/app/sarama/aeroedge/di/appModule.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package app.sarama.aeroedge.di

import app.sarama.aeroedge.AeroEdge
import app.sarama.aeroedge.ModelEntity
import app.sarama.aeroedge.ModelServerClient
import app.sarama.aeroedge.service.autocomplete.AutoCompleteService
import app.sarama.aeroedge.service.autocomplete.AutoCompleteServiceImpl
import org.koin.dsl.module
import org.koin.android.ext.koin.androidContext


val appModule = module {
single<AutoCompleteService> {
AutoCompleteServiceImpl(
aeroEdge = get(),
)
}

single {
AeroEdge(
context = androidContext(),
client = object : ModelServerClient {
override suspend fun fetchRemoteModelInfo(modelName: String) = Result.success(
ModelEntity(
name = modelName,
version = 1,
url = "https://gitlab.com/melvin.biamont/test-aeroedge/-/raw/main/autocomplete_1.tflite.zip?ref_type=heads&inline=false",
fileExtension = "tflite"
)
)
}
)
}
}
9 changes: 9 additions & 0 deletions app/src/main/java/app/sarama/aeroedge/di/viewModelModule.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package app.sarama.aeroedge.di

import app.sarama.aeroedge.ui.screen.autocomplete.AutoCompleteViewModel
import org.koin.androidx.viewmodel.dsl.viewModel
import org.koin.dsl.module

val viewModelModule = module {
viewModel { AutoCompleteViewModel(get()) }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package app.sarama.aeroedge.service.autocomplete

data class AutoCompleteInputConfiguration(
// Minimum number of words to be taken from the end of the input text
val minWordCount: Int = 5,
// Maximum number of words to be taken from the end of the input text
val maxWordCount: Int = 50,
// Initially selected value for number of words to be taken from the end of the input text
val initialWordCount: Int = 20,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package app.sarama.aeroedge.service.autocomplete

import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.StateFlow

interface AutoCompleteService {

val initializationStatus: InitializationStatus

val inputConfiguration: AutoCompleteInputConfiguration

suspend fun loadModel(scope: CoroutineScope): StateFlow<InitializationStatus>

suspend fun autocomplete(input: String, applyWindow: Boolean = false, windowSize: Int = 50): Result<List<String>>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package app.sarama.aeroedge.service.autocomplete

sealed class AutoCompleteServiceError(message: String): Throwable(message) {

data object ModelAlreadyLoading: AutoCompleteServiceError("Model already loading, no need to load it again.")
data object NoSuggestion: AutoCompleteServiceError("No autocomplete suggestion found.")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package app.sarama.aeroedge.service.autocomplete

import androidx.annotation.WorkerThread
import app.sarama.aeroedge.AeroEdge
import app.sarama.aeroedge.ModelStatusUpdate
import app.sarama.aeroedge.ModelType
import app.sarama.aeroedge.util.splitToWords
import app.sarama.aeroedge.util.trimToMaxWordCount
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.stateIn
import kotlinx.coroutines.withContext
import org.tensorflow.lite.Interpreter
import java.io.File
import java.io.FileInputStream
import java.nio.ByteBuffer
import java.nio.MappedByteBuffer
import java.nio.channels.FileChannel
import kotlin.math.min


class AutoCompleteServiceImpl(
private val aeroEdge: AeroEdge,
private val dispatcher: CoroutineDispatcher = Dispatchers.IO,
) : AutoCompleteService {

private val modelStatusFlow =
MutableStateFlow<InitializationStatus>(InitializationStatus.NotInitialized)
private var interpreter: Interpreter? = null
private val outputBuffer = ByteBuffer.allocateDirect(OutputBufferSize)
override val initializationStatus: InitializationStatus
get() = modelStatusFlow.value

override val inputConfiguration = AutoCompleteInputConfiguration(
// Minimum number of words to be taken from the end of the input text
minWordCount = 5,
// Maximum number of words to be taken from the end of the input text, limited by what the model allows
maxWordCount = min(50, MaxInputWordCount),
// Initially selected value for number of words to be taken from the end of the input text
initialWordCount = 20
)

override suspend fun loadModel(scope: CoroutineScope) = aeroEdge
.getModel(ModelName, ModelType.TensorFlowLite)
.onEach {
if(it is ModelStatusUpdate.OnCompleted) {
this.interpreter = Interpreter(it.model.localFile.fileChannel)
println("[AEROEDGE] Interpreter loaded!")
}
}
.map {
when (it) {
is ModelStatusUpdate.OnCompleted -> InitializationStatus.Initialized
is ModelStatusUpdate.InProgress -> InitializationStatus.Initializing(it.progress)
is ModelStatusUpdate.OnFailed -> InitializationStatus.Error(it.exception)
}
}
.stateIn(scope)


private val File.fileChannel: MappedByteBuffer
get() = FileInputStream(this).channel.map(FileChannel.MapMode.READ_ONLY, 0, length())

override suspend fun autocomplete(
input: String,
applyWindow: Boolean,
windowSize: Int,
): Result<List<String>> = withContext(dispatcher) {
val maxInputWordCount = if (applyWindow) windowSize else MaxInputWordCount
val trimmedInput = input.trimToMaxWordCount(maxInputWordCount)

val output = runInterpreterOn(trimmedInput)

if (output.length < trimmedInput.length) {
return@withContext Result.failure(AutoCompleteServiceError.NoSuggestion)
}

val newText = output.substring(output.indexOf(trimmedInput) + trimmedInput.length)
val words = newText.splitToWords()
if (words.isEmpty()) {
return@withContext Result.failure(AutoCompleteServiceError.NoSuggestion)
}

Result.success(words)
}

@WorkerThread
private fun runInterpreterOn(input: String): String {
outputBuffer.clear()

// Run interpreter, which will generate text into outputBuffer
interpreter?.run(input, outputBuffer)

// Set output buffer limit to current position & position to 0
outputBuffer.flip()

// Get bytes from output buffer
val bytes = ByteArray(outputBuffer.remaining())
outputBuffer.get(bytes)

outputBuffer.clear()

// Return bytes converted to String
return String(bytes, Charsets.UTF_8)
}

private companion object {
const val ModelName = "autocomplete"
const val MaxInputWordCount = 1024
const val OutputBufferSize = 800
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package app.sarama.aeroedge.service.autocomplete

sealed class InitializationStatus {

data object NotInitialized: InitializationStatus()
data class Initializing(val progress: Float): InitializationStatus()
data object Initialized: InitializationStatus()
data class Error(val exception: Throwable): InitializationStatus()
}
Loading