diff --git a/build.gradle.kts b/build.gradle.kts index cf65aeae..f22f0ce7 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -66,6 +66,8 @@ dependencies { implementation("org.apache.httpcomponents.client5:httpclient5:5.2.1") implementation("org.apache.tika:tika-core:2.8.0") implementation("org.flywaydb:flyway-core") + implementation("com.squareup.okhttp3:okhttp:4.9.2") + implementation("com.fasterxml.jackson.module:jackson-module-kotlin:2.12.+") implementation("io.github.oshai:kotlin-logging-jvm:5.1.0") testImplementation("org.springframework.security:spring-security-test") testImplementation("org.springframework.boot:spring-boot-starter-test") { diff --git a/src/main/java/ch/uzh/ifi/access/config/ModelMapperConfig.java b/src/main/java/ch/uzh/ifi/access/config/ModelMapperConfig.java index 57d7fdb8..499d6407 100644 --- a/src/main/java/ch/uzh/ifi/access/config/ModelMapperConfig.java +++ b/src/main/java/ch/uzh/ifi/access/config/ModelMapperConfig.java @@ -5,16 +5,15 @@ import ch.uzh.ifi.access.model.Submission; import ch.uzh.ifi.access.model.Task; import ch.uzh.ifi.access.model.constants.Visibility; -import ch.uzh.ifi.access.model.dto.AssignmentDTO; -import ch.uzh.ifi.access.model.dto.CourseDTO; -import ch.uzh.ifi.access.model.dto.SubmissionDTO; -import ch.uzh.ifi.access.model.dto.TaskDTO; +import ch.uzh.ifi.access.model.dto.*; import org.apache.commons.lang3.ObjectUtils; import org.modelmapper.ModelMapper; import org.modelmapper.convention.MatchingStrategies; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import java.util.Objects; + @Configuration public class ModelMapperConfig { // TODO: convert to Kotlin. This is very tricky, because ModelMapper is very Java-specific. Maybe replace it. diff --git a/src/main/kotlin/ch/uzh/ifi/access/config/SecurityConfig.kt b/src/main/kotlin/ch/uzh/ifi/access/config/SecurityConfig.kt index 061f1faa..b94e15fe 100644 --- a/src/main/kotlin/ch/uzh/ifi/access/config/SecurityConfig.kt +++ b/src/main/kotlin/ch/uzh/ifi/access/config/SecurityConfig.kt @@ -122,6 +122,11 @@ class SecurityConfig(private val env: Environment) { return Path.of(env.getProperty("WORKING_DIR", "/workspace/data")) } + @Bean + fun assistantServerUrl(): String { + return env.getProperty("ASSISTANT_SERVER_URL", "http://localhost:4000") + } + @Bean fun accessRealm(): RealmResource { val keycloakClient = Keycloak.getInstance( diff --git a/src/main/kotlin/ch/uzh/ifi/access/controller/CourseController.kt b/src/main/kotlin/ch/uzh/ifi/access/controller/CourseController.kt index 32b107ad..63daa839 100644 --- a/src/main/kotlin/ch/uzh/ifi/access/controller/CourseController.kt +++ b/src/main/kotlin/ch/uzh/ifi/access/controller/CourseController.kt @@ -17,6 +17,7 @@ import java.time.LocalDateTime import java.util.concurrent.Semaphore + @RestController class CourseRootController( private val courseService: CourseService, diff --git a/src/main/kotlin/ch/uzh/ifi/access/model/Task.kt b/src/main/kotlin/ch/uzh/ifi/access/model/Task.kt index 079833b4..81269c33 100644 --- a/src/main/kotlin/ch/uzh/ifi/access/model/Task.kt +++ b/src/main/kotlin/ch/uzh/ifi/access/model/Task.kt @@ -71,6 +71,45 @@ class Task { val attemptRefill: Int? get() = if (Objects.nonNull(attemptWindow)) Math.toIntExact(attemptWindow!!.toSeconds()) else null + @Column(nullable = false) + var llmSubmission: String? = null + + @Column + var llmSolution: String? = null + + @Column + var llmRubrics: String? = null + + @Column(nullable = false) + var llmCot: Boolean = false + + @Column(nullable = false) + var llmVoting: Int = 1 + + @Column + var llmExamples: String? = null + + @Column + var llmPrompt: String? = null + + @Column + var llmPre: String? = null + + @Column + var llmPost: String? = null + + @Column + var llmTemperature: Double? = null + + @Column + var llmModel: String? = null + + @Column + var llmModelFamily: String? = null + + @Column + var llmMaxPoints: Double? = null + fun createFile(): TaskFile { val newTaskFile = TaskFile() files.add(newTaskFile) diff --git a/src/main/kotlin/ch/uzh/ifi/access/model/dto/AssistantDTO.kt b/src/main/kotlin/ch/uzh/ifi/access/model/dto/AssistantDTO.kt new file mode 100644 index 00000000..d0f2a5a7 --- /dev/null +++ b/src/main/kotlin/ch/uzh/ifi/access/model/dto/AssistantDTO.kt @@ -0,0 +1,76 @@ +package ch.uzh.ifi.access.model.dto + +import com.fasterxml.jackson.annotation.JsonProperty +import lombok.Data + +@Data +class RubricDTO ( + @JsonProperty("id") + var id: String, + + @JsonProperty("title") + var title: String, + + @JsonProperty("points") + var points: Double +) + +@Data +class FewShotExampleDTO ( + @JsonProperty("answer") + var answer: String, + + @JsonProperty("points") + var points: String +) + +@Data +class AssistantDTO( + @JsonProperty("question") + var question: String, + + @JsonProperty("answer") + var answer: String, + + @JsonProperty("rubrics") + var rubrics: List? = null, + + @JsonProperty("modelSolution") + var modelSolution: String? = null, + + @JsonProperty("maxPoints") + var maxPoints: Double? = 1.0, + + @JsonProperty("minPoints") + var minPoints: Double? = 0.0, + + @JsonProperty("pointStep") + var pointStep: Double? = 0.5, + + @JsonProperty("chainOfThought") + var chainOfThought: Boolean? = true, + + @JsonProperty("votingCount") + var votingCount: Int? = 1, + + @JsonProperty("temperature") + var temperature: Double? = 0.2, + + @JsonProperty("fewShotExamples") + var fewShotExamples: List? = null, + + @JsonProperty("prePrompt") + var prePrompt: String? = null, + + @JsonProperty("prompt") + var prompt: String? = null, + + @JsonProperty("postPrompt") + var postPrompt: String? = null, + + @JsonProperty("llmType") + var llmType: String? = null, + + @JsonProperty("llmModel") + var llmModel: String? = null +) \ No newline at end of file diff --git a/src/main/kotlin/ch/uzh/ifi/access/model/dto/AssistantResponseDTO.kt b/src/main/kotlin/ch/uzh/ifi/access/model/dto/AssistantResponseDTO.kt new file mode 100644 index 00000000..c091d4da --- /dev/null +++ b/src/main/kotlin/ch/uzh/ifi/access/model/dto/AssistantResponseDTO.kt @@ -0,0 +1,45 @@ +package ch.uzh.ifi.access.model.dto + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties +import com.fasterxml.jackson.annotation.JsonProperty +import lombok.Data + +enum class Status { + correct, + incorrect, + incomplete +} + +@JsonIgnoreProperties(ignoreUnknown = true) +@Data +class AssistantResponseDTO ( + @JsonProperty("status") + var status: Status, + + @JsonProperty("feedback") + var feedback: String, + + @JsonProperty("hint") + var hint: String? = null, + + @JsonProperty("points") + var points: Double, + + @JsonProperty("votingResult") + var votingResult: String +) + +@JsonIgnoreProperties(ignoreUnknown = true) +@Data +class AssistantEvaluationResponseDTO ( + @JsonProperty("status") + var status: String, + + @JsonProperty("result") + var result: AssistantResponseDTO? +) + +class TaskIdDTO( + @JsonProperty("jobId") + var jobId: String +) \ No newline at end of file diff --git a/src/main/kotlin/ch/uzh/ifi/access/model/dto/LLMConfigDTO.kt b/src/main/kotlin/ch/uzh/ifi/access/model/dto/LLMConfigDTO.kt new file mode 100644 index 00000000..cbde82ce --- /dev/null +++ b/src/main/kotlin/ch/uzh/ifi/access/model/dto/LLMConfigDTO.kt @@ -0,0 +1,17 @@ +package ch.uzh.ifi.access.model.dto + +data class LLMConfigDTO( + var submission: String? = null, + var solution: String? = null, + var rubrics: String? = null, + var cot: Boolean = false, + var voting: Int = 1, + var examples: String? = null, + var prompt: String? = null, + var pre: String? = null, + var post: String? = null, + var temperature: Double? = null, + var model: String? = null, + var modelFamily: String? = null, + var maxPoints: Double? = null +) \ No newline at end of file diff --git a/src/main/kotlin/ch/uzh/ifi/access/model/dto/TaskDTO.kt b/src/main/kotlin/ch/uzh/ifi/access/model/dto/TaskDTO.kt index 7afb6430..94e7cd73 100644 --- a/src/main/kotlin/ch/uzh/ifi/access/model/dto/TaskDTO.kt +++ b/src/main/kotlin/ch/uzh/ifi/access/model/dto/TaskDTO.kt @@ -11,5 +11,18 @@ class TaskDTO( var maxAttempts: Int? = null, var refill: Int? = null, var evaluator: TaskEvaluatorDTO? = null, - var files: TaskFilesDTO? = null + var files: TaskFilesDTO? = null, + var llmSubmission: String? = null, + var llmSolution: String? = null, + var llmRubrics: String? = null, + var llmCot: Boolean = false, + var llmVoting: Int = 1, + var llmExamples: String? = null, + var llmPrompt: String? = null, + var llmPre: String? = null, + var llmPost: String? = null, + var llmTemperature: Double? = null, + var llmModel: String? = null, + var llmModelFamily: String? = null, + var llmMaxPoints: Double? = null, ) diff --git a/src/main/kotlin/ch/uzh/ifi/access/service/CourseConfigImporter.kt b/src/main/kotlin/ch/uzh/ifi/access/service/CourseConfigImporter.kt index 96ffa878..4d750753 100644 --- a/src/main/kotlin/ch/uzh/ifi/access/service/CourseConfigImporter.kt +++ b/src/main/kotlin/ch/uzh/ifi/access/service/CourseConfigImporter.kt @@ -1,7 +1,9 @@ package ch.uzh.ifi.access.service import ch.uzh.ifi.access.model.dto.* +import com.fasterxml.jackson.core.type.TypeReference import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.databind.node.NullNode import com.fasterxml.jackson.dataformat.toml.TomlMapper import org.springframework.stereotype.Service @@ -12,7 +14,8 @@ import java.time.LocalDateTime @Service class CourseConfigImporter( private val tomlMapper: TomlMapper, - private val fileService: FileService + private val fileService: FileService, + private val objectMapper: ObjectMapper // Needed for JSON serialization ) { fun JsonNode?.asTextOrNull(): String? { @@ -62,7 +65,6 @@ class CourseConfigImporter( course.globalFiles = files return course - } fun readAssignmentConfig(path: Path): AssignmentDTO { @@ -86,7 +88,37 @@ class CourseConfigImporter( } return assignment + } + + private fun readRubricsFromToml(path: Path, rubricsFile: String): String? { + val rubricsPath = path.resolve(rubricsFile) + if (!Files.exists(rubricsPath)) return null + + val rubricsConfig: JsonNode = tomlMapper.readTree(Files.readString(rubricsPath)) + val rubricsList = rubricsConfig["rubrics"]?.map { rubric -> + RubricDTO( + id = rubric["id"].asText(), + title = rubric["title"].asText(), + points = rubric["points"].asDouble() + ) + } ?: emptyList() + + return objectMapper.writeValueAsString(rubricsList) // Convert to JSON string + } + + private fun readExamplesFromToml(path: Path, examplesFile: String): String? { + val examplesPath = path.resolve(examplesFile) + if (!Files.exists(examplesPath)) return null + val examplesConfig: JsonNode = tomlMapper.readTree(Files.readString(examplesPath)) + val examplesList = examplesConfig["examples"]?.map { example -> + FewShotExampleDTO( + answer = example["answer"].asText(), + points = objectMapper.convertValue(example["points"], object : TypeReference>() {}).toString() + ) + } ?: emptyList() + + return objectMapper.writeValueAsString(examplesList) // Convert to JSON string } fun readTaskConfig(path: Path): TaskDTO { @@ -130,15 +162,43 @@ class CourseConfigImporter( "solution" -> files.solution = filenames "persist" -> files.persist = filenames } + + val llmConfig = config["llm"] + if (llmConfig != null && !llmConfig.isNull) { + val submissionFileName = llmConfig["submission"]?.asTextOrNull() + val solutionContent = llmConfig["solution"]?.asTextOrNull()?.let { Files.readString(path.resolve(it)) } + val rubricsJson = llmConfig["rubrics"]?.asTextOrNull()?.let { readRubricsFromToml(path, it) } + val examplesJson = llmConfig["examples"]?.asTextOrNull()?.let { readExamplesFromToml(path, it) } + val promptContent = llmConfig["prompt"]?.asTextOrNull()?.let { Files.readString(path.resolve(it)) } + val preContent = llmConfig["pre"]?.asTextOrNull()?.let { Files.readString(path.resolve(it)) } + val postContent = llmConfig["post"]?.asTextOrNull()?.let { Files.readString(path.resolve(it)) } + val temperature = llmConfig["temperature"]?.asDouble() + val model = llmConfig["model"]?.asTextOrNull() + val modelFamily = llmConfig["model_family"]?.asTextOrNull() + + task.llmSubmission = submissionFileName + task.llmSolution = solutionContent + task.llmRubrics = rubricsJson + task.llmCot = llmConfig["cot"]?.asBoolean() ?: false + task.llmVoting = llmConfig["voting"]?.asInt() ?: 1 + task.llmExamples = examplesJson + task.llmPrompt = promptContent + task.llmPre = preContent + task.llmPost = postContent + task.llmTemperature = temperature + task.llmModel = model + task.llmModelFamily = modelFamily + task.llmMaxPoints = llmConfig["max_points"].asDouble() + } + } task.files = files return task - } - } class InvalidCourseException : Throwable() { + } \ No newline at end of file diff --git a/src/main/kotlin/ch/uzh/ifi/access/service/CourseService.kt b/src/main/kotlin/ch/uzh/ifi/access/service/CourseService.kt index 9adf5b0c..47e513c4 100644 --- a/src/main/kotlin/ch/uzh/ifi/access/service/CourseService.kt +++ b/src/main/kotlin/ch/uzh/ifi/access/service/CourseService.kt @@ -8,6 +8,7 @@ import ch.uzh.ifi.access.model.dao.Results import ch.uzh.ifi.access.model.dto.* import ch.uzh.ifi.access.projections.* import ch.uzh.ifi.access.repository.* +import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.databind.json.JsonMapper import com.github.dockerjava.api.DockerClient import com.github.dockerjava.api.command.PullImageResultCallback @@ -18,10 +19,17 @@ import com.github.dockerjava.api.model.HostConfig import io.github.oshai.kotlinlogging.KotlinLogging import jakarta.transaction.Transactional import jakarta.xml.bind.DatatypeConverter +import okhttp3.Request +import okhttp3.OkHttpClient +import okhttp3.MediaType.Companion.toMediaTypeOrNull +import okhttp3.RequestBody +import okhttp3.RequestBody.Companion.toRequestBody +import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper import org.apache.commons.collections4.ListUtils import org.apache.commons.io.FileUtils import org.apache.tika.Tika import org.keycloak.representations.idm.UserRepresentation +import org.keycloak.util.JsonSerialization.mapper import org.modelmapper.ModelMapper import org.springframework.cache.annotation.CacheEvict import org.springframework.cache.annotation.Cacheable @@ -45,6 +53,7 @@ import java.util.function.Consumer import java.util.stream.Stream import javax.crypto.Mac import javax.crypto.spec.SecretKeySpec +import kotlin.math.log @Service class CourseServiceForCaching( @@ -82,11 +91,13 @@ class CourseService( private val evaluationRepository: EvaluationRepository, private val dockerClient: DockerClient, private val modelMapper: ModelMapper, + private val objectMapper: ObjectMapper, private val jsonMapper: JsonMapper, private val courseLifecycle: CourseLifecycle, private val roleService: RoleService, private val fileService: FileService, - private val tika: Tika + private val tika: Tika, + private val assistantServerUrl: String ) { private val logger = KotlinLogging.logger {} @@ -113,27 +124,29 @@ class CourseService( } } + private val client = OkHttpClient() + private fun verifyUserId(@Nullable userId: String?): String { return userId ?: SecurityContextHolder.getContext().authentication.name } fun getCourseBySlug(courseSlug: String): Course { return courseRepository.getBySlug(courseSlug) ?: - throw ResponseStatusException(HttpStatus.NOT_FOUND, "No course found with the URL $courseSlug") + throw ResponseStatusException(HttpStatus.NOT_FOUND, "No course found with the URL $courseSlug") } fun getCourseWorkspaceBySlug(courseSlug: String): CourseWorkspace { return courseRepository.findBySlug(courseSlug) ?: - throw ResponseStatusException(HttpStatus.NOT_FOUND, "No course found with the URL $courseSlug") + throw ResponseStatusException(HttpStatus.NOT_FOUND, "No course found with the URL $courseSlug") } fun getTaskById(taskId: Long): Task { return taskRepository.findById(taskId).get() ?: - throw ResponseStatusException(HttpStatus.NOT_FOUND, "No task found with the ID $taskId") + throw ResponseStatusException(HttpStatus.NOT_FOUND, "No task found with the ID $taskId") } fun getTaskFileById(fileId: Long): TaskFile { return taskFileRepository.findById(fileId).get() ?: - throw ResponseStatusException(HttpStatus.NOT_FOUND, "No task file found with the ID $fileId") + throw ResponseStatusException(HttpStatus.NOT_FOUND, "No task file found with the ID $fileId") } fun getCoursesOverview(): List { //return courseRepository.findCoursesBy() @@ -147,7 +160,7 @@ class CourseService( fun getCourseSummary(courseSlug: String): CourseSummary { return courseRepository.findCourseBySlug(courseSlug) ?: - throw ResponseStatusException(HttpStatus.NOT_FOUND, "No course found with the URL $courseSlug") + throw ResponseStatusException(HttpStatus.NOT_FOUND, "No course found with the URL $courseSlug") } fun enabledTasksOnly(tasks: List): List { @@ -161,23 +174,23 @@ class CourseService( fun getAssignment(courseSlug: String?, assignmentSlug: String): AssignmentWorkspace { return assignmentRepository.findByCourse_SlugAndSlug(courseSlug, assignmentSlug) ?: - throw ResponseStatusException( HttpStatus.NOT_FOUND, - "No assignment found with the URL $assignmentSlug" ) + throw ResponseStatusException( HttpStatus.NOT_FOUND, + "No assignment found with the URL $assignmentSlug" ) } fun getAssignmentBySlug(courseSlug: String?, assignmentSlug: String): Assignment { return assignmentRepository.getByCourse_SlugAndSlug(courseSlug, assignmentSlug) ?: throw ResponseStatusException( - HttpStatus.NOT_FOUND, - "No assignment found with the URL $assignmentSlug" - ) + HttpStatus.NOT_FOUND, + "No assignment found with the URL $assignmentSlug" + ) } fun getTask(courseSlug: String?, assignmentSlug: String?, taskSlug: String?, userId: String?): TaskWorkspace { val workspace = taskRepository.findByAssignment_Course_SlugAndAssignment_SlugAndSlug(courseSlug, assignmentSlug, taskSlug) ?: throw ResponseStatusException( HttpStatus.NOT_FOUND, - "No task found with the URL: $courseSlug/$assignmentSlug/$taskSlug" ) + "No task found with the URL: $courseSlug/$assignmentSlug/$taskSlug" ) workspace.setUserId(userId) return workspace } @@ -322,8 +335,8 @@ class CourseService( assignmentSlug, taskSlug ) ?: throw ResponseStatusException( - HttpStatus.NOT_FOUND, "No task found with the URL $taskSlug" - ) + HttpStatus.NOT_FOUND, "No task found with the URL $taskSlug" + ) } fun getGradingFiles(taskId: Long?): List { @@ -426,6 +439,81 @@ class CourseService( } } + // Evaluates a submission using the assistant API + fun evaluateSubmissionWithAssistant(submission: AssistantDTO): AssistantResponseDTO? { + val url = "$assistantServerUrl/evaluate" + + val requestBodyJson = mapper.writeValueAsString(submission) + logger.debug { "Request body: $requestBodyJson" } + + val requestBody = requestBodyJson.toRequestBody("application/json".toMediaTypeOrNull()) + val request = Request.Builder() + .url(url) + .post(requestBody) + .build() + + client.newCall(request).execute().use { response -> + if (!response.isSuccessful) { + throw RuntimeException("Failed to get response from assistant backend: ${response.message}") + } + + val responseBody = response.body?.string() ?: throw RuntimeException("Empty response body") + val taskId = mapper.readValue(responseBody, TaskIdDTO::class.java).jobId + + if (taskId.isNullOrBlank()) { + throw RuntimeException("Invalid taskId received from assistant backend") + } + + // Polling loop + var attempts = 0 + val maxAttempts = 20 // Adjust as needed + val delayMillis = 2000L // 2 seconds delay per attempt + + while (attempts < maxAttempts) { + val statusUrl = "$assistantServerUrl/evaluate/$taskId" + val statusRequest = Request.Builder().url(statusUrl).get().build() + client.newCall(statusRequest).execute().use { statusResponse -> + if (!statusResponse.isSuccessful) { + throw RuntimeException("Failed to fetch evaluation status: ${statusResponse.message}") + } + + val statusBody = statusResponse.body?.string() + val statusResponseDTO = mapper.readValue(statusBody, AssistantEvaluationResponseDTO::class.java) + + when (statusResponseDTO.status) { + "completed" -> return statusResponseDTO.result // Return completed result + "not_found" -> throw RuntimeException("Task not found in assistant backend") + "delayed" -> { + // Keep polling + logger.debug { "Task $taskId delayed, retrying..." } + } + "active" -> { + // Keep polling + logger.debug { "Task $taskId still active, retrying..." } + } + else -> throw RuntimeException("Unexpected status: ${statusResponseDTO.status}") + } + } + Thread.sleep(delayMillis) // Wait before next attempt + attempts++ + } + throw RuntimeException("Evaluation timed out for taskId: $taskId") + } + } + + fun parseJsonOrEmpty(json: String?, objectMapper: ObjectMapper, clazz: Class>): List { + return try { + if (!json.isNullOrBlank()) { + objectMapper.readValue(json, clazz).toList() + } else { + emptyList() + } + } catch (e: Exception) { + emptyList() // Fallback to empty list if parsing fails + } + } + + @Caching(evict = [ CacheEvict(value = ["getStudent"], key = "#courseSlug + '-' + #submissionDTO.userId"), CacheEvict(value = ["studentWithPoints"], key = "#courseSlug + '-' + #submissionDTO.userId"), @@ -454,6 +542,7 @@ class CourseService( ) } val submission = submissionRepository.saveAndFlush(newSubmission) + //pruneSubmissions(evaluation) submissionDTO.files.stream().filter { fileDTO -> fileDTO.content != null } .forEach { fileDTO: SubmissionFileDTO -> createSubmissionFile(submission, fileDTO) } @@ -500,7 +589,7 @@ class CourseService( // TODO: make this size configurable in task config.toml? val resultFileSizeLimit = convertSizeToBytes("100K") val persistentFileCopyCommands = task.persistentResultFilePaths.joinToString("\n") { path -> -""" + """ # Check if results file exceeds permissible size limit if [[ -f "$path" ]]; then actual_size=${'$'}(stat -c%s "$path") @@ -514,7 +603,7 @@ cp "$path" "/submission/${'$'}file_dir" """ } val command = ( -""" + """ # copy submitted files to tmpfs /bin/cp -R /submission/* /workspace/; # run command (the cwd is set to /workspace already) @@ -533,7 +622,7 @@ fi $persistentFileCopyCommands exit ${'$'}exit_code; """ - ) + ) val container = containerCmd .withLabels(mapOf("userId" to submission.userId)).withWorkingDir("/workspace") .withCmd("/bin/bash", "-c", command) @@ -636,9 +725,73 @@ exit ${'$'}exit_code; FileUtils.deleteQuietly(submissionDir.toFile()) } } + + + // Only evaluate with assistant if the submission file was defined in the config + if(task.llmSubmission != null) { + // Evaluate with Assistant API + try { + // Collect the student's submitted code + val llmSubmissionFile = submission.files + .firstOrNull { file -> file.taskFile?.path == "/${task.llmSubmission}" } // Find the specific file + + submission.files.map { + logger.debug { "Submission file: ${it.taskFile?.path}" } + } + + var assistantResponse: AssistantResponseDTO? = null + + if (llmSubmissionFile?.content != null) { + assistantResponse = evaluateSubmissionWithAssistant( + AssistantDTO( + question = task.instructions ?: "No instructions provided", + answer = llmSubmissionFile.content ?: "No answer provided", + llmType = task.llmModelFamily, + chainOfThought = task.llmCot, + votingCount = task.llmVoting, + rubrics = parseJsonOrEmpty(task.llmRubrics, objectMapper, Array::class.java), + prePrompt = task.llmPre, + postPrompt = task.llmPost, + prompt = task.llmPrompt, + temperature = task.llmTemperature, + fewShotExamples = parseJsonOrEmpty( + task.llmExamples, + objectMapper, + Array::class.java + ), + maxPoints = task.llmMaxPoints, + modelSolution = task.llmSolution, + llmModel = task.llmModel, + ) + ) + } + + + + if (assistantResponse != null) { + logger.debug { "Assistant internal feedback: ${assistantResponse.feedback}" } + logger.debug { "Assistant scoring: ${assistantResponse.points}" } + } + + // Incorporate the assistant feedback into the submission + assistantResponse?.let { + if (it.hint != null && it.hint != "" && it.status !== Status.correct) { + newSubmission.logs += "\nHint: ${it.hint}" + } + val newPoints = newSubmission.points?.plus(it.points) + newSubmission.points = minOf(newPoints!!, newSubmission.maxPoints!!) + evaluation.update(newSubmission.points) + } + } catch (e: Exception) { + // print error + logger.error { "Failed to evaluate submission with assistant: ${e.message}" } + } + } + } catch (e: Exception) { newSubmission.output = "Uncaught ${e::class.simpleName}: ${e.message}. Please report this as a bug and provide as much detail as possible." } + submissionRepository.save(newSubmission) } diff --git a/src/main/resources/db/migration/V2_1__add_llm_fields.sql b/src/main/resources/db/migration/V2_1__add_llm_fields.sql new file mode 100644 index 00000000..405fb841 --- /dev/null +++ b/src/main/resources/db/migration/V2_1__add_llm_fields.sql @@ -0,0 +1,14 @@ +ALTER TABLE task + ADD COLUMN IF NOT EXISTS llm_submission TEXT, + ADD COLUMN IF NOT EXISTS llm_solution TEXT, + ADD COLUMN IF NOT EXISTS llm_rubrics TEXT, + ADD COLUMN IF NOT EXISTS llm_cot BOOLEAN, + ADD COLUMN IF NOT EXISTS llm_voting INTEGER, + ADD COLUMN IF NOT EXISTS llm_examples TEXT, + ADD COLUMN IF NOT EXISTS llm_prompt TEXT, + ADD COLUMN IF NOT EXISTS llm_pre TEXT, + ADD COLUMN IF NOT EXISTS llm_post TEXT, + ADD COLUMN IF NOT EXISTS llm_temperature DOUBLE PRECISION, + ADD COLUMN IF NOT EXISTS llm_model TEXT, + ADD COLUMN IF NOT EXISTS llm_model_family TEXT, + ADD COLUMN IF NOT EXISTS llm_max_points DOUBLE PRECISION;