diff --git a/backend/fastapi_generate_quiz.py b/backend/fastapi_generate_quiz.py index 510e9fa..90f80d8 100644 --- a/backend/fastapi_generate_quiz.py +++ b/backend/fastapi_generate_quiz.py @@ -24,32 +24,29 @@ @app.get("/GenerateQuiz") async def generate_quiz_endpoint(request: Request) -> JSONResponse: """ - FastAPI App to generate an image based on a provided prompt. + FastAPI endpoint to generate a quiz based on topic, difficulty, and model. - The function expects a 'prompt' parameter in the HTTP request query - or body. If a valid prompt is received, the function uses the - generate_image() function to create an image URL corresponding to - the prompt and returns it in the HTTP response. - - Parameters: - - request (Request): The FastAPI request object containing the client request. + Query Parameters: + - topic: The subject for the quiz (e.g., "UK History"). + - difficulty: The desired difficulty (e.g., "easy", "medium"). + - n_questions: (Optional) Number of questions to generate (defaults to 10). + - model: (Optional) The model to use. If not provided, the default from QuizGenerator is used. Returns: - - JSONResponse: The HTTP response object containing the generated quiz or - an appropriate error message. + - StreamingResponse: Streams quiz questions in SSE format. + - JSONResponse: Error message if required parameters are missing. """ - + # Retrieve query parameters topic = request.query_params.get("topic") difficulty = request.query_params.get("difficulty") n_questions = request.query_params.get("n_questions") + model = request.query_params.get("model") logging.info( - f"Python HTTP trigger function processed a request with {topic=} {difficulty=}, {n_questions=}." + f"Python HTTP trigger function processed a request with {topic=} {difficulty=}, {n_questions=}, model={model}." ) - # If either 'topic' or 'difficulty' is not provided in the request, - # the function will return an error message and a 400 status code. - # n_questions is optional + # If either 'topic' or 'difficulty' is missing, return an error. if not topic or not difficulty: error_message = "Please provide a topic and difficulty in the query string or in the request body to generate a quiz." logging.error(error_message) @@ -58,52 +55,53 @@ async def generate_quiz_endpoint(request: Request) -> JSONResponse: status_code=400, ) - # Set default value if not set + # Set default number of questions if not provided. if not n_questions: n_questions = 10 + else: + # Convert n_questions to an integer if provided as string. + try: + n_questions = int(n_questions) + except ValueError: + error_message = "n_questions must be an integer." + logging.error(error_message) + return JSONResponse( + content={"error": error_message}, + status_code=400, + ) logging.info( - f"Generating quiz for topic: {topic} with difficulty: {difficulty} with number of questions: {n_questions}" + f"Generating quiz with: {topic=}, {difficulty=}, {n_questions=}, {model=}." ) - # TODO: rename to quiz creator - # TODO: Fix - currently doesnt actually stream, but returns all items at once. - # Need to look into the azure functions streaming capability - # Or think about hosting the fastapi in another method e.g. ACI - quiz_generator = QuizGenerator() + # Create a QuizGenerator instance. + # TODO: rename to quiz creator ? + quiz_generator = QuizGenerator(model=model) generator = quiz_generator.generate_quiz(topic, difficulty, n_questions) + # Return the quiz as a streaming response in SSE format. return StreamingResponse(generator, media_type="text/event-stream") @app.get("/GenerateImage") async def generate_image_endpoint(request: Request) -> JSONResponse: """ - FastAPI App to generate an image based on a provided prompt. - - The function expects a 'prompt' parameter in the HTTP request query - or body. If a valid prompt is received, the function uses the - generate_image() function to create an image URL corresponding to - the prompt and returns it in the HTTP response. + FastAPI endpoint to generate an image based on a provided prompt. - Parameters: - - request (Request): The FastAPI request object containing the client request. + Query Parameters: + - prompt: The prompt for image generation. Returns: - - JSONResponse: The HTTP response object containing the image URL or - an appropriate error message. + - JSONResponse: Contains the generated image URL or an error message. """ - - logging.info("Python HTTP trigger function processed a request.") - + logging.info("Processing image generation request.") prompt = request.query_params.get("prompt") - if not prompt: error_message = "No prompt query param provided for image generation." logging.warning(error_message) return JSONResponse(content={"error": error_message}, status_code=400) - logging.info(f"Received prompt: {prompt}") + logging.info(f"Received image prompt: {prompt}") image_generator = ImageGenerator() image_url = image_generator.generate_image(prompt) @@ -112,8 +110,7 @@ async def generate_image_endpoint(request: Request) -> JSONResponse: logging.error(error_message) return JSONResponse(content={"error": error_message}, status_code=500) - # Return the image URL in the HTTP response - logging.info(f"Generated image for prompt {prompt}: {image_url}") + logging.info(f"Generated image for prompt '{prompt}': {image_url}") return JSONResponse(content={"image_url": image_url}, status_code=200) diff --git a/backend/generate_quiz.py b/backend/generate_quiz.py index 27c934b..7961724 100644 --- a/backend/generate_quiz.py +++ b/backend/generate_quiz.py @@ -1,8 +1,11 @@ from typing import Generator, Optional -from openai import OpenAI, Stream import logging import json import os +from response_stream_parser import ResponseStreamParser + +# Import the completion function from litellm (as shown in the docs example) +from litellm import completion # Set up logging logger = logging.getLogger(__name__) @@ -14,7 +17,19 @@ class QuizGenerator: - EXAMPLE_RESPONSE = json.dumps( + # Define the list of supported models. + SUPPORTED_MODELS = [ + "gpt-3.5-turbo", + "gpt-4-turbo", + "o1-mini", + "o3-mini", + "gemini/gemini-pro", + "gemini/gemini-2.0-flash", + "gemini/gemini-1.5-pro-latest", + "azure_ai/DeepSeek-R1", + ] + + example_question_1 = json.dumps( { "question_id": 1, "question": "Who was the first emperor of Rome?", @@ -23,199 +38,194 @@ class QuizGenerator: "C": "Constantine", "answer": "B", "explanation": ( - "Augustus, originally Octavian, " - "was the first to hold the title of Roman Emperor. " + "Augustus, originally Octavian, was the first to hold the title of Roman Emperor. " "Julius Caesar, while pivotal, never held the emperor title." ), - "wikipedia": r"https://en.wikipedia.org/wiki/Augustus", + "wikipedia": "https://en.wikipedia.org/wiki/Augustus", } ) - @classmethod - def get_api_key_from_env(cls) -> str: - """Retrieves the OpenAI API key from environment variables. + example_question_2 = json.dumps( + { + "question_id": 2, + "question": ( + "Which Roman Emperor is known for issuing the Edict on Maximum Prices to curb inflation, " + "and is regarded as a pivotal figure in the transition from the Principate to the Dominate?" + ), + "A": "Nero", + "B": "Diocletian", + "C": "Marcus Aurelius", + "answer": "B", + "explanation": ( + "Diocletian, who reigned from 284 to 305 AD, issued the Edict on Maximum Prices in 301 AD in an effort " + "to control rampant inflation and economic instability. His reforms marked a significant shift in the " + "structure of Roman imperial governance." + ), + "wikipedia": "https://en.wikipedia.org/wiki/Diocletian", + } + ) - Returns: - str: The API key from the environment variable OPENAI_API_KEY. + EXAMPLE_RESPONSE = example_question_1 + "\n" + example_question_2 + + @classmethod + def check_api_key_from_env(cls) -> None: + """Retrieves the API keys from environment variables. Raises: ValueError: If the environment variable is not set or empty. """ - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - raise ValueError( - "Environment variable OPENAI_API_KEY is not set. " - "Please ensure it's set and try again." - ) - return api_key - def __init__(self, api_key: Optional[str] = None): + for key in [ + "OPENAI_API_KEY", + "GEMINI_API_KEY", + "DEEPSEEK_API_KEY", + "AZURE_AI_API_KEY", + "AZURE_AI_API_BASE", + ]: + api_key = os.getenv(key) + if not api_key: + raise ValueError( + f"Environment variable {key} is not set." + "Please ensure it's set and try again." + ) + + @staticmethod + def check_model_is_supported(model: str) -> str: + """ + Validate the requested model. If it is not supported, default to "gpt-4-turbo". + + Args: + model (str): The model name to validate. + + Returns: + str: A supported model name. + """ + if model not in QuizGenerator.SUPPORTED_MODELS: + logger.warning( + f"Model '{model}' is not supported. Defaulting to 'gpt-4-turbo'." + ) + return "gpt-4-turbo" + return model + + def __init__( + self, + api_key: Optional[str] = None, + model: str = "gpt-3.5-turbo", + ): """ - Initializes the QuizGenerator by setting up the OpenAI client with the API key. - If `api_key` is not provided, it is retrieved from the environment - using `get_api_key_from_env`. + Initializes the QuizGenerator. + If `api_key` is not provided, it is retrieved from the environment. + Also validates that the requested model is one of the supported models. + If the model is not supported, defaults to "gpt-4-turbo". Args: - api_key (str, optional): The OpenAI API key to use. Defaults to None. + api_key (str, optional): The API key to use. Defaults to None. + model (str, optional): The model name to use. Defaults to "gpt-3.5-turbo". """ - if api_key is None: - api_key = self.get_api_key_from_env() + self.check_api_key_from_env() + + # Validate and set the model. + self.model = QuizGenerator.check_model_is_supported(model) - self.client = OpenAI(api_key=api_key) + # Use the separate parser class to handle the stream. + self.parser = ResponseStreamParser() def generate_quiz( self, topic: str, difficulty: str, n_questions: int = 10 ) -> Generator[str, None, None]: """ - Generate a quiz based on the provided topic and difficulty using OpenAI API. + Generate a quiz based on the provided topic and difficulty using litellm. Parameters: - - topic (str): The subject for the quiz, e.g., 'Roman History'. - - difficulty (str): The desired difficulty of the quiz e.g., 'Easy', 'Medium'. - - n_questions (int, optional): Number of questions required. Defaults to 10. + topic (str): The subject for the quiz (e.g., 'Roman History'). + difficulty (str): The desired difficulty (e.g., 'Easy', 'Medium'). + n_questions (int, optional): Number of questions required. Defaults to 10. Returns: - - str: JSON-formatted quiz questions. If an error occurs, an empty string is returned. - - This method coordinates the creation of the role for the OpenAI API, - the generation of the response, and the cleaning of the response. + Generator[str, None, None]: A generator yielding JSON-formatted quiz questions as SSE strings. """ - role = self._create_role(topic, difficulty, n_questions) - logger.info(f"Role content for OpenAI API: {role}") - openai_stream = self._create_openai_stream(role) - response_generator = self._create_question_generator(openai_stream) - - return response_generator + prompt = self._create_role(topic, difficulty, n_questions) + logger.info(f"Prompt for LLM: {prompt}") + llm_stream = self._create_llm_stream(prompt) + # Use the separate parser class to handle the stream + return self.parser.parse_stream(llm_stream) def _create_role(self, topic: str, difficulty: str, n_questions: int) -> str: """ - Creates the role string that will be sent to the OpenAI API to generate the quiz. + Creates the prompt to be sent to the LLM. Parameters: - - topic (str): The subject for the quiz. - - difficulty (str): The desired difficulty of the quiz. - - n_questions (int): Number of questions required. + topic (str): The quiz subject. + difficulty (str): The quiz difficulty. + n_questions (int): Number of questions to generate. Returns: - - str: The role string to be sent to the OpenAI API. - - This method structures the prompt for the OpenAI API to ensure consistent and correct responses. + str: The prompt string. """ return ( - f"You are an AI to generate quiz questions. " - f"You will be given a topic e.g. Roman History with a difficulty of Normal. " - f"Give {str(n_questions)} responses in a json format such as: {self.EXAMPLE_RESPONSE}. " - f"Your task is to generate similar responses for {topic} " - f"with the difficulty of {difficulty}. " + f"You are an AI that generates quiz questions. " + f"You will be given a topic (e.g., Roman History) with a difficulty level. " + f"Provide {n_questions} responses in JSON format similar to this example: \n{self.EXAMPLE_RESPONSE}. " + f"Generate similar responses for the topic '{topic}' with a difficulty of '{difficulty}'. " f"ENSURE THESE ARE CORRECT. DO NOT INCLUDE INCORRECT ANSWERS! " f"DO NOT PREFIX THE RESPONSE WITH ANYTHING EXCEPT THE RAW JSON! " - f"Return each question on a new line. " + f"Return each question on a new line." ) - def _create_openai_stream(self, role: str) -> Stream: + def _create_llm_stream(self, prompt: str): """ - Creates the stream from the OpenAI API based on the given role. - Exceptions are not caught here so that errors are visible in tests. + Creates a streaming response from litellm based on the given prompt. Parameters: - - role (str): The role string to be sent to the OpenAI API. + prompt (str): The prompt string. Returns: - - str: The raw response from the OpenAI API. + Generator: A generator yielding streamed response chunks from the LLM. """ - return self.client.chat.completions.create( - model="gpt-4-turbo-preview", - messages=[{"role": "user", "content": role}], + # The completion function supports a stream flag. + return completion( + model=self.model, + messages=[{"role": "user", "content": prompt}], stream=True, ) - def _create_question_generator( - self, openai_stream: Stream - ) -> Generator[str, None, None]: - """Parses streamed data chunks from OpenAI into complete JSON objects and yields them in SSE format. - - Accumulates data in a buffer and attempts to parse complete JSON objects. If successful, - the JSON object is yielded as a string and the buffer is cleared for the next object. - Ignores empty chunks and continues buffering if the JSON is incomplete. - - Similar-ish SSE Fast API blog: https://medium.com/@nandagopal05/server-sent-events-with-python-fastapi-f1960e0c8e4b - Helpful SO that says about the SSE format of data: {your-json}: https://stackoverflow.com/a/49486869/11902832 - - Args: - openai_stream (Stream): Stream from OpenAI's api - - Yields: - str: Complete JSON object of a quiz question in string representation. - """ - buffer = "" - for chunk in openai_stream: - chunk_contents = chunk.choices[0].delta.content - - # Ignore empty chunks. - if chunk_contents is None: - logger.debug("Chunk was empty!") - continue - - buffer += chunk_contents # Append new data to buffer - result = self.validate_and_parse_json(buffer) - - # If the JSON is incomplete, wait for more data. - if result is None: - logger.debug("JSON is incomplete, waiting for more data...") - continue - - # If the JSON is complete, yield it and clear the buffer. - yield self._format_sse(result) - buffer = "" # Clear buffer on successful parse. - - logger.info("Finished stream!") - @staticmethod - def _format_sse(json_obj: dict) -> str: - """ - Formats a JSON object as a Server-Sent Event (SSE) string. - """ - return f"data: {json.dumps(json_obj)}\n\n" - - @staticmethod - def validate_and_parse_json(s: str) -> Optional[dict]: + def print_quiz(generator: Generator[str, None, None]): """ - Helper method to validate and parse the provided string as JSON. - Returns the parsed dict if s is valid JSON, otherwise returns None if the JSON is incomplete. + Iterates through the generator and prints each quiz question. Parameters: - - s (str): The string to check. - - Returns: - - dict: The parsed JSON object, or None if the JSON is incomplete. - """ - try: - return json.loads(s) - except json.JSONDecodeError as e: - logger.debug(f"Incomplete JSON '{s}': {e.msg} at pos {e.pos}") - return None - - @staticmethod - def print_quiz(generator: Generator[str, None, None]): - """Helper function to iterate through and print the results from the question generator. - - Args: - generator (Generator[str, None, None]): Generator producing quiz questions as SSE formatted strings. + generator (Generator[str, None, None]): Generator producing quiz questions as SSE strings. """ + questions = [] try: for idx, question in enumerate(generator, start=1): logger.info(f"Item {idx}: {question}") + questions.append(question) + return questions except Exception as e: logger.error(f"Error during quiz generation: {e}") if __name__ == "__main__": - # Set logger level to DEBUG if running this file to test + # For detailed output during testing, set the logger level to DEBUG. logger.setLevel(logging.DEBUG) - quiz_generator = QuizGenerator() + suppported_models = [ + "gpt-3.5-turbo", + "gpt-4-turbo", + "o1-mini", + "o3-mini", + "gemini/gemini-pro", + "gemini/gemini-1.5-pro-latest", + "azure_ai/DeepSeek-R1", + ] + + quiz_generator = QuizGenerator(model="o1-mini") + topic = "Crested Gecko" difficulty = "Medium" - generator = quiz_generator.generate_quiz(topic, difficulty, 2) - logger.info(generator) - QuizGenerator.print_quiz(generator) + generator = quiz_generator.generate_quiz(topic, difficulty, n_questions=2) + logger.info("Starting quiz generation...") + quiz = QuizGenerator.print_quiz(generator) + logger.info(quiz) diff --git a/backend/requirements-dev.txt b/backend/requirements-dev.txt index 5570e34..4229188 100644 --- a/backend/requirements-dev.txt +++ b/backend/requirements-dev.txt @@ -1,4 +1,4 @@ -r requirements.txt ruff pytest -pytest-mock \ No newline at end of file +pytest-mock diff --git a/backend/requirements.txt b/backend/requirements.txt index b6dc488..ca1ada2 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,3 +1,4 @@ openai fastapi -uvicorn \ No newline at end of file +uvicorn +litellm diff --git a/backend/response_stream_parser.py b/backend/response_stream_parser.py new file mode 100644 index 0000000..1c53177 --- /dev/null +++ b/backend/response_stream_parser.py @@ -0,0 +1,187 @@ +import json +import logging +from typing import Generator, Optional + +logger = logging.getLogger(__name__) + + +class ResponseStreamParser: + """ + A class responsible for processing streaming responses from an LLM and parse into JSON objects. + + Parses streamed data chunks from the LLM into complete JSON objects and yields them as SSE strings. + + Accumulates data in a buffer and attempts to parse complete JSON objects. If successful, + the JSON object is yielded as a string and the buffer is cleared for the next object. + Ignores empty chunks and continues buffering if the JSON is incomplete. + + Similar-ish SSE Fast API blog: https://medium.com/@nandagopal05/server-sent-events-with-python-fastapi-f1960e0c8e4b + Helpful SO that says about the SSE format of data: {your-json}: https://stackoverflow.com/a/49486869/11902832 + + Methods: + - parse_stream(llm_stream): Processes an LLM stream and yields complete SSE-formatted JSON objects. + - _extract_chunk_content(chunk): Extracts text content from a single chunk. + - _split_buffer(): Splits the internal buffer on newline characters into complete lines and a remainder. + - _process_line(line): Parses a single line as JSON and formats it as an SSE string. + + Example: + Suppose the LLM returns chunks that, when combined, look like: + + '{"question_id": 1, "question": "Who was the first emperor of Rome?", ...}\n' + '{"question_id": 2, "question": "Which Roman Emperor issued the Edict on Maximum Prices?", ...}\n' + + The parser will: + - Accumulate these chunks into a buffer. + - Split the buffer on newlines. + - Parse each complete JSON line. + - Format each parsed JSON as: + + data: {"question_id": 1, ...}\n\n + + - Yield each formatted SSE string. + """ + + def __init__(self): + self.buffer = "" + + # Public Method + def parse_stream(self, llm_stream) -> Generator[str, None, None]: + """ + Processes the LLM stream and yields complete SSE-formatted JSON objects. + + For each chunk in the stream: + - The private method _extract_chunk_content is used to get text content. + - This content is appended to the internal buffer. + - When the buffer contains one or more newline characters, the private method _split_buffer + splits it into complete lines and a remainder. + - Each complete line is processed by _process_line to parse it as JSON and format it as an SSE string. + - The formatted string is then yielded. + + After the stream ends, any remaining data in the buffer is processed similarly. + + Args: + llm_stream: An iterable or generator yielding chunks from the LLM. + + Yields: + SSE-formatted strings, each representing a complete JSON object. + """ + for chunk in llm_stream: + # Extract text from the chunk. + content = self._extract_chunk_content(chunk) + if content is None: + logger.debug("Received an empty or invalid chunk; skipping...") + continue + + # Append the new content to the buffer. + self.buffer += content + + # If the buffer contains a newline, process the complete lines. + if "\n" in self.buffer: + complete_lines, self.buffer = self._split_buffer() + for line in complete_lines: + sse_line = self._process_line(line) + if sse_line is not None: + yield sse_line + + # After processing all chunks, process any remaining data in the buffer. + if self.buffer.strip(): + logging.warning(f"Unprocessed data in the buffer! {self.buffer=}") + sse_line = self._process_line(self.buffer) + if sse_line is not None: + yield sse_line + + logger.info("Finished processing the stream!") + + def _extract_chunk_content(self, chunk) -> Optional[str]: + """ + Extracts text content from a given chunk. + + Expected chunk structure (example): + { + "choices": [ + { + "delta": { + "content": "some text..." + } + } + ] + } + + If the chunk does not follow the expected structure, a debug message is logged, + and None is returned. + + Args: + chunk: A single chunk from the LLM stream. + + Returns: + The extracted text (str) if available; otherwise, None. + """ + try: + return chunk.choices[0].delta.content + except (AttributeError, IndexError, KeyError): + logger.debug("Chunk format unexpected or chunk is empty!") + return None + + def _split_buffer(self) -> (list[str], str): + """ + Splits the internal buffer on newline characters. + + Since each complete JSON object is expected to end with a newline, + this function splits the buffer into complete lines and a remaining + (possibly incomplete) portion. + + Example: + If self.buffer is: + '{"question": "Who was ..."}\n"question": "What is ..."}\nincomplete' + Then: + complete_lines = ['{"question": "Who was ..."}', '{"question": "What is ..."}'] + remainder = "incomplete" + + Returns: + A tuple (list of complete JSON lines, remainder string). + """ + if "\n" not in self.buffer: + return [], self.buffer # No full lines, everything is remainder + + lines = self.buffer.split("\n") + + # The remainder is empty if the buffer ends with a newline. + if self.buffer.endswith("\n"): + return lines[:-1], "" + + # Otherwise, the last line is incomplete. + return lines[:-1], lines[-1] + + def _process_line(self, line: str) -> Optional[str]: + """ + Processes a single line by parsing it as JSON and formatting it as an SSE string. + + Steps: + 1. Strip any leading or trailing whitespace. + 2. If the line is empty, return None. + 3. Attempt to parse the line as JSON. + 4. If parsing is successful, format the JSON object as an SSE string: + + data: \n\n + + 5. If parsing fails, log a debug message and return None. + + Example: + Input: '{"question_id": 1, "question": "Who was the first emperor of Rome?"}' + Output: 'data: {"question_id": 1, "question": "Who was the first emperor of Rome?"}\n\n' + + Args: + line: The line of text to process. + + Returns: + An SSE-formatted string if parsing is successful; otherwise, None. + """ + line = line.strip() + if not line: + return None + try: + json_obj = json.loads(line) + return f"data: {json.dumps(json_obj)}\n\n" + except json.JSONDecodeError as e: + logger.debug(f"Error parsing line '{line}': {e}") + return None diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 06c4975..989a10e 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,4 +1,8 @@ import sys import os -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +# Add backend directory explicitly +sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) # Add tests directory +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +) # Add backend directory diff --git a/backend/tests/test_generate_quiz.py b/backend/tests/test_generate_quiz.py index d0f70fb..804e52c 100644 --- a/backend/tests/test_generate_quiz.py +++ b/backend/tests/test_generate_quiz.py @@ -1,7 +1,7 @@ import os -import json import pytest -from types import SimpleNamespace +import logging +from unittest.mock import patch, MagicMock from backend.generate_quiz import QuizGenerator """ @@ -16,30 +16,30 @@ """ -# Fixture to create an instance of QuizGenerator with a dummy API key. @pytest.fixture def quiz_generator(monkeypatch): - # Set a dummy API key in the environment so that the class can initialize without error. + """Fixture to create an instance of QuizGenerator with dummy API keys. + So that the class can initialize without error.""" monkeypatch.setenv("OPENAI_API_KEY", "dummy_key") + monkeypatch.setenv("GEMINI_API_KEY", "dummy_key") + monkeypatch.setenv("DEEPSEEK_API_KEY", "dummy_key") + monkeypatch.setenv("AZURE_AI_API_KEY", "dummy_key") + monkeypatch.setenv("AZURE_AI_API_BASE", "https://dummy.azure.com") + return QuizGenerator() -class TestQuizGeneratorUnit: - """ - Unit tests for the QuizGenerator class. - These tests use mocks to avoid making real API calls. - """ - - def test_get_api_key_from_env(self, monkeypatch): - """ - Test that get_api_key_from_env correctly retrieves the API key from the environment. +class TestQuizGenerator: + """Unit tests for the QuizGenerator class.""" - We set the environment variable and then call the class method to verify that it returns - the expected API key. - """ - monkeypatch.setenv("OPENAI_API_KEY", "test_key") - key = QuizGenerator.get_api_key_from_env() - assert key == "test_key" + def test_check_model_is_supported(self): + """Test that unsupported models default to 'gpt-4-turbo'.""" + assert ( + QuizGenerator.check_model_is_supported("unsupported-model") == "gpt-4-turbo" + ) + assert ( + QuizGenerator.check_model_is_supported("gpt-3.5-turbo") == "gpt-3.5-turbo" + ) def test_environment_variable_not_set(self, monkeypatch): """ @@ -70,127 +70,29 @@ def test_create_role(self, quiz_generator): assert str(n_questions) in role assert quiz_generator.EXAMPLE_RESPONSE in role - def test_create_openai_stream(self, mocker, quiz_generator): - """ - Test that _create_openai_stream calls the underlying OpenAI API with the correct parameters. - - We use method patching (with mocker.patch.object) to replace the actual API call with a dummy - value, then verify that the method was called with the correct parameters. - """ - dummy_role = "dummy role string" - dummy_stream = "dummy stream" - # Patch the client's chat.completions.create method so no actual API call is made. - patcher = mocker.patch.object( - quiz_generator.client.chat.completions, "create", return_value=dummy_stream - ) - result = quiz_generator._create_openai_stream(dummy_role) - # Verify the patched method was called once with the expected arguments. - patcher.assert_called_once_with( - model="gpt-4-turbo-preview", - messages=[{"role": "user", "content": dummy_role}], - stream=True, - ) - assert result == dummy_stream - - def test_create_question_generator(self, quiz_generator): - """ - Test the _create_question_generator method by simulating a stream that yields a single chunk - containing a complete JSON string. - - We use a fake chunk (wrapped in a SimpleNamespace) to simulate what the OpenAI API might return. - """ - # Use the EXAMPLE_RESPONSE as our fake complete JSON content. - fake_json = quiz_generator.EXAMPLE_RESPONSE - fake_chunk = SimpleNamespace( - choices=[SimpleNamespace(delta=SimpleNamespace(content=fake_json))] - ) - - def fake_stream(): - # Yield a single fake chunk. - yield fake_chunk - - # Call the generator method and check that it yields the correctly formatted SSE event. - gen = quiz_generator._create_question_generator(fake_stream()) - expected = "data: " + json.dumps(json.loads(fake_json)) + "\n\n" - result = next(gen) - assert result == expected + @patch("backend.generate_quiz.completion") + def test_generate_quiz(self, mock_completion, quiz_generator): + """Test generate_quiz to ensure it streams responses properly.""" + mock_stream = iter(['{"question": "What is 2+2?", "answer": "4"}\n']) + mock_completion.return_value = mock_stream - def test_empty_chunk_in_question_generator(self, quiz_generator, mocker): - """ - Test _create_question_generator when the stream yields an empty chunk (i.e., a chunk with None content) - before yielding a valid JSON chunk. - - This verifies that the method correctly logs the empty chunk and then proceeds once valid data is received. - """ - fake_json = quiz_generator.EXAMPLE_RESPONSE - # Create a chunk that simulates an empty response. - empty_chunk = SimpleNamespace( - choices=[SimpleNamespace(delta=SimpleNamespace(content=None))] + parser_mock = MagicMock() + parser_mock.parse_stream.return_value = iter( + ['data: {"question": "What is 2+2?", "answer": "4"}\n\n'] ) - # Then a chunk that contains valid JSON. - valid_chunk = SimpleNamespace( - choices=[SimpleNamespace(delta=SimpleNamespace(content=fake_json))] - ) - - def fake_stream(): - yield empty_chunk - yield valid_chunk - - # Patch logger.debug to capture log messages about empty chunks. - logger_debug = mocker.patch("backend.generate_quiz.logger.debug") - gen = quiz_generator._create_question_generator(fake_stream()) - result = next(gen) - # Verify that the empty chunk log was produced. - logger_debug.assert_any_call("Chunk was empty!") - expected = "data: " + json.dumps(json.loads(fake_json)) + "\n\n" - assert result == expected - - def test_format_sse(self): - """ - Test that _format_sse correctly formats a JSON object as an SSE (Server-Sent Event) string. - - This is a simple helper method that should return a string starting with "data:". - """ - sample_dict = {"key": "value"} - expected = "data: " + json.dumps(sample_dict) + "\n\n" - result = QuizGenerator._format_sse(sample_dict) - assert result == expected - def test_validate_and_parse_json_valid(self): - """ - Test validate_and_parse_json with a valid JSON string. - - The method should return the corresponding Python dictionary. - """ - valid_json_str = '{"foo": "bar"}' - result = QuizGenerator.validate_and_parse_json(valid_json_str) - assert result == {"foo": "bar"} - - def test_validate_and_parse_json_incomplete(self): - """ - Test validate_and_parse_json with an incomplete JSON string. + with patch.object(quiz_generator, "parser", parser_mock): + generator = quiz_generator.generate_quiz("Math", "Easy", n_questions=1) + result = list(generator) - Since the method is designed to return None if the JSON is incomplete (not fully formed), - we expect the result to be None. - """ - incomplete_json_str = '{"foo": "bar"' - result = QuizGenerator.validate_and_parse_json(incomplete_json_str) - assert result is None - - def test_print_quiz(self, mocker, quiz_generator): - """ - Test the static print_quiz method by passing in a dummy generator. + assert result == ['data: {"question": "What is 2+2?", "answer": "4"}\n\n'] - We patch logger.info to verify that the print_quiz method logs each quiz item correctly. - """ - dummy_generator = ( - s for s in ['data: {"quiz": "q1"}\n\n', 'data: {"quiz": "q2"}\n\n'] - ) - logger_info = mocker.patch("backend.generate_quiz.logger.info") - QuizGenerator.print_quiz(dummy_generator) - # Verify that logger.info was called with the expected messages. - logger_info.assert_any_call('Item 1: data: {"quiz": "q1"}\n\n') - logger_info.assert_any_call('Item 2: data: {"quiz": "q2"}\n\n') + def test_print_quiz(self, quiz_generator, caplog): + """Test that print_quiz correctly logs the generated questions.""" + caplog.set_level(logging.INFO) + test_generator = iter(['data: {"question": "What is 2+2?", "answer": "4"}\n\n']) + result = quiz_generator.print_quiz(test_generator) + assert 'data: {"question": "What is 2+2?", "answer": "4"}\n\n' in result class TestQuizGeneratorIntegration: diff --git a/backend/tests/test_response_stream_parser.py b/backend/tests/test_response_stream_parser.py new file mode 100644 index 0000000..4acc52d --- /dev/null +++ b/backend/tests/test_response_stream_parser.py @@ -0,0 +1,114 @@ +import pytest +from types import SimpleNamespace +from backend.response_stream_parser import ResponseStreamParser + +""" +Test file for ResponseStreamParser class. + +Grouped into: +1. **Unit Tests**: Tests class behavior using mocks (no real API calls). +2. **Integration Tests**: Processes real stream responses. + +This file uses fixtures, monkeypatching, and method patching to isolate +code under test and simulate various conditions. +""" + + +@pytest.fixture +def response_parser(): + return ResponseStreamParser() + + +class TestResponseStreamParser: + """ + Unit tests for the ResponseStreamParser class. + """ + + def test_extract_chunk_content_valid(self, response_parser): + """ + Test extracting content from a valid chunk. + """ + chunk = SimpleNamespace( + choices=[SimpleNamespace(delta=SimpleNamespace(content="test content"))] + ) + content = response_parser._extract_chunk_content(chunk) + assert content == "test content" + + def test_extract_chunk_content_invalid(self, response_parser): + """ + Test that an invalid chunk returns None. + """ + chunk = SimpleNamespace(choices=[]) + content = response_parser._extract_chunk_content(chunk) + assert content is None + + def test_split_buffer(self, response_parser): + """ + Test splitting the buffer into complete lines and a remainder. + """ + response_parser.buffer = ( + '{"question": "Who was ..."}\n{"question": "What is ..."}\nincomplete' + ) + + expected_complete_lines = [ + '{"question": "Who was ..."}', + '{"question": "What is ..."}', + ] + expected_remainder = "incomplete" + + complete_lines, remainder = response_parser._split_buffer() + + print("Complete Lines:", complete_lines) + print("Remainder:", remainder) + + assert complete_lines == expected_complete_lines, ( + f"Expected complete lines '{expected_complete_lines}', but got {complete_lines}" + ) + + assert remainder == expected_remainder, ( + f"Expected remainder '{expected_remainder}', but got {remainder}" + ) + + def test_process_line_valid_json(self, response_parser): + """ + Test processing a valid JSON line. + """ + line = '{"question": "Who was the first emperor of Rome?"}' + result = response_parser._process_line(line) + assert result == 'data: {"question": "Who was the first emperor of Rome?"}\n\n' + + def test_process_line_invalid_json(self, response_parser): + """ + Test processing an invalid JSON line. + """ + line = '{"question": "Who was the first emperor of Rome?"' + result = response_parser._process_line(line) + assert result is None + + def test_parse_stream(self, response_parser): + """ + Test parsing a simulated LLM stream. + """ + fake_stream = iter( + [ + SimpleNamespace( + choices=[ + SimpleNamespace( + delta=SimpleNamespace(content='{"question": "First"}\n') + ) + ] + ), + SimpleNamespace( + choices=[ + SimpleNamespace( + delta=SimpleNamespace(content='{"question": "Second"}\n') + ) + ] + ), + ] + ) + results = list(response_parser.parse_stream(fake_stream)) + assert results == [ + 'data: {"question": "First"}\n\n', + 'data: {"question": "Second"}\n\n', + ] diff --git a/docker-compose.yml b/docker-compose.yml index e2cd39e..e6ec0a7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,6 +8,11 @@ services: - "8000:8000" environment: - OPENAI_API_KEY=${OPENAI_API_KEY} + - GEMINI_API_KEY=${GEMINI_API_KEY} + - AZURE_AI_API_BASE=${AZURE_AI_API_BASE} + - AZURE_AI_API_KEY=${AZURE_AI_API_KEY} + - DEEPSEEK_API_KEY=${DEEPSEEK_API_KEY} + # Frontend service for local testing of the static site frontend: diff --git a/frontend/index.html b/frontend/index.html index fd24bb5..7ebf229 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -31,6 +31,16 @@

GPTeasers 🧠💡

+ diff --git a/frontend/scripts/app.js b/frontend/scripts/app.js index 27718ba..e115d01 100644 --- a/frontend/scripts/app.js +++ b/frontend/scripts/app.js @@ -49,6 +49,7 @@ class App { // Get the topic and difficulty from the input field const topic = this.ui.getTopic(); const difficulty = this.ui.getDifficulty(); + const model = this.ui.getModel(); // Check if topic is empty or contains only whitespace if (!topic.trim()) { @@ -70,7 +71,7 @@ class App { // use the onQuestionReceived callback to display each question individually as it is added to the Quiz object. // Arrow function: Shorter syntax for functions and keeps 'this' context from surrounding code // Set's up quiz only when the first question is received - await this.controller.callQuizAPI(topic, difficulty, () => { + await this.controller.callQuizAPI(topic, difficulty, model, () => { if(!firstQuestionReceived){ this.showQuestion(); // Question should've been added to quiz, so display it this.ui.hideLoading(); // Hide loading clues diff --git a/frontend/scripts/controller.js b/frontend/scripts/controller.js index d69c742..ddd5f1d 100644 --- a/frontend/scripts/controller.js +++ b/frontend/scripts/controller.js @@ -24,7 +24,7 @@ class Controller { this.baseURLQuiz = `${this.baseURL}/GenerateQuiz`; this.baseURLImage = `${this.baseURL}/GenerateImage`; this.quiz = quiz; // this will be initialized as a quiz object - this.numQuestions = this.quiz.numQuestions; + this.numQuestions = this.quiz.numQuestions; } /** @@ -38,14 +38,16 @@ class Controller { * @returns {Promise} * @throws {Error} When the network response is not ok. */ - callQuizAPI(topic, difficulty, onQuestionReceived) { + callQuizAPI(topic, difficulty, model, onQuestionReceived) { console.log("Generating quiz for topic:", topic); console.log("Generating quiz with difficulty:", difficulty); + console.log("Generating quiz with model:", model); const encodedTopic = encodeURIComponent(topic); const encodedDifficulty = encodeURIComponent(difficulty); + const encodedModel = encodeURIComponent(model); const numQuestions = encodeURIComponent(this.numQuestions); - const url = `${this.baseURLQuiz}?topic=${encodedTopic}&difficulty=${encodedDifficulty}&n_questions=${numQuestions}`; + const url = `${this.baseURLQuiz}?topic=${encodedTopic}&difficulty=${encodedDifficulty}&n_questions=${numQuestions}&model=${encodedModel}`; console.log(`Connecting to SSE endpoint: ${url}`); // Promises are used to handle asynchronous operations. They represent a value that may be available now, diff --git a/frontend/scripts/ui.js b/frontend/scripts/ui.js index aa43cfd..a32ac40 100644 --- a/frontend/scripts/ui.js +++ b/frontend/scripts/ui.js @@ -6,6 +6,7 @@ class UI { this.intro = document.getElementById("intro"); this.topicInput = document.getElementById("quizTopic"); this.quizDifficulty = document.getElementById("quizDifficulty"); + this.quizModel = document.getElementById("quizModel") this.button = document.querySelector("button"); //Image elements @@ -71,6 +72,10 @@ class UI { return this.quizDifficulty.value; } + getModel() { + return this.quizModel.value; + } + // Display question in ui elements // Example currentQuestion format: // { diff --git a/frontend/static/styles.css b/frontend/static/styles.css index 7751e41..247e1ef 100644 --- a/frontend/static/styles.css +++ b/frontend/static/styles.css @@ -20,6 +20,10 @@ input[type="text"] { height: 40px; } +#quizModel { + height: 40px; +} + #quiz-container { display: none; }