Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion chat_demo/.gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
assets/external/
*.db
*.py[cod]
.web
__pycache__/
__pycache__/
31 changes: 28 additions & 3 deletions chat_demo/chat_demo/chat_demo.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,42 @@
from reflex_chat import api, chat

import reflex as rx
from reflex_chat import chat, api


chat1 = chat(process=api.openai(model="gpt-3.5-turbo"))
chat2 = chat(process=api.openai(model="gpt-4"))
class State(rx.State):
messages1: list[dict[str, str]] = [
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi! How can I help you today?"},
]
messages2: list[dict[str, str]] = [
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi! How can I help you today?"},
]

def clear_messages(self):
self.messages1 = []
self.messages2 = []

def trigger_other_submit(self, id: str, message: str):
if id == chat1.State.get_id():
yield chat2.State.process_message(message)
else:
yield chat1.State.process_message(message)

chat1 = chat(
process=api.openai(model="gpt-3.5-turbo"),
messages=State.messages1,
on_submit=State.trigger_other_submit,
)
chat2 = chat(process=api.openai(model="gpt-4"), messages=State.messages2, on_submit=State.trigger_other_submit)

@rx.page()
def index() -> rx.Component:
return rx.container(
rx.hstack(
chat1,
chat2,
rx.button("Clear", on_click=State.clear_messages),
height="100vh",
),
size="4",
Expand Down
7 changes: 3 additions & 4 deletions custom_components/reflex_chat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,17 @@ def openai(
from openai import OpenAI
client = client or OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

async def process(chat):
async def process(messages: list[dict[str, str]]):
# Start a new session to answer the question.
session = client.chat.completions.create(
model=model,
messages=chat.get_messages(),
messages=messages,
stream=True,
)

# Stream the results, yielding after every word.
for item in session:
delta = item.choices[0].delta.content
chat.append_to_response(delta)
yield
yield delta

return process
74 changes: 36 additions & 38 deletions custom_components/reflex_chat/chat.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
"""A custom component for a chat interface."""
from __future__ import annotations

from typing import Callable, Generator, Type
from typing import Any, Callable, Generator, Type, ClassVar

import reflex as rx

from .api import default_process
from .components import chat_bubble, action_bar
from .components import action_bar, chat_bubble

print("hi")

class Chat(rx.ComponentState):
"""A chat component with state."""

# The full chat history, in the OpenAI format.
messages: list[dict[str, str]] = []
messages: ClassVar[rx.Var[list[dict[str, str]]]] = []

# Whether we are processing a message.
processing: bool = False

def get_messages(self) -> list[dict[str, str]]:
async def get_messages(self) -> list[dict[str, Any]]:
"""Return the chat history including the last submitted user message.

Returns:
The chat history as a list of dictionaries.
"""
# Convert to a list before sending.
return self.get_value(self.messages)
return await self.get_var_value(self.messages)

@classmethod
def get_id(cls):
Expand All @@ -33,7 +35,8 @@ def get_id(cls):
@classmethod
def create(
cls,
initial_messages: list[dict[str, str]] | None = None,
messages: rx.Var[list[dict[str, str]]],
on_submit: Callable | None = None,
process: Callable[[Chat], Generator] | None = default_process,
logo: rx.Component | None = rx.logo(),
chat_bubble: Callable[[Type[Chat]], rx.Component] | None = chat_bubble,
Expand All @@ -52,11 +55,9 @@ def create(
Returns:
A chat component.
"""
# Set the initial value of the State var.
if initial_messages is not None:
# Update the pydantic model to use the initial value as default.
cls.__fields__["messages"].default = initial_messages
return super().create(
messages=messages,
on_submit=on_submit,
process=process,
logo=logo,
chat_bubble=chat_bubble,
Expand All @@ -66,7 +67,9 @@ def create(

@classmethod
def get_component(cls, **props) -> rx.Component:
cls.process = props.pop("process", default_process)
cls.process = staticmethod(props.pop("process", default_process))
cls.on_submit = props.pop("on_submit", lambda _: rx.fragment())
cls.messages = props.pop("messages")
chat_bubble = props.pop("chat_bubble", lambda message, i: rx.fragment())
action_bar = props.pop("action_bar", lambda _: rx.fragment())

Expand Down Expand Up @@ -110,16 +113,25 @@ def scroll_to_bottom(self):
"""
)

async def process_message(self):
async for value in self.process():
yield value
@rx.event(background=True)
async def process_message(self, message: str):
async with self:
self.processing = True
messages = await self.get_messages()
messages.append({"role": "user", "content": message})
messages.append({"role": "assistant", "content": ""})
yield

async for value in self.process(messages):
await self.append_to_response(value)

self.processing = False
async with self:
self.processing = False

# Scroll to the last message.
yield self.scroll_to_bottom()
# yield self.scroll_to_bottom()

def submit_message(self, form_data: dict[str, str]):
async def submit_message(self, form_data: dict[str, str]):
# Get the message from the form
message = form_data[self.__class__.__name__]

Expand All @@ -128,33 +140,19 @@ def submit_message(self, form_data: dict[str, str]):
return

# Add the message to the list of messages.
self.messages.append({"role": "user", "content": message})
self.messages.append({"role": "assistant", "content": ""})
self.processing = True
yield self.scroll_to_bottom()
yield type(self).process_message
yield type(self).on_submit(self.get_id(), message)
# yield self.scroll_to_bottom()
yield type(self).process_message(message)

@rx.var
def last_user_message(self) -> str:
"""Return the last submitted user message.

Returns:
The last submitted user message.
"""
for message in reversed(self.messages):
if message["role"] == "user":
return message["content"]
return ""

def append_to_response(self, answer: str):
async def append_to_response(self, answer: str):
"""Append to the last answer in the chat history.

Args:
answer: The answer to add to the chat history.
"""
self.messages[-1]["content"] += answer or ""
async with self:
messages = await self.get_var_value(self.messages)
messages[-1]["content"] += answer or ""


chat = Chat.create

c1 = chat()
8 changes: 3 additions & 5 deletions custom_components/reflex_chat/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,9 @@ def action_bar(ChatState) -> rx.Component:
"""The action bar to send a new message."""
return rx.form(
rx.hstack(
rx.input.root(
rx.input.input(
placeholder="Type something...",
id=ChatState.__name__,
),
rx.input(
placeholder="Type something...",
id=ChatState.__name__,
width="100%",
),
rx.spacer(),
Expand Down
13 changes: 8 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,26 @@ build-backend = "setuptools.build_meta"

[project]
name = "reflex-chat"
version = "0.0.1"
description = "Reflex custom component chat"
version = "0.0.2"
description = "A custom chat component that can be hooked up to any model."
readme = "README.md"
license = { text = "Apache-2.0" }
requires-python = ">=3.8"
authors = [{ name = "", email = "nikhil@reflex.dev" }]
keywords = ["reflex","reflex-custom-components"]
authors = [{ name = "Nikhil Rao", email = "nikhil@reflex.dev" }]
keywords = ["reflex", "reflex-custom-components", "chat", "llm"]

dependencies = ["reflex>=0.4.7"]

classifiers = ["Development Status :: 4 - Beta"]

[project.urls]
Homepage = "https://github.com/picklelo/reflex-chat"
homepage = "https://github.com/picklelo/reflex-chat"
source = "https://github.com/picklelo/reflex-chat"

[project.optional-dependencies]
dev = ["build", "twine"]



[tool.setuptools.packages.find]
where = ["custom_components"]