From 11f90cf055914189b1f2909f205a0fd4ec13f922 Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Thu, 29 Jan 2026 15:49:23 -0800 Subject: [PATCH 1/2] Refactor BigQuery write error handling and add timeout Refactor error handling in BigQuery write operations and add timeout for perform_write function. --- .../bigquery_agent_analytics_plugin.py | 88 ++++++++++++------- 1 file changed, 54 insertions(+), 34 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 84fa66eb66..e5da176dd2 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -867,41 +867,52 @@ async def _write_rows_with_retry(self, rows: list[dict[str, Any]]) -> None: async def requests_iter(): yield req - responses = await self.write_client.append_rows(requests_iter()) - async for response in responses: - error = getattr(response, "error", None) - error_code = getattr(error, "code", None) - if error_code and error_code != 0: - error_message = getattr(error, "message", "Unknown error") - logger.warning( - "BigQuery Write API returned error code %s: %s", - error_code, - error_message, - ) - if error_code in [ - _GRPC_DEADLINE_EXCEEDED, - _GRPC_INTERNAL, - _GRPC_UNAVAILABLE, - ]: # Deadline, Internal, Unavailable - raise ServiceUnavailable(error_message) - else: - if "schema mismatch" in error_message.lower(): - logger.error( - "BigQuery Schema Mismatch: %s. This usually means the" - " table schema does not match the expected schema.", - error_message, - ) + async def perform_write(): + responses = await self.write_client.append_rows(requests_iter()) + async for response in responses: + error = getattr(response, "error", None) + error_code = getattr(error, "code", None) + if error_code and error_code != 0: + error_message = getattr(error, "message", "Unknown error") + logger.warning( + "BigQuery Write API returned error code %s: %s", + error_code, + error_message, + ) + if error_code in [ + _GRPC_DEADLINE_EXCEEDED, + _GRPC_INTERNAL, + _GRPC_UNAVAILABLE, + ]: # Deadline, Internal, Unavailable + raise ServiceUnavailable(error_message) else: - logger.error("Non-retryable BigQuery error: %s", error_message) - row_errors = getattr(response, "row_errors", []) - if row_errors: - for row_error in row_errors: - logger.error("Row error details: %s", row_error) - logger.error("Row content causing error: %s", rows) - return + if "schema mismatch" in error_message.lower(): + logger.error( + "BigQuery Schema Mismatch: %s. This usually means the" + " table schema does not match the expected schema.", + error_message, + ) + else: + logger.error( + "Non-retryable BigQuery error: %s", error_message + ) + row_errors = getattr(response, "row_errors", []) + if row_errors: + for row_error in row_errors: + logger.error("Row error details: %s", row_error) + logger.error("Row content causing error: %s", rows) + return + return + + await asyncio.wait_for(perform_write(), timeout=30.0) return - except (ServiceUnavailable, TooManyRequests, InternalServerError) as e: + except ( + ServiceUnavailable, + TooManyRequests, + InternalServerError, + asyncio.TimeoutError, + ) as e: attempt += 1 if attempt > self.retry_config.max_retries: logger.error( @@ -1625,8 +1636,17 @@ def get_credentials(): @staticmethod def _atexit_cleanup(batch_processor: "BatchProcessor") -> None: """Clean up batch processor on script exit.""" - # Check if the batch_processor object is still alive - if batch_processor and not batch_processor._shutdown: + try: + # Check if the batch_processor object is still alive + if batch_processor and not batch_processor._shutdown: + pass + else: + return + except ReferenceError: + return + + if True: # Indentation anchor, logic continues below + # Emergency Flush: Rescue any logs remaining in the queue remaining_items = [] try: From 5f55c59300cf104ab7c1691b77b4b9d0e75ec2ed Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Fri, 30 Jan 2026 12:56:26 -0800 Subject: [PATCH 2/2] Refactor BigQuery client for multi-loop support Refactor BigQuery client handling to support multiple event loops. Introduce _LoopState to manage resources per loop, ensuring thread safety and proper shutdown procedures. --- .../bigquery_agent_analytics_plugin.py | 200 +++++++++++------- 1 file changed, 126 insertions(+), 74 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index e5da176dd2..2262648c65 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -31,12 +31,17 @@ import time from types import MappingProxyType from typing import Any +from typing import Awaitable from typing import Callable +from typing import Dict from typing import Optional from typing import TYPE_CHECKING +from typing import TypeVar import uuid import weakref +T = TypeVar("T") + from google.api_core import client_options from google.api_core.exceptions import InternalServerError from google.api_core.exceptions import ServiceUnavailable @@ -1469,6 +1474,14 @@ def _get_events_schema() -> list[bigquery.SchemaField]: _GLOBAL_CLIENT_LOCK = asyncio.Lock() +@dataclass +class _LoopState: + """Holds resources bound to a specific event loop.""" + + write_client: BigQueryWriteAsyncClient + batch_processor: BatchProcessor + + class BigQueryAgentAnalyticsPlugin(BasePlugin): """BigQuery Agent Analytics Plugin (v2.0 using Write API). @@ -1501,16 +1514,19 @@ def __init__( self.table_id = table_id or self.config.table_id self.location = location + self._started = False + self._is_shutting_down = False self._started = False self._is_shutting_down = False self._setup_lock = None self.client = None - self.write_client = None - self.write_stream = None - self.batch_processor = None + self._loop_states: dict[asyncio.AbstractEventLoop, _LoopState] = {} + self._write_stream_name = None # Resolved stream name self._executor = None self.offloader: Optional[GCSOffloader] = None self.parser: Optional[HybridContentParser] = None + self._schema = None + self.arrow_schema = None def _format_content_safely( self, content: Optional[types.Content] @@ -1534,10 +1550,87 @@ def _format_content_safely( logger.warning("Content formatter failed: %s", e) return "[FORMATTING FAILED]", False + async def _get_loop_state(self) -> _LoopState: + """Gets or creates the state for the current event loop.""" + loop = asyncio.get_running_loop() + if loop in self._loop_states: + return self._loop_states[loop] + + # Initialize new resources for this loop + # We DO NOT use the global client approach for multi-loop safety simpler + # or we must ensure _GLOBAL_WRITE_CLIENT usage is safe. + # The original code had a _GLOBAL_WRITE_CLIENT. + # If we want to reuse it, we must be careful. + # actually, _GLOBAL_WRITE_CLIENT is created in *A* loop. + # It cannot be shared across loops if it uses loop primitives. + # So strictly speaking, we should create a new client per loop. + # OR we assume the global client is thread-safe? + # grpc.aio clients are generally loop-bound. + # SAFE approach: Create one client per loop. + + def get_credentials(): + creds, project_id = google.auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + return creds, project_id + + creds, project_id = await loop.run_in_executor( + self._executor, get_credentials + ) + quota_project_id = getattr(creds, "quota_project_id", None) or project_id + options = ( + client_options.ClientOptions(quota_project_id=quota_project_id) + if quota_project_id + else None + ) + client_info = gapic_client_info.ClientInfo( + user_agent=f"google-adk-bq-logger/{__version__}" + ) + + write_client = BigQueryWriteAsyncClient( + credentials=creds, + client_info=client_info, + client_options=options, + ) + + # Use the resolved write stream name + if not self._write_stream_name: + # Should be set in _lazy_setup or we set it here if missing? + # _lazy_setup guarantees self.table_id etc are ready. + self._write_stream_name = f"projects/{self.project_id}/datasets/{self.dataset_id}/tables/{self.table_id}/_default" + + batch_processor = BatchProcessor( + write_client=write_client, + arrow_schema=self.arrow_schema, + write_stream=self._write_stream_name, + batch_size=self.config.batch_size, + flush_interval=self.config.batch_flush_interval, + retry_config=self.config.retry_config, + queue_max_size=self.config.queue_max_size, + shutdown_timeout=self.config.shutdown_timeout, + ) + await batch_processor.start() + + state = _LoopState(write_client, batch_processor) + self._loop_states[loop] = state + + # Register cleanup + atexit.register(self._atexit_cleanup, weakref.proxy(batch_processor)) + + return state + async def flush(self) -> None: - """Flushes any pending events to BigQuery.""" - if self.batch_processor: - await self.batch_processor.flush() + """Flushes any pending events to BigQuery. + + Flushes the processor associated with the CURRENT loop. + """ + try: + loop = asyncio.get_running_loop() + if loop in self._loop_states: + await self._loop_states[loop].batch_processor.flush() + except RuntimeError: + # No running loop or other issue + pass async def _lazy_setup(self, **kwargs) -> None: """Performs lazy initialization of BigQuery clients and resources.""" @@ -1557,46 +1650,11 @@ async def _lazy_setup(self, **kwargs) -> None: ) self.full_table_id = f"{self.project_id}.{self.dataset_id}.{self.table_id}" - self._schema = _get_events_schema() - await loop.run_in_executor(self._executor, self._ensure_schema_exists) - - if not self.write_client: - global _GLOBAL_WRITE_CLIENT - async with _GLOBAL_CLIENT_LOCK: - if _GLOBAL_WRITE_CLIENT is None: - - def get_credentials(): - creds, project_id = google.auth.default( - scopes=["https://www.googleapis.com/auth/cloud-platform"] - ) - return creds, project_id - - creds, project_id = await loop.run_in_executor( - self._executor, get_credentials - ) - quota_project_id = ( - getattr(creds, "quota_project_id", None) or project_id - ) - options = ( - client_options.ClientOptions(quota_project_id=quota_project_id) - if quota_project_id - else None - ) - client_info = gapic_client_info.ClientInfo( - user_agent=f"google-adk-bq-logger/{__version__}" - ) - # Initialize the async client in the current event loop, not in the - # executor. - _GLOBAL_WRITE_CLIENT = BigQueryWriteAsyncClient( - credentials=creds, - client_info=client_info, - client_options=options, - ) - self.write_client = _GLOBAL_WRITE_CLIENT - - self.write_stream = f"projects/{self.project_id}/datasets/{self.dataset_id}/tables/{self.table_id}/_default" + if not self._schema: + self._schema = _get_events_schema() + await loop.run_in_executor(self._executor, self._ensure_schema_exists) - if not self.batch_processor: + if not self.parser: self.arrow_schema = to_arrow_schema(self._schema) if not self.arrow_schema: raise RuntimeError("Failed to convert BigQuery schema to Arrow schema.") @@ -1617,21 +1675,9 @@ def get_credentials(): max_length=self.config.max_content_length, connection_id=self.config.connection_id, ) - self.batch_processor = BatchProcessor( - write_client=self.write_client, - arrow_schema=self.arrow_schema, - write_stream=self.write_stream, - batch_size=self.config.batch_size, - flush_interval=self.config.batch_flush_interval, - retry_config=self.config.retry_config, - queue_max_size=self.config.queue_max_size, - shutdown_timeout=self.config.shutdown_timeout, - ) - await self.batch_processor.start() - # Register cleanup to ensure logs are flushed if user forgets to close - # Use weakref to avoid circular references that prevent garbage collection - atexit.register(self._atexit_cleanup, weakref.proxy(self.batch_processor)) + # Initialize state for this loop + await self._get_loop_state() @staticmethod def _atexit_cleanup(batch_processor: "BatchProcessor") -> None: @@ -1734,22 +1780,29 @@ async def shutdown(self, timeout: float | None = None) -> None: t = timeout if timeout is not None else self.config.shutdown_timeout loop = asyncio.get_running_loop() try: - if self.batch_processor: - await self.batch_processor.shutdown(timeout=t) - if self.write_client and getattr(self.write_client, "transport", None): - # Only close the client if it's NOT the global one (unlikely with new logic, - # but good for safety if injected manually) or if we decide to handle global close differently. - # For now, we DO NOT close the global client to allow reuse. - if self.write_client is not _GLOBAL_WRITE_CLIENT: - await self.write_client.transport.close() + # Correct Multi-Loop Shutdown: + # 1. Shutdown current loop's processor directly. + if loop in self._loop_states: + await self._loop_states[loop].batch_processor.shutdown(timeout=t) + + # 2. Close clients for all states + for state in self._loop_states.values(): + if state.write_client and getattr( + state.write_client, "transport", None + ): + try: + await state.write_client.transport.close() + except Exception: + pass + + self._loop_states.clear() + if self.client: if self._executor: executor = self._executor await loop.run_in_executor(None, lambda: executor.shutdown(wait=True)) self._executor = None - self.write_client = None self.client = None - self._is_shutting_down = False except Exception as e: logger.error("Error during shutdown: %s", e, exc_info=True) self._is_shutting_down = False @@ -1760,9 +1813,8 @@ def __getstate__(self): state = self.__dict__.copy() state["_setup_lock"] = None state["client"] = None - state["write_client"] = None - state["write_stream"] = None - state["batch_processor"] = None + state["_loop_states"] = {} + state["_write_stream_name"] = None state["_executor"] = None state["offloader"] = None state["parser"] = None @@ -1931,8 +1983,8 @@ async def _log_event( "is_truncated": is_truncated, } - if self.batch_processor: - await self.batch_processor.append(row) + state = await self._get_loop_state() + await state.batch_processor.append(row) # --- UPDATED CALLBACKS FOR V1 PARITY ---