diff --git a/samples/kumo_powered_personalized_shopping_chatbot/README.md b/samples/kumo_powered_personalized_shopping_chatbot/README.md new file mode 100644 index 00000000..53bbf01a --- /dev/null +++ b/samples/kumo_powered_personalized_shopping_chatbot/README.md @@ -0,0 +1,19 @@ +# Personalized Shopping Chatbot Powered by Kumo in Snowflake Native App + +This package contains the code used in the Kumo/Snowflake solution center “Personalized Shopping Chatbot Powered by Kumo in Snowflake Native App” + +## Description + +The scripts used in the quickstart are as follows: + +- [chatbot.py](http://chatbot.py) - final code for rendering chatbot +- i2i.sql - code for setting up item to item search service +- i2u.sql - code for setting up item to user search service +- customer_info.sql - code for formatting customer information table +- item_info.sql - code for formatting item information table +- format_i2i.sql - code for formatting item to item predictions table +- format_i2u.sql - code for formatting item to user predictions table + +## More Information + +Please see the following for information on how to use these scripts: \ No newline at end of file diff --git a/samples/kumo_powered_personalized_shopping_chatbot/chatbot.py b/samples/kumo_powered_personalized_shopping_chatbot/chatbot.py new file mode 100644 index 00000000..b7324ca3 --- /dev/null +++ b/samples/kumo_powered_personalized_shopping_chatbot/chatbot.py @@ -0,0 +1,367 @@ +import streamlit as st +from snowflake.core import Root # requires snowflake>=0.8.0 +from snowflake.snowpark.context import get_active_session + +MODELS = [ + "mistral-large", + "snowflake-arctic", + "llama3-70b", + "llama3-8b", +] + +def init_messages(): + """ + Initialize the session state for chat messages. If the session state indicates that the + conversation should be cleared or if the "messages" key is not in the session state, + initialize it as an empty list. + """ + if st.session_state.clear_conversation or "messages" not in st.session_state: + st.session_state.messages = [] + +def init_customer_id(): + """ + Initialize or reset the customer ID in the session state. + + """ + if "customer_id" not in st.session_state: + st.session_state.customer_id = None + +def init_service_metadata(): + """ + Initialize the session state for cortex search service metadata. Query the available + cortex search services from the Snowflake session and store their names and search + columns in the session state. + """ + if "service_metadata" not in st.session_state: + services = session.sql("SHOW CORTEX SEARCH SERVICES;").collect() + service_metadata = [] + if services: + for s in services: + svc_name = s["name"] + svc_search_col = session.sql( + f"DESC CORTEX SEARCH SERVICE {svc_name};" + ).collect()[0]["search_column"] + service_metadata.append( + {"name": svc_name, "search_column": svc_search_col} + ) + + st.session_state.service_metadata = service_metadata + +def init_config_options(): + """ + Initialize the configuration options in the Streamlit sidebar. Allow the user to select + a cortex search service, clear the conversation, toggle debug mode, and toggle the use of + chat history. Also provide advanced options to select a model, the number of context chunks, + and the number of chat messages to use in the chat history. + """ + # st.sidebar.selectbox( + # "Select cortex search service:", + # [s["name"] for s in st.session_state.service_metadata], + # key="selected_cortex_search_service", + # ) + + st.sidebar.button("Clear conversation", key="clear_conversation") + st.sidebar.toggle("Debug", key="debug", value=False) + st.sidebar.toggle("Use chat history", key="use_chat_history", value=True) + + with st.sidebar.expander("Advanced options"): + st.selectbox("Select model:", MODELS, key="model_name") + st.number_input( + "Select number of context chunks", + value=5, + key="num_retrieved_chunks", + min_value=1, + max_value=10, + ) + st.number_input( + "Select number of messages to use in chat history", + value=5, + key="num_chat_messages", + min_value=1, + max_value=10, + ) + + st.sidebar.expander("Session State").write(st.session_state) + +def add_login_feature(): + """ + Add a login feature to input and persist the customer ID. + + """ + + with st.sidebar: + st.header("Customer Login") + # Customer ID input + customer_id = st.text_input( + "Enter Customer ID", + value=st.session_state.customer_id or "", + placeholder="input ID here" + ) + + # Update customer ID in session state + if st.button("Log In"): + if customer_id.strip(): # Ensure input is not empty + st.session_state.customer_id = customer_id.strip() + st.success(f"Logged in as Customer ID: {st.session_state.customer_id}") + else: + st.error("Customer ID cannot be empty.") + + # Logout option + if st.session_state.customer_id: + if st.button("Log Out"): + st.session_state.customer_id = None + st.info("Logged out.") + + + +def query_cortex_search_service(query, selected_service_name): + """ + Query the selected cortex search service with the given query and retrieve context documents. + Display the retrieved context documents in the sidebar if debug mode is enabled. Return the + context documents as a string. + + Args: + query (str): The query to search the cortex search service with. + + Returns: + str: The concatenated string of context documents. + """ + db, schema = session.get_current_database(), session.get_current_schema() + + cortex_search_service = ( + root.databases[db] + .schemas[schema] + .cortex_search_services[selected_service_name] + ) + + context_documents = cortex_search_service.search( + query, columns=[], limit=st.session_state.num_retrieved_chunks + ) + results = context_documents.results + + service_metadata = st.session_state.service_metadata + search_col = next( + (s["search_column"] for s in service_metadata if s["name"] == selected_service_name), + None + ) + if not search_col: + raise ValueError(f"Invalid service name: {selected_service_name}") + + + context_str = "" + for i, r in enumerate(results): + context_str += f"Context document {i+1}: {r[search_col]} \n" + "\n" + + if st.session_state.debug: + st.sidebar.text_area(f"Context documents from {selected_service_name}", context_str, height=500) + + return context_str + +def get_chat_history(): + """ + Retrieve the chat history from the session state limited to the number of messages specified + by the user in the sidebar options. + + Returns: + list: The list of chat messages from the session state. + """ + start_index = max( + 0, len(st.session_state.messages) - st.session_state.num_chat_messages + ) + return st.session_state.messages[start_index : len(st.session_state.messages) - 1] + +def complete(model, prompt): + """ + Generate a completion for the given prompt using the specified model. + + Args: + model (str): The name of the model to use for completion. + prompt (str): The prompt to generate a completion for. + + Returns: + str: The generated completion. + """ + return session.sql("SELECT snowflake.cortex.complete(?,?)", (model, prompt)).collect()[0][0] + +def classify_query(model, query): + """ + Generates classification for given query using the specified model. + + Args: + model(str): The name of the model to use for classification + prompt(str): The prompt to generate the classification for + + Returns: + str: The generated classification + """ + classification_prompt = f"""Please classify the given user query as either item to item + or item to user. If classifying as item to item, print "ITEM" with + no additional text, and no whitespace. If classifying as item to user, print "CUSTOMER" + with no additional text, and no whitespace. If the prompt involves a customer id, it will be + an item to user prompt. If the prompt involves item ids, it will be an + item to item prompt. + Here is the user query: {query} + Please return the appropriate classfication. + + """ + #st.write(classification_prompt) + return complete(model, classification_prompt) + + + + +def make_chat_history_summary(chat_history, question): + """ + Generate a summary of the chat history combined with the current question to extend the query + context. Use the language model to generate this summary. + + Args: + chat_history (str): The chat history to include in the summary. + question (str): The current user question to extend with the chat history. + + Returns: + str: The generated summary of the chat history and question. + """ + prompt = f""" + [INST] + Based on the chat history below and the question, generate a query that extend the question + with the chat history provided. The query should be in natural language. + Answer with only the query. Do not add any explanation. + + + {chat_history} + + + {question} + + [/INST] + """ + + summary = complete(st.session_state.model_name, prompt) + + if st.session_state.debug: + st.sidebar.text_area( + "Chat history summary", summary.replace("$", "\$"), height=150 + ) + + return summary + +def create_prompt(user_question, selected_service_name): + """ + Create a prompt for the language model by combining the user question with context retrieved + from the cortex search service and chat history (if enabled). Format the prompt according to + the expected input format of the model. + + Args: + user_question (str): The user's question to generate a prompt for. + + Returns: + str: The generated prompt for the language model. + """ + + customer_id = st.session_state.customer_id + if customer_id: + user_question = f"[Customer ID: {customer_id}]{user_question}" + + + if st.session_state.use_chat_history: + chat_history = get_chat_history() + if chat_history != []: + question_summary = make_chat_history_summary(chat_history, user_question) + prompt_context = query_cortex_search_service(question_summary, selected_service_name) + else: + prompt_context = query_cortex_search_service(user_question, selected_service_name) + else: + prompt_context = query_cortex_search_service(user_question, selected_service_name) + chat_history = "" + + prompt = f""" + [INST] + You are a helpful AI chat assistant with RAG capabilities. When a user asks you a question, + you will also be given context provided between and tags. Use that context + with the user's chat history provided in the between and tags + to provide a summary that addresses the user's question. Ensure the answer is coherent, concise, + and directly relevant to the user's question. + + If the user asks a generic question which cannot be answered with the given context or chat_history, + just say "I don't know the answer to that question. + + Don't saying things like "according to the provided context". + + + {chat_history} + + + {prompt_context} + + + {user_question} + + [/INST] + Answer: + """ + return prompt + +def main(): + st.title(f":speech_balloon: Personalized Shopping Assistant Powered by Kumo") + + init_customer_id() + add_login_feature() + init_service_metadata() + init_config_options() + init_messages() + + icons = {"assistant": "❄️", "user": "👤"} + + # Display chat messages from history on app rerun + for message in st.session_state.messages: + with st.chat_message(message["role"], avatar=icons[message["role"]]): + st.markdown(message["content"]) + + disable_chat = ( + "service_metadata" not in st.session_state + or len(st.session_state.service_metadata) == 0 + ) + if question := st.chat_input("Ask a question...", disabled=disable_chat): + # Add user message to chat history + st.session_state.messages.append({"role": "user", "content": question}) + # Display user message in chat message container + with st.chat_message("user", avatar=icons["user"]): + st.markdown(question.replace("$", "\$")) + + # Display assistant response in chat message container + with st.chat_message("assistant", avatar=icons["assistant"]): + message_placeholder = st.empty() + question = question.replace("'", "") + with st.spinner("Thinking..."): + #classification response + classification = classify_query(st.session_state.model_name, question) + classification = classification.strip() + #st.write(repr(classification)) + selected_search = "" + #st.write(selected_search) + if classification == "CUSTOMER": + #st.write("reached here") + selected_search = "ITEM_TO_USER_SERVICE" + elif classification == "ITEM": + selected_search = "ITEM_TO_ITEM_SERVICE" + elif classification == None: + selected_search = None + #st.write(selected_search) + generated_response = "" + if classification is None: + generated_response = "I could not classify the query. Please try again." + else: + generated_response = complete( + st.session_state.model_name, create_prompt(question, selected_search) + ) + message_placeholder.markdown(generated_response) + + st.session_state.messages.append( + {"role": "assistant", "content": generated_response} + ) + +if __name__ == "__main__": + session = get_active_session() + root = Root(session) + main() \ No newline at end of file diff --git a/samples/kumo_powered_personalized_shopping_chatbot/customer_info.sql b/samples/kumo_powered_personalized_shopping_chatbot/customer_info.sql new file mode 100644 index 00000000..b2e5dd37 --- /dev/null +++ b/samples/kumo_powered_personalized_shopping_chatbot/customer_info.sql @@ -0,0 +1,23 @@ +CREATE OR REPLACE TABLE KUMO_CHATBOT_NAISHA.PUBLIC.CUSTOMER_FULL_INFO AS +SELECT + cust.CUSTOMER_ID, + cust.FN, + cust.ACTIVE, + cust.CLUB_MEMBER_STATUS, + cust.FASHION_NEWS_FREQUENCY, + cust.AGE, + cust.POSTAL_CODE, + ( + 'Customer with ID ' || cust.CUSTOMER_ID || + CASE + WHEN cust.ACTIVE = 1 THEN ' is active' + ELSE ' is not active' + END || + ', holds a ' || COALESCE(cust.CLUB_MEMBER_STATUS, 'non-member') || + ' status, and receives fashion news ' || COALESCE(cust.FASHION_NEWS_FREQUENCY, 'never') || + '. They are ' || COALESCE(TO_CHAR(cust.AGE), 'of unknown age') || + ' years old and live in the postal code area ' || COALESCE(cust.POSTAL_CODE, 'unknown') || + '.' + ) AS customer_full_text +FROM + KUMO_CHATBOT_NAISHA.PUBLIC.CUSTOMERS cust; \ No newline at end of file diff --git a/samples/kumo_powered_personalized_shopping_chatbot/format_i2i.sql b/samples/kumo_powered_personalized_shopping_chatbot/format_i2i.sql new file mode 100644 index 00000000..45549169 --- /dev/null +++ b/samples/kumo_powered_personalized_shopping_chatbot/format_i2i.sql @@ -0,0 +1,35 @@ +CREATE OR REPLACE TABLE KUMO_CHATBOT_NAISHA.PUBLIC.ITEM_TO_ITEM_PREDICTIONS AS +WITH ranked_predictions AS ( + SELECT + source_item_id AS item_id, + target_item_id, + prediction_score, + ROW_NUMBER() OVER (PARTITION BY source_item_id ORDER BY prediction_score DESC) AS rank + FROM + KUMO_CHATBOT_NAISHA.PUBLIC.ITEM_PREDICTIONS +), +formatted_predictions AS ( + SELECT + item_id, + STRING_AGG( + 'The ' || + CASE rank + WHEN 1 THEN '1st' + WHEN 2 THEN '2nd' + WHEN 3 THEN '3rd' + ELSE rank || 'th' + END || + ' prediction for item_id ' || item_id || + ' is item_id ' || target_item_id || + ' with score ' || prediction_score || '.', ' ' + ) AS formatted_predictions + FROM + ranked_predictions + GROUP BY + item_id +) +SELECT + item_id, + formatted_predictions +FROM + formatted_predictions; diff --git a/samples/kumo_powered_personalized_shopping_chatbot/format_i2u.sql b/samples/kumo_powered_personalized_shopping_chatbot/format_i2u.sql new file mode 100644 index 00000000..73a33394 --- /dev/null +++ b/samples/kumo_powered_personalized_shopping_chatbot/format_i2u.sql @@ -0,0 +1,35 @@ +CREATE OR REPLACE TABLE KUMO_CHATBOT_NAISHA.PUBLIC.ITEM_TO_USER_PREDICTIONS AS +WITH ranked_predictions AS ( + SELECT + customer_id, + item_id, + prediction_score, + ROW_NUMBER() OVER (PARTITION BY customer_id ORDER BY prediction_score DESC) AS rank + FROM + KUMO_DB.HNM.ITEM_TO_USER_PREDICTIONS +), +formatted_predictions AS ( + SELECT + customer_id, + STRING_AGG( + 'The ' || + CASE rank + WHEN 1 THEN '1st' + WHEN 2 THEN '2nd' + WHEN 3 THEN '3rd' + ELSE rank || 'th' + END || + ' prediction for customer_id ' || customer_id || + ' is item_id ' || item_id || + ' with score ' || prediction_score || '.', ' ' + ) AS formatted_predictions + FROM + ranked_predictions + GROUP BY + customer_id +) +SELECT + customer_id, + formatted_predictions +FROM + formatted_predictions; diff --git a/samples/kumo_powered_personalized_shopping_chatbot/i2i.sql b/samples/kumo_powered_personalized_shopping_chatbot/i2i.sql new file mode 100644 index 00000000..963fd5ed --- /dev/null +++ b/samples/kumo_powered_personalized_shopping_chatbot/i2i.sql @@ -0,0 +1,23 @@ +CREATE OR REPLACE CORTEX SEARCH SERVICE KUMO_CHATBOT_NAISHA.PUBLIC.item_to_item_service +ON search_text +ATTRIBUTES ENTITY_TYPE, ITEM_ID +WAREHOUSE = WAREHOUSE_SMALL +TARGET_LAG = '1 hour' +AS + -- Search predictions from ITEM_TO_ITEM_PREDICTIONS_ENRICHED + SELECT + 'PREDICTIONS' AS ENTITY_TYPE, + ITEM_ID, + FORMATTED_PREDICTIONS AS search_text + FROM + KUMO_CHATBOT_NAISHA.PUBLIC.ITEM_TO_ITEM_PREDICTIONS_LESS + + UNION ALL + + -- Search item information from ITEM_INFO table + SELECT + 'ITEM_INFO' AS ENTITY_TYPE, + ITEM_ID, + ITEM_INFO AS search_text + FROM + KUMO_CHATBOT_NAISHA.PUBLIC.ITEM_INFO; \ No newline at end of file diff --git a/samples/kumo_powered_personalized_shopping_chatbot/i2u.sql b/samples/kumo_powered_personalized_shopping_chatbot/i2u.sql new file mode 100644 index 00000000..40a0ee31 --- /dev/null +++ b/samples/kumo_powered_personalized_shopping_chatbot/i2u.sql @@ -0,0 +1,23 @@ +CREATE OR REPLACE CORTEX SEARCH SERVICE KUMO_CHATBOT_NAISHA.PUBLIC.item_to_user_service +ON search_text +ATTRIBUTES ENTITY_TYPE, CUSTOMER_ID +WAREHOUSE = WAREHOUSE_SMALL +TARGET_LAG = '1 hour' +AS + -- Search predictions from ITEM_TO_USER_PREDICTIONS_ENRICHED + SELECT + 'PREDICTIONS' AS ENTITY_TYPE, + CUSTOMER_ID, + PREDICTIONS AS search_text + FROM + KUMO_CHATBOT_NAISHA.PUBLIC.ITEM_TO_USER_PREDICTIONS_LESS + + UNION ALL + + -- Search customer info from CUSTOMER_INFO table + SELECT + 'CUSTOMER_INFO' AS ENTITY_TYPE, + CUSTOMER_ID, + CUSTOMER_INFO AS search_text + FROM + KUMO_CHATBOT_NAISHA.PUBLIC.CUSTOMER_INFO; \ No newline at end of file diff --git a/samples/kumo_powered_personalized_shopping_chatbot/item_info.sql b/samples/kumo_powered_personalized_shopping_chatbot/item_info.sql new file mode 100644 index 00000000..a29baa75 --- /dev/null +++ b/samples/kumo_powered_personalized_shopping_chatbot/item_info.sql @@ -0,0 +1,21 @@ +CREATE OR REPLACE TABLE KUMO_CHATBOT_NAISHA.PUBLIC.ITEM_INFO AS +SELECT + article_id AS item_id, + ( + 'Item ' || article_id || ' is called ''' || prod_name || + '''. It is a ' || product_type_name || ' (product type number ' || product_type_no || + ') under the ' || product_group_name || ' product group. This item has a graphical appearance of ' || + graphical_appearance_name || ' (appearance number ' || graphical_appearance_no || + ') and belongs to the ' || colour_group_name || ' colour group (colour group code ' || + colour_group_code || '). It is perceived as ' || perceived_colour_value_name || + ' (perceived colour value ID ' || perceived_colour_value_id || + ') and categorized under the ' || perceived_colour_master_name || + ' master colour (master colour ID ' || perceived_colour_master_id || + '). The item is in the ' || department_name || ' department (department number ' || + department_no || ') of ' || index_name || ' (index code ' || + index_code || '). Specifically, it is in the ' || section_name || + ' section (section number ' || section_no || '). The garment group is ' || + garment_group_name || ' (garment group number ' || garment_group_no || ').' + ) AS item_info +FROM + KUMO_CHATBOT_NAISHA.PUBLIC.ARTICLES; \ No newline at end of file