diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..1cc393c Binary files /dev/null and b/.DS_Store differ diff --git a/.gitignore b/.gitignore index 4ef4b7b..51adb0f 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -*.rdb \ No newline at end of file +*.rdb +.envbackend/.env diff --git a/README.md b/README.md index 4240c9f..da05334 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,19 @@ -# TEP-LLM +# FaultExplainer -This project consists of a backend and a frontend. +This project consists of a backend and a frontend. It requires a valid `OPENAI_API_KEY` to be placed in `backend/.env` ## Backend -The backend requires `redis-server` to be already installed. To run the backend, follow these steps: +To run the backend, follow these steps: -1. Start the `redis-server`. +1. Create a virtual environment and install necessary packages `pip install -r backend/requirements.txt`. [Save yourself some time while installing requirements.txt. Install graphviz instead of pygraphviz {https://stackoverflow.com/questions/40266604/pip-install-pygraphviz-fails-failed-building-wheel-for-pygraphviz}. Steps are here depending on the OS you are running] 2. Change directory to the backend folder: `cd backend`. -3. Run the backend application: `python app.py`. +3. Run the backend application: `fastapi dev app.py`. ## Frontend To start the frontend, follow these steps: -1. Change directory to the frontend folder: `cd frontend`. -2. Install the required dependencies using `yarn`: `yarn`. +1. Start up a new terminal and change directory to the frontend folder: `cd frontend`. +2. Install the required dependencies using yarn: `yarn`. 3. Start the frontend development server: `yarn dev`. - diff --git a/backend/.env b/backend/.env new file mode 100644 index 0000000..401c969 --- /dev/null +++ b/backend/.env @@ -0,0 +1 @@ +OPENAI_API_KEY='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' \ No newline at end of file diff --git a/backend/.gitignore b/backend/.gitignore index ba0430d..efa407c 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -1 +1,162 @@ -__pycache__/ \ No newline at end of file +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ \ No newline at end of file diff --git a/backend/app.py b/backend/app.py index b777361..3e37beb 100644 --- a/backend/app.py +++ b/backend/app.py @@ -1,259 +1,655 @@ -from flask import Flask, jsonify, request -from flask_socketio import SocketIO -from flask_cors import CORS -import redis -import threading -from csvSimulator import DataFrameSimulator -from langchain.prompts import PromptTemplate -from langchain.chat_models import ChatOpenAI -from langchain.schema import StrOutputParser -from langchain.schema.messages import HumanMessage, SystemMessage - +# # imports +# from openai import OpenAI +# from fastapi import FastAPI, HTTPException +# from fastapi.responses import StreamingResponse +# from fastapi.middleware.cors import CORSMiddleware +# from pydantic import BaseModel +# from dotenv import load_dotenv +# import json +# from asyncio import sleep +# import base64 +# import matplotlib + +# matplotlib.use("Agg") +# import matplotlib.pyplot as plt +# import io +# import networkx as nx +# from networkx.drawing.nx_agraph import graphviz_layout + + +# cols = [ +# "A Feed", +# "D Feed", +# "E Feed", +# "A and C Feed", +# "Recycle Flow", +# "Reactor Feed Rate", +# "Reactor Pressure", +# "Reactor Level", +# "Reactor Temperature", +# "Purge Rate", +# "Product Sep Temp", +# "Product Sep Level", +# "Product Sep Pressure", +# "Product Sep Underflow", +# "Stripper Level", +# "Stripper Pressure", +# "Stripper Underflow", +# "Stripper Temp", +# "Stripper Steam Flow", +# "Compressor Work", +# "Reactor Coolant Temp", +# "Separator Coolant Temp", +# "D feed load", +# "E feed load", +# "A feed load", +# "A and C feed load", +# "Compressor recycle valve", +# "Purge valve", +# "Separator liquid load", +# "Stripper liquid load", +# "Stripper steam valve", +# "Reactor coolant load", +# "Condenser coolant load", +# ] +# mapping = {f"x{idx+1}": c for idx, c in enumerate(cols)} + +# G = nx.read_adjlist("./cg.adjlist", create_using=nx.DiGraph) +# # G = nx.reverse(G) +# G = nx.relabel_nodes(G, mapping) + + +# def get_subgraph(G, nodes): +# subgraph_nodes = set(nodes) +# for node in nodes: +# subgraph_nodes.update(G.predecessors(node)) +# subgraph_nodes.update(G.successors(node)) +# return G.subgraph(subgraph_nodes) + + +# load_dotenv() + + +# INTRO_MESSAGE = """The process produces two products from four reactants. Also present are an inert and a byproduct making a total of eight components: +# A, B, C, D, E, F, G, and H. The reactions are: + +# A(g) + C(g) + D(g) - G(liq): Product 1, + +# A(g) + C(g) + E(g) - H(liq): Product 2, + +# A(g) + E(g) - F(liq): Byproduct, + +# 3D(g) - 2F(liq): Byproduct. + +# All the reactions are irreversible and exothermic. The reaction rates are a function of temperature through an Arrhenius expression. +# The reaction to produce G has a higher activation energy resulting in more sensitivity to temperature. +# Also, the reactions are approximately first-order with respect to the reactant concentrations. + +# The process has five major unit operations: the reactor, the product condenser, a vapor-liquid separator, a recycle compressor and a product stripper. +# Figure showing a diagram of the process is attached. + +# The gaseous reactants are fed to the reactor where they react to form liquid products. The gas phase reactions are catalyzed by a nonvolatile catalyst dissolved +# in the liquid phase. The reactor has an internal cooling bundle for removing the heat of reaction. The products leave the reactor as vapors along with the unreacted feeds. +# The catalyst remains in the reactor. The reactor product stream passes through a cooler for condensing the products and from there to a vapor-liquid separator. +# Noncondensed components recycle back through a centrifugal compressor to the reactor feed. +# Condensed components move to a product stripping column to remove remaining reactants by stripping with feed stream number 4. +# Products G and H exit the stripper base and are separated in a downstream refining section which is not included in this problem. +# The inert and byproduct are primarily purged from the system as a vapor from the vapor-liquid separator.""" + +# SYSTEM_MESSAGE = ( +# "You are a helpful AI chatbot trained to assist with " +# "monitoring the Tennessee Eastman process. The Tennessee Eastman " +# f"Process can be described as follows:\n{INTRO_MESSAGE}" +# "\n\nYour purpose is to help the user identify and understand potential " +# "explanations for any faults that occur during the process. You should " +# "explain your reasoning using the graphs provided." +# ) + +# EXPLAIN_PROMPT = ( +# "You are provided with the general schematics of the Tennessee" +# "Eastman process, causal graphs of different features and graphs showing " +# "the values of the top contributing features for a recent fault. For every " +# "contributing feature reason about the observation graphs (not all " +# "contributing features might have sudden change around the fault) and " +# "create hypotheses for these observations based on the causal graph. " +# "Finally combine these hypotheses in order to generate an explanation as to" +# "why this fault occurred and how it is propagating." +# ) + +# client = OpenAI() + +# # Initialize FastAPI app +# app = FastAPI() + +# origins = ["http://localhost", "http://localhost:5173", "*"] + +# app.add_middleware( +# CORSMiddleware, +# allow_origins=origins, +# allow_credentials=True, +# allow_methods=["*"], +# allow_headers=["*"], +# ) + + +# # Define request and response models +# class MessageRequest(BaseModel): +# data: list[dict[str, str]] +# id: str + + +# class ExplainationRequest(BaseModel): +# data: dict[str, list[float]] +# id: str + + +# class Image(BaseModel): +# image: str +# name: str + + +# class MessageResponse(BaseModel): +# content: str +# images: list[Image] = [] +# index: int +# id: str + + +# def ChatModelCompletion( +# messages: list[dict[str, str]], msg_id: str, images: list[str] = None +# ): +# # Send the message to OpenAI's GPT-4 +# # print(messages) +# response = client.chat.completions.create( +# model="gpt-4o", +# messages=messages, +# stream=True, +# temperature=0, +# ) +# print("sending response") +# index = 0 +# for chunk in response: +# # print(chunk) +# # Extract the response text +# if chunk.choices[0].delta.content: +# response_text = chunk.choices[0].delta.content +# if index == 0 and images: +# response_str = json.dumps( +# { +# "index": index, +# "content": response_text, +# "id": msg_id, +# "images": images, +# } +# ) +# else: +# response_str = json.dumps( +# { +# "index": index, +# "content": response_text, +# "id": msg_id, +# "images": [], +# } +# ) +# index += 1 +# yield "data: " + response_str + "\n\n" +# print(f"Sent {index} chunks") + + +# def plot_causal_subgraph(request: ExplainationRequest): +# nodes_of_interest = request.data.keys() +# subgraph = get_subgraph(G, nodes_of_interest) +# # Visualize the subgraph +# pos = graphviz_layout(subgraph, prog="dot") +# nx.draw( +# subgraph, +# pos, +# with_labels=True, +# node_color="lightblue", +# node_size=500, +# font_size=10, +# arrows=True, +# ) +# nx.draw_networkx_nodes( +# subgraph, pos, nodelist=nodes_of_interest, node_color="red", node_size=600 +# ) +# plt.title("Causal graph of important features (higlighted in red)") +# plt.axis("off") +# img_bytes = io.BytesIO() +# plt.savefig(img_bytes, format="png") +# img_bytes.seek(0) +# img_base64 = base64.b64encode(img_bytes.read()).decode() +# # # DEBUG +# # with open(f"./img/causal_graph.png", "wb") as f: +# # f.write(base64.b64decode(bytes(img_base64, "utf-8"))) +# plt.close() +# return {"image": img_base64, "name": "Causal graph"} + + +# def plot_graphs_to_base64(request: ExplainationRequest): +# graphs = [] +# for feature_name in request.data: +# try: +# # Plot the feature's historical data +# fig, ax = plt.subplots() +# ax.step( +# range(len(request.data[feature_name])), +# request.data[feature_name], +# label=feature_name, +# where="mid", +# ) +# # ax.plot(request.data[feature_name], label=feature_name) + +# ax.axvline( +# x=len(request.data[feature_name]) - 20, +# color="r", +# linestyle="--", +# label="Fault Start", +# ) + +# ax.set_xlabel("Time") +# ax.set_ylabel(feature_name) +# ax.set_title(f"{feature_name} over Time around Fault") + +# # Ensure layout is neat +# plt.tight_layout() + +# # Convert plot to a format that can be sent over WebSocket +# img_bytes = io.BytesIO() +# plt.savefig(img_bytes, format="png") +# img_bytes.seek(0) +# img_base64 = base64.b64encode(img_bytes.read()).decode() +# # # DEBUG +# # with open(f"./img/{feature_name}.png", "wb") as f: +# # f.write(base64.b64decode(bytes(img_base64, "utf-8"))) +# # Send the image to the frontend +# graphs.append({"name": feature_name, "image": img_base64}) +# plt.close(fig) +# except Exception as e: +# print(e) +# return graphs + + +# @app.post("/explain", response_model=None) +# async def explain(request: ExplainationRequest): +# try: +# with open("./tep_flowsheet.png", "rb") as image_file: +# schematic_img_base64 = base64.b64encode(image_file.read()).decode("utf-8") +# graphs = plot_graphs_to_base64(request) +# causal_graph = plot_causal_subgraph(request) +# schema_image = { +# "type": "image_url", +# "image_url": {"url": f"data:image/png;base64,{schematic_img_base64}"}, +# } +# emessages = [ +# {"role": "system", "content": SYSTEM_MESSAGE}, +# { +# "role": "user", +# "content": [ +# {"type": "text", "text": EXPLAIN_PROMPT}, +# schema_image, +# ] +# + [ +# { +# "type": "image_url", +# "image_url": {"url": f"data:image/png;base64,{graph['image']}"}, +# } +# for graph in graphs +# ] +# + [ +# { +# "type": "image_url", +# "image_url": { +# "url": f"data:image/png;base64,{causal_graph['image']}" +# }, +# } +# ], +# }, +# ] +# return StreamingResponse( +# ChatModelCompletion( +# messages=emessages, +# msg_id=request.id, +# images=graphs + [causal_graph], +# ), +# media_type="text/event-stream", +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=str(e)) + + +# @app.post("/send_message", response_model=MessageResponse) +# async def send_message(request: MessageRequest): +# try: +# return StreamingResponse( +# ChatModelCompletion(messages=request.data, msg_id=f"reply-{request.id}"), +# media_type="text/event-stream", +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=str(e)) + + +# imports +from openai import OpenAI +from fastapi import FastAPI, HTTPException +from fastapi.responses import StreamingResponse +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from dotenv import load_dotenv import json -import datetime -import io +from asyncio import sleep import base64 import matplotlib -matplotlib.use('Agg') + +matplotlib.use("Agg") import matplotlib.pyplot as plt -import matplotlib.dates as mdates +import io +import networkx as nx +from networkx.drawing.nx_agraph import graphviz_layout + + +cols = [ + "A Feed", + "D Feed", + "E Feed", + "A and C Feed", + "Recycle Flow", + "Reactor Feed Rate", + "Reactor Pressure", + "Reactor Level", + "Reactor Temperature", + "Purge Rate", + "Product Sep Temp", + "Product Sep Level", + "Product Sep Pressure", + "Product Sep Underflow", + "Stripper Level", + "Stripper Pressure", + "Stripper Underflow", + "Stripper Temp", + "Stripper Steam Flow", + "Compressor Work", + "Reactor Coolant Temp", + "Separator Coolant Temp", + "D feed load", + "E feed load", + "A feed load", + "A and C feed load", + "Compressor recycle valve", + "Purge valve", + "Separator liquid load", + "Stripper liquid load", + "Stripper steam valve", + "Reactor coolant load", + "Condenser coolant load", +] +mapping = {f"x{idx+1}": c for idx, c in enumerate(cols)} + +G = nx.read_adjlist("./cg.adjlist", create_using=nx.DiGraph) +# G = nx.reverse(G) +G = nx.relabel_nodes(G, mapping) + + +def get_full_graph(G): + return G # Return the entire graph + +def get_subgraph(G, nodes): + subgraph_nodes = set(nodes) + for node in nodes: + subgraph_nodes.update(G.predecessors(node)) + subgraph_nodes.update(G.successors(node)) + return G.subgraph(subgraph_nodes) + +load_dotenv() + + +INTRO_MESSAGE = """The process produces two products from four reactants. Also present are an inert and a byproduct making a total of eight components: +A, B, C, D, E, F, G, and H. The reactions are: + +A(g) + C(g) + D(g) - G(liq): Product 1, + +A(g) + C(g) + E(g) - H(liq): Product 2, + +A(g) + E(g) - F(liq): Byproduct, + +3D(g) - 2F(liq): Byproduct. + +All the reactions are irreversible and exothermic. The reaction rates are a function of temperature through an Arrhenius expression. +The reaction to produce G has a higher activation energy resulting in more sensitivity to temperature. +Also, the reactions are approximately first-order with respect to the reactant concentrations. + +The process has five major unit operations: the reactor, the product condenser, a vapor-liquid separator, a recycle compressor and a product stripper. +Figure showing a diagram of the process is attached. + +The gaseous reactants are fed to the reactor where they react to form liquid products. The gas phase reactions are catalyzed by a nonvolatile catalyst dissolved +in the liquid phase. The reactor has an internal cooling bundle for removing the heat of reaction. The products leave the reactor as vapors along with the unreacted feeds. +The catalyst remains in the reactor. The reactor product stream passes through a cooler for condensing the products and from there to a vapor-liquid separator. +Noncondensed components recycle back through a centrifugal compressor to the reactor feed. +Condensed components move to a product stripping column to remove remaining reactants by stripping with feed stream number 4. +Products G and H exit the stripper base and are separated in a downstream refining section which is not included in this problem. +The inert and byproduct are primarily purged from the system as a vapor from the vapor-liquid separator.""" + +SYSTEM_MESSAGE = ( + "You are a helpful AI chatbot trained to assist with " + "monitoring the Tennessee Eastman process. The Tennessee Eastman " + f"Process can be described as follows:\n{INTRO_MESSAGE}" + "\n\nYour purpose is to help the user identify and understand potential " + "explanations for any faults that occur during the process. You should " + "explain your reasoning using the graphs provided." +) +EXPLAIN_PROMPT = ( + "You are provided with the general schematics of the Tennessee" + "Eastman process, causal graphs of different features and graphs showing " + "the values of the top contributing features for a recent fault. For every " + "contributing feature reason about the observation graphs (not all " + "contributing features might have sudden change around the fault) and " + "create hypotheses for these observations based on the causal graph. " + "Finally combine these hypotheses in order to generate an explanation as to" + "why this fault occurred and how it is propagating." +) -# Fault Datacetion model integration -from model import FaultDetectionModel -import pandas as pd +client = OpenAI() -train_data = pd.read_csv("./data/fault0.csv") -fault_detector = FaultDetectionModel(alpha=0.001) -fault_detector.fit(train_data.iloc[:, 1:]) +# Initialize FastAPI app +app = FastAPI() -fault_history = [] +origins = ["http://localhost", "http://localhost:5173", "*"] -def top_feature_graphs(top_features, timestamp): - target_idx = fault_detector.data_buffer[fault_detector.data_buffer['timestamp'] == timestamp].index[0] - start_index = max(0, target_idx - 50) # 30 points before the fault - graphs = [] - for feature in top_features: - # Plot the feature's historical data - fig, ax = plt.subplots() - ax.plot(fault_detector.data_buffer['timestamp'][start_index:target_idx], fault_detector.data_buffer[feature][start_index:target_idx], label=feature) +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) - ax.axvline(x=timestamp-datetime.timedelta(minutes=fault_detector.post_fault_threshold*3), color='r', linestyle='--', label='Fault Start') +# Define request and response models +class MessageRequest(BaseModel): + data: list[dict[str, str]] + id: str - # Format the x-axis to display dates correctly - ax.xaxis.set_major_locator(mdates.AutoDateLocator()) - ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) - # Rotate date labels for better readability - plt.xticks(rotation=45) +class ExplainationRequest(BaseModel): + data: dict[str, list[float]] + id: str - # Set labels and title - ax.set_xlabel('Time') - ax.set_ylabel(feature) - ax.set_title(f'{feature} over Time around Fault') - # Ensure layout is neat - plt.tight_layout() +class Image(BaseModel): + image: str + name: str - # Convert plot to a format that can be sent over WebSocket - img_bytes = io.BytesIO() - plt.savefig(img_bytes, format='png') - img_bytes.seek(0) - img_base64 = base64.b64encode(img_bytes.read()).decode() - # Send the image to the frontend - graphs.append({'feature': feature, 'image': img_base64}) +class MessageResponse(BaseModel): + content: str + images: list[Image] = [] + index: int + id: str - # with open(f"debug_{feature}.txt", "w") as file: - # file.write(img_base64) - plt.close(fig) +def ChatModelCompletion( + messages: list[dict[str, str]], msg_id: str, images: list[str] = None +): + # Send the message to OpenAI's GPT-4 + # print(messages) + response = client.chat.completions.create( + model="gpt-4o", + messages=messages, + stream=True, + temperature=0, + ) + print("sending response") + index = 0 + for chunk in response: + # print(chunk) + # Extract the response text + if chunk.choices[0].delta.content: + response_text = chunk.choices[0].delta.content + if index == 0 and images: + response_str = json.dumps( + { + "index": index, + "content": response_text, + "id": msg_id, + "images": images, + } + ) + else: + response_str = json.dumps( + { + "index": index, + "content": response_text, + "id": msg_id, + "images": [], + } + ) + index += 1 + yield "data: " + response_str + "\n\n" + print(f"Sent {index} chunks") + + +def plot_causal_subgraph(request: ExplainationRequest): + nodes_of_interest = request.data.keys() + full_graph = get_full_graph(G) + sub_graph = get_subgraph(G, nodes_of_interest) + # Visualize the subgraph + pos = graphviz_layout(sub_graph, prog="dot") + nx.draw( + sub_graph, + pos, + with_labels=True, + node_color="lightblue", + node_size=500, + font_size=10, + arrows=True, + ) + nx.draw_networkx_nodes( + sub_graph, pos, nodelist=nodes_of_interest, node_color="red", node_size=600 + ) + plt.title("Causal graph of important features (higlighted in red)") + plt.axis("off") + img_bytes = io.BytesIO() + plt.savefig(img_bytes, format="png") + img_bytes.seek(0) + img_base64 = base64.b64encode(img_bytes.read()).decode() + # # DEBUG + # with open(f"./img/causal_graph.png", "wb") as f: + # f.write(base64.b64decode(bytes(img_base64, "utf-8"))) + plt.close() + return {"image": img_base64, "name": "Causal graph"} + + +def plot_graphs_to_base64(request: ExplainationRequest): + graphs = [] + for feature_name in request.data: + try: + # Plot the feature's historical data + fig, ax = plt.subplots() + ax.step( + range(len(request.data[feature_name])), + request.data[feature_name], + label=feature_name, + where="mid", + ) + # ax.plot(request.data[feature_name], label=feature_name) + + ax.axvline( + x=len(request.data[feature_name]) - 20, + color="r", + linestyle="--", + label="Fault Start", + ) + + ax.set_xlabel("Time") + ax.set_ylabel(feature_name) + ax.set_title(f"{feature_name} over Time around Fault") + + # Ensure layout is neat + plt.tight_layout() + + # Convert plot to a format that can be sent over WebSocket + img_bytes = io.BytesIO() + plt.savefig(img_bytes, format="png") + img_bytes.seek(0) + img_base64 = base64.b64encode(img_bytes.read()).decode() + # # DEBUG + # with open(f"./img/{feature_name}.png", "wb") as f: + # f.write(base64.b64decode(bytes(img_base64, "utf-8"))) + # Send the image to the frontend + graphs.append({"name": feature_name, "image": img_base64}) + plt.close(fig) + except Exception as e: + print(e) return graphs -def handle_fault_detection(processed_data_point): - data_point = processed_data_point['data'] - fault_id = processed_data_point['fault_id'] - # Identify the top contributing features - top_features_contrib = fault_detector.t2_contrib(data_point.iloc[0]['timestamp']) - paired_list = sorted(zip(top_features_contrib, fault_detector.feature_names), reverse=True)[:6] - top_features = [s for _, s in paired_list] - - graphs = top_feature_graphs(top_features, data_point.iloc[0]['timestamp']) - - with open("./tep_flowsheet.png", "rb") as image_file: - schematic_img_base64 = base64.b64encode(image_file.read()).decode('utf-8') - # Add the schematic image - image_data = [ - {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{schematic_img_base64}"}} - ] - image_data.extend([ - {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{graph['image']}"}} - for graph in graphs - ]) - ai_introduction = SystemMessage( - content=[ +@app.post("/explain", response_model=None) +async def explain(request: ExplainationRequest): + try: + with open("./tep_flowsheet.png", "rb") as image_file: + schematic_img_base64 = base64.b64encode(image_file.read()).decode("utf-8") + graphs = plot_graphs_to_base64(request) + causal_graph = plot_causal_subgraph(request) + schema_image = { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{schematic_img_base64}"}, + } + emessages = [ + {"role": "system", "content": SYSTEM_MESSAGE}, { - "type": "text", - "text": "I am a helpful AI chatbot trained to assist with monitoring the Tennessee Eastman process. My purpose is to help you identify and understand potential explanations for any faults that occur during the process." - } + "role": "user", + "content": [ + {"type": "text", "text": EXPLAIN_PROMPT}, + schema_image, + ] + + [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{graph['image']}"}, + } + for graph in graphs + ] + }, ] - ) - user_prompt = HumanMessage( - content=[ - { - "type": "text", - "text": "You are provided with the general schematics of the Tennessee Eastman reactor and graphs showing the values of the top contributing features for a recent fault. Based on this information, please generate an explanation as to why this fault occurred and how it is propagating." - } - ] + image_data # type: ignore - ) - response = chain.invoke([ai_introduction, user_prompt]) - explanation = response.content if response else "No response received." - # Convert the fault information into a readable message - fault_message = f"Fault detected at {data_point.iloc[0]['timestamp']-datetime.timedelta(minutes=fault_detector.post_fault_threshold*3)}\n {explanation}" - print("Fault explanation sent") - # Send this message to the frontend via WebSocket - socketio.emit('chat_reply', {'images': graphs, 'message': fault_message}) - - # import pdb; pdb.set_trace() - target_idx = fault_detector.data_buffer[fault_detector.data_buffer['timestamp'] == data_point.iloc[0]['timestamp']].index[0] - start_idx = max(0, target_idx-50) - data_df = fault_detector.data_buffer.loc[start_idx:target_idx, top_features+["timestamp"]] - df_melted = data_df.melt(id_vars='timestamp', var_name='key', value_name='value') - result = df_melted.groupby('key').apply(lambda x: x[['timestamp', 'value']].to_dict('records')).to_json() - fault_info = { - 'start_time': int((data_point.iloc[0]['timestamp']-datetime.timedelta(minutes=fault_detector.post_fault_threshold*3)).timestamp()*1000), - 'explanation': explanation, - 'top_features': result - } - fault_history.append(fault_info) - -fault_detector.register_fault_callback(handle_fault_detection) - - -# LLM Integration - -import os - -k = os.environ["OPENAI_API_KEY"] -chain = ChatOpenAI(model='gpt-4-vision-preview', api_key=k, max_tokens=2048) -prompt = PromptTemplate.from_template( -""" -{text} -""" -) -runnable = prompt | chain | StrOutputParser() - -app = Flask(__name__) -app.config['SECRET_KEY'] = 'secret!' -CORS(app) # Enable CORS for all routes -socketio = SocketIO(app, cors_allowed_origins="*") - -import configparser -config = configparser.ConfigParser() -config.read('config.ini') - -redis_host = config['Redis']['host'] -redis_port = int(config['Redis']['port']) # Port should be an integer -redis_db = int(config['Redis']['db']) # DB should be an integer - - -redis_client = redis.Redis(host=redis_host, port=redis_port, db=redis_db) - -simulator = DataFrameSimulator() -simulator.start() - -def data_listener(): - TIME_DELTA = datetime.timedelta(minutes=3) - PREV_TIME = datetime.datetime.now() - TIME_DELTA - while True: - # Block until data is available - _, data = redis_client.blpop('simulator_data') # type: ignore - data_point = json.loads(data.decode('utf-8')) - df_data_point = pd.DataFrame([data_point]).iloc[:, 1:] - df_data_point['timestamp'] = (PREV_TIME + TIME_DELTA) - PREV_TIME = (PREV_TIME + TIME_DELTA) - - socketio.emit('data_update', {'data': df_data_point.iloc[0].to_json()}) - # print(df_data_point.iloc[0].to_json()) - # print(df_data_point.iloc[0]['timestamp'].timestamp() ) - t2_stat, anomaly = fault_detector.process_data_point(df_data_point) - socketio.emit('t2_update', {'t2_stat': t2_stat, 'anomaly': anomaly.item(), 'timestamp': int(df_data_point.iloc[0]['timestamp'].timestamp()*1000) } ) - # handle through callback - # if anomaly: - # handle_fault_detection(df_data_point) - - - -# Start the data listener in a separate thread -listener_thread = threading.Thread(target=data_listener, daemon=True) -listener_thread.start() - -@app.route('/') -def hello_world(): - return "
Hello, World!
" - -@app.route('/set_rate', methods=['POST']) -def set_rate(): - data = request.get_json() - new_rate = float(data.get('rate', 1)) # Default to 1 if not specified - print(new_rate) - simulator.change_rate(new_rate) - return jsonify({'message': f'Rate updated to {new_rate}'}) - -@app.route('/change_state', methods=['POST']) -def change_state(): - data = request.get_json() - new_state = int(data.get('state', 0)) # Default to 0 if not specified - simulator.induce_fault(new_state) - return jsonify({'message': f'State updated to {new_state}'}) - -@app.route('/pause', methods=['POST']) -def pause_sim(): - data = request.get_json() - simulator.pause() - print(simulator.paused) - return jsonify({'message': f'Simulator status updated to Paused'}) - -@app.route('/resume', methods=['POST']) -def resume_sim(): - data = request.get_json() - simulator.resume() - print(simulator.paused) - return jsonify({'message': f'Simulator status updated to Running'}) - - -@app.route('/get_state', methods=['GET']) -def get_state(): - return jsonify({'state': simulator.fault_id}) - -@app.route('/get_rate', methods=['GET']) -def get_rate(): - return jsonify({'rate': float(simulator.rate)}) - -@app.route('/get_pause_status', methods=['GET']) -def get_pause_status(): - return jsonify({'status': simulator.paused}) - -@app.route('/fault_history') -def get_fault_history(): - return jsonify(fault_history) - -# @socketio.on('change_state') -# def handle_change_state(message): -# print(f"Received state change request: {message}") -# fault_id = int(message.split()[-1]) if message.split()[-1].isdigit() else 0 -# simulator.induce_fault(fault_id) - -@socketio.on('chat_message') -def handle_chat_message(message): - print(f"Received chat message: {message}") - # Process the message and generate a reply - # reply = f"SERVER Reply to: {message}" # Example reply - reply = runnable.invoke({"text": message}) - socketio.emit('chat_reply', reply) # Send reply back to the frontend - - -if __name__ == '__main__': - socketio.run(app, host="0.0.0.0", port=5001) # type: ignore + return StreamingResponse( + ChatModelCompletion( + messages=emessages, + msg_id=request.id, + images=graphs, + ), + media_type="text/event-stream", + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/send_message", response_model=MessageResponse) +async def send_message(request: MessageRequest): + try: + return StreamingResponse( + ChatModelCompletion(messages=request.data, msg_id=f"reply-{request.id}"), + media_type="text/event-stream", + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/cg.adjlist b/backend/cg.adjlist new file mode 100644 index 0000000..4ce8644 --- /dev/null +++ b/backend/cg.adjlist @@ -0,0 +1,33 @@ +x1 x25 +x2 x23 +x3 x24 +x4 x26 +x5 x27 +x6 x9 +x7 x22 +x8 x9 +x9 x7 +x10 x28 +x11 x13 x14 x27 +x12 x11 +x13 x28 +x14 x29 +x15 x16 x17 x18 +x16 +x17 x30 +x18 +x19 x31 +x20 +x21 x32 +x22 x11 x33 +x23 x6 +x24 x6 +x25 x6 +x26 x15 +x27 x20 +x28 +x29 x15 +x30 x19 +x31 +x32 x9 +x33 \ No newline at end of file diff --git a/backend/cg.png b/backend/cg.png new file mode 100644 index 0000000..e0112fc Binary files /dev/null and b/backend/cg.png differ diff --git a/backend/config.ini b/backend/config.ini deleted file mode 100644 index 994c820..0000000 --- a/backend/config.ini +++ /dev/null @@ -1,4 +0,0 @@ -[Redis] -host = localhost -port = 6379 -db = 0 diff --git a/backend/csvSimulator.py b/backend/csvSimulator.py deleted file mode 100644 index 9cba70c..0000000 --- a/backend/csvSimulator.py +++ /dev/null @@ -1,66 +0,0 @@ -import pandas as pd -import threading -import redis -from simulator import BaseSimulator - -class DataFrameSimulator(BaseSimulator): - initial_file = './data/fault0.csv' - def __init__(self, redis_host='localhost', redis_port=6379, redis_db=0): - super().__init__() - self.data_frame = pd.read_csv(self.initial_file) - self.iterator = self.data_frame.iterrows() - self.redis_client = redis.Redis(host=redis_host, port=redis_port, db=redis_db) - self.fault_id = 0 - - def step(self): - try: - # Get the next row from the iterator - _, row = next(self.iterator) - # print(f"Thread ID: {threading.get_ident()}, Data: {row['time']}") - except StopIteration: - # Restart the iterator if we reach the end of the DataFrame - self.iterator = self.data_frame.iterrows() - _, row = next(self.iterator) - # print(f"Thread ID: {threading.get_ident()}, Data: {row['time']}") - - # Convert the row to a string or a suitable format - data_str = row.to_json() - # Publish the data to Redis - self.redis_client.rpush('simulator_data', data_str) - - def induce_fault(self, fault_id): - self.fault_id = fault_id - # Change the DataFrame based on the fault_id - print(f"Changing dataframe to fault{fault_id}.csv") - new_file = f"./data/fault{fault_id}.csv" - self.data_frame = pd.read_csv(new_file) - self.iterator = self.data_frame.iterrows() - - -if __name__ == "__main__": - import time - - # Create an instance of the Simulator - simulator = DataFrameSimulator() - print(f"Thread ID: {threading.get_ident()}> Starting Simulator") - - # Start the simulator - simulator.start() - time.sleep(5) # Let it run for a while - - print(f"Thread ID: {threading.get_ident()}> Introducing Fault") - # Induce a fault by setting the counter to 3 - simulator.induce_fault(3) - - time.sleep(5) # Let it run for a while after inducing the fault - - # Pause, resume, and stop to demonstrate control - print(f"Thread ID: {threading.get_ident()}> Pausing Simulator") - simulator.pause() - time.sleep(2) # Paused for 2 seconds - print(f"Thread ID: {threading.get_ident()}> Resuming Simulator") - simulator.resume() - - time.sleep(5) # Let it run for a while more - print(f"Thread ID: {threading.get_ident()}> Stopping Simulator") - simulator.stop() diff --git a/backend/data/data_processing.ipynb b/backend/data/data_processing.ipynb deleted file mode 100644 index 7ed320f..0000000 --- a/backend/data/data_processing.ipynb +++ /dev/null @@ -1,477 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import pyreadr" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [], - "source": [ - "data = pyreadr.read_r(\"./TEP_Faulty_Training.RData\")" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [], - "source": [ - "df = data['faulty_training']" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "| \n", - " | faultNumber | \n", - "simulationRun | \n", - "sample | \n", - "xmeas_1 | \n", - "xmeas_2 | \n", - "xmeas_3 | \n", - "xmeas_4 | \n", - "xmeas_5 | \n", - "xmeas_6 | \n", - "xmeas_7 | \n", - "... | \n", - "xmv_2 | \n", - "xmv_3 | \n", - "xmv_4 | \n", - "xmv_5 | \n", - "xmv_6 | \n", - "xmv_7 | \n", - "xmv_8 | \n", - "xmv_9 | \n", - "xmv_10 | \n", - "xmv_11 | \n", - "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", - "1 | \n", - "1.0 | \n", - "1 | \n", - "0.25038 | \n", - "3674.0 | \n", - "4529.0 | \n", - "9.2320 | \n", - "26.889 | \n", - "42.402 | \n", - "2704.3 | \n", - "... | \n", - "53.744 | \n", - "24.657 | \n", - "62.544 | \n", - "22.137 | \n", - "39.935 | \n", - "42.323 | \n", - "47.757 | \n", - "47.510 | \n", - "41.258 | \n", - "18.447 | \n", - "
| 1 | \n", - "1 | \n", - "1.0 | \n", - "2 | \n", - "0.25109 | \n", - "3659.4 | \n", - "4556.6 | \n", - "9.4264 | \n", - "26.721 | \n", - "42.576 | \n", - "2705.0 | \n", - "... | \n", - "53.414 | \n", - "24.588 | \n", - "59.259 | \n", - "22.084 | \n", - "40.176 | \n", - "38.554 | \n", - "43.692 | \n", - "47.427 | \n", - "41.359 | \n", - "17.194 | \n", - "
| 2 | \n", - "1 | \n", - "1.0 | \n", - "3 | \n", - "0.25038 | \n", - "3660.3 | \n", - "4477.8 | \n", - "9.4426 | \n", - "26.875 | \n", - "42.070 | \n", - "2706.2 | \n", - "... | \n", - "54.357 | \n", - "24.666 | \n", - "61.275 | \n", - "22.380 | \n", - "40.244 | \n", - "38.990 | \n", - "46.699 | \n", - "47.468 | \n", - "41.199 | \n", - "20.530 | \n", - "
| 3 | \n", - "1 | \n", - "1.0 | \n", - "4 | \n", - "0.24977 | \n", - "3661.3 | \n", - "4512.1 | \n", - "9.4776 | \n", - "26.758 | \n", - "42.063 | \n", - "2707.2 | \n", - "... | \n", - "53.946 | \n", - "24.725 | \n", - "59.856 | \n", - "22.277 | \n", - "40.257 | \n", - "38.072 | \n", - "47.541 | \n", - "47.658 | \n", - "41.643 | \n", - "18.089 | \n", - "
| 4 | \n", - "1 | \n", - "1.0 | \n", - "5 | \n", - "0.29405 | \n", - "3679.0 | \n", - "4497.0 | \n", - "9.3381 | \n", - "26.889 | \n", - "42.650 | \n", - "2705.1 | \n", - "... | \n", - "53.658 | \n", - "28.797 | \n", - "60.717 | \n", - "21.947 | \n", - "39.144 | \n", - "41.955 | \n", - "47.645 | \n", - "47.346 | \n", - "41.507 | \n", - "18.461 | \n", - "
| ... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "
| 4999995 | \n", - "20 | \n", - "500.0 | \n", - "496 | \n", - "0.23419 | \n", - "3655.3 | \n", - "4461.7 | \n", - "9.3448 | \n", - "27.008 | \n", - "42.481 | \n", - "2703.0 | \n", - "... | \n", - "53.670 | \n", - "23.350 | \n", - "61.061 | \n", - "20.719 | \n", - "40.999 | \n", - "38.653 | \n", - "47.386 | \n", - "47.528 | \n", - "40.212 | \n", - "17.659 | \n", - "
| 4999996 | \n", - "20 | \n", - "500.0 | \n", - "497 | \n", - "0.26704 | \n", - "3647.4 | \n", - "4540.2 | \n", - "9.3546 | \n", - "27.034 | \n", - "42.671 | \n", - "2704.7 | \n", - "... | \n", - "54.650 | \n", - "26.362 | \n", - "60.020 | \n", - "20.263 | \n", - "41.579 | \n", - "33.624 | \n", - "47.536 | \n", - "47.647 | \n", - "41.199 | \n", - "18.741 | \n", - "
| 4999997 | \n", - "20 | \n", - "500.0 | \n", - "498 | \n", - "0.26543 | \n", - "3630.3 | \n", - "4571.6 | \n", - "9.4089 | \n", - "27.129 | \n", - "42.470 | \n", - "2705.1 | \n", - "... | \n", - "54.274 | \n", - "26.521 | \n", - "59.824 | \n", - "20.189 | \n", - "41.505 | \n", - "40.967 | \n", - "52.437 | \n", - "47.802 | \n", - "41.302 | \n", - "23.199 | \n", - "
| 4999998 | \n", - "20 | \n", - "500.0 | \n", - "499 | \n", - "0.27671 | \n", - "3655.7 | \n", - "4498.9 | \n", - "9.3781 | \n", - "27.353 | \n", - "42.281 | \n", - "2705.8 | \n", - "... | \n", - "53.506 | \n", - "26.781 | \n", - "62.818 | \n", - "20.453 | \n", - "40.208 | \n", - "40.957 | \n", - "47.628 | \n", - "48.086 | \n", - "40.510 | \n", - "15.932 | \n", - "
| 4999999 | \n", - "20 | \n", - "500.0 | \n", - "500 | \n", - "0.27421 | \n", - "3640.4 | \n", - "4474.4 | \n", - "9.3866 | \n", - "27.145 | \n", - "41.985 | \n", - "2706.0 | \n", - "... | \n", - "53.800 | \n", - "27.027 | \n", - "59.757 | \n", - "20.157 | \n", - "40.326 | \n", - "36.039 | \n", - "48.885 | \n", - "48.170 | \n", - "41.115 | \n", - "15.752 | \n", - "
5000000 rows × 55 columns
\n", - "