From ad39557e7b1f299ffbccabce29d79710e71cfb10 Mon Sep 17 00:00:00 2001 From: Nikhil Rao Date: Wed, 24 Apr 2024 16:34:44 -0700 Subject: [PATCH 1/2] Upgrade for new radix --- custom_components/reflex_chat/components.py | 8 +++----- pyproject.toml | 8 +++++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/custom_components/reflex_chat/components.py b/custom_components/reflex_chat/components.py index c50204a..fa66b2c 100644 --- a/custom_components/reflex_chat/components.py +++ b/custom_components/reflex_chat/components.py @@ -43,11 +43,9 @@ def action_bar(ChatState) -> rx.Component: """The action bar to send a new message.""" return rx.form( rx.hstack( - rx.input.root( - rx.input.input( - placeholder="Type something...", - id=ChatState.__name__, - ), + rx.input( + placeholder="Type something...", + id=ChatState.__name__, width="100%", ), rx.spacer(), diff --git a/pyproject.toml b/pyproject.toml index df0c7e0..5234cb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,12 +4,12 @@ build-backend = "setuptools.build_meta" [project] name = "reflex-chat" -version = "0.0.1" +version = "0.0.2a1" description = "Reflex custom component chat" readme = "README.md" license = { text = "Apache-2.0" } requires-python = ">=3.8" -authors = [{ name = "", email = "nikhil@reflex.dev" }] +authors = [{ name = "Nikhil Rao", email = "nikhil@reflex.dev" }] keywords = ["reflex","reflex-custom-components"] dependencies = ["reflex>=0.4.7"] @@ -17,10 +17,12 @@ dependencies = ["reflex>=0.4.7"] classifiers = ["Development Status :: 4 - Beta"] [project.urls] -Homepage = "https://github.com/picklelo/reflex-chat" +homepage = "https://github.com/picklelo/reflex-chat" +source = "https://github.com/picklelo/reflex-chat" [project.optional-dependencies] dev = ["build", "twine"] + [tool.setuptools.packages.find] where = ["custom_components"] From 590f946f7c0aa6f502ff00007d73069bedc42288 Mon Sep 17 00:00:00 2001 From: Nikhil Rao Date: Tue, 17 Dec 2024 19:14:14 -0800 Subject: [PATCH 2/2] Update chat to follow block standards --- chat_demo/.gitignore | 3 +- chat_demo/chat_demo/chat_demo.py | 31 +++++++++-- custom_components/reflex_chat/api.py | 7 ++- custom_components/reflex_chat/chat.py | 74 +++++++++++++-------------- pyproject.toml | 7 +-- 5 files changed, 73 insertions(+), 49 deletions(-) diff --git a/chat_demo/.gitignore b/chat_demo/.gitignore index eab0d4b..84a3ecc 100644 --- a/chat_demo/.gitignore +++ b/chat_demo/.gitignore @@ -1,4 +1,5 @@ +assets/external/ *.db *.py[cod] .web -__pycache__/ \ No newline at end of file +__pycache__/ diff --git a/chat_demo/chat_demo/chat_demo.py b/chat_demo/chat_demo/chat_demo.py index 98239cf..953e104 100644 --- a/chat_demo/chat_demo/chat_demo.py +++ b/chat_demo/chat_demo/chat_demo.py @@ -1,10 +1,34 @@ +from reflex_chat import api, chat + import reflex as rx -from reflex_chat import chat, api -chat1 = chat(process=api.openai(model="gpt-3.5-turbo")) -chat2 = chat(process=api.openai(model="gpt-4")) +class State(rx.State): + messages1: list[dict[str, str]] = [ + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi! How can I help you today?"}, + ] + messages2: list[dict[str, str]] = [ + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi! How can I help you today?"}, + ] + + def clear_messages(self): + self.messages1 = [] + self.messages2 = [] + + def trigger_other_submit(self, id: str, message: str): + if id == chat1.State.get_id(): + yield chat2.State.process_message(message) + else: + yield chat1.State.process_message(message) +chat1 = chat( + process=api.openai(model="gpt-3.5-turbo"), + messages=State.messages1, + on_submit=State.trigger_other_submit, +) +chat2 = chat(process=api.openai(model="gpt-4"), messages=State.messages2, on_submit=State.trigger_other_submit) @rx.page() def index() -> rx.Component: @@ -12,6 +36,7 @@ def index() -> rx.Component: rx.hstack( chat1, chat2, + rx.button("Clear", on_click=State.clear_messages), height="100vh", ), size="4", diff --git a/custom_components/reflex_chat/api.py b/custom_components/reflex_chat/api.py index 53dd68d..964955d 100644 --- a/custom_components/reflex_chat/api.py +++ b/custom_components/reflex_chat/api.py @@ -27,18 +27,17 @@ def openai( from openai import OpenAI client = client or OpenAI(api_key=os.getenv("OPENAI_API_KEY")) - async def process(chat): + async def process(messages: list[dict[str, str]]): # Start a new session to answer the question. session = client.chat.completions.create( model=model, - messages=chat.get_messages(), + messages=messages, stream=True, ) # Stream the results, yielding after every word. for item in session: delta = item.choices[0].delta.content - chat.append_to_response(delta) - yield + yield delta return process diff --git a/custom_components/reflex_chat/chat.py b/custom_components/reflex_chat/chat.py index 4d14056..f3ab0c1 100644 --- a/custom_components/reflex_chat/chat.py +++ b/custom_components/reflex_chat/chat.py @@ -1,30 +1,32 @@ """A custom component for a chat interface.""" from __future__ import annotations -from typing import Callable, Generator, Type +from typing import Any, Callable, Generator, Type, ClassVar import reflex as rx + from .api import default_process -from .components import chat_bubble, action_bar +from .components import action_bar, chat_bubble +print("hi") class Chat(rx.ComponentState): """A chat component with state.""" # The full chat history, in the OpenAI format. - messages: list[dict[str, str]] = [] + messages: ClassVar[rx.Var[list[dict[str, str]]]] = [] # Whether we are processing a message. processing: bool = False - def get_messages(self) -> list[dict[str, str]]: + async def get_messages(self) -> list[dict[str, Any]]: """Return the chat history including the last submitted user message. Returns: The chat history as a list of dictionaries. """ # Convert to a list before sending. - return self.get_value(self.messages) + return await self.get_var_value(self.messages) @classmethod def get_id(cls): @@ -33,7 +35,8 @@ def get_id(cls): @classmethod def create( cls, - initial_messages: list[dict[str, str]] | None = None, + messages: rx.Var[list[dict[str, str]]], + on_submit: Callable | None = None, process: Callable[[Chat], Generator] | None = default_process, logo: rx.Component | None = rx.logo(), chat_bubble: Callable[[Type[Chat]], rx.Component] | None = chat_bubble, @@ -52,11 +55,9 @@ def create( Returns: A chat component. """ - # Set the initial value of the State var. - if initial_messages is not None: - # Update the pydantic model to use the initial value as default. - cls.__fields__["messages"].default = initial_messages return super().create( + messages=messages, + on_submit=on_submit, process=process, logo=logo, chat_bubble=chat_bubble, @@ -66,7 +67,9 @@ def create( @classmethod def get_component(cls, **props) -> rx.Component: - cls.process = props.pop("process", default_process) + cls.process = staticmethod(props.pop("process", default_process)) + cls.on_submit = props.pop("on_submit", lambda _: rx.fragment()) + cls.messages = props.pop("messages") chat_bubble = props.pop("chat_bubble", lambda message, i: rx.fragment()) action_bar = props.pop("action_bar", lambda _: rx.fragment()) @@ -110,16 +113,25 @@ def scroll_to_bottom(self): """ ) - async def process_message(self): - async for value in self.process(): - yield value + @rx.event(background=True) + async def process_message(self, message: str): + async with self: + self.processing = True + messages = await self.get_messages() + messages.append({"role": "user", "content": message}) + messages.append({"role": "assistant", "content": ""}) + yield + + async for value in self.process(messages): + await self.append_to_response(value) - self.processing = False + async with self: + self.processing = False # Scroll to the last message. - yield self.scroll_to_bottom() + # yield self.scroll_to_bottom() - def submit_message(self, form_data: dict[str, str]): + async def submit_message(self, form_data: dict[str, str]): # Get the message from the form message = form_data[self.__class__.__name__] @@ -128,33 +140,19 @@ def submit_message(self, form_data: dict[str, str]): return # Add the message to the list of messages. - self.messages.append({"role": "user", "content": message}) - self.messages.append({"role": "assistant", "content": ""}) - self.processing = True - yield self.scroll_to_bottom() - yield type(self).process_message + yield type(self).on_submit(self.get_id(), message) + # yield self.scroll_to_bottom() + yield type(self).process_message(message) - @rx.var - def last_user_message(self) -> str: - """Return the last submitted user message. - - Returns: - The last submitted user message. - """ - for message in reversed(self.messages): - if message["role"] == "user": - return message["content"] - return "" - - def append_to_response(self, answer: str): + async def append_to_response(self, answer: str): """Append to the last answer in the chat history. Args: answer: The answer to add to the chat history. """ - self.messages[-1]["content"] += answer or "" + async with self: + messages = await self.get_var_value(self.messages) + messages[-1]["content"] += answer or "" chat = Chat.create - -c1 = chat() diff --git a/pyproject.toml b/pyproject.toml index 5234cb7..49d0f7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,13 +4,13 @@ build-backend = "setuptools.build_meta" [project] name = "reflex-chat" -version = "0.0.2a1" -description = "Reflex custom component chat" +version = "0.0.2" +description = "A custom chat component that can be hooked up to any model." readme = "README.md" license = { text = "Apache-2.0" } requires-python = ">=3.8" authors = [{ name = "Nikhil Rao", email = "nikhil@reflex.dev" }] -keywords = ["reflex","reflex-custom-components"] +keywords = ["reflex", "reflex-custom-components", "chat", "llm"] dependencies = ["reflex>=0.4.7"] @@ -24,5 +24,6 @@ source = "https://github.com/picklelo/reflex-chat" dev = ["build", "twine"] + [tool.setuptools.packages.find] where = ["custom_components"]