diff --git a/api/db/services/connector_service.py b/api/db/services/connector_service.py index db8c713e23b..1a724042e3b 100644 --- a/api/db/services/connector_service.py +++ b/api/db/services/connector_service.py @@ -194,7 +194,19 @@ def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_co update_date=timestamp_to_date(current_timestamp()) )\ .where(cls.model.id == id).execute() - + + @classmethod + def increase_deleted_docs(cls, min_update, max_update, doc_num, err_msg="", error_count=0): + cls.model.update( + docs_removed_from_index=cls.model.docs_removed_from_index + doc_num, + total_docs_indexed=cls.model.total_docs_indexed - doc_num, + poll_range_start=fn.COALESCE(fn.LEAST(cls.model.poll_range_start,min_update), min_update), + poll_range_end=fn.COALESCE(fn.GREATEST(cls.model.poll_range_end, max_update), max_update), + error_msg=cls.model.error_msg + err_msg, + error_count=cls.model.error_count + error_count, + update_time=current_timestamp(), + update_date=timestamp_to_date(current_timestamp()) + ) @classmethod def duplicate_and_parse(cls, kb, docs, tenant_id, src, auto_parse=True): from api.db.services.file_service import FileService diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 514b3fd870d..312b0233de9 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -750,6 +750,22 @@ def get_all_kb_doc_count(cls): for row in rows: result[row.kb_id] = row.count return result + + @classmethod + @DB.connection_context() + def list_document_ids_by_src(cls, tenant_id, kb, src): + fields = [cls.model.id] + docs = cls.model.select(*fields)\ + .where( + (cls.model.kb_id == kb), + (cls.model.source_type == src) + ) + + res = [] + for doc in docs: + res.append(doc.id) + + return res @classmethod @DB.connection_context() @@ -1028,3 +1044,4 @@ def embedding(doc_id, cnts, batch_size=16): doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) return [d["id"] for d, _ in files] + diff --git a/common/settings.py b/common/settings.py index 9df0c0cd2d0..3921f4aa996 100644 --- a/common/settings.py +++ b/common/settings.py @@ -83,6 +83,8 @@ # user registration switch REGISTER_ENABLED = 1 +ENABLE_SYNC_DELETED_CHANGE = os.getenv('ENABLE_SYNC_DELETED_CHANGE', False) + # sandbox-executor-manager SANDBOX_HOST = None diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index b29ad15de53..4ea0b4fe5d1 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -35,6 +35,8 @@ from api.db.services.connector_service import ConnectorService, SyncLogsService from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.document_service import DocumentService +from api.db.services.file_service import FileService from common import settings from common.config_utils import show_configs from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector, MoodleConnector, JiraConnector @@ -61,6 +63,13 @@ async def __call__(self, task: dict): SyncLogsService.start(task["id"], task["connector_id"]) try: async with task_limiter: + synced_doc_ids = set() #synced document ids for this sync task + source_type = f"{self.SOURCE_NAME}/{task['connector_id']}" + existing_doc_ids = [] + with trio.fail_after(task["timeout_secs"]): + # get current synced docs from last sync + existing_doc_ids = DocumentService.list_documents_by_source(task["tenant_id"], task["kb_id"], source_type) + with trio.fail_after(task["timeout_secs"]): document_batch_generator = await self._generate(task) doc_num = 0 @@ -92,6 +101,42 @@ async def __call__(self, task: dict): SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err)) doc_num += len(docs) + if settings.ENABLE_SYNC_DELETED_CHANGE: + task_copy = copy.deepcopy(task) + task_copy.pop("poll_range_start", None) + document_batch_generator = await self._generate(task) + for document_batch in document_batch_generator: + if not document_batch: + continue + docs = [ + { + "id": doc.id, + "connector_id": task["connector_id"], + "source": self.SOURCE_NAME, + "semantic_identifier": doc.semantic_identifier, + "extension": doc.extension, + "size_bytes": doc.size_bytes, + "doc_updated_at": doc.doc_updated_at, + "blob": doc.blob, + } + for doc in document_batch + ] + + for doc in docs: + synced_doc_ids.add(doc["id"]) + + # delete removed docs + if not existing_doc_ids: + to_delete_ids = [] + for doc_id in existing_doc_ids: + if doc_id not in synced_doc_ids: + to_delete_ids.append(doc_id) + + if to_delete_ids: + FileService.delete_docs(to_delete_ids, task["tenant_id"]) + SyncLogsService.increase_deleted_docs(task["id"], len(to_delete_ids)) + logging.info(f"Deleted {len(to_delete_ids)} documents from knowledge base {task['kb_id']} for connector {task['connector_id']}") + prefix = "[Jira] " if self.SOURCE_NAME == FileSource.JIRA else "" logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}") SyncLogsService.done(task["id"], task["connector_id"])