diff --git a/.gitignore b/.gitignore index 139597f..ab9464c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,8 @@ - - +/dist/ +/src/build/ +/src/dist/ +/src/setup.cfg +__pycache__/ +*.egg-info/ +.coverage +/bazel-* diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..be7798e --- /dev/null +++ b/.pylintrc @@ -0,0 +1,429 @@ +# This Pylint rcfile contains a best-effort configuration to uphold the +# best-practices and style described in the Google Python style guide: +# https://google.github.io/styleguide/pyguide.html +# +# Its canonical open-source location is: +# https://google.github.io/styleguide/pylintrc + +[MASTER] + +# Files or directories to be skipped. They should be base names, not paths. +ignore=third_party + +# Files or directories matching the regex patterns are skipped. The regex +# matches against base names, not paths. +ignore-patterns= + +# Pickle collected data for later comparisons. +persistent=no + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=4 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +#enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=abstract-method, + apply-builtin, + arguments-differ, + attribute-defined-outside-init, + backtick, + bad-option-value, + basestring-builtin, + buffer-builtin, + c-extension-no-member, + consider-using-enumerate, + cmp-builtin, + cmp-method, + coerce-builtin, + coerce-method, + delslice-method, + div-method, + duplicate-code, + eq-without-hash, + execfile-builtin, + file-builtin, + filter-builtin-not-iterating, + fixme, + getslice-method, + global-statement, + hex-method, + idiv-method, + implicit-str-concat, + import-error, + import-self, + import-star-module-level, + inconsistent-return-statements, + input-builtin, + intern-builtin, + invalid-str-codec, + locally-disabled, + long-builtin, + long-suffix, + map-builtin-not-iterating, + misplaced-comparison-constant, + missing-function-docstring, + metaclass-assignment, + next-method-called, + next-method-defined, + no-absolute-import, + no-else-break, + no-else-continue, + no-else-raise, + no-else-return, + no-init, # added + no-member, + no-name-in-module, + no-self-use, + nonzero-method, + oct-method, + old-division, + old-ne-operator, + old-octal-literal, + old-raise-syntax, + parameter-unpacking, + print-statement, + raising-string, + range-builtin-not-iterating, + raw_input-builtin, + rdiv-method, + reduce-builtin, + relative-import, + reload-builtin, + round-builtin, + setslice-method, + signature-differs, + standarderror-builtin, + suppressed-message, + sys-max-int, + too-few-public-methods, + too-many-ancestors, + too-many-arguments, + too-many-boolean-expressions, + too-many-branches, + too-many-instance-attributes, + too-many-locals, + too-many-nested-blocks, + too-many-public-methods, + too-many-return-statements, + too-many-statements, + trailing-newlines, + unichr-builtin, + unicode-builtin, + unnecessary-pass, + unpacking-in-except, + useless-else-on-loop, + useless-object-inheritance, + useless-suppression, + using-cmp-argument, + wrong-import-order, + xrange-builtin, + zip-builtin-not-iterating, + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[BASIC] + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl + +# Regular expression matching correct function names +function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression matching correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct constant names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression matching correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression matching correct module names +module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ + +# Regular expression matching correct method names +method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=10 + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=80 + +# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt +# lines made too long by directives to pytype. + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=(?x)( + ^\s*(\#\ )??$| + ^\s*(from\s+\S+\s+)?import\s+.+$) + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=yes + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. The internal Google style guide mandates 2 +# spaces. Google's externaly-published style guide says 4, consistent with +# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google +# projects (like TensorFlow). +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=TODO + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=yes + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging,absl.logging,tensorflow.io.logging + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub, + TERMIOS, + Bastion, + rexec, + sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant, absl + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls, + class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=StandardError, + Exception, + BaseException diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 0000000..fdd0723 --- /dev/null +++ b/.style.yapf @@ -0,0 +1,2 @@ +[style] +based_on_style = yapf diff --git a/BUILD b/BUILD new file mode 100644 index 0000000..ae821f1 --- /dev/null +++ b/BUILD @@ -0,0 +1,2 @@ +package(default_visibility = ["//visibility:public"]) + diff --git a/LICENSE b/LICENSE index b5d5055..497d805 100644 --- a/LICENSE +++ b/LICENSE @@ -187,7 +187,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2015 Google Inc. + Copyright [yyyy] [name of copyright owner Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index 7e1362f..9f171b4 100644 --- a/README.md +++ b/README.md @@ -1,47 +1,48 @@ -# Python Cloud Debugger +# Python Snapshot Debugger Agent -Google [Cloud Debugger](https://cloud.google.com/tools/cloud-debugger/) for -Python 2.7. +[Snapshot debugger](https://github.com/GoogleCloudPlatform/snapshot-debugger/) +agent for Python 3.6, Python 3.7, Python 3.8, Python 3.9, and Python 3.10. -## Overview - -The Cloud Debugger lets you inspect the state of an application at any code -location without stopping or slowing it down. The debugger makes it easier to -view the application state without adding logging statements. - -You can use the Cloud Debugger on both production and staging instances of your -application. The debugger never pauses the application for more than a few -milliseconds. In most cases, this is not noticeable by users. The Cloud Debugger -gives a read-only experience. Application variables can't be changed through the -debugger. - -The Cloud Debugger attaches to all instances of the application. The call stack -and the variables come from the first instance to take the snapshot. - -The Python Cloud Debugger is only supported on Linux at the moment. It was tested -on Debian Linux, but it should work on other distributions as well. -The Cloud Debugger consists of 3 primary components: +## Project Status: Archived -1. The debugger agent. This repo implements one for Python 2.7. -2. Cloud Debugger backend that stores the list of snapshots for each debuggee. - You can explore the API using the - [APIs Explorer](https://developers.google.com/apis-explorer/#p/clouddebugger/v2/). -3. User interface for the debugger implemented using the Cloud Debugger API. - Currently the only option for Python is the - [Google Developers Console](https://console.developers.google.com). The - UI requires that the source code is submitted to - [Google Cloud Repo](https://cloud.google.com/tools/repo/cloud-repositories/). - More options (including browsing local source files) are coming soon. +This project has been archived and is no longer supported. There will be no +further bug fixes or security patches. The repository can be forked by users +if they want to maintain it going forward. -This document only focuses on the Python debugger agent. Please see the -this [page](https://cloud.google.com/tools/cloud-debugger/debugging) for -explanation how to debug an application with the Cloud Debugger. -## Options for Getting Help +## Overview -1. StackOverflow: http://stackoverflow.com/questions/tagged/google-cloud-debugger -2. Google Group: cdbg-feedback@google.com +Snapshot Debugger lets you inspect the state +of a running cloud application, at any code location, without stopping or +slowing it down. It is not your traditional process debugger but rather an +always on, whole app debugger taking snapshots from any instance of the app. + +Snapshot Debugger is safe for use with production apps or during development. The +Python debugger agent only few milliseconds to the request latency when a debug +snapshot is captured. In most cases, this is not noticeable to users. +Furthermore, the Python debugger agent does not allow modification of +application state in any way, and has close to zero impact on the app instances. + +Snapshot Debugger attaches to all instances of the app providing the ability to +take debug snapshots and add logpoints. A snapshot captures the call-stack and +variables from any one instance that executes the snapshot location. A logpoint +writes a formatted message to the application log whenever any instance of the +app executes the logpoint location. + +The Python debugger agent is only supported on Linux at the moment. It was +tested on Debian Linux, but it should work on other distributions as well. + +Snapshot Debugger consists of 3 primary components: + +1. The Python debugger agent (this repo implements one for CPython 3.6, + 3.7, 3.8, 3.9, and 3.10). +2. A Firebase Realtime Database for storing and managing snapshots/logpoints. + Explore the + [schema](https://github.com/GoogleCloudPlatform/snapshot-debugger/blob/main/docs/SCHEMA.md). +3. User interface, including a command line interface + [`snapshot-dbg-cli`](https://pypi.org/project/snapshot-dbg-cli/) and a + [VSCode extension](https://github.com/GoogleCloudPlatform/snapshot-debugger/tree/main/snapshot_dbg_extension) ## Installation @@ -51,108 +52,236 @@ The easiest way to install the Python Cloud Debugger is with PyPI: pip install google-python-cloud-debugger ``` -Alternatively, download the *egg* package from -[Releases](https://github.com/GoogleCloudPlatform/cloud-debug-python/releases) -and install the debugger agent with: +You can also build the agent from source code: ```shell -easy_install google_python_cloud_debugger-py2.7-linux-x86_64.egg +git clone https://github.com/GoogleCloudPlatform/cloud-debug-python.git +cd cloud-debug-python/src/ +./build.sh +pip install dist/google_python_cloud_debugger-*.whl ``` -You can also build the agent from source code (OS dependencies are listed in -[build.sh](https://github.com/GoogleCloudPlatform/cloud-debug-python/blob/master/src/build.sh) -script): +Note that the build script assumes some dependencies. To install these +dependencies on Debian, run this command: ```shell -git clone https://github.com/GoogleCloudPlatform/cloud-debug-python.git -cd cloud-debug-python/src/ -./build.sh -easy_install dist/google_python_cloud_debugger-*.egg +sudo apt-get -y -q --no-install-recommends install \ + curl ca-certificates gcc build-essential cmake \ + python3 python3-dev python3-pip ``` +If the desired target version of Python is not the default version of +the 'python3' command on your system, run the build script as `PYTHON=python3.x +./build.sh`. + +### Alpine Linux + +The Python agent is not regularly tested on Alpine Linux, and support will be on +a best effort basis. The [Dockerfile](alpine/Dockerfile) shows how to build a +minimal image with the agent installed. + ## Setup -### Google Compute Engine +### Google Cloud Platform -1. First, make sure that you created the VM with this option enabled: +1. First, make sure that the VM has the + [required scopes](https://github.com/GoogleCloudPlatform/snapshot-debugger/blob/main/docs/configuration.md#access-scopes). - > Allow API access to all Google Cloud services in the same project. +2. Install the Python debugger agent as explained in the + [Installation](#installation) section. - This option lets the debugger agent authenticate with the machine account - of the Virtual Machine. +3. Enable the debugger in your application: - It is possible to use Python Cloud Debugger without it. Please see the - [next section](#Service_Account) for details. + ```python + # Attach Python Cloud Debugger + try: + import googleclouddebugger + googleclouddebugger.enable(module='[MODULE]', version='[VERSION]') + except ImportError: + pass + ``` + + Where: + + * `[MODULE]` is the name of your app. This, along with the version, is + used to identify the debug target in the UI.
+ Example values: `MyApp`, `Backend`, or `Frontend`. + + * `[VERSION]` is the app version (for example, the build ID). The UI + displays the running version as `[MODULE] - [VERSION]`.
+ Example values: `v1.0`, `build_147`, or `v20170714`. + +### Outside Google Cloud Platform + +To use the Python debugger agent on machines not hosted by Google Cloud +Platform, you must set up credentials to authenticate with Google Cloud APIs. By +default, the debugger agent tries to find the [Application Default +Credentials](https://cloud.google.com/docs/authentication/production) on the +system. This can either be from your personal account or a dedicated service +account. -1. Install the debugger agent as explained in the [Installation](#Installation) - section. +#### Personal Account -2. Enable the debugger in your application using one of the two options: +1. Set up Application Default Credentials through + [gcloud](https://cloud.google.com/sdk/gcloud/reference/auth/application-default/login). - _Option A_: add this code to the beginning of your `main()` function: + ```shell + gcloud auth application-default login + ``` + +2. Follow the rest of the steps in the [GCP](#google-cloud-platform) section. + +#### Service Account + +1. Use the Google Cloud Console Service Accounts + [page](https://console.cloud.google.com/iam-admin/serviceaccounts/project) + to create a credentials file for an existing or new service account. The + service account must have at least the `roles/firebasedatabase.admin` role. + +2. Once you have the service account credentials JSON file, deploy it alongside + the Python debugger agent. + +3. Set the `GOOGLE_APPLICATION_CREDENTIALS` environment variable. + + ```shell + export GOOGLE_APPLICATION_CREDENTIALS=/path/to/credentials.json + ``` + + Alternatively, you can provide the path to the credentials file directly to + the debugger agent. ```python # Attach Python Cloud Debugger try: import googleclouddebugger - googleclouddebugger.AttachDebugger() + googleclouddebugger.enable( + module='[MODULE]', + version='[VERSION]', + service_account_json_file='/path/to/credentials.json') except ImportError: pass ``` +4. Follow the rest of the steps in the [GCP](#google-cloud-platform) section. - _Option B_: run the debugger agent as a module: +### Django Web Framework -
-    python -m googleclouddebugger -- myapp.py
-    
+You can use the Cloud Debugger to debug Django web framework applications. -### Service Account +The best way to enable the Cloud Debugger with Django is to add the following +code fragment to your `manage.py` file: -Service account authentication lets you run the debugger agent on any Linux -machine, including outside of [Google Cloud Platform](https://cloud.google.com). -The debugger agent authenticates against the backend with the service account -created in [Google Developers Console](https://console.developers.google.com). -If your application runs on Google Compute Engine, -[metadata service authentication](#Google_Compute_Engine) is an easier option. +```python +# Attach the Python Cloud debugger (only the main server process). +if os.environ.get('RUN_MAIN') or '--noreload' in sys.argv: + try: + import googleclouddebugger + googleclouddebugger.enable(module='[MODULE]', version='[VERSION]') + except ImportError: + pass +``` -The first step for this setup is to create the service account in .p12 format. -Please see this [page](https://cloud.google.com/storage/docs/authentication?hl=en#generating-a-private-key) -for detailed instructions. If you don't have a Google Cloud Platform project, -you can create one for free on [Google Developers Console](https://console.developers.google.com). +Alternatively, you can pass the `--noreload` flag when running the Django +`manage.py` and use any one of the option A and B listed earlier. Note that +using the `--noreload` flag disables the autoreload feature in Django, which +means local changes to files will not be automatically picked up by Django. -Once you have the service account, please note the service account e-mail, -[project ID and project number](https://developers.google.com/console/help/new/#projectnumber). -Then copy the .p12 file to all the machines that run your application. +## Historical note -Then, enable the debugger agent in a similary way as described in -the [previous](#Google_Compute_Engine) section: +Version 3.x of this agent supported both the now shutdown Cloud Debugger service +(by default) and the +[Snapshot Debugger](https://github.com/GoogleCloudPlatform/snapshot-debugger/) +(Firebase RTDB backend) by setting the `use_firebase` flag to true. Version 4.0 +removed support for the Cloud Debugger service, making the Snapshot Debugger the +default. To note the `use_firebase` flag is now obsolete, but still present for +backward compatibility. -_Option A_: add this code to the beginning of your `main()` function: +## Flag Reference + +The agent offers various flags to configure its behavior. Flags can be specified +as keyword arguments: ```python -# Attach Python Cloud Debugger -try: - import googleclouddebugger - googleclouddebugger.AttachDebugger( - enable_service_account=True, - project_id='my-gcp-project-id', - project_number='123456789', - service_account_email='123@developer.gserviceaccount.com', - service_account_p12_file='/opt/cdbg/gcp.p12') -except ImportError: - pass +googleclouddebugger.enable(flag_name='flag_value') +``` + +or as command line arguments when running the agent as a module: + +```shell +python -m googleclouddebugger --flag_name=flag_value -- myapp.py +``` + +The following flags are available: + +`module`: A name for your app. This, along with the version, is used to identify +the debug target in the UI.
+Example values: `MyApp`, `Backend`, or `Frontend`. + +`version`: A version for your app. The UI displays the running version as +`[MODULE] - [VERSION]`.
+If not provided, the UI will display the generated debuggee ID instead.
+Example values: `v1.0`, `build_147`, or `v20170714`. + +`service_account_json_file`: Path to JSON credentials of a [service +account](https://cloud.google.com/iam/docs/service-accounts) to use for +authentication. If not provided, the agent will fall back to [Application +Default Credentials](https://cloud.google.com/docs/authentication/production) +which are automatically available on machines hosted on GCP, or can be set via +`gcloud auth application-default login` or the `GOOGLE_APPLICATION_CREDENTIALS` +environment variable. + +`firebase_db_url`: Url pointing to a configured Firebase Realtime Database for +the agent to use to store snapshot data. +https://**PROJECT_ID**-cdbg.firebaseio.com will be used if not provided. where +**PROJECT_ID** is your project ID. + +## Development + +The following instructions are intended to help with modifying the codebase. + +### Testing + +#### Unit tests + +Run the `build_and_test.sh` script from the root of the repository to build and +run the unit tests using the locally installed version of Python. + +Run `bazel test tests/cpp:all` from the root of the repository to run unit +tests against the C++ portion of the codebase. + +#### Local development + +You may want to run an agent with local changes in an application in order to +validate functionality in a way that unit tests don't fully cover. To do this, +you will need to build the agent: +``` +cd src +./build.sh +cd .. ``` -_Option B_: run the debugger agent as a module: - -
-python \
-    -m googleclouddebugger \
-    --enable_service_account_auth=1 \
-    --project_id=my-gcp-project-id \
-    --project_number=123456789 \
-    --service_account_email=123@developer.gserviceaccount.com \
-    --service_account_p12_file=/opt/cdbg/gcp.p12 \
-    -- \
-    myapp.py
-
+The built agent will be available in the `src/dist` directory. You can now +force the installation of the agent using: +``` +pip3 install src/dist/* --force-reinstall +``` + +You can now run your test application using the development build of the agent +in whatever way you desire. + +It is recommended that you do this within a +[virtual environment](https://docs.python.org/3/library/venv.html). + +### Build & Release (for project owners) + +Before performing a release, be sure to update the version number in +`src/googleclouddebugger/version.py`. Tag the commit that increments the +version number (eg. `v3.1`) and create a Github release. + +Run the `build-dist.sh` script from the root of the repository to build, +test, and generate the distribution whls. You may need to use `sudo` +depending on your system's docker setup. + +Build artifacts will be placed in `/dist` and can be pushed to pypi by running: +``` +twine upload dist/*.whl +``` diff --git a/WORKSPACE b/WORKSPACE new file mode 100644 index 0000000..55013f2 --- /dev/null +++ b/WORKSPACE @@ -0,0 +1,50 @@ +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "bazel_skylib", + sha256 = "74d544d96f4a5bb630d465ca8bbcfe231e3594e5aae57e1edbf17a6eb3ca2506", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz", + "https://github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz", + ], +) +load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") +bazel_skylib_workspace() + +http_archive( + name = "com_github_gflags_gflags", + sha256 = "34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf", + strip_prefix = "gflags-2.2.2", + urls = ["https://github.com/gflags/gflags/archive/v2.2.2.tar.gz"], +) + +http_archive( + name = "com_github_google_glog", + sha256 = "21bc744fb7f2fa701ee8db339ded7dce4f975d0d55837a97be7d46e8382dea5a", + strip_prefix = "glog-0.5.0", + urls = ["https://github.com/google/glog/archive/v0.5.0.zip"], +) + +# Pinning to 1.12.1, the last release that supports C++11 +http_archive( + name = "com_google_googletest", + urls = ["https://github.com/google/googletest/archive/58d77fa8070e8cec2dc1ed015d66b454c8d78850.tar.gz"], + strip_prefix = "googletest-58d77fa8070e8cec2dc1ed015d66b454c8d78850", +) + +# Used to build against Python.h +http_archive( + name = "pybind11_bazel", + strip_prefix = "pybind11_bazel-faf56fb3df11287f26dbc66fdedf60a2fc2c6631", + urls = ["https://github.com/pybind/pybind11_bazel/archive/faf56fb3df11287f26dbc66fdedf60a2fc2c6631.zip"], +) + +http_archive( + name = "pybind11", + build_file = "@pybind11_bazel//:pybind11.BUILD", + strip_prefix = "pybind11-2.9.2", + urls = ["https://github.com/pybind/pybind11/archive/v2.9.2.tar.gz"], +) +load("@pybind11_bazel//:python_configure.bzl", "python_configure") +python_configure(name = "local_config_python")#, python_interpreter_target = interpreter) + diff --git a/alpine/Dockerfile b/alpine/Dockerfile new file mode 100644 index 0000000..888e448 --- /dev/null +++ b/alpine/Dockerfile @@ -0,0 +1,36 @@ +# WARNING: Stackdriver Debugger is not regularly tested on the Alpine Linux +# platform and support will be on a best effort basis. +# Sample Alpine Linux image including Python and the Stackdriver Debugger agent. +# To build: +# docker build . # Python 2.7 +# docker build --build-arg PYTHON_VERSION=3 . # Python 3.6 +# The final image size should be around 50-60 MiB. + +# Stage 1: Build the agent. +FROM alpine:latest + +ARG PYTHON_VERSION=2 +ENV PYTHON_VERSION=$PYTHON_VERSION +ENV PYTHON=python${PYTHON_VERSION} + +RUN apk update +RUN apk add bash git curl gcc g++ make cmake ${PYTHON}-dev +RUN if [ $PYTHON_VERSION == "2" ]; then apk add py-setuptools; fi + +RUN git clone https://github.com/GoogleCloudPlatform/cloud-debug-python +RUN PYTHON=$PYTHON bash cloud-debug-python/src/build.sh + + +# Stage 2: Create minimal image with just Python and the debugger. +FROM alpine:latest + +ARG PYTHON_VERSION=2 +ENV PYTHON_VERSION=$PYTHON_VERSION +ENV PYTHON=python${PYTHON_VERSION} + +RUN apk --no-cache add $PYTHON libstdc++ +RUN if [ $PYTHON_VERSION == "2" ]; then apk add --no-cache py-setuptools; fi + +COPY --from=0 /cloud-debug-python/src/dist/*.egg . +RUN $PYTHON -m easy_install *.egg +RUN rm *.egg diff --git a/build-dist.sh b/build-dist.sh new file mode 100755 index 0000000..fffe140 --- /dev/null +++ b/build-dist.sh @@ -0,0 +1,4 @@ +DOCKER_IMAGE='quay.io/pypa/manylinux2014_x86_64' + +docker pull "$DOCKER_IMAGE" +docker container run -t --rm -v "$(pwd)":/io "$DOCKER_IMAGE" /io/src/build-wheels.sh diff --git a/build_and_test.sh b/build_and_test.sh new file mode 100755 index 0000000..4ce82b9 --- /dev/null +++ b/build_and_test.sh @@ -0,0 +1,18 @@ +#!/bin/bash -e + +# Clean up any previous generated test files. +rm -rf tests/py/__pycache__ + +cd src +./build.sh +cd .. + +python3 -m venv /tmp/cdbg-venv +source /tmp/cdbg-venv/bin/activate +pip3 install -r requirements_dev.txt +pip3 install src/dist/* --force-reinstall +python3 -m pytest tests/py +deactivate + +# Clean up any generated test files. +rm -rf tests/py/__pycache__ diff --git a/firebase-sample/app.py b/firebase-sample/app.py new file mode 100644 index 0000000..0916e7c --- /dev/null +++ b/firebase-sample/app.py @@ -0,0 +1,12 @@ +import googleclouddebugger + +googleclouddebugger.enable(use_firebase=True) + +from flask import Flask + +app = Flask(__name__) + + +@app.route("/") +def hello_world(): + return "

Hello World!

" diff --git a/firebase-sample/build-and-run.sh b/firebase-sample/build-and-run.sh new file mode 100755 index 0000000..a0cc7b1 --- /dev/null +++ b/firebase-sample/build-and-run.sh @@ -0,0 +1,20 @@ +#!/bin/bash -e + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd "${SCRIPT_DIR}/.." + +cd src +./build.sh +cd .. + +python3 -m venv /tmp/cdbg-venv +source /tmp/cdbg-venv/bin/activate +pip3 install -r requirements.txt +pip3 install src/dist/* --force-reinstall + +cd firebase-sample +pip3 install -r requirements.txt +python3 -m flask run +cd .. + +deactivate diff --git a/firebase-sample/requirements.txt b/firebase-sample/requirements.txt new file mode 100644 index 0000000..7e10602 --- /dev/null +++ b/firebase-sample/requirements.txt @@ -0,0 +1 @@ +flask diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..784eb7e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +firebase_admin>=5.3.0 +pyyaml diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 0000000..89aa308 --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,4 @@ +-r requirements.txt +absl-py +pytest +requests-mock diff --git a/src/build-wheels.sh b/src/build-wheels.sh new file mode 100755 index 0000000..8477d84 --- /dev/null +++ b/src/build-wheels.sh @@ -0,0 +1,94 @@ +#!/bin/bash -e + +GFLAGS_URL=https://github.com/gflags/gflags/archive/v2.2.2.tar.gz +GLOG_URL=https://github.com/google/glog/archive/v0.4.0.tar.gz + +SUPPORTED_VERSIONS=(cp36-cp36m cp37-cp37m cp38-cp38 cp39-cp39 cp310-cp310) + +ROOT=$(cd $(dirname "${BASH_SOURCE[0]}") >/dev/null; /bin/pwd -P) + +# Parallelize the build over N threads where N is the number of cores * 1.5. +PARALLEL_BUILD_OPTION="-j $(($(nproc 2> /dev/null || echo 4)*3/2))" + +# Clean up any previous build/test files. +rm -rf \ + ${ROOT}/build \ + ${ROOT}/dist \ + ${ROOT}/setup.cfg \ + ${ROOT}/google_python_cloud_debugger.egg-info \ + /io/dist \ + /io/tests/py/__pycache__ + +# Create directory for third-party libraries. +mkdir -p ${ROOT}/build/third_party + +# Build and install gflags to build/third_party. +pushd ${ROOT}/build/third_party +curl -Lk ${GFLAGS_URL} -o gflags.tar.gz +tar xzvf gflags.tar.gz +cd gflags-* +mkdir build +cd build +cmake -DCMAKE_CXX_FLAGS=-fpic \ + -DGFLAGS_NAMESPACE=google \ + -DCMAKE_INSTALL_PREFIX:PATH=${ROOT}/build/third_party \ + .. +make ${PARALLEL_BUILD_OPTION} +make install +popd + +# Build and install glog to build/third_party. +pushd ${ROOT}/build/third_party +curl -L ${GLOG_URL} -o glog.tar.gz +tar xzvf glog.tar.gz +cd glog-* +mkdir build +cd build +cmake -DCMAKE_CXX_FLAGS=-fpic \ + -DCMAKE_PREFIX_PATH=${ROOT}/build/third_party \ + -DCMAKE_INSTALL_PREFIX:PATH=${ROOT}/build/third_party \ + .. +make ${PARALLEL_BUILD_OPTION} +make install +popd + +# Extract build version from version.py +grep "^ *__version__ *=" "/io/src/googleclouddebugger/version.py" | grep -Eo "[0-9.]+" > "version.txt" +AGENT_VERSION=$(cat "version.txt") +echo "Building distribution packages for python agent version ${AGENT_VERSION}" + +# Create setup.cfg file and point to the third_party libraries we just build. +echo "[global] +verbose=1 + +[build_ext] +include_dirs=${ROOT}/build/third_party/include +library_dirs=${ROOT}/build/third_party/lib:${ROOT}/build/third_party/lib64" > ${ROOT}/setup.cfg + +# Build the Python Cloud Debugger agent. +pushd ${ROOT} + +for PY_VERSION in ${SUPPORTED_VERSIONS[@]}; do + echo "Building the ${PY_VERSION} agent" + "/opt/python/${PY_VERSION}/bin/pip" install -r /io/requirements_dev.txt + "/opt/python/${PY_VERSION}/bin/pip" wheel /io/src --no-deps -w /tmp/dist/ + PACKAGE_NAME="google_python_cloud_debugger-${AGENT_VERSION}" + WHL_FILENAME="${PACKAGE_NAME}-${PY_VERSION}-linux_x86_64.whl" + auditwheel repair "/tmp/dist/${WHL_FILENAME}" -w /io/dist/ + + echo "Running tests" + "/opt/python/${PY_VERSION}/bin/pip" install google-python-cloud-debugger --no-index -f /io/dist + "/opt/python/${PY_VERSION}/bin/pytest" /io/tests/py +done + +popd + +# Clean up temporary directories. +rm -rf \ + ${ROOT}/build \ + ${ROOT}/setup.cfg \ + ${ROOT}/google_python_cloud_debugger.egg-info \ + /io/tests/py/__pycache__ + +echo "Build artifacts are in the dist directory" + diff --git a/src/build.sh b/src/build.sh old mode 100644 new mode 100755 index 2b3b16c..f61ef2f --- a/src/build.sh +++ b/src/build.sh @@ -20,9 +20,10 @@ # debugger is currently only supported on Linux. # # The build script assumes Python, cmake, curl and gcc are installed. -# To install those on Debian, run this commandd: -# sudo apt-get install curl ca-certificates gcc build-essential cmake \ -# python python-dev libpython2.7 python-setuptools +# To install these dependencies on Debian, run this commandd: +# sudo apt-get -y -q --no-install-recommends install \ +# curl ca-certificates gcc build-essential cmake \ +# python python-dev libpython2.7 python-setuptools # # The Python Cloud Debugger agent uses glog and gflags libraries. We build them # first. Then we use setuptools to build the debugger agent. The entire @@ -32,8 +33,8 @@ # Home page of glog: https://github.com/google/glog # -GFLAGS_URL=https://github.com/gflags/gflags/archive/v2.1.2.tar.gz -GLOG_URL=https://github.com/google/glog/archive/v0.3.4.tar.gz +GFLAGS_URL=https://github.com/gflags/gflags/archive/v2.2.2.tar.gz +GLOG_URL=https://github.com/google/glog/archive/v0.4.0.tar.gz ROOT=$(cd $(dirname "${BASH_SOURCE[0]}") >/dev/null; /bin/pwd -P) @@ -41,7 +42,11 @@ ROOT=$(cd $(dirname "${BASH_SOURCE[0]}") >/dev/null; /bin/pwd -P) PARALLEL_BUILD_OPTION="-j $(($(nproc 2> /dev/null || echo 4)*3/2))" # Clean up any previous build files. -rm -rf ${ROOT}/build ${ROOT}/dist ${ROOT}/setup.cfg +rm -rf \ + ${ROOT}/build \ + ${ROOT}/dist \ + ${ROOT}/setup.cfg \ + ${ROOT}/google_python_cloud_debugger.egg-info # Create directory for third-party libraries. mkdir -p ${ROOT}/build/third_party @@ -66,9 +71,12 @@ pushd ${ROOT}/build/third_party curl -L ${GLOG_URL} -o glog.tar.gz tar xzvf glog.tar.gz cd glog-* -./configure --with-pic \ - --prefix=${ROOT}/build/third_party \ - --with-gflags=${ROOT}/build/third_party +mkdir build +cd build +cmake -DCMAKE_CXX_FLAGS=-fpic \ + -DCMAKE_PREFIX_PATH=${ROOT}/build/third_party \ + -DCMAKE_INSTALL_PREFIX:PATH=${ROOT}/build/third_party \ + .. make ${PARALLEL_BUILD_OPTION} make install popd @@ -79,10 +87,16 @@ verbose=1 [build_ext] include_dirs=${ROOT}/build/third_party/include -library_dirs=${ROOT}/build/third_party/lib" > ${ROOT}/setup.cfg +library_dirs=${ROOT}/build/third_party/lib:${ROOT}/build/third_party/lib64" > ${ROOT}/setup.cfg # Build the Python Cloud Debugger agent. pushd ${ROOT} -python setup.py bdist_egg +# Use custom python command if variable is set +"${PYTHON:-python3}" -m pip wheel . --no-deps -w dist popd +# Clean up temporary directories. +rm -rf \ + ${ROOT}/build \ + ${ROOT}/setup.cfg \ + ${ROOT}/google_python_cloud_debugger.egg-info diff --git a/src/googleclouddebugger/BUILD b/src/googleclouddebugger/BUILD new file mode 100644 index 0000000..c0d6ae7 --- /dev/null +++ b/src/googleclouddebugger/BUILD @@ -0,0 +1,103 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "common", + hdrs = ["common.h"], + deps = [ + "@com_github_google_glog//:glog", + "@local_config_python//:python_headers", + ], +) + +cc_library( + name = "nullable", + hdrs = ["nullable.h"], + deps = [ + ":common", + ], +) + +cc_library( + name = "python_util", + srcs = ["python_util.cc"], + hdrs = ["python_util.h"], + deps = [ + ":common", + ":nullable", + "//src/third_party:pylinetable", + ], +) + + +cc_library( + name = "python_callback", + srcs = ["python_callback.cc"], + hdrs = ["python_callback.h"], + deps = [ + ":common", + ":python_util", + ], +) + +cc_library( + name = "leaky_bucket", + srcs = ["leaky_bucket.cc"], + hdrs = ["leaky_bucket.h"], + deps = [ + ":common", + ], +) + +cc_library( + name = "rate_limit", + srcs = ["rate_limit.cc"], + hdrs = ["rate_limit.h"], + deps = [ + ":common", + ":leaky_bucket", + ], +) + +cc_library( + name = "bytecode_manipulator", + srcs = ["bytecode_manipulator.cc"], + hdrs = ["bytecode_manipulator.h"], + deps = [ + ":common", + ], +) + +cc_library( + name = "bytecode_breakpoint", + srcs = ["bytecode_breakpoint.cc"], + hdrs = ["bytecode_breakpoint.h"], + deps = [ + ":bytecode_manipulator", + ":common", + ":python_callback", + ":python_util", + ], +) + +cc_library( + name = "immutability_tracer", + srcs = ["immutability_tracer.cc"], + hdrs = ["immutability_tracer.h"], + deps = [ + ":common", + ":python_util", + ], +) + +cc_library( + name = "conditional_breakpoint", + srcs = ["conditional_breakpoint.cc"], + hdrs = ["conditional_breakpoint.h"], + deps = [ + ":common", + ":immutability_tracer", + ":python_util", + ":rate_limit", + ":leaky_bucket", + ], +) diff --git a/src/googleclouddebugger/__init__.py b/src/googleclouddebugger/__init__.py index ccbfc7a..378f6a7 100644 --- a/src/googleclouddebugger/__init__.py +++ b/src/googleclouddebugger/__init__.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Main module for Python Cloud Debugger. The debugger is enabled in a very similar way to enabling pdb. @@ -26,52 +25,70 @@ import os import sys -import appengine_pretty_printers -import breakpoints_manager -import capture_collector -import cdbg_native -import gcp_hub_client +from . import appengine_pretty_printers +from . import breakpoints_manager +from . import collector +from . import error_data_visibility_policy +from . import firebase_client +from . import glob_data_visibility_policy +from . import yaml_data_visibility_config_reader +from . import cdbg_native +from . import version -# Versioning scheme: MAJOR.MINOR -# The major version should only change on breaking changes. Minor version -# changes go between regular updates. Instances running debuggers with -# different major versions will show up as two different debuggees. -__version__ = '1.9' +__version__ = version.__version__ _flags = None +_backend_client = None +_breakpoints_manager = None def _StartDebugger(): - global _hub_client + """Configures and starts the debugger.""" + global _backend_client global _breakpoints_manager cdbg_native.InitializeModule(_flags) + cdbg_native.LogInfo( + f'Initializing Cloud Debugger Python agent version: {__version__}') + + _backend_client = firebase_client.FirebaseClient() + _backend_client.SetupAuth( + _flags.get('project_id'), _flags.get('service_account_json_file'), + _flags.get('firebase_db_url')) - _hub_client = gcp_hub_client.GcpHubClient() - _breakpoints_manager = breakpoints_manager.BreakpointsManager(_hub_client) + visibility_policy = _GetVisibilityPolicy() + + _breakpoints_manager = breakpoints_manager.BreakpointsManager( + _backend_client, visibility_policy) # Set up loggers for logpoints. - capture_collector.log_info_message = logging.info - capture_collector.log_warning_message = logging.warning - capture_collector.log_error_message = logging.error + collector.SetLogger(logging.getLogger()) - """Configures and starts the debugger.""" - capture_collector.CaptureCollector.pretty_printers.append( + collector.CaptureCollector.pretty_printers.append( appengine_pretty_printers.PrettyPrinter) - _hub_client.on_active_breakpoints_changed = ( + _backend_client.on_active_breakpoints_changed = ( _breakpoints_manager.SetActiveBreakpoints) - _hub_client.on_idle = _breakpoints_manager.CheckBreakpointsExpiration - if _flags.get('enable_service_account_auth') in ('1', 'true', True): - _hub_client.EnableServiceAccountAuth( - _flags['project_id'], - _flags['project_number'], - _flags['service_account_email'], - _flags['service_account_p12_file']) - else: - _hub_client.EnableGceAuth() - _hub_client.InitializeDebuggeeLabels(_flags) - _hub_client.Start() + _backend_client.on_idle = _breakpoints_manager.CheckBreakpointsExpiration + + _backend_client.InitializeDebuggeeLabels(_flags) + _backend_client.Start() + + +def _GetVisibilityPolicy(): + """If a debugger configuration is found, create a visibility policy.""" + try: + visibility_config = yaml_data_visibility_config_reader.OpenAndRead() + except yaml_data_visibility_config_reader.Error as err: + return error_data_visibility_policy.ErrorDataVisibilityPolicy( + f'Could not process debugger config: {err}') + + if visibility_config: + return glob_data_visibility_policy.GlobDataVisibilityPolicy( + visibility_config.blacklist_patterns, + visibility_config.whitelist_patterns) + + return None def _DebuggerMain(): @@ -100,25 +117,30 @@ def _DebuggerMain(): sys.path[0] = os.path.dirname(app_path) - import __main__ # pylint: disable=g-import-not-at-top + import __main__ # pylint: disable=import-outside-toplevel __main__.__dict__.clear() - __main__.__dict__.update({'__name__': '__main__', - '__file__': app_path, - '__builtins__': __builtins__}) + __main__.__dict__.update({ + '__name__': '__main__', + '__file__': app_path, + '__builtins__': __builtins__ + }) locals = globals = __main__.__dict__ # pylint: disable=redefined-builtin sys.modules['__main__'] = __main__ - exec 'execfile(%r)' % app_path in globals, locals # pylint: disable=exec-used + with open(app_path, encoding='utf-8') as f: + code = compile(f.read(), app_path, 'exec') + exec(code, globals, locals) # pylint: disable=exec-used -def AttachDebugger(**kwargs): +# pylint: disable=invalid-name +def enable(**kwargs): """Starts the debugger for already running application. This function should only be called once. Args: - flags: debugger configuration. + **kwargs: debugger configuration flags. Raises: RuntimeError: if called more than once. @@ -132,3 +154,6 @@ def AttachDebugger(**kwargs): _flags = kwargs _StartDebugger() + +# AttachDebugger is an alias for enable, preserved for compatibility. +AttachDebugger = enable diff --git a/src/googleclouddebugger/__main__.py b/src/googleclouddebugger/__main__.py index b6ef937..edfe6c0 100644 --- a/src/googleclouddebugger/__main__.py +++ b/src/googleclouddebugger/__main__.py @@ -11,11 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Entry point for Python Cloud Debugger.""" # pylint: disable=invalid-name - if __name__ == '__main__': import googleclouddebugger googleclouddebugger._DebuggerMain() - diff --git a/src/googleclouddebugger/appengine_pretty_printers.py b/src/googleclouddebugger/appengine_pretty_printers.py index 0abed1c..3908990 100644 --- a/src/googleclouddebugger/appengine_pretty_printers.py +++ b/src/googleclouddebugger/appengine_pretty_printers.py @@ -11,9 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Formatters for well known objects that don't show up nicely by default.""" +try: + from protorpc import messages # pylint: disable=g-import-not-at-top +except ImportError: + messages = None + try: from google.appengine.ext import ndb # pylint: disable=g-import-not-at-top except ImportError: @@ -24,6 +28,9 @@ def PrettyPrinter(obj): """Pretty printers for AppEngine objects.""" if ndb and isinstance(obj, ndb.Model): - return obj.to_dict().iteritems(), 'ndb.Model(%s)' % type(obj).__name__ + return obj.to_dict().items(), 'ndb.Model(%s)' % type(obj).__name__ + + if messages and isinstance(obj, messages.Enum): + return [('name', obj.name), ('number', obj.number)], type(obj).__name__ return None diff --git a/src/googleclouddebugger/application_info.py b/src/googleclouddebugger/application_info.py new file mode 100644 index 0000000..c920cce --- /dev/null +++ b/src/googleclouddebugger/application_info.py @@ -0,0 +1,79 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module to fetch information regarding the current application. + +Some examples of the information the methods in this module fetch are platform +and region of the application. +""" + +import enum +import os +import requests + +# These environment variables will be set automatically by cloud functions +# depending on the runtime. If one of these values is set, we can infer that +# the current environment is GCF. Reference: +# https://cloud.google.com/functions/docs/env-var#runtime_environment_variables_set_automatically +_GCF_EXISTENCE_ENV_VARIABLES = ['FUNCTION_NAME', 'FUNCTION_TARGET'] +_GCF_REGION_ENV_VARIABLE = 'FUNCTION_REGION' + +_GCP_METADATA_REGION_URL = 'http://metadata/computeMetadata/v1/instance/region' +_GCP_METADATA_HEADER = {'Metadata-Flavor': 'Google'} + + +class PlatformType(enum.Enum): + """The type of platform the application is running on. + + TODO: Define this enum in a common format for all agents to + share. This enum needs to be maintained between the labels code generator + and other agents, until there is a unified way to generate it. + """ + CLOUD_FUNCTION = 'cloud_function' + DEFAULT = 'default' + + +def GetPlatform(): + """Returns PlatformType for the current application.""" + + # Check if it's a cloud function. + for name in _GCF_EXISTENCE_ENV_VARIABLES: + if name in os.environ: + return PlatformType.CLOUD_FUNCTION + + # If we weren't able to identify the platform, fall back to default value. + return PlatformType.DEFAULT + + +def GetRegion(): + """Returns region of the current application.""" + + # If it's running cloud function with an old runtime. + if _GCF_REGION_ENV_VARIABLE in os.environ: + return os.environ.get(_GCF_REGION_ENV_VARIABLE) + + # Otherwise try fetching it from the metadata server. + try: + response = requests.get( + _GCP_METADATA_REGION_URL, headers=_GCP_METADATA_HEADER) + response.raise_for_status() + # Example of response text: projects/id/regions/us-central1. So we strip + # everything before the last /. + region = response.text.split('/')[-1] + if region == 'html>': + # Sometimes we get an html response! + return None + + return region + except requests.exceptions.RequestException: + return None diff --git a/src/googleclouddebugger/backoff.py b/src/googleclouddebugger/backoff.py index edc024f..f12237d 100644 --- a/src/googleclouddebugger/backoff.py +++ b/src/googleclouddebugger/backoff.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Implements exponential backoff for retry timeouts.""" diff --git a/src/googleclouddebugger/breakpoints_manager.py b/src/googleclouddebugger/breakpoints_manager.py index 807349c..e3f0421 100644 --- a/src/googleclouddebugger/breakpoints_manager.py +++ b/src/googleclouddebugger/breakpoints_manager.py @@ -11,13 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Manages lifetime of individual breakpoint objects.""" from datetime import datetime from threading import RLock -import python_breakpoint +from . import python_breakpoint class BreakpointsManager(object): @@ -34,10 +33,13 @@ class BreakpointsManager(object): Args: hub_client: queries active breakpoints from the backend and sends breakpoint updates back to the backend. + data_visibility_policy: An object used to determine the visibiliy + of a captured variable. May be None if no policy is available. """ - def __init__(self, hub_client): + def __init__(self, hub_client, data_visibility_policy): self._hub_client = hub_client + self.data_visibility_policy = data_visibility_policy # Lock to synchronize access to data across multiple threads. self._lock = RLock() @@ -65,15 +67,17 @@ def SetActiveBreakpoints(self, breakpoints_data): ids = set([x['id'] for x in breakpoints_data]) # Clear breakpoints that no longer show up in active breakpoints list. - for breakpoint_id in self._active.viewkeys() - ids: + for breakpoint_id in self._active.keys() - ids: self._active.pop(breakpoint_id).Clear() # Create new breakpoints. self._active.update([ (x['id'], - python_breakpoint.PythonBreakpoint(x, self._hub_client, self)) + python_breakpoint.PythonBreakpoint(x, self._hub_client, self, + self.data_visibility_policy)) for x in breakpoints_data - if x['id'] in ids - self._active.viewkeys() - self._completed]) + if x['id'] in ids - self._active.keys() - self._completed + ]) # Remove entries from completed_breakpoints_ that weren't listed in # breakpoints_data vector. These are confirmed to have been removed by the @@ -102,13 +106,13 @@ def CompleteBreakpoint(self, breakpoint_id): def CheckBreakpointsExpiration(self): """Completes all breakpoints that have been active for too long.""" with self._lock: - current_time = BreakpointsManager._GetCurrentTime() + current_time = BreakpointsManager.GetCurrentTime() if self._next_expiration > current_time: return expired_breakpoints = [] self._next_expiration = datetime.max - for breakpoint in self._active.itervalues(): + for breakpoint in self._active.values(): expiration_time = breakpoint.GetExpirationTime() if expiration_time <= current_time: expired_breakpoints.append(breakpoint) @@ -119,7 +123,7 @@ def CheckBreakpointsExpiration(self): breakpoint.ExpireBreakpoint() @staticmethod - def _GetCurrentTime(): + def GetCurrentTime(): """Wrapper around datetime.now() function. The datetime class is a built-in one and therefore not patchable by unit diff --git a/src/googleclouddebugger/bytecode_breakpoint.cc b/src/googleclouddebugger/bytecode_breakpoint.cc index 96700f7..dd1af6e 100644 --- a/src/googleclouddebugger/bytecode_breakpoint.cc +++ b/src/googleclouddebugger/bytecode_breakpoint.cc @@ -19,6 +19,8 @@ #include "bytecode_breakpoint.h" +#include + #include "bytecode_manipulator.h" #include "python_callback.h" #include "python_util.h" @@ -48,7 +50,7 @@ void BytecodeBreakpoint::Detach() { it->second->breakpoints.clear(); PatchCodeObject(it->second); - // TODO(vlif): assert zombie_refs.empty() after garbage collection + // TODO: assert zombie_refs.empty() after garbage collection // for zombie refs is implemented. delete it->second; @@ -64,7 +66,7 @@ void BytecodeBreakpoint::Detach() { } -int BytecodeBreakpoint::SetBreakpoint( +int BytecodeBreakpoint::CreateBreakpoint( PyCodeObject* code_object, int line, std::function hit_callback, @@ -80,7 +82,7 @@ int BytecodeBreakpoint::SetBreakpoint( // table in case "code_object" is already patched with another breakpoint. CodeObjectLinesEnumerator lines_enumerator( code_object->co_firstlineno, - code_object_breakpoints->original_lnotab.get()); + code_object_breakpoints->original_linedata.get()); while (lines_enumerator.line_number() != line) { if (!lines_enumerator.Next()) { LOG(ERROR) << "Line " << line << " not found in " @@ -95,10 +97,12 @@ int BytecodeBreakpoint::SetBreakpoint( std::unique_ptr breakpoint(new Breakpoint); breakpoint->code_object = ScopedPyCodeObject::NewReference(code_object); + breakpoint->line = line; breakpoint->offset = lines_enumerator.offset(); breakpoint->hit_callable = PythonCallback::Wrap(hit_callback); breakpoint->error_callback = error_callback; breakpoint->cookie = cookie; + breakpoint->status = BreakpointStatus::kInactive; code_object_breakpoints->breakpoints.insert( std::make_pair(breakpoint->offset, breakpoint.get())); @@ -106,15 +110,44 @@ int BytecodeBreakpoint::SetBreakpoint( DCHECK(cookie_map_[cookie] == nullptr); cookie_map_[cookie] = breakpoint.release(); - PatchCodeObject(code_object_breakpoints); - return cookie; } +void BytecodeBreakpoint::ActivateBreakpoint(int cookie) { + if (cookie == -1) return; // no-op if invalid cookie. + + auto it_breakpoint = cookie_map_.find(cookie); + if (it_breakpoint == cookie_map_.end()) { + LOG(WARNING) << "Trying to activate a breakpoint with an unknown cookie: " + << cookie; + return; // No breakpoint with this cookie. + } + + auto it_code = patches_.find(it_breakpoint->second->code_object); + if (it_code != patches_.end()) { + CodeObjectBreakpoints* code = it_code->second; + // Ensure that there is a new breakpoint that was added. + if (it_breakpoint->second->status == BreakpointStatus::kInactive) { + // Set breakpoint to active. + it_breakpoint->second->status = BreakpointStatus::kActive; + // Patch code. + PatchCodeObject(code); + } else { + LOG(WARNING) << "Breakpoint with cookie: " << cookie + << " has already been activated"; + } + } else { + LOG(DFATAL) << "Missing code object"; + } +} void BytecodeBreakpoint::ClearBreakpoint(int cookie) { + if (cookie == -1) return; // no-op if invalid cookie + auto it_breakpoint = cookie_map_.find(cookie); if (it_breakpoint == cookie_map_.end()) { + LOG(WARNING) << "Trying to clear a breakpoint with an unknown cookie: " + << cookie; return; // No breakpoint with this cookie. } @@ -138,6 +171,9 @@ void BytecodeBreakpoint::ClearBreakpoint(int cookie) { DCHECK_EQ(1, erase_count); + // Set breakpoint as done, as it was removed from code->breakpoints map. + it_breakpoint->second->status = BreakpointStatus::kDone; + PatchCodeObject(code); if (code->breakpoints.empty() && code->zombie_refs.empty()) { @@ -145,13 +181,22 @@ void BytecodeBreakpoint::ClearBreakpoint(int cookie) { patches_.erase(it_code); } } else { - DCHECK(false) << "Missing code object"; + LOG(DFATAL) << "Missing code object"; } delete it_breakpoint->second; cookie_map_.erase(it_breakpoint); } +BreakpointStatus BytecodeBreakpoint::GetBreakpointStatus(int cookie) { + auto it_breakpoint = cookie_map_.find(cookie); + if (it_breakpoint == cookie_map_.end()) { + // No breakpoint with this cookie. + return BreakpointStatus::kUnknown; + } + + return it_breakpoint->second->status; +} BytecodeBreakpoint::CodeObjectBreakpoints* BytecodeBreakpoint::PreparePatchCodeObject( @@ -187,13 +232,19 @@ BytecodeBreakpoint::PreparePatchCodeObject( data->original_code = ScopedPyObject::NewReference(code_object.get()->co_code); if ((data->original_code == nullptr) || - !PyString_CheckExact(data->original_code.get())) { + !PyBytes_CheckExact(data->original_code.get())) { LOG(ERROR) << "Code object has no code"; return nullptr; // Probably a built-in method or uninitialized code object. } - data->original_lnotab = + // Store the original (unmodified) line data. +#if PY_VERSION_HEX < 0x030A0000 + data->original_linedata = ScopedPyObject::NewReference(code_object.get()->co_lnotab); +#else + data->original_linedata = + ScopedPyObject::NewReference(code_object.get()->co_linetable); +#endif patches_[code_object] = data.get(); return data.release(); @@ -217,29 +268,38 @@ void BytecodeBreakpoint::PatchCodeObject(CodeObjectBreakpoints* code) { << " from patched " << code->zombie_refs.back().get(); Py_INCREF(code_object->co_code); + // Restore the original line data to the code object. +#if PY_VERSION_HEX < 0x030A0000 if (code_object->co_lnotab != nullptr) { code->zombie_refs.push_back(ScopedPyObject(code_object->co_lnotab)); } - code_object->co_lnotab = code->original_lnotab.get(); + code_object->co_lnotab = code->original_linedata.get(); Py_INCREF(code_object->co_lnotab); +#else + if (code_object->co_linetable != nullptr) { + code->zombie_refs.push_back(ScopedPyObject(code_object->co_linetable)); + } + code_object->co_linetable = code->original_linedata.get(); + Py_INCREF(code_object->co_linetable); +#endif return; } - std::vector bytecode = PyStringToByteArray(code->original_code.get()); + std::vector bytecode = PyBytesToByteArray(code->original_code.get()); - bool has_lnotab = false; - std::vector lnotab; - if (!code->original_lnotab.is_null() && - PyString_CheckExact(code->original_lnotab.get())) { - has_lnotab = true; - lnotab = PyStringToByteArray(code->original_lnotab.get()); + bool has_linedata = false; + std::vector linedata; + if (!code->original_linedata.is_null() && + PyBytes_CheckExact(code->original_linedata.get())) { + has_linedata = true; + linedata = PyBytesToByteArray(code->original_linedata.get()); } BytecodeManipulator bytecode_manipulator( std::move(bytecode), - has_lnotab, - std::move(lnotab)); + has_linedata, + std::move(linedata)); // Add callbacks to code object constants and patch the bytecode. std::vector callbacks; @@ -251,16 +311,43 @@ void BytecodeBreakpoint::PatchCodeObject(CodeObjectBreakpoints* code) { for (auto it_entry = code->breakpoints.begin(); it_entry != code->breakpoints.end(); ++it_entry, ++const_index) { - const int offset = it_entry->first; + // Skip breakpoint if it still hasn't been activated. + if (it_entry->second->status == BreakpointStatus::kInactive) continue; + + int offset = it_entry->first; + bool offset_found = true; const Breakpoint& breakpoint = *it_entry->second; DCHECK_EQ(offset, breakpoint.offset); callbacks.push_back(breakpoint.hit_callable.get()); - if (!bytecode_manipulator.InjectMethodCall(offset, const_index)) { + // In Python 3, since we allow upgrading of instructions to use + // EXTENDED_ARG, the offsets for lines originally calculated might not be + // accurate, so we need to recalculate them each insertion. + offset_found = false; + if (bytecode_manipulator.has_linedata()) { + ScopedPyObject linedata(PyBytes_FromStringAndSize( + reinterpret_cast(bytecode_manipulator.linedata().data()), + bytecode_manipulator.linedata().size())); + CodeObjectLinesEnumerator lines_enumerator(code_object->co_firstlineno, + linedata.release()); + while (lines_enumerator.line_number() != breakpoint.line) { + if (!lines_enumerator.Next()) { + break; + } + offset = lines_enumerator.offset(); + } + offset_found = lines_enumerator.line_number() == breakpoint.line; + } + + if (!offset_found || + !bytecode_manipulator.InjectMethodCall(offset, const_index)) { LOG(WARNING) << "Failed to insert bytecode for breakpoint " - << breakpoint.cookie; + << breakpoint.cookie << " at line " << breakpoint.line; errors.push_back(breakpoint.error_callback); + it_entry->second->status = BreakpointStatus::kError; + } else { + it_entry->second->status = BreakpointStatus::kActive; } } @@ -272,7 +359,7 @@ void BytecodeBreakpoint::PatchCodeObject(CodeObjectBreakpoints* code) { code_object->co_stacksize = code->original_stacksize + 1; code->zombie_refs.push_back(ScopedPyObject(code_object->co_code)); - ScopedPyObject bytecode_string(PyString_FromStringAndSize( + ScopedPyObject bytecode_string(PyBytes_FromStringAndSize( reinterpret_cast(bytecode_manipulator.bytecode().data()), bytecode_manipulator.bytecode().size())); DCHECK(!bytecode_string.is_null()); @@ -281,14 +368,26 @@ void BytecodeBreakpoint::PatchCodeObject(CodeObjectBreakpoints* code) { << " reassigned to " << code_object->co_code << ", original was " << code->original_code.get(); - if (has_lnotab) { + // Update the line data in the code object. +#if PY_VERSION_HEX < 0x030A0000 + if (has_linedata) { code->zombie_refs.push_back(ScopedPyObject(code_object->co_lnotab)); - ScopedPyObject lnotab_string(PyString_FromStringAndSize( - reinterpret_cast(bytecode_manipulator.lnotab().data()), - bytecode_manipulator.lnotab().size())); + ScopedPyObject lnotab_string(PyBytes_FromStringAndSize( + reinterpret_cast(bytecode_manipulator.linedata().data()), + bytecode_manipulator.linedata().size())); DCHECK(!lnotab_string.is_null()); code_object->co_lnotab = lnotab_string.release(); } +#else + if (has_linedata) { + code->zombie_refs.push_back(ScopedPyObject(code_object->co_linetable)); + ScopedPyObject linetable_string(PyBytes_FromStringAndSize( + reinterpret_cast(bytecode_manipulator.linedata().data()), + bytecode_manipulator.linedata().size())); + DCHECK(!linetable_string.is_null()); + code_object->co_linetable = linetable_string.release(); + } +#endif // Invoke error callback after everything else is done. The callback may // decide to remove the breakpoint, which will change "code". @@ -299,5 +398,3 @@ void BytecodeBreakpoint::PatchCodeObject(CodeObjectBreakpoints* code) { } // namespace cdbg } // namespace devtools - - diff --git a/src/googleclouddebugger/bytecode_breakpoint.h b/src/googleclouddebugger/bytecode_breakpoint.h index 5256125..5eaa893 100644 --- a/src/googleclouddebugger/bytecode_breakpoint.h +++ b/src/googleclouddebugger/bytecode_breakpoint.h @@ -18,18 +18,56 @@ #define DEVTOOLS_CDBG_DEBUGLETS_PYTHON_BYTECODE_BREAKPOINT_H_ #include -#include #include +#include + #include "common.h" #include "python_util.h" namespace devtools { namespace cdbg { +// Enum representing the status of a breakpoint. State tracking is helpful +// for testing and debugging the bytecode breakpoints. +// ======================================================================= +// State transition map: +// +// (start) kUnknown +// |- [CreateBreakpoint] +// | +// | +// | [ActivateBreakpoint] [PatchCodeObject] +// v | | +// kInactive ----> kActive <---> kError +// | | | +// |-------| | |-------| +// | | | +// |- |- |- [ClearBreakpoint] +// v v v +// kDone +// +// ======================================================================= +enum class BreakpointStatus { + // Unknown status for the breakpoint + kUnknown = 0, + + // Breakpoint is created and is patched in the bytecode. + kActive, + + // Breakpoint is created but is currently not patched in the bytecode. + kInactive, + + // Breakpoint has been cleared. + kDone, + + // Breakpoint is created but failed to be activated (patched in the bytecode). + kError +}; + // Sets breakpoints in Python code with zero runtime overhead. // BytecodeBreakpoint rewrites Python bytecode to insert a breakpoint. The // implementation is specific to CPython 2.7. -// TODO(vlif): rename to BreakpointsEmulator when the original implementation +// TODO: rename to BreakpointsEmulator when the original implementation // of BreakpointsEmulator goes away. class BytecodeBreakpoint { public: @@ -40,27 +78,45 @@ class BytecodeBreakpoint { // Clears all the set breakpoints. void Detach(); - // Sets a new breakpoint in the specified code object. More than one - // breakpoint can be set at the same source location. When the breakpoint - // hits, the "callback" parameter is invoked. Every time this class fails to - // install the breakpoint, "error_callback" is invoked. Returns cookie used - // to clear the breakpoint. - int SetBreakpoint( - PyCodeObject* code_object, - int line, - std::function hit_callback, - std::function error_callback); - - // Removes a previously set breakpoint. If the cookie is invalid, this - // function does nothing. + // Creates a new breakpoint in the specified code object. More than one + // breakpoint can be created at the same source location. When the breakpoint + // hits, the "callback" parameter is invoked. Every time this method fails to + // create the breakpoint, "error_callback" is invoked and a cookie value of + // -1 is returned. If it succeeds in creating the breakpoint, returns the + // unique cookie used to activate and clear the breakpoint. Note this method + // only creates the breakpoint, to activate it you must call + // "ActivateBreakpoint". + int CreateBreakpoint(PyCodeObject* code_object, int line, + std::function hit_callback, + std::function error_callback); + + // Activates a previously created breakpoint. If it fails to set any + // breakpoint, the error callback will be invoked. This method is kept + // separate from "CreateBreakpoint" to ensure that the cookie is available + // before the "error_callback" is invoked. Calling this method with a cookie + // value of -1 is a no-op. Note that any breakpoints in the same function that + // previously failed to activate will retry to activate during this call. + // TODO: Provide a method "ActivateAllBreakpoints" to optimize + // the code and patch the code once, instead of multiple times. + void ActivateBreakpoint(int cookie); + + // Removes a previously set breakpoint. Calling this method with a cookie + // value of -1 is a no-op. Note that any breakpoints in the same function that + // previously failed to activate will retry to activate during this call. void ClearBreakpoint(int cookie); + // Get the status of a breakpoint. + BreakpointStatus GetBreakpointStatus(int cookie); + private: // Information about the breakpoint. struct Breakpoint { // Method in which the breakpoint is set. ScopedPyCodeObject code_object; + // Line number on which the breakpoint is set. + int line; + // Offset to the instruction on which the breakpoint is set. int offset; @@ -73,6 +129,9 @@ class BytecodeBreakpoint { // Breakpoint ID used to clear the breakpoint. int cookie; + + // Status of the breakpoint. + BreakpointStatus status; }; // Set of breakpoints in a particular code object and original data of @@ -91,7 +150,7 @@ class BytecodeBreakpoint { // constants. Instead we store these references in a special zombie pool. // Then once we know that no Python thread is executing the code object, // we can release all of them. - // TODO(vlif): implement garbage collection for zombie refs. + // TODO: implement garbage collection for zombie refs. std::vector zombie_refs; // Original value of PyCodeObject::co_stacksize before patching. @@ -103,9 +162,10 @@ class BytecodeBreakpoint { // Original value of PyCodeObject::co_code before patching. ScopedPyObject original_code; - // Original value of PythonCode::co_lnotab before patching. - // "lnotab" stands for "line numbers table" in CPython lingo. - ScopedPyObject original_lnotab; + // Original value of PythonCode::co_lnotab or PythonCode::co_linetable + // before patching. This is the line numbers table in CPython <= 3.9 and + // CPython >= 3.10 respectively + ScopedPyObject original_linedata; }; // Loads code object into "patches_" if not there yet. Returns nullptr if diff --git a/src/googleclouddebugger/bytecode_manipulator.cc b/src/googleclouddebugger/bytecode_manipulator.cc index 21030a8..44cef74 100644 --- a/src/googleclouddebugger/bytecode_manipulator.cc +++ b/src/googleclouddebugger/bytecode_manipulator.cc @@ -19,6 +19,9 @@ #include "bytecode_manipulator.h" +#include +#include + namespace devtools { namespace cdbg { @@ -31,65 +34,68 @@ enum PythonOpcodeType { YIELD_OPCODE }; -// Single Python instruction. There are 3 types of instructions: -// 1. Instruction without arguments (takes 1 byte). -// 2. Instruction with a single 16 bit argument (takes 3 bytes). -// 3. Instruction with a 32 bit argument (very uncommon; takes 6 bytes). +// Single Python instruction. +// +// In Python 3.6, there are 4 types of instructions: +// 1. Instructions without arguments, or a 8 bit argument (takes 2 bytes). +// 2. Instructions with a 16 bit argument (takes 4 bytes). +// 3. Instructions with a 24 bit argument (takes 6 bytes). +// 4. Instructions with a 32 bit argument (takes 8 bytes). +// +// To handle 16-32 bit arguments in Python 3, +// a special instruction with an opcode of EXTENDED_ARG is prepended to the +// actual instruction. The argument of the EXTENDED_ARG instruction is combined +// with the argument of the next instruction to form the full argument. struct PythonInstruction { - uint8 opcode; - uint32 argument; - bool is_extended; + uint8_t opcode; + uint32_t argument; + int size; }; // Special pseudo-instruction to indicate failures. -static const PythonInstruction kInvalidInstruction { 0xFF, 0xFFFFFFFF, false }; +static const PythonInstruction kInvalidInstruction { 0xFF, 0xFFFFFFFF, 0 }; // Creates an instance of PythonInstruction for instruction with no arguments. -static PythonInstruction PythonInstructionNoArg(uint8 opcode) { +static PythonInstruction PythonInstructionNoArg(uint8_t opcode) { DCHECK(!HAS_ARG(opcode)); PythonInstruction instruction; instruction.opcode = opcode; instruction.argument = 0; - instruction.is_extended = false; + + instruction.size = 2; return instruction; } - // Creates an instance of PythonInstruction for instruction with an argument. -static PythonInstruction PythonInstructionArg(uint8 opcode, uint32 argument) { +static PythonInstruction PythonInstructionArg(uint8_t opcode, + uint32_t argument) { DCHECK(HAS_ARG(opcode)); PythonInstruction instruction; instruction.opcode = opcode; instruction.argument = argument; - instruction.is_extended = (argument > 0xFFFF); - - return instruction; -} - -// Calculates the number of bytes that an instruction occupies. -static int GetInstructionSize(const PythonInstruction& instruction) { - if (instruction.is_extended) { - return 6; // Extended instruction with a 32 bit argument. - } - - if (HAS_ARG(instruction.opcode)) { - return 3; // Instruction with a single 16 bit argument. + if (argument <= 0xFF) { + instruction.size = 2; + } else if (argument <= 0xFFFF) { + instruction.size = 4; + } else if (argument <= 0xFFFFFF) { + instruction.size = 6; + } else { + instruction.size = 8; } - return 1; // Instruction without argument. + return instruction; } - // Calculates the size of a set of instructions. static int GetInstructionsSize( const std::vector& instructions) { int size = 0; for (auto it = instructions.begin(); it != instructions.end(); ++it) { - size += GetInstructionSize(*it); + size += it->size; } return size; @@ -97,17 +103,25 @@ static int GetInstructionsSize( // Classification of an opcode. -static PythonOpcodeType GetOpcodeType(uint8 opcode) { +static PythonOpcodeType GetOpcodeType(uint8_t opcode) { switch (opcode) { case YIELD_VALUE: + case YIELD_FROM: return YIELD_OPCODE; case FOR_ITER: case JUMP_FORWARD: +#if PY_VERSION_HEX < 0x03080000 + // Removed in Python 3.8. case SETUP_LOOP: case SETUP_EXCEPT: +#endif case SETUP_FINALLY: case SETUP_WITH: +#if PY_VERSION_HEX >= 0x03080000 && PY_VERSION_HEX < 0x03090000 + // Added in Python 3.8 and removed in 3.9 + case CALL_FINALLY: +#endif return BRANCH_DELTA_OPCODE; case JUMP_IF_FALSE_OR_POP: @@ -115,7 +129,13 @@ static PythonOpcodeType GetOpcodeType(uint8 opcode) { case JUMP_ABSOLUTE: case POP_JUMP_IF_FALSE: case POP_JUMP_IF_TRUE: +#if PY_VERSION_HEX < 0x03080000 + // Removed in Python 3.8. case CONTINUE_LOOP: +#endif +#if PY_VERSION_HEX >= 0x03090000 + case JUMP_IF_NOT_EXC_MATCH: +#endif return BRANCH_ABSOLUTE_OPCODE; default: @@ -123,19 +143,22 @@ static PythonOpcodeType GetOpcodeType(uint8 opcode) { } } - // Gets the target offset of a branch instruction. static int GetBranchTarget(int offset, PythonInstruction instruction) { - const int argument_value = instruction.is_extended - ? static_cast(instruction.argument) - : static_cast(instruction.argument); - switch (GetOpcodeType(instruction.opcode)) { case BRANCH_DELTA_OPCODE: - return offset + GetInstructionSize(instruction) + argument_value; +#if PY_VERSION_HEX < 0x030A0000 + return offset + instruction.size + instruction.argument; +#else + return offset + instruction.size + instruction.argument * 2; +#endif case BRANCH_ABSOLUTE_OPCODE: - return argument_value; +#if PY_VERSION_HEX < 0x030A0000 + return instruction.argument; +#else + return instruction.argument * 2; +#endif default: DCHECK(false) << "Not a branch instruction"; @@ -144,106 +167,66 @@ static int GetBranchTarget(int offset, PythonInstruction instruction) { } -// Reads 16 bit value according to Python bytecode encoding. -static uint16 ReadPythonBytecodeUInt16(std::vector::const_iterator it) { - return it[0] | (static_cast(it[1]) << 8); -} - - -// Writes 16 bit value according to Python bytecode encoding. -static void WritePythonBytecodeUInt16( - std::vector::iterator it, - uint16 data) { - it[0] = static_cast(data); - it[1] = data >> 8; -} - - // Read instruction at the specified offset. Returns kInvalidInstruction // buffer underflow. static PythonInstruction ReadInstruction( - const std::vector& bytecode, - std::vector::const_iterator it) { - PythonInstruction instruction { 0, 0, false }; + const std::vector& bytecode, + std::vector::const_iterator it) { + PythonInstruction instruction { 0, 0, 0 }; - if (it == bytecode.end()) { + if (bytecode.end() - it < 2) { LOG(ERROR) << "Buffer underflow"; return kInvalidInstruction; } - instruction.opcode = it[0]; - - auto it_arg = it + 1; - if (instruction.opcode == EXTENDED_ARG) { - if (bytecode.end() - it < 6) { - LOG(ERROR) << "Buffer underflow"; - return kInvalidInstruction; - } - - instruction.opcode = it[3]; - - auto it_ext = it + 4; - instruction.argument = - (static_cast(ReadPythonBytecodeUInt16(it_arg)) << 16) | - ReadPythonBytecodeUInt16(it_ext); - instruction.is_extended = true; - } else if (HAS_ARG(instruction.opcode)) { - if (bytecode.end() - it < 3) { + while (it[0] == EXTENDED_ARG) { + instruction.argument = instruction.argument << 8 | it[1]; + it += 2; + instruction.size += 2; + if (bytecode.end() - it < 2) { LOG(ERROR) << "Buffer underflow"; return kInvalidInstruction; } - - instruction.argument = ReadPythonBytecodeUInt16(it_arg); } + instruction.opcode = it[0]; + instruction.argument = instruction.argument << 8 | it[1]; + instruction.size += 2; + return instruction; } - // Writes instruction to the specified destination. The caller is responsible // to make sure the target vector has enough space. Returns size of an // instruction. -static int WriteInstruction( - std::vector::iterator it, - const PythonInstruction& instruction) { - if (instruction.is_extended) { - it[0] = EXTENDED_ARG; - WritePythonBytecodeUInt16(it + 1, instruction.argument >> 16); - it[3] = instruction.opcode; - WritePythonBytecodeUInt16( - it + 4, - static_cast(instruction.argument)); - return 6; - } else { - it[0] = instruction.opcode; - - if (HAS_ARG(instruction.opcode)) { - DCHECK_LE(instruction.argument, 0xFFFFU); - WritePythonBytecodeUInt16( - it + 1, - static_cast(instruction.argument)); - return 3; - } - - return 1; +static int WriteInstruction(std::vector::iterator it, + const PythonInstruction& instruction) { + uint32_t arg = instruction.argument; + int size_written = 0; + // Start writing backwards from the real instruction, followed by any + // EXTENDED_ARG instructions if needed. + for (int i = instruction.size - 2; i >= 0; i -= 2) { + it[i] = size_written == 0 ? instruction.opcode : EXTENDED_ARG; + it[i + 1] = static_cast(arg); + arg = arg >> 8; + size_written += 2; } + return size_written; } - // Write set of instructions to the specified destination. static void WriteInstructions( - std::vector::iterator it, + std::vector::iterator it, const std::vector& instructions) { for (auto it_instruction = instructions.begin(); it_instruction != instructions.end(); ++it_instruction) { const int instruction_size = WriteInstruction(it, *it_instruction); - DCHECK_EQ(instruction_size, GetInstructionSize(*it_instruction)); + DCHECK_EQ(instruction_size, it_instruction->size); it += instruction_size; } } - // Returns set of instructions to invoke a method with no arguments. The // method is assumed to be defined in the specified item of a constants tuple. static std::vector BuildMethodCall(int const_index) { @@ -255,14 +238,12 @@ static std::vector BuildMethodCall(int const_index) { return instructions; } - -BytecodeManipulator::BytecodeManipulator( - std::vector bytecode, - const bool has_lnotab, - std::vector lnotab) - : has_lnotab_(has_lnotab) { +BytecodeManipulator::BytecodeManipulator(std::vector bytecode, + const bool has_linedata, + std::vector linedata) + : has_linedata_(has_linedata) { data_.bytecode = std::move(bytecode); - data_.lnotab = std::move(lnotab); + data_.linedata = std::move(linedata); strategy_ = STRATEGY_INSERT; // Default strategy. for (auto it = data_.bytecode.begin(); it < data_.bytecode.end(); ) { @@ -272,16 +253,15 @@ BytecodeManipulator::BytecodeManipulator( break; } - if (instruction.opcode == YIELD_VALUE) { + if (GetOpcodeType(instruction.opcode) == YIELD_OPCODE) { strategy_ = STRATEGY_APPEND; break; } - it += GetInstructionSize(instruction); + it += instruction.size; } } - bool BytecodeManipulator::InjectMethodCall( int offset, int callable_const_index) { @@ -308,91 +288,237 @@ bool BytecodeManipulator::InjectMethodCall( } +// Represents a branch instruction in the original bytecode that may need to +// have its offsets fixed and/or upgraded to use EXTENDED_ARG. +struct UpdatedInstruction { + PythonInstruction instruction; + int original_size; + int current_offset; +}; + + +// Represents space that needs to be reserved for an insertion operation. +struct Insertion { + int size; + int current_offset; +}; + +// Max number of outer loop iterations to do before failing in +// InsertAndUpdateBranchInstructions. +static const int kMaxInsertionIterations = 10; + +#if PY_VERSION_HEX < 0x030A0000 +// Updates the line number table for an insertion in the bytecode. +// Example for inserting 2 bytes at offset 2: +// lnotab: [{2, 1}, {4, 1}] // {offset_delta, line_delta} +// updated: [{2, 1}, {6, 1}] +static void InsertAndUpdateLineData(int offset, int size, + std::vector* lnotab) { + int current_offset = 0; + for (auto it = lnotab->begin(); it != lnotab->end(); it += 2) { + current_offset += it[0]; + + if (current_offset > offset) { + int remaining_size = it[0] + size; + int remaining_lines = it[1]; + it = lnotab->erase(it, it + 2); + while (remaining_size > 0xFF) { + it = lnotab->insert(it, 0xFF) + 1; + it = lnotab->insert(it, 0) + 1; + remaining_size -= 0xFF; + } + it = lnotab->insert(it, remaining_size) + 1; + it = lnotab->insert(it, remaining_lines) + 1; + return; + } + } +} +#else +// Updates the line number table for an insertion in the bytecode. +// Example for inserting 2 bytes at offset 2: +// linetable: [{2, 1}, {4, 1}] // {address_end_delta, line_delta} +// updated: [{2, 1}, {6, 1}] +// +// For more information on the linetable format in Python 3.10, see: +// https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt +static void InsertAndUpdateLineData(int offset, int size, + std::vector* linetable) { + int current_offset = 0; + for (auto it = linetable->begin(); it != linetable->end(); it += 2) { + current_offset += it[0]; + + if (current_offset > offset) { + int remaining_size = it[0] + size; + int remaining_lines = it[1]; + it = linetable->erase(it, it + 2); + while (remaining_size > 0xFE) { // Max address delta is listed as 254. + it = linetable->insert(it, 0xFE) + 1; + it = linetable->insert(it, 0) + 1; + remaining_size -= 0xFE; + } + it = linetable->insert(it, remaining_size) + 1; + it = linetable->insert(it, remaining_lines) + 1; + return; + } + } +} +#endif + +// Reserves space for instructions to be inserted into the bytecode, and +// calculates the new offsets and arguments of branch instructions. +// Returns true if the calculation was successful, and false if too many +// iterations were needed. +// +// When inserting some space for the method call bytecode, branch instructions +// may need to have their offsets updated. Some cases might require branch +// instructions to be 'upgraded' to use EXTENDED_ARG if the new argument crosses +// the argument value limit for its current size.. This in turn will require +// another insertion and possibly further updates. +// +// It won't be manageable to update the bytecode in place in such cases, as when +// performing an insertion we might need to perform more insertions and quickly +// lose our place. +// +// Instead, we perform process insertion operations one at a time, starting from +// the original argument. While processing an operation, if an instruction needs +// to be upgraded to use EXTENDED_ARG, then another insertion operation is +// pushed on the stack to be processed later. +// +// Example: +// Suppose we need to reserve space for 6 bytes at offset 40. We have a +// JUMP_ABSOLUTE 250 instruction at offset 0, and a JUMP_FORWARD 2 instruction +// at offset 40. +// insertions: [{6, 40}] +// instructions: [{JUMP_ABSOLUTE 250, 0}, {JUMP_FORWARD 2, 40}] +// +// The JUMP_ABSOLUTE argument needs to be moved forward to 256, since the +// insertion occurs before the target. This requires an EXTENDED_ARG, so another +// insertion operation with size=2 at offset=0 is pushed. +// The JUMP_FORWARD instruction will be after the space reserved, so we need to +// update its current offset to now be 46. The argument does not need to be +// changed, as the insertion is not between its offset and target. +// insertions: [{2, 0}] +// instructions: [{JUMP_ABSOLUTE 256, 0}, {JUMP_FORWARD 2, 46}] +// +// For the next insertion, The JUMP_ABSOLUTE instruction's offset does not +// change, since it has the same offset as the insertion, signaling that the +// insertion is for the instruction itself. The argument gets updated to 258 to +// account for the additional space. The JUMP_FORWARD instruction's offset needs +// to be updated, but not its argument, for the same reason as before. +// insertions: [] +// instructions: [{JUMP_ABSOLUTE 258, 0}, {JUMP_FORWARD 2, 48}] +// +// There are no more insertions so we are done. +static bool InsertAndUpdateBranchInstructions( + Insertion insertion, std::vector& instructions) { + std::vector insertions { insertion }; + + int iterations = 0; + while (insertions.size() && iterations < kMaxInsertionIterations) { + insertion = insertions.back(); + insertions.pop_back(); + + // Update the offsets of all insertions after. + for (auto it = insertions.begin(); it < insertions.end(); it++) { + if (it->current_offset >= insertion.current_offset) { + it->current_offset += insertion.size; + } + } + + // Update the offsets and arguments of the branches. + for (auto it = instructions.begin(); + it < instructions.end(); it++) { + PythonInstruction instruction = it->instruction; + int32_t arg = static_cast(instruction.argument); + bool need_to_update = false; + PythonOpcodeType opcode_type = GetOpcodeType(instruction.opcode); + if (opcode_type == BRANCH_DELTA_OPCODE) { + // For relative branches, the argument needs to be updated if the + // insertion is between the instruction and the target. + // The Python compiler sometimes prematurely adds EXTENDED_ARG with an + // argument of 0 even when it is not required. This needs to be taken + // into account when calculating the target of a branch instruction. + int inst_size = std::max(instruction.size, it->original_size); +#if PY_VERSION_HEX < 0x030A0000 + int32_t target = it->current_offset + inst_size + arg; +#else + int32_t target = it->current_offset + inst_size + arg * 2; +#endif + need_to_update = it->current_offset < insertion.current_offset && + insertion.current_offset < target; + } else if (opcode_type == BRANCH_ABSOLUTE_OPCODE) { + // For absolute branches, the argument needs to be updated if the + // insertion before the target. +#if PY_VERSION_HEX < 0x030A0000 + need_to_update = insertion.current_offset < arg; +#else + need_to_update = insertion.current_offset < arg * 2; +#endif + } + + // If we are inserting the original method call instructions, we want to + // update the current_offset of any instructions at or after. If we are + // doing an EXTENDED_ARG insertion, we don't want to update the offset of + // instructions right at the offset, because that is the original + // instruction that the EXTENDED_ARG is for. + int offset_diff = it->current_offset - insertion.current_offset; + if ((iterations == 0 && offset_diff >= 0) || (offset_diff > 0)) { + it->current_offset += insertion.size; + } + + if (need_to_update) { +#if PY_VERSION_HEX < 0x030A0000 + int delta = insertion.size; +#else + // Changed in version 3.10: The argument of jump, exception handling + // and loop instructions is now the instruction offset rather than the + // byte offset. + int delta = insertion.size / 2; +#endif + PythonInstruction new_instruction = + PythonInstructionArg(instruction.opcode, arg + delta); + int size_diff = new_instruction.size - instruction.size; + if (size_diff > 0) { + insertions.push_back(Insertion { size_diff, it->current_offset }); + } + it->instruction = new_instruction; + } + } + iterations++; + } + + return insertions.size() == 0; +} + + bool BytecodeManipulator::InsertMethodCall( BytecodeManipulator::Data* data, int offset, int const_index) const { - const std::vector method_call_instructions = - BuildMethodCall(const_index); - int size = GetInstructionsSize(method_call_instructions); - + std::vector updated_instructions; bool offset_valid = false; - for (auto it = data->bytecode.begin(); it < data->bytecode.end(); ) { - const int current_offset = it - data->bytecode.begin(); + + // Gather all branch instructions. + for (auto it = data->bytecode.begin(); it < data->bytecode.end();) { + int current_offset = it - data->bytecode.begin(); if (current_offset == offset) { DCHECK(!offset_valid) << "Each offset should be visited only once"; offset_valid = true; } - int current_fixed_offset = current_offset; - if (current_fixed_offset >= offset) { - current_fixed_offset += size; - } - PythonInstruction instruction = ReadInstruction(data->bytecode, it); if (instruction.opcode == kInvalidInstruction.opcode) { return false; } - const int instruction_size = GetInstructionSize(instruction); - - // Fix targets in branch instructions. - switch (instruction.opcode) { - // Delta target argument. - case FOR_ITER: - case JUMP_FORWARD: - case SETUP_LOOP: - case SETUP_EXCEPT: - case SETUP_FINALLY: - case SETUP_WITH: { - int32 delta = instruction.is_extended - ? static_cast(instruction.argument) - : static_cast(instruction.argument); - - int32 target = current_offset + instruction_size + delta; - if (target > offset) { - target += size; - } - - int32 fixed_delta = target - current_fixed_offset - instruction_size; - if (delta != fixed_delta) { - if (instruction.is_extended) { - instruction.argument = static_cast(fixed_delta); - } else { - if (static_cast(delta) != delta) { - LOG(ERROR) << "Upgrading instruction to extended not supported"; - return false; - } - - instruction.argument = static_cast(fixed_delta); - } - - WriteInstruction(it, instruction); - } - - break; - } - - // Absolute target argument. - case JUMP_IF_FALSE_OR_POP: - case JUMP_IF_TRUE_OR_POP: - case JUMP_ABSOLUTE: - case POP_JUMP_IF_FALSE: - case POP_JUMP_IF_TRUE: - case CONTINUE_LOOP: - if (static_cast(instruction.argument) > offset) { - instruction.argument += size; - if (!instruction.is_extended && (instruction.argument > 0xFFFF)) { - LOG(ERROR) << "Upgrading instruction to extended not supported"; - return false; - } - - WriteInstruction(it, instruction); - } - break; + PythonOpcodeType opcode_type = GetOpcodeType(instruction.opcode); + if (opcode_type == BRANCH_DELTA_OPCODE || + opcode_type == BRANCH_ABSOLUTE_OPCODE) { + updated_instructions.push_back( + UpdatedInstruction { instruction, instruction.size, current_offset }); } - it += instruction_size; + it += instruction.size; } if (!offset_valid) { @@ -400,28 +526,42 @@ bool BytecodeManipulator::InsertMethodCall( return false; } - // Insert the bytecode to invoke the callable. - data->bytecode.insert(data->bytecode.begin() + offset, size, STOP_CODE); - WriteInstructions(data->bytecode.begin() + offset, method_call_instructions); + // Calculate new branch instructions. + const std::vector method_call_instructions = + BuildMethodCall(const_index); + int method_call_size = GetInstructionsSize(method_call_instructions); + if (!InsertAndUpdateBranchInstructions({ method_call_size, offset }, + updated_instructions)) { + LOG(ERROR) << "Too many instruction argument upgrades required"; + return false; + } - // Insert a new entry into line table to account for the new bytecode. - if (has_lnotab_) { - int current_offset = 0; - for (auto it = data->lnotab.begin(); it != data->lnotab.end(); it += 2) { - current_offset += it[0]; - - if (current_offset >= offset) { - int remaining_size = size; - while (remaining_size > 0) { - const int current_size = std::min(remaining_size, 0xFF); - it = data->lnotab.insert(it, static_cast(current_size)) + 1; - it = data->lnotab.insert(it, 0) + 1; - remaining_size -= current_size; - } + // Insert the method call. + data->bytecode.insert(data->bytecode.begin() + offset, method_call_size, NOP); + WriteInstructions(data->bytecode.begin() + offset, method_call_instructions); + if (has_linedata_) { + InsertAndUpdateLineData(offset, method_call_size, &data->linedata); + } - break; + // Write new branch instructions. + // We can use current_offset directly since all insertions before would have + // been done by the time we reach the current instruction. + for (auto it = updated_instructions.begin(); + it < updated_instructions.end(); it++) { + int size_diff = it->instruction.size - it->original_size; + int offset = it->current_offset; + if (size_diff > 0) { + data->bytecode.insert(data->bytecode.begin() + offset, size_diff, NOP); + if (has_linedata_) { + InsertAndUpdateLineData(it->current_offset, size_diff, &data->linedata); } + } else if (size_diff < 0) { + // The Python compiler sometimes prematurely adds EXTENDED_ARG with an + // argument of 0 even when it is not required. Just leave it there, but + // start writing the instruction after them. + offset -= size_diff; } + WriteInstruction(data->bytecode.begin() + offset, it->instruction); } return true; @@ -436,17 +576,13 @@ bool BytecodeManipulator::AppendMethodCall( BytecodeManipulator::Data* data, int offset, int const_index) const { - PythonInstruction trampoline; - trampoline.opcode = JUMP_ABSOLUTE; - trampoline.is_extended = false; - trampoline.argument = data->bytecode.size(); - - const int trampoline_size = GetInstructionSize(trampoline); + PythonInstruction trampoline = + PythonInstructionArg(JUMP_ABSOLUTE, data->bytecode.size()); std::vector relocated_instructions; int relocated_size = 0; for (auto it = data->bytecode.begin() + offset; - relocated_size < trampoline_size; ) { + relocated_size < trampoline.size; ) { if (it >= data->bytecode.end()) { LOG(ERROR) << "Not enough instructions"; return false; @@ -464,7 +600,7 @@ bool BytecodeManipulator::AppendMethodCall( // block. Unfortunately not all instructions can be moved: // 1. Instructions with relative offset can't be moved forward, because // the offset can't be negative. - // TODO(vlif): FORWARD_JUMP can be replaced with ABSOLUTE_JUMP. + // TODO: FORWARD_JUMP can be replaced with ABSOLUTE_JUMP. // 2. YIELD_VALUE can't be moved because generator object keeps the frame // object in between "yield" calls. If the breakpoint is added or // removed, subsequent calls into the generator will jump into invalid @@ -476,8 +612,8 @@ bool BytecodeManipulator::AppendMethodCall( } relocated_instructions.push_back(instruction); - relocated_size += GetInstructionSize(instruction); - it += GetInstructionSize(instruction); + relocated_size += instruction.size; + it += instruction.size; } for (auto it = data->bytecode.begin(); it < data->bytecode.end(); ) { @@ -501,7 +637,7 @@ bool BytecodeManipulator::AppendMethodCall( // Suppose we insert breakpoint into offset 1. The new bytecode will be: // 0 LOAD_CONST 6 // 1 JUMP_ABSOLUTE 100 - // 4 STOP_CODE + // 4 NOP // 5 ... // ... // 100 NOP # First relocated instruction. @@ -522,7 +658,7 @@ bool BytecodeManipulator::AppendMethodCall( } } - it += GetInstructionSize(instruction); + it += instruction.size; } std::vector appendix = BuildMethodCall(const_index); @@ -542,14 +678,12 @@ bool BytecodeManipulator::AppendMethodCall( // Insert jump to trampoline. WriteInstruction(data->bytecode.begin() + offset, trampoline); std::fill( - data->bytecode.begin() + offset + trampoline_size, + data->bytecode.begin() + offset + trampoline.size, data->bytecode.begin() + offset + relocated_size, - STOP_CODE); + NOP); return true; } } // namespace cdbg } // namespace devtools - - diff --git a/src/googleclouddebugger/bytecode_manipulator.h b/src/googleclouddebugger/bytecode_manipulator.h index 7d88a15..31a5e46 100644 --- a/src/googleclouddebugger/bytecode_manipulator.h +++ b/src/googleclouddebugger/bytecode_manipulator.h @@ -17,7 +17,9 @@ #ifndef DEVTOOLS_CDBG_DEBUGLETS_PYTHON_BYTECODE_MANIPULATOR_H_ #define DEVTOOLS_CDBG_DEBUGLETS_PYTHON_BYTECODE_MANIPULATOR_H_ +#include #include + #include "common.h" namespace devtools { @@ -50,7 +52,7 @@ namespace cdbg { // For example consider this Python code: // def test(): // yield 'hello' -// It's bytecode without any breakpoints is: +// Its bytecode without any breakpoints is: // 0 LOAD_CONST 1 ('hello') // 3 YIELD_VALUE // 4 POP_TOP @@ -61,27 +63,25 @@ namespace cdbg { // 3 YIELD_VALUE // 4 POP_TOP // 5 LOAD_CONST 0 (None) -// 9 LOAD_CONST 2 (cdbg_native._Callback) // 8 RETURN_VALUE +// 9 LOAD_CONST 2 (cdbg_native._Callback) // 12 CALL_FUNCTION 0 // 15 POP_TOP // 16 LOAD_CONST 1 ('hello') // 19 JUMP_ABSOLUTE 3 class BytecodeManipulator { public: - BytecodeManipulator( - std::vector bytecode, - const bool has_lnotab, - std::vector lnotab); + BytecodeManipulator(std::vector bytecode, const bool has_linedata, + std::vector linedata); // Gets the transformed method bytecode. - const std::vector& bytecode() const { return data_.bytecode; } + const std::vector& bytecode() const { return data_.bytecode; } // Returns true if this class was initialized with line numbers table. - bool has_lnotab() const { return has_lnotab_; } + bool has_linedata() const { return has_linedata_; } // Gets the method line numbers table or empty vector if not available. - const std::vector& lnotab() const { return data_.lnotab; } + const std::vector& linedata() const { return data_.linedata; } // Rewrites the method bytecode to invoke callable at the specified offset. // Return false if the method call could not be inserted. The bytecode @@ -107,10 +107,10 @@ class BytecodeManipulator { struct Data { // Bytecode of a transformed method. - std::vector bytecode; + std::vector bytecode; - // Method line numbers table or empty vector if "has_lnotab_" is false. - std::vector lnotab; + // Method line numbers table or empty vector if "has_linedata_" is false. + std::vector linedata; }; // Insert space into the bytecode. This space is later used to add new @@ -130,7 +130,7 @@ class BytecodeManipulator { Data data_; // True if the method has line number table. - const bool has_lnotab_; + const bool has_linedata_; // Algorithm to insert breakpoint callback into method bytecode. Strategy strategy_; diff --git a/src/googleclouddebugger/capture_collector.py b/src/googleclouddebugger/capture_collector.py deleted file mode 100644 index 6354890..0000000 --- a/src/googleclouddebugger/capture_collector.py +++ /dev/null @@ -1,632 +0,0 @@ -# Copyright 2015 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS-IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Captures application state on a breakpoint hit.""" - -# TODO(vlif): rename this file to collector.py. - -import copy -import datetime -import inspect -import os -import re -import sys -import types - -import cdbg_native as native - -# Externally defined functions to actually log a message. If these variables -# are not initialized, the log action for breakpoints is invalid. -log_info_message = None -log_warning_message = None -log_error_message = None - -_PRIMITIVE_TYPES = (int, long, float, complex, str, unicode, bool, - types.NoneType) -_DATE_TYPES = (datetime.date, datetime.time, datetime.timedelta) -_VECTOR_TYPES = (types.TupleType, types.ListType, types.SliceType, set) - -# TODO(vlif): move to messages.py module. -EMPTY_DICTIONARY = 'Empty dictionary' -EMPTY_COLLECTION = 'Empty collection' -OBJECT_HAS_NO_FIELDS = 'Object has no fields' -LOG_ACTION_NOT_SUPPORTED = 'Log action on a breakpoint not supported' -INVALID_EXPRESSION_INDEX = '' - - -class CaptureCollector(object): - """Captures application state snapshot. - - Captures call stack, local variables and referenced objects. Then formats the - result to be sent back to the user. - - The performance of this class is important. Once the breakpoint hits, the - completion of the user request will be delayed until the collection is over. - It might make sense to implement this logic in C++. - - Attributes: - breakpoint: breakpoint definition augmented with captured call stack, - local variables, arguments and referenced objects. - """ - - # Additional type-specific printers. Each pretty printer is a callable - # that returns None if it doesn't recognize the object or returns a tuple - # with iterable enumerating object fields (name-value tuple) and object type - # string. - pretty_printers = [] - - def __init__(self, definition): - """Class constructor. - - Args: - definition: breakpoint definition that this class will augment with - captured data. - """ - self.breakpoint = copy.deepcopy(definition) - - self.breakpoint['stackFrames'] = [] - self.breakpoint['evaluatedExpressions'] = [] - self.breakpoint['variableTable'] = [{ - 'status': { - 'isError': True, - 'refersTo': 'VARIABLE_VALUE', - 'description': {'format': 'Buffer full'}}}] - - # Shortcut to variables table in the breakpoint message. - self._var_table = self.breakpoint['variableTable'] - - # Maps object ID to its index in variables table. - self._var_table_index = {} - - # Total size of data collected so far. Limited by max_size. - self._total_size = 0 - - # Maximum number of stack frame to capture. The limit is aimed to reduce - # the overall collection time. - self.max_frames = 20 - - # Only collect locals and arguments on the few top frames. For the rest of - # the frames we only collect the source location. - self.max_expand_frames = 5 - - # Maximum amount of data to capture. The application will usually have a - # lot of objects and we need to stop somewhere to keep the delay - # reasonable. - # This constant only counts the collected payload. Overhead due to key - # names is not counted. - self.max_size = 32768 # 32 KB - - # Maximum number of character to allow for a single value. Longer strings - # are truncated. - self.max_value_len = 256 - - # Maximum number of items in a list to capture. - self.max_list_items = 25 - - # Maximum depth of dictionaries to capture. - self.max_depth = 5 - - def Collect(self, top_frame): - """Collects call stack, local variables and objects. - - Starts collection from the specified frame. We don't start from the top - frame to exclude the frames due to debugger. Updates the content of - self.breakpoint. - - Args: - top_frame: top frame to start data collection. - """ - # Evaluate call stack. - frame = top_frame - breakpoint_frames = self.breakpoint['stackFrames'] - while frame and (len(breakpoint_frames) < self.max_frames): - code = frame.f_code - if len(breakpoint_frames) < self.max_expand_frames: - frame_arguments, frame_locals = self.CaptureFrameLocals(frame) - else: - frame_arguments = [] - frame_locals = [] - - breakpoint_frames.append({ - 'function': code.co_name, - 'location': { - 'path': CaptureCollector._NormalizePath(code.co_filename), - 'line': frame.f_lineno}, - 'arguments': frame_arguments, - 'locals': frame_locals}) - frame = frame.f_back - - # Evaluate watched expressions. - if 'expressions' in self.breakpoint: - self.breakpoint['evaluatedExpressions'] = [ - self._CaptureExpression(top_frame, expression) for expression - in self.breakpoint['expressions']] - - # Explore variables table in BFS fashion. The variables table will grow - # inside CaptureVariable as we encounter new references. - i = 1 - while (i < len(self._var_table)) and (self._total_size < self.max_size): - self._var_table[i] = self.CaptureVariable(self._var_table[i], 0, False) - i += 1 - - # Trim variables table and change make all references to variables that - # didn't make it point to var_index of 0 ("buffer full") - self.TrimVariableTable(i) - - def CaptureFrameLocals(self, frame): - """Captures local variables and arguments of the specified frame. - - Args: - frame: frame to capture locals and arguments. - - Returns: - (arguments, locals) tuple. - """ - # Capture all local variables (including method arguments). - variables = {n: self.CaptureNamedVariable(n, v) - for n, v in frame.f_locals.viewitems()} - - # Split between locals and arguments (keeping arguments in the right order). - nargs = frame.f_code.co_argcount - if frame.f_code.co_flags & inspect.CO_VARARGS: nargs += 1 - if frame.f_code.co_flags & inspect.CO_VARKEYWORDS: nargs += 1 - - frame_arguments = [] - for argname in frame.f_code.co_varnames[:nargs]: - if argname in variables: frame_arguments.append(variables.pop(argname)) - - return (frame_arguments, list(variables.viewvalues())) - - def CaptureNamedVariable(self, name, value, depth=1): - """Appends name to the product of CaptureVariable. - - Args: - name: name of the variable. - value: data to capture - depth: nested depth of dictionaries and vectors so far. - - Returns: - Formatted captured data as per Variable proto with name. - """ - if not hasattr(name, '__dict__'): - name = str(name) - else: # TODO(vlif): call str(name) with immutability verifier here. - name = str(id(name)) - self._total_size += len(name) - - v = self.CaptureVariable(value, depth) - v['name'] = name - return v - - def CaptureVariablesList(self, items, depth, empty_message): - """Captures list of named items. - - Args: - items: iterable of (name, value) tuples. - depth: nested depth of dictionaries and vectors for items. - empty_message: info status message to set if items is empty. - - Returns: - List of formatted variable objects. - """ - v = [] - for name, value in items: - if (self._total_size >= self.max_size) or (len(v) >= self.max_list_items): - v.append({ - 'status': { - 'refers_to': 'VARIABLE_VALUE', - 'description': { - 'format': 'Only first $0 items were captured', - 'parameters': [str(len(v))]}}}) - break - v.append(self.CaptureNamedVariable(name, value, depth)) - - if not v: - return [{'status': { - 'is_error': False, - 'refers_to': 'VARIABLE_NAME', - 'description': {'format': empty_message}}}] - - return v - - def CaptureVariable(self, value, depth=1, can_enqueue=True): - """Captures a single nameless object into Variable message. - - TODO(vlif): safely evaluate iterable types. - TODO(vlif): safely call str(value) - - Args: - value: data to capture - depth: nested depth of dictionaries and vectors so far. - can_enqueue: allows referencing the object in variables table. - - Returns: - Formatted captured data as per Variable proto. - """ - if depth == self.max_depth: - return {'varTableIndex': 0} # Buffer full. - - if value is None: - self._total_size += 4 - return {'value': 'None'} - - if isinstance(value, _PRIMITIVE_TYPES): - r = _TrimString(repr(value), # Primitive type, always immutable. - self.max_value_len) - self._total_size += len(r) - return {'value': r, 'type': type(value).__name__} - - if isinstance(value, _DATE_TYPES): - r = str(value) # Safe to call str(). - self._total_size += len(r) - return {'value': r, 'type': 'datetime.'+ type(value).__name__} - - if isinstance(value, dict): - return {'members': self.CaptureVariablesList(value.iteritems(), - depth + 1, - EMPTY_DICTIONARY), - 'type': 'dict'} - - if isinstance(value, _VECTOR_TYPES): - fields = self.CaptureVariablesList( - (('[%d]' % i, x) for i, x in enumerate(value)), - depth + 1, - EMPTY_COLLECTION) - return {'members': fields, 'type': type(value).__name__} - - if isinstance(value, types.FunctionType): - self._total_size += len(value.func_name) - # TODO(vlif): set value to func_name and type to 'function' - return {'value': 'function ' + value.func_name} - - if can_enqueue: - index = self._var_table_index.get(id(value)) - if index is None: - index = len(self._var_table) - self._var_table_index[id(value)] = index - self._var_table.append(value) - self._total_size += 4 # number of characters to accomodate a number. - return {'varTableIndex': index} - - for pretty_printer in CaptureCollector.pretty_printers: - pretty_value = pretty_printer(value) - if not pretty_value: - continue - - fields, object_type = pretty_value - return {'members': self.CaptureVariablesList(fields, - depth + 1, - OBJECT_HAS_NO_FIELDS), - 'type': object_type} - - if not hasattr(value, '__dict__'): - # TODO(vlif): keep "value" empty and populate the "type" field instead. - r = str(type(value)) - self._total_size += len(r) - return {'value': r} - - if value.__dict__: - v = self.CaptureVariable(value.__dict__, depth + 1) - else: - v = {'members': - [ - {'status': { - 'is_error': False, - 'refers_to': 'VARIABLE_NAME', - 'description': {'format': OBJECT_HAS_NO_FIELDS}}} - ]} - - object_type = type(value) - if hasattr(object_type, '__name__'): - type_string = getattr(object_type, '__module__', '') - if type_string: - type_string += '.' - type_string += object_type.__name__ - v['type'] = type_string - - return v - - def _CaptureExpression(self, frame, expression): - """Evalutes the expression and captures it into a Variable object. - - Args: - frame: evaluation context. - expression: watched expression to compile and evaluate. - - Returns: - Variable object (which will have error status if the expression fails - to evaluate). - """ - rc, value = _EvaluateExpression(frame, expression) - if not rc: - return {'name': expression, 'status': value} - - return self.CaptureNamedVariable(expression, value) - - def TrimVariableTable(self, new_size): - """Trims the variable table in the formatted breakpoint message. - - Removes trailing entries in variables table. Then scans the entire - breakpoint message and replaces references to the trimmed variables to - point to var_index of 0 ("buffer full"). - - Args: - new_size: desired size of variables table. - """ - - def ProcessBufferFull(variables): - for variable in variables: - var_index = variable.get('varTableIndex') - if var_index is not None and (var_index >= new_size): - variable['varTableIndex'] = 0 # Buffer full. - members = variable.get('members') - if members is not None: - ProcessBufferFull(members) - - del self._var_table[new_size:] - for stack_frame in self.breakpoint['stackFrames']: - ProcessBufferFull(stack_frame['arguments']) - ProcessBufferFull(stack_frame['locals']) - ProcessBufferFull(self._var_table) - ProcessBufferFull(self.breakpoint['evaluatedExpressions']) - - @staticmethod - def _NormalizePath(path): - """Converts an absolute path to a relative one. - - Python keeps almost all paths absolute. This is not what we actually - want to return. This loops through system paths (directories in which - Python will load modules). If "path" is relative to one of them, the - directory prefix is removed. - - Args: - path: absolute path to normalize (relative paths will not be altered) - - Returns: - Relative path if "path" is within one of the sys.path directories or - the input otherwise. - """ - path = os.path.normpath(path) - - for sys_path in sys.path: - if not sys_path: - continue - - # Append '/' at the end of the path if it's not there already. - sys_path = os.path.join(sys_path, '') - - if path.startswith(sys_path): - return path[len(sys_path):] - - return path - - -class LogCollector(object): - """Captures minimal application snapshot and logs it to application log. - - This is similar to CaptureCollector, but we don't need to capture local - variables, arguments and the objects tree. All we need to do is to format a - log message. We still need to evaluate watched expressions. - - The actual log functions are defined globally outside of this module. - """ - - def __init__(self, definition): - """Class constructor. - - Args: - definition: breakpoint definition indicating log level, message, etc. - """ - self._definition = definition - - # Maximum number of character to allow for a single value. Longer strings - # are truncated. - self.max_value_len = 256 - - # Maximum number of items in a list to capture. - self.max_list_items = 10 - - # Select log function. - level = self._definition.get('logLevel') - if not level or level == 'INFO': - self._log_message = log_info_message - elif level == 'WARNING': - self._log_message = log_warning_message - elif level == 'ERROR': - self._log_message = log_error_message - else: - self._log_message = None - - def Log(self, frame): - """Captures the minimal application states, formats it and logs the message. - - Args: - frame: Python stack frame of breakpoint hit. - - Returns: - None on success or status message on error. - """ - # Return error if log methods were not configured globally. - if not self._log_message: - return {'isError': True, - 'description': {'format': LOG_ACTION_NOT_SUPPORTED}} - - # Evaluate watched expressions. - message = _FormatMessage( - self._definition.get('logMessageFormat', ''), - self._EvaluateExpressions(frame)) - - self._log_message(message) - return None - - def _EvaluateExpressions(self, frame): - """Evaluates watched expressions into a string form. - - If expression evaluation fails, the error message is used as evaluated - expression string. - - Args: - frame: Python stack frame of breakpoint hit. - - Returns: - Array of strings where each string corresponds to the breakpoint - expression with the same index. - """ - return [self._FormatExpression(frame, expression) for expression in - self._definition.get('expressions') or []] - - def _FormatExpression(self, frame, expression): - """Evaluates a single watched expression and formats it into a string form. - - If expression evaluation fails, returns error message string. - - Args: - frame: Python stack frame in which the expression is evaluated. - expression: string expression to evaluate. - - Returns: - Formatted expression value that can be used in the log message. - """ - rc, value = _EvaluateExpression(frame, expression) - if not rc: - message = _FormatMessage(value['description']['format'], - value['description'].get('parameters')) - return '<' + message + '>' - - return self._FormatValue(value) - - def _FormatValue(self, value, level=0): - """Pretty-prints an object for a logger. - - This function is very similar to the standard pprint. The main difference - is that it enforces limits to make sure we never produce an extremely long - string or take too much time. - - Args: - value: Python object to print. - level: current recursion level. - - Returns: - Formatted string. - """ - - def FormatDictItem(key_value): - """Formats single dictionary item.""" - key, value = key_value - return (self._FormatValue(key, level + 1) + - ': ' + - self._FormatValue(value, level + 1)) - - def LimitedEnumerate(items, formatter): - """Returns items in the specified enumerable enforcing threshold.""" - count = 0 - for item in items: - if count == self.max_list_items: - yield '...' - break - - yield formatter(item) - count += 1 - - def FormatList(items, formatter): - """Formats a list using a custom item formatter enforcing threshold.""" - return ', '.join(LimitedEnumerate(items, formatter)) - - if isinstance(value, _PRIMITIVE_TYPES): - return _TrimString(repr(value), # Primitive type, always immutable. - self.max_value_len) - - if isinstance(value, _DATE_TYPES): - return str(value) - - if level > 0: - return str(type(value)) - - if isinstance(value, dict): - return '{' + FormatList(value.iteritems(), FormatDictItem) + '}' - - if isinstance(value, _VECTOR_TYPES): - return FormatList(value, lambda item: self._FormatValue(item, level + 1)) - - if isinstance(value, types.FunctionType): - return 'function ' + value.func_name - - if hasattr(value, '__dict__') and value.__dict__: - return self._FormatValue(value.__dict__, level) - - return str(type(value)) - - -def _EvaluateExpression(frame, expression): - """Compiles and evaluates watched expression. - - Args: - frame: evaluation context. - expression: watched expression to compile and evaluate. - - Returns: - (False, status) on error or (True, value) on success. - """ - try: - code = compile(expression, '', 'eval') - except TypeError as e: # condition string contains null bytes. - return (False, { - 'isError': True, - 'refersTo': 'VARIABLE_NAME', - 'description': { - 'format': 'Invalid expression', - 'parameters': [str(e)]}}) - except SyntaxError as e: - return (False, { - 'isError': True, - 'refersTo': 'VARIABLE_NAME', - 'description': { - 'format': 'Expression could not be compiled: $0', - 'parameters': [e.msg]}}) - - try: - return (True, native.CallImmutable(frame, code)) - except BaseException as e: - return (False, { - 'isError': True, - 'refersTo': 'VARIABLE_VALUE', - 'description': { - 'format': 'Exception occurred: $0', - 'parameters': [e.message]}}) - - -def _FormatMessage(template, parameters): - """Formats the message. - - Args: - template: message template (e.g. 'a = $0, b = $1'). - parameters: substitution parameters for the format. - - Returns: - Formatted message with parameters embedded in template placeholders. - """ - def GetParameter(m): - try: - return parameters[int(m.group(0)[1:])] - except IndexError: - return INVALID_EXPRESSION_INDEX - - return re.sub(r'\$\d+', GetParameter, template) - - -def _TrimString(s, max_len): - """Trims the string if it exceeds max_len.""" - if len(s) <= max_len: - return s - return s[:max_len+1] + '...' diff --git a/src/googleclouddebugger/collector.py b/src/googleclouddebugger/collector.py new file mode 100644 index 0000000..82916ab --- /dev/null +++ b/src/googleclouddebugger/collector.py @@ -0,0 +1,983 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Captures application state on a breakpoint hit.""" + +import copy +import datetime +import inspect +import itertools +import logging +import os +import re +import sys +import time +import types + +from . import cdbg_native as native +from . import labels + +# Externally defined functions to actually log a message. If these variables +# are not initialized, the log action for breakpoints is invalid. +log_info_message = None +log_warning_message = None +log_error_message = None + +# Externally defined function to collect the request log id. +request_log_id_collector = None + +# Externally defined function to collect the end user id. +user_id_collector = lambda: (None, None) + +# Externally defined function to collect the end user id. +breakpoint_labels_collector = lambda: {} + +_PRIMITIVE_TYPES = (type(None), float, complex, bool, slice, bytearray, str, + bytes, int) +_DATE_TYPES = (datetime.date, datetime.time, datetime.timedelta) +_VECTOR_TYPES = (tuple, list, set) + +# TODO: move to messages.py module. +EMPTY_DICTIONARY = 'Empty dictionary' +EMPTY_COLLECTION = 'Empty collection' +OBJECT_HAS_NO_FIELDS = 'Object has no fields' +LOG_ACTION_NOT_SUPPORTED = 'Log action on a breakpoint not supported' +INVALID_EXPRESSION_INDEX = '' +DYNAMIC_LOG_OUT_OF_QUOTA = ( + 'LOGPOINT: Logpoint is paused due to high log rate until log ' + 'quota is restored') + + +def _ListTypeFormatString(value): + """Returns the appropriate format string for formatting a list object.""" + + if isinstance(value, tuple): + return '({0})' + if isinstance(value, set): + return '{{{0}}}' + return '[{0}]' + + +def NormalizePath(path): + """Removes any Python system path prefix from the given path. + + Python keeps almost all paths absolute. This is not what we actually + want to return. This loops through system paths (directories in which + Python will load modules). If "path" is relative to one of them, the + directory prefix is removed. + + Args: + path: absolute path to normalize (relative paths will not be altered) + + Returns: + Relative path if "path" is within one of the sys.path directories or + the input otherwise. + """ + path = os.path.normpath(path) + + for sys_path in sys.path: + if not sys_path: + continue + + # Append '/' at the end of the path if it's not there already. + sys_path = os.path.join(sys_path, '') + + if path.startswith(sys_path): + return path[len(sys_path):] + + return path + + +def DetermineType(value): + """Determines the type of val, returning a "full path" string. + + For example: + DetermineType(5) -> __builtin__.int + DetermineType(Foo()) -> com.google.bar.Foo + + Args: + value: Any value, the value is irrelevant as only the type metadata + is checked + + Returns: + Type path string. None if type cannot be determined. + """ + + object_type = type(value) + if not hasattr(object_type, '__name__'): + return None + + type_string = getattr(object_type, '__module__', '') + if type_string: + type_string += '.' + + type_string += object_type.__name__ + return type_string + + +class LineNoFilter(logging.Filter): + """Enables overriding the path and line number in a logging record. + + The "extra" parameter in logging cannot override existing fields in log + record, so we can't use it to directly set pathname and lineno. Instead, + we add this filter to the default logger, and it looks for "cdbg_pathname" + and "cdbg_lineno", moving them to the pathname and lineno fields accordingly. + """ + + def filter(self, record): + # This method gets invoked for user-generated logging, so verify that this + # particular invocation came from our logging code. + if record.pathname != inspect.currentframe().f_code.co_filename: + return True + pathname, lineno, func_name = GetLoggingLocation() + if pathname: + record.pathname = pathname + record.filename = os.path.basename(pathname) + record.lineno = lineno + record.funcName = func_name + return True + + +def GetLoggingLocation(): + """Search for and return the file and line number from the log collector. + + Returns: + (pathname, lineno, func_name) The full path, line number, and function name + for the logpoint location. + """ + frame = inspect.currentframe() + this_file = frame.f_code.co_filename + frame = frame.f_back + while frame: + if this_file == frame.f_code.co_filename: + if 'cdbg_logging_location' in frame.f_locals: + ret = frame.f_locals['cdbg_logging_location'] + if len(ret) != 3: + return (None, None, None) + return ret + frame = frame.f_back + return (None, None, None) + + +def SetLogger(logger): + """Sets the logger object to use for all 'LOG' breakpoint actions.""" + global log_info_message + global log_warning_message + global log_error_message + log_info_message = logger.info + log_warning_message = logger.warning + log_error_message = logger.error + logger.addFilter(LineNoFilter()) + + +class _CaptureLimits(object): + """Limits for variable capture. + + Args: + max_value_len: Maximum number of character to allow for a single string + value. Longer strings are truncated. + max_list_items: Maximum number of items in a list to capture. + max_depth: Maximum depth of dictionaries to capture. + """ + + def __init__(self, max_value_len=256, max_list_items=25, max_depth=5): + self.max_value_len = max_value_len + self.max_list_items = max_list_items + self.max_depth = max_depth + + +class CaptureCollector(object): + """Captures application state snapshot. + + Captures call stack, local variables and referenced objects. Then formats the + result to be sent back to the user. + + The performance of this class is important. Once the breakpoint hits, the + completion of the user request will be delayed until the collection is over. + It might make sense to implement this logic in C++. + + Attributes: + breakpoint: breakpoint definition augmented with captured call stack, + local variables, arguments and referenced objects. + """ + + # Additional type-specific printers. Each pretty printer is a callable + # that returns None if it doesn't recognize the object or returns a tuple + # with iterable enumerating object fields (name-value tuple) and object type + # string. + pretty_printers = [] + + def __init__(self, definition, data_visibility_policy): + """Class constructor. + + Args: + definition: breakpoint definition that this class will augment with + captured data. + data_visibility_policy: An object used to determine the visibiliy + of a captured variable. May be None if no policy is available. + """ + self.data_visibility_policy = data_visibility_policy + + self.breakpoint = copy.deepcopy(definition) + + self.breakpoint['stackFrames'] = [] + self.breakpoint['evaluatedExpressions'] = [] + self.breakpoint['variableTable'] = [{ + 'status': { + 'isError': True, + 'refersTo': 'VARIABLE_VALUE', + 'description': { + 'format': 'Buffer full. Use an expression to see more data' + } + } + }] + + # Shortcut to variables table in the breakpoint message. + self._var_table = self.breakpoint['variableTable'] + + # Maps object ID to its index in variables table. + self._var_table_index = {} + + # Total size of data collected so far. Limited by max_size. + self._total_size = 0 + + # Maximum number of stack frame to capture. The limit is aimed to reduce + # the overall collection time. + self.max_frames = 20 + + # Only collect locals and arguments on the few top frames. For the rest of + # the frames we only collect the source location. + self.max_expand_frames = 5 + + # Maximum amount of data to capture. The application will usually have a + # lot of objects and we need to stop somewhere to keep the delay + # reasonable. + # This constant only counts the collected payload. Overhead due to key + # names is not counted. + self.max_size = 32768 # 32 KB + + self.default_capture_limits = _CaptureLimits() + + # When the user provides an expression, they've indicated that they're + # interested in some specific data. Use higher per-object capture limits + # for expressions. We don't want to globally increase capture limits, + # because in the case where the user has not indicated a preference, we + # don't want a single large object on the stack to use the entire max_size + # quota and hide the rest of the data. + self.expression_capture_limits = _CaptureLimits( + max_value_len=32768, max_list_items=32768) + + def Collect(self, top_frame): + """Collects call stack, local variables and objects. + + Starts collection from the specified frame. We don't start from the top + frame to exclude the frames due to debugger. Updates the content of + self.breakpoint. + + Args: + top_frame: top frame to start data collection. + """ + # Evaluate call stack. + frame = top_frame + top_line = self.breakpoint['location']['line'] + breakpoint_frames = self.breakpoint['stackFrames'] + try: + # Evaluate watched expressions. + if 'expressions' in self.breakpoint: + self.breakpoint['evaluatedExpressions'] = [ + self._CaptureExpression(top_frame, expression) + for expression in self.breakpoint['expressions'] + ] + + while frame and (len(breakpoint_frames) < self.max_frames): + line = top_line if frame == top_frame else frame.f_lineno + code = frame.f_code + if len(breakpoint_frames) < self.max_expand_frames: + frame_arguments, frame_locals = self.CaptureFrameLocals(frame) + else: + frame_arguments = [] + frame_locals = [] + + breakpoint_frames.append({ + 'function': _GetFrameCodeObjectName(frame), + 'location': { + 'path': NormalizePath(code.co_filename), + 'line': line + }, + 'arguments': frame_arguments, + 'locals': frame_locals + }) + frame = frame.f_back + + except BaseException as e: # pylint: disable=broad-except + # The variable table will get serialized even though there was a failure. + # The results can be useful for diagnosing the internal error. + self.breakpoint['status'] = { + 'isError': True, + 'description': { + 'format': ('INTERNAL ERROR: Failed while capturing locals ' + 'of frame $0: $1'), + 'parameters': [str(len(breakpoint_frames)), + str(e)] + } + } + + # Number of entries in _var_table. Starts at 1 (index 0 is the 'buffer full' + # status value). + num_vars = 1 + + # Explore variables table in BFS fashion. The variables table will grow + # inside CaptureVariable as we encounter new references. + while (num_vars < len(self._var_table)) and (self._total_size < + self.max_size): + self._var_table[num_vars] = self.CaptureVariable( + self._var_table[num_vars], + 0, + self.default_capture_limits, + can_enqueue=False) + + # Move on to the next entry in the variable table. + num_vars += 1 + + # Trim variables table and change make all references to variables that + # didn't make it point to var_index of 0 ("buffer full") + self.TrimVariableTable(num_vars) + + self._CaptureEnvironmentLabels() + self._CaptureRequestLogId() + self._CaptureUserId() + + def CaptureFrameLocals(self, frame): + """Captures local variables and arguments of the specified frame. + + Args: + frame: frame to capture locals and arguments. + + Returns: + (arguments, locals) tuple. + """ + # Capture all local variables (including method arguments). + variables = { + n: self.CaptureNamedVariable(n, v, 1, self.default_capture_limits) + for n, v in frame.f_locals.items() + } + + # Split between locals and arguments (keeping arguments in the right order). + nargs = frame.f_code.co_argcount + if frame.f_code.co_flags & inspect.CO_VARARGS: + nargs += 1 + if frame.f_code.co_flags & inspect.CO_VARKEYWORDS: + nargs += 1 + + frame_arguments = [] + for argname in frame.f_code.co_varnames[:nargs]: + if argname in variables: + frame_arguments.append(variables.pop(argname)) + + return (frame_arguments, list(variables.values())) + + def CaptureNamedVariable(self, name, value, depth, limits): + """Appends name to the product of CaptureVariable. + + Args: + name: name of the variable. + value: data to capture + depth: nested depth of dictionaries and vectors so far. + limits: Per-object limits for capturing variable data. + + Returns: + Formatted captured data as per Variable proto with name. + """ + if not hasattr(name, '__dict__'): + name = str(name) + else: # TODO: call str(name) with immutability verifier here. + name = str(id(name)) + self._total_size += len(name) + + v = ( + self.CheckDataVisibility(value) or + self.CaptureVariable(value, depth, limits)) + v['name'] = name + return v + + def CheckDataVisibility(self, value): + """Returns a status object if the given name is not visible. + + Args: + value: The value to check. The actual value here is not important but the + value's metadata (e.g. package and type) will be checked. + + Returns: + None if the value is visible. A variable structure with an error status + if the value should not be visible. + """ + if not self.data_visibility_policy: + return None + + visible, reason = self.data_visibility_policy.IsDataVisible( + DetermineType(value)) + + if visible: + return None + + return { + 'status': { + 'isError': True, + 'refersTo': 'VARIABLE_NAME', + 'description': { + 'format': reason + } + } + } + + def CaptureVariablesList(self, items, depth, empty_message, limits): + """Captures list of named items. + + Args: + items: iterable of (name, value) tuples. + depth: nested depth of dictionaries and vectors for items. + empty_message: info status message to set if items is empty. + limits: Per-object limits for capturing variable data. + + Returns: + List of formatted variable objects. + """ + v = [] + for name, value in items: + if (self._total_size >= self.max_size) or (len(v) >= + limits.max_list_items): + v.append({ + 'status': { + 'refersTo': 'VARIABLE_VALUE', + 'description': { + 'format': ('Only first $0 items were captured. Use in an ' + 'expression to see all items.'), + 'parameters': [str(len(v))] + } + } + }) + break + v.append(self.CaptureNamedVariable(name, value, depth, limits)) + + if not v: + return [{ + 'status': { + 'refersTo': 'VARIABLE_NAME', + 'description': { + 'format': empty_message + } + } + }] + + return v + + def CaptureVariable(self, value, depth, limits, can_enqueue=True): + """Try-Except wrapped version of CaptureVariableInternal.""" + try: + return self.CaptureVariableInternal(value, depth, limits, can_enqueue) + except BaseException as e: # pylint: disable=broad-except + return { + 'status': { + 'isError': True, + 'refersTo': 'VARIABLE_VALUE', + 'description': { + 'format': ('Failed to capture variable: $0'), + 'parameters': [str(e)] + } + } + } + + def CaptureVariableInternal(self, value, depth, limits, can_enqueue=True): + """Captures a single nameless object into Variable message. + + TODO: safely evaluate iterable types. + TODO: safely call str(value) + + Args: + value: data to capture + depth: nested depth of dictionaries and vectors so far. + limits: Per-object limits for capturing variable data. + can_enqueue: allows referencing the object in variables table. + + Returns: + Formatted captured data as per Variable proto. + """ + if depth == limits.max_depth: + return {'varTableIndex': 0} # Buffer full. + + if value is None: + self._total_size += 4 + return {'value': 'None'} + + if isinstance(value, _PRIMITIVE_TYPES): + r = _TrimString( + repr(value), # Primitive type, always immutable. + min(limits.max_value_len, self.max_size - self._total_size)) + self._total_size += len(r) + return {'value': r, 'type': type(value).__name__} + + if isinstance(value, _DATE_TYPES): + r = str(value) # Safe to call str(). + self._total_size += len(r) + return {'value': r, 'type': 'datetime.' + type(value).__name__} + + if isinstance(value, dict): + # Do not use iteritems() here. If GC happens during iteration (which it + # often can for dictionaries containing large variables), you will get a + # RunTimeError exception. + items = [(repr(k), v) for (k, v) in value.items()] + return { + 'members': + self.CaptureVariablesList(items, depth + 1, EMPTY_DICTIONARY, + limits), + 'type': + 'dict' + } + + if isinstance(value, _VECTOR_TYPES): + fields = self.CaptureVariablesList( + (('[%d]' % i, x) for i, x in enumerate(value)), depth + 1, + EMPTY_COLLECTION, limits) + return {'members': fields, 'type': type(value).__name__} + + if isinstance(value, types.FunctionType): + self._total_size += len(value.__name__) + # TODO: set value to func_name and type to 'function' + return {'value': 'function ' + value.__name__} + + if isinstance(value, Exception): + fields = self.CaptureVariablesList( + (('[%d]' % i, x) for i, x in enumerate(value.args)), depth + 1, + EMPTY_COLLECTION, limits) + return {'members': fields, 'type': type(value).__name__} + + if can_enqueue: + index = self._var_table_index.get(id(value)) + if index is None: + index = len(self._var_table) + self._var_table_index[id(value)] = index + self._var_table.append(value) + self._total_size += 4 # number of characters to accommodate a number. + return {'varTableIndex': index} + + for pretty_printer in CaptureCollector.pretty_printers: + pretty_value = pretty_printer(value) + if not pretty_value: + continue + + fields, object_type = pretty_value + return { + 'members': + self.CaptureVariablesList(fields, depth + 1, OBJECT_HAS_NO_FIELDS, + limits), + 'type': + object_type + } + + if not hasattr(value, '__dict__'): + # TODO: keep "value" empty and populate the "type" field instead. + r = str(type(value)) + self._total_size += len(r) + return {'value': r} + + # Add an additional depth for the object itself + items = value.__dict__.items() + # Make a list of the iterator in Python 3, to avoid 'dict changed size + # during iteration' errors from GC happening in the middle. + # Only limits.max_list_items + 1 items are copied, anything past that will + # get ignored by CaptureVariablesList(). + items = list(itertools.islice(items, limits.max_list_items + 1)) + members = self.CaptureVariablesList(items, depth + 2, OBJECT_HAS_NO_FIELDS, + limits) + v = {'members': members} + + type_string = DetermineType(value) + if type_string: + v['type'] = type_string + + return v + + def _CaptureExpression(self, frame, expression): + """Evalutes the expression and captures it into a Variable object. + + Args: + frame: evaluation context. + expression: watched expression to compile and evaluate. + + Returns: + Variable object (which will have error status if the expression fails + to evaluate). + """ + rc, value = _EvaluateExpression(frame, expression) + if not rc: + return {'name': expression, 'status': value} + + return self.CaptureNamedVariable(expression, value, 0, + self.expression_capture_limits) + + def TrimVariableTable(self, new_size): + """Trims the variable table in the formatted breakpoint message. + + Removes trailing entries in variables table. Then scans the entire + breakpoint message and replaces references to the trimmed variables to + point to var_index of 0 ("buffer full"). + + Args: + new_size: desired size of variables table. + """ + + def ProcessBufferFull(variables): + for variable in variables: + var_index = variable.get('varTableIndex') + if var_index is not None and (var_index >= new_size): + variable['varTableIndex'] = 0 # Buffer full. + members = variable.get('members') + if members is not None: + ProcessBufferFull(members) + + del self._var_table[new_size:] + ProcessBufferFull(self.breakpoint['evaluatedExpressions']) + for stack_frame in self.breakpoint['stackFrames']: + ProcessBufferFull(stack_frame['arguments']) + ProcessBufferFull(stack_frame['locals']) + ProcessBufferFull(self._var_table) + + def _CaptureEnvironmentLabels(self): + """Captures information about the environment, if possible.""" + if 'labels' not in self.breakpoint: + self.breakpoint['labels'] = {} + + if callable(breakpoint_labels_collector): + for (key, value) in breakpoint_labels_collector().items(): + self._StoreLabel(key, value) + + def _CaptureRequestLogId(self): + """Captures the request log id if possible. + + The request log id is stored inside the breakpoint labels. + """ + # pylint: disable=not-callable + if callable(request_log_id_collector): + request_log_id = request_log_id_collector() + if request_log_id: + # We have a request_log_id, save it into the breakpoint labels + self._StoreLabel(labels.Breakpoint.REQUEST_LOG_ID, request_log_id) + + def _CaptureUserId(self): + """Captures the user id of the end user, if possible.""" + user_kind, user_id = user_id_collector() + if user_kind and user_id: + self.breakpoint['evaluatedUserId'] = {'kind': user_kind, 'id': user_id} + + def _StoreLabel(self, name, value): + """Stores the specified label in the breakpoint's labels. + + In the event of a duplicate label, favour the pre-existing labels. This + generally should not be an issue as the pre-existing client label names are + chosen with care and there should be no conflicts. + + Args: + name: The name of the label to be stored. + value: The value of the label to be stored. + """ + if name not in self.breakpoint['labels']: + self.breakpoint['labels'][name] = value + + +class LogCollector(object): + """Captures minimal application snapshot and logs it to application log. + + This is similar to CaptureCollector, but we don't need to capture local + variables, arguments and the objects tree. All we need to do is to format a + log message. We still need to evaluate watched expressions. + + The actual log functions are defined globally outside of this module. + """ + + def __init__(self, definition): + """Class constructor. + + Args: + definition: breakpoint definition indicating log level, message, etc. + """ + self._definition = definition + + # Maximum number of character to allow for a single value. Longer strings + # are truncated. + self.max_value_len = 256 + + # Maximum recursion depth. + self.max_depth = 2 + + # Maximum number of items in a list to capture at the top level. + self.max_list_items = 10 + + # When capturing recursively, limit on the size of sublists. + self.max_sublist_items = 5 + + # Time to pause after dynamic log quota has run out. + self.quota_recovery_ms = 500 + + # The time when we first entered the quota period + self._quota_recovery_start_time = None + + # Select log function. + level = self._definition.get('logLevel') + if not level or level == 'INFO': + self._log_message = log_info_message + elif level == 'WARNING': + self._log_message = log_warning_message + elif level == 'ERROR': + self._log_message = log_error_message + else: + self._log_message = None + + def Log(self, frame): + """Captures the minimal application states, formats it and logs the message. + + Args: + frame: Python stack frame of breakpoint hit. + + Returns: + None on success or status message on error. + """ + # Return error if log methods were not configured globally. + if not self._log_message: + return { + 'isError': True, + 'description': { + 'format': LOG_ACTION_NOT_SUPPORTED + } + } + + if self._quota_recovery_start_time: + ms_elapsed = (time.time() - self._quota_recovery_start_time) * 1000 + if ms_elapsed > self.quota_recovery_ms: + # We are out of the recovery period, clear the time and continue + self._quota_recovery_start_time = None + else: + # We are in the recovery period, exit + return + + # Evaluate watched expressions. + message = 'LOGPOINT: ' + _FormatMessage( + self._definition.get('logMessageFormat', ''), + self._EvaluateExpressions(frame)) + + line = self._definition['location']['line'] + cdbg_logging_location = (NormalizePath(frame.f_code.co_filename), line, + _GetFrameCodeObjectName(frame)) + + if native.ApplyDynamicLogsQuota(len(message)): + self._log_message(message) + else: + self._quota_recovery_start_time = time.time() + self._log_message(DYNAMIC_LOG_OUT_OF_QUOTA) + del cdbg_logging_location + return None + + def _EvaluateExpressions(self, frame): + """Evaluates watched expressions into a string form. + + If expression evaluation fails, the error message is used as evaluated + expression string. + + Args: + frame: Python stack frame of breakpoint hit. + + Returns: + Array of strings where each string corresponds to the breakpoint + expression with the same index. + """ + return [ + self._FormatExpression(frame, expression) + for expression in self._definition.get('expressions') or [] + ] + + def _FormatExpression(self, frame, expression): + """Evaluates a single watched expression and formats it into a string form. + + If expression evaluation fails, returns error message string. + + Args: + frame: Python stack frame in which the expression is evaluated. + expression: string expression to evaluate. + + Returns: + Formatted expression value that can be used in the log message. + """ + rc, value = _EvaluateExpression(frame, expression) + if not rc: + message = _FormatMessage(value['description']['format'], + value['description'].get('parameters')) + return '<' + message + '>' + + return self._FormatValue(value) + + def _FormatValue(self, value, level=0): + """Pretty-prints an object for a logger. + + This function is very similar to the standard pprint. The main difference + is that it enforces limits to make sure we never produce an extremely long + string or take too much time. + + Args: + value: Python object to print. + level: current recursion level. + + Returns: + Formatted string. + """ + + def FormatDictItem(key_value): + """Formats single dictionary item.""" + key, value = key_value + return (self._FormatValue(key, level + 1) + ': ' + + self._FormatValue(value, level + 1)) + + def LimitedEnumerate(items, formatter, level=0): + """Returns items in the specified enumerable enforcing threshold.""" + count = 0 + limit = self.max_sublist_items if level > 0 else self.max_list_items + for item in items: + if count == limit: + yield '...' + break + + yield formatter(item) + count += 1 + + def FormatList(items, formatter, level=0): + """Formats a list using a custom item formatter enforcing threshold.""" + return ', '.join(LimitedEnumerate(items, formatter, level=level)) + + if isinstance(value, _PRIMITIVE_TYPES): + return _TrimString( + repr(value), # Primitive type, always immutable. + self.max_value_len) + + if isinstance(value, _DATE_TYPES): + return str(value) + + if level > self.max_depth: + return str(type(value)) + + if isinstance(value, dict): + return '{' + FormatList(value.items(), FormatDictItem) + '}' + + if isinstance(value, _VECTOR_TYPES): + return _ListTypeFormatString(value).format( + FormatList( + value, + lambda item: self._FormatValue(item, level + 1), + level=level)) + + if isinstance(value, types.FunctionType): + return 'function ' + value.__name__ + + if hasattr(value, '__dict__') and value.__dict__: + return self._FormatValue(value.__dict__, level) + + return str(type(value)) + + +def _EvaluateExpression(frame, expression): + """Compiles and evaluates watched expression. + + Args: + frame: evaluation context. + expression: watched expression to compile and evaluate. + + Returns: + (False, status) on error or (True, value) on success. + """ + try: + code = compile(expression, '', 'eval') + except (TypeError, ValueError) as e: + # expression string contains null bytes. + return (False, { + 'isError': True, + 'refersTo': 'VARIABLE_NAME', + 'description': { + 'format': 'Invalid expression', + 'parameters': [str(e)] + } + }) + except SyntaxError as e: + return (False, { + 'isError': True, + 'refersTo': 'VARIABLE_NAME', + 'description': { + 'format': 'Expression could not be compiled: $0', + 'parameters': [e.msg] + } + }) + + try: + return (True, native.CallImmutable(frame, code)) + except BaseException as e: # pylint: disable=broad-except + return (False, { + 'isError': True, + 'refersTo': 'VARIABLE_VALUE', + 'description': { + 'format': 'Exception occurred: $0', + 'parameters': [str(e)] + } + }) + + +def _GetFrameCodeObjectName(frame): + """Gets the code object name for the frame. + + Args: + frame: the frame to get the name from + + Returns: + The function name if the code is a static function or the class name with + the method name if it is an member function. + """ + # This functions under the assumption that member functions will name their + # first parameter argument 'self' but has some edge-cases. + if frame.f_code.co_argcount >= 1 and 'self' == frame.f_code.co_varnames[0]: + return (frame.f_locals['self'].__class__.__name__ + '.' + + frame.f_code.co_name) + else: + return frame.f_code.co_name + + +def _FormatMessage(template, parameters): + """Formats the message. Unescapes '$$' with '$'. + + Args: + template: message template (e.g. 'a = $0, b = $1'). + parameters: substitution parameters for the format. + + Returns: + Formatted message with parameters embedded in template placeholders. + """ + + def GetParameter(m): + try: + return parameters[int(m.group(0)[1:])] + except IndexError: + return INVALID_EXPRESSION_INDEX + + parts = template.split('$$') + return '$'.join(re.sub(r'\$\d+', GetParameter, part) for part in parts) + + +def _TrimString(s, max_len): + """Trims the string if it exceeds max_len.""" + if len(s) <= max_len: + return s + return s[:max_len + 1] + '...' diff --git a/src/googleclouddebugger/common.h b/src/googleclouddebugger/common.h index 2cd1ed5..59d5255 100644 --- a/src/googleclouddebugger/common.h +++ b/src/googleclouddebugger/common.h @@ -17,8 +17,6 @@ #ifndef DEVTOOLS_CDBG_DEBUGLETS_PYTHON_COMMON_H_ #define DEVTOOLS_CDBG_DEBUGLETS_PYTHON_COMMON_H_ - -// // Open source includes and definition of common constants. // @@ -31,6 +29,7 @@ #include #include +#include #include #include "glog/logging.h" @@ -60,5 +59,51 @@ using google::LogSeverity; using google::AddLogSink; using google::RemoveLogSink; +// The open source build uses gflags, which uses the traditional (v1) flags APIs +// to define/declare/access command line flags. The internal build has upgraded +// to use v2 flags API (DEFINE_FLAG/DECLARE_FLAG/GetFlag/SetFlag), which is not +// supported by gflags yet (and absl is not released to open source yet). +// Here, we use simple, dummy v2 flags wrappers around v1 flags implementation. +// This allows us to use the same flags APIs both internally and externally. + +#define ABSL_FLAG(type, name, default_value, help) \ + DEFINE_##type(name, default_value, help) + +#define ABSL_DECLARE_FLAG(type, name) DECLARE_##type(name) + +namespace absl { +// Return the value of an old-style flag. Not thread-safe. +inline bool GetFlag(bool flag) { return flag; } +inline int32 GetFlag(int32 flag) { return flag; } +inline int64 GetFlag(int64 flag) { return flag; } +inline uint64 GetFlag(uint64 flag) { return flag; } +inline double GetFlag(double flag) { return flag; } +inline string GetFlag(const string& flag) { return flag; } + +// Change the value of an old-style flag. Not thread-safe. +inline void SetFlag(bool* f, bool v) { *f = v; } +inline void SetFlag(int32* f, int32 v) { *f = v; } +inline void SetFlag(int64* f, int64 v) { *f = v; } +inline void SetFlag(uint64* f, uint64 v) { *f = v; } +inline void SetFlag(double* f, double v) { *f = v; } +inline void SetFlag(string* f, const string& v) { *f = v; } +} // namespace absl + +// Python 3 compatibility +#if PY_MAJOR_VERSION >= 3 +// Python 2 has both an 'int' and a 'long' type, and Python 3 only as an 'int' +// type which is the equivalent of Python 2's 'long'. +// PyInt* functions will refer to 'int' in Python 2 and 3. + #define PyInt_FromLong PyLong_FromLong + #define PyInt_AsLong PyLong_AsLong + #define PyInt_CheckExact PyLong_CheckExact + +// Python 3's 'bytes' type is the equivalent of Python 2's 'str' type, which are +// byte arrays. Python 3's 'str' type represents a unicode string. +// In this codebase: +// PyString* functions will refer to 'str' in Python 2 and 3. +// PyBytes* functions will refer to 'str' in Python 2 and 'bytes' in Python 3. + #define PyString_AsString PyUnicode_AsUTF8 +#endif #endif // DEVTOOLS_CDBG_DEBUGLETS_PYTHON_COMMON_H_ diff --git a/src/googleclouddebugger/conditional_breakpoint.cc b/src/googleclouddebugger/conditional_breakpoint.cc index b7e9bec..9d66474 100644 --- a/src/googleclouddebugger/conditional_breakpoint.cc +++ b/src/googleclouddebugger/conditional_breakpoint.cc @@ -19,6 +19,8 @@ #include "conditional_breakpoint.h" +#include + #include "immutability_tracer.h" #include "rate_limit.h" @@ -64,19 +66,23 @@ bool ConditionalBreakpoint::EvaluateCondition(PyFrameObject* frame) { ScopedPyObject result; bool is_mutable_code_detected = false; - int32 line_count = 0; + int32_t line_count = 0; { ScopedImmutabilityTracer immutability_tracer; result.reset(PyEval_EvalCode( +#if PY_MAJOR_VERSION >= 3 + reinterpret_cast(condition_.get()), +#else condition_.get(), +#endif frame->f_globals, frame->f_locals)); is_mutable_code_detected = immutability_tracer.IsMutableCodeDetected(); line_count = immutability_tracer.GetLineCount(); } - // TODO(vlif): clear breakpoint if condition evaluation failed due to + // TODO: clear breakpoint if condition evaluation failed due to // mutable code or timeout. auto eval_exception = ClearPythonException(); diff --git a/src/googleclouddebugger/deferred_modules.py b/src/googleclouddebugger/deferred_modules.py deleted file mode 100644 index 816f7f5..0000000 --- a/src/googleclouddebugger/deferred_modules.py +++ /dev/null @@ -1,238 +0,0 @@ -# Copyright 2015 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS-IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Support for breakpoints on modules that haven't been loaded yet.""" - -import imp -import os -import sys # Must be imported, otherwise import hooks don't work. -import time - -import cdbg_native as native - -# Maximum number of directories that IsValidSourcePath will scan. -_DIRECTORY_LOOKUP_QUOTA = 250 - -# Callbacks to invoke when a module is imported. -_import_callbacks = {} - -# Original __import__ function if import hook is installed or None otherwise. -_real_import = None - - -def IsValidSourcePath(source_path): - """Checks availability of a Python module. - - This function checks if it is possible that a module will match the specified - path. We only use the file name and we ignore the directory. - - There is no absolutely correct way to do this. The application may just - import a module from a string, or dynamically change sys.path. This function - implements heuristics that should cover all reasonable cases with a good - performance. - - There can be some edge cases when this code is going to scan a huge number - of directories. This can be very expensive. To mitigate it, we limit the - number of directories that can be scanned. If this threshold is reached, - false negatives are possible. - - Args: - source_path: source path as specified in the breakpoint. - - Returns: - True if it is possible that a module matching source_path will ever be - loaded or false otherwise. - """ - - def IsPackage(path): - """Checks if the specified directory is a valid Python package.""" - init_base_path = os.path.join(path, '__init__.py') - return (os.path.isfile(init_base_path) or - os.path.isfile(init_base_path + 'c') or - os.path.isfile(init_base_path + 'o')) - - def SubPackages(path): - """Gets a list of all the directories of subpackages of path.""" - if os.path.isdir(path): - for name in os.listdir(path): - if '.' in name: - continue # This is definitely a file, package names can't have dots. - - if directory_lookups[0] >= _DIRECTORY_LOOKUP_QUOTA: - break - - directory_lookups[0] += 1 - - subpath = os.path.join(path, name) - if IsPackage(subpath): - yield subpath - - start_time = time.time() - directory_lookups = [0] - - file_name = _GetModuleName(source_path) - if not file_name: - return False - - # Recursively discover all the subpackages in all the Python paths. - paths = set() - pending = set(sys.path) - while pending: - path = pending.pop() - paths.add(path) - pending |= frozenset(SubPackages(path)) - paths - - # Append all directories where some modules have already been loaded. There - # is a good chance that the file we are looking for will be there. This is - # only useful if a module got somehow loaded outside of sys.path. We don't - # include these paths in the recursive discovery of subpackages because it - # takes a lot of time in some edge cases and not worth it. - default_path = sys.path[0] - for unused_module_name, module in sys.modules.copy().iteritems(): - file_path = getattr(module, '__file__', None) - path, unused_name = os.path.split(file_path) if file_path else (None, None) - paths.add(path or default_path) - - try: - imp.find_module(file_name, list(paths)) - rc = True - except ImportError: - rc = False - - native.LogInfo( - ('Look up for %s completed in %d directories, ' - 'scanned %d directories (quota: %d), ' - 'result: %r, total time: %f ms') % ( - file_name, - len(paths), - directory_lookups[0], - _DIRECTORY_LOOKUP_QUOTA, - rc, - (time.time() - start_time) * 1000)) - return rc - - -def AddImportCallback(source_path, callback): - """Register import hook. - - This function overrides the default import process. Then whenever a module - corresponding to source_path is imported, the callback will be invoked. - - A module may be imported multiple times. Import event only means that the - Python code contained an "import" statement. The actual loading and - initialization of a new module normally happens only once. After that the - module is just fetched from the cache. This function doesn't care whether a - module was loaded or fetched from cache. The callback will be triggered - all the same. - - Args: - source_path: source file path identifying the monitored module name. If - the file is __init__.py, this function will monitor package import. - Otherwise it will monitor module import. - callback: callable to invoke upon module import. - - Returns: - Function object to invoke to remove the installed callback. - """ - - def RemoveCallback(): - # Atomic operations, no need to lock. - callbacks = _import_callbacks.get(module_name) - if callbacks: - callbacks.remove(callback) - - module_name = _GetModuleName(source_path) - if not module_name: - return None - - # Atomic operations, no need to lock. - _import_callbacks.setdefault(module_name, set()).add(callback) - _InstallImportHook() - - return RemoveCallback - - -def _GetModuleName(source_path): - """Gets the name of the module that corresponds to source_path. - - Args: - source_path: file path to resolve into a module. - - Returns: - If the source file is __init__.py, this function will return the name - of the package (last directory before file name). Otherwise this function - return file name without extension. - """ - directory, name = os.path.split(source_path) - if name == '__init__.py': - if not directory.strip(os.sep): - return None # '__init__.py' is way too generic. We can't match it. - - directory, file_name = os.path.split(directory) - else: - file_name, ext = os.path.splitext(name) - if ext != '.py': - return None # ".py" extension is expected - - return file_name - - -def _InstallImportHook(): - """Lazily installs import hook.""" - - global _real_import - - if _real_import: - return # Import hook already installed - - builtin = sys.modules['__builtin__'] - - _real_import = getattr(builtin, '__import__') - assert _real_import - - builtin.__import__ = _ImportHook - - -# pylint: disable=redefined-builtin, g-doc-args, g-doc-return-or-yield -def _ImportHook(name, globals=None, locals=None, fromlist=None, level=-1): - """Callback when a module is being imported by Python interpreter. - - Argument names have to exactly match those of __import__. Otherwise calls - to __import__ that use keyword syntax will fail: __import('a', fromlist=[]). - """ - - module = _real_import(name, globals, locals, fromlist, level) - - # Invoke callbacks for the imported module. No need to lock, since all - # operations are atomic. - pos = name.rfind('.') + 1 - _InvokeImportCallback(name[pos:]) - - if fromlist: - for module_name in fromlist: - _InvokeImportCallback(module_name) - - return module - - -def _InvokeImportCallback(module_name): - """Invokes import callbacks for the specified module.""" - callbacks = _import_callbacks.get(module_name) - if not callbacks: - return # Common code path. - - # Clone the callbacks set, since it can change during enumeration. - for callback in callbacks.copy(): - callback(module_name) - diff --git a/src/googleclouddebugger/error_data_visibility_policy.py b/src/googleclouddebugger/error_data_visibility_policy.py new file mode 100644 index 0000000..0a04c36 --- /dev/null +++ b/src/googleclouddebugger/error_data_visibility_policy.py @@ -0,0 +1,31 @@ +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Always returns the provided error on visibility requests. + +Example Usage: + + policy = ErrorDataVisibilityPolicy('An error message') + + policy.IsDataVisible('org.foo.bar') -> (False, 'An error message') +""" + + +class ErrorDataVisibilityPolicy(object): + """Visibility policy that always returns an error to the caller.""" + + def __init__(self, error_message): + self.error_message = error_message + + def IsDataVisible(self, unused_path): + return (False, self.error_message) diff --git a/src/googleclouddebugger/firebase_client.py b/src/googleclouddebugger/firebase_client.py new file mode 100644 index 0000000..8dbe30a --- /dev/null +++ b/src/googleclouddebugger/firebase_client.py @@ -0,0 +1,699 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Communicates with Firebase RTDB backend.""" + +from collections import deque +import copy +import hashlib +import json +import os +import platform +import requests +import sys +import threading +import time +import traceback + +import firebase_admin +import firebase_admin.credentials +import firebase_admin.db +import firebase_admin.exceptions + +from . import backoff +from . import cdbg_native as native +from . import labels +from . import uniquifier_computer +from . import application_info +from . import version +# This module catches all exception. This is safe because it runs in +# a daemon thread (so we are not blocking Ctrl+C). We need to catch all +# the exception because HTTP client is unpredictable as far as every +# exception it can throw. +# pylint: disable=broad-except + +# Set of all known debuggee labels (passed down as flags). The value of +# a map is optional environment variable that can be used to set the flag +# (flags still take precedence). +_DEBUGGEE_LABELS = { + labels.Debuggee.MODULE: [ + 'GAE_SERVICE', 'GAE_MODULE_NAME', 'K_SERVICE', 'FUNCTION_NAME' + ], + labels.Debuggee.VERSION: [ + 'GAE_VERSION', 'GAE_MODULE_VERSION', 'K_REVISION', + 'X_GOOGLE_FUNCTION_VERSION' + ], + labels.Debuggee.MINOR_VERSION: ['GAE_DEPLOYMENT_ID', 'GAE_MINOR_VERSION'] +} + +# Debuggee labels used to format debuggee description (ordered). The minor +# version is excluded for the sake of consistency with AppEngine UX. +_DESCRIPTION_LABELS = [ + labels.Debuggee.PROJECT_ID, labels.Debuggee.MODULE, labels.Debuggee.VERSION +] + +_METADATA_SERVER_URL = 'http://metadata.google.internal/computeMetadata/v1' + +_TRANSIENT_ERROR_CODES = ('UNKNOWN', 'INTERNAL', 'N/A', 'UNAVAILABLE', + 'DEADLINE_EXCEEDED', 'RESOURCE_EXHAUSTED', + 'UNAUTHENTICATED', 'PERMISSION_DENIED') + + +class NoProjectIdError(Exception): + """Used to indicate the project id cannot be determined.""" + + +class FirebaseClient(object): + """Firebase RTDB Backend client. + + Registers the debuggee, subscribes for active breakpoints and sends breakpoint + updates to the backend. + + This class supports two types of authentication: application default + credentials or a manually provided JSON credentials file for a service + account. + + FirebaseClient creates a worker thread that communicates with the backend. The + thread can be stopped with a Stop function, but it is optional since the + worker thread is marked as daemon. + """ + + def __init__(self): + self.on_active_breakpoints_changed = lambda x: None + self.on_idle = lambda: None + self._debuggee_labels = {} + self._credentials = None + self._project_id = None + self._database_url = None + self._debuggee_id = None + self._canary_mode = None + self._breakpoints = {} + self._main_thread = None + self._transmission_thread = None + self._transmission_thread_startup_lock = threading.Lock() + self._transmission_queue = deque(maxlen=100) + self._mark_active_timer = None + self._mark_active_interval_sec = 60 * 60 # 1 hour in seconds + self._new_updates = threading.Event() + self._breakpoint_subscription = None + self._firebase_app = None + + # Events for unit testing. + self.connection_complete = threading.Event() + self.registration_complete = threading.Event() + self.subscription_complete = threading.Event() + + # + # Configuration options (constants only modified by unit test) + # + + # Delay before retrying failed request. + self.connect_backoff = backoff.Backoff() # Connect to the DB. + self.register_backoff = backoff.Backoff() # Register debuggee. + self.subscribe_backoff = backoff.Backoff() # Subscribe to updates. + self.update_backoff = backoff.Backoff() # Update breakpoint. + + # Maximum number of times that the message is re-transmitted before it + # is assumed to be poisonous and discarded + self.max_transmit_attempts = 10 + + def InitializeDebuggeeLabels(self, flags): + """Initialize debuggee labels from environment variables and flags. + + The caller passes all the flags that the debuglet got. This function + will only use the flags used to label the debuggee. Flags take precedence + over environment variables. + + Debuggee description is formatted from available flags. + + Args: + flags: dictionary of debuglet command line flags. + """ + self._debuggee_labels = {} + + for (label, var_names) in _DEBUGGEE_LABELS.items(): + # var_names is a list of possible environment variables that may contain + # the label value. Find the first one that is set. + for name in var_names: + value = os.environ.get(name) + if value: + # Special case for module. We omit the "default" module + # to stay consistent with AppEngine. + if label == labels.Debuggee.MODULE and value == 'default': + break + self._debuggee_labels[label] = value + break + + # Special case when FUNCTION_NAME is set and X_GOOGLE_FUNCTION_VERSION + # isn't set. We set the version to 'unversioned' to be consistent with other + # agents. + # TODO: Stop assigning 'unversioned' to a GCF and find the + # actual version. + if ('FUNCTION_NAME' in os.environ and + labels.Debuggee.VERSION not in self._debuggee_labels): + self._debuggee_labels[labels.Debuggee.VERSION] = 'unversioned' + + if flags: + self._debuggee_labels.update({ + name: value + for (name, value) in flags.items() + if name in _DEBUGGEE_LABELS + }) + + self._debuggee_labels[labels.Debuggee.PROJECT_ID] = self._project_id + + platform_enum = application_info.GetPlatform() + self._debuggee_labels[labels.Debuggee.PLATFORM] = platform_enum.value + + if platform_enum == application_info.PlatformType.CLOUD_FUNCTION: + region = application_info.GetRegion() + if region: + self._debuggee_labels[labels.Debuggee.REGION] = region + + def SetupAuth(self, + project_id=None, + service_account_json_file=None, + database_url=None): + """Sets up authentication with Google APIs. + + This will use the credentials from service_account_json_file if provided, + falling back to application default credentials. + See https://cloud.google.com/docs/authentication/production. + + Args: + project_id: GCP project ID (e.g. myproject). If not provided, will attempt + to retrieve it from the credentials. + service_account_json_file: JSON file to use for credentials. If not + provided, will default to application default credentials. + database_url: Firebase realtime database URL to be used. If not + provided, connect attempts to the following DBs will be made, in + order: + https://{project_id}-cdbg.firebaseio.com + https://{project_id}-default-rtdb.firebaseio.com + Raises: + NoProjectIdError: If the project id cannot be determined. + """ + if service_account_json_file: + self._credentials = firebase_admin.credentials.Certificate( + service_account_json_file) + if not project_id: + with open(service_account_json_file, encoding='utf-8') as f: + project_id = json.load(f).get('project_id') + else: + if not project_id: + try: + r = requests.get( + f'{_METADATA_SERVER_URL}/project/project-id', + headers={'Metadata-Flavor': 'Google'}, + timeout=1) + project_id = r.text + except requests.exceptions.RequestException: + native.LogInfo('Metadata server not available') + + if not project_id: + raise NoProjectIdError( + 'Unable to determine the project id from the API credentials. ' + 'Please specify the project id using the --project_id flag.') + + self._project_id = project_id + self._database_url = database_url + + def Start(self): + """Starts the worker thread.""" + self._shutdown = False + + # Spin up the main thread which will create the other necessary threads. + self._main_thread = threading.Thread(target=self._MainThreadProc) + self._main_thread.name = 'Cloud Debugger main worker thread' + self._main_thread.daemon = True + self._main_thread.start() + + def Stop(self): + """Signals the worker threads to shut down and waits until it exits.""" + self._shutdown = True + self._new_updates.set() # Wake up the transmission thread. + + if self._main_thread is not None: + self._main_thread.join() + self._main_thread = None + + if self._transmission_thread is not None: + self._transmission_thread.join() + self._transmission_thread = None + + if self._mark_active_timer is not None: + self._mark_active_timer.cancel() + self._mark_active_timer = None + + if self._breakpoint_subscription is not None: + self._breakpoint_subscription.close() + self._breakpoint_subscription = None + + def EnqueueBreakpointUpdate(self, breakpoint_data): + """Asynchronously updates the specified breakpoint on the backend. + + This function returns immediately. The worker thread is actually doing + all the work. The worker thread is responsible to retry the transmission + in case of transient errors. + + The assumption is that the breakpoint is moving from Active to Final state. + + Args: + breakpoint: breakpoint in either final or non-final state. + """ + with self._transmission_thread_startup_lock: + if self._transmission_thread is None: + self._transmission_thread = threading.Thread( + target=self._TransmissionThreadProc) + self._transmission_thread.name = 'Cloud Debugger transmission thread' + self._transmission_thread.daemon = True + self._transmission_thread.start() + + self._transmission_queue.append((breakpoint_data, 0)) + self._new_updates.set() # Wake up the worker thread to send immediately. + + def _MainThreadProc(self): + """Entry point for the worker thread. + + This thread only serves to register and kick off the firebase subscription + which will run in its own thread. That thread will be owned by + self._breakpoint_subscription. + """ + connection_required, delay = True, 0 + while connection_required: + time.sleep(delay) + connection_required, delay = self._ConnectToDb() + self.connection_complete.set() + + registration_required, delay = True, 0 + while registration_required: + time.sleep(delay) + registration_required, delay = self._RegisterDebuggee() + self.registration_complete.set() + + subscription_required, delay = True, 0 + while subscription_required: + time.sleep(delay) + subscription_required, delay = self._SubscribeToBreakpoints() + self.subscription_complete.set() + + self._StartMarkActiveTimer() + + while not self._shutdown: + if self.on_idle is not None: + self.on_idle() + + time.sleep(1) + + def _TransmissionThreadProc(self): + """Entry point for the transmission worker thread.""" + + while not self._shutdown: + self._new_updates.clear() + + delay = self._TransmitBreakpointUpdates() + + self._new_updates.wait(delay) + + def _MarkActiveTimerFunc(self): + """Entry point for the mark active timer.""" + + try: + self._MarkDebuggeeActive() + except: + native.LogInfo( + f'Failed to mark debuggee as active: {traceback.format_exc()}') + finally: + self._StartMarkActiveTimer() + + def _StartMarkActiveTimer(self): + self._mark_active_timer = threading.Timer(self._mark_active_interval_sec, + self._MarkActiveTimerFunc) + self._mark_active_timer.start() + + def _ConnectToDb(self): + urls = [self._database_url] if self._database_url is not None else \ + [f'https://{self._project_id}-cdbg.firebaseio.com', + f'https://{self._project_id}-default-rtdb.firebaseio.com'] + + for url in urls: + native.LogInfo(f'Attempting to connect to DB with url: {url}') + + status, firebase_app = self._TryInitializeDbForUrl(url) + if status: + native.LogInfo(f'Successfully connected to DB with url: {url}') + self._database_url = url + self._firebase_app = firebase_app + self.connect_backoff.Succeeded() + return (False, 0) # Proceed immediately to registering the debuggee. + + return (True, self.connect_backoff.Failed()) + + def _TryInitializeDbForUrl(self, database_url): + # Note: if self._credentials is None, default app credentials will be used. + app = None + try: + app = firebase_admin.initialize_app( + self._credentials, {'databaseURL': database_url}, name='cdbg') + + if self._CheckSchemaVersionPresence(app): + return True, app + + except ValueError: + native.LogWarning( + f'Failed to initialize firebase: {traceback.format_exc()}') + + # This is the failure path, if we hit here we must cleanup the app handle + if app is not None: + firebase_admin.delete_app(app) + app = None + + return False, app + + def _RegisterDebuggee(self): + """Single attempt to register the debuggee. + + If the registration succeeds, sets self._debuggee_id to the registered + debuggee ID. + + Args: + service: client to use for API calls + + Returns: + (registration_required, delay) tuple + """ + debuggee = None + try: + debuggee = self._GetDebuggee() + self._debuggee_id = debuggee['id'] + except BaseException: + native.LogWarning( + f'Debuggee information not available: {traceback.format_exc()}') + return (True, self.register_backoff.Failed()) + + try: + present = self._CheckDebuggeePresence() + if present: + self._MarkDebuggeeActive() + else: + debuggee_path = f'cdbg/debuggees/{self._debuggee_id}' + native.LogInfo( + f'Registering at {self._database_url}, path: {debuggee_path}') + debuggee_data = copy.deepcopy(debuggee) + debuggee_data['registrationTimeUnixMsec'] = {'.sv': 'timestamp'} + debuggee_data['lastUpdateTimeUnixMsec'] = {'.sv': 'timestamp'} + firebase_admin.db.reference(debuggee_path, + self._firebase_app).set(debuggee_data) + + native.LogInfo( + f'Debuggee registered successfully, ID: {self._debuggee_id}') + + self.register_backoff.Succeeded() + return (False, 0) # Proceed immediately to subscribing to breakpoints. + except BaseException as e: + # There is no significant benefit to handing different exceptions + # in different ways; we will log and retry regardless. + native.LogInfo(f'Failed to register debuggee: {repr(e)}') + return (True, self.register_backoff.Failed()) + + def _CheckSchemaVersionPresence(self, firebase_app): + path = f'cdbg/schema_version' + try: + snapshot = firebase_admin.db.reference(path, firebase_app).get() + # The value doesn't matter; just return true if there's any value. + return snapshot is not None + except BaseException as e: + native.LogInfo( + f'Failed to check schema version presence at {path}: {repr(e)}') + return False + + def _CheckDebuggeePresence(self): + path = f'cdbg/debuggees/{self._debuggee_id}/registrationTimeUnixMsec' + try: + snapshot = firebase_admin.db.reference(path, self._firebase_app).get() + # The value doesn't matter; just return true if there's any value. + return snapshot is not None + except BaseException as e: + native.LogInfo(f'Failed to check debuggee presence at {path}: {repr(e)}') + return False + + def _MarkDebuggeeActive(self): + active_path = f'cdbg/debuggees/{self._debuggee_id}/lastUpdateTimeUnixMsec' + try: + server_time = {'.sv': 'timestamp'} + firebase_admin.db.reference(active_path, + self._firebase_app).set(server_time) + except BaseException: + native.LogInfo( + f'Failed to mark debuggee active: {traceback.format_exc()}') + + def _SubscribeToBreakpoints(self): + # Kill any previous subscriptions first. + if self._breakpoint_subscription is not None: + self._breakpoint_subscription.close() + self._breakpoint_subscription = None + + path = f'cdbg/breakpoints/{self._debuggee_id}/active' + native.LogInfo(f'Subscribing to breakpoint updates at {path}') + ref = firebase_admin.db.reference(path, self._firebase_app) + try: + self._breakpoint_subscription = ref.listen(self._ActiveBreakpointCallback) + return (False, 0) + except firebase_admin.exceptions.FirebaseError: + native.LogInfo( + f'Failed to subscribe to breakpoints: {traceback.format_exc()}') + return (True, self.subscribe_backoff.Failed()) + + def _ActiveBreakpointCallback(self, event): + if event.event_type == 'put': + if event.data is None: + # Either deleting a breakpoint or initializing with no breakpoints. + # Initializing with no breakpoints is a no-op. + # If deleting, event.path will be /{breakpointid} + if event.path != '/': + breakpoint_id = event.path[1:] + # Breakpoint may have already been deleted, so pop for possible no-op. + self._breakpoints.pop(breakpoint_id, None) + else: + if event.path == '/': + # New set of breakpoints. + self._breakpoints = {} + for (key, value) in event.data.items(): + self._AddBreakpoint(key, value) + else: + # New breakpoint. + breakpoint_id = event.path[1:] + self._AddBreakpoint(breakpoint_id, event.data) + + elif event.event_type == 'patch': + # New breakpoint or breakpoints. + for (key, value) in event.data.items(): + self._AddBreakpoint(key, value) + else: + native.LogWarning('Unexpected event from Firebase: ' + f'{event.event_type} {event.path} {event.data}') + return + + native.LogInfo(f'Breakpoints list changed, {len(self._breakpoints)} active') + self.on_active_breakpoints_changed(list(self._breakpoints.values())) + + def _AddBreakpoint(self, breakpoint_id, breakpoint_data): + breakpoint_data['id'] = breakpoint_id + self._breakpoints[breakpoint_id] = breakpoint_data + + def _TransmitBreakpointUpdates(self): + """Tries to send pending breakpoint updates to the backend. + + Sends all the pending breakpoint updates. In case of transient failures, + the breakpoint is inserted back to the top of the queue. Application + failures are not retried (for example updating breakpoint in a final + state). + + Each pending breakpoint maintains a retry counter. After repeated transient + failures the breakpoint is discarded and dropped from the queue. + + Args: + service: client to use for API calls + + Returns: + (reconnect, timeout) tuple. The first element ("reconnect") is set to + true on unexpected HTTP responses. The caller should discard the HTTP + connection and create a new one. The second element ("timeout") is + set to None if all pending breakpoints were sent successfully. Otherwise + returns time interval in seconds to stall before retrying. + """ + retry_list = [] + + # There is only one consumer, so two step pop is safe. + while self._transmission_queue: + breakpoint_data, retry_count = self._transmission_queue.popleft() + + bp_id = breakpoint_data['id'] + + try: + # Something has changed on the breakpoint. + # It should be going from active to final, but let's make sure. + if not breakpoint_data.get('isFinalState', False): + raise BaseException( + f'Unexpected breakpoint update requested: {breakpoint_data}') + + # If action is missing, it should be set to 'CAPTURE' + is_logpoint = breakpoint_data.get('action') == 'LOG' + is_snapshot = not is_logpoint + if is_snapshot: + breakpoint_data['action'] = 'CAPTURE' + + # Set the completion time on the server side using a magic value. + breakpoint_data['finalTimeUnixMsec'] = {'.sv': 'timestamp'} + + # First, remove from the active breakpoints. + bp_ref = firebase_admin.db.reference( + f'cdbg/breakpoints/{self._debuggee_id}/active/{bp_id}', + self._firebase_app) + bp_ref.delete() + + summary_data = breakpoint_data + # Save snapshot data for snapshots only. + if is_snapshot: + # Note that there may not be snapshot data. + bp_ref = firebase_admin.db.reference( + f'cdbg/breakpoints/{self._debuggee_id}/snapshot/{bp_id}', + self._firebase_app) + bp_ref.set(breakpoint_data) + + # Now strip potential snapshot data. + summary_data = copy.deepcopy(breakpoint_data) + summary_data.pop('evaluatedExpressions', None) + summary_data.pop('stackFrames', None) + summary_data.pop('variableTable', None) + + # Then add it to the list of final breakpoints. + bp_ref = firebase_admin.db.reference( + f'cdbg/breakpoints/{self._debuggee_id}/final/{bp_id}', + self._firebase_app) + bp_ref.set(summary_data) + + native.LogInfo(f'Breakpoint {bp_id} update transmitted successfully') + + except firebase_admin.exceptions.FirebaseError as err: + if err.code in _TRANSIENT_ERROR_CODES: + if retry_count < self.max_transmit_attempts - 1: + native.LogInfo(f'Failed to send breakpoint {bp_id} update: ' + f'{traceback.format_exc()}') + retry_list.append((breakpoint_data, retry_count + 1)) + else: + native.LogWarning( + f'Breakpoint {bp_id} retry count exceeded maximum') + else: + # This is very common if multiple instances are sending final update + # simultaneously. + native.LogInfo(f'{err}, breakpoint: {bp_id}') + + except BaseException: + native.LogWarning(f'Fatal error sending breakpoint {bp_id} update: ' + f'{traceback.format_exc()}') + + self._transmission_queue.extend(retry_list) + + if not self._transmission_queue: + self.update_backoff.Succeeded() + # Nothing to send, wait until next breakpoint update. + return None + else: + return self.update_backoff.Failed() + + def _GetDebuggee(self): + """Builds the debuggee structure.""" + major_version = version.__version__.split('.', maxsplit=1)[0] + python_version = ''.join(platform.python_version().split('.')[:2]) + agent_version = f'google.com/python{python_version}-gcp/v{major_version}' + + debuggee = { + 'description': self._GetDebuggeeDescription(), + 'labels': self._debuggee_labels, + 'agentVersion': agent_version, + } + + source_context = self._ReadAppJsonFile('source-context.json') + if source_context: + debuggee['sourceContexts'] = [source_context] + + debuggee['uniquifier'] = self._ComputeUniquifier(debuggee) + + debuggee['id'] = self._ComputeDebuggeeId(debuggee) + + return debuggee + + def _ComputeDebuggeeId(self, debuggee): + """Computes a debuggee ID. + + The debuggee ID has to be identical on all instances. Therefore the + ID should not include any random elements or elements that may be + different on different instances. + + Args: + debuggee: complete debuggee message (including uniquifier) + + Returns: + Debuggee ID meeting the criteria described above. + """ + fullhash = hashlib.sha1(json.dumps(debuggee, + sort_keys=True).encode()).hexdigest() + return f'd-{fullhash[:8]}' + + def _GetDebuggeeDescription(self): + """Formats debuggee description based on debuggee labels.""" + return '-'.join(self._debuggee_labels[label] + for label in _DESCRIPTION_LABELS + if label in self._debuggee_labels) + + def _ComputeUniquifier(self, debuggee): + """Computes debuggee uniquifier. + + The debuggee uniquifier has to be identical on all instances. Therefore the + uniquifier should not include any random numbers and should only be based + on inputs that are guaranteed to be the same on all instances. + + Args: + debuggee: complete debuggee message without the uniquifier + + Returns: + Hex string of SHA1 hash of project information, debuggee labels and + debuglet version. + """ + uniquifier = hashlib.sha1() + + # Compute hash of application files if we don't have source context. This + # way we can still distinguish between different deployments. + if ('minorversion' not in debuggee.get('labels', []) and + 'sourceContexts' not in debuggee): + uniquifier_computer.ComputeApplicationUniquifier(uniquifier) + + return uniquifier.hexdigest() + + def _ReadAppJsonFile(self, relative_path): + """Reads JSON file from an application directory. + + Args: + relative_path: file name relative to application root directory. + + Returns: + Parsed JSON data or None if the file does not exist, can't be read or + not a valid JSON file. + """ + try: + with open( + os.path.join(sys.path[0], relative_path), 'r', encoding='utf-8') as f: + return json.load(f) + except (IOError, ValueError): + return None diff --git a/src/googleclouddebugger/gcp_hub_client.py b/src/googleclouddebugger/gcp_hub_client.py deleted file mode 100644 index 68760f5..0000000 --- a/src/googleclouddebugger/gcp_hub_client.py +++ /dev/null @@ -1,500 +0,0 @@ -# Copyright 2015 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS-IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Communicates with Cloud Debugger backend over HTTP.""" - -from collections import deque -import copy -import hashlib -import inspect -import json -import logging -import os -import sys -import threading -import time -import traceback - - - -import apiclient -from apiclient import discovery # pylint: disable=unused-import -from backoff import Backoff -import httplib2 -import oauth2client -from oauth2client.contrib.gce import AppAssertionCredentials - -import cdbg_native as native -import uniquifier_computer -import googleclouddebugger - -# This module catches all exception. This is safe because it runs in -# a daemon thread (so we are not blocking Ctrl+C). We need to catch all -# the exception because HTTP client is unpredictable as far as every -# exception it can throw. -# pylint: disable=broad-except - -# API scope we are requesting when service account authentication is enabled. -_CLOUD_PLATFORM_SCOPE = 'https://www.googleapis.com/auth/cloud-platform' - -# Base URL for metadata service. Specific attributes are appended to this URL. -_LOCAL_METADATA_SERVICE_PROJECT_URL = ('http://metadata.google.internal/' - 'computeMetadata/v1/project/') - -# Set of all known debuggee labels (passed down as flags). The value of -# a map is optional environment variable that can be used to set the flag -# (flags still take precedence). -_DEBUGGEE_LABELS = { - 'module': 'GAE_MODULE_NAME', - 'version': 'GAE_MODULE_VERSION', - 'minorversion': 'GAE_MINOR_VERSION'} - -# Debuggee labels used to format debuggee description (ordered). The minor -# version is excluded for the sake of consistency with AppEngine UX. -_DESCRIPTION_LABELS = ['projectid', 'module', 'version'] - - -class GcpHubClient(object): - """Controller API client. - - Registers the debuggee, queries the active breakpoints and sends breakpoint - updates to the backend. - - This class supports two types of authentication: metadata service and service - account. The mode is selected by calling EnableServiceAccountAuth or - EnableGceAuth method. - - GcpHubClient creates a worker thread that communicates with the backend. The - thread can be stopped with a Stop function, but it is optional since the - worker thread is marked as daemon. - """ - - def __init__(self): - self.on_active_breakpoints_changed = lambda x: None - self.on_idle = lambda: None - self._debuggee_labels = {} - self._service_account_auth = False - self._debuggee_id = None - self._wait_token = 'init' - self._breakpoints = [] - self._main_thread = None - self._transmission_thread = None - self._transmission_thread_startup_lock = threading.Lock() - self._transmission_queue = deque(maxlen=100) - self._new_updates = threading.Event(False) - - # Disable logging in the discovery API to avoid excessive logging. - class _ChildLogFilter(logging.Filter): - """Filter to eliminate info-level logging when called from this module.""" - - def __init__(self, filter_levels=None): - super(_ChildLogFilter, self).__init__() - self._filter_levels = filter_levels or set(logging.INFO) - # Get name without extension to avoid .py vs .pyc issues - self._my_filename = os.path.splitext( - inspect.getmodule(_ChildLogFilter).__file__)[0] - - def filter(self, record): - if record.levelno not in self._filter_levels: - return True - callerframes = inspect.getouterframes(inspect.currentframe()) - for f in callerframes: - if os.path.splitext(f[1])[0] == self._my_filename: - return False - return True - self._log_filter = _ChildLogFilter({logging.INFO}) - discovery.logger.addFilter(self._log_filter) - - # - # Configuration options (constants only modified by unit test) - # - - # Delay before retrying failed request. - self.register_backoff = Backoff() # Register debuggee. - self.list_backoff = Backoff() # Query active breakpoints. - self.update_backoff = Backoff() # Update breakpoint. - - # Maximum number of times that the message is re-transmitted before it - # is assumed to be poisonous and discarded - self.max_transmit_attempts = 10 - - def InitializeDebuggeeLabels(self, flags): - """Initialize debuggee labels from environment variables and flags. - - The caller passes all the flags that the the debuglet got. This function - will only use the flags used to label the debuggee. Flags take precedence - over environment variables. - - Debuggee description is formatted from available flags. - - Project ID is not set here. It is obtained from metadata service or - specified as a parameter to EnableServiceAccountAuth. - - Args: - flags: dictionary of debuglet command line flags. - """ - self._debuggee_labels = {} - - for (label, env) in _DEBUGGEE_LABELS.iteritems(): - if env and env in os.environ: - # Special case for GAE_MODULE_NAME. We omit the "default" module - # to stay consistent with AppEngine. - if env == 'GAE_MODULE_NAME' and os.environ[env] == 'default': - continue - self._debuggee_labels[label] = os.environ[env] - - if flags: - self._debuggee_labels.update( - {name: value for (name, value) in flags.iteritems() - if name in _DEBUGGEE_LABELS}) - - self._debuggee_labels['projectid'] = self._project_id() - - def EnableServiceAccountAuth(self, project_id, project_number, - email, p12_file): - """Selects to use the service account authentication. - - Args: - project_id: GCP project ID (e.g. myproject). - project_number: numberic GCP project ID (e.g. 72386324623). - email: service account identifier (...@developer.gserviceaccount.com). - p12_file: path to the file with the private key. - """ - with open(p12_file, 'rb') as f: - self._credentials = oauth2client.client.SignedJwtAssertionCredentials( - email, f.read(), scope=_CLOUD_PLATFORM_SCOPE) - self._project_id = lambda: project_id - self._project_number = lambda: project_number - - def EnableGceAuth(self): - """Selects to use local metadata service for authentication. - - The project ID and project number are also retrieved from the metadata - service. It is done lazily from the worker thread. The motivation is to - speed up initialization and be able to recover from failures. - """ - self._credentials = AppAssertionCredentials(_CLOUD_PLATFORM_SCOPE) - self._project_id = lambda: self._QueryGcpProject('project-id') - self._project_number = lambda: self._QueryGcpProject('numeric-project-id') - - def Start(self): - """Starts the worker thread.""" - self._shutdown = False - - self._main_thread = threading.Thread(target=self._MainThreadProc) - self._main_thread.name = 'Cloud Debugger main worker thread' - self._main_thread.daemon = True - self._main_thread.start() - - def Stop(self): - """Signals the worker threads to shut down and waits until it exits.""" - self._shutdown = True - self._new_updates.set() # Wake up the transmission thread. - - if self._main_thread is not None: - self._main_thread.join() - self._main_thread = None - - if self._transmission_thread is not None: - self._transmission_thread.join() - self._transmission_thread = None - - def EnqueueBreakpointUpdate(self, breakpoint): - """Asynchronously updates the specified breakpoint on the backend. - - This function returns immediately. The worker thread is actually doing - all the work. The worker thread is responsible to retry the transmission - in case of transient errors. - - Args: - breakpoint: breakpoint in either final or non-final state. - """ - with self._transmission_thread_startup_lock: - if self._transmission_thread is None: - self._transmission_thread = threading.Thread( - target=self._TransmissionThreadProc) - self._transmission_thread.name = 'Cloud Debugger transmission thread' - self._transmission_thread.daemon = True - self._transmission_thread.start() - - self._transmission_queue.append((breakpoint, 0)) - self._new_updates.set() # Wake up the worker thread to send immediately. - - def _BuildService(self): - http = httplib2.Http() - http = self._credentials.authorize(http) - - api = apiclient.discovery.build('clouddebugger', 'v2', http=http) - return api.controller() - - def _MainThreadProc(self): - """Entry point for the worker thread.""" - registration_required = True - while not self._shutdown: - if registration_required: - service = self._BuildService() - registration_required, delay = self._RegisterDebuggee(service) - - if not registration_required: - registration_required, delay = self._ListActiveBreakpoints(service) - - if self.on_idle is not None: - self.on_idle() - - if not self._shutdown: - time.sleep(delay) - - def _TransmissionThreadProc(self): - """Entry point for the transmission worker thread.""" - reconnect = True - - while not self._shutdown: - self._new_updates.clear() - - if reconnect: - service = self._BuildService() - reconnect = False - - reconnect, delay = self._TransmitBreakpointUpdates(service) - - self._new_updates.wait(delay) - - def _RegisterDebuggee(self, service): - """Single attempt to register the debuggee. - - If the registration succeeds, sets self._debuggee_id to the registered - debuggee ID. - - Args: - service: client to use for API calls - - Returns: - (registration_required, delay) tuple - """ - try: - request = {'debuggee': self._GetDebuggee()} - - try: - response = service.debuggees().register(body=request).execute() - - self._debuggee_id = response['debuggee']['id'] - native.LogInfo('Debuggee registered successfully, ID: %s' % ( - self._debuggee_id)) - self.register_backoff.Succeeded() - return (False, 0) # Proceed immediately to list active breakpoints. - except BaseException: - native.LogInfo('Failed to register debuggee: %s, %s' % - (request, traceback.format_exc())) - except BaseException: - native.LogWarning('Debuggee information not available: ' + - traceback.format_exc()) - - return (True, self.register_backoff.Failed()) - - def _ListActiveBreakpoints(self, service): - """Single attempt query the list of active breakpoints. - - Must not be called before the debuggee has been registered. If the request - fails, this function resets self._debuggee_id, which triggers repeated - debuggee registration. - - Args: - service: client to use for API calls - - Returns: - (registration_required, delay) tuple - """ - try: - response = service.debuggees().breakpoints().list( - debuggeeId=self._debuggee_id, waitToken=self._wait_token).execute() - breakpoints = response.get('breakpoints') or [] - self._wait_token = response.get('nextWaitToken') - if cmp(self._breakpoints, breakpoints) != 0: - self._breakpoints = breakpoints - native.LogInfo( - 'Breakpoints list changed, %d active, wait token: %s' % ( - len(self._breakpoints), self._wait_token)) - self.on_active_breakpoints_changed(copy.deepcopy(self._breakpoints)) - except Exception as e: - if not isinstance(e, apiclient.errors.HttpError) or e.resp.status != 409: - native.LogInfo('Failed to query active breakpoints: ' + - traceback.format_exc()) - - # Forget debuggee ID to trigger repeated debuggee registration. Once the - # registration succeeds, the worker thread will retry this query - self._debuggee_id = None - - return (True, self.list_backoff.Failed()) - - self.list_backoff.Succeeded() - return (False, 0) - - def _TransmitBreakpointUpdates(self, service): - """Tries to send pending breakpoint updates to the backend. - - Sends all the pending breakpoint updates. In case of transient failures, - the breakpoint is inserted back to the top of the queue. Application - failures are not retried (for example updating breakpoint in a final - state). - - Each pending breakpoint maintains a retry counter. After repeated transient - failures the breakpoint is discarded and dropped from the queue. - - Args: - service: client to use for API calls - - Returns: - (reconnect, timeout) tuple. The first element ("reconnect") is set to - true on unexpected HTTP responses. The caller should discard the HTTP - connection and create a new one. The second element ("timeout") is - set to None if all pending breakpoints were sent successfully. Otherwise - returns time interval in seconds to stall before retrying. - """ - reconnect = False - retry_list = [] - - # There is only one consumer, so two step pop is safe. - while self._transmission_queue: - breakpoint, retry_count = self._transmission_queue.popleft() - - try: - service.debuggees().breakpoints().update( - debuggeeId=self._debuggee_id, id=breakpoint['id'], - body={'breakpoint': breakpoint}).execute() - - native.LogInfo('Breakpoint %s update transmitted successfully' % ( - breakpoint['id'])) - except apiclient.errors.HttpError as err: - # Treat 400 error codes (except timeout) as application error that will - # not be retried. All other errors are assumed to be transient. - status = err.resp.status - is_transient = ((status >= 500) or (status == 408)) - if is_transient and retry_count < self.max_transmit_attempts - 1: - native.LogInfo('Failed to send breakpoint %s update: %s' % ( - breakpoint['id'], traceback.format_exc())) - retry_list.append((breakpoint, retry_count + 1)) - elif is_transient: - native.LogWarning( - 'Breakpoint %s retry count exceeded maximum' % breakpoint['id']) - else: - # This is very common if multiple instances are sending final update - # simultaneously. - native.LogInfo('%s, breakpoint: %s' % (err, breakpoint['id'])) - except Exception: - native.LogWarning( - 'Fatal error sending breakpoint %s update: %s' % ( - breakpoint['id'], traceback.format_exc())) - reconnect = True - - self._transmission_queue.extend(retry_list) - - if not self._transmission_queue: - self.update_backoff.Succeeded() - # Nothing to send, wait until next breakpoint update. - return (reconnect, None) - else: - return (reconnect, self.update_backoff.Failed()) - - def _QueryGcpProject(self, resource): - """Queries project resource on a local metadata service.""" - url = _LOCAL_METADATA_SERVICE_PROJECT_URL + resource - http = httplib2.Http() - response, content = http.request( - url, headers={'Metadata-Flavor': 'Google'}) - if response['status'] != '200': - raise RuntimeError( - 'HTTP error %s %s when querying local metadata service at %s' % - (response['status'], content, url)) - - return content - - def _GetDebuggee(self): - """Builds the debuggee structure.""" - version = googleclouddebugger.__version__ - major_version = version.split('.')[0] - - debuggee = { - 'project': self._project_number(), - 'description': self._GetDebuggeeDescription(), - 'labels': self._debuggee_labels, - 'agentVersion': 'google.com/python2.7-' + major_version - } - - source_context = self._ReadAppJsonFile('source-context.json') - if source_context: - debuggee['sourceContexts'] = [source_context] - - source_contexts = self._ReadAppJsonFile('source-contexts.json') - if source_contexts: - debuggee['extSourceContexts'] = source_contexts - elif source_context: - debuggee['extSourceContexts'] = [{'context': source_context}] - - debuggee['uniquifier'] = self._ComputeUniquifier(debuggee) - - return debuggee - - def _GetDebuggeeDescription(self): - """Formats debuggee description based on debuggee labels.""" - return '-'.join(self._debuggee_labels[label] - for label in _DESCRIPTION_LABELS - if label in self._debuggee_labels) - - def _ComputeUniquifier(self, debuggee): - """Computes debuggee uniquifier. - - The debuggee uniquifier has to be identical on all instances. Therefore the - uniquifier should not include any random numbers and should only be based - on inputs that are guaranteed to be the same on all instances. - - Args: - debuggee: complete debuggee message without the uniquifier - - Returns: - Hex string of SHA1 hash of project information, debuggee labels and - debuglet version. - """ - uniquifier = hashlib.sha1() - - # Project information. - uniquifier.update(self._project_id()) - uniquifier.update(self._project_number()) - - # Debuggee information. - uniquifier.update(str(debuggee)) - - # Compute hash of application files if we don't have source context. This - # way we can still distinguish between different deployments. - if ('minorversion' not in debuggee.get('labels', []) and - 'sourceContexts' not in debuggee and - 'extSourceContexts' not in debuggee): - uniquifier_computer.ComputeApplicationUniquifier(uniquifier) - - return uniquifier.hexdigest() - - def _ReadAppJsonFile(self, relative_path): - """Reads JSON file from an application directory. - - Args: - relative_path: file name relative to application root directory. - - Returns: - Parsed JSON data or None if the file does not exist, can't be read or - not a valid JSON file. - """ - try: - with open(os.path.join(sys.path[0], relative_path), 'r') as f: - return json.load(f) - except (IOError, ValueError): - return None diff --git a/src/googleclouddebugger/glob_data_visibility_policy.py b/src/googleclouddebugger/glob_data_visibility_policy.py new file mode 100644 index 0000000..275e69a --- /dev/null +++ b/src/googleclouddebugger/glob_data_visibility_policy.py @@ -0,0 +1,86 @@ +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Determines the visibility of python data and symbols. + +Example Usage: + + blacklist_patterns = ( + 'com.private.*' + 'com.foo.bar' + ) + whitelist_patterns = ( + 'com.*' + ) + policy = GlobDataVisibilityPolicy(blacklist_patterns, whitelist_patterns) + + policy.IsDataVisible('org.foo.bar') -> (False, 'not whitelisted by config') + policy.IsDataVisible('com.foo.bar') -> (False, 'blacklisted by config') + policy.IsDataVisible('com.private.foo') -> (False, 'blacklisted by config') + policy.IsDataVisible('com.foo') -> (True, 'visible') +""" + +import fnmatch + +# Possible visibility responses +RESPONSES = { + 'UNKNOWN_TYPE': 'could not determine type', + 'BLACKLISTED': 'blacklisted by config', + 'NOT_WHITELISTED': 'not whitelisted by config', + 'VISIBLE': 'visible', +} + + +class GlobDataVisibilityPolicy(object): + """Policy provides visibility policy details to the caller.""" + + def __init__(self, blacklist_patterns, whitelist_patterns): + self.blacklist_patterns = blacklist_patterns + self.whitelist_patterns = whitelist_patterns + + def IsDataVisible(self, path): + """Returns a tuple (visible, reason) stating if the data should be visible. + + Args: + path: A dot separated path that represents a package, class, method or + variable. The format is identical to pythons "import" statement. + + Returns: + (visible, reason) where visible is a boolean that is True if the data + should be visible. Reason is a string reason that can be displayed + to the user and indicates why data is visible or not visible. + """ + if path is None: + return (False, RESPONSES['UNKNOWN_TYPE']) + + if _Matches(path, self.blacklist_patterns): + return (False, RESPONSES['BLACKLISTED']) + + if not _Matches(path, self.whitelist_patterns): + return (False, RESPONSES['NOT_WHITELISTED']) + + return (True, RESPONSES['VISIBLE']) + + +def _Matches(path, pattern_list): + """Returns true if path matches any patten found in pattern_list. + + Args: + path: A dot separated path to a package, class, method or variable + pattern_list: A list of wildcard patterns + + Returns: + True if path matches any wildcard found in pattern_list. + """ + # Note: This code does not scale to large pattern_list sizes. + return any(fnmatch.fnmatchcase(path, pattern) for pattern in pattern_list) diff --git a/src/googleclouddebugger/immutability_tracer.cc b/src/googleclouddebugger/immutability_tracer.cc index 22d1d1f..c05d407 100644 --- a/src/googleclouddebugger/immutability_tracer.cc +++ b/src/googleclouddebugger/immutability_tracer.cc @@ -19,12 +19,12 @@ #include "immutability_tracer.h" +#include + #include "python_util.h" -DEFINE_int32( - max_expression_lines, - 10000, - "maximum number of Python lines to allow in a single expression"); +ABSL_FLAG(int32, max_expression_lines, 10000, + "maximum number of Python lines to allow in a single expression"); namespace devtools { namespace cdbg { @@ -192,7 +192,7 @@ int ImmutabilityTracer::OnTraceCallbackInternal( break; } - if (line_count_ > FLAGS_max_expression_lines) { + if (line_count_ > absl::GetFlag(FLAGS_max_expression_lines)) { LOG(INFO) << "Expression evaluation exceeded quota"; mutable_code_detected_ = true; } @@ -251,9 +251,9 @@ void ImmutabilityTracer::VerifyCodeObject(ScopedPyCodeObject code_object) { void ImmutabilityTracer::ProcessCodeLine( PyCodeObject* code_object, int line_number) { - int size = PyString_Size(code_object->co_code); - const uint8* opcodes = - reinterpret_cast(PyString_AsString(code_object->co_code)); + int size = PyBytes_Size(code_object->co_code); + const uint8_t* opcodes = + reinterpret_cast(PyBytes_AsString(code_object->co_code)); DCHECK(opcodes != nullptr); @@ -263,6 +263,7 @@ void ImmutabilityTracer::ProcessCodeLine( do { if (start_offset != -1) { ProcessCodeRange( + opcodes, opcodes + start_offset, enumerator.offset() - start_offset); start_offset = -1; @@ -274,199 +275,267 @@ void ImmutabilityTracer::ProcessCodeLine( } while (enumerator.Next()); if (start_offset != -1) { - ProcessCodeRange(opcodes + start_offset, size - start_offset); + ProcessCodeRange(opcodes, opcodes + start_offset, size - start_offset); } } +enum OpcodeMutableStatus { + OPCODE_MUTABLE, + OPCODE_NOT_MUTABLE, + OPCODE_MAYBE_MUTABLE +}; + +static OpcodeMutableStatus IsOpcodeMutable(const uint8_t opcode) { + // Notes: + // * We allow changing local variables (i.e. STORE_FAST). Expression + // evaluation doesn't let changing local variables of the top frame + // because we use "Py_eval_input" when compiling the expression. Methods + // invoked by an expression can freely change local variables as it + // doesn't change the state of the program once the method exits. + // * We let opcodes calling methods like "PyObject_Repr". These will either + // be completely executed inside Python interpreter (with no side + // effects), or call object method (e.g. "__repr__"). In this case the + // tracer will kick in and will verify that the method has no side + // effects. + switch (opcode) { + case POP_TOP: + case ROT_TWO: + case ROT_THREE: + case DUP_TOP: + case NOP: + case UNARY_POSITIVE: + case UNARY_NEGATIVE: + case UNARY_INVERT: + case BINARY_POWER: + case BINARY_MULTIPLY: + case BINARY_MODULO: + case BINARY_ADD: + case BINARY_SUBTRACT: + case BINARY_SUBSCR: + case BINARY_FLOOR_DIVIDE: + case BINARY_TRUE_DIVIDE: + case INPLACE_FLOOR_DIVIDE: + case INPLACE_TRUE_DIVIDE: + case INPLACE_ADD: + case INPLACE_SUBTRACT: + case INPLACE_MULTIPLY: + case INPLACE_MODULO: + case BINARY_LSHIFT: + case BINARY_RSHIFT: + case BINARY_AND: + case BINARY_XOR: + case INPLACE_POWER: + case GET_ITER: + case INPLACE_LSHIFT: + case INPLACE_RSHIFT: + case INPLACE_AND: + case INPLACE_XOR: + case INPLACE_OR: + case RETURN_VALUE: + case YIELD_VALUE: + case POP_BLOCK: + case UNPACK_SEQUENCE: + case FOR_ITER: + case LOAD_CONST: + case LOAD_NAME: + case BUILD_TUPLE: + case BUILD_LIST: + case BUILD_SET: + case BUILD_MAP: + case LOAD_ATTR: + case COMPARE_OP: + case JUMP_FORWARD: + case JUMP_IF_FALSE_OR_POP: + case JUMP_IF_TRUE_OR_POP: + case POP_JUMP_IF_TRUE: + case POP_JUMP_IF_FALSE: + case LOAD_GLOBAL: + case LOAD_FAST: + case STORE_FAST: + case DELETE_FAST: + case CALL_FUNCTION: + case MAKE_FUNCTION: + case BUILD_SLICE: + case LOAD_DEREF: + case CALL_FUNCTION_KW: + case EXTENDED_ARG: +#if PY_VERSION_HEX < 0x03080000 + // These were all removed in Python 3.8. + case BREAK_LOOP: + case CONTINUE_LOOP: + case SETUP_LOOP: +#endif + case DUP_TOP_TWO: + case BINARY_MATRIX_MULTIPLY: + case INPLACE_MATRIX_MULTIPLY: + case GET_YIELD_FROM_ITER: + case YIELD_FROM: + case UNPACK_EX: + case CALL_FUNCTION_EX: + case LOAD_CLASSDEREF: +#if PY_VERSION_HEX < 0x03090000 + // Removed in Python 3.9. + case BUILD_LIST_UNPACK: + case BUILD_MAP_UNPACK: + case BUILD_MAP_UNPACK_WITH_CALL: + case BUILD_TUPLE_UNPACK: + case BUILD_TUPLE_UNPACK_WITH_CALL: + case BUILD_SET_UNPACK: +#endif +#if PY_VERSION_HEX > 0x03090000 + // Added in Python 3.9. + case LIST_TO_TUPLE: + case IS_OP: + case CONTAINS_OP: + case JUMP_IF_NOT_EXC_MATCH: +#endif + case FORMAT_VALUE: + case BUILD_CONST_KEY_MAP: + case BUILD_STRING: +#if PY_VERSION_HEX >= 0x03070000 + // Added in Python 3.7. + case LOAD_METHOD: + case CALL_METHOD: +#endif +#if PY_VERSION_HEX >= 0x03080000 + // Added back in Python 3.8 (was in 2.7 as well) + case ROT_FOUR: +#endif +#if PY_VERSION_HEX >= 0x030A0000 + // Added in Python 3.10 + case COPY_DICT_WITHOUT_KEYS: + case GET_LEN: + case MATCH_MAPPING: + case MATCH_SEQUENCE: + case MATCH_KEYS: + case MATCH_CLASS: + case ROT_N: +#endif + return OPCODE_NOT_MUTABLE; + + case PRINT_EXPR: + case STORE_GLOBAL: + case DELETE_GLOBAL: + case IMPORT_STAR: + case IMPORT_NAME: + case IMPORT_FROM: + case SETUP_FINALLY: + // TODO: allow changing fields of locally created objects/lists. + case STORE_SUBSCR: + case DELETE_SUBSCR: + case STORE_NAME: + case DELETE_NAME: + case STORE_ATTR: + case DELETE_ATTR: + case LIST_APPEND: + case SET_ADD: + case MAP_ADD: + case STORE_DEREF: + // TODO: allow exception handling + case RAISE_VARARGS: + case SETUP_WITH: + // TODO: allow closures + case LOAD_CLOSURE: +#if PY_VERSION_HEX < 0x03080000 + // Removed in Python 3.8. + case SETUP_EXCEPT: +#endif + case GET_AITER: + case GET_ANEXT: + case BEFORE_ASYNC_WITH: + case LOAD_BUILD_CLASS: + case GET_AWAITABLE: +#if PY_VERSION_HEX < 0x03090000 + // Removed in 3.9. + case WITH_CLEANUP_START: + case WITH_CLEANUP_FINISH: + case END_FINALLY: +#endif + case SETUP_ANNOTATIONS: + case POP_EXCEPT: +#if PY_VERSION_HEX < 0x03070000 + // Removed in Python 3.7. + case STORE_ANNOTATION: +#endif + case DELETE_DEREF: + case SETUP_ASYNC_WITH: +#if PY_VERSION_HEX >= 0x03080000 + // Added in Python 3.8. + case END_ASYNC_FOR: +#endif +#if PY_VERSION_HEX >= 0x03080000 && PY_VERSION_HEX < 0x03090000 + // Added in Python 3.8 and removed in 3.9 + case BEGIN_FINALLY: + case CALL_FINALLY: + case POP_FINALLY: +#endif +#if PY_VERSION_HEX >= 0x03090000 + // Added in 3.9. + case DICT_MERGE: + case DICT_UPDATE: + case LIST_EXTEND: + case SET_UPDATE: + case RERAISE: + case WITH_EXCEPT_START: + case LOAD_ASSERTION_ERROR: +#endif +#if PY_VERSION_HEX >= 0x030A0000 + // Added in Python 3.10 + case GEN_START: +#endif + return OPCODE_MUTABLE; + + default: + return OPCODE_MAYBE_MUTABLE; + } +} -void ImmutabilityTracer::ProcessCodeRange(const uint8* opcodes, int size) { - const uint8* end = opcodes + size; +void ImmutabilityTracer::ProcessCodeRange(const uint8_t* code_start, + const uint8_t* opcodes, int size) { + const uint8_t* end = opcodes + size; while (opcodes < end) { // Read opcode. - const uint8 opcode = *opcodes; - ++opcodes; - - if (HAS_ARG(opcode)) { - DCHECK_LE(opcodes + 2, end); - opcodes += 2; - - // Opcode argument is: - // (static_cast(opcodes[1]) << 8) | opcodes[0]; - // and can extend to 32 bit if EXTENDED_ARG is used. - } - - // Notes: - // * We allow changing local variables (i.e. STORE_FAST). Expression - // evaluation doesn't let changing local variables of the top frame - // because we use "Py_eval_input" when compiling the expression. Methods - // invoked by an expression can freely change local variables as it - // doesn't change the state of the program once the method exits. - // * We let opcodes calling methods like "PyObject_Repr". These will either - // be completely executed inside Python interpreter (with no side - // effects), or call object method (e.g. "__repr__"). In this case the - // tracer will kick in and will verify that the method has no side - // effects. - switch (opcode) { - case NOP: - case LOAD_FAST: - case LOAD_CONST: - case STORE_FAST: - case POP_TOP: - case ROT_TWO: - case ROT_THREE: - case ROT_FOUR: - case DUP_TOP: - case DUP_TOPX: - case UNARY_POSITIVE: - case UNARY_NEGATIVE: - case UNARY_NOT: - case UNARY_CONVERT: - case UNARY_INVERT: - case BINARY_POWER: - case BINARY_MULTIPLY: - case BINARY_DIVIDE: - case BINARY_TRUE_DIVIDE: - case BINARY_FLOOR_DIVIDE: - case BINARY_MODULO: - case BINARY_ADD: - case BINARY_SUBTRACT: - case BINARY_SUBSCR: - case BINARY_LSHIFT: - case BINARY_RSHIFT: - case BINARY_AND: - case BINARY_XOR: - case BINARY_OR: - case INPLACE_POWER: - case INPLACE_MULTIPLY: - case INPLACE_DIVIDE: - case INPLACE_TRUE_DIVIDE: - case INPLACE_FLOOR_DIVIDE: - case INPLACE_MODULO: - case INPLACE_ADD: - case INPLACE_SUBTRACT: - case INPLACE_LSHIFT: - case INPLACE_RSHIFT: - case INPLACE_AND: - case INPLACE_XOR: - case INPLACE_OR: - case SLICE+0: - case SLICE+1: - case SLICE+2: - case SLICE+3: - case LOAD_LOCALS: - case RETURN_VALUE: - case YIELD_VALUE: - case EXEC_STMT: - case UNPACK_SEQUENCE: - case LOAD_NAME: - case LOAD_GLOBAL: - case DELETE_FAST: - case LOAD_DEREF: - case BUILD_TUPLE: - case BUILD_LIST: - case BUILD_SET: - case BUILD_MAP: - case LOAD_ATTR: - case COMPARE_OP: - case JUMP_FORWARD: - case POP_JUMP_IF_FALSE: - case POP_JUMP_IF_TRUE: - case JUMP_IF_FALSE_OR_POP: - case JUMP_IF_TRUE_OR_POP: - case JUMP_ABSOLUTE: - case GET_ITER: - case FOR_ITER: - case BREAK_LOOP: - case CONTINUE_LOOP: - case SETUP_LOOP: - case CALL_FUNCTION: - case CALL_FUNCTION_VAR: - case CALL_FUNCTION_KW: - case CALL_FUNCTION_VAR_KW: - case MAKE_FUNCTION: - case MAKE_CLOSURE: - case BUILD_SLICE: - case POP_BLOCK: - break; - - case EXTENDED_ARG: - // Go to the next opcode. The argument is going to be incorrect, - // but we don't really care. + const uint8_t opcode = *opcodes; + switch (IsOpcodeMutable(opcode)) { + case OPCODE_NOT_MUTABLE: + // We don't worry about the sizes of instructions with EXTENDED_ARG. + // The argument does not really matter and so EXTENDED_ARGs can be + // treated as just another instruction with an opcode. + opcodes += 2; + DCHECK_LE(opcodes, end); break; - // TODO(vlif): allow changing fields of locally created objects/lists. - case LIST_APPEND: - case SET_ADD: - case STORE_SLICE+0: - case STORE_SLICE+1: - case STORE_SLICE+2: - case STORE_SLICE+3: - case DELETE_SLICE+0: - case DELETE_SLICE+1: - case DELETE_SLICE+2: - case DELETE_SLICE+3: - case STORE_SUBSCR: - case DELETE_SUBSCR: - case STORE_NAME: - case DELETE_NAME: - case STORE_ATTR: - case DELETE_ATTR: - case STORE_DEREF: - case STORE_MAP: - case MAP_ADD: - mutable_code_detected_ = true; - return; - - case STORE_GLOBAL: - case DELETE_GLOBAL: + case OPCODE_MAYBE_MUTABLE: + if (opcode == JUMP_ABSOLUTE) { + // Check for a jump to itself, which happens in "while True: pass". + // The tracer won't call our tracing function unless there is a jump + // backwards, or we reached a new line. In this case neither of those + // ever happens, so we can't rely on our tracing function to detect + // infinite loops. + // In this case EXTENDED_ARG doesn't matter either because if this + // instruction had one it would jump backwards and be caught tracing. + if (opcodes - code_start == opcodes[1]) { + mutable_code_detected_ = true; + return; + } + opcodes += 2; + DCHECK_LE(opcodes, end); + break; + } + LOG(WARNING) << "Unknown opcode " << static_cast(opcode); mutable_code_detected_ = true; return; - case PRINT_EXPR: - case PRINT_ITEM_TO: - case PRINT_ITEM: - case PRINT_NEWLINE_TO: - case PRINT_NEWLINE: - mutable_code_detected_ = true; - return; - - case BUILD_CLASS: - mutable_code_detected_ = true; - return; - - case IMPORT_NAME: - case IMPORT_STAR: - case IMPORT_FROM: - case SETUP_EXCEPT: - case SETUP_FINALLY: - case WITH_CLEANUP: - mutable_code_detected_ = true; - return; - - // TODO(vlif): allow exception handling. - case RAISE_VARARGS: - case END_FINALLY: - case SETUP_WITH: - mutable_code_detected_ = true; - return; - - // TODO(vlif): allow closures. - case LOAD_CLOSURE: - mutable_code_detected_ = true; - return; - - default: - LOG(WARNING) << "Unknown opcode " << static_cast(opcode); + case OPCODE_MUTABLE: mutable_code_detected_ = true; return; } } } - void ImmutabilityTracer::ProcessCCall(PyObject* function) { if (PyCFunction_Check(function)) { - // TODO(vlif): the application code can define its own "str" function + // TODO: the application code can define its own "str" function // that will do some evil things. Application can also override builtin // "str" method. If we want to protect against it, we should load pointers // to native functions when debugger initializes (which happens before @@ -477,7 +546,7 @@ void ImmutabilityTracer::ProcessCCall(PyObject* function) { auto c_function = reinterpret_cast(function); const char* name = c_function->m_ml->ml_name; - for (uint32 i = 0; i < arraysize(kWhitelistedCFunctions); ++i) { + for (uint32_t i = 0; i < arraysize(kWhitelistedCFunctions); ++i) { if (!strcmp(name, kWhitelistedCFunctions[i])) { return; } @@ -496,7 +565,7 @@ void ImmutabilityTracer::ProcessCCall(PyObject* function) { void ImmutabilityTracer::SetMutableCodeException() { - // TODO(vlif): use custom type for this exception. This way we can provide + // TODO: use custom type for this exception. This way we can provide // a more detailed error message. PyErr_SetString( PyExc_SystemError, @@ -505,4 +574,3 @@ void ImmutabilityTracer::SetMutableCodeException() { } // namespace cdbg } // namespace devtools - diff --git a/src/googleclouddebugger/immutability_tracer.h b/src/googleclouddebugger/immutability_tracer.h index 0035d94..e0cbd4d 100644 --- a/src/googleclouddebugger/immutability_tracer.h +++ b/src/googleclouddebugger/immutability_tracer.h @@ -17,7 +17,9 @@ #ifndef DEVTOOLS_CDBG_DEBUGLETS_PYTHON_IMMUTABILITY_TRACER_H_ #define DEVTOOLS_CDBG_DEBUGLETS_PYTHON_IMMUTABILITY_TRACER_H_ +#include #include + #include "common.h" #include "python_util.h" @@ -55,7 +57,7 @@ class ImmutabilityTracer { // Gets the number of lines executed while the tracer was enabled. Native // functions calls are counted as a single line. - int32 GetLineCount() const { return line_count_; } + int32_t GetLineCount() const { return line_count_; } private: // Python tracer callback function. @@ -78,7 +80,8 @@ class ImmutabilityTracer { void ProcessCodeLine(PyCodeObject* code_object, int line_number); // Verifies immutability of block of opcodes. - void ProcessCodeRange(const uint8* opcodes, int size); + void ProcessCodeRange(const uint8_t* code_start, const uint8_t* opcodes, + int size); // Verifies that the called C function is whitelisted. void ProcessCCall(PyObject* function); @@ -105,11 +108,11 @@ class ImmutabilityTracer { // Original value of PyThreadState::tracing. We revert it to 0 to enforce // trace callback on this thread, even if the whole thing was executed from // within another trace callback (that caught the breakpoint). - int32 original_thread_state_tracing_; + int32_t original_thread_state_tracing_; // Counts the number of lines executed while the tracer was enabled. Native // functions calls are counted as a single line. - int32 line_count_; + int32_t line_count_; // Set to true after immutable statement is detected. When it happens we // want to stop execution of the entire construct entirely. @@ -142,7 +145,7 @@ class ScopedImmutabilityTracer { // Gets the number of lines executed while the tracer was enabled. Native // functions calls are counted as a single line. - int32 GetLineCount() const { return Instance()->GetLineCount(); } + int32_t GetLineCount() const { return Instance()->GetLineCount(); } private: ImmutabilityTracer* Instance() { diff --git a/src/googleclouddebugger/imphook.py b/src/googleclouddebugger/imphook.py new file mode 100644 index 0000000..2e80648 --- /dev/null +++ b/src/googleclouddebugger/imphook.py @@ -0,0 +1,436 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Support for breakpoints on modules that haven't been loaded yet. + +This is the module import hook which: + 1. Takes a partial path of the module file excluding the file extension as + input (can be as short as 'foo' or longer such as 'sys/path/pkg/foo'). + 2. At each (top-level-only) import statement: + a. Generates an estimate of the modules that might be loaded as a result + of this import (and all chained imports) using the arguments of the + import hook. The estimate is best-effort, it may contain extra entries + that are not of interest to us (e.g., outer packages that were already + loaded before this import), or may be missing some module names (not + all intricacies of Python module importer are handled). + b. Checks sys.modules if any of these modules have a file that matches the + given path, using suffix match. + +""" + +import importlib +import itertools +import os +import sys # Must be imported, otherwise import hooks don't work. +import threading + +import builtins + +from . import module_utils + +# Callbacks to invoke when a module is imported. +_import_callbacks = {} +_import_callbacks_lock = threading.Lock() + +# Per thread data holding information about the import call nest level. +_import_local = threading.local() + +# Original __import__ function if import hook is installed or None otherwise. +_real_import = None + +# Original importlib.import_module function if import hook is installed or None +# otherwise. +_real_import_module = None + + +def AddImportCallbackBySuffix(path, callback): + """Register import hook. + + This function overrides the default import process. Then whenever a module + whose suffix matches path is imported, the callback will be invoked. + + A module may be imported multiple times. Import event only means that the + Python code contained an "import" statement. The actual loading and + initialization of a new module normally happens only once, at which time + the callback will be invoked. This function does not validates the existence + of such a module and it's the responsibility of the caller. + + TODO: handle module reload. + + Args: + path: python module file path. It may be missing the directories for the + outer packages, and therefore, requires suffix comparison to match + against loaded modules. If it contains all outer packages, it may + contain the sys.path as well. + It might contain an incorrect file extension (e.g., py vs. pyc). + callback: callable to invoke upon module load. + + Returns: + Function object to invoke to remove the installed callback. + """ + + def RemoveCallback(): + # This is a read-if-del operation on _import_callbacks. Lock to prevent + # callbacks from being inserted just before the key is deleted. Thus, it + # must be locked also when inserting a new entry below. On the other hand + # read only access, in the import hook, does not require a lock. + with _import_callbacks_lock: + callbacks = _import_callbacks.get(path) + if callbacks: + callbacks.remove(callback) + if not callbacks: + del _import_callbacks[path] + + with _import_callbacks_lock: + _import_callbacks.setdefault(path, set()).add(callback) + _InstallImportHookBySuffix() + + return RemoveCallback + + +def _InstallImportHookBySuffix(): + """Lazily installs import hook.""" + global _real_import + + if _real_import: + return # Import hook already installed + + _real_import = getattr(builtins, '__import__') + assert _real_import + builtins.__import__ = _ImportHookBySuffix + + # importlib.import_module and __import__ are separate in Python 3 so both + # need to be overwritten. + global _real_import_module + _real_import_module = importlib.import_module + assert _real_import_module + importlib.import_module = _ImportModuleHookBySuffix + + +def _IncrementNestLevel(): + """Increments the per thread nest level of imports.""" + # This is the top call to import (no nesting), init the per-thread nest level + # and names set. + if getattr(_import_local, 'nest_level', None) is None: + _import_local.nest_level = 0 + + if _import_local.nest_level == 0: + # Re-initialize names set at each top-level import to prevent any + # accidental unforeseen memory leak. + _import_local.names = set() + + _import_local.nest_level += 1 + + +# pylint: disable=redefined-builtin +def _ProcessImportBySuffix(name, fromlist, globals): + """Processes an import. + + Calculates the possible names generated from an import and invokes + registered callbacks if needed. + + Args: + name: Argument as passed to the importer. + fromlist: Argument as passed to the importer. + globals: Argument as passed to the importer. + """ + _import_local.nest_level -= 1 + + # To improve common code path performance, compute the loaded modules only + # if there are any import callbacks. + if _import_callbacks: + # Collect the names of all modules that might be newly loaded as a result + # of this import. Add them in a thread-local list. + _import_local.names |= _GenerateNames(name, fromlist, globals) + + # Invoke the callbacks only on the top-level import call. + if _import_local.nest_level == 0: + _InvokeImportCallbackBySuffix(_import_local.names) + + # To be safe, we clear the names set every time we exit a top level import. + if _import_local.nest_level == 0: + _import_local.names.clear() + + +# pylint: disable=redefined-builtin, g-doc-args, g-doc-return-or-yield +def _ImportHookBySuffix(name, + globals=None, + locals=None, + fromlist=None, + level=None): + """Callback when an import statement is executed by the Python interpreter. + + Argument names have to exactly match those of __import__. Otherwise calls + to __import__ that use keyword syntax will fail: __import('a', fromlist=[]). + """ + _IncrementNestLevel() + + if level is None: + # A level of 0 means absolute import, positive values means relative + # imports. + # https://docs.python.org/3/library/functions.html#__import__ + level = 0 + + try: + # Really import modules. + module = _real_import(name, globals, locals, fromlist, level) + finally: + # This _real_import call may raise an exception (e.g., ImportError). + # However, there might be several modules already loaded before the + # exception was raised. For instance: + # a.py + # import b # success + # import c # ImportError exception. + # In this case, an 'import a' statement would have the side effect of + # importing module 'b'. This should trigger the import hooks for module + # 'b'. To achieve this, we always search/invoke import callbacks (i.e., + # even when an exception is raised). + # + # Important Note: Do not use 'return' inside the finally block. It will + # cause any pending exception to be discarded. + _ProcessImportBySuffix(name, fromlist, globals) + + return module + + +def _ResolveRelativeImport(name, package): + """Resolves a relative import into an absolute path. + + This is mostly an adapted version of the logic found in the backported + version of import_module in Python 2.7. + https://github.com/python/cpython/blob/2.7/Lib/importlib/__init__.py + + Args: + name: relative name imported, such as '.a' or '..b.c' + package: absolute package path, such as 'a.b.c.d.e' + + Returns: + The absolute path of the name to be imported, or None if it is invalid. + Examples: + _ResolveRelativeImport('.c', 'a.b') -> 'a.b.c' + _ResolveRelativeImport('..c', 'a.b') -> 'a.c' + _ResolveRelativeImport('...c', 'a.c') -> None + """ + level = sum(1 for c in itertools.takewhile(lambda c: c == '.', name)) + if level == 1: + return package + name + else: + parts = package.split('.')[:-(level - 1)] + if not parts: + return None + parts.append(name[level:]) + return '.'.join(parts) + + +def _ImportModuleHookBySuffix(name, package=None): + """Callback when a module is imported through importlib.import_module.""" + _IncrementNestLevel() + + try: + # Really import modules. + module = _real_import_module(name, package) + finally: + if name.startswith('.'): + if package: + name = _ResolveRelativeImport(name, package) + else: + # Should not happen. Relative imports require the package argument. + name = None + if name: + _ProcessImportBySuffix(name, None, None) + + return module + + +def _GenerateNames(name, fromlist, globals): + """Generates the names of modules that might be loaded via this import. + + Args: + name: Argument as passed to the importer. + fromlist: Argument as passed to the importer. + globals: Argument as passed to the importer. + + Returns: + A set that contains the names of all modules that are loaded by the + currently executing import statement, as they would show up in sys.modules. + The returned set may contain module names that were already loaded before + the execution of this import statement. + The returned set may contain names that are not real modules. + """ + + def GetCurrentPackage(globals): + """Finds the name of the package for the currently executing module.""" + if not globals: + return None + + # Get the name of the module/package that the current import is being + # executed in. + current = globals.get('__name__') + if not current: + return None + + # Check if the current module is really a module, or a package. + current_file = globals.get('__file__') + if not current_file: + return None + + root = os.path.splitext(os.path.basename(current_file))[0] + if root == '__init__': + # The current import happened from a package. Return the package. + return current + else: + # The current import happened from a module. Return the package that + # contains the module. + return current.rpartition('.')[0] + + # A Python module can be addressed in two ways: + # 1. Using a path relative to the currently executing module's path. For + # instance, module p1/p2/m3.py imports p1/p2/p3/m4.py using 'import p3.m4'. + # 2. Using a path relative to sys.path. For instance, module p1/p2/m3.py + # imports p1/p2/p3/m4.py using 'import p1.p2.p3.m4'. + # + # The Python importer uses the 'globals' argument to identify the module that + # the current import is being performed in. The actual logic is very + # complicated, and we only approximate it here to limit the performance + # overhead (See import.c in the interpreter for details). Here, we only use + # the value of the globals['__name__'] for this purpose. + # + # Note: The Python importer prioritizes the current package over sys.path. For + # instance, if 'p1.p2.m3' imports 'm4', then 'p1.p2.m4' is a better match than + # the top level 'm4'. However, the debugger does not have to implement this, + # because breakpoint paths are not described relative to some other file. They + # are always assumed to be relative to the sys.path directories. If the user + # sets breakpoint inside 'm4.py', then we can map it to either the top level + # 'm4' or 'p1.p2.m4', i.e., both are valid matches. + curpkg = GetCurrentPackage(globals) + + names = set() + + # A Python module can be imported using two syntaxes: + # 1. import p1.p2.m3 + # 2. from p1.p2 import m3 + # + # When the regular 'import p1.p2.m3' syntax is used, the name of the module + # being imported is passed in the 'name' argument (e.g., name='p1.p2.m3', + # fromlist=None). + # + # When the from-import syntax is used, then fromlist contains the leaf names + # of the modules, and name contains the containing package. For instance, if + # name='a.b', fromlist=['c', 'd'], then we add ['a.b.c', 'a.b.d']. + # + # Corner cases: + # 1. The fromlist syntax can be used to import a function from a module. + # For instance, 'from p1.p2.m3 import func'. + # 2. Sometimes, the importer is passed a dummy fromlist=['__doc__'] (see + # import.c in the interpreter for details). + # Due to these corner cases, the returned set may contain entries that are not + # names of real modules. + for from_entry in fromlist or []: + # Name relative to sys.path. + # For relative imports such as 'from . import x', name will be the empty + # string. Thus we should not prepend a '.' to the entry. + entry = (name + '.' + from_entry) if name else from_entry + names.add(entry) + # Name relative to the currently executing module's package. + if curpkg: + names.add(curpkg + '.' + entry) + + # Generate all names from name. For instance, if name='a.b.c', then + # we need to add ['a.b.c', 'a.b', 'a']. + while name: + # Name relative to sys.path. + names.add(name) + # Name relative to currently executing module's package. + if curpkg: + names.add(curpkg + '.' + name) + name = name.rpartition('.')[0] + + return names + + +def _InvokeImportCallbackBySuffix(names): + """Invokes import callbacks for newly loaded modules. + + Uses a path suffix match to identify whether a loaded module matches the + file path provided by the user. + + Args: + names: A set of names for modules that are loaded by the current import. + The set may contain some superfluous entries that were already + loaded before this import, or some entries that do not correspond + to a module. The list is expected to be much smaller than the exact + sys.modules so that a linear search is not as costly. + """ + + def GetModuleFromName(name, path): + """Returns the loaded module for this name/path, or None if not found. + + Args: + name: A string that may represent the name of a loaded Python module. + path: If 'name' ends with '.*', then the last path component in 'path' is + used to identify what the wildcard may map to. Does not contain file + extension. + + Returns: + The loaded module for the given name and path, or None if a loaded module + was not found. + """ + # The from-import syntax can be used as 'from p1.p2 import *'. In this case, + # we cannot know what modules will match the wildcard. However, we know that + # the wildcard can only be used to import leaf modules. So, we guess that + # the leaf module will have the same name as the leaf file name the user + # provided. For instance, + # User input path = 'foo.py' + # Currently executing import: + # from pkg1.pkg2 import * + # Then, we combine: + # 1. 'pkg1.pkg2' from import's outer package and + # 2. Add 'foo' as our guess for the leaf module name. + # So, we will search for modules with name similar to 'pkg1.pkg2.foo'. + if name.endswith('.*'): + # Replace the final '*' with the name of the module we are looking for. + name = name.rpartition('.')[0] + '.' + path.split('/')[-1] + + # Check if the module was loaded. + return sys.modules.get(name) + + # _import_callbacks might change during iteration because RemoveCallback() + # might delete items. Iterate over a copy to avoid a + # 'dictionary changed size during iteration' error. + for path, callbacks in list(_import_callbacks.items()): + root = os.path.splitext(path)[0] + + nonempty_names = (n for n in names if n) + modules = (GetModuleFromName(name, root) for name in nonempty_names) + nonempty_modules = (m for m in modules if m) + + for module in nonempty_modules: + # TODO: Write unit test to cover None case. + mod_file = getattr(module, '__file__', None) + if not mod_file: + continue + if not isinstance(mod_file, str): + continue + + mod_root = os.path.splitext(mod_file)[0] + + # If the module is relative, add the curdir prefix to convert it to + # absolute path. Note that we don't use os.path.abspath because it + # also normalizes the path (which has side effects we don't want). + if not os.path.isabs(mod_root): + mod_root = os.path.join(os.curdir, mod_root) + + if module_utils.IsPathSuffix(mod_root, root): + for callback in callbacks.copy(): + callback(module) + break diff --git a/src/googleclouddebugger/labels.py b/src/googleclouddebugger/labels.py new file mode 100644 index 0000000..1bca819 --- /dev/null +++ b/src/googleclouddebugger/labels.py @@ -0,0 +1,47 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines the keys of the well known labels used by the cloud debugger. + +TODO: Define these strings in a common format for all agents to +share. This file needs to be maintained with the code generator file +being used in the UI, until the labels are unified. +""" + + +class Breakpoint(object): + REQUEST_LOG_ID = 'requestlogid' + + SET_ALL = frozenset([ + 'requestlogid', + ]) + + +class Debuggee(object): + DOMAIN = 'domain' + PROJECT_ID = 'projectid' + MODULE = 'module' + VERSION = 'version' + MINOR_VERSION = 'minorversion' + PLATFORM = 'platform' + REGION = 'region' + + SET_ALL = frozenset([ + 'domain', + 'projectid', + 'module', + 'version', + 'minorversion', + 'platform', + 'region', + ]) diff --git a/src/googleclouddebugger/leaky_bucket.cc b/src/googleclouddebugger/leaky_bucket.cc index 83caa68..aa18ef0 100644 --- a/src/googleclouddebugger/leaky_bucket.cc +++ b/src/googleclouddebugger/leaky_bucket.cc @@ -19,31 +19,20 @@ #include "leaky_bucket.h" -#ifndef NACL_BUILD -#include -#include -#else // NACL_BUILD -#include "third_party/apphosting/nacl/chromium/base/time.h" -#endif // NACL_BUILD - #include +#include #include namespace devtools { namespace cdbg { -static int64 NowInNanoseconds() { -#ifndef NACL_BUILD +static int64_t NowInNanoseconds() { timespec time; clock_gettime(CLOCK_MONOTONIC, &time); return 1000000000LL * time.tv_sec + time.tv_nsec; -#else // NACL_BUILD - return (base::Time::Now() - base::Time::UnixEpoch()).InMicroseconds() * 1000; -#endif // NACL_BUILD } - -LeakyBucket::LeakyBucket(int64 capacity, int64 fill_rate) +LeakyBucket::LeakyBucket(int64_t capacity, int64_t fill_rate) : capacity_(capacity), fractional_tokens_(0.0), fill_rate_(fill_rate), @@ -51,20 +40,19 @@ LeakyBucket::LeakyBucket(int64 capacity, int64 fill_rate) tokens_ = capacity; } - -bool LeakyBucket::RequestTokensSlow(int64 requested_tokens) { +bool LeakyBucket::RequestTokensSlow(int64_t requested_tokens) { // Getting the time outside the lock is significantly faster (reduces // contention, etc.). - const int64 current_time_ns = NowInNanoseconds(); + const int64_t current_time_ns = NowInNanoseconds(); std::lock_guard lock(mu_); - const int64 cur_tokens = AtomicLoadTokens(); + const int64_t cur_tokens = AtomicLoadTokens(); if (cur_tokens >= 0) { return true; } - const int64 available_tokens = + const int64_t available_tokens = RefillBucket(requested_tokens + cur_tokens, current_time_ns); if (available_tokens >= 0) { return true; @@ -77,17 +65,15 @@ bool LeakyBucket::RequestTokensSlow(int64 requested_tokens) { return false; } - -int64 LeakyBucket::RefillBucket( - int64 available_tokens, - int64 current_time_ns) { +int64_t LeakyBucket::RefillBucket(int64_t available_tokens, + int64_t current_time_ns) { if (current_time_ns <= fill_time_ns_) { // We check to see if the bucket has been refilled after we checked the // current time but before we grabbed mu_. If it has there's nothing to do. return AtomicLoadTokens(); } - const int64 elapsed_ns = current_time_ns - fill_time_ns_; + const int64_t elapsed_ns = current_time_ns - fill_time_ns_; fill_time_ns_ = current_time_ns; // Calculate the number of tokens we can add. Note elapsed is in ns while @@ -96,10 +82,10 @@ int64 LeakyBucket::RefillBucket( // don't add more than the capacity of leaky bucket. fractional_tokens_ += std::min(elapsed_ns * (fill_rate_ / 1e9), static_cast(capacity_)); - const int64 ideal_tokens_to_add = fractional_tokens_; + const int64_t ideal_tokens_to_add = fractional_tokens_; - const int64 max_tokens_to_add = capacity_ - available_tokens; - int64 real_tokens_to_add; + const int64_t max_tokens_to_add = capacity_ - available_tokens; + int64_t real_tokens_to_add; if (max_tokens_to_add < ideal_tokens_to_add) { fractional_tokens_ = 0.0; real_tokens_to_add = max_tokens_to_add; @@ -111,16 +97,15 @@ int64 LeakyBucket::RefillBucket( return AtomicIncrementTokens(real_tokens_to_add); } - -void LeakyBucket::TakeTokens(int64 tokens) { - const int64 remaining = AtomicIncrementTokens(-tokens); +void LeakyBucket::TakeTokens(int64_t tokens) { + const int64_t remaining = AtomicIncrementTokens(-tokens); if (remaining < 0) { // (Try to) refill the bucket. If we don't do this, we could just // keep decreasing forever without refilling. We need to be // refilling at least as frequently as every capacity_ / // fill_rate_ seconds. Otherwise, we waste tokens. - const int64 current_time_ns = NowInNanoseconds(); + const int64_t current_time_ns = NowInNanoseconds(); std::lock_guard lock(mu_); RefillBucket(remaining, current_time_ns); diff --git a/src/googleclouddebugger/leaky_bucket.h b/src/googleclouddebugger/leaky_bucket.h index 0547539..4dd8d27 100644 --- a/src/googleclouddebugger/leaky_bucket.h +++ b/src/googleclouddebugger/leaky_bucket.h @@ -18,6 +18,7 @@ #define DEVTOOLS_CDBG_COMMON_LEAKY_BUCKET_H_ #include +#include #include // NOLINT #include "common.h" @@ -32,7 +33,7 @@ class LeakyBucket { public: // "capacity": The max number of tokens the bucket can hold at any point. // "fill_rate": The rate which the bucket fills in tokens per second. - LeakyBucket(int64 capacity, int64 fill_rate); + LeakyBucket(int64_t capacity, int64_t fill_rate); ~LeakyBucket() {} @@ -46,32 +47,31 @@ class LeakyBucket { // tokens are being acquired. Suddenly, infinite demand arrives. // At most "capacity_" tokens will be granted immediately. Subsequent // requests will only be admitted based on the fill rate. - inline bool RequestTokens(int64 requested_tokens); + inline bool RequestTokens(int64_t requested_tokens); // Takes tokens from bucket, possibly sending the number of tokens in the // bucket negative. - void TakeTokens(int64 tokens); + void TakeTokens(int64_t tokens); private: // The slow path of RequestTokens. Grabs a lock and may refill tokens_ // using the fill rate and time passed since last fill. - bool RequestTokensSlow(int64 requested_tokens); + bool RequestTokensSlow(int64_t requested_tokens); // Refills the bucket with newly added tokens since last update and returns // the current amount of tokens in the bucket. 'available_tokens' indicates // the number of tokens in the bucket before refilling. 'current_time_ns' // indicates the current time in nanoseconds. - int64 RefillBucket(int64 available_tokens, int64 current_time_ns); - + int64_t RefillBucket(int64_t available_tokens, int64_t current_time_ns); // Atomically increment "tokens_". - inline int64 AtomicIncrementTokens(int64 increment) { - return tokens_ += increment; + inline int64_t AtomicIncrementTokens(int64_t increment) { + return tokens_.fetch_add(increment, std::memory_order_relaxed) + increment; } // Atomically load the value of "tokens_". - inline int64 AtomicLoadTokens() const { - return tokens_; + inline int64_t AtomicLoadTokens() const { + return tokens_.load(std::memory_order_relaxed); } private: @@ -85,33 +85,33 @@ class LeakyBucket { // // Tokens can be momentarily negative, either via TakeTokens or // during a normal RequestTokens that was not satisfied. - std::atomic tokens_; + std::atomic tokens_; // Capacity of the bucket. - const int64 capacity_; + const int64_t capacity_; // Although the main token count is an integer we also track fractional tokens // for increased precision. double fractional_tokens_; // Fill rate in tokens per second. - const int64 fill_rate_; + const int64_t fill_rate_; // Time in nanoseconds of the last refill. - int64 fill_time_ns_; + int64_t fill_time_ns_; DISALLOW_COPY_AND_ASSIGN(LeakyBucket); }; // Inline fast-path. -inline bool LeakyBucket::RequestTokens(int64 requested_tokens) { +inline bool LeakyBucket::RequestTokens(int64_t requested_tokens) { if (requested_tokens > capacity_) { return false; } // Try and grab some tokens. remaining is how many tokens are // left after subtracting out requested tokens. - int64 remaining = AtomicIncrementTokens(-requested_tokens); + int64_t remaining = AtomicIncrementTokens(-requested_tokens); if (remaining >= 0) { // We had at least as much as we needed. return true; diff --git a/src/googleclouddebugger/module_explorer.py b/src/googleclouddebugger/module_explorer.py index ef12312..ac62ce4 100644 --- a/src/googleclouddebugger/module_explorer.py +++ b/src/googleclouddebugger/module_explorer.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Finds all the code objects defined by a module.""" import gc @@ -19,8 +18,6 @@ import sys import types -import cdbg_native as native - # Maximum traversal depth when looking for all the code objects referenced by # a module or another code object. _MAX_REFERENTS_BFS_DEPTH = 15 @@ -29,11 +26,14 @@ # objects implemented in a module. _MAX_VISIT_OBJECTS = 100000 +# Maximum referents an object can have before it is skipped in the BFS +# traversal. This is to prevent things like long objects or dictionaries that +# probably do not contain code objects from using the _MAX_VISIT_OBJECTS quota. +_MAX_OBJECT_REFERENTS = 1000 + # Object types to ignore when looking for the code objects. -_BFS_IGNORE_TYPES = (types.ModuleType, types.NoneType, types.BooleanType, - types.IntType, types.LongType, types.FloatType, - types.StringType, types.UnicodeType, - types.BuiltinFunctionType, types.BuiltinMethodType) +_BFS_IGNORE_TYPES = (types.ModuleType, type(None), bool, float, bytes, str, int, + types.BuiltinFunctionType, types.BuiltinMethodType, list) def GetCodeObjectAtLine(module, line): @@ -44,16 +44,61 @@ def GetCodeObjectAtLine(module, line): line: 1-based line number of the statement. Returns: - Code object or None if not found. + (True, Code object) on success or (False, (prev_line, next_line)) on + failure, where prev_line and next_line are the closest lines with code above + and below the specified line, or None if they do not exist. """ if not hasattr(module, '__file__'): - return None + return (False, (None, None)) + + prev_line = 0 + next_line = sys.maxsize for code_object in _GetModuleCodeObjects(module): - if native.HasSourceLine(code_object, line): - return code_object + for co_line_number in _GetLineNumbers(code_object): + if co_line_number == line: + return (True, code_object) + elif co_line_number < line: + prev_line = max(prev_line, co_line_number) + elif co_line_number > line: + next_line = min(next_line, co_line_number) + # Continue because line numbers may not be sequential. + + prev_line = None if prev_line == 0 else prev_line + next_line = None if next_line == sys.maxsize else next_line + return (False, (prev_line, next_line)) + + +def _GetLineNumbers(code_object): + """Generator for getting the line numbers of a code object. + + Args: + code_object: the code object. + + Yields: + The next line number in the code object. + """ - return None + if sys.version_info.minor < 10: + # Get the line number deltas, which are the odd number entries, from the + # lnotab. See + # https://svn.python.org/projects/python/branches/pep-0384/Objects/lnotab_notes.txt + # In Python 3, prior to 3.10, this is just a byte array. + line_incrs = code_object.co_lnotab[1::2] + current_line = code_object.co_firstlineno + for line_incr in line_incrs: + if line_incr >= 0x80: + # line_incrs is an array of 8-bit signed integers + line_incr -= 0x100 + current_line += line_incr + yield current_line + else: + # Get the line numbers directly, which are the third entry in the tuples. + # https://peps.python.org/pep-0626/#the-new-co-lines-method-of-code-objects + line_numbers = [entry[2] for entry in code_object.co_lines()] + for line_number in line_numbers: + if line_number is not None: + yield line_number def _GetModuleCodeObjects(module): @@ -113,8 +158,12 @@ def _FindCodeObjectsReferents(module, start_objects, visit_recorder): Returns: List of code objects. """ + def CheckIgnoreCodeObject(code_object): - """Checks if the code object originated from "module". + """Checks if the code object can be ignored. + + Code objects that are not implemented in the module, or are from a lambda or + generator expression can be ignored. If the module was precompiled, the code object may point to .py file, while the module says that it originated from .pyc file. We just strip extension @@ -124,8 +173,11 @@ def CheckIgnoreCodeObject(code_object): code_object: code object that we want to check against module. Returns: - False if code_object was implemented in module or True otherwise. + True if the code object can be ignored, False otherwise. """ + if code_object.co_name in ('', ''): + return True + code_object_file = os.path.splitext(code_object.co_filename)[0] module_file = os.path.splitext(module.__file__)[0] @@ -133,7 +185,6 @@ def CheckIgnoreCodeObject(code_object): if code_object_file == module_file: return False - return True def CheckIgnoreClass(cls): @@ -142,31 +193,39 @@ def CheckIgnoreClass(cls): if not cls_module: return False # We can't tell for sure, so explore this class. - return ( - cls_module is not module and - getattr(cls_module, '__file__', None) != module.__file__) + return (cls_module is not module and + getattr(cls_module, '__file__', None) != module.__file__) code_objects = set() current = start_objects + for obj in current: + visit_recorder.Record(obj) + depth = 0 while current and depth < _MAX_REFERENTS_BFS_DEPTH: - referents = gc.get_referents(*current) - current = [] - for obj in referents: - if isinstance(obj, _BFS_IGNORE_TYPES) or not visit_recorder.Record(obj): + new_current = [] + for current_obj in current: + referents = gc.get_referents(current_obj) + if (current_obj is not module.__dict__ and + len(referents) > _MAX_OBJECT_REFERENTS): continue - if isinstance(obj, types.CodeType) and CheckIgnoreCodeObject(obj): - continue + for obj in referents: + if isinstance(obj, _BFS_IGNORE_TYPES) or not visit_recorder.Record(obj): + continue - if isinstance(obj, types.ClassType) and CheckIgnoreClass(obj): - continue + if isinstance(obj, types.CodeType) and CheckIgnoreCodeObject(obj): + continue - if isinstance(obj, types.CodeType): - code_objects.add(obj) - else: - current.append(obj) + if isinstance(obj, type) and CheckIgnoreClass(obj): + continue + if isinstance(obj, types.CodeType): + code_objects.add(obj) + else: + new_current.append(obj) + + current = new_current depth += 1 return code_objects @@ -203,4 +262,3 @@ def Record(self, obj): self._visit_recorder_objects[obj_id] = obj return True - diff --git a/src/googleclouddebugger/module_lookup.py b/src/googleclouddebugger/module_lookup.py deleted file mode 100644 index 039fc4f..0000000 --- a/src/googleclouddebugger/module_lookup.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2015 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS-IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Finds the loaded module by source path. - -The lookup is a fuzzy one, the source path coming from a breakpoint might -be a subpath of module path or may be longer than the module path. -""" - -import os -import sys - - -def FindModule(source_path): - """Find the loaded module by source path. - - If there are multiple possible matches, chooses the best match. - - Args: - source_path: source file path as specified in the breakpoint. - - Returns: - Module object that best matches the source_path or None if no match found. - """ - file_name, ext = os.path.splitext(os.path.basename(source_path)) - if ext != '.py': - return None # ".py" extension is expected - - candidates = _GetModulesByFileName(file_name) - if not candidates: - return None - - if len(candidates) == 1: - return candidates[0] - - return candidates[_Disambiguate( - os.path.split(source_path)[0], - [os.path.split(module.__file__)[0] for module in candidates])] - - -def _GetModulesByFileName(lookup_file_name): - """Gets list of all the loaded modules by file name (ignores directory).""" - matches = [] - - # Clone modules dictionaries to allow new modules to load during iteration. - for unused_name, module in sys.modules.copy().iteritems(): - if not hasattr(module, '__file__'): - continue # This is a built-in module. - - file_name, ext = os.path.splitext(os.path.basename(module.__file__)) - if file_name == lookup_file_name and (ext == '.py' or ext == '.pyc'): - matches.append(module) - - return matches - - -def _Disambiguate(lookup_path, paths): - """Disambiguate multiple candidates based on the longest suffix. - - Example when this disambiguation is needed: - Breakpoint at: 'myproject/app/db/common.py' - Candidate modules: ['/home/root/fe/common.py', '/home/root/db/common.py'] - - In this example the input to this function will be: - lookup_path = 'myproject/app/db' - paths = ['/home/root/fe', '/home/root/db'] - - The second path is clearly the best match, so this function will return 1. - - Args: - lookup_path: the source path of the searched module (without file name - and extension). - paths: candidate paths (each without file name and extension). - - Returns: - Index of the best match or arbitrary index if this function can't - discriminate. - """ - best_index = 0 - best_len = 0 - for i in range(len(paths)): - current_len = _CommonSuffix(lookup_path, paths[i]) - if current_len > best_len: - best_index = i - best_len = current_len - - return best_index - - -def _CommonSuffix(path1, path2): - """Computes the number of common directory names at the tail of the paths. - - Examples: - * _CommonSuffix('a/x/y', 'b/x/y') = 2 - * _CommonSuffix('a/b/c', 'd/e/f') = 0 - * _CommonSuffix('a/b/c', 'a/b/x') = 0 - - Args: - path1: first directory path (should not have file name). - path2: second directory path (should not have file name). - - Returns: - Number of common consecutive directory segments from right. - """ - - # Normalize the paths just to be on the safe side - path1 = path1.strip(os.sep) - path2 = path2.strip(os.sep) - - counter = 0 - while path1 and path2: - path1, cur1 = os.path.split(path1) - path2, cur2 = os.path.split(path2) - - if cur1 != cur2 or not cur1: - break - - counter += 1 - - return counter diff --git a/src/googleclouddebugger/module_search.py b/src/googleclouddebugger/module_search.py new file mode 100644 index 0000000..e8d29f3 --- /dev/null +++ b/src/googleclouddebugger/module_search.py @@ -0,0 +1,105 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inclusive search for module files.""" + +import os +import sys + + +def Search(path): + """Search sys.path to find a source file that matches path. + + The provided input path may have an unknown number of irrelevant outer + directories (e.g., /garbage1/garbage2/real1/real2/x.py'). This function + does multiple search iterations until an actual Python module file that + matches the input path is found. At each iteration, it strips one leading + directory from the path and searches the directories at sys.path + for a match. + + Examples: + sys.path: ['/x1/x2', '/y1/y2'] + Search order: [.pyo|.pyc|.py] + /x1/x2/a/b/c + /x1/x2/b/c + /x1/x2/c + /y1/y2/a/b/c + /y1/y2/b/c + /y1/y2/c + Filesystem: ['/y1/y2/a/b/c.pyc'] + + 1) Search('a/b/c.py') + Returns '/y1/y2/a/b/c.pyc' + 2) Search('q/w/a/b/c.py') + Returns '/y1/y2/a/b/c.pyc' + 3) Search('q/w/c.py') + Returns 'q/w/c.py' + + The provided input path may also be relative to an unknown directory. + The path may include some or all outer package names. + + Examples (continued): + + 4) Search('c.py') + Returns 'c.py' + 5) Search('b/c.py') + Returns 'b/c.py' + + Args: + path: Path that describes a source file. Must contain .py file extension. + Must not contain any leading os.sep character. + + Returns: + Full path to the matched source file, if a match is found. Otherwise, + returns the input path. + + Raises: + AssertionError: if the provided path is an absolute path, or if it does not + have a .py extension. + """ + + def SearchCandidates(p): + """Generates all candidates for the fuzzy search of p.""" + while p: + yield p + (_, _, p) = p.partition(os.sep) + + # Verify that the os.sep is already stripped from the input. + assert not path.startswith(os.sep) + + # Strip the file extension, it will not be needed. + src_root, src_ext = os.path.splitext(path) + assert src_ext == '.py' + + # Search longer suffixes first. Move to shorter suffixes only if longer + # suffixes do not result in any matches. + for src_part in SearchCandidates(src_root): + # Search is done in sys.path order, which gives higher priority to earlier + # entries in sys.path list. + for sys_path in sys.path: + f = os.path.join(sys_path, src_part) + # The order in which we search the extensions does not matter. + for ext in ('.pyo', '.pyc', '.py'): + # The os.path.exists check internally follows symlinks and flattens + # relative paths, so we don't have to deal with it. + fext = f + ext + if os.path.exists(fext): + # Once we identify a matching file in the filesystem, we should + # preserve the (1) potentially-symlinked and (2) + # potentially-non-flattened file path (f+ext), because that's exactly + # how we expect it to appear in sys.modules when we search the file + # there. + return fext + + # A matching file was not found in sys.path directories. + return path diff --git a/src/googleclouddebugger/module_utils.py b/src/googleclouddebugger/module_utils.py new file mode 100644 index 0000000..53f2e37 --- /dev/null +++ b/src/googleclouddebugger/module_utils.py @@ -0,0 +1,100 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Provides utility functions for module path processing.""" + +import os +import sys + +def NormalizePath(path): + """Normalizes a path. + + E.g. One example is it will convert "/a/b/./c" -> "/a/b/c" + """ + # TODO: Calling os.path.normpath "may change the meaning of a + # path that contains symbolic links" (e.g., "A/foo/../B" != "A/B" if foo is a + # symlink). This might cause trouble when matching against loaded module + # paths. We should try to avoid using it. + # Example: + # > import symlink.a + # > symlink.a.__file__ + # symlink/a.py + # > import target.a + # > starget.a.__file__ + # target/a.py + # Python interpreter treats these as two separate modules. So, we also need to + # handle them the same way. + return os.path.normpath(path) + + +def IsPathSuffix(mod_path, path): + """Checks whether path is a full path suffix of mod_path. + + Args: + mod_path: Must be an absolute path to a source file. Must not have + file extension. + path: A relative path. Must not have file extension. + + Returns: + True if path is a full path suffix of mod_path. False otherwise. + """ + return (mod_path.endswith(path) and (len(mod_path) == len(path) or + mod_path[:-len(path)].endswith(os.sep))) + + +def GetLoadedModuleBySuffix(path): + """Searches sys.modules to find a module with the given file path. + + Args: + path: Path to the source file. It can be relative or absolute, as suffix + match can handle both. If absolute, it must have already been + sanitized. + + Algorithm: + The given path must be a full suffix of a loaded module to be a valid match. + File extensions are ignored when performing suffix match. + + Example: + path: 'a/b/c.py' + modules: {'a': 'a.py', 'a.b': 'a/b.py', 'a.b.c': 'a/b/c.pyc'] + returns: module('a.b.c') + + Returns: + The module that corresponds to path, or None if such module was not + found. + """ + root = os.path.splitext(path)[0] + for module in sys.modules.values(): + mod_root = os.path.splitext(getattr(module, '__file__', None) or '')[0] + + if not mod_root: + continue + + # While mod_root can contain symlinks, we cannot eliminate them. This is + # because, we must perform exactly the same transformations on mod_root and + # path, yet path can be relative to an unknown directory which prevents + # identifying and eliminating symbolic links. + # + # Therefore, we only convert relative to absolute path. + if not os.path.isabs(mod_root): + mod_root = os.path.join(os.getcwd(), mod_root) + + # In the following invocation 'python3 ./main.py' (using the ./), the + # mod_root variable will '/base/path/./main'. In order to correctly compare + # it with the root variable, it needs to be '/base/path/main'. + mod_root = NormalizePath(mod_root) + + if IsPathSuffix(mod_root, root): + return module + + return None diff --git a/src/googleclouddebugger/native_module.cc b/src/googleclouddebugger/native_module.cc index 63995c7..60a9a8a 100644 --- a/src/googleclouddebugger/native_module.cc +++ b/src/googleclouddebugger/native_module.cc @@ -17,11 +17,14 @@ // Ensure that Python.h is included before any other header. #include "common.h" +#include "native_module.h" + +#include + #include "bytecode_breakpoint.h" #include "common.h" #include "conditional_breakpoint.h" #include "immutability_tracer.h" -#include "native_module.h" #include "python_callback.h" #include "python_util.h" #include "rate_limit.h" @@ -37,52 +40,22 @@ const LogSeverity LOG_SEVERITY_ERROR = ::google::ERROR; struct INTEGER_CONSTANT { const char* name; - int32 value; + int32_t value; }; static const INTEGER_CONSTANT kIntegerConstants[] = { - { - "BREAKPOINT_EVENT_HIT", - static_cast(BreakpointEvent::Hit) - }, - { - "BREAKPOINT_EVENT_ERROR", - static_cast(BreakpointEvent::Error) - }, - { - "BREAKPOINT_EVENT_GLOBAL_CONDITION_QUOTA_EXCEEDED", - static_cast(BreakpointEvent::GlobalConditionQuotaExceeded) - }, - { - "BREAKPOINT_EVENT_BREAKPOINT_CONDITION_QUOTA_EXCEEDED", - static_cast(BreakpointEvent::BreakpointConditionQuotaExceeded) - }, - { - "BREAKPOINT_EVENT_CONDITION_EXPRESSION_MUTABLE", - static_cast(BreakpointEvent::ConditionExpressionMutable) - } -}; + {"BREAKPOINT_EVENT_HIT", static_cast(BreakpointEvent::Hit)}, + {"BREAKPOINT_EVENT_ERROR", static_cast(BreakpointEvent::Error)}, + {"BREAKPOINT_EVENT_GLOBAL_CONDITION_QUOTA_EXCEEDED", + static_cast(BreakpointEvent::GlobalConditionQuotaExceeded)}, + {"BREAKPOINT_EVENT_BREAKPOINT_CONDITION_QUOTA_EXCEEDED", + static_cast(BreakpointEvent::BreakpointConditionQuotaExceeded)}, + {"BREAKPOINT_EVENT_CONDITION_EXPRESSION_MUTABLE", + static_cast(BreakpointEvent::ConditionExpressionMutable)}}; // Class to set zero overhead breakpoints. static BytecodeBreakpoint g_bytecode_breakpoint; -// Condition and dynamic logging rate limits are defined as the maximum -// amount of time in nanoseconds to spend on particular processing per -// second. These rate are enforced as following: -// 1. If a single breakpoint contributes to half the maximum rate, that -// breakpoint will be deactivated. -// 2. If all breakpoints combined hit the maximum rate, any breakpoint to -// exceed the limit gets disabled. -// -// The first rule ensures that in vast majority of scenarios expensive -// breakpoints will get deactivated. The second rule guarantees that in edge -// case scenarios the total amount of time spent in condition evaluation will -// not exceed the alotted limit. -// -// All limits ignore the number of CPUs since Python is inherently single -// threaded. -static std::unique_ptr g_global_condition_quota_; - // Initializes C++ flags and logging. // // This function should be called exactly once during debugger bootstrap. It @@ -148,7 +121,6 @@ static PyObject* InitializeModule(PyObject* self, PyObject* py_args) { Py_RETURN_NONE; } - // Common code for LogXXX functions. // // The source file name and the source line are obtained automatically by @@ -204,42 +176,7 @@ static PyObject* LogError(PyObject* self, PyObject* py_args) { } -// Searches for a statement with the specified line number in the specified -// code object. -// -// Args: -// code_object: Python code object to analyze. -// line: 1-based line number to search. -// -// Returns: -// True if code_object includes a statement that maps to the specified -// source line or False otherwise. -static PyObject* HasSourceLine(PyObject* self, PyObject* py_args) { - PyCodeObject* code_object = nullptr; - int line = -1; - if (!PyArg_ParseTuple(py_args, "Oi", &code_object, &line)) { - return nullptr; - } - - if ((code_object == nullptr) || !PyCode_Check(code_object)) { - PyErr_SetString( - PyExc_TypeError, - "code_object must be a code object"); - return nullptr; - } - - CodeObjectLinesEnumerator enumerator(code_object); - do { - if (enumerator.line_number() == line) { - Py_RETURN_TRUE; - } - } while (enumerator.Next()); - - Py_RETURN_FALSE; -} - - -// Sets a new breakpoint in Python code. The breakpoint may have an optional +// Creates a new breakpoint in Python code. The breakpoint may have an optional // condition to evaluate. When the breakpoint hits (and the condition matches) // a callable object will be invoked from that thread. // @@ -259,7 +196,8 @@ static PyObject* HasSourceLine(PyObject* self, PyObject* py_args) { // Returns: // Integer cookie identifying this breakpoint. It needs to be specified when // clearing the breakpoint. -static PyObject* SetConditionalBreakpoint(PyObject* self, PyObject* py_args) { +static PyObject* CreateConditionalBreakpoint(PyObject* self, + PyObject* py_args) { PyCodeObject* code_object = nullptr; int line = -1; PyCodeObject* condition = nullptr; @@ -301,7 +239,7 @@ static PyObject* SetConditionalBreakpoint(PyObject* self, PyObject* py_args) { int cookie = -1; - cookie = g_bytecode_breakpoint.SetBreakpoint( + cookie = g_bytecode_breakpoint.CreateBreakpoint( code_object, line, std::bind( @@ -318,11 +256,11 @@ static PyObject* SetConditionalBreakpoint(PyObject* self, PyObject* py_args) { } -// Clears the breakpoint previously set by "SetConditionalBreakpoint". Must be -// called exactly once per each call to "SetConditionalBreakpoint". +// Clears a breakpoint previously created by "CreateConditionalBreakpoint". Must +// be called exactly once per each call to "CreateConditionalBreakpoint". // // Args: -// cookie: breakpoint identifier returned by "SetConditionalBreakpoint". +// cookie: breakpoint identifier returned by "CreateConditionalBreakpoint". static PyObject* ClearConditionalBreakpoint(PyObject* self, PyObject* py_args) { int cookie = -1; if (!PyArg_ParseTuple(py_args, "i", &cookie)) { @@ -334,6 +272,24 @@ static PyObject* ClearConditionalBreakpoint(PyObject* self, PyObject* py_args) { Py_RETURN_NONE; } +// Activates a previously created breakpoint by "CreateConditionalBreakpoint" +// and that haven't been cleared yet using "ClearConditionalBreakpoint". +// TODO: Optimize breakpoint activation by having one method +// "ActivateAllConditionalBreakpoints" for all previously created breakpoints. +// +// Args: +// cookie: breakpoint identifier returned by "CreateConditionalBreakpoint". +static PyObject* ActivateConditionalBreakpoint(PyObject* self, + PyObject* py_args) { + int cookie = -1; + if (!PyArg_ParseTuple(py_args, "i", &cookie)) { + return nullptr; + } + + g_bytecode_breakpoint.ActivateBreakpoint(cookie); + + Py_RETURN_NONE; +} // Invokes a Python callable object with immutability tracer. // @@ -368,21 +324,50 @@ static PyObject* CallImmutable(PyObject* self, PyObject* py_args) { } PyFrameObject* frame = reinterpret_cast(obj_frame); - PyCodeObject* code = reinterpret_cast(obj_code); PyFrame_FastToLocals(frame); ScopedImmutabilityTracer immutability_tracer; - return PyEval_EvalCode(code, frame->f_globals, frame->f_locals); +#if PY_MAJOR_VERSION >= 3 + return PyEval_EvalCode(obj_code, frame->f_globals, frame->f_locals); +#else + return PyEval_EvalCode(reinterpret_cast(obj_code), + frame->f_globals, frame->f_locals); +#endif } +// Applies the dynamic logs quota, which is limited by both total messages and +// total bytes. This should be called before doing the actual logging call. +// +// Args: +// num_bytes: number of bytes in the message to log. +// Returns: +// True if there is quota available, False otherwise. +static PyObject* ApplyDynamicLogsQuota(PyObject* self, PyObject* py_args) { + LazyInitializeRateLimit(); + int num_bytes = -1; + if (!PyArg_ParseTuple(py_args, "i", &num_bytes) || num_bytes < 1) { + Py_RETURN_FALSE; + } + + LeakyBucket* global_dynamic_log_limiter = GetGlobalDynamicLogQuota(); + LeakyBucket* global_dynamic_log_bytes_limiter = + GetGlobalDynamicLogBytesQuota(); + + if (global_dynamic_log_limiter->RequestTokens(1) && + global_dynamic_log_bytes_limiter->RequestTokens(num_bytes)) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} static PyMethodDef g_module_functions[] = { { - "InitializeModule", - InitializeModule, - METH_VARARGS, - "Initialize C++ flags and logging." + "InitializeModule", + InitializeModule, + METH_VARARGS, + "Initialize C++ flags and logging." }, { "LogInfo", @@ -403,23 +388,22 @@ static PyMethodDef g_module_functions[] = { "ERROR level logging from Python code." }, { - "HasSourceLine", - HasSourceLine, + "CreateConditionalBreakpoint", + CreateConditionalBreakpoint, METH_VARARGS, - "Checks whether Python code object includes the specified source " - "line number." + "Creates a new breakpoint in Python code." }, { - "SetConditionalBreakpoint", - SetConditionalBreakpoint, + "ActivateConditionalBreakpoint", + ActivateConditionalBreakpoint, METH_VARARGS, - "Sets a new breakpoint in Python code." + "Activates previously created breakpoint in Python code." }, { "ClearConditionalBreakpoint", ClearConditionalBreakpoint, METH_VARARGS, - "Clears previously set breakpoint in Python code." + "Clears previously created breakpoint in Python code." }, { "CallImmutable", @@ -427,42 +411,76 @@ static PyMethodDef g_module_functions[] = { METH_VARARGS, "Invokes a Python callable object with immutability tracer." }, + { + "ApplyDynamicLogsQuota", + ApplyDynamicLogsQuota, + METH_VARARGS, + "Applies the dynamic log quota" + }, { nullptr, nullptr, 0, nullptr } // sentinel }; -void InitDebuggerNativeModule() { +#if PY_MAJOR_VERSION >= 3 +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, /** m_base */ + CDBG_MODULE_NAME, /** m_name */ + "Native module for Python Cloud Debugger", /** m_doc */ + -1, /** m_size */ + g_module_functions, /** m_methods */ + NULL, /** m_slots */ + NULL, /** m_traverse */ + NULL, /** m_clear */ + NULL /** m_free */ +}; + +PyObject* InitDebuggerNativeModuleInternal() { + PyObject* module = PyModule_Create(&moduledef); +#else +PyObject* InitDebuggerNativeModuleInternal() { PyObject* module = Py_InitModule3( CDBG_MODULE_NAME, g_module_functions, "Native module for Python Cloud Debugger"); +#endif SetDebugletModule(module); if (!RegisterPythonType() || !RegisterPythonType()) { - return; + return nullptr; } // Add constants we want to share with the Python code. - for (uint32 i = 0; i < arraysize(kIntegerConstants); ++i) { + for (uint32_t i = 0; i < arraysize(kIntegerConstants); ++i) { if (PyModule_AddObject( module, kIntegerConstants[i].name, PyInt_FromLong(kIntegerConstants[i].value))) { LOG(ERROR) << "Failed to constant " << kIntegerConstants[i].name << " to native module"; - return; + return nullptr; } } + + return module; +} + +void InitDebuggerNativeModule() { + InitDebuggerNativeModuleInternal(); } } // namespace cdbg } // namespace devtools - // This function is called to initialize the module. +#if PY_MAJOR_VERSION >= 3 +PyMODINIT_FUNC PyInit_cdbg_native() { + return devtools::cdbg::InitDebuggerNativeModuleInternal(); +} +#else PyMODINIT_FUNC initcdbg_native() { devtools::cdbg::InitDebuggerNativeModule(); } +#endif diff --git a/src/googleclouddebugger/nullable.h b/src/googleclouddebugger/nullable.h index a70ddb7..88703c3 100644 --- a/src/googleclouddebugger/nullable.h +++ b/src/googleclouddebugger/nullable.h @@ -17,7 +17,6 @@ #ifndef DEVTOOLS_CDBG_DEBUGLETS_PYTHON_NULLABLE_H_ #define DEVTOOLS_CDBG_DEBUGLETS_PYTHON_NULLABLE_H_ - #include "common.h" namespace devtools { diff --git a/src/googleclouddebugger/python_breakpoint.py b/src/googleclouddebugger/python_breakpoint.py index 9651dd2..62f2512 100644 --- a/src/googleclouddebugger/python_breakpoint.py +++ b/src/googleclouddebugger/python_breakpoint.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Handles a single Python breakpoint.""" from datetime import datetime @@ -19,50 +18,123 @@ import os from threading import Lock -import capture_collector -import cdbg_native as native -import deferred_modules -import module_explorer -import module_lookup - -# TODO(vlif): move to messages.py module. -BREAKPOINT_ONLY_SUPPORTS_PY_FILES = ( - 'Only files with .py or .pyc extension are supported') -MODULE_NOT_FOUND = ( - 'Python module not found') -NO_CODE_FOUND_AT_LINE = ( - 'No code found at line $0') -GLOBAL_CONDITION_QUOTA_EXCEEDED = ( +from . import collector +from . import cdbg_native as native +from . import imphook +from . import module_explorer +from . import module_search +from . import module_utils + +# TODO: move to messages.py module. +# Use the following schema to define breakpoint error message constant: +# ERROR___ +ERROR_LOCATION_FILE_EXTENSION_0 = ( + 'Only files with .py extension are supported') +ERROR_LOCATION_MODULE_NOT_FOUND_0 = ( + 'Python module not found. Please ensure this file is present in the ' + 'version of the service you are trying to debug.') +ERROR_LOCATION_MULTIPLE_MODULES_1 = ( + 'Multiple modules matching $0. Please specify the module path.') +ERROR_LOCATION_MULTIPLE_MODULES_3 = ('Multiple modules matching $0 ($1, $2)') +ERROR_LOCATION_MULTIPLE_MODULES_4 = ( + 'Multiple modules matching $0 ($1, $2, and $3 more)') +ERROR_LOCATION_NO_CODE_FOUND_AT_LINE_2 = 'No code found at line $0 in $1' +ERROR_LOCATION_NO_CODE_FOUND_AT_LINE_3 = ( + 'No code found at line $0 in $1. Try line $2.') +ERROR_LOCATION_NO_CODE_FOUND_AT_LINE_4 = ( + 'No code found at line $0 in $1. Try lines $2 or $3.') +ERROR_CONDITION_GLOBAL_QUOTA_EXCEEDED_0 = ( 'Snapshot cancelled. The condition evaluation cost for all active ' 'snapshots might affect the application performance.') -BREAKPOINT_CONDITION_QUOTA_EXCEEDED = ( +ERROR_CONDITION_BREAKPOINT_QUOTA_EXCEEDED_0 = ( 'Snapshot cancelled. The condition evaluation at this location might ' 'affect application performance. Please simplify the condition or move ' 'the snapshot to a less frequently called statement.') -MUTABLE_CONDITION = ( +ERROR_CONDITION_MUTABLE_0 = ( 'Only immutable expressions can be used in snapshot conditions') -BREAKPOINT_EXPIRED = ( - 'The snapshot has expired') -INTERNAL_ERROR = ( - 'Internal error occurred') +ERROR_AGE_SNAPSHOT_EXPIRED_0 = ('The snapshot has expired') +ERROR_AGE_LOGPOINT_EXPIRED_0 = ('The logpoint has expired') +ERROR_UNSPECIFIED_INTERNAL_ERROR = ('Internal error occurred') # Status messages for different breakpoint events (except of "hit"). -_BREAKPOINT_EVENT_STATUS = dict( - [(native.BREAKPOINT_EVENT_ERROR, - {'isError': True, - 'description': {'format': INTERNAL_ERROR}}), - (native.BREAKPOINT_EVENT_GLOBAL_CONDITION_QUOTA_EXCEEDED, - {'isError': True, - 'refersTo': 'BREAKPOINT_CONDITION', - 'description': {'format': GLOBAL_CONDITION_QUOTA_EXCEEDED}}), - (native.BREAKPOINT_EVENT_BREAKPOINT_CONDITION_QUOTA_EXCEEDED, - {'isError': True, - 'refersTo': 'BREAKPOINT_CONDITION', - 'description': {'format': BREAKPOINT_CONDITION_QUOTA_EXCEEDED}}), - (native.BREAKPOINT_EVENT_CONDITION_EXPRESSION_MUTABLE, - {'isError': True, - 'refersTo': 'BREAKPOINT_CONDITION', - 'description': {'format': MUTABLE_CONDITION}})]) +_BREAKPOINT_EVENT_STATUS = dict([ + (native.BREAKPOINT_EVENT_ERROR, { + 'isError': True, + 'description': { + 'format': ERROR_UNSPECIFIED_INTERNAL_ERROR + } + }), + (native.BREAKPOINT_EVENT_GLOBAL_CONDITION_QUOTA_EXCEEDED, { + 'isError': True, + 'refersTo': 'BREAKPOINT_CONDITION', + 'description': { + 'format': ERROR_CONDITION_GLOBAL_QUOTA_EXCEEDED_0 + } + }), + (native.BREAKPOINT_EVENT_BREAKPOINT_CONDITION_QUOTA_EXCEEDED, { + 'isError': True, + 'refersTo': 'BREAKPOINT_CONDITION', + 'description': { + 'format': ERROR_CONDITION_BREAKPOINT_QUOTA_EXCEEDED_0 + } + }), + (native.BREAKPOINT_EVENT_CONDITION_EXPRESSION_MUTABLE, { + 'isError': True, + 'refersTo': 'BREAKPOINT_CONDITION', + 'description': { + 'format': ERROR_CONDITION_MUTABLE_0 + } + }) +]) + +# The implementation of datetime.strptime imports an undocumented module called +# _strptime. If it happens at the wrong time, we can get an exception about +# trying to import while another thread holds the import lock. This dummy call +# to strptime ensures that the module is loaded at startup. +# See http://bugs.python.org/issue7980 for discussion of the Python bug. +datetime.strptime('2017-01-01', '%Y-%m-%d') + + +def _IsRootInitPy(path): + return path.lstrip(os.sep) == '__init__.py' + + +def _StripCommonPathPrefix(paths): + """Removes path common prefix from a list of path strings.""" + # Find the longest common prefix in terms of characters. + common_prefix = os.path.commonprefix(paths) + # Truncate at last segment boundary. E.g. '/aa/bb1/x.py' and '/a/bb2/x.py' + # have '/aa/bb' as the common prefix, but we should strip '/aa/' instead. + # If there's no '/' found, returns -1+1=0. + common_prefix_len = common_prefix.rfind('/') + 1 + return [path[common_prefix_len:] for path in paths] + + +def _MultipleModulesFoundError(path, candidates): + """Generates an error message to be used when multiple matches are found. + + Args: + path: The breakpoint location path that the user provided. + candidates: List of paths that match the user provided path. Must + contain at least 2 entries (throws AssertionError otherwise). + + Returns: + A (format, parameters) tuple that should be used in the description + field of the breakpoint error status. + """ + assert len(candidates) > 1 + params = [path] + _StripCommonPathPrefix(candidates[:2]) + if len(candidates) == 2: + fmt = ERROR_LOCATION_MULTIPLE_MODULES_3 + else: + fmt = ERROR_LOCATION_MULTIPLE_MODULES_4 + params.append(str(len(candidates) - 2)) + return fmt, params + + +def _NormalizePath(path): + """Removes surrounding whitespace, leading separator and normalize.""" + return module_utils.NormalizePath(path.strip().lstrip(os.sep)) class PythonBreakpoint(object): @@ -74,23 +146,31 @@ class PythonBreakpoint(object): to log a statement. """ - def __init__(self, definition, hub_client, breakpoints_manager): + def __init__(self, definition, hub_client, breakpoints_manager, + data_visibility_policy): """Class constructor. Tries to set the breakpoint. If the source location is invalid, the breakpoint is completed with an error message. If the source location is - valid, but the module hasn't been loaded yet, the breakpoint is initialized - as deferred. + valid, but the module hasn't been loaded yet, the breakpoint is deferred. Args: definition: breakpoint definition as it came from the backend. hub_client: asynchronously sends breakpoint updates to the backend. breakpoints_manager: parent object managing active breakpoints. + data_visibility_policy: An object used to determine the visibility + of a captured variable. May be None if no policy is available. """ self.definition = definition + self.data_visibility_policy = data_visibility_policy + # Breakpoint expiration time. self.expiration_period = timedelta(hours=24) + if self.definition.get('expires_in'): + self.expiration_period = min( + timedelta(definition.get('expires_in').get('seconds', 0)), + self.expiration_period) self._hub_client = hub_client self._breakpoints_manager = breakpoints_manager @@ -100,8 +180,46 @@ def __init__(self, definition, hub_client, breakpoints_manager): self._lock = Lock() self._completed = False - if not self._TryActivateBreakpoint() and not self._completed: - self._DeferBreakpoint() + if self.definition.get('action') == 'LOG': + self._collector = collector.LogCollector(self.definition) + + path = _NormalizePath(self.definition['location']['path']) + + # Only accept .py extension. + if os.path.splitext(path)[1] != '.py': + self._CompleteBreakpoint({ + 'status': { + 'isError': True, + 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', + 'description': { + 'format': ERROR_LOCATION_FILE_EXTENSION_0 + } + } + }) + return + + # A flat init file is too generic; path must include package name. + if path == '__init__.py': + self._CompleteBreakpoint({ + 'status': { + 'isError': True, + 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', + 'description': { + 'format': ERROR_LOCATION_MULTIPLE_MODULES_1, + 'parameters': [path] + } + } + }) + return + + new_path = module_search.Search(path) + new_module = module_utils.GetLoadedModuleBySuffix(new_path) + + if new_module: + self._ActivateBreakpoint(new_module) + else: + self._import_hook_cleanup = imphook.AddImportCallbackBySuffix( + new_path, self._ActivateBreakpoint) def Clear(self): """Clears the breakpoint and releases all breakpoint resources. @@ -121,11 +239,40 @@ def GetBreakpointId(self): return self.definition['id'] def GetExpirationTime(self): - """Computes the timestamp at which this breakpoint will expire.""" - create_datetime = datetime.strptime( - self.definition['createTime'].replace('Z', 'UTC'), - '%Y-%m-%dT%H:%M:%S.%f%Z') - return create_datetime + self.expiration_period + """Computes the timestamp at which this breakpoint will expire. + + If no creation time can be found an expiration time in the past will be + used. + """ + return self.GetCreateTime() + self.expiration_period + + def GetCreateTime(self): + """Retrieves the creation time of this breakpoint. + + If no creation time can be found a creation time in the past will be used. + """ + if 'createTime' in self.definition: + return self.GetTimeFromRfc3339Str(self.definition['createTime']) + else: + return self.GetTimeFromUnixMsec( + self.definition.get('createTimeUnixMsec', 0)) + + def GetTimeFromRfc3339Str(self, rfc3339_str): + if '.' not in rfc3339_str: + fmt = '%Y-%m-%dT%H:%M:%S%Z' + else: + fmt = '%Y-%m-%dT%H:%M:%S.%f%Z' + + return datetime.strptime(rfc3339_str.replace('Z', 'UTC'), fmt) + + def GetTimeFromUnixMsec(self, unix_msec): + try: + return datetime.fromtimestamp(unix_msec / 1000) + except (TypeError, ValueError, OSError, OverflowError) as e: + native.LogWarning( + 'Unexpected error (%s) occured processing unix_msec %s, breakpoint: %s' + % (repr(e), str(unix_msec), self.GetBreakpointId())) + return datetime.fromtimestamp(0) def ExpireBreakpoint(self): """Expires this breakpoint.""" @@ -133,49 +280,81 @@ def ExpireBreakpoint(self): if not self._SetCompleted(): return + if self.definition.get('action') == 'LOG': + message = ERROR_AGE_LOGPOINT_EXPIRED_0 + else: + message = ERROR_AGE_SNAPSHOT_EXPIRED_0 self._CompleteBreakpoint({ 'status': { 'isError': True, - 'refersTo': 'UNSPECIFIED', - 'description': {'format': BREAKPOINT_EXPIRED}}}) - - def _TryActivateBreakpoint(self): - """Sets the breakpoint if the module has already been loaded. + 'refersTo': 'BREAKPOINT_AGE', + 'description': { + 'format': message + } + } + }) - This function will complete the breakpoint with error if breakpoint - definition is incorrect. Examples: invalid line or bad condition. + def _ActivateBreakpoint(self, module): + """Sets the breakpoint in the loaded module, or complete with error.""" - If the code object corresponding to the source path can't be found, - this function returns False. In this case, the breakpoint is not - completed, since the breakpoint may be deferred. + # First remove the import hook (if installed). + self._RemoveImportHook() - Returns: - True if breakpoint was set or false otherwise. False can be returned - for potentially deferred breakpoints or in case of a bad breakpoint - definition. The self._completed flag distinguishes between the two cases. - """ + line = self.definition['location']['line'] # Find the code object in which the breakpoint is being set. - code_object = self._FindCodeObject() - if not code_object: - return False + status, codeobj = module_explorer.GetCodeObjectAtLine(module, line) + if not status: + # First two parameters are common: the line of the breakpoint and the + # module we are trying to insert the breakpoint in. + # TODO: Do not display the entire path of the file. Either + # strip some prefix, or display the path in the breakpoint. + params = [str(line), os.path.splitext(module.__file__)[0] + '.py'] + + # The next 0, 1, or 2 parameters are the alternative lines to set the + # breakpoint at, displayed for the user's convenience. + alt_lines = (str(l) for l in codeobj if l is not None) + params += alt_lines + + if len(params) == 4: + fmt = ERROR_LOCATION_NO_CODE_FOUND_AT_LINE_4 + elif len(params) == 3: + fmt = ERROR_LOCATION_NO_CODE_FOUND_AT_LINE_3 + else: + fmt = ERROR_LOCATION_NO_CODE_FOUND_AT_LINE_2 + + self._CompleteBreakpoint({ + 'status': { + 'isError': True, + 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', + 'description': { + 'format': fmt, + 'parameters': params + } + } + }) + return # Compile the breakpoint condition. condition = None if self.definition.get('condition'): try: - condition = compile(self.definition.get('condition'), - '', - 'eval') - except TypeError as e: # condition string contains null bytes. + condition = compile( + self.definition.get('condition'), '', 'eval') + except (TypeError, ValueError) as e: + # condition string contains null bytes. self._CompleteBreakpoint({ 'status': { 'isError': True, 'refersTo': 'BREAKPOINT_CONDITION', 'description': { 'format': 'Invalid expression', - 'parameters': [str(e)]}}}) - return False + 'parameters': [str(e)] + } + } + }) + return + except SyntaxError as e: self._CompleteBreakpoint({ 'status': { @@ -183,90 +362,19 @@ def _TryActivateBreakpoint(self): 'refersTo': 'BREAKPOINT_CONDITION', 'description': { 'format': 'Expression could not be compiled: $0', - 'parameters': [e.msg]}}}) - return False - - line = self.definition['location']['line'] - - native.LogInfo('Creating new Python breakpoint %s in %s, line %d' % ( - self.GetBreakpointId(), code_object, line)) - - self._cookie = native.SetConditionalBreakpoint( - code_object, - line, - condition, - self._BreakpointEvent) - - return True - - def _FindCodeObject(self): - """Finds the target code object for the breakpoint. - - This function completes breakpoint with error if the module was found, - but the line number is invalid. When code object is not found for the - breakpoint source location, this function just returns None. It does not - assume error, because it might be a deferred breakpoint. - - Returns: - Python code object object in which the breakpoint will be set or None if - module not found or if there is no code at the specified line. - """ - path = self.definition['location']['path'] - line = self.definition['location']['line'] - - module = module_lookup.FindModule(path) - if not module: - return None - - code_object = module_explorer.GetCodeObjectAtLine(module, line) - if code_object is None: - self._CompleteBreakpoint({ - 'status': { - 'isError': True, - 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', - 'description': { - 'format': NO_CODE_FOUND_AT_LINE, - 'parameters': [str(line)]}}}) - return None - - return code_object - - # Enables deferred breakpoints. - def _DeferBreakpoint(self): - """Defers breakpoint activation until the module has been loaded. - - This function first verifies that a module corresponding to breakpoint - location exists. This way if the user sets breakpoint in a file that - doesn't even exist, the debugger will not be waiting forever. If there - is definitely no module that matches this breakpoint, this function - completes the breakpoint with error status. - - Otherwise the debugger assumes that the module corresponding to breakpoint - location hasn't been loaded yet. The debugger will then start waiting for - the module to get loaded. Once the module is loaded, the debugger - will automatically try to activate the breakpoint. - """ - path = self.definition['location']['path'] + 'parameters': [e.msg] + } + } + }) + return - if os.path.splitext(path)[1] != '.py': - self._CompleteBreakpoint({ - 'status': { - 'isError': True, - 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', - 'description': {'format': BREAKPOINT_ONLY_SUPPORTS_PY_FILES}}}) - return + native.LogInfo('Creating new Python breakpoint %s in %s, line %d' % + (self.GetBreakpointId(), codeobj, line)) - if not deferred_modules.IsValidSourcePath(path): - self._CompleteBreakpoint({ - 'status': { - 'isError': True, - 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', - 'description': {'format': MODULE_NOT_FOUND}}}) + self._cookie = native.CreateConditionalBreakpoint(codeobj, line, condition, + self._BreakpointEvent) - assert not self._import_hook_cleanup - self._import_hook_cleanup = deferred_modules.AddImportCallback( - self.definition['location']['path'], - lambda unused_module_name: self._TryActivateBreakpoint()) + native.ActivateConditionalBreakpoint(self._cookie) def _RemoveImportHook(self): """Removes the import hook if one was installed.""" @@ -309,8 +417,7 @@ def _BreakpointEvent(self, event, frame): if event != native.BREAKPOINT_EVENT_HIT: error_status = _BREAKPOINT_EVENT_STATUS[event] elif self.definition.get('action') == 'LOG': - collector = capture_collector.LogCollector(self.definition) - error_status = collector.Log(frame) + error_status = self._collector.Log(frame) if not error_status: return # Log action successful, no need to clear the breakpoint. @@ -324,7 +431,32 @@ def _BreakpointEvent(self, event, frame): self._CompleteBreakpoint({'status': error_status}) return - collector = capture_collector.CaptureCollector(self.definition) - collector.Collect(frame) + capture_collector = collector.CaptureCollector(self.definition, + self.data_visibility_policy) + + # TODO: This is a temporary try/except. All exceptions should be + # caught inside Collect and converted into breakpoint error messages. + try: + capture_collector.Collect(frame) + except BaseException as e: # pylint: disable=broad-except + native.LogInfo('Internal error during data capture: %s' % repr(e)) + error_status = { + 'isError': True, + 'description': { + 'format': ('Internal error while capturing data: %s' % repr(e)) + } + } + self._CompleteBreakpoint({'status': error_status}) + return + except: # pylint: disable=bare-except + native.LogInfo('Unknown exception raised') + error_status = { + 'isError': True, + 'description': { + 'format': 'Unknown internal error' + } + } + self._CompleteBreakpoint({'status': error_status}) + return - self._CompleteBreakpoint(collector.breakpoint, is_incremental=False) + self._CompleteBreakpoint(capture_collector.breakpoint, is_incremental=False) diff --git a/src/googleclouddebugger/python_callback.h b/src/googleclouddebugger/python_callback.h index 9c86fb6..2e258f3 100644 --- a/src/googleclouddebugger/python_callback.h +++ b/src/googleclouddebugger/python_callback.h @@ -18,6 +18,7 @@ #define DEVTOOLS_CDBG_DEBUGLETS_PYTHON_PYTHON_CALLBACK_H_ #include + #include "common.h" #include "python_util.h" diff --git a/src/googleclouddebugger/python_util.cc b/src/googleclouddebugger/python_util.cc index 5ba1b2b..bc03bfc 100644 --- a/src/googleclouddebugger/python_util.cc +++ b/src/googleclouddebugger/python_util.cc @@ -19,6 +19,15 @@ #include "python_util.h" +#include + +#include + +#if PY_VERSION_HEX >= 0x030A0000 +#include "../third_party/pylinetable.h" +#endif // PY_VERSION_HEX >= 0x030A0000 + + namespace devtools { namespace cdbg { @@ -28,25 +37,29 @@ static PyObject* g_debuglet_module = nullptr; CodeObjectLinesEnumerator::CodeObjectLinesEnumerator( PyCodeObject* code_object) { +#if PY_VERSION_HEX < 0x030A0000 Initialize(code_object->co_firstlineno, code_object->co_lnotab); +#else + Initialize(code_object->co_firstlineno, code_object->co_linetable); +#endif // PY_VERSION_HEX < 0x030A0000 } CodeObjectLinesEnumerator::CodeObjectLinesEnumerator( int firstlineno, - PyObject* lnotab) { - Initialize(firstlineno, lnotab); + PyObject* linedata) { + Initialize(firstlineno, linedata); } +#if PY_VERSION_HEX < 0x030A0000 void CodeObjectLinesEnumerator::Initialize( int firstlineno, PyObject* lnotab) { offset_ = 0; line_number_ = firstlineno; - remaining_entries_ = PyString_Size(lnotab) / 2; - next_entry_ = - reinterpret_cast(PyString_AsString(lnotab)); + remaining_entries_ = PyBytes_Size(lnotab) / 2; + next_entry_ = reinterpret_cast(PyBytes_AsString(lnotab)); // If the line table starts with offset 0, the first line is not // "code_object->co_firstlineno", but the following line. @@ -66,7 +79,7 @@ bool CodeObjectLinesEnumerator::Next() { while (true) { offset_ += next_entry_[0]; - line_number_ += next_entry_[1]; + line_number_ += static_cast(next_entry_[1]); bool stop = ((next_entry_[0] != 0xFF) || (next_entry_[1] != 0)) && ((next_entry_[0] != 0) || (next_entry_[1] != 0xFF)); @@ -83,7 +96,26 @@ bool CodeObjectLinesEnumerator::Next() { } } } +#else +void CodeObjectLinesEnumerator::Initialize( + int firstlineno, + PyObject* linetable) { + Py_ssize_t length = PyBytes_Size(linetable); + _PyLineTable_InitAddressRange(PyBytes_AsString(linetable), length, firstlineno, &range_); +} + +bool CodeObjectLinesEnumerator::Next() { + while (_PyLineTable_NextAddressRange(&range_)) { + if (range_.ar_line >= 0) { + line_number_ = range_.ar_line; + offset_ = range_.ar_start; + return true; + } + } + return false; +} +#endif // PY_VERSION_HEX < 0x030A0000 PyObject* GetDebugletModule() { DCHECK(g_debuglet_module != nullptr); @@ -100,8 +132,12 @@ void SetDebugletModule(PyObject* module) { PyTypeObject DefaultTypeDefinition(const char* type_name) { return { +#if PY_MAJOR_VERSION >= 3 + PyVarObject_HEAD_INIT(nullptr, /* ob_size */ 0) +#else PyObject_HEAD_INIT(nullptr) 0, /* ob_size */ +#endif type_name, /* tp_name */ 0, /* tp_basicsize */ 0, /* tp_itemsize */ @@ -168,29 +204,36 @@ bool RegisterPythonType(PyTypeObject* type) { return true; } - -Nullable ClearPythonException() { +Nullable ClearPythonException() { PyObject* exception_obj = PyErr_Occurred(); if (exception_obj == nullptr) { - return Nullable(); // return nullptr. + return Nullable(); // return nullptr. } - // TODO(vlif): call str(exception_obj) with a verification of immutability + // TODO: call str(exception_obj) with a verification of immutability // that the object state is not being altered. auto exception_type = reinterpret_cast(exception_obj->ob_type); - string msg = exception_type->tp_name; + std::string msg = exception_type->tp_name; #ifndef NDEBUG PyErr_Print(); +#else + static constexpr time_t EXCEPTION_THROTTLE_SECONDS = 30; + static time_t last_exception_reported = 0; + + time_t current_time = time(nullptr); + if (current_time - last_exception_reported >= EXCEPTION_THROTTLE_SECONDS) { + last_exception_reported = current_time; + PyErr_Print(); + } #endif // NDEBUG PyErr_Clear(); - return Nullable(msg); + return Nullable(msg); } - PyObject* GetDebugletModuleObject(const char* key) { PyObject* module_dict = PyModule_GetDict(GetDebugletModule()); if (module_dict == nullptr) { @@ -207,8 +250,7 @@ PyObject* GetDebugletModuleObject(const char* key) { return object; } - -string CodeObjectDebugString(PyCodeObject* code_object) { +std::string CodeObjectDebugString(PyCodeObject* code_object) { if (code_object == nullptr) { return ""; } @@ -217,38 +259,36 @@ string CodeObjectDebugString(PyCodeObject* code_object) { return ""; } - string str; + std::string str; if ((code_object->co_name != nullptr) && - PyString_CheckExact(code_object->co_name)) { - str += PyString_AS_STRING(code_object->co_name); + PyBytes_CheckExact(code_object->co_name)) { + str += PyBytes_AS_STRING(code_object->co_name); } else { str += ""; } str += ':'; - str += std::to_string(static_cast(code_object->co_firstlineno)); + str += std::to_string(static_cast(code_object->co_firstlineno)); if ((code_object->co_filename != nullptr) && - PyString_CheckExact(code_object->co_filename)) { + PyBytes_CheckExact(code_object->co_filename)) { str += " at "; - str += PyString_AS_STRING(code_object->co_filename); + str += PyBytes_AS_STRING(code_object->co_filename); } return str; } +std::vector PyBytesToByteArray(PyObject* obj) { + DCHECK(PyBytes_CheckExact(obj)); -std::vector PyStringToByteArray(PyObject* obj) { - DCHECK(PyString_CheckExact(obj)); - - const size_t bytecode_size = PyString_GET_SIZE(obj); - const uint8* const bytecode_data = - reinterpret_cast(PyString_AS_STRING(obj)); - return std::vector(bytecode_data, bytecode_data + bytecode_size); + const size_t bytecode_size = PyBytes_GET_SIZE(obj); + const uint8_t* const bytecode_data = + reinterpret_cast(PyBytes_AS_STRING(obj)); + return std::vector(bytecode_data, bytecode_data + bytecode_size); } - // Creates a new tuple by appending "items" to elements in "tuple". ScopedPyObject AppendTuple( PyObject* tuple, diff --git a/src/googleclouddebugger/python_util.h b/src/googleclouddebugger/python_util.h index 1275194..10116be 100644 --- a/src/googleclouddebugger/python_util.h +++ b/src/googleclouddebugger/python_util.h @@ -17,8 +17,10 @@ #ifndef DEVTOOLS_CDBG_DEBUGLETS_PYTHON_PYTHON_UTIL_H_ #define DEVTOOLS_CDBG_DEBUGLETS_PYTHON_PYTHON_UTIL_H_ +#include #include #include + #include "common.h" #include "nullable.h" @@ -69,7 +71,13 @@ class ScopedPyObjectT { } ~ScopedPyObjectT() { - reset(nullptr); + // Only do anything if Python is running. If not, we get might get a + // segfault when we try to decrement the reference count of the underlying + // object when this destructor is run after Python itself has cleaned up. + // https://bugs.python.org/issue17703 + if (Py_IsInitialized()) { + reset(nullptr); + } } static ScopedPyObjectT NewReference(TPointer* obj) { @@ -170,36 +178,43 @@ class CodeObjectLinesEnumerator { explicit CodeObjectLinesEnumerator(PyCodeObject* code_object); // Uses explicitly provided line table. - CodeObjectLinesEnumerator(int firstlineno, PyObject* lnotab); + CodeObjectLinesEnumerator(int firstlineno, PyObject* linedata); // Moves over to the next entry in code object line table. bool Next(); // Gets the bytecode offset of the current line. - int32 offset() const { return offset_; } + int32_t offset() const { return offset_; } // Gets the current source code line number. - int32 line_number() const { return line_number_; } + int32_t line_number() const { return line_number_; } private: - void Initialize(int firstlineno, PyObject* lnotab); + void Initialize(int firstlineno, PyObject* linedata); private: + // Bytecode offset of the current line. + int32_t offset_; + + // Current source code line number + int32_t line_number_; + +#if PY_VERSION_HEX < 0x030A0000 // Number of remaining entries in line table. int remaining_entries_; // Pointer to the next entry of line table. - const uint8* next_entry_; + const uint8_t* next_entry_; - // Bytecode offset of the current line. - int32 offset_; - - // Current source code line number - int32 line_number_; +#else + // Current address range in the linetable data. + PyCodeAddressRange range_; +#endif DISALLOW_COPY_AND_ASSIGN(CodeObjectLinesEnumerator); }; + template bool operator== (TPointer* ref1, const ScopedPyObjectT& ref2) { return ref2 == ref1; @@ -296,18 +311,18 @@ ScopedPyObject NewNativePythonObject() { // Checks whether the previous call generated an exception. If not, returns // nullptr. Otherwise formats the exception to string. -Nullable ClearPythonException(); +Nullable ClearPythonException(); // Gets Python object from dictionary of a native module. Returns nullptr if not // found. In case of success returns borrowed reference. PyObject* GetDebugletModuleObject(const char* key); // Formats the name and the origin of the code object for logging. -string CodeObjectDebugString(PyCodeObject* code_object); +std::string CodeObjectDebugString(PyCodeObject* code_object); // Reads Python string as a byte array. The function does not verify that // "obj" is of a string type. -std::vector PyStringToByteArray(PyObject* obj); +std::vector PyBytesToByteArray(PyObject* obj); // Creates a new tuple by appending "items" to elements in "tuple". ScopedPyObject AppendTuple( diff --git a/src/googleclouddebugger/rate_limit.cc b/src/googleclouddebugger/rate_limit.cc index babae25..80b7c47 100644 --- a/src/googleclouddebugger/rate_limit.cc +++ b/src/googleclouddebugger/rate_limit.cc @@ -19,21 +19,27 @@ #include "rate_limit.h" -DEFINE_int64( - max_trace_rate, - 25000, - "maximum number of Python trace callbacks per second before all " - "breakpoints are disabled"); - -DEFINE_int32( - max_condition_lines_rate, - 5000, +#include + +ABSL_FLAG( + int32, max_condition_lines_rate, 5000, "maximum number of Python lines/sec to spend on condition evaluation"); +ABSL_FLAG( + int32, max_dynamic_log_rate, + 50, // maximum of 50 log entries per second on average + "maximum rate of dynamic log entries in this process; short bursts are " + "allowed to exceed this limit"); + +ABSL_FLAG(int32, max_dynamic_log_bytes_rate, + 20480, // maximum of 20K bytes per second on average + "maximum rate of dynamic log bytes in this process; short bursts are " + "allowed to exceed this limit"); + namespace devtools { namespace cdbg { -// Define capacity of "trace_quota_" leaky bucket: +// Define capacity of leaky bucket: // capacity = fill_rate * capacity_factor // // The capacity is conceptually unrelated to fill rate, but we don't want to @@ -44,41 +50,41 @@ namespace cdbg { // debugger wil not impact the service throughput. Longer values will allow the // burst, and will only disable the breakpoint if CPU consumption due to // debugger is continuous for a prolonged period of time. -static const double kMaxTraceRateCapacityFactor = 10; static const double kConditionCostCapacityFactor = 0.1; +static const double kDynamicLogCapacityFactor = 5; +static const double kDynamicLogBytesCapacityFactor = 2; -static std::unique_ptr g_trace_quota; static std::unique_ptr g_global_condition_quota; +static std::unique_ptr g_global_dynamic_log_quota; +static std::unique_ptr g_global_dynamic_log_bytes_quota; - -static int64 GetBaseConditionQuotaCapacity() { - return FLAGS_max_condition_lines_rate * kConditionCostCapacityFactor; +static int64_t GetBaseConditionQuotaCapacity() { + return absl::GetFlag(FLAGS_max_condition_lines_rate) * + kConditionCostCapacityFactor; } - void LazyInitializeRateLimit() { - if (g_trace_quota == nullptr) { - g_trace_quota.reset(new LeakyBucket( - FLAGS_max_trace_rate * kMaxTraceRateCapacityFactor, - FLAGS_max_trace_rate)); - } - if (g_global_condition_quota == nullptr) { - g_global_condition_quota.reset(new LeakyBucket( - GetBaseConditionQuotaCapacity(), - FLAGS_max_condition_lines_rate)); + g_global_condition_quota.reset( + new LeakyBucket(GetBaseConditionQuotaCapacity(), + absl::GetFlag(FLAGS_max_condition_lines_rate))); + + g_global_dynamic_log_quota.reset(new LeakyBucket( + absl::GetFlag(FLAGS_max_dynamic_log_rate) * kDynamicLogCapacityFactor, + absl::GetFlag(FLAGS_max_dynamic_log_rate))); + + g_global_dynamic_log_bytes_quota.reset( + new LeakyBucket(absl::GetFlag(FLAGS_max_dynamic_log_bytes_rate) * + kDynamicLogBytesCapacityFactor, + absl::GetFlag(FLAGS_max_dynamic_log_bytes_rate))); } } void CleanupRateLimit() { - g_trace_quota = nullptr; g_global_condition_quota = nullptr; -} - - -LeakyBucket* GetTraceQuota() { - return g_trace_quota.get(); + g_global_dynamic_log_quota = nullptr; + g_global_dynamic_log_bytes_quota = nullptr; } @@ -86,11 +92,18 @@ LeakyBucket* GetGlobalConditionQuota() { return g_global_condition_quota.get(); } +LeakyBucket* GetGlobalDynamicLogQuota() { + return g_global_dynamic_log_quota.get(); +} + +LeakyBucket* GetGlobalDynamicLogBytesQuota() { + return g_global_dynamic_log_bytes_quota.get(); +} std::unique_ptr CreatePerBreakpointConditionQuota() { - return std::unique_ptr(new LeakyBucket( - GetBaseConditionQuotaCapacity() / 2, - FLAGS_max_condition_lines_rate / 2)); + return std::unique_ptr( + new LeakyBucket(GetBaseConditionQuotaCapacity() / 2, + absl::GetFlag(FLAGS_max_condition_lines_rate) / 2)); } } // namespace cdbg diff --git a/src/googleclouddebugger/rate_limit.h b/src/googleclouddebugger/rate_limit.h index ebe2737..a7cf976 100644 --- a/src/googleclouddebugger/rate_limit.h +++ b/src/googleclouddebugger/rate_limit.h @@ -18,6 +18,7 @@ #define DEVTOOLS_CDBG_DEBUGLETS_PYTHON_RATE_LIMIT_H_ #include + #include "leaky_bucket.h" #include "common.h" @@ -30,18 +31,6 @@ void LazyInitializeRateLimit(); // Release quota objects. void CleanupRateLimit(); -// Gets the global quota on number of trace calls per second. Once the quota is -// exceeded we disable all the breakpoints. This is because the overhead is -// due to having trace callback and a specific breakpoint doesn't impact -// much. -// We don't measure total time, because: -// 1. There is an overhead of calling the trace function in CPython. We -// can't measure it. -// 2. Most of these callbacks are too fast to reliably measure. -// The quota is not a function of number of CPUs because Python is inherently -// single threaded. -LeakyBucket* GetTraceQuota(); - // Condition and dynamic logging rate limits are defined as the maximum // number of lines of Python code per second to execute. These rate are enforced // as following: @@ -60,7 +49,8 @@ LeakyBucket* GetTraceQuota(); // single threaded. LeakyBucket* GetGlobalConditionQuota(); std::unique_ptr CreatePerBreakpointConditionQuota(); - +LeakyBucket* GetGlobalDynamicLogQuota(); +LeakyBucket* GetGlobalDynamicLogBytesQuota(); } // namespace cdbg } // namespace devtools diff --git a/src/googleclouddebugger/uniquifier_computer.py b/src/googleclouddebugger/uniquifier_computer.py index 2b307f1..8395f33 100644 --- a/src/googleclouddebugger/uniquifier_computer.py +++ b/src/googleclouddebugger/uniquifier_computer.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Computes a unique identifier of the deployed application. When the application runs under AppEngine, the deployment is uniquely @@ -28,7 +27,6 @@ import os import sys - # Maximum recursion depth to follow when traversing the file system. This limit # will prevent stack overflow in case of a loop created by symbolic links. _MAX_DEPTH = 10 @@ -93,8 +91,7 @@ def ProcessDirectory(path, relative_path, depth=1): modules.add(file_name) ProcessApplicationFile(current_path, os.path.join(relative_path, name)) elif IsPackage(current_path): - ProcessDirectory(current_path, - os.path.join(relative_path, name), + ProcessDirectory(current_path, os.path.join(relative_path, name), depth + 1) def IsPackage(path): @@ -106,12 +103,12 @@ def IsPackage(path): def ProcessApplicationFile(path, relative_path): """Updates the hash with the specified application file.""" - hash_obj.update(relative_path) - hash_obj.update(':') + hash_obj.update(relative_path.encode()) + hash_obj.update(':'.encode()) try: - hash_obj.update(str(os.stat(path).st_size)) + hash_obj.update(str(os.stat(path).st_size).encode()) except BaseException: pass - hash_obj.update('\n') + hash_obj.update('\n'.encode()) ProcessDirectory(sys.path[0], '') diff --git a/src/googleclouddebugger/version.py b/src/googleclouddebugger/version.py new file mode 100644 index 0000000..3b0f00f --- /dev/null +++ b/src/googleclouddebugger/version.py @@ -0,0 +1,7 @@ +"""Version of the Google Python Cloud Debugger.""" + +# Versioning scheme: MAJOR.MINOR +# The major version should only change on breaking changes. Minor version +# changes go between regular updates. Instances running debuggers with +# different major versions will show up as two different debuggees. +__version__ = '4.1' diff --git a/src/googleclouddebugger/yaml_data_visibility_config_reader.py b/src/googleclouddebugger/yaml_data_visibility_config_reader.py new file mode 100644 index 0000000..dc75673 --- /dev/null +++ b/src/googleclouddebugger/yaml_data_visibility_config_reader.py @@ -0,0 +1,144 @@ +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Reads a YAML configuration file to determine visibility policy. + +Example Usage: + try: + config = yaml_data_visibility_config_reader.OpenAndRead(filename) + except yaml_data_visibility_config_reader.Error as e: + ... + + visibility_policy = GlobDataVisibilityPolicy( + config.blacklist_patterns, + config.whitelist_patterns) +""" + +import os +import sys +import yaml + + +class Error(Exception): + """Generic error class that other errors in this module inherit from.""" + pass + + +class YAMLLoadError(Error): + """Thrown when reading an opened file fails.""" + pass + + +class ParseError(Error): + """Thrown when there is a problem with the YAML structure.""" + pass + + +class UnknownConfigKeyError(Error): + """Thrown when the YAML contains an unsupported keyword.""" + pass + + +class NotAListError(Error): + """Thrown when a YAML key does not reference a list.""" + pass + + +class ElementNotAStringError(Error): + """Thrown when a YAML list element is not a string.""" + pass + + +class Config(object): + """Configuration object that Read() returns to the caller.""" + + def __init__(self, blacklist_patterns, whitelist_patterns): + self.blacklist_patterns = blacklist_patterns + self.whitelist_patterns = whitelist_patterns + + +def OpenAndRead(relative_path='debugger-blacklist.yaml'): + """Attempts to find the yaml configuration file, then read it. + + Args: + relative_path: Optional relative path override. + + Returns: + A Config object if the open and read were successful, None if the file + does not exist (which is not considered an error). + + Raises: + Error (some subclass): As thrown by the called Read() function. + """ + + # Note: This logic follows the convention established by source-context.json + try: + with open(os.path.join(sys.path[0], relative_path), 'r') as f: + return Read(f) + except IOError: + return None + + +def Read(f): + """Reads and returns Config data from a yaml file. + + Args: + f: Yaml file to parse. + + Returns: + Config object as defined in this file. + + Raises: + Error (some subclass): If there is a problem loading or parsing the file. + """ + try: + yaml_data = yaml.safe_load(f) + except yaml.YAMLError as e: + raise ParseError('%s' % e) + except IOError as e: + raise YAMLLoadError('%s' % e) + + _CheckData(yaml_data) + + try: + return Config( + yaml_data.get('blacklist', ()), yaml_data.get('whitelist', ('*'))) + except UnicodeDecodeError as e: + raise YAMLLoadError('%s' % e) + + +def _CheckData(yaml_data): + """Checks data for illegal keys and formatting.""" + legal_keys = set(('blacklist', 'whitelist')) + unknown_keys = set(yaml_data) - legal_keys + if unknown_keys: + raise UnknownConfigKeyError('Unknown keys in configuration: %s' % + unknown_keys) + + for key, data in yaml_data.items(): + _AssertDataIsList(key, data) + + +def _AssertDataIsList(key, lst): + """Assert that lst contains list data and is not structured.""" + + # list and tuple are supported. Not supported are direct strings + # and dictionary; these indicate too much or two little structure. + if not isinstance(lst, list) and not isinstance(lst, tuple): + raise NotAListError('%s must be a list' % key) + + # each list entry must be a string + for element in lst: + if not isinstance(element, str): + raise ElementNotAStringError('Unsupported list element %s found in %s', + (element, lst)) diff --git a/src/setup.py b/src/setup.py index f747b8a..25f6095 100644 --- a/src/setup.py +++ b/src/setup.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Python Cloud Debugger build and packaging script.""" -import ConfigParser +from configparser import ConfigParser from glob import glob import os import re @@ -34,7 +33,7 @@ def RemovePrefixes(optlist, bad_prefixes): def ReadConfig(section, value, default): try: - config = ConfigParser.ConfigParser() + config = ConfigParser() config.read('setup.cfg') return config.get(section, value) except: # pylint: disable=bare-except @@ -49,8 +48,7 @@ def ReadConfig(section, value, default): 'For more details please see ' 'https://github.com/GoogleCloudPlatform/cloud-debug-python\n') -lib_dirs = ReadConfig('build_ext', - 'library_dirs', +lib_dirs = ReadConfig('build_ext', 'library_dirs', sysconfig.get_config_var('LIBDIR')).split(':') extra_compile_args = ReadConfig('cc_options', 'extra_compile_args', '').split() extra_link_args = ReadConfig('cc_options', 'extra_link_args', '').split() @@ -65,18 +63,19 @@ def ReadConfig(section, value, default): assert len(static_libs) == len(deps), (static_libs, deps, lib_dirs) cvars = sysconfig.get_config_vars() -cvars['OPT'] = str.join(' ', RemovePrefixes( - cvars.get('OPT').split(), - ['-g', '-O', '-Wstrict-prototypes'])) +cvars['OPT'] = str.join( + ' ', + RemovePrefixes( + cvars.get('OPT').split(), ['-g', '-O', '-Wstrict-prototypes'])) # Determine the current version of the package. The easiest way would be to # import "googleclouddebugger" and read its __version__ attribute. # Unfortunately we can't do that because "googleclouddebugger" depends on # "cdbg_native" that hasn't been built yet. version = None -with open('googleclouddebugger/__init__.py', 'r') as init_file: +with open('googleclouddebugger/version.py', 'r') as version_file: version_pattern = re.compile(r"^\s*__version__\s*=\s*'([0-9.]*)'") - for line in init_file: + for line in version_file: match = version_pattern.match(line) if match: version = match.groups()[0] @@ -101,13 +100,20 @@ def ReadConfig(section, value, default): url='https://github.com/GoogleCloudPlatform/cloud-debug-python', author='Google Inc.', version=version, - install_requires=['google-api-python-client'], + install_requires=[ + 'firebase-admin>=5.3.0', + 'pyyaml', + ], packages=['googleclouddebugger'], ext_modules=[cdbg_native_module], license='Apache License, Version 2.0', keywords='google cloud debugger', classifiers=[ - 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers' + 'Intended Audience :: Developers', ]) diff --git a/src/third_party/BUILD b/src/third_party/BUILD new file mode 100644 index 0000000..bcce1e2 --- /dev/null +++ b/src/third_party/BUILD @@ -0,0 +1,7 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "pylinetable", + hdrs = ["pylinetable.h"], +) + diff --git a/src/third_party/pylinetable.h b/src/third_party/pylinetable.h new file mode 100644 index 0000000..ea44c64 --- /dev/null +++ b/src/third_party/pylinetable.h @@ -0,0 +1,210 @@ +/** + * Copyright (c) 2001-2023 Python Software Foundation; All Rights Reserved + * + * You may obtain a copy of the PSF License at + * + * https://docs.python.org/3/license.html + */ + +#ifndef DEVTOOLS_CDBG_DEBUGLETS_PYTHON_PYLINETABLE_H_ +#define DEVTOOLS_CDBG_DEBUGLETS_PYTHON_PYLINETABLE_H_ + +/* Python Linetable helper methods. + * They are not part of the cpython api. + * This code has been extracted from: + * https://github.com/python/cpython/blob/main/Objects/codeobject.c + * + * See https://peps.python.org/pep-0626/#out-of-process-debuggers-and-profilers + * for more information about this code and its usage. + */ + +#if PY_VERSION_HEX >= 0x030B0000 +// Things are different in 3.11 than 3.10. +// See https://github.com/python/cpython/blob/main/Objects/locations.md + +typedef enum _PyCodeLocationInfoKind { + /* short forms are 0 to 9 */ + PY_CODE_LOCATION_INFO_SHORT0 = 0, + /* one lineforms are 10 to 12 */ + PY_CODE_LOCATION_INFO_ONE_LINE0 = 10, + PY_CODE_LOCATION_INFO_ONE_LINE1 = 11, + PY_CODE_LOCATION_INFO_ONE_LINE2 = 12, + + PY_CODE_LOCATION_INFO_NO_COLUMNS = 13, + PY_CODE_LOCATION_INFO_LONG = 14, + PY_CODE_LOCATION_INFO_NONE = 15 +} _PyCodeLocationInfoKind; + +/** Out of process API for initializing the location table. */ +extern void _PyLineTable_InitAddressRange( + const char *linetable, + Py_ssize_t length, + int firstlineno, + PyCodeAddressRange *range); + +/** API for traversing the line number table. */ +extern int _PyLineTable_NextAddressRange(PyCodeAddressRange *range); + + +void _PyLineTable_InitAddressRange(const char *linetable, Py_ssize_t length, int firstlineno, PyCodeAddressRange *range) { + range->opaque.lo_next = linetable; + range->opaque.limit = range->opaque.lo_next + length; + range->ar_start = -1; + range->ar_end = 0; + range->opaque.computed_line = firstlineno; + range->ar_line = -1; +} + +static int +scan_varint(const uint8_t *ptr) +{ + unsigned int read = *ptr++; + unsigned int val = read & 63; + unsigned int shift = 0; + while (read & 64) { + read = *ptr++; + shift += 6; + val |= (read & 63) << shift; + } + return val; +} + +static int +scan_signed_varint(const uint8_t *ptr) +{ + unsigned int uval = scan_varint(ptr); + if (uval & 1) { + return -(int)(uval >> 1); + } + else { + return uval >> 1; + } +} + +static int +get_line_delta(const uint8_t *ptr) +{ + int code = ((*ptr) >> 3) & 15; + switch (code) { + case PY_CODE_LOCATION_INFO_NONE: + return 0; + case PY_CODE_LOCATION_INFO_NO_COLUMNS: + case PY_CODE_LOCATION_INFO_LONG: + return scan_signed_varint(ptr+1); + case PY_CODE_LOCATION_INFO_ONE_LINE0: + return 0; + case PY_CODE_LOCATION_INFO_ONE_LINE1: + return 1; + case PY_CODE_LOCATION_INFO_ONE_LINE2: + return 2; + default: + /* Same line */ + return 0; + } +} + +static int +is_no_line_marker(uint8_t b) +{ + return (b >> 3) == 0x1f; +} + + +#define ASSERT_VALID_BOUNDS(bounds) \ + assert(bounds->opaque.lo_next <= bounds->opaque.limit && \ + (bounds->ar_line == -1 || bounds->ar_line == bounds->opaque.computed_line) && \ + (bounds->opaque.lo_next == bounds->opaque.limit || \ + (*bounds->opaque.lo_next) & 128)) + +static int +next_code_delta(PyCodeAddressRange *bounds) +{ + assert((*bounds->opaque.lo_next) & 128); + return (((*bounds->opaque.lo_next) & 7) + 1) * sizeof(_Py_CODEUNIT); +} + +static void +advance(PyCodeAddressRange *bounds) +{ + ASSERT_VALID_BOUNDS(bounds); + bounds->opaque.computed_line += get_line_delta(reinterpret_cast(bounds->opaque.lo_next)); + if (is_no_line_marker(*bounds->opaque.lo_next)) { + bounds->ar_line = -1; + } + else { + bounds->ar_line = bounds->opaque.computed_line; + } + bounds->ar_start = bounds->ar_end; + bounds->ar_end += next_code_delta(bounds); + do { + bounds->opaque.lo_next++; + } while (bounds->opaque.lo_next < bounds->opaque.limit && + ((*bounds->opaque.lo_next) & 128) == 0); + ASSERT_VALID_BOUNDS(bounds); +} + +static inline int +at_end(PyCodeAddressRange *bounds) { + return bounds->opaque.lo_next >= bounds->opaque.limit; +} + +int +_PyLineTable_NextAddressRange(PyCodeAddressRange *range) +{ + if (at_end(range)) { + return 0; + } + advance(range); + assert(range->ar_end > range->ar_start); + return 1; +} +#elif PY_VERSION_HEX >= 0x030A0000 +void +_PyLineTable_InitAddressRange(const char *linetable, Py_ssize_t length, int firstlineno, PyCodeAddressRange *range) +{ + range->opaque.lo_next = linetable; + range->opaque.limit = range->opaque.lo_next + length; + range->ar_start = -1; + range->ar_end = 0; + range->opaque.computed_line = firstlineno; + range->ar_line = -1; +} + +static void +advance(PyCodeAddressRange *bounds) +{ + bounds->ar_start = bounds->ar_end; + int delta = ((unsigned char *)bounds->opaque.lo_next)[0]; + bounds->ar_end += delta; + int ldelta = ((signed char *)bounds->opaque.lo_next)[1]; + bounds->opaque.lo_next += 2; + if (ldelta == -128) { + bounds->ar_line = -1; + } + else { + bounds->opaque.computed_line += ldelta; + bounds->ar_line = bounds->opaque.computed_line; + } +} + +static inline int +at_end(PyCodeAddressRange *bounds) { + return bounds->opaque.lo_next >= bounds->opaque.limit; +} + +int +_PyLineTable_NextAddressRange(PyCodeAddressRange *range) +{ + if (at_end(range)) { + return 0; + } + advance(range); + while (range->ar_start == range->ar_end) { + assert(!at_end(range)); + advance(range); + } + return 1; +} +#endif + +#endif // DEVTOOLS_CDBG_DEBUGLETS_PYTHON_PYLINETABLE_H_ diff --git a/tests/cpp/BUILD b/tests/cpp/BUILD new file mode 100644 index 0000000..7536fe9 --- /dev/null +++ b/tests/cpp/BUILD @@ -0,0 +1,10 @@ +package(default_visibility = ["//visibility:public"]) + +cc_test( + name = "bytecode_manipulator_test", + srcs = ["bytecode_manipulator_test.cc"], + deps = [ + "//src/googleclouddebugger:bytecode_manipulator", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tests/cpp/bytecode_manipulator_test.cc b/tests/cpp/bytecode_manipulator_test.cc new file mode 100644 index 0000000..934dfef --- /dev/null +++ b/tests/cpp/bytecode_manipulator_test.cc @@ -0,0 +1,1059 @@ +/** + * Copyright 2023 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/googleclouddebugger/bytecode_manipulator.h" + +#include +#include + +namespace devtools { +namespace cdbg { + +static std::string FormatOpcode(uint8_t opcode) { + switch (opcode) { + case POP_TOP: return "POP_TOP"; + case ROT_TWO: return "ROT_TWO"; + case ROT_THREE: return "ROT_THREE"; + case DUP_TOP: return "DUP_TOP"; + case NOP: return "NOP"; + case UNARY_POSITIVE: return "UNARY_POSITIVE"; + case UNARY_NEGATIVE: return "UNARY_NEGATIVE"; + case UNARY_NOT: return "UNARY_NOT"; + case UNARY_INVERT: return "UNARY_INVERT"; + case BINARY_POWER: return "BINARY_POWER"; + case BINARY_MULTIPLY: return "BINARY_MULTIPLY"; + case BINARY_MODULO: return "BINARY_MODULO"; + case BINARY_ADD: return "BINARY_ADD"; + case BINARY_SUBTRACT: return "BINARY_SUBTRACT"; + case BINARY_SUBSCR: return "BINARY_SUBSCR"; + case BINARY_FLOOR_DIVIDE: return "BINARY_FLOOR_DIVIDE"; + case BINARY_TRUE_DIVIDE: return "BINARY_TRUE_DIVIDE"; + case INPLACE_FLOOR_DIVIDE: return "INPLACE_FLOOR_DIVIDE"; + case INPLACE_TRUE_DIVIDE: return "INPLACE_TRUE_DIVIDE"; + case INPLACE_ADD: return "INPLACE_ADD"; + case INPLACE_SUBTRACT: return "INPLACE_SUBTRACT"; + case INPLACE_MULTIPLY: return "INPLACE_MULTIPLY"; + case INPLACE_MODULO: return "INPLACE_MODULO"; + case STORE_SUBSCR: return "STORE_SUBSCR"; + case DELETE_SUBSCR: return "DELETE_SUBSCR"; + case BINARY_LSHIFT: return "BINARY_LSHIFT"; + case BINARY_RSHIFT: return "BINARY_RSHIFT"; + case BINARY_AND: return "BINARY_AND"; + case BINARY_XOR: return "BINARY_XOR"; + case BINARY_OR: return "BINARY_OR"; + case INPLACE_POWER: return "INPLACE_POWER"; + case GET_ITER: return "GET_ITER"; + case PRINT_EXPR: return "PRINT_EXPR"; + case INPLACE_LSHIFT: return "INPLACE_LSHIFT"; + case INPLACE_RSHIFT: return "INPLACE_RSHIFT"; + case INPLACE_AND: return "INPLACE_AND"; + case INPLACE_XOR: return "INPLACE_XOR"; + case INPLACE_OR: return "INPLACE_OR"; + case RETURN_VALUE: return "RETURN_VALUE"; + case IMPORT_STAR: return "IMPORT_STAR"; + case YIELD_VALUE: return "YIELD_VALUE"; + case POP_BLOCK: return "POP_BLOCK"; +#if PY_VERSION_HEX <= 0x03080000 + case END_FINALLY: return "END_FINALLY"; +#endif + case STORE_NAME: return "STORE_NAME"; + case DELETE_NAME: return "DELETE_NAME"; + case UNPACK_SEQUENCE: return "UNPACK_SEQUENCE"; + case FOR_ITER: return "FOR_ITER"; + case LIST_APPEND: return "LIST_APPEND"; + case STORE_ATTR: return "STORE_ATTR"; + case DELETE_ATTR: return "DELETE_ATTR"; + case STORE_GLOBAL: return "STORE_GLOBAL"; + case DELETE_GLOBAL: return "DELETE_GLOBAL"; + case LOAD_CONST: return "LOAD_CONST"; + case LOAD_NAME: return "LOAD_NAME"; + case BUILD_TUPLE: return "BUILD_TUPLE"; + case BUILD_LIST: return "BUILD_LIST"; + case BUILD_SET: return "BUILD_SET"; + case BUILD_MAP: return "BUILD_MAP"; + case LOAD_ATTR: return "LOAD_ATTR"; + case COMPARE_OP: return "COMPARE_OP"; + case IMPORT_NAME: return "IMPORT_NAME"; + case IMPORT_FROM: return "IMPORT_FROM"; + case JUMP_FORWARD: return "JUMP_FORWARD"; + case JUMP_IF_FALSE_OR_POP: return "JUMP_IF_FALSE_OR_POP"; + case JUMP_IF_TRUE_OR_POP: return "JUMP_IF_TRUE_OR_POP"; + case JUMP_ABSOLUTE: return "JUMP_ABSOLUTE"; + case POP_JUMP_IF_FALSE: return "POP_JUMP_IF_FALSE"; + case POP_JUMP_IF_TRUE: return "POP_JUMP_IF_TRUE"; + case LOAD_GLOBAL: return "LOAD_GLOBAL"; + case SETUP_FINALLY: return "SETUP_FINALLY"; + case LOAD_FAST: return "LOAD_FAST"; + case STORE_FAST: return "STORE_FAST"; + case DELETE_FAST: return "DELETE_FAST"; + case RAISE_VARARGS: return "RAISE_VARARGS"; + case CALL_FUNCTION: return "CALL_FUNCTION"; + case MAKE_FUNCTION: return "MAKE_FUNCTION"; + case BUILD_SLICE: return "BUILD_SLICE"; + case LOAD_CLOSURE: return "LOAD_CLOSURE"; + case LOAD_DEREF: return "LOAD_DEREF"; + case STORE_DEREF: return "STORE_DEREF"; + case CALL_FUNCTION_KW: return "CALL_FUNCTION_KW"; + case SETUP_WITH: return "SETUP_WITH"; + case EXTENDED_ARG: return "EXTENDED_ARG"; + case SET_ADD: return "SET_ADD"; + case MAP_ADD: return "MAP_ADD"; +#if PY_VERSION_HEX < 0x03080000 + case BREAK_LOOP: return "BREAK_LOOP"; + case CONTINUE_LOOP: return "CONTINUE_LOOP"; + case SETUP_LOOP: return "SETUP_LOOP"; + case SETUP_EXCEPT: return "SETUP_EXCEPT"; +#endif + case DUP_TOP_TWO: return "DUP_TOP_TWO"; + case BINARY_MATRIX_MULTIPLY: return "BINARY_MATRIX_MULTIPLY"; + case INPLACE_MATRIX_MULTIPLY: return "INPLACE_MATRIX_MULTIPLY"; + case GET_AITER: return "GET_AITER"; + case GET_ANEXT: return "GET_ANEXT"; + case BEFORE_ASYNC_WITH: return "BEFORE_ASYNC_WITH"; + case GET_YIELD_FROM_ITER: return "GET_YIELD_FROM_ITER"; + case LOAD_BUILD_CLASS: return "LOAD_BUILD_CLASS"; + case YIELD_FROM: return "YIELD_FROM"; + case GET_AWAITABLE: return "GET_AWAITABLE"; +#if PY_VERSION_HEX <= 0x03080000 + case WITH_CLEANUP_START: return "WITH_CLEANUP_START"; + case WITH_CLEANUP_FINISH: return "WITH_CLEANUP_FINISH"; +#endif + case SETUP_ANNOTATIONS: return "SETUP_ANNOTATIONS"; + case POP_EXCEPT: return "POP_EXCEPT"; + case UNPACK_EX: return "UNPACK_EX"; +#if PY_VERSION_HEX < 0x03070000 + case STORE_ANNOTATION: return "STORE_ANNOTATION"; +#endif + case CALL_FUNCTION_EX: return "CALL_FUNCTION_EX"; + case LOAD_CLASSDEREF: return "LOAD_CLASSDEREF"; +#if PY_VERSION_HEX <= 0x03080000 + case BUILD_LIST_UNPACK: return "BUILD_LIST_UNPACK"; + case BUILD_MAP_UNPACK: return "BUILD_MAP_UNPACK"; + case BUILD_MAP_UNPACK_WITH_CALL: return "BUILD_MAP_UNPACK_WITH_CALL"; + case BUILD_TUPLE_UNPACK: return "BUILD_TUPLE_UNPACK"; + case BUILD_SET_UNPACK: return "BUILD_SET_UNPACK"; +#endif + case SETUP_ASYNC_WITH: return "SETUP_ASYNC_WITH"; + case FORMAT_VALUE: return "FORMAT_VALUE"; + case BUILD_CONST_KEY_MAP: return "BUILD_CONST_KEY_MAP"; + case BUILD_STRING: return "BUILD_STRING"; +#if PY_VERSION_HEX <= 0x03080000 + case BUILD_TUPLE_UNPACK_WITH_CALL: return "BUILD_TUPLE_UNPACK_WITH_CALL"; +#endif +#if PY_VERSION_HEX >= 0x03070000 + case LOAD_METHOD: return "LOAD_METHOD"; + case CALL_METHOD: return "CALL_METHOD"; +#endif +#if PY_VERSION_HEX >= 0x03080000 && PY_VERSION_HEX < 0x03090000 + case BEGIN_FINALLY: return "BEGIN_FINALLY": + case POP_FINALLY: return "POP_FINALLY"; +#endif +#if PY_VERSION_HEX >= 0x03080000 + case ROT_FOUR: return "ROT_FOUR"; + case END_ASYNC_FOR: return "END_ASYNC_FOR"; +#endif +#if PY_VERSION_HEX >= 0x03080000 && PY_VERSION_HEX < 0x03090000 + // Added in Python 3.8 and removed in 3.9 + case CALL_FINALLY: return "CALL_FINALLY"; +#endif +#if PY_VERSION_HEX >= 0x03090000 + case RERAISE: return "RERAISE"; + case WITH_EXCEPT_START: return "WITH_EXCEPT_START"; + case LOAD_ASSERTION_ERROR: return "LOAD_ASSERTION_ERROR"; + case LIST_TO_TUPLE: return "LIST_TO_TUPLE"; + case IS_OP: return "IS_OP"; + case CONTAINS_OP: return "CONTAINS_OP"; + case JUMP_IF_NOT_EXC_MATCH: return "JUMP_IF_NOT_EXC_MATCH"; + case LIST_EXTEND: return "LIST_EXTEND"; + case SET_UPDATE: return "SET_UPDATE"; + case DICT_MERGE: return "DICT_MERGE"; + case DICT_UPDATE: return "DICT_UPDATE"; +#endif + + default: return std::to_string(static_cast(opcode)); + } +} + +static std::string FormatBytecode(const std::vector& bytecode, + int indent) { + std::string rc; + int remaining_argument_bytes = 0; + for (auto it = bytecode.begin(); it != bytecode.end(); ++it) { + std::string line; + if (remaining_argument_bytes == 0) { + line = FormatOpcode(*it); + remaining_argument_bytes = 1; + } else { + line = std::to_string(static_cast(*it)); + --remaining_argument_bytes; + } + + if (it < bytecode.end() - 1) { + line += ','; + } + + line.resize(20, ' '); + line += "// offset "; + line += std::to_string(it - bytecode.begin()); + line += '.'; + + rc += std::string(indent, ' '); + rc += line; + + if (it < bytecode.end() - 1) { + rc += '\n'; + } + } + + return rc; +} + +static void VerifyBytecode(const BytecodeManipulator& bytecode_manipulator, + std::vector expected_bytecode) { + EXPECT_EQ(expected_bytecode, bytecode_manipulator.bytecode()) + << "Actual bytecode:\n" + << " {\n" + << FormatBytecode(bytecode_manipulator.bytecode(), 10) << "\n" + << " }"; +} + +static void VerifyLineNumbersTable( + const BytecodeManipulator& bytecode_manipulator, + std::vector expected_linedata) { + // Convert to integers to better logging by EXPECT_EQ. + std::vector expected(expected_linedata.begin(), expected_linedata.end()); + std::vector actual( + bytecode_manipulator.linedata().begin(), + bytecode_manipulator.linedata().end()); + + EXPECT_EQ(expected, actual); +} + +TEST(BytecodeManipulatorTest, EmptyBytecode) { + BytecodeManipulator instance({}, false, {}); + EXPECT_FALSE(instance.InjectMethodCall(0, 0)); +} + + +TEST(BytecodeManipulatorTest, HasLineNumbersTable) { + BytecodeManipulator instance1({}, false, {}); + EXPECT_FALSE(instance1.has_linedata()); + + BytecodeManipulator instance2({}, true, {}); + EXPECT_TRUE(instance2.has_linedata()); +} + + + + +TEST(BytecodeManipulatorTest, InsertionSimple) { + BytecodeManipulator instance({ NOP, 0, RETURN_VALUE, 0 }, false, {}); + ASSERT_TRUE(instance.InjectMethodCall(2, 47)); + + VerifyBytecode( + instance, + { + NOP, // offset 0. + 0, // offset 1. + LOAD_CONST, // offset 4. + 47, // offset 5. + CALL_FUNCTION, // offset 6. + 0, // offset 7. + POP_TOP, // offset 8. + 0, // offset 9. + RETURN_VALUE, // offset 10. + 0 // offset 11. + }); +} + + +TEST(BytecodeManipulatorTest, InsertionExtended) { + BytecodeManipulator instance({ NOP, 0, RETURN_VALUE, 0 }, false, {}); + ASSERT_TRUE(instance.InjectMethodCall(2, 0x12345678)); + + VerifyBytecode( + instance, + { + NOP, // offset 0. + 0, // offset 1. + EXTENDED_ARG, // offset 2. + 0x12, // offset 3. + EXTENDED_ARG, // offset 2. + 0x34, // offset 3. + EXTENDED_ARG, // offset 2. + 0x56, // offset 3. + LOAD_CONST, // offset 4. + 0x78, // offset 5. + CALL_FUNCTION, // offset 6. + 0, // offset 7. + POP_TOP, // offset 8. + 0, // offset 9. + RETURN_VALUE, // offset 10. + 0 // offset 11. + }); +} + + +TEST(BytecodeManipulatorTest, InsertionBeginning) { + BytecodeManipulator instance({ NOP, 0, RETURN_VALUE, 0 }, false, {}); + ASSERT_TRUE(instance.InjectMethodCall(0, 47)); + + VerifyBytecode( + instance, + { + LOAD_CONST, // offset 0. + 47, // offset 1. + CALL_FUNCTION, // offset 2. + 0, // offset 3. + POP_TOP, // offset 4. + 0, // offset 5. + NOP, // offset 6. + 0, // offset 7. + RETURN_VALUE, // offset 8. + 0 // offset 9. + }); +} + + +TEST(BytecodeManipulatorTest, InsertionOffsetUpdates) { + BytecodeManipulator instance( + { + JUMP_FORWARD, + 12, + NOP, + 0, + JUMP_ABSOLUTE, + 34, + }, + false, + {}); + ASSERT_TRUE(instance.InjectMethodCall(2, 47)); + +#if PY_VERSION_HEX >= 0x030A0000 + // Jump offsets are instruction offsets, not byte offsets. + VerifyBytecode( + instance, + { + JUMP_FORWARD, // offset 0. + 12 + 3, // offset 1. + LOAD_CONST, // offset 2. + 47, // offset 3. + CALL_FUNCTION, // offset 4. + 0, // offset 5. + POP_TOP, // offset 6. + 0, // offset 7. + NOP, // offset 8. + 0, // offset 9. + JUMP_ABSOLUTE, // offset 10. + 34 + 3 // offset 11. + }); +#else + VerifyBytecode( + instance, + { + JUMP_FORWARD, // offset 0. + 12 + 6, // offset 1. + LOAD_CONST, // offset 2. + 47, // offset 3. + CALL_FUNCTION, // offset 4. + 0, // offset 5. + POP_TOP, // offset 6. + 0, // offset 7. + NOP, // offset 8. + 0, // offset 9. + JUMP_ABSOLUTE, // offset 10. + 34 + 6 // offset 11. + }); +#endif +} + + +TEST(BytecodeManipulatorTest, InsertionExtendedOffsetUpdates) { + BytecodeManipulator instance( + { + EXTENDED_ARG, + 12, + EXTENDED_ARG, + 34, + EXTENDED_ARG, + 56, + JUMP_FORWARD, + 78, + NOP, + 0, + EXTENDED_ARG, + 98, + EXTENDED_ARG, + 76, + EXTENDED_ARG, + 54, + JUMP_ABSOLUTE, + 32 + }, + false, + {}); + ASSERT_TRUE(instance.InjectMethodCall(8, 11)); + +#if PY_VERSION_HEX >= 0x030A0000 + // Jump offsets are instruction offsets, not byte offsets. + VerifyBytecode( + instance, + { + EXTENDED_ARG, // offset 0. + 12, // offset 1. + EXTENDED_ARG, // offset 2. + 34, // offset 3. + EXTENDED_ARG, // offset 4. + 56, // offset 5. + JUMP_FORWARD, // offset 6. + 78 + 3, // offset 7. + LOAD_CONST, // offset 8. + 11, // offset 9. + CALL_FUNCTION, // offset 10. + 0, // offset 11. + POP_TOP, // offset 12. + 0, // offset 13. + NOP, // offset 14. + 0, // offset 15. + EXTENDED_ARG, // offset 16. + 98, // offset 17. + EXTENDED_ARG, // offset 18. + 76, // offset 19. + EXTENDED_ARG, // offset 20. + 54, // offset 21. + JUMP_ABSOLUTE, // offset 22. + 32 + 3 // offset 23. + }); +#else + VerifyBytecode( + instance, + { + EXTENDED_ARG, // offset 0. + 12, // offset 1. + EXTENDED_ARG, // offset 2. + 34, // offset 3. + EXTENDED_ARG, // offset 4. + 56, // offset 5. + JUMP_FORWARD, // offset 6. + 78 + 6, // offset 7. + LOAD_CONST, // offset 8. + 11, // offset 9. + CALL_FUNCTION, // offset 10. + 0, // offset 11. + POP_TOP, // offset 12. + 0, // offset 13. + NOP, // offset 14. + 0, // offset 15. + EXTENDED_ARG, // offset 16. + 98, // offset 17. + EXTENDED_ARG, // offset 18. + 76, // offset 19. + EXTENDED_ARG, // offset 20. + 54, // offset 21. + JUMP_ABSOLUTE, // offset 22. + 32 + 6 // offset 23. + }); +#endif +} + + +TEST(BytecodeManipulatorTest, InsertionDeltaOffsetNoUpdate) { + BytecodeManipulator instance( + { + JUMP_FORWARD, + 2, + NOP, + 0, + RETURN_VALUE, + 0, + JUMP_FORWARD, + 2, + }, + false, {}); + ASSERT_TRUE(instance.InjectMethodCall(4, 99)); + + VerifyBytecode( + instance, + { + JUMP_FORWARD, // offset 0. + 2, // offset 1. + NOP, // offset 2. + 0, // offset 3. + LOAD_CONST, // offset 4. + 99, // offset 5. + CALL_FUNCTION, // offset 6. + 0, // offset 7. + POP_TOP, // offset 8. + 0, // offset 9. + RETURN_VALUE, // offset 10. + 0, // offset 11. + JUMP_FORWARD, // offset 12. + 2 // offset 13. + }); +} + + +TEST(BytecodeManipulatorTest, InsertionAbsoluteOffsetNoUpdate) { + BytecodeManipulator instance( + { + JUMP_ABSOLUTE, + 2, + RETURN_VALUE, + 0 + }, + false, + {}); + ASSERT_TRUE(instance.InjectMethodCall(2, 99)); + + VerifyBytecode( + instance, + { + JUMP_ABSOLUTE, // offset 0. + 2, // offset 1. + LOAD_CONST, // offset 2. + 99, // offset 3. + CALL_FUNCTION, // offset 4. + 0, // offset 5. + POP_TOP, // offset 6. + 0, // offset 7. + RETURN_VALUE, // offset 8. + 0 // offset 9. + }); +} + + +TEST(BytecodeManipulatorTest, InsertionOffsetUneededExtended) { + BytecodeManipulator instance( + { EXTENDED_ARG, 0, JUMP_FORWARD, 2, NOP, 0 }, + false, + {}); + ASSERT_TRUE(instance.InjectMethodCall(4, 11)); + +#if PY_VERSION_HEX >= 0x030A0000 + // Jump offsets are instruction offsets, not byte offsets. + VerifyBytecode( + instance, + { + EXTENDED_ARG, // offset 0. + 0, // offset 1. + JUMP_FORWARD, // offset 2. + 2 + 3, // offset 3. + LOAD_CONST, // offset 4. + 11, // offset 5. + CALL_FUNCTION, // offset 6. + 0, // offset 7. + POP_TOP, // offset 8. + 0, // offset 9. + NOP, // offset 10. + 0 // offset 11. + }); +#else + VerifyBytecode( + instance, + { + EXTENDED_ARG, // offset 0. + 0, // offset 1. + JUMP_FORWARD, // offset 2. + 2 + 6, // offset 3. + LOAD_CONST, // offset 4. + 11, // offset 5. + CALL_FUNCTION, // offset 6. + 0, // offset 7. + POP_TOP, // offset 8. + 0, // offset 9. + NOP, // offset 10. + 0 // offset 11. + }); +#endif +} + + +TEST(BytecodeManipulatorTest, InsertionOffsetUpgradeExtended) { + BytecodeManipulator instance({ JUMP_ABSOLUTE, 254 , NOP, 0 }, false, {}); + ASSERT_TRUE(instance.InjectMethodCall(2, 11)); + +#if PY_VERSION_HEX >= 0x030A0000 + // Jump offsets are instruction offsets, not byte offsets. + VerifyBytecode( + instance, + { + EXTENDED_ARG, // offset 0. + 1, // offset 1. + JUMP_ABSOLUTE, // offset 2. + 2, // offset 3. + LOAD_CONST, // offset 4. + 11, // offset 5. + CALL_FUNCTION, // offset 6. + 0, // offset 7. + POP_TOP, // offset 8. + 0, // offset 9. + NOP, // offset 10. + 0 // offset 11. + }); +#else + VerifyBytecode( + instance, + { + EXTENDED_ARG, // offset 0. + 1, // offset 1. + JUMP_ABSOLUTE, // offset 2. + 6, // offset 3. + LOAD_CONST, // offset 4. + 11, // offset 5. + CALL_FUNCTION, // offset 6. + 0, // offset 7. + POP_TOP, // offset 8. + 0, // offset 9. + NOP, // offset 10. + 0 // offset 11. + }); +#endif +} + + +TEST(BytecodeManipulatorTest, InsertionOffsetUpgradeExtendedTwice) { + BytecodeManipulator instance( + { JUMP_ABSOLUTE, 252, JUMP_ABSOLUTE, 254, NOP, 0 }, + false, + {}); + ASSERT_TRUE(instance.InjectMethodCall(4, 12)); + +#if PY_VERSION_HEX >= 0x030A0000 + // Jump offsets are instruction offsets, not byte offsets. + VerifyBytecode( + instance, + { + EXTENDED_ARG, // offset 0. + 1, // offset 1. + JUMP_ABSOLUTE, // offset 2. + 1, // offset 3. + EXTENDED_ARG, // offset 4. + 1, // offset 5. + JUMP_ABSOLUTE, // offset 6. + 3, // offset 7. + LOAD_CONST, // offset 8. + 12, // offset 9. + CALL_FUNCTION, // offset 10. + 0, // offset 11. + POP_TOP, // offset 12. + 0, // offset 13. + NOP, // offset 14. + 0 // offset 15. + }); +#else + VerifyBytecode( + instance, + { + EXTENDED_ARG, // offset 0. + 1, // offset 1. + JUMP_ABSOLUTE, // offset 2. + 6, // offset 3. + EXTENDED_ARG, // offset 4. + 1, // offset 5. + JUMP_ABSOLUTE, // offset 6. + 8, // offset 7. + LOAD_CONST, // offset 8. + 12, // offset 9. + CALL_FUNCTION, // offset 10. + 0, // offset 11. + POP_TOP, // offset 12. + 0, // offset 13. + NOP, // offset 14. + 0 // offset 15. + }); +#endif +} + + +TEST(BytecodeManipulatorTest, InsertionBadInstruction) { + BytecodeManipulator instance( + { NOP, 0, NOP, 0, LOAD_CONST }, + false, + {}); + EXPECT_FALSE(instance.InjectMethodCall(2, 0)); +} + + +TEST(BytecodeManipulatorTest, InsertionNegativeOffset) { + BytecodeManipulator instance({ NOP, 0, RETURN_VALUE, 0 }, false, {}); + EXPECT_FALSE(instance.InjectMethodCall(-1, 0)); +} + + +TEST(BytecodeManipulatorTest, InsertionOutOfRangeOffset) { + BytecodeManipulator instance({ NOP, 0, RETURN_VALUE, 0 }, false, {}); + EXPECT_FALSE(instance.InjectMethodCall(4, 0)); +} + + +TEST(BytecodeManipulatorTest, InsertionMidInstruction) { + BytecodeManipulator instance( + { NOP, 0, LOAD_CONST, 0, NOP, 0 }, + false, + {}); + + EXPECT_FALSE(instance.InjectMethodCall(1, 0)); + EXPECT_FALSE(instance.InjectMethodCall(3, 0)); + EXPECT_FALSE(instance.InjectMethodCall(5, 0)); +} + + +TEST(BytecodeManipulatorTest, InsertionTooManyUpgrades) { + BytecodeManipulator instance( + { + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + NOP, 0 + }, + false, + {}); + EXPECT_FALSE(instance.InjectMethodCall(20, 0)); +} + + +TEST(BytecodeManipulatorTest, IncompleteBytecodeInsert) { + BytecodeManipulator instance({ NOP, 0, LOAD_CONST }, false, {}); + EXPECT_FALSE(instance.InjectMethodCall(2, 0)); +} + + +TEST(BytecodeManipulatorTest, IncompleteBytecodeAppend) { + BytecodeManipulator instance( + { YIELD_VALUE, 0, NOP, 0, LOAD_CONST }, + false, {}); + EXPECT_FALSE(instance.InjectMethodCall(4, 0)); +} + + +TEST(BytecodeManipulatorTest, LineNumbersTableUpdateBeginning) { + BytecodeManipulator instance( + { NOP, 0, RETURN_VALUE, 0 }, + true, + { 2, 1, 2, 1 }); + ASSERT_TRUE(instance.InjectMethodCall(0, 99)); + + VerifyLineNumbersTable(instance, { 8, 1, 2, 1 }); +} + + +TEST(BytecodeManipulatorTest, LineNumbersTableUpdateLineBoundary) { + BytecodeManipulator instance( + { NOP, 0, RETURN_VALUE, 0 }, + true, + { 0, 1, 2, 1, 2, 1 }); + ASSERT_TRUE(instance.InjectMethodCall(2, 99)); + + VerifyLineNumbersTable(instance, { 0, 1, 2, 1, 8, 1 }); +} + + +TEST(BytecodeManipulatorTest, LineNumbersTableUpdateMidLine) { + BytecodeManipulator instance( + { NOP, 0, NOP, 0, RETURN_VALUE, 0 }, + true, + { 0, 1, 4, 1 }); + ASSERT_TRUE(instance.InjectMethodCall(2, 99)); + + VerifyLineNumbersTable(instance, { 0, 1, 10, 1 }); +} + + +TEST(BytecodeManipulatorTest, LineNumbersTablePastEnd) { + BytecodeManipulator instance( + { NOP, 0, NOP, 0, NOP, 0, RETURN_VALUE, 0 }, + true, + { 0, 1 }); + ASSERT_TRUE(instance.InjectMethodCall(6, 99)); + + VerifyLineNumbersTable(instance, { 0, 1 }); +} + + +TEST(BytecodeManipulatorTest, LineNumbersTableUpgradeExtended) { + BytecodeManipulator instance( + { JUMP_ABSOLUTE, 254, RETURN_VALUE, 0 }, + true, + { 2, 1, 2, 1 }); + ASSERT_TRUE(instance.InjectMethodCall(2, 99)); + + VerifyLineNumbersTable(instance, { 4, 1, 8, 1 }); +} + + +TEST(BytecodeManipulatorTest, LineNumbersTableOverflow) { + std::vector bytecode(300, 0); + BytecodeManipulator instance( + bytecode, + true, + { 254, 1 }); + ASSERT_TRUE(instance.InjectMethodCall(2, 99)); + +#if PY_VERSION_HEX >= 0x030A0000 + VerifyLineNumbersTable(instance, { 254, 0, 6, 1 }); +#else + VerifyLineNumbersTable(instance, { 255, 0, 5, 1 }); +#endif +} + + +TEST(BytecodeManipulatorTest, SuccessAppend) { + BytecodeManipulator instance( + { YIELD_VALUE, 0, LOAD_CONST, 0, NOP, 0 }, + false, + {}); + ASSERT_TRUE(instance.InjectMethodCall(2, 57)); + + VerifyBytecode( + instance, + { + YIELD_VALUE, // offset 0. + 0, // offset 1. + JUMP_ABSOLUTE, // offset 2. + 6, // offset 3. + NOP, // offset 4. + 0, // offset 5. + LOAD_CONST, // offset 6. + 57, // offset 7. + CALL_FUNCTION, // offset 8. + 0, // offset 9. + POP_TOP, // offset 10. + 0, // offset 11. + LOAD_CONST, // offset 12. + 0, // offset 13. + JUMP_ABSOLUTE, // offset 14. + 4 // offset 15. + }); +} + + +TEST(BytecodeManipulatorTest, SuccessAppendYieldFrom) { + BytecodeManipulator instance( + { YIELD_FROM, 0, LOAD_CONST, 0, NOP, 0 }, + false, + {}); + ASSERT_TRUE(instance.InjectMethodCall(2, 57)); + + VerifyBytecode( + instance, + { + YIELD_FROM, // offset 0. + 0, // offset 1. + JUMP_ABSOLUTE, // offset 2. + 6, // offset 3. + NOP, // offset 4. + 0, // offset 5. + LOAD_CONST, // offset 6. + 57, // offset 7. + CALL_FUNCTION, // offset 8. + 0, // offset 9. + POP_TOP, // offset 10. + 0, // offset 11. + LOAD_CONST, // offset 12. + 0, // offset 13. + JUMP_ABSOLUTE, // offset 14. + 4 // offset 15. + }); +} + + +TEST(BytecodeManipulatorTest, AppendExtraPadding) { + BytecodeManipulator instance( + { + YIELD_VALUE, + 0, + EXTENDED_ARG, + 15, + EXTENDED_ARG, + 16, + EXTENDED_ARG, + 17, + LOAD_CONST, + 18, + RETURN_VALUE, + 0 + }, + false, {}); + ASSERT_TRUE(instance.InjectMethodCall(2, 0x7273)); + + VerifyBytecode( + instance, + { + YIELD_VALUE, // offset 0. + 0, // offset 1. + JUMP_ABSOLUTE, // offset 2. + 12, // offset 3. + NOP, // offset 4. Args for NOP do not matter. + 9, // offset 5. + NOP, // offset 6. + 9, // offset 7. + NOP, // offset 8. + 9, // offset 9. + RETURN_VALUE, // offset 10. + 0, // offset 11. + EXTENDED_ARG, // offset 12. + 0x72, // offset 13. + LOAD_CONST, // offset 14. + 0x73, // offset 15. + CALL_FUNCTION, // offset 16. + 0, // offset 17. + POP_TOP, // offset 18. + 0, // offset 19. + EXTENDED_ARG, // offset 20. + 15, // offset 21. + EXTENDED_ARG, // offset 22. + 16, // offset 23. + EXTENDED_ARG, // offset 24. + 17, // offset 25. + LOAD_CONST, // offset 26. + 18, // offset 27. + JUMP_ABSOLUTE, // offset 28. + 10 // offset 29. + }); +} + + +TEST(BytecodeManipulatorTest, AppendToEnd) { + std::vector bytecode = {YIELD_VALUE, 0}; + // Case where trampoline requires 4 bytes to write. + bytecode.resize(300); + BytecodeManipulator instance(bytecode, false, {}); + + // This scenario could be supported in theory, but it's not. The purpose of + // this test case is to verify there are no crashes or corruption. + ASSERT_FALSE(instance.InjectMethodCall(298, 0x12)); +} + + +TEST(BytecodeManipulatorTest, NoSpaceForTrampoline) { + const std::vector test_cases[] = { + {YIELD_VALUE, 0, YIELD_VALUE, 0, NOP, 0}, + {YIELD_VALUE, 0, FOR_ITER, 0, NOP, 0}, + {YIELD_VALUE, 0, JUMP_FORWARD, 0, NOP, 0}, +#if PY_VERSION_HEX < 0x03080000 + {YIELD_VALUE, 0, SETUP_LOOP, 0, NOP, 0}, +#endif + {YIELD_VALUE, 0, SETUP_FINALLY, 0, NOP, 0}, +#if PY_VERSION_HEX < 0x03080000 + {YIELD_VALUE, 0, SETUP_LOOP, 0, NOP, 0}, + {YIELD_VALUE, 0, SETUP_EXCEPT, 0, NOP, 0}, +#endif +#if PY_VERSION_HEX >= 0x03080000 && PY_VERSION_HEX < 0x03090000 + {YIELD_VALUE, 0, CALL_FINALLY, 0, NOP, 0}, +#endif + }; + + for (const auto& test_case : test_cases) { + BytecodeManipulator instance(test_case, false, {}); + EXPECT_FALSE(instance.InjectMethodCall(2, 0)) + << "Input:\n" + << FormatBytecode(test_case, 4) << "\n" + << "Unexpected output:\n" + << FormatBytecode(instance.bytecode(), 4); + } + + // Case where trampoline requires 4 bytes to write. + std::vector bytecode(300, 0); + bytecode[0] = YIELD_VALUE; + bytecode[2] = NOP; + bytecode[4] = YIELD_VALUE; + BytecodeManipulator instance(bytecode, false, {}); + ASSERT_FALSE(instance.InjectMethodCall(2, 0x12)); +} + +// Tests that we don't allow jumping into the middle of the space reserved for +// the trampoline. See the comments in AppendMethodCall() in +// bytecode_manipulator.cc. +TEST(BytecodeManipulatorTest, JumpMidRelocatedInstructions) { + std::vector test_cases[] = { + {YIELD_VALUE, 0, FOR_ITER, 2, LOAD_CONST, 0}, + {YIELD_VALUE, 0, JUMP_FORWARD, 2, LOAD_CONST, 0}, + {YIELD_VALUE, 0, SETUP_FINALLY, 2, LOAD_CONST, 0}, + {YIELD_VALUE, 0, SETUP_WITH, 2, LOAD_CONST, 0}, + {YIELD_VALUE, 0, SETUP_FINALLY, 2, LOAD_CONST, 0}, + {YIELD_VALUE, 0, JUMP_IF_FALSE_OR_POP, 6, LOAD_CONST, 0}, + {YIELD_VALUE, 0, JUMP_IF_TRUE_OR_POP, 6, LOAD_CONST, 0}, + {YIELD_VALUE, 0, JUMP_ABSOLUTE, 6, LOAD_CONST, 0}, + {YIELD_VALUE, 0, POP_JUMP_IF_FALSE, 6, LOAD_CONST, 0}, + {YIELD_VALUE, 0, POP_JUMP_IF_TRUE, 6, LOAD_CONST, 0}, +#if PY_VERSION_HEX < 0x03080000 + {YIELD_VALUE, 0, SETUP_LOOP, 2, LOAD_CONST, 0}, + {YIELD_VALUE, 0, CONTINUE_LOOP, 6, LOAD_CONST, 0}, +#endif + }; + + for (auto& test_case : test_cases) { + // Case where trampoline requires 4 bytes to write. + test_case.resize(300); + BytecodeManipulator instance(test_case, false, {}); + EXPECT_FALSE(instance.InjectMethodCall(4, 0)) + << "Input:\n" + << FormatBytecode(test_case, 4) << "\n" + << "Unexpected output:\n" + << FormatBytecode(instance.bytecode(), 4); + } +} + + +// Test that we allow jumping to the start of the space reserved for the +// trampoline. +TEST(BytecodeManipulatorTest, JumpStartOfRelocatedInstructions) { + const std::vector test_cases[] = { + {YIELD_VALUE, 0, FOR_ITER, 0, LOAD_CONST, 0}, + {YIELD_VALUE, 0, SETUP_WITH, 0, LOAD_CONST, 0}, + {YIELD_VALUE, 0, JUMP_ABSOLUTE, 4, LOAD_CONST, 0}}; + + for (const auto& test_case : test_cases) { + BytecodeManipulator instance(test_case, false, {}); + EXPECT_TRUE(instance.InjectMethodCall(4, 0)) + << "Input:\n" << FormatBytecode(test_case, 4); + } +} + + +// Test that we allow jumping after the space reserved for the trampoline. +TEST(BytecodeManipulatorTest, JumpAfterRelocatedInstructions) { + const std::vector test_cases[] = { + {YIELD_VALUE, 0, FOR_ITER, 2, LOAD_CONST, 0, NOP, 0}, + {YIELD_VALUE, 0, SETUP_WITH, 2, LOAD_CONST, 0, NOP, 0}, + {YIELD_VALUE, 0, JUMP_ABSOLUTE, 6, LOAD_CONST, 0, NOP, 0}}; + + for (const auto& test_case : test_cases) { + BytecodeManipulator instance(test_case, false, {}); + EXPECT_TRUE(instance.InjectMethodCall(4, 0)) + << "Input:\n" << FormatBytecode(test_case, 4); + } +} + + +TEST(BytecodeManipulatorTest, InsertionRevertOnFailure) { + const std::vector input{JUMP_FORWARD, 0, NOP, 0, JUMP_ABSOLUTE, 2}; + + BytecodeManipulator instance(input, false, {}); + ASSERT_FALSE(instance.InjectMethodCall(1, 47)); + + VerifyBytecode(instance, input); +} + + +} // namespace cdbg +} // namespace devtools diff --git a/tests/py/application_info_test.py b/tests/py/application_info_test.py new file mode 100644 index 0000000..51f7427 --- /dev/null +++ b/tests/py/application_info_test.py @@ -0,0 +1,72 @@ +"""Tests for application_info.""" + +import os +from unittest import mock + +import requests + +from googleclouddebugger import application_info +from absl.testing import absltest + + +class ApplicationInfoTest(absltest.TestCase): + + def test_get_platform_default(self): + """Returns default platform when no platform is detected.""" + self.assertEqual(application_info.PlatformType.DEFAULT, + application_info.GetPlatform()) + + def test_get_platform_gcf_name(self): + """Returns cloud_function when the FUNCTION_NAME env variable is set.""" + try: + os.environ['FUNCTION_NAME'] = 'function-name' + self.assertEqual(application_info.PlatformType.CLOUD_FUNCTION, + application_info.GetPlatform()) + finally: + del os.environ['FUNCTION_NAME'] + + def test_get_platform_gcf_target(self): + """Returns cloud_function when the FUNCTION_TARGET env variable is set.""" + try: + os.environ['FUNCTION_TARGET'] = 'function-target' + self.assertEqual(application_info.PlatformType.CLOUD_FUNCTION, + application_info.GetPlatform()) + finally: + del os.environ['FUNCTION_TARGET'] + + def test_get_region_none(self): + """Returns None when no region is detected.""" + self.assertIsNone(application_info.GetRegion()) + + def test_get_region_gcf(self): + """Returns correct region when the FUNCTION_REGION env variable is set.""" + try: + os.environ['FUNCTION_REGION'] = 'function-region' + self.assertEqual('function-region', application_info.GetRegion()) + finally: + del os.environ['FUNCTION_REGION'] + + @mock.patch('requests.get') + def test_get_region_metadata_server(self, mock_requests_get): + """Returns correct region if found in metadata server.""" + success_response = mock.Mock(requests.Response) + success_response.status_code = 200 + success_response.text = 'a/b/function-region' + mock_requests_get.return_value = success_response + + self.assertEqual('function-region', application_info.GetRegion()) + + @mock.patch('requests.get') + def test_get_region_metadata_server_fail(self, mock_requests_get): + """Returns None if region not found in metadata server.""" + exception = requests.exceptions.HTTPError() + failed_response = mock.Mock(requests.Response) + failed_response.status_code = 400 + failed_response.raise_for_status.side_effect = exception + mock_requests_get.return_value = failed_response + + self.assertIsNone(application_info.GetRegion()) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/backoff_test.py b/tests/py/backoff_test.py new file mode 100644 index 0000000..262976c --- /dev/null +++ b/tests/py/backoff_test.py @@ -0,0 +1,35 @@ +"""Unit test for backoff module.""" + +from absl.testing import absltest + +from googleclouddebugger import backoff + + +class BackoffTest(absltest.TestCase): + """Unit test for backoff module.""" + + def setUp(self): + self._backoff = backoff.Backoff(10, 100, 1.5) + + def testInitial(self): + self.assertEqual(10, self._backoff.Failed()) + + def testIncrease(self): + self._backoff.Failed() + self.assertEqual(15, self._backoff.Failed()) + + def testMaximum(self): + for _ in range(100): + self._backoff.Failed() + + self.assertEqual(100, self._backoff.Failed()) + + def testResetOnSuccess(self): + for _ in range(4): + self._backoff.Failed() + self._backoff.Succeeded() + self.assertEqual(10, self._backoff.Failed()) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/breakpoints_manager_test.py b/tests/py/breakpoints_manager_test.py new file mode 100644 index 0000000..269931f --- /dev/null +++ b/tests/py/breakpoints_manager_test.py @@ -0,0 +1,229 @@ +"""Unit test for breakpoints_manager module.""" + +from datetime import datetime +from datetime import timedelta +from unittest import mock + +from absl.testing import absltest + +from googleclouddebugger import breakpoints_manager + + +class BreakpointsManagerTest(absltest.TestCase): + """Unit test for breakpoints_manager module.""" + + def setUp(self): + self._breakpoints_manager = breakpoints_manager.BreakpointsManager( + self, None) + + path = 'googleclouddebugger.breakpoints_manager.' + breakpoint_class = path + 'python_breakpoint.PythonBreakpoint' + + patcher = mock.patch(breakpoint_class) + self._mock_breakpoint = patcher.start() + self.addCleanup(patcher.stop) + + def testEmpty(self): + self.assertEmpty(self._breakpoints_manager._active) + + def testSetSingle(self): + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self._mock_breakpoint.assert_has_calls( + [mock.call({'id': 'ID1'}, self, self._breakpoints_manager, None)]) + self.assertLen(self._breakpoints_manager._active, 1) + + def testSetDouble(self): + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self._mock_breakpoint.assert_has_calls( + [mock.call({'id': 'ID1'}, self, self._breakpoints_manager, None)]) + self.assertLen(self._breakpoints_manager._active, 1) + + self._breakpoints_manager.SetActiveBreakpoints([{ + 'id': 'ID1' + }, { + 'id': 'ID2' + }]) + self._mock_breakpoint.assert_has_calls([ + mock.call({'id': 'ID1'}, self, self._breakpoints_manager, None), + mock.call({'id': 'ID2'}, self, self._breakpoints_manager, None) + ]) + self.assertLen(self._breakpoints_manager._active, 2) + + def testSetRepeated(self): + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self.assertEqual(1, self._mock_breakpoint.call_count) + + def testClear(self): + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self._breakpoints_manager.SetActiveBreakpoints([]) + self.assertEqual(1, self._mock_breakpoint.return_value.Clear.call_count) + self.assertEmpty(self._breakpoints_manager._active) + + def testCompleteInvalidId(self): + self._breakpoints_manager.CompleteBreakpoint('ID_INVALID') + + def testComplete(self): + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self._breakpoints_manager.CompleteBreakpoint('ID1') + self.assertEqual(1, self._mock_breakpoint.return_value.Clear.call_count) + + def testSetCompleted(self): + self._breakpoints_manager.CompleteBreakpoint('ID1') + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self.assertEqual(0, self._mock_breakpoint.call_count) + + def testCompletedCleanup(self): + self._breakpoints_manager.CompleteBreakpoint('ID1') + self._breakpoints_manager.SetActiveBreakpoints([]) + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self.assertEqual(1, self._mock_breakpoint.call_count) + + def testMultipleSetDelete(self): + self._breakpoints_manager.SetActiveBreakpoints([{ + 'id': 'ID1' + }, { + 'id': 'ID2' + }, { + 'id': 'ID3' + }, { + 'id': 'ID4' + }]) + self.assertLen(self._breakpoints_manager._active, 4) + + self._breakpoints_manager.SetActiveBreakpoints([{ + 'id': 'ID1' + }, { + 'id': 'ID2' + }, { + 'id': 'ID3' + }, { + 'id': 'ID4' + }]) + self.assertLen(self._breakpoints_manager._active, 4) + + self._breakpoints_manager.SetActiveBreakpoints([]) + self.assertEmpty(self._breakpoints_manager._active) + + def testCombination(self): + self._breakpoints_manager.SetActiveBreakpoints([{ + 'id': 'ID1' + }, { + 'id': 'ID2' + }, { + 'id': 'ID3' + }]) + self.assertLen(self._breakpoints_manager._active, 3) + + self._breakpoints_manager.CompleteBreakpoint('ID2') + self.assertEqual(1, self._mock_breakpoint.return_value.Clear.call_count) + self.assertLen(self._breakpoints_manager._active, 2) + + self._breakpoints_manager.SetActiveBreakpoints([{ + 'id': 'ID2' + }, { + 'id': 'ID3' + }, { + 'id': 'ID4' + }]) + self.assertEqual(2, self._mock_breakpoint.return_value.Clear.call_count) + self.assertLen(self._breakpoints_manager._active, 2) + + self._breakpoints_manager.CompleteBreakpoint('ID2') + self.assertEqual(2, self._mock_breakpoint.return_value.Clear.call_count) + self.assertLen(self._breakpoints_manager._active, 2) + + self._breakpoints_manager.SetActiveBreakpoints([]) + self.assertEqual(4, self._mock_breakpoint.return_value.Clear.call_count) + self.assertEmpty(self._breakpoints_manager._active) + + def testCheckExpirationNoBreakpoints(self): + self._breakpoints_manager.CheckBreakpointsExpiration() + + def testCheckNotExpired(self): + self._breakpoints_manager.SetActiveBreakpoints([{ + 'id': 'ID1' + }, { + 'id': 'ID2' + }]) + self._mock_breakpoint.return_value.GetExpirationTime.return_value = ( + datetime.utcnow() + timedelta(minutes=1)) + self._breakpoints_manager.CheckBreakpointsExpiration() + self.assertEqual( + 0, self._mock_breakpoint.return_value.ExpireBreakpoint.call_count) + + def testCheckExpired(self): + self._breakpoints_manager.SetActiveBreakpoints([{ + 'id': 'ID1' + }, { + 'id': 'ID2' + }]) + self._mock_breakpoint.return_value.GetExpirationTime.return_value = ( + datetime.utcnow() - timedelta(minutes=1)) + self._breakpoints_manager.CheckBreakpointsExpiration() + self.assertEqual( + 2, self._mock_breakpoint.return_value.ExpireBreakpoint.call_count) + + def testCheckExpirationReset(self): + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self._mock_breakpoint.return_value.GetExpirationTime.return_value = ( + datetime.utcnow() + timedelta(minutes=1)) + self._breakpoints_manager.CheckBreakpointsExpiration() + self.assertEqual( + 0, self._mock_breakpoint.return_value.ExpireBreakpoint.call_count) + + self._breakpoints_manager.SetActiveBreakpoints([{ + 'id': 'ID1' + }, { + 'id': 'ID2' + }]) + self._mock_breakpoint.return_value.GetExpirationTime.return_value = ( + datetime.utcnow() - timedelta(minutes=1)) + self._breakpoints_manager.CheckBreakpointsExpiration() + self.assertEqual( + 2, self._mock_breakpoint.return_value.ExpireBreakpoint.call_count) + + def testCheckExpirationCacheNegative(self): + base = datetime(2015, 1, 1) + + with mock.patch.object(breakpoints_manager.BreakpointsManager, + 'GetCurrentTime') as mock_time: + mock_time.return_value = base + + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self._mock_breakpoint.return_value.GetExpirationTime.return_value = ( + base + timedelta(minutes=1)) + + self._breakpoints_manager.CheckBreakpointsExpiration() + self.assertEqual( + 0, self._mock_breakpoint.return_value.ExpireBreakpoint.call_count) + + # The nearest expiration time is cached, so this should have no effect. + self._mock_breakpoint.return_value.GetExpirationTime.return_value = ( + base - timedelta(minutes=1)) + self._breakpoints_manager.CheckBreakpointsExpiration() + self.assertEqual( + 0, self._mock_breakpoint.return_value.ExpireBreakpoint.call_count) + + def testCheckExpirationCachePositive(self): + base = datetime(2015, 1, 1) + + with mock.patch.object(breakpoints_manager.BreakpointsManager, + 'GetCurrentTime') as mock_time: + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self._mock_breakpoint.return_value.GetExpirationTime.return_value = ( + base + timedelta(minutes=1)) + + mock_time.return_value = base + self._breakpoints_manager.CheckBreakpointsExpiration() + self.assertEqual( + 0, self._mock_breakpoint.return_value.ExpireBreakpoint.call_count) + + mock_time.return_value = base + timedelta(minutes=2) + self._breakpoints_manager.CheckBreakpointsExpiration() + self.assertEqual( + 1, self._mock_breakpoint.return_value.ExpireBreakpoint.call_count) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/collector_test.py b/tests/py/collector_test.py new file mode 100644 index 0000000..abc39b2 --- /dev/null +++ b/tests/py/collector_test.py @@ -0,0 +1,1778 @@ +"""Unit test for collector module.""" + +import copy +import datetime +import inspect +import logging +import os +import sys +import time +from unittest import mock + +from absl.testing import absltest + +from googleclouddebugger import collector +from googleclouddebugger import labels + +LOGPOINT_PAUSE_MSG = ( + 'LOGPOINT: Logpoint is paused due to high log rate until log ' + 'quota is restored') + + +def CaptureCollectorWithDefaultLocation(definition, + data_visibility_policy=None): + """Makes a LogCollector with a default location. + + Args: + definition: the rest of the breakpoint definition + data_visibility_policy: optional visibility policy + + Returns: + A LogCollector + """ + definition['location'] = {'path': 'collector_test.py', 'line': 10} + return collector.CaptureCollector(definition, data_visibility_policy) + + +def LogCollectorWithDefaultLocation(definition): + """Makes a LogCollector with a default location. + + Args: + definition: the rest of the breakpoint definition + + Returns: + A LogCollector + """ + definition['location'] = {'path': 'collector_test.py', 'line': 10} + return collector.LogCollector(definition) + + +class CaptureCollectorTest(absltest.TestCase): + """Unit test for capture collector.""" + + def tearDown(self): + collector.CaptureCollector.pretty_printers = [] + + def testCallStackUnlimitedFrames(self): + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.max_frames = 1000 + self._collector.Collect(inspect.currentframe()) + + self.assertGreater(len(self._collector.breakpoint['stackFrames']), 1) + self.assertLess(len(self._collector.breakpoint['stackFrames']), 100) + + def testCallStackLimitedFrames(self): + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.max_frames = 2 + self._collector.Collect(inspect.currentframe()) + + self.assertLen(self._collector.breakpoint['stackFrames'], 2) + + top_frame = self._collector.breakpoint['stackFrames'][0] + self.assertEqual('CaptureCollectorTest.testCallStackLimitedFrames', + top_frame['function']) + self.assertIn('collector_test.py', top_frame['location']['path']) + self.assertGreater(top_frame['location']['line'], 1) + + frame_below = self._collector.breakpoint['stackFrames'][1] + frame_below_line = inspect.currentframe().f_back.f_lineno + self.assertEqual(frame_below_line, frame_below['location']['line']) + + def testCallStackLimitedExpandedFrames(self): + + def CountLocals(frame): + return len(frame['arguments']) + len(frame['locals']) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.max_frames = 3 + self._collector.max_expand_frames = 2 + self._collector.Collect(inspect.currentframe()) + + frames = self._collector.breakpoint['stackFrames'] + self.assertLen(frames, 3) + self.assertGreater(CountLocals(frames[0]), 0) + self.assertGreater(CountLocals(frames[1]), 1) + self.assertEqual(0, CountLocals(frames[2])) + + def testSimpleArguments(self): + + def Method(unused_a, unused_b): + self._collector.Collect(inspect.currentframe()) + top_frame = self._collector.breakpoint['stackFrames'][0] + self.assertListEqual([{ + 'name': 'unused_a', + 'value': '158', + 'type': 'int' + }, { + 'name': 'unused_b', + 'value': "'hello'", + 'type': 'str' + }], top_frame['arguments']) + self.assertEqual('Method', top_frame['function']) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + Method(158, 'hello') + + def testMethodWithFirstArgumentNamedSelf(self): + this = self + + def Method(self, unused_a, unused_b): # pylint: disable=unused-argument + this._collector.Collect(inspect.currentframe()) + top_frame = this._collector.breakpoint['stackFrames'][0] + this.assertListEqual([{ + 'name': 'self', + 'value': "'world'", + 'type': 'str' + }, { + 'name': 'unused_a', + 'value': '158', + 'type': 'int' + }, { + 'name': 'unused_b', + 'value': "'hello'", + 'type': 'str' + }], top_frame['arguments']) + # This is the incorrect function name, but we are validating that no + # exceptions are thrown here. + this.assertEqual('str.Method', top_frame['function']) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + Method('world', 158, 'hello') + + def testMethodWithArgumentNamedSelf(self): + this = self + + def Method(unused_a, unused_b, self): # pylint: disable=unused-argument + this._collector.Collect(inspect.currentframe()) + top_frame = this._collector.breakpoint['stackFrames'][0] + this.assertListEqual([{ + 'name': 'unused_a', + 'value': '158', + 'type': 'int' + }, { + 'name': 'unused_b', + 'value': "'hello'", + 'type': 'str' + }, { + 'name': 'self', + 'value': "'world'", + 'type': 'str' + }], top_frame['arguments']) + this.assertEqual('Method', top_frame['function']) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + Method(158, 'hello', 'world') + + def testClassMethod(self): + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + top_frame = self._collector.breakpoint['stackFrames'][0] + self.assertListEqual([{ + 'name': 'self', + 'varTableIndex': 1 + }], top_frame['arguments']) + self.assertEqual('CaptureCollectorTest.testClassMethod', + top_frame['function']) + + def testClassMethodWithOptionalArguments(self): + + def Method(unused_a, unused_optional='notneeded'): + self._collector.Collect(inspect.currentframe()) + top_frame = self._collector.breakpoint['stackFrames'][0] + self.assertListEqual([{ + 'name': 'unused_a', + 'varTableIndex': 1 + }, { + 'name': 'unused_optional', + 'value': "'notneeded'", + 'type': 'str' + }], top_frame['arguments']) + self.assertEqual('Method', top_frame['function']) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + Method(self) + + def testClassMethodWithPositionalArguments(self): + + def Method(*unused_pos): + self._collector.Collect(inspect.currentframe()) + top_frame = self._collector.breakpoint['stackFrames'][0] + self.assertListEqual([{ + 'name': 'unused_pos', + 'type': 'tuple', + 'members': [{ + 'name': '[0]', + 'value': '1', + 'type': 'int' + }] + }], top_frame['arguments']) + self.assertEqual('Method', top_frame['function']) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + Method(1) + + def testClassMethodWithKeywords(self): + + def Method(**unused_kwd): + self._collector.Collect(inspect.currentframe()) + top_frame = self._collector.breakpoint['stackFrames'][0] + self.assertCountEqual([{ + 'name': "'first'", + 'value': '1', + 'type': 'int' + }, { + 'name': "'second'", + 'value': '2', + 'type': 'int' + }], top_frame['arguments'][0]['members']) + self.assertEqual('Method', top_frame['function']) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + Method(first=1, second=2) + + def testNoLocalVariables(self): + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + top_frame = self._collector.breakpoint['stackFrames'][0] + self.assertEmpty(top_frame['locals']) + self.assertEqual('CaptureCollectorTest.testNoLocalVariables', + top_frame['function']) + + def testRuntimeError(self): + + class BadDict(dict): + + def __init__(self, d): + d['foo'] = 'bar' + super(BadDict, self).__init__(d) + + def __getattribute__(self, attr): + raise RuntimeError('Bogus error') + + class BadType(object): + + def __init__(self): + self.__dict__ = BadDict(self.__dict__) + + unused_a = BadType() + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + var_a = self._Pack(self._LocalByName('unused_a')) + self.assertDictEqual( + { + 'name': 'unused_a', + 'status': { + 'isError': True, + 'refersTo': 'VARIABLE_VALUE', + 'description': { + 'format': 'Failed to capture variable: $0', + 'parameters': ['Bogus error'] + }, + } + }, var_a) + + def testBadDictionary(self): + + class BadDict(dict): + + def items(self): + raise AttributeError('attribute error') + + class BadType(object): + + def __init__(self): + self.good = 1 + self.bad = BadDict() + + unused_a = BadType() + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + var_a = self._Pack(self._LocalByName('unused_a')) + members = var_a['members'] + self.assertLen(members, 2) + self.assertIn({'name': 'good', 'value': '1', 'type': 'int'}, members) + self.assertIn( + { + 'name': 'bad', + 'status': { + 'isError': True, + 'refersTo': 'VARIABLE_VALUE', + 'description': { + 'format': 'Failed to capture variable: $0', + 'parameters': ['attribute error'] + }, + } + }, members) + + def testLocalVariables(self): + unused_a = 8 + unused_b = True + unused_nothing = None + unused_s = 'hippo' + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + top_frame = self._collector.breakpoint['stackFrames'][0] + self.assertLen(top_frame['arguments'], 1) # just self. + self.assertCountEqual([{ + 'name': 'unused_a', + 'value': '8', + 'type': 'int' + }, { + 'name': 'unused_b', + 'value': 'True', + 'type': 'bool' + }, { + 'name': 'unused_nothing', + 'value': 'None' + }, { + 'name': 'unused_s', + 'value': "'hippo'", + 'type': 'str' + }], top_frame['locals']) + + def testLocalVariablesWithBlacklist(self): + unused_a = collector.LineNoFilter() + unused_b = 5 + + # Side effect logic for the mock data visibility object + def IsDataVisible(name): + path_prefix = 'googleclouddebugger.collector.' + if name == path_prefix + 'LineNoFilter': + return (False, 'data blocked') + return (True, None) + + mock_policy = mock.MagicMock() + mock_policy.IsDataVisible.side_effect = IsDataVisible + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}, + mock_policy) + self._collector.Collect(inspect.currentframe()) + top_frame = self._collector.breakpoint['stackFrames'][0] + # Should be blocked + self.assertIn( + { + 'name': 'unused_a', + 'status': { + 'description': { + 'format': 'data blocked' + }, + 'refersTo': 'VARIABLE_NAME', + 'isError': True + } + }, top_frame['locals']) + # Should not be blocked + self.assertIn({ + 'name': 'unused_b', + 'value': '5', + 'type': 'int' + }, top_frame['locals']) + + def testWatchedExpressionsBlacklisted(self): + + class TestClass(object): + + def __init__(self): + self.a = 5 + + unused_a = TestClass() + + # Side effect logic for the mock data visibility object + def IsDataVisible(name): + if name == 'collector_test.TestClass': + return (False, 'data blocked') + return (True, None) + + mock_policy = mock.MagicMock() + mock_policy.IsDataVisible.side_effect = IsDataVisible + + self._collector = CaptureCollectorWithDefaultLocation( + { + 'id': 'BP_ID', + 'expressions': ['unused_a', 'unused_a.a'] + }, mock_policy) + self._collector.Collect(inspect.currentframe()) + # Class should be blocked + self.assertIn( + { + 'name': 'unused_a', + 'status': { + 'description': { + 'format': 'data blocked' + }, + 'refersTo': 'VARIABLE_NAME', + 'isError': True + } + }, self._collector.breakpoint['evaluatedExpressions']) + # TODO: Explicit member SHOULD also be blocked but this is + # currently not implemented. After fixing the implementation, change + # the test below to assert that it's blocked too. + self.assertIn({ + 'name': 'unused_a.a', + 'type': 'int', + 'value': '5' + }, self._collector.breakpoint['evaluatedExpressions']) + + def testLocalsNonTopFrame(self): + + def Method(): + self._collector.Collect(inspect.currentframe()) + self.assertListEqual([{ + 'name': 'self', + 'varTableIndex': 1 + }], self._collector.breakpoint['stackFrames'][1]['arguments']) + self.assertCountEqual([{ + 'name': 'unused_a', + 'value': '47', + 'type': 'int' + }, { + 'name': 'Method', + 'value': 'function Method' + }], self._collector.breakpoint['stackFrames'][1]['locals']) + + unused_a = 47 + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + Method() + + def testDictionaryMaxDepth(self): + d = {} + t = d + for _ in range(10): + t['inner'] = {} + t = t['inner'] + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.default_capture_limits.max_depth = 3 + self._collector.Collect(inspect.currentframe()) + self.assertDictEqual( + { + 'name': + 'd', + 'type': + 'dict', + 'members': [{ + 'name': "'inner'", + 'type': 'dict', + 'members': [{ + 'name': "'inner'", + 'varTableIndex': 0 + }] + }] + }, self._LocalByName('d')) + + def testVectorMaxDepth(self): + l = [] + t = l + for _ in range(10): + t.append([]) + t = t[0] + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.default_capture_limits.max_depth = 3 + self._collector.Collect(inspect.currentframe()) + self.assertDictEqual( + { + 'name': + 'l', + 'type': + 'list', + 'members': [{ + 'name': '[0]', + 'type': 'list', + 'members': [{ + 'name': '[0]', + 'varTableIndex': 0 + }] + }] + }, self._LocalByName('l')) + + def testStringTrimming(self): + unused_s = '123456789' + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.default_capture_limits.max_value_len = 8 + self._collector.Collect(inspect.currentframe()) + self.assertListEqual([{ + 'name': 'unused_s', + 'value': "'12345678...", + 'type': 'str' + }], self._collector.breakpoint['stackFrames'][0]['locals']) + + def testBytearrayTrimming(self): + unused_bytes = bytearray(range(20)) + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.default_capture_limits.max_value_len = 20 + self._collector.Collect(inspect.currentframe()) + self.assertListEqual([{ + 'name': 'unused_bytes', + 'value': r"bytearray(b'\x00\x01\...", + 'type': 'bytearray' + }], self._collector.breakpoint['stackFrames'][0]['locals']) + + def testObject(self): + + class MyClass(object): + + def __init__(self): + self.a = 1 + self.b = 2 + + unused_my = MyClass() + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + var_index = self._LocalByName('unused_my')['varTableIndex'] + self.assertEqual( + __name__ + '.MyClass', + self._collector.breakpoint['variableTable'][var_index]['type']) + self.assertCountEqual([{ + 'name': 'a', + 'value': '1', + 'type': 'int' + }, { + 'name': 'b', + 'value': '2', + 'type': 'int' + }], self._collector.breakpoint['variableTable'][var_index]['members']) + + def testBufferFullLocalRef(self): + + class MyClass(object): + + def __init__(self, data): + self.data = data + + def Method(): + unused_m1 = MyClass('1' * 10000) + unused_m2 = MyClass('2' * 10000) + unused_m3 = MyClass('3' * 10000) + unused_m4 = MyClass('4' * 10000) + unused_m5 = MyClass('5' * 10000) + unused_m6 = MyClass('6' * 10000) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.max_frames = 1 + self._collector.max_size = 48000 + self._collector.default_capture_limits.max_value_len = 10009 + self._collector.Collect(inspect.currentframe()) + + # Verify that 5 locals fit and 1 is out of buffer. + count = {True: 0, False: 0} # captured, not captured + for local in self._collector.breakpoint['stackFrames'][0]['locals']: + var_index = local['varTableIndex'] + self.assertLess(var_index, + len(self._collector.breakpoint['variableTable'])) + if local['name'].startswith('unused_m'): + count[var_index != 0] += 1 + self.assertDictEqual({True: 5, False: 1}, count) + + Method() + + def testBufferFullDictionaryRef(self): + + class MyClass(object): + + def __init__(self, data): + self.data = data + + def Method(): + unused_d1 = {'a': MyClass('1' * 10000)} + unused_d2 = {'b': MyClass('2' * 10000)} + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.max_frames = 1 + self._collector.max_size = 9000 + self._collector.default_capture_limits.max_value_len = 10009 + self._collector.Collect(inspect.currentframe()) + + # Verify that one of {d1,d2} could fit and the other didn't. + var_indexes = [ + self._LocalByName(n)['members'][0]['varTableIndex'] == 0 + for n in ['unused_d1', 'unused_d2'] + ] + self.assertEqual(1, sum(var_indexes)) + + Method() + + def testClassCrossReference(self): + + class MyClass(object): + pass + + m1 = MyClass() + m2 = MyClass() + m1.other = m2 + m2.other = m1 + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + m1_var_index = self._LocalByName('m1')['varTableIndex'] + m2_var_index = self._LocalByName('m2')['varTableIndex'] + + var_table = self._collector.breakpoint['variableTable'] + self.assertDictEqual( + { + 'type': __name__ + '.MyClass', + 'members': [{ + 'name': 'other', + 'varTableIndex': m1_var_index + }] + }, var_table[m2_var_index]) + self.assertDictEqual( + { + 'type': __name__ + '.MyClass', + 'members': [{ + 'name': 'other', + 'varTableIndex': m2_var_index + }] + }, var_table[m1_var_index]) + + def testCaptureVector(self): + unused_my_list = [1, 2, 3, 4, 5] + unused_my_slice = unused_my_list[1:4] + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertDictEqual( + { + 'name': + 'unused_my_list', + 'type': + 'list', + 'members': [{ + 'name': '[0]', + 'value': '1', + 'type': 'int' + }, { + 'name': '[1]', + 'value': '2', + 'type': 'int' + }, { + 'name': '[2]', + 'value': '3', + 'type': 'int' + }, { + 'name': '[3]', + 'value': '4', + 'type': 'int' + }, { + 'name': '[4]', + 'value': '5', + 'type': 'int' + }] + }, self._LocalByName('unused_my_list')) + self.assertDictEqual( + { + 'name': + 'unused_my_slice', + 'type': + 'list', + 'members': [{ + 'name': '[0]', + 'value': '2', + 'type': 'int' + }, { + 'name': '[1]', + 'value': '3', + 'type': 'int' + }, { + 'name': '[2]', + 'value': '4', + 'type': 'int' + }] + }, self._LocalByName('unused_my_slice')) + + def testCaptureDictionary(self): + unused_my_dict = { + 'first': 1, + 3.14: 'pi', + (5, 6): 7, + frozenset([5, 6]): 'frozen', + 'vector': ['odin', 'dva', 'tri'], + 'inner': { + 1: 'one' + }, + 'empty': {} + } + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + frozenset_name = 'frozenset({5, 6})' + self.assertCountEqual([{ + 'name': "'first'", + 'value': '1', + 'type': 'int' + }, { + 'name': '3.14', + 'value': "'pi'", + 'type': 'str' + }, { + 'name': '(5, 6)', + 'value': '7', + 'type': 'int' + }, { + 'name': frozenset_name, + 'value': "'frozen'", + 'type': 'str' + }, { + 'name': + "'vector'", + 'type': + 'list', + 'members': [{ + 'name': '[0]', + 'value': "'odin'", + 'type': 'str' + }, { + 'name': '[1]', + 'value': "'dva'", + 'type': 'str' + }, { + 'name': '[2]', + 'value': "'tri'", + 'type': 'str' + }] + }, { + 'name': "'inner'", + 'type': 'dict', + 'members': [{ + 'name': '1', + 'value': "'one'", + 'type': 'str' + }] + }, { + 'name': + "'empty'", + 'type': + 'dict', + 'members': [{ + 'status': { + 'refersTo': 'VARIABLE_NAME', + 'description': { + 'format': 'Empty dictionary' + } + } + }] + }], + self._LocalByName('unused_my_dict')['members']) + + def testEscapeDictionaryKey(self): + unused_dict = {} + unused_dict[u'\xe0'] = u'\xe0' + unused_dict['\x88'] = '\x88' + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + unicode_type = 'str' + unicode_name = "'\xe0'" + unicode_value = "'\xe0'" + + self.assertCountEqual([{ + 'type': 'str', + 'name': "'\\x88'", + 'value': "'\\x88'" + }, { + 'type': unicode_type, + 'name': unicode_name, + 'value': unicode_value + }], + self._LocalByName('unused_dict')['members']) + + def testOversizedList(self): + unused_big_list = ['x'] * 10000 + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + members = self._LocalByName('unused_big_list')['members'] + + self.assertLen(members, 26) + self.assertDictEqual({ + 'name': '[7]', + 'value': "'x'", + 'type': 'str' + }, members[7]) + self.assertDictEqual( + { + 'status': { + 'refersTo': 'VARIABLE_VALUE', + 'description': { + 'format': ( + 'Only first $0 items were captured. Use in an expression' + ' to see all items.'), + 'parameters': ['25'] + } + } + }, members[25]) + + def testOversizedDictionary(self): + unused_big_dict = {'item' + str(i): i**2 for i in range(26)} + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + members = self._LocalByName('unused_big_dict')['members'] + + self.assertLen(members, 26) + self.assertDictEqual( + { + 'status': { + 'refersTo': 'VARIABLE_VALUE', + 'description': { + 'format': ( + 'Only first $0 items were captured. Use in an expression' + ' to see all items.'), + 'parameters': ['25'] + } + } + }, members[25]) + + def testEmptyDictionary(self): + unused_empty_dict = {} + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertEqual( + { + 'name': + 'unused_empty_dict', + 'type': + 'dict', + 'members': [{ + 'status': { + 'refersTo': 'VARIABLE_NAME', + 'description': { + 'format': 'Empty dictionary' + } + } + }] + }, self._LocalByName('unused_empty_dict')) + + def testEmptyCollection(self): + for unused_c, object_type in [([], 'list'), ((), 'tuple'), (set(), 'set')]: + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertEqual( + { + 'name': + 'unused_c', + 'type': + object_type, + 'members': [{ + 'status': { + 'refersTo': 'VARIABLE_NAME', + 'description': { + 'format': 'Empty collection' + } + } + }] + }, self._Pack(self._LocalByName('unused_c'))) + + def testEmptyClass(self): + + class EmptyObject(object): + pass + + unused_empty_object = EmptyObject() + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertEqual( + { + 'name': + 'unused_empty_object', + 'type': + __name__ + '.EmptyObject', + 'members': [{ + 'status': { + 'refersTo': 'VARIABLE_NAME', + 'description': { + 'format': 'Object has no fields' + } + } + }] + }, self._Pack(self._LocalByName('unused_empty_object'))) + + def testWatchedExpressionsSuccess(self): + unused_dummy_a = 'x' + unused_dummy_b = {1: 2, 3: 'a'} + + self._collector = CaptureCollectorWithDefaultLocation({ + 'id': 'BP_ID', + 'expressions': ['1+2', 'unused_dummy_a*8', 'unused_dummy_b'] + }) + self._collector.Collect(inspect.currentframe()) + self.assertListEqual([{ + 'name': '1+2', + 'value': '3', + 'type': 'int' + }, { + 'name': 'unused_dummy_a*8', + 'value': "'xxxxxxxx'", + 'type': 'str' + }, { + 'name': + 'unused_dummy_b', + 'type': + 'dict', + 'members': [{ + 'name': '1', + 'value': '2', + 'type': 'int' + }, { + 'name': '3', + 'value': "'a'", + 'type': 'str' + }] + }], self._collector.breakpoint['evaluatedExpressions']) + + def testOversizedStringExpression(self): + # This test checks that string expressions are collected first, up to the + # max size. The last 18 characters of the string will be missing due to the + # size for the name (14 bytes), type name (3 bytes), and the opening quote + # (1 byte). This test may be sensitive to minor changes in the collector + # code. If it turns out to break easily, consider simply verifying + # that the first 400 characters are collected, since that should suffice to + # ensure that we're not using the normal limit of 256 bytes. + self._collector = CaptureCollectorWithDefaultLocation({ + 'id': 'BP_ID', + 'expressions': ['unused_dummy_a'] + }) + self._collector.max_size = 500 + unused_dummy_a = '|'.join(['%04d' % i for i in range(5, 510, 5)]) + self._collector.Collect(inspect.currentframe()) + self.assertListEqual([{ + 'name': 'unused_dummy_a', + 'type': 'str', + 'value': "'{0}...".format(unused_dummy_a[0:-18]) + }], self._collector.breakpoint['evaluatedExpressions']) + + def testOversizedListExpression(self): + self._collector = CaptureCollectorWithDefaultLocation({ + 'id': 'BP_ID', + 'expressions': ['unused_dummy_a'] + }) + unused_dummy_a = list(range(0, 100)) + self._collector.Collect(inspect.currentframe()) + # Verify that the list did not get truncated. + self.assertListEqual([{ + 'name': + 'unused_dummy_a', + 'type': + 'list', + 'members': [{ + 'type': 'int', + 'value': str(a), + 'name': '[{0}]'.format(a) + } for a in unused_dummy_a] + }], self._collector.breakpoint['evaluatedExpressions']) + + def testExpressionNullBytes(self): + self._collector = CaptureCollectorWithDefaultLocation({ + 'id': 'BP_ID', + 'expressions': ['\0'] + }) + self._collector.Collect(inspect.currentframe()) + + evaluated_expressions = self._collector.breakpoint['evaluatedExpressions'] + self.assertLen(evaluated_expressions, 1) + self.assertTrue(evaluated_expressions[0]['status']['isError']) + + def testSyntaxErrorExpression(self): + self._collector = CaptureCollectorWithDefaultLocation({ + 'id': 'BP_ID', + 'expressions': ['2+'] + }) + self._collector.Collect(inspect.currentframe()) + + evaluated_expressions = self._collector.breakpoint['evaluatedExpressions'] + self.assertLen(evaluated_expressions, 1) + self.assertTrue(evaluated_expressions[0]['status']['isError']) + self.assertEqual('VARIABLE_NAME', + evaluated_expressions[0]['status']['refersTo']) + + def testExpressionException(self): + unused_dummy_a = 1 + unused_dummy_b = 0 + self._collector = CaptureCollectorWithDefaultLocation({ + 'id': 'BP_ID', + 'expressions': ['unused_dummy_a/unused_dummy_b'] + }) + self._collector.Collect(inspect.currentframe()) + + zero_division_msg = 'division by zero' + + self.assertListEqual([{ + 'name': 'unused_dummy_a/unused_dummy_b', + 'status': { + 'isError': True, + 'refersTo': 'VARIABLE_VALUE', + 'description': { + 'format': 'Exception occurred: $0', + 'parameters': [zero_division_msg] + } + } + }], self._collector.breakpoint['evaluatedExpressions']) + + def testMutableExpression(self): + + def ChangeA(): + self._a += 1 + + self._a = 0 + ChangeA() + self._collector = CaptureCollectorWithDefaultLocation({ + 'id': 'BP_ID', + 'expressions': ['ChangeA()'] + }) + self._collector.Collect(inspect.currentframe()) + + self.assertEqual(1, self._a) + self.assertListEqual([{ + 'name': 'ChangeA()', + 'status': { + 'isError': True, + 'refersTo': 'VARIABLE_VALUE', + 'description': { + 'format': + 'Exception occurred: $0', + 'parameters': [('Only immutable methods can be ' + 'called from expressions')] + } + } + }], self._collector.breakpoint['evaluatedExpressions']) + + def testPrettyPrinters(self): + + class MyClass(object): + pass + + def PrettyPrinter1(obj): + if obj != unused_obj1: + return None + return ((('name1_%d' % i, '1_%d' % i) for i in range(2)), 'pp-type1') + + def PrettyPrinter2(obj): + if obj != unused_obj2: + return None + return ((('name2_%d' % i, '2_%d' % i) for i in range(3)), 'pp-type2') + + collector.CaptureCollector.pretty_printers += [ + PrettyPrinter1, PrettyPrinter2 + ] + + unused_obj1 = MyClass() + unused_obj2 = MyClass() + unused_obj3 = MyClass() + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + obj_vars = [ + self._Pack(self._LocalByName('unused_obj%d' % i)) for i in range(1, 4) + ] + + self.assertListEqual([{ + 'name': + 'unused_obj1', + 'type': + 'pp-type1', + 'members': [{ + 'name': 'name1_0', + 'value': "'1_0'", + 'type': 'str' + }, { + 'name': 'name1_1', + 'value': "'1_1'", + 'type': 'str' + }] + }, { + 'name': + 'unused_obj2', + 'type': + 'pp-type2', + 'members': [{ + 'name': 'name2_0', + 'value': "'2_0'", + 'type': 'str' + }, { + 'name': 'name2_1', + 'value': "'2_1'", + 'type': 'str' + }, { + 'name': 'name2_2', + 'value': "'2_2'", + 'type': 'str' + }] + }, { + 'name': + 'unused_obj3', + 'type': + __name__ + '.MyClass', + 'members': [{ + 'status': { + 'refersTo': 'VARIABLE_NAME', + 'description': { + 'format': 'Object has no fields' + } + } + }] + }], obj_vars) + + def testDateTime(self): + unused_datetime = datetime.datetime(2014, 6, 11, 2, 30) + unused_date = datetime.datetime(1980, 3, 1) + unused_time = datetime.time(18, 43, 11) + unused_timedelta = datetime.timedelta(days=3, microseconds=8237) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertDictEqual( + { + 'name': 'unused_datetime', + 'type': 'datetime.datetime', + 'value': '2014-06-11 02:30:00' + }, self._Pack(self._LocalByName('unused_datetime'))) + + self.assertDictEqual( + { + 'name': 'unused_date', + 'type': 'datetime.datetime', + 'value': '1980-03-01 00:00:00' + }, self._Pack(self._LocalByName('unused_date'))) + + self.assertDictEqual( + { + 'name': 'unused_time', + 'type': 'datetime.time', + 'value': '18:43:11' + }, self._Pack(self._LocalByName('unused_time'))) + + self.assertDictEqual( + { + 'name': 'unused_timedelta', + 'type': 'datetime.timedelta', + 'value': '3 days, 0:00:00.008237' + }, self._Pack(self._LocalByName('unused_timedelta'))) + + def testException(self): + unused_exception = ValueError('arg1', 2, [3]) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + obj = self._Pack(self._LocalByName('unused_exception')) + + self.assertEqual('unused_exception', obj['name']) + self.assertEqual('ValueError', obj['type']) + self.assertListEqual([{ + 'value': "'arg1'", + 'type': 'str', + 'name': '[0]' + }, { + 'value': '2', + 'type': 'int', + 'name': '[1]' + }, { + 'members': [{ + 'value': '3', + 'type': 'int', + 'name': '[0]' + }], + 'type': 'list', + 'name': '[2]' + }], obj['members']) + + def testRequestLogIdCapturing(self): + collector.request_log_id_collector = lambda: 'test_log_id' + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertIn('labels', self._collector.breakpoint) + self.assertEqual( + 'test_log_id', + self._collector.breakpoint['labels'][labels.Breakpoint.REQUEST_LOG_ID]) + + def testRequestLogIdCapturingNoId(self): + collector.request_log_id_collector = lambda: None + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + def testRequestLogIdCapturingNoCollector(self): + collector.request_log_id_collector = None + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + def testUserIdSuccess(self): + collector.user_id_collector = lambda: ('mdb_user', 'noogler') + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertIn('evaluatedUserId', self._collector.breakpoint) + self.assertEqual({ + 'kind': 'mdb_user', + 'id': 'noogler' + }, self._collector.breakpoint['evaluatedUserId']) + + def testUserIdIsNone(self): + collector.user_id_collector = lambda: (None, None) + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertNotIn('evaluatedUserId', self._collector.breakpoint) + + def testUserIdNoKind(self): + collector.user_id_collector = lambda: (None, 'noogler') + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertNotIn('evaluatedUserId', self._collector.breakpoint) + + def testUserIdNoValue(self): + collector.user_id_collector = lambda: ('mdb_user', None) + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertNotIn('evaluatedUserId', self._collector.breakpoint) + + def _LocalByName(self, name, frame=0): + for local in self._collector.breakpoint['stackFrames'][frame]['locals']: + if local['name'] == name: + return local + self.fail('Local %s not found in frame %d' % (name, frame)) + + def _Pack(self, variable): + """Embeds variables referenced through var_index.""" + packed_variable = copy.copy(variable) + + var_index = variable.get('varTableIndex') + if var_index is not None: + packed_variable.update( + self._collector.breakpoint['variableTable'][var_index]) + del packed_variable['varTableIndex'] + + if 'members' in packed_variable: + packed_variable['members'] = [ + self._Pack(member) for member in packed_variable['members'] + ] + + return packed_variable + + +class LogCollectorTest(absltest.TestCase): + """Unit test for log collector.""" + + def setUp(self): + self._logger = logging.getLogger('test') + + class LogVerifier(logging.Handler): + + def __init__(self): + super(LogVerifier, self).__init__() + self._received_records = [] + + def emit(self, record): + self._received_records.append(record) + + def GotMessage(self, + msg, + level=logging.INFO, + line_number=10, + func_name=None): + """Checks that the given message was logged correctly. + + This method verifies both the contents and the source location of the + message match expectations. + + Args: + msg: The expected message + level: The expected logging level. + line_number: The expected line number. + func_name: If specified, the expected log record must have a funcName + equal to this value. + Returns: + True iff the oldest unverified message matches the given attributes. + """ + record = self._received_records.pop(0) + frame = inspect.currentframe().f_back + if level != record.levelno: + logging.error('Expected log level %d, got %d (%s)', level, + record.levelno, record.levelname) + return False + if msg != record.msg: + logging.error('Expected msg "%s", received "%s"', msg, record.msg) + return False + pathname = collector.NormalizePath(frame.f_code.co_filename) + if pathname != record.pathname: + logging.error('Expected pathname "%s", received "%s"', pathname, + record.pathname) + return False + if os.path.basename(pathname) != record.filename: + logging.error('Expected filename "%s", received "%s"', + os.path.basename(pathname), record.filename) + return False + if func_name and func_name != record.funcName: + logging.error('Expected function "%s", received "%s"', func_name, + record.funcName) + return False + if line_number and record.lineno != line_number: + logging.error('Expected lineno %d, received %d', line_number, + record.lineno) + return False + for attr in ['cdbg_pathname', 'cdbg_lineno']: + if hasattr(record, attr): + logging.error('Attribute %s still present in log record', attr) + return False + return True + + def CheckMessageSafe(self, msg): + """Checks that the given message was logged correctly. + + Unlike GotMessage, this will only check the contents, and will not log + an error or pop the record if the message does not match. + + Args: + msg: The expected message + Returns: + True iff the oldest unverified message matches the given attributes. + """ + record = self._received_records[0] + if msg != record.msg: + print(record.msg) + return False + self._received_records.pop(0) + return True + + self._verifier = LogVerifier() + self._logger.addHandler(self._verifier) + self._logger.setLevel(logging.INFO) + collector.SetLogger(self._logger) + + # Give some time for the global quota to recover + time.sleep(0.1) + + def tearDown(self): + self._logger.removeHandler(self._verifier) + + def ResetGlobalLogQuota(self): + # The global log quota takes up to 5 seconds to fully fill back up to + # capacity (kDynamicLogCapacityFactor is 5). The capacity is 5 times the per + # second fill rate. The best we can do is a sleep, since the global + # leaky_bucket instance is inaccessible to the test. + time.sleep(5.0) + + def ResetGlobalLogBytesQuota(self): + # The global log bytes quota takes up to 2 seconds to fully fill back up to + # capacity (kDynamicLogBytesCapacityFactor is 2). The capacity is twice the + # per second fill rate. The best we can do is a sleep, since the global + # leaky_bucket instance is inaccessible to the test. + time.sleep(2.0) + + def testLogQuota(self): + # Attempt to get to a known starting state by letting the global quota fully + # recover so the ordering of tests ideally doesn't affect this test. + self.ResetGlobalLogQuota() + bucket_max_capacity = 250 + log_collector = LogCollectorWithDefaultLocation({ + 'logMessageFormat': '$0', + 'expressions': ['i'] + }) + for i in range(0, bucket_max_capacity * 2): + self.assertIsNone(log_collector.Log(inspect.currentframe())) + if not self._verifier.CheckMessageSafe('LOGPOINT: %s' % i): + self.assertGreaterEqual(i, bucket_max_capacity, + 'Log quota exhausted earlier than expected') + self.assertTrue( + self._verifier.CheckMessageSafe(LOGPOINT_PAUSE_MSG), + 'Quota hit message not logged') + time.sleep(0.6) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.CheckMessageSafe('LOGPOINT: %s' % i), + 'Logging not resumed after quota recovery time') + return + self.fail('Logging was never paused when quota was exceeded') + + def testLogBytesQuota(self): + # Attempt to get to a known starting state by letting the global quota fully + # recover so the ordering of tests ideally doesn't affect this test. + self.ResetGlobalLogBytesQuota() + + # Default capacity is 40960, though based on how the leaky bucket is + # implemented, it can allow effectively twice that amount to go out in a + # very short time frame. So the third 30k message should pause. + msg = ' ' * 30000 + log_collector = LogCollectorWithDefaultLocation({'logMessageFormat': msg}) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage('LOGPOINT: ' + msg)) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage('LOGPOINT: ' + msg)) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.CheckMessageSafe(LOGPOINT_PAUSE_MSG), + 'Quota hit message not logged') + time.sleep(0.6) + log_collector._definition['logMessageFormat'] = 'hello' + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage('LOGPOINT: hello'), + 'Logging was not resumed after quota recovery time') + + def testMissingLogLevel(self): + # Missing is equivalent to INFO. + log_collector = LogCollectorWithDefaultLocation( + {'logMessageFormat': 'hello'}) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage('LOGPOINT: hello')) + + def testUndefinedLogLevel(self): + collector.log_info_message = None + log_collector = LogCollectorWithDefaultLocation({'logLevel': 'INFO'}) + self.assertDictEqual( + { + 'isError': True, + 'description': { + 'format': 'Log action on a breakpoint not supported' + } + }, log_collector.Log(inspect.currentframe())) + + def testLogInfo(self): + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': 'hello' + }) + log_collector._definition['location']['line'] = 20 + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: hello', + func_name='LogCollectorTest.testLogInfo', + line_number=20)) + + def testLogWarning(self): + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'WARNING', + 'logMessageFormat': 'hello' + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: hello', + level=logging.WARNING, + func_name='LogCollectorTest.testLogWarning')) + + def testLogError(self): + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'ERROR', + 'logMessageFormat': 'hello' + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: hello', + level=logging.ERROR, + func_name='LogCollectorTest.testLogError')) + + def testBadExpression(self): + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': 'a=$0, b=$1', + 'expressions': ['-', '+'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + if sys.version_info.minor < 10: + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: a=, b=')) + else: + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: a=, ' + 'b=')) + + def testDollarEscape(self): + unused_integer = 12345 + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$ $$ $$$ $$$$ $0 $$0 $$$0 $$$$0 $1 hello', + 'expressions': ['unused_integer'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + msg = 'LOGPOINT: $ $ $$ $$ 12345 $0 $12345 $$0 hello' + self.assertTrue(self._verifier.GotMessage(msg)) + + def testInvalidExpressionIndex(self): + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': 'a=$0' + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage('LOGPOINT: a=')) + + def testException(self): + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['[][1]'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: ')) + + def testMutableExpression(self): + + def MutableMethod(): # pylint: disable=unused-variable + self.abc = None + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['MutableMethod()'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: ')) + + def testNone(self): + unused_none = None + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_none'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage('LOGPOINT: None')) + + def testPrimitives(self): + unused_boolean = True + unused_integer = 12345 + unused_string = 'hello' + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0,$1,$2', + 'expressions': ['unused_boolean', 'unused_integer', 'unused_string'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage("LOGPOINT: True,12345,'hello'")) + + def testLongString(self): + unused_string = '1234567890' + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_string'] + }) + log_collector.max_value_len = 9 + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage("LOGPOINT: '123456789...")) + + def testLongBytes(self): + unused_bytes = bytearray([i for i in range(20)]) + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_bytes'] + }) + log_collector.max_value_len = 20 + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage(r"LOGPOINT: bytearray(b'\x00\x01\...")) + + def testDate(self): + unused_datetime = datetime.datetime(2014, 6, 11, 2, 30) + unused_date = datetime.datetime(1980, 3, 1) + unused_time = datetime.time(18, 43, 11) + unused_timedelta = datetime.timedelta(days=3, microseconds=8237) + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': + 'INFO', + 'logMessageFormat': + '$0;$1;$2;$3', + 'expressions': [ + 'unused_datetime', 'unused_date', 'unused_time', 'unused_timedelta' + ] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: 2014-06-11 02:30:00;1980-03-01 00:00:00;' + '18:43:11;3 days, 0:00:00.008237')) + + def testSet(self): + unused_set = set(['a']) + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_set'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage("LOGPOINT: {'a'}")) + + def testTuple(self): + unused_tuple = (1, 2, 3, 4, 5) + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_tuple'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage('LOGPOINT: (1, 2, 3, 4, 5)')) + + def testList(self): + unused_list = ['a', 'b', 'c'] + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_list'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage("LOGPOINT: ['a', 'b', 'c']")) + + def testOversizedList(self): + unused_list = [1, 2, 3, 4] + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_list'] + }) + log_collector.max_list_items = 3 + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage('LOGPOINT: [1, 2, 3, ...]')) + + def testSlice(self): + unused_slice = slice(1, 10) + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_slice'] + }) + collector.max_list_items = 3 + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage('LOGPOINT: slice(1, 10, None)')) + + def testMap(self): + unused_map = {'a': 1} + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_map'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage("LOGPOINT: {'a': 1}")) + + def testObject(self): + + class MyClass(object): + + def __init__(self): + self.some = 'thing' + + unused_my = MyClass() + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_my'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage("LOGPOINT: {'some': 'thing'}")) + + def testNestedBelowLimit(self): + unused_list = [1, [2], [1, 2, 3], [1, [1, 2, 3]], 5] + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_list'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: [1, [2], [1, 2, 3], [1, [1, 2, 3]], 5]')) + + def testNestedAtLimits(self): + unused_list = [ + 1, [1, 2, 3, 4, 5], [[1, 2, 3, 4, 5], 2, 3, 4, 5], 4, 5, 6, 7, 8, 9 + ] + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_list'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: [1, [1, 2, 3, 4, 5], [[1, 2, 3, 4, 5], 2, 3, 4, 5], ' + '4, 5, 6, 7, 8, 9]')) + + def testNestedRecursionLimit(self): + unused_list = [1, [[2, [3]], 4], 5] + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_list'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage('LOGPOINT: [1, [[2, %s], 4], 5]' % type([]))) + + def testNestedRecursionItemLimits(self): + unused_list = [1, [1, [1, [2], 3, 4], 3, 4], 3, 4] + + list_type = "" + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_list'] + }) + log_collector.max_list_items = 3 + log_collector.max_sublist_items = 3 + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: [1, [1, [1, %s, 3, ...], 3, ...], 3, ...]' % list_type)) + + def testDetermineType(self): + builtin_prefix = 'builtins.' + path_prefix = 'googleclouddebugger.collector.' + test_data = ( + (builtin_prefix + 'int', 5), + (builtin_prefix + 'str', 'hello'), + (builtin_prefix + 'function', collector.DetermineType), + (path_prefix + 'LineNoFilter', collector.LineNoFilter()), + ) + + for type_string, value in test_data: + self.assertEqual(type_string, collector.DetermineType(value)) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/error_data_visibility_policy_test.py b/tests/py/error_data_visibility_policy_test.py new file mode 100644 index 0000000..6c8b6d9 --- /dev/null +++ b/tests/py/error_data_visibility_policy_test.py @@ -0,0 +1,17 @@ +"""Tests for googleclouddebugger.error_data_visibility_policy.""" + +from absl.testing import absltest +from googleclouddebugger import error_data_visibility_policy + + +class ErrorDataVisibilityPolicyTest(absltest.TestCase): + + def testIsDataVisible(self): + policy = error_data_visibility_policy.ErrorDataVisibilityPolicy( + 'An error message.') + + self.assertEqual((False, 'An error message.'), policy.IsDataVisible('foo')) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/firebase_client_test.py b/tests/py/firebase_client_test.py new file mode 100644 index 0000000..d72c68c --- /dev/null +++ b/tests/py/firebase_client_test.py @@ -0,0 +1,875 @@ +"""Unit tests for firebase_client module.""" + +import copy +import os +import sys +import tempfile +import time +from unittest import mock +from unittest.mock import ANY +from unittest.mock import MagicMock +from unittest.mock import call +from unittest.mock import patch +import requests +import requests_mock + +from googleclouddebugger import version +from googleclouddebugger import firebase_client + +from absl.testing import absltest +from absl.testing import parameterized + +import firebase_admin.credentials +from firebase_admin.exceptions import FirebaseError +from firebase_admin.exceptions import NotFoundError + +TEST_PROJECT_ID = 'test-project-id' +METADATA_PROJECT_URL = ('http://metadata.google.internal/computeMetadata/' + 'v1/project/project-id') + + +class FakeEvent: + + def __init__(self, event_type, path, data): + self.event_type = event_type + self.path = path + self.data = data + + +class FakeReference: + + def __init__(self): + self.subscriber = None + + def listen(self, callback): + self.subscriber = callback + + def update(self, event_type, path, data): + if self.subscriber: + event = FakeEvent(event_type, path, data) + self.subscriber(event) + + +class FirebaseClientTest(parameterized.TestCase): + """Simulates service account authentication.""" + + def setUp(self): + version.__version__ = 'test' + + self._client = firebase_client.FirebaseClient() + + self.breakpoints_changed_count = 0 + self.breakpoints = {} + + # Speed up the delays for retry loops. + for backoff in [ + self._client.connect_backoff, self._client.register_backoff, + self._client.subscribe_backoff, self._client.update_backoff + ]: + backoff.min_interval_sec /= 100000.0 + backoff.max_interval_sec /= 100000.0 + backoff._current_interval_sec /= 100000.0 + + # Set up patchers. + patcher = patch('firebase_admin.initialize_app') + self._mock_initialize_app = patcher.start() + self.addCleanup(patcher.stop) + + patcher = patch('firebase_admin.delete_app') + self._mock_delete_app = patcher.start() + self.addCleanup(patcher.stop) + + patcher = patch('firebase_admin.db.reference') + self._mock_db_ref = patcher.start() + self.addCleanup(patcher.stop) + + # Set up the mocks for the database refs. + self._firebase_app = 'FIREBASE_APP_HANDLE' + self._mock_initialize_app.return_value = self._firebase_app + self._mock_schema_version_ref = MagicMock() + self._mock_schema_version_ref.get.return_value = "2" + self._mock_presence_ref = MagicMock() + self._mock_presence_ref.get.return_value = None + self._mock_active_ref = MagicMock() + self._mock_register_ref = MagicMock() + self._fake_subscribe_ref = FakeReference() + + # Setup common happy path reference sequence: + # cdbg/schema_version + # cdbg/debuggees/{debuggee_id}/registrationTimeUnixMsec + # cdbg/debuggees/{debuggee_id} + # cdbg/breakpoints/{debuggee_id}/active + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, self._mock_presence_ref, + self._mock_register_ref, self._fake_subscribe_ref + ] + + def tearDown(self): + self._client.Stop() + + def testSetupAuthDefault(self): + # By default, we try getting the project id from the metadata server. + # Note that actual credentials are not fetched. + with requests_mock.Mocker() as m: + m.get(METADATA_PROJECT_URL, text=TEST_PROJECT_ID) + + self._client.SetupAuth() + + self.assertEqual(TEST_PROJECT_ID, self._client._project_id) + + def testSetupAuthOverrideProjectIdNumber(self): + # If a project id is provided, we use it. + project_id = 'project2' + self._client.SetupAuth(project_id=project_id) + + self.assertEqual(project_id, self._client._project_id) + + def testSetupAuthServiceAccountJsonAuth(self): + # We'll load credentials from the provided file (mocked for simplicity) + with mock.patch.object(firebase_admin.credentials, + 'Certificate') as firebase_certificate: + json_file = tempfile.NamedTemporaryFile() + # And load the project id from the file as well. + with open(json_file.name, 'w', encoding='utf-8') as f: + f.write(f'{{"project_id": "{TEST_PROJECT_ID}"}}') + self._client.SetupAuth(service_account_json_file=json_file.name) + + firebase_certificate.assert_called_with(json_file.name) + self.assertEqual(TEST_PROJECT_ID, self._client._project_id) + + def testSetupAuthNoProjectId(self): + # There will be an exception raised if we try to contact the metadata + # server on a non-gcp machine. + with requests_mock.Mocker() as m: + m.get(METADATA_PROJECT_URL, exc=requests.exceptions.RequestException) + + with self.assertRaises(firebase_client.NoProjectIdError): + self._client.SetupAuth() + + def testStart(self): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + debuggee_id = self._client._debuggee_id + + self._mock_initialize_app.assert_called_with( + None, {'databaseURL': f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com'}, + name='cdbg') + self.assertEqual([ + call(f'cdbg/schema_version', self._firebase_app), + call(f'cdbg/debuggees/{debuggee_id}/registrationTimeUnixMsec', + self._firebase_app), + call(f'cdbg/debuggees/{debuggee_id}', self._firebase_app), + call(f'cdbg/breakpoints/{debuggee_id}/active', self._firebase_app) + ], self._mock_db_ref.call_args_list) + + # Verify that the register call has been made. + expected_data = copy.deepcopy(self._client._GetDebuggee()) + expected_data['registrationTimeUnixMsec'] = {'.sv': 'timestamp'} + expected_data['lastUpdateTimeUnixMsec'] = {'.sv': 'timestamp'} + self._mock_register_ref.set.assert_called_once_with(expected_data) + + def testStartCustomDbUrlConfigured(self): + self._client.SetupAuth( + project_id=TEST_PROJECT_ID, + database_url='https://custom-db.firebaseio.com') + self._client.Start() + self._client.connection_complete.wait() + + debuggee_id = self._client._debuggee_id + + self._mock_initialize_app.assert_called_once_with( + None, {'databaseURL': 'https://custom-db.firebaseio.com'}, name='cdbg') + + def testStartConnectFallsBackToDefaultRtdb(self): + # A new schema_version ref will be fetched each time + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, self._mock_schema_version_ref, + self._mock_presence_ref, self._mock_register_ref, + self._fake_subscribe_ref + ] + + # Fail on the '-cdbg' instance test, succeed on the '-default-rtdb' one. + self._mock_schema_version_ref.get.side_effect = [ + NotFoundError("Not found", http_response=404), '2' + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.connection_complete.wait() + + self.assertEqual([ + call( + None, + {'databaseURL': f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com'}, + name='cdbg'), + call( + None, { + 'databaseURL': + f'https://{TEST_PROJECT_ID}-default-rtdb.firebaseio.com' + }, + name='cdbg') + ], self._mock_initialize_app.call_args_list) + + self.assertEqual(1, self._mock_delete_app.call_count) + + def testStartConnectFailsThenSucceeds(self): + # A new schema_version ref will be fetched each time + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, self._mock_schema_version_ref, + self._mock_schema_version_ref, self._mock_presence_ref, + self._mock_register_ref, self._fake_subscribe_ref + ] + + # Completely fail on the initial attempt at reaching a DB, then succeed on + # 2nd attempt. One full attempt will try the '-cdbg' db instance followed by + # the '-default-rtdb' one. + self._mock_schema_version_ref.get.side_effect = [ + NotFoundError("Not found", http_response=404), + NotFoundError("Not found", http_response=404), '2' + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.connection_complete.wait() + + self.assertEqual([ + call( + None, + {'databaseURL': f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com'}, + name='cdbg'), + call( + None, { + 'databaseURL': + f'https://{TEST_PROJECT_ID}-default-rtdb.firebaseio.com' + }, + name='cdbg'), + call( + None, + {'databaseURL': f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com'}, + name='cdbg') + ], self._mock_initialize_app.call_args_list) + + self.assertEqual(2, self._mock_delete_app.call_count) + + def testStartAlreadyPresent(self): + # Create a mock for just this test that claims the debuggee is registered. + mock_presence_ref = MagicMock() + mock_presence_ref.get.return_value = 'present!' + + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, mock_presence_ref, self._mock_active_ref, + self._fake_subscribe_ref + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + debuggee_id = self._client._debuggee_id + + self.assertEqual([ + call(f'cdbg/schema_version', self._firebase_app), + call(f'cdbg/debuggees/{debuggee_id}/registrationTimeUnixMsec', + self._firebase_app), + call(f'cdbg/debuggees/{debuggee_id}/lastUpdateTimeUnixMsec', + self._firebase_app), + call(f'cdbg/breakpoints/{debuggee_id}/active', self._firebase_app) + ], self._mock_db_ref.call_args_list) + + # Verify that the register call has been made. + self._mock_active_ref.set.assert_called_once_with({'.sv': 'timestamp'}) + + def testStartRegisterRetry(self): + # A new set of db refs are fetched on each retry. + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, self._mock_presence_ref, + self._mock_register_ref, self._mock_presence_ref, + self._mock_register_ref, self._fake_subscribe_ref + ] + + # Fail once, then succeed on retry. + self._mock_register_ref.set.side_effect = [FirebaseError(1, 'foo'), None] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.registration_complete.wait() + + self.assertEqual(2, self._mock_register_ref.set.call_count) + + def testStartSubscribeRetry(self): + mock_subscribe_ref = MagicMock() + mock_subscribe_ref.listen.side_effect = FirebaseError(1, 'foo') + + # A new db ref is fetched on each retry. + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, + self._mock_presence_ref, + self._mock_register_ref, + mock_subscribe_ref, # Fail the first time + self._fake_subscribe_ref # Succeed the second time + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + self.assertEqual(5, self._mock_db_ref.call_count) + + def testMarkActiveTimer(self): + # Make sure that there are enough refs queued up. + refs = list(self._mock_db_ref.side_effect) + refs.extend([self._mock_active_ref] * 10) + self._mock_db_ref.side_effect = refs + + # Speed things WAY up rather than waiting for hours. + self._client._mark_active_interval_sec = 0.1 + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + # wait long enough for the timer to trigger a few times. + time.sleep(0.5) + + print(f'Timer triggered {self._mock_active_ref.set.call_count} times') + self.assertTrue(self._mock_active_ref.set.call_count > 3) + self._mock_active_ref.set.assert_called_with({'.sv': 'timestamp'}) + + def testBreakpointSubscription(self): + # This class will keep track of the breakpoint updates and will check + # them against expectations. + class ResultChecker: + + def __init__(self, expected_results, test): + self._expected_results = expected_results + self._test = test + self._change_count = 0 + + def callback(self, new_breakpoints): + self._test.assertEqual(self._expected_results[self._change_count], + new_breakpoints) + self._change_count += 1 + + breakpoints = [ + { + 'id': 'breakpoint-0', + 'location': { + 'path': 'foo.py', + 'line': 18 + } + }, + { + 'id': 'breakpoint-1', + 'location': { + 'path': 'bar.py', + 'line': 23 + } + }, + { + 'id': 'breakpoint-2', + 'location': { + 'path': 'baz.py', + 'line': 45 + } + }, + ] + + expected_results = [[breakpoints[0]], [breakpoints[0], breakpoints[1]], + [breakpoints[0], breakpoints[1], breakpoints[2]], + [breakpoints[1], breakpoints[2]], + [breakpoints[1], breakpoints[2]]] + result_checker = ResultChecker(expected_results, self) + + self._client.on_active_breakpoints_changed = result_checker.callback + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + # Send in updates to trigger the subscription callback. + + # Initial state. + self._fake_subscribe_ref.update('put', '/', + {breakpoints[0]['id']: breakpoints[0]}) + # Add a breakpoint via patch. + self._fake_subscribe_ref.update('patch', '/', + {breakpoints[1]['id']: breakpoints[1]}) + # Add a breakpoint via put. + self._fake_subscribe_ref.update('put', f'/{breakpoints[2]["id"]}', + breakpoints[2]) + # Delete a breakpoint. + self._fake_subscribe_ref.update('put', f'/{breakpoints[0]["id"]}', None) + # Delete the breakpoint a second time; should handle this gracefully. + self._fake_subscribe_ref.update('put', f'/{breakpoints[0]["id"]}', None) + + self.assertEqual(len(expected_results), result_checker._change_count) + + def testEnqueueBreakpointUpdate(self): + active_ref_mock = MagicMock() + snapshot_ref_mock = MagicMock() + final_ref_mock = MagicMock() + + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, self._mock_presence_ref, + self._mock_register_ref, self._fake_subscribe_ref, active_ref_mock, + snapshot_ref_mock, final_ref_mock + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + debuggee_id = self._client._debuggee_id + breakpoint_id = 'breakpoint-0' + + input_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'evaluatedExpressions': ['expressions go here'], + 'stackFrames': ['stuff goes here'], + 'variableTable': ['lots', 'of', 'variables'], + } + short_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'action': 'CAPTURE', + 'finalTimeUnixMsec': { + '.sv': 'timestamp' + } + } + full_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'action': 'CAPTURE', + 'evaluatedExpressions': ['expressions go here'], + 'stackFrames': ['stuff goes here'], + 'variableTable': ['lots', 'of', 'variables'], + 'finalTimeUnixMsec': { + '.sv': 'timestamp' + } + } + + self._client.EnqueueBreakpointUpdate(input_breakpoint) + + # Wait for the breakpoint to be sent. + while self._client._transmission_queue: + time.sleep(0.1) + + db_ref_calls = self._mock_db_ref.call_args_list + self.assertEqual( + call(f'cdbg/breakpoints/{debuggee_id}/active/{breakpoint_id}', + self._firebase_app), db_ref_calls[4]) + self.assertEqual( + call(f'cdbg/breakpoints/{debuggee_id}/snapshot/{breakpoint_id}', + self._firebase_app), db_ref_calls[5]) + self.assertEqual( + call(f'cdbg/breakpoints/{debuggee_id}/final/{breakpoint_id}', + self._firebase_app), db_ref_calls[6]) + + active_ref_mock.delete.assert_called_once() + snapshot_ref_mock.set.assert_called_once_with(full_breakpoint) + final_ref_mock.set.assert_called_once_with(short_breakpoint) + + def testEnqueueBreakpointUpdateWithLogpoint(self): + active_ref_mock = MagicMock() + final_ref_mock = MagicMock() + + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, self._mock_presence_ref, + self._mock_register_ref, self._fake_subscribe_ref, active_ref_mock, + final_ref_mock + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + debuggee_id = self._client._debuggee_id + breakpoint_id = 'logpoint-0' + + input_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'action': 'LOG', + 'isFinalState': True, + 'status': { + 'isError': True, + 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', + }, + } + output_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'action': 'LOG', + 'status': { + 'isError': True, + 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', + }, + 'finalTimeUnixMsec': { + '.sv': 'timestamp' + } + } + + self._client.EnqueueBreakpointUpdate(input_breakpoint) + + # Wait for the breakpoint to be sent. + while self._client._transmission_queue: + time.sleep(0.1) + + db_ref_calls = self._mock_db_ref.call_args_list + self.assertEqual( + call(f'cdbg/breakpoints/{debuggee_id}/active/{breakpoint_id}', + self._firebase_app), db_ref_calls[4]) + self.assertEqual( + call(f'cdbg/breakpoints/{debuggee_id}/final/{breakpoint_id}', + self._firebase_app), db_ref_calls[5]) + + active_ref_mock.delete.assert_called_once() + final_ref_mock.set.assert_called_once_with(output_breakpoint) + + # Make sure that the snapshot node was not accessed. + self.assertTrue( + call(f'cdbg/breakpoints/{debuggee_id}/snapshot/{breakpoint_id}', ANY) + not in db_ref_calls) + + def testEnqueueBreakpointUpdateRetry(self): + active_ref_mock = MagicMock() + snapshot_ref_mock = MagicMock() + final_ref_mock = MagicMock() + + # This test will have three failures, one for each of the firebase writes. + # UNAVAILABLE errors are retryable. + active_ref_mock.delete.side_effect = [ + FirebaseError('UNAVAILABLE', 'active error'), None, None, None + ] + snapshot_ref_mock.set.side_effect = [ + FirebaseError('UNAVAILABLE', 'snapshot error'), None, None + ] + final_ref_mock.set.side_effect = [ + FirebaseError('UNAVAILABLE', 'final error'), None + ] + + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, + self._mock_presence_ref, + self._mock_register_ref, + self._fake_subscribe_ref, # setup + active_ref_mock, # attempt 1 + active_ref_mock, + snapshot_ref_mock, # attempt 2 + active_ref_mock, + snapshot_ref_mock, + final_ref_mock, # attempt 3 + active_ref_mock, + snapshot_ref_mock, + final_ref_mock # attempt 4 + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + breakpoint_id = 'breakpoint-0' + + input_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'evaluatedExpressions': ['expressions go here'], + 'stackFrames': ['stuff goes here'], + 'variableTable': ['lots', 'of', 'variables'], + } + short_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'action': 'CAPTURE', + 'finalTimeUnixMsec': { + '.sv': 'timestamp' + } + } + full_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'action': 'CAPTURE', + 'evaluatedExpressions': ['expressions go here'], + 'stackFrames': ['stuff goes here'], + 'variableTable': ['lots', 'of', 'variables'], + 'finalTimeUnixMsec': { + '.sv': 'timestamp' + } + } + + self._client.EnqueueBreakpointUpdate(input_breakpoint) + + # Wait for the breakpoint to be sent. Retries will have occured. + while self._client._transmission_queue: + time.sleep(0.1) + + active_ref_mock.delete.assert_has_calls([call()] * 4) + snapshot_ref_mock.set.assert_has_calls([call(full_breakpoint)] * 3) + final_ref_mock.set.assert_has_calls([call(short_breakpoint)] * 2) + + def _TestInitializeLabels(self, module_var, version_var, minor_var): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + self._client.InitializeDebuggeeLabels({ + 'module': 'my_module', + 'version': '1', + 'minorversion': '23', + 'something_else': 'irrelevant' + }) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'my_module', + 'version': '1', + 'minorversion': '23', + 'platform': 'default' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-my_module-1', + self._client._GetDebuggeeDescription()) + + uniquifier1 = self._client._ComputeUniquifier( + {'labels': self._client._debuggee_labels}) + self.assertTrue(uniquifier1) # Not empty string. + + try: + os.environ[module_var] = 'env_module' + os.environ[version_var] = '213' + os.environ[minor_var] = '3476734' + self._client.InitializeDebuggeeLabels(None) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'env_module', + 'version': '213', + 'minorversion': '3476734', + 'platform': 'default' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-env_module-213', + self._client._GetDebuggeeDescription()) + + os.environ[module_var] = 'default' + os.environ[version_var] = '213' + os.environ[minor_var] = '3476734' + self._client.InitializeDebuggeeLabels({'minorversion': 'something else'}) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'version': '213', + 'minorversion': 'something else', + 'platform': 'default' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-213', + self._client._GetDebuggeeDescription()) + + finally: + del os.environ[module_var] + del os.environ[version_var] + del os.environ[minor_var] + + def testInitializeLegacyDebuggeeLabels(self): + self._TestInitializeLabels('GAE_MODULE_NAME', 'GAE_MODULE_VERSION', + 'GAE_MINOR_VERSION') + + def testInitializeDebuggeeLabels(self): + self._TestInitializeLabels('GAE_SERVICE', 'GAE_VERSION', + 'GAE_DEPLOYMENT_ID') + + def testInitializeCloudRunDebuggeeLabels(self): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + try: + os.environ['K_SERVICE'] = 'env_module' + os.environ['K_REVISION'] = '213' + self._client.InitializeDebuggeeLabels(None) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'env_module', + 'version': '213', + 'platform': 'default' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-env_module-213', + self._client._GetDebuggeeDescription()) + + finally: + del os.environ['K_SERVICE'] + del os.environ['K_REVISION'] + + def testInitializeCloudFunctionDebuggeeLabels(self): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + try: + os.environ['FUNCTION_NAME'] = 'fcn-name' + os.environ['X_GOOGLE_FUNCTION_VERSION'] = '213' + self._client.InitializeDebuggeeLabels(None) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'fcn-name', + 'version': '213', + 'platform': 'cloud_function' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-fcn-name-213', + self._client._GetDebuggeeDescription()) + + finally: + del os.environ['FUNCTION_NAME'] + del os.environ['X_GOOGLE_FUNCTION_VERSION'] + + def testInitializeCloudFunctionUnversionedDebuggeeLabels(self): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + try: + os.environ['FUNCTION_NAME'] = 'fcn-name' + self._client.InitializeDebuggeeLabels(None) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'fcn-name', + 'version': 'unversioned', + 'platform': 'cloud_function' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-fcn-name-unversioned', + self._client._GetDebuggeeDescription()) + + finally: + del os.environ['FUNCTION_NAME'] + + def testInitializeCloudFunctionWithRegionDebuggeeLabels(self): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + try: + os.environ['FUNCTION_NAME'] = 'fcn-name' + os.environ['FUNCTION_REGION'] = 'fcn-region' + self._client.InitializeDebuggeeLabels(None) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'fcn-name', + 'version': 'unversioned', + 'platform': 'cloud_function', + 'region': 'fcn-region' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-fcn-name-unversioned', + self._client._GetDebuggeeDescription()) + + finally: + del os.environ['FUNCTION_NAME'] + del os.environ['FUNCTION_REGION'] + + def testAppFilesUniquifierNoMinorVersion(self): + """Verify that uniquifier_computer is used if minor version not defined.""" + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + root = tempfile.mkdtemp('', 'fake_app_') + sys.path.insert(0, root) + try: + uniquifier1 = self._client._ComputeUniquifier({}) + + with open(os.path.join(root, 'app.py'), 'w', encoding='utf-8') as f: + f.write('hello') + uniquifier2 = self._client._ComputeUniquifier({}) + finally: + del sys.path[0] + + self.assertNotEqual(uniquifier1, uniquifier2) + + def testAppFilesUniquifierWithMinorVersion(self): + """Verify that uniquifier_computer not used if minor version is defined.""" + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + root = tempfile.mkdtemp('', 'fake_app_') + + os.environ['GAE_MINOR_VERSION'] = '12345' + sys.path.insert(0, root) + try: + self._client.InitializeDebuggeeLabels(None) + + uniquifier1 = self._client._GetDebuggee()['uniquifier'] + + with open(os.path.join(root, 'app.py'), 'w', encoding='utf-8') as f: + f.write('hello') + uniquifier2 = self._client._GetDebuggee()['uniquifier'] + finally: + del os.environ['GAE_MINOR_VERSION'] + del sys.path[0] + + self.assertEqual(uniquifier1, uniquifier2) + + def testSourceContext(self): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + root = tempfile.mkdtemp('', 'fake_app_') + source_context_path = os.path.join(root, 'source-context.json') + + sys.path.insert(0, root) + try: + debuggee_no_source_context1 = self._client._GetDebuggee() + + with open(source_context_path, 'w', encoding='utf-8') as f: + f.write('not a valid JSON') + debuggee_bad_source_context = self._client._GetDebuggee() + + with open(os.path.join(root, 'fake_app.py'), 'w', encoding='utf-8') as f: + f.write('pretend') + debuggee_no_source_context2 = self._client._GetDebuggee() + + with open(source_context_path, 'w', encoding='utf-8') as f: + f.write('{"what": "source context"}') + debuggee_with_source_context = self._client._GetDebuggee() + + os.remove(source_context_path) + finally: + del sys.path[0] + + self.assertNotIn('sourceContexts', debuggee_no_source_context1) + self.assertNotIn('sourceContexts', debuggee_bad_source_context) + self.assertListEqual([{ + 'what': 'source context' + }], debuggee_with_source_context['sourceContexts']) + + uniquifiers = set() + uniquifiers.add(debuggee_no_source_context1['uniquifier']) + uniquifiers.add(debuggee_with_source_context['uniquifier']) + uniquifiers.add(debuggee_bad_source_context['uniquifier']) + self.assertLen(uniquifiers, 1) + uniquifiers.add(debuggee_no_source_context2['uniquifier']) + self.assertLen(uniquifiers, 2) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/glob_data_visibility_policy_test.py b/tests/py/glob_data_visibility_policy_test.py new file mode 100644 index 0000000..8670198 --- /dev/null +++ b/tests/py/glob_data_visibility_policy_test.py @@ -0,0 +1,36 @@ +"""Tests for glob_data_visibility_policy.""" + +from absl.testing import absltest +from googleclouddebugger import glob_data_visibility_policy + +RESPONSES = glob_data_visibility_policy.RESPONSES +UNKNOWN_TYPE = (False, RESPONSES['UNKNOWN_TYPE']) +BLACKLISTED = (False, RESPONSES['BLACKLISTED']) +NOT_WHITELISTED = (False, RESPONSES['NOT_WHITELISTED']) +VISIBLE = (True, RESPONSES['VISIBLE']) + + +class GlobDataVisibilityPolicyTest(absltest.TestCase): + + def testIsDataVisible(self): + blacklist_patterns = ( + 'wl1.private1', + 'wl2.*', + '*.private2', + '', + ) + whitelist_patterns = ('wl1.*', 'wl2.*') + + policy = glob_data_visibility_policy.GlobDataVisibilityPolicy( + blacklist_patterns, whitelist_patterns) + + self.assertEqual(BLACKLISTED, policy.IsDataVisible('wl1.private1')) + self.assertEqual(BLACKLISTED, policy.IsDataVisible('wl2.foo')) + self.assertEqual(BLACKLISTED, policy.IsDataVisible('foo.private2')) + self.assertEqual(NOT_WHITELISTED, policy.IsDataVisible('wl3.foo')) + self.assertEqual(VISIBLE, policy.IsDataVisible('wl1.foo')) + self.assertEqual(UNKNOWN_TYPE, policy.IsDataVisible(None)) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/imphook_test.py b/tests/py/imphook_test.py new file mode 100644 index 0000000..320e477 --- /dev/null +++ b/tests/py/imphook_test.py @@ -0,0 +1,485 @@ +"""Unit test for imphook module.""" + +import importlib +import os +import sys +import tempfile + +from absl.testing import absltest + +from googleclouddebugger import imphook + + +class ImportHookTest(absltest.TestCase): + """Tests for the new module import hook.""" + + def setUp(self): + self._test_package_dir = tempfile.mkdtemp('', 'imphook_') + sys.path.append(self._test_package_dir) + + self._import_callbacks_log = [] + self._callback_cleanups = [] + + def tearDown(self): + sys.path.remove(self._test_package_dir) + + for cleanup in self._callback_cleanups: + cleanup() + + # Assert no hooks or entries remained in the set. + self.assertEmpty(imphook._import_callbacks) + + def testPackageImport(self): + self._Hook(self._CreateFile('testpkg1/__init__.py')) + import testpkg1 # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg1/__init__.py'], self._import_callbacks_log) + + def testModuleImport(self): + self._CreateFile('testpkg2/__init__.py') + self._Hook(self._CreateFile('testpkg2/my.py')) + import testpkg2.my # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg2/my.py'], self._import_callbacks_log) + + def testUnrelatedImport(self): + self._CreateFile('testpkg3/__init__.py') + self._Hook(self._CreateFile('testpkg3/first.py')) + self._CreateFile('testpkg3/second.py') + import testpkg3.second # pylint: disable=g-import-not-at-top,unused-variable + self.assertEmpty(self._import_callbacks_log) + + def testDoubleImport(self): + self._Hook(self._CreateFile('testpkg4/__init__.py')) + import testpkg4 # pylint: disable=g-import-not-at-top,unused-variable + import testpkg4 # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg4/__init__.py', 'testpkg4/__init__.py'], + sorted(self._import_callbacks_log)) + + def testRemoveCallback(self): + cleanup = self._Hook(self._CreateFile('testpkg4b/__init__.py')) + cleanup() + import testpkg4b # pylint: disable=g-import-not-at-top,unused-variable + self.assertEmpty(self._import_callbacks_log) + + def testRemoveCallbackAfterImport(self): + cleanup = self._Hook(self._CreateFile('testpkg5/__init__.py')) + import testpkg5 # pylint: disable=g-import-not-at-top,unused-variable + cleanup() + import testpkg5 # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg5/__init__.py'], self._import_callbacks_log) + + def testTransitiveImport(self): + self._CreateFile('testpkg6/__init__.py') + self._Hook(self._CreateFile('testpkg6/first.py', 'import second')) + self._Hook(self._CreateFile('testpkg6/second.py', 'import third')) + self._Hook(self._CreateFile('testpkg6/third.py')) + import testpkg6.first # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual( + ['testpkg6/first.py', 'testpkg6/second.py', 'testpkg6/third.py'], + sorted(self._import_callbacks_log)) + + def testPackageDotModuleImport(self): + self._Hook(self._CreateFile('testpkg8/__init__.py')) + self._Hook(self._CreateFile('testpkg8/my.py')) + import testpkg8.my # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg8/__init__.py', 'testpkg8/my.py'], + sorted(self._import_callbacks_log)) + + def testNestedPackageDotModuleImport(self): + self._Hook(self._CreateFile('testpkg9a/__init__.py')) + self._Hook(self._CreateFile('testpkg9a/testpkg9b/__init__.py')) + self._CreateFile('testpkg9a/testpkg9b/my.py') + import testpkg9a.testpkg9b.my # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual( + ['testpkg9a/__init__.py', 'testpkg9a/testpkg9b/__init__.py'], + sorted(self._import_callbacks_log)) + + def testFromImport(self): + self._Hook(self._CreateFile('testpkg10/__init__.py')) + self._CreateFile('testpkg10/my.py') + from testpkg10 import my # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg10/__init__.py'], self._import_callbacks_log) + + def testTransitiveFromImport(self): + self._CreateFile('testpkg7/__init__.py') + self._Hook( + self._CreateFile('testpkg7/first.py', 'from testpkg7 import second')) + self._Hook(self._CreateFile('testpkg7/second.py')) + from testpkg7 import first # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg7/first.py', 'testpkg7/second.py'], + sorted(self._import_callbacks_log)) + + def testFromNestedPackageImportModule(self): + self._Hook(self._CreateFile('testpkg11a/__init__.py')) + self._Hook(self._CreateFile('testpkg11a/testpkg11b/__init__.py')) + self._Hook(self._CreateFile('testpkg11a/testpkg11b/my.py')) + self._Hook(self._CreateFile('testpkg11a/testpkg11b/your.py')) + from testpkg11a.testpkg11b import my, your # pylint: disable=g-import-not-at-top,unused-variable,g-multiple-import + self.assertEqual([ + 'testpkg11a/__init__.py', 'testpkg11a/testpkg11b/__init__.py', + 'testpkg11a/testpkg11b/my.py', 'testpkg11a/testpkg11b/your.py' + ], sorted(self._import_callbacks_log)) + + def testDoubleNestedImport(self): + self._Hook(self._CreateFile('testpkg12a/__init__.py')) + self._Hook(self._CreateFile('testpkg12a/testpkg12b/__init__.py')) + self._Hook(self._CreateFile('testpkg12a/testpkg12b/my.py')) + from testpkg12a.testpkg12b import my # pylint: disable=g-import-not-at-top,unused-variable,g-multiple-import + from testpkg12a.testpkg12b import my # pylint: disable=g-import-not-at-top,unused-variable,g-multiple-import + self.assertEqual([ + 'testpkg12a/__init__.py', 'testpkg12a/__init__.py', + 'testpkg12a/testpkg12b/__init__.py', + 'testpkg12a/testpkg12b/__init__.py', 'testpkg12a/testpkg12b/my.py', + 'testpkg12a/testpkg12b/my.py' + ], sorted(self._import_callbacks_log)) + + def testFromPackageImportStar(self): + self._Hook(self._CreateFile('testpkg13a/__init__.py')) + self._Hook(self._CreateFile('testpkg13a/my1.py')) + self._Hook(self._CreateFile('testpkg13a/your1.py')) + # Star imports are only allowed at the top level, not inside a function in + # Python 3. Doing so would be a SyntaxError. + exec('from testpkg13a import *') # pylint: disable=exec-used + self.assertEqual(['testpkg13a/__init__.py'], self._import_callbacks_log) + + def testFromPackageImportStarWith__all__(self): + self._Hook(self._CreateFile('testpkg14a/__init__.py', '__all__=["my1"]')) + self._Hook(self._CreateFile('testpkg14a/my1.py')) + self._Hook(self._CreateFile('testpkg14a/your1.py')) + exec('from testpkg14a import *') # pylint: disable=exec-used + self.assertEqual(['testpkg14a/__init__.py', 'testpkg14a/my1.py'], + sorted(self._import_callbacks_log)) + + def testImportFunction(self): + self._Hook(self._CreateFile('testpkg27/__init__.py')) + __import__('testpkg27') + self.assertEqual(['testpkg27/__init__.py'], self._import_callbacks_log) + + def testImportLib(self): + self._Hook(self._CreateFile('zero.py')) + self._Hook(self._CreateFile('testpkg15a/__init__.py')) + self._Hook(self._CreateFile('testpkg15a/first.py')) + self._Hook( + self._CreateFile('testpkg15a/testpkg15b/__init__.py', + 'assert False, "unexpected import"')) + self._Hook(self._CreateFile('testpkg15a/testpkg15c/__init__.py')) + self._Hook(self._CreateFile('testpkg15a/testpkg15c/second.py')) + + # Import top level module. + importlib.import_module('zero') + self.assertEqual(['zero.py'], self._import_callbacks_log) + self._import_callbacks_log = [] + + # Import top level package. + importlib.import_module('testpkg15a') + self.assertEqual(['testpkg15a/__init__.py'], self._import_callbacks_log) + self._import_callbacks_log = [] + + # Import package.module. + importlib.import_module('testpkg15a.first') + self.assertEqual(['testpkg15a/__init__.py', 'testpkg15a/first.py'], + sorted(self._import_callbacks_log)) + self._import_callbacks_log = [] + + # Relative module import from package context. + importlib.import_module('.first', 'testpkg15a') + self.assertEqual(['testpkg15a/__init__.py', 'testpkg15a/first.py'], + sorted(self._import_callbacks_log)) + self._import_callbacks_log = [] + + # Relative module import from package context with '..'. + # In Python 3, the parent module has to be loaded before a relative import + importlib.import_module('testpkg15a.testpkg15c') + self._import_callbacks_log = [] + importlib.import_module('..first', 'testpkg15a.testpkg15c') + self.assertEqual( + [ + 'testpkg15a/__init__.py', + # TODO: Importlib may or may not load testpkg15b, + # depending on the implementation. Currently on blaze, it does not + # load testpkg15b, but a similar non-blaze code on my workstation + # loads testpkg15b. We should verify this behavior. + # 'testpkg15a/testpkg15b/__init__.py', + 'testpkg15a/first.py' + ], + sorted(self._import_callbacks_log)) + self._import_callbacks_log = [] + + # Relative module import from nested package context. + importlib.import_module('.second', 'testpkg15a.testpkg15c') + self.assertEqual([ + 'testpkg15a/__init__.py', 'testpkg15a/testpkg15c/__init__.py', + 'testpkg15a/testpkg15c/second.py' + ], sorted(self._import_callbacks_log)) + self._import_callbacks_log = [] + + def testRemoveImportHookFromCallback(self): + + def RunCleanup(unused_mod): + cleanup() + + cleanup = self._Hook(self._CreateFile('testpkg15/__init__.py'), RunCleanup) + import testpkg15 # pylint: disable=g-import-not-at-top,unused-variable + import testpkg15 # pylint: disable=g-import-not-at-top,unused-variable + import testpkg15 # pylint: disable=g-import-not-at-top,unused-variable + + # The first import should have removed the hook, so expect only one entry. + self.assertEqual(['testpkg15/__init__.py'], self._import_callbacks_log) + + def testInitImportNoPrematureCallback(self): + # Verifies that the callback is not invoked before the package is fully + # loaded. Thus, assuring that the all module code is available for lookup. + def CheckFullyLoaded(module): + self.assertEqual(1, getattr(module, 'validate', None), 'premature call') + + self._Hook(self._CreateFile('testpkg16/my1.py')) + self._Hook( + self._CreateFile('testpkg16/__init__.py', 'import my1\nvalidate = 1'), + CheckFullyLoaded) + import testpkg16.my1 # pylint: disable=g-import-not-at-top,unused-variable + + self.assertEqual(['testpkg16/__init__.py', 'testpkg16/my1.py'], + sorted(self._import_callbacks_log)) + + def testCircularImportNoPrematureCallback(self): + # Verifies that the callback is not invoked before the first module is fully + # loaded. Thus, assuring that the all module code is available for lookup. + def CheckFullyLoaded(module): + self.assertEqual(1, getattr(module, 'validate', None), 'premature call') + + self._CreateFile('testpkg17/__init__.py') + self._Hook( + self._CreateFile('testpkg17/c1.py', 'import testpkg17.c2\nvalidate = 1', + False), CheckFullyLoaded) + self._Hook( + self._CreateFile('testpkg17/c2.py', 'import testpkg17.c1\nvalidate = 1', + False), CheckFullyLoaded) + + import testpkg17.c1 # pylint: disable=g-import-not-at-top,unused-variable + + self.assertEqual(['testpkg17/c1.py', 'testpkg17/c2.py'], + sorted(self._import_callbacks_log)) + + def testImportException(self): + # An exception is thrown by the builtin importer during import. + self._CreateFile('testpkg18/__init__.py') + self._Hook(self._CreateFile('testpkg18/bad.py', 'assert False, "bad file"')) + self._Hook(self._CreateFile('testpkg18/good.py')) + + try: + import testpkg18.bad # pylint: disable=g-import-not-at-top,unused-variable + except AssertionError: + pass + + import testpkg18.good # pylint: disable=g-import-not-at-top,unused-variable + + self.assertEqual(['testpkg18/good.py'], self._import_callbacks_log) + + def testImportNestedException(self): + # An import exception is thrown and caught inside a module being imported. + self._CreateFile('testpkg19/__init__.py') + self._Hook( + self._CreateFile('testpkg19/m19.py', + 'try: import m19b\nexcept ImportError: pass')) + + import testpkg19.m19 # pylint: disable=g-import-not-at-top,unused-variable + + self.assertEqual(['testpkg19/m19.py'], self._import_callbacks_log) + + def testModuleImportByPathSuffix(self): + # Import module by providing only a suffix of the module's file path. + self._CreateFile('testpkg20a/__init__.py') + self._CreateFile('testpkg20a/testpkg20b/__init__.py') + self._CreateFile('testpkg20a/testpkg20b/my1.py') + self._CreateFile('testpkg20a/testpkg20b/my2.py') + self._CreateFile('testpkg20a/testpkg20b/my3.py') + + # Import just by the name of the module file. + self._Hook('my1.py') + import testpkg20a.testpkg20b.my1 # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['my1.py'], self._import_callbacks_log) + self._import_callbacks_log = [] + + # Import with only one of the enclosing package names. + self._Hook('testpkg20b/my2.py') + import testpkg20a.testpkg20b.my2 # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg20b/my2.py'], self._import_callbacks_log) + self._import_callbacks_log = [] + + # Import with all enclosing packages (the typical case). + self._Hook('testpkg20b/my3.py') + import testpkg20a.testpkg20b.my3 # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg20b/my3.py'], self._import_callbacks_log) + self._import_callbacks_log = [] + + def testFromImportImportsFunction(self): + self._CreateFile('testpkg21a/__init__.py') + self._CreateFile('testpkg21a/testpkg21b/__init__.py') + self._CreateFile('testpkg21a/testpkg21b/mod.py', ('def func1():\n' + ' return 5\n' + '\n' + 'def func2():\n' + ' return 7\n')) + + self._Hook('mod.py') + from testpkg21a.testpkg21b.mod import func1, func2 # pylint: disable=g-import-not-at-top,unused-variable,g-multiple-import + self.assertEqual(['mod.py'], self._import_callbacks_log) + + def testImportSibling(self): + self._CreateFile('testpkg22/__init__.py') + self._CreateFile('testpkg22/first.py', 'import second') + self._CreateFile('testpkg22/second.py') + + self._Hook('testpkg22/second.py') + import testpkg22.first # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg22/second.py'], self._import_callbacks_log) + + def testImportSiblingSamePackage(self): + self._CreateFile('testpkg23/__init__.py') + self._CreateFile('testpkg23/testpkg23/__init__.py') + self._CreateFile( + 'testpkg23/first.py', + 'import testpkg23.second') # This refers to testpkg23.testpkg23.second + self._CreateFile('testpkg23/testpkg23/second.py') + + self._Hook('testpkg23/testpkg23/second.py') + import testpkg23.first # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg23/testpkg23/second.py'], + self._import_callbacks_log) + + def testImportSiblingFromInit(self): + self._Hook(self._CreateFile('testpkg23a/__init__.py', 'import testpkg23b')) + self._Hook( + self._CreateFile('testpkg23a/testpkg23b/__init__.py', + 'import testpkg23c')) + self._Hook(self._CreateFile('testpkg23a/testpkg23b/testpkg23c/__init__.py')) + import testpkg23a # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual([ + 'testpkg23a/__init__.py', 'testpkg23a/testpkg23b/__init__.py', + 'testpkg23a/testpkg23b/testpkg23c/__init__.py' + ], sorted(self._import_callbacks_log)) + + def testThreadLocalCleanup(self): + self._CreateFile('testpkg24/__init__.py') + self._CreateFile('testpkg24/foo.py', 'import bar') + self._CreateFile('testpkg24/bar.py') + + # Create a hook for any arbitrary module. Doesn't need to hit. + self._Hook('xxx/yyy.py') + + import testpkg24.foo # pylint: disable=g-import-not-at-top,unused-variable + + self.assertEqual(imphook._import_local.nest_level, 0) + self.assertEmpty(imphook._import_local.names) + + def testThreadLocalCleanupWithCaughtImportError(self): + self._CreateFile('testpkg25/__init__.py') + self._CreateFile( + 'testpkg25/foo.py', + 'import bar\n' # success. + 'import baz') # success. + self._CreateFile('testpkg25/bar.py') + self._CreateFile( + 'testpkg25/baz.py', 'try:\n' + ' import testpkg25b\n' + 'except ImportError:\n' + ' pass') + + # Create a hook for any arbitrary module. Doesn't need to hit. + self._Hook('xxx/yyy.py') + + # Successful import at top level. Failed import at inner level. + import testpkg25.foo # pylint: disable=g-import-not-at-top,unused-variable + + self.assertEqual(imphook._import_local.nest_level, 0) + self.assertEmpty(imphook._import_local.names) + + def testThreadLocalCleanupWithUncaughtImportError(self): + self._CreateFile('testpkg26/__init__.py') + self._CreateFile( + 'testpkg26/foo.py', + 'import bar\n' # success. + 'import baz') # fail. + self._CreateFile('testpkg26/bar.py') + + # Create a hook for any arbitrary module. Doesn't need to hit. + self._Hook('testpkg26/bar.py') + + # Inner import will fail, and exception will be propagated here. + try: + import testpkg26.foo # pylint: disable=g-import-not-at-top,unused-variable + except ImportError: + pass + + # The hook for bar should be invoked, as bar is already loaded. + self.assertEqual(['testpkg26/bar.py'], self._import_callbacks_log) + + self.assertEqual(imphook._import_local.nest_level, 0) + self.assertEmpty(imphook._import_local.names) + + def testCleanup(self): + cleanup1 = self._Hook('a/b/c.py') + cleanup2 = self._Hook('a/b/c.py') + cleanup3 = self._Hook('a/d/f.py') + cleanup4 = self._Hook('a/d/g.py') + cleanup5 = self._Hook('a/d/c.py') + self.assertLen(imphook._import_callbacks, 4) + + cleanup1() + self.assertLen(imphook._import_callbacks, 4) + cleanup2() + self.assertLen(imphook._import_callbacks, 3) + cleanup3() + self.assertLen(imphook._import_callbacks, 2) + cleanup4() + self.assertLen(imphook._import_callbacks, 1) + cleanup5() + self.assertLen(imphook._import_callbacks, 0) + + def _CreateFile(self, path, content='', rewrite_imports=True): + full_path = os.path.join(self._test_package_dir, path) + directory, unused_name = os.path.split(full_path) + + if not os.path.isdir(directory): + os.makedirs(directory) + + def RewriteImport(line): + """Converts import statements to relative form. + + Examples: + import x => from . import x + import x.y.z => from .x.y import z + print('') => print('') + + Args: + line: str, the line to convert. + + Returns: + str, the converted import statement or original line. + """ + original_line_length = len(line) + line = line.lstrip(' ') + indent = ' ' * (original_line_length - len(line)) + if line.startswith('import'): + pkg, _, mod = line.split(' ')[1].rpartition('.') + line = 'from .%s import %s' % (pkg, mod) + return indent + line + + with open(full_path, 'w') as writer: + if rewrite_imports: + content = '\n'.join(RewriteImport(l) for l in content.split('\n')) + writer.write(content) + + return path + + # TODO: add test for the module param in the callback. + def _Hook(self, path, callback=lambda m: None): + cleanup = imphook.AddImportCallbackBySuffix( + path, lambda mod: + (self._import_callbacks_log.append(path), callback(mod))) + self.assertTrue(cleanup, path) + self._callback_cleanups.append(cleanup) + return cleanup + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/integration_test.py b/tests/py/integration_test.py new file mode 100644 index 0000000..1a16f30 --- /dev/null +++ b/tests/py/integration_test.py @@ -0,0 +1,670 @@ +"""Complete tests of the debugger mocking the backend.""" + +from datetime import datetime +from datetime import timedelta +import functools +import inspect +import itertools +import os +import sys +import time +from unittest import mock + +from googleapiclient import discovery +import googleclouddebugger as cdbg + +import queue + +import google.auth +from absl.testing import absltest + +from googleclouddebugger import collector +from googleclouddebugger import labels +import python_test_util + +_TEST_DEBUGGEE_ID = 'gcp:integration-test-debuggee-id' +_TEST_AGENT_ID = 'agent-id-123-abc' +_TEST_PROJECT_ID = 'test-project-id' +_TEST_PROJECT_NUMBER = '123456789' + +# Time to sleep before returning the result of an API call. +# Without a delay, the agent will continuously call ListActiveBreakpoints, +# and the mock object will use a lot of memory to record all the calls. +_REQUEST_DELAY_SECS = 0.01 + +# TODO: Modify to work with a mocked Firebase database instead. +# class IntegrationTest(absltest.TestCase): +# """Complete tests of the debugger mocking the backend. + +# These tests employ all the components of the debugger. The actual +# communication channel with the backend is mocked. This allows the test +# quickly inject breakpoints and read results. It also makes the test +# standalone and independent of the actual backend. + +# Uses the new module search algorithm (b/70226488). +# """ + +# class FakeHub(object): +# """Starts the debugger with a mocked communication channel.""" + +# def __init__(self): +# # Breakpoint updates posted by the debugger that haven't been processed +# # by the test case code. +# self._incoming_breakpoint_updates = queue.Queue() + +# # Running counter used to generate unique breakpoint IDs. +# self._id_counter = itertools.count() + +# self._service = mock.Mock() + +# patcher = mock.patch.object(discovery, 'build') +# self._mock_build = patcher.start() +# self._mock_build.return_value = self._service + +# patcher = mock.patch.object(google.auth, 'default') +# self._default_auth_mock = patcher.start() +# self._default_auth_mock.return_value = None, _TEST_PROJECT_ID + +# controller = self._service.controller.return_value +# debuggees = controller.debuggees.return_value +# breakpoints = debuggees.breakpoints.return_value + +# # Simulate a time delay for calls to the mock API. +# def ReturnWithDelay(val): + +# def GetVal(): +# time.sleep(_REQUEST_DELAY_SECS) +# return val + +# return GetVal + +# self._register_execute = debuggees.register.return_value.execute +# self._register_execute.side_effect = ReturnWithDelay({ +# 'debuggee': { +# 'id': _TEST_DEBUGGEE_ID +# }, +# 'agentId': _TEST_AGENT_ID +# }) + +# self._active_breakpoints = {'breakpoints': []} +# self._list_execute = breakpoints.list.return_value.execute +# self._list_execute.side_effect = ReturnWithDelay(self._active_breakpoints) + +# breakpoints.update = self._UpdateBreakpoint + +# # Start the debugger. +# cdbg.enable() + +# def SetBreakpoint(self, tag, template=None): +# """Sets a new breakpoint in this source file. + +# The line number is identified by tag. The optional template may specify +# other breakpoint parameters such as condition and watched expressions. + +# Args: +# tag: label for a source line. +# template: optional breakpoint parameters. +# """ +# path, line = python_test_util.ResolveTag(sys.modules[__name__], tag) +# self.SetBreakpointAtPathLine(path, line, template) + +# def SetBreakpointAtFile(self, filename, tag, template=None): +# """Sets a breakpoint in a file with the given filename. + +# The line number is identified by tag. The optional template may specify +# other breakpoint parameters such as condition and watched expressions. + +# Args: +# filename: the name of the file inside which the tag will be searched. +# Must be in the same directory as the current file. +# tag: label for a source line. +# template: optional breakpoint parameters. + +# Raises: +# Exception: when the given tag does not uniquely identify a line. +# """ +# # TODO: Move part of this to python_test_utils.py file. +# # Find the full path of filename, using the directory of the current file. +# module_path = inspect.getsourcefile(sys.modules[__name__]) +# directory, unused_name = os.path.split(module_path) +# path = os.path.join(directory, filename) + +# # Similar to ResolveTag(), but for a module that's not loaded yet. +# tags = python_test_util.GetSourceFileTags(path) +# if tag not in tags: +# raise Exception('tag %s not found' % tag) +# lines = tags[tag] +# if len(lines) != 1: +# raise Exception('tag %s is ambiguous (lines: %s)' % (tag, lines)) + +# self.SetBreakpointAtPathLine(path, lines[0], template) + +# def SetBreakpointAtPathLine(self, path, line, template=None): +# """Sets a new breakpoint at path:line.""" +# breakpoint = { +# 'id': 'BP_%d' % next(self._id_counter), +# 'createTime': python_test_util.DateTimeToTimestamp(datetime.utcnow()), +# 'location': { +# 'path': path, +# 'line': line +# } +# } +# breakpoint.update(template or {}) + +# self.SetActiveBreakpoints(self.GetActiveBreakpoints() + [breakpoint]) + +# def GetActiveBreakpoints(self): +# """Returns current list of active breakpoints.""" +# return self._active_breakpoints['breakpoints'] + +# def SetActiveBreakpoints(self, breakpoints): +# """Sets a new list of active breakpoints. + +# Args: +# breakpoints: list of breakpoints to return to the debuglet. +# """ +# self._active_breakpoints['breakpoints'] = breakpoints +# begin_count = self._list_execute.call_count +# while self._list_execute.call_count < begin_count + 2: +# time.sleep(_REQUEST_DELAY_SECS) + +# def GetNextResult(self): +# """Waits for the next breakpoint update from the debuglet. + +# Returns: +# First breakpoint update sent by the debuglet that hasn't been +# processed yet. + +# Raises: +# queue.Empty: if waiting for breakpoint update times out. +# """ +# try: +# return self._incoming_breakpoint_updates.get(True, 15) +# except queue.Empty: +# raise AssertionError('Timed out waiting for breakpoint update') + +# def TryGetNextResult(self): +# """Returns the first unprocessed breakpoint update from the debuglet. + +# Returns: +# First breakpoint update sent by the debuglet that hasn't been +# processed yet. If no updates are pending, returns None. +# """ +# try: +# return self._incoming_breakpoint_updates.get_nowait() +# except queue.Empty: +# return None + +# def _UpdateBreakpoint(self, **keywords): +# """Fake implementation of service.debuggees().breakpoints().update().""" + +# class FakeBreakpointUpdateCommand(object): + +# def __init__(self, q): +# self._breakpoint = keywords['body']['breakpoint'] +# self._queue = q + +# def execute(self): # pylint: disable=invalid-name +# self._queue.put(self._breakpoint) + +# return FakeBreakpointUpdateCommand(self._incoming_breakpoint_updates) + +# # We only need to attach the debugger exactly once. The IntegrationTest class +# # is created for each test case, so we need to keep this state global. + +# _hub = FakeHub() + +# def _FakeLog(self, message, extra=None): +# del extra # unused +# self._info_log.append(message) + +# def setUp(self): +# self._info_log = [] +# collector.log_info_message = self._FakeLog + +# def tearDown(self): +# IntegrationTest._hub.SetActiveBreakpoints([]) + +# while True: +# breakpoint = IntegrationTest._hub.TryGetNextResult() +# if breakpoint is None: +# break +# self.fail('Unexpected incoming breakpoint update: %s' % breakpoint) + +# def testBackCompat(self): +# # Verify that the old AttachDebugger() is the same as enable() +# self.assertEqual(cdbg.enable, cdbg.AttachDebugger) + +# def testBasic(self): + +# def Trigger(): +# print('Breakpoint trigger') # BPTAG: BASIC + +# IntegrationTest._hub.SetBreakpoint('BASIC') +# Trigger() +# result = IntegrationTest._hub.GetNextResult() +# self.assertEqual('Trigger', result['stackFrames'][0]['function']) +# self.assertEqual('IntegrationTest.testBasic', +# result['stackFrames'][1]['function']) + +# # Verify that any pre existing labels present in the breakpoint are preserved +# # by the agent. +# def testExistingLabelsSurvive(self): + +# def Trigger(): +# print('Breakpoint trigger with labels') # BPTAG: EXISTING_LABELS_SURVIVE + +# IntegrationTest._hub.SetBreakpoint( +# 'EXISTING_LABELS_SURVIVE', +# {'labels': { +# 'label_1': 'value_1', +# 'label_2': 'value_2' +# }}) +# Trigger() +# result = IntegrationTest._hub.GetNextResult() +# self.assertIn('labels', result.keys()) +# self.assertIn('label_1', result['labels']) +# self.assertIn('label_2', result['labels']) +# self.assertEqual('value_1', result['labels']['label_1']) +# self.assertEqual('value_2', result['labels']['label_2']) + +# # Verify that any pre existing labels present in the breakpoint have priority +# # if they 'collide' with labels in the agent. +# def testExistingLabelsPriority(self): + +# def Trigger(): +# print('Breakpoint trigger with labels') # BPTAG: EXISTING_LABELS_PRIORITY + +# current_labels_collector = collector.breakpoint_labels_collector +# collector.breakpoint_labels_collector = \ +# lambda: {'label_1': 'value_1', 'label_2': 'value_2'} + +# IntegrationTest._hub.SetBreakpoint( +# 'EXISTING_LABELS_PRIORITY', +# {'labels': { +# 'label_1': 'value_foobar', +# 'label_3': 'value_3' +# }}) + +# Trigger() + +# collector.breakpoint_labels_collector = current_labels_collector + +# # In this case, label_1 was in both the agent and the pre existing labels, +# # the pre existing value of value_foobar should be preserved. +# result = IntegrationTest._hub.GetNextResult() +# self.assertIn('labels', result.keys()) +# self.assertIn('label_1', result['labels']) +# self.assertIn('label_2', result['labels']) +# self.assertIn('label_3', result['labels']) +# self.assertEqual('value_foobar', result['labels']['label_1']) +# self.assertEqual('value_2', result['labels']['label_2']) +# self.assertEqual('value_3', result['labels']['label_3']) + +# def testRequestLogIdLabel(self): + +# def Trigger(): +# print('Breakpoint trigger req id label') # BPTAG: REQUEST_LOG_ID_LABEL + +# current_request_log_id_collector = \ +# collector.request_log_id_collector +# collector.request_log_id_collector = lambda: 'foo_bar_id' + +# IntegrationTest._hub.SetBreakpoint('REQUEST_LOG_ID_LABEL') + +# Trigger() + +# collector.request_log_id_collector = \ +# current_request_log_id_collector + +# result = IntegrationTest._hub.GetNextResult() +# self.assertIn('labels', result.keys()) +# self.assertIn(labels.Breakpoint.REQUEST_LOG_ID, result['labels']) +# self.assertEqual('foo_bar_id', +# result['labels'][labels.Breakpoint.REQUEST_LOG_ID]) + +# # Tests the issue in b/30876465 +# def testSameLine(self): + +# def Trigger(): +# print('Breakpoint trigger same line') # BPTAG: SAME_LINE + +# num_breakpoints = 5 +# _, line = python_test_util.ResolveTag(sys.modules[__name__], 'SAME_LINE') +# for _ in range(0, num_breakpoints): +# IntegrationTest._hub.SetBreakpoint('SAME_LINE') +# Trigger() +# results = [] +# for _ in range(0, num_breakpoints): +# results.append(IntegrationTest._hub.GetNextResult()) +# lines = [result['stackFrames'][0]['location']['line'] for result in results] +# self.assertListEqual(lines, [line] * num_breakpoints) + +# def testCallStack(self): + +# def Method1(): +# Method2() + +# def Method2(): +# Method3() + +# def Method3(): +# Method4() + +# def Method4(): +# Method5() + +# def Method5(): +# return 0 # BPTAG: CALL_STACK + +# IntegrationTest._hub.SetBreakpoint('CALL_STACK') +# Method1() +# result = IntegrationTest._hub.GetNextResult() +# self.assertEqual([ +# 'Method5', 'Method4', 'Method3', 'Method2', 'Method1', +# 'IntegrationTest.testCallStack' +# ], [frame['function'] for frame in result['stackFrames']][:6]) + +# def testInnerMethod(self): + +# def Inner1(): + +# def Inner2(): + +# def Inner3(): +# print('Inner3') # BPTAG: INNER3 + +# Inner3() + +# Inner2() + +# IntegrationTest._hub.SetBreakpoint('INNER3') +# Inner1() +# result = IntegrationTest._hub.GetNextResult() +# self.assertEqual('Inner3', result['stackFrames'][0]['function']) + +# def testClassMethodWithDecorator(self): + +# def MyDecorator(handler): + +# def Caller(self): +# return handler(self) + +# return Caller + +# class BaseClass(object): +# pass + +# class MyClass(BaseClass): + +# @MyDecorator +# def Get(self): +# param = {} # BPTAG: METHOD_WITH_DECORATOR +# return str(param) + +# IntegrationTest._hub.SetBreakpoint('METHOD_WITH_DECORATOR') +# self.assertEqual('{}', MyClass().Get()) +# result = IntegrationTest._hub.GetNextResult() +# self.assertEqual('MyClass.Get', result['stackFrames'][0]['function']) +# self.assertEqual('MyClass.Caller', result['stackFrames'][1]['function']) +# self.assertEqual( +# { +# 'name': +# 'self', +# 'type': +# __name__ + '.MyClass', +# 'members': [{ +# 'status': { +# 'refersTo': 'VARIABLE_NAME', +# 'description': { +# 'format': 'Object has no fields' +# } +# } +# }] +# }, +# python_test_util.PackFrameVariable( +# result, 'self', collection='arguments')) + +# def testGlobalDecorator(self): +# IntegrationTest._hub.SetBreakpoint('WRAPPED_GLOBAL_METHOD') +# self.assertEqual('hello', WrappedGlobalMethod()) +# result = IntegrationTest._hub.GetNextResult() + +# self.assertNotIn('status', result) + +# def testNoLambdaExpression(self): + +# def Trigger(): +# cube = lambda x: x**3 # BPTAG: LAMBDA +# cube(18) + +# num_breakpoints = 5 +# for _ in range(0, num_breakpoints): +# IntegrationTest._hub.SetBreakpoint('LAMBDA') +# Trigger() +# results = [] +# for _ in range(0, num_breakpoints): +# results.append(IntegrationTest._hub.GetNextResult()) +# functions = [result['stackFrames'][0]['function'] for result in results] +# self.assertListEqual(functions, ['Trigger'] * num_breakpoints) + +# def testNoGeneratorExpression(self): + +# def Trigger(): +# gen = (i for i in range(0, 5)) # BPTAG: GENEXPR +# next(gen) +# next(gen) +# next(gen) +# next(gen) +# next(gen) + +# num_breakpoints = 1 +# for _ in range(0, num_breakpoints): +# IntegrationTest._hub.SetBreakpoint('GENEXPR') +# Trigger() +# results = [] +# for _ in range(0, num_breakpoints): +# results.append(IntegrationTest._hub.GetNextResult()) +# functions = [result['stackFrames'][0]['function'] for result in results] +# self.assertListEqual(functions, ['Trigger'] * num_breakpoints) + +# def testTryBlock(self): + +# def Method(a): +# try: +# return a * a # BPTAG: TRY_BLOCK +# except Exception as unused_e: # pylint: disable=broad-except +# return a + +# IntegrationTest._hub.SetBreakpoint('TRY_BLOCK') +# Method(11) +# result = IntegrationTest._hub.GetNextResult() +# self.assertEqual('Method', result['stackFrames'][0]['function']) +# self.assertEqual([{ +# 'name': 'a', +# 'value': '11', +# 'type': 'int' +# }], result['stackFrames'][0]['arguments']) + +# def testFrameArguments(self): + +# def Method(a, b): +# return a + str(b) # BPTAG: FRAME_ARGUMENTS + +# IntegrationTest._hub.SetBreakpoint('FRAME_ARGUMENTS') +# Method('hello', 87) +# result = IntegrationTest._hub.GetNextResult() +# self.assertEqual([{ +# 'name': 'a', +# 'value': "'hello'", +# 'type': 'str' +# }, { +# 'name': 'b', +# 'value': '87', +# 'type': 'int' +# }], result['stackFrames'][0]['arguments']) +# self.assertEqual('self', result['stackFrames'][1]['arguments'][0]['name']) + +# def testFrameLocals(self): + +# class Number(object): + +# def __init__(self): +# self.n = 57 + +# def Method(a): +# b = a**2 +# c = str(a) * 3 +# return c + str(b) # BPTAG: FRAME_LOCALS + +# IntegrationTest._hub.SetBreakpoint('FRAME_LOCALS') +# x = {'a': 1, 'b': Number()} +# Method(8) +# result = IntegrationTest._hub.GetNextResult() +# self.assertEqual({ +# 'name': 'b', +# 'value': '64', +# 'type': 'int' +# }, python_test_util.PackFrameVariable(result, 'b')) +# self.assertEqual({ +# 'name': 'c', +# 'value': "'888'", +# 'type': 'str' +# }, python_test_util.PackFrameVariable(result, 'c')) +# self.assertEqual( +# { +# 'name': +# 'x', +# 'type': +# 'dict', +# 'members': [{ +# 'name': "'a'", +# 'value': '1', +# 'type': 'int' +# }, { +# 'name': "'b'", +# 'type': __name__ + '.Number', +# 'members': [{ +# 'name': 'n', +# 'value': '57', +# 'type': 'int' +# }] +# }] +# }, python_test_util.PackFrameVariable(result, 'x', frame=1)) +# return x + + +# # FIXME: Broken in Python 3.10 +# # def testRecursion(self): +# # +# # def RecursiveMethod(i): +# # if i == 0: +# # return 0 # BPTAG: RECURSION +# # return RecursiveMethod(i - 1) +# # +# # IntegrationTest._hub.SetBreakpoint('RECURSION') +# # RecursiveMethod(5) +# # result = IntegrationTest._hub.GetNextResult() +# # +# # for frame in range(5): +# # self.assertEqual({ +# # 'name': 'i', +# # 'value': str(frame), +# # 'type': 'int' +# # }, python_test_util.PackFrameVariable(result, 'i', frame, 'arguments')) + +# def testWatchedExpressions(self): + +# def Trigger(): + +# class MyClass(object): + +# def __init__(self): +# self.a = 1 +# self.b = 'bbb' + +# unused_my = MyClass() +# print('Breakpoint trigger') # BPTAG: WATCHED_EXPRESSION + +# IntegrationTest._hub.SetBreakpoint('WATCHED_EXPRESSION', +# {'expressions': ['unused_my']}) +# Trigger() +# result = IntegrationTest._hub.GetNextResult() + +# self.assertEqual( +# { +# 'name': +# 'unused_my', +# 'type': +# __name__ + '.MyClass', +# 'members': [{ +# 'name': 'a', +# 'value': '1', +# 'type': 'int' +# }, { +# 'name': 'b', +# 'value': "'bbb'", +# 'type': 'str' +# }] +# }, python_test_util.PackWatchedExpression(result, 0)) + +# def testBreakpointExpiration(self): # BPTAG: BREAKPOINT_EXPIRATION +# created_time = datetime.utcnow() - timedelta(hours=25) +# IntegrationTest._hub.SetBreakpoint( +# 'BREAKPOINT_EXPIRATION', +# {'createTime': python_test_util.DateTimeToTimestamp(created_time)}) +# result = IntegrationTest._hub.GetNextResult() + +# self.assertTrue(result['status']['isError']) + +# def testLogAction(self): + +# def Trigger(): +# for i in range(3): +# print('Log me %d' % i) # BPTAG: LOG + +# IntegrationTest._hub.SetBreakpoint( +# 'LOG', { +# 'action': 'LOG', +# 'logLevel': 'INFO', +# 'logMessageFormat': 'hello $0', +# 'expressions': ['i'] +# }) +# Trigger() +# self.assertListEqual( +# ['LOGPOINT: hello 0', 'LOGPOINT: hello 1', 'LOGPOINT: hello 2'], +# self._info_log) + +# def testDeferred(self): + +# def Trigger(): +# import integration_test_helper # pylint: disable=g-import-not-at-top +# integration_test_helper.Trigger() + +# IntegrationTest._hub.SetBreakpointAtFile('integration_test_helper.py', +# 'DEFERRED') + +# Trigger() +# result = IntegrationTest._hub.GetNextResult() +# self.assertEqual('Trigger', result['stackFrames'][0]['function']) +# self.assertEqual('Trigger', result['stackFrames'][1]['function']) +# self.assertEqual('IntegrationTest.testDeferred', +# result['stackFrames'][2]['function']) + + +def MyGlobalDecorator(fn): + + @functools.wraps(fn) + def Wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + return Wrapper + + +@MyGlobalDecorator +def WrappedGlobalMethod(): + return 'hello' # BPTAG: WRAPPED_GLOBAL_METHOD + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/integration_test_helper.py b/tests/py/integration_test_helper.py new file mode 100644 index 0000000..5a7e04b --- /dev/null +++ b/tests/py/integration_test_helper.py @@ -0,0 +1,5 @@ +"""Helper module for integration test to validate deferred breakpoints.""" + + +def Trigger(): + print('bp trigger') # BPTAG: DEFERRED diff --git a/tests/py/labels_test.py b/tests/py/labels_test.py new file mode 100644 index 0000000..b7b01dd --- /dev/null +++ b/tests/py/labels_test.py @@ -0,0 +1,29 @@ +"""Tests for googleclouddebugger.labels""" + +from absl.testing import absltest +from googleclouddebugger import labels + + +class LabelsTest(absltest.TestCase): + + def testDefinesLabelsCorrectly(self): + self.assertEqual(labels.Breakpoint.REQUEST_LOG_ID, 'requestlogid') + + self.assertEqual(labels.Debuggee.DOMAIN, 'domain') + self.assertEqual(labels.Debuggee.PROJECT_ID, 'projectid') + self.assertEqual(labels.Debuggee.MODULE, 'module') + self.assertEqual(labels.Debuggee.VERSION, 'version') + self.assertEqual(labels.Debuggee.MINOR_VERSION, 'minorversion') + self.assertEqual(labels.Debuggee.PLATFORM, 'platform') + self.assertEqual(labels.Debuggee.REGION, 'region') + + def testProvidesAllLabelsSet(self): + self.assertIsNotNone(labels.Breakpoint.SET_ALL) + self.assertLen(labels.Breakpoint.SET_ALL, 1) + + self.assertIsNotNone(labels.Debuggee.SET_ALL) + self.assertLen(labels.Debuggee.SET_ALL, 7) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/module_explorer_test.py b/tests/py/module_explorer_test.py new file mode 100644 index 0000000..4e1a42c --- /dev/null +++ b/tests/py/module_explorer_test.py @@ -0,0 +1,321 @@ +"""Unit test for module_explorer module.""" + +import dis +import inspect +import os +import py_compile +import shutil +import sys +import tempfile + +from absl.testing import absltest + +from googleclouddebugger import module_explorer +import python_test_util + + +class ModuleExplorerTest(absltest.TestCase): + """Unit test for module_explorer module.""" + + def setUp(self): + self._module = sys.modules[__name__] + self._code_objects = module_explorer._GetModuleCodeObjects(self._module) + + # Populate line cache for this module (neeed .par test). + inspect.getsourcelines(self.testCodeObjectAtLine) + + def testGlobalMethod(self): + """Verify that global method is found.""" + self.assertIn(_GlobalMethod.__code__, self._code_objects) + + def testInnerMethodOfGlobalMethod(self): + """Verify that inner method defined in a global method is found.""" + self.assertIn(_GlobalMethod(), self._code_objects) + + def testInstanceClassMethod(self): + """Verify that instance class method is found.""" + self.assertIn(self.testInstanceClassMethod.__code__, self._code_objects) + + def testInnerMethodOfInstanceClassMethod(self): + """Verify that inner method defined in a class instance method is found.""" + + def InnerMethod(): + pass + + self.assertIn(InnerMethod.__code__, self._code_objects) + + def testStaticMethod(self): + """Verify that static class method is found.""" + self.assertIn(ModuleExplorerTest._StaticMethod.__code__, self._code_objects) + + def testInnerMethodOfStaticMethod(self): + """Verify that static class method is found.""" + self.assertIn(ModuleExplorerTest._StaticMethod(), self._code_objects) + + def testNonModuleClassMethod(self): + """Verify that instance method defined in a base class is not added.""" + self.assertNotIn(self.assertTrue.__code__, self._code_objects) + + def testDeepInnerMethod(self): + """Verify that inner of inner of inner, etc. method is found.""" + + def Inner1(): + + def Inner2(): + + def Inner3(): + + def Inner4(): + + def Inner5(): + pass + + return Inner5.__code__ + + return Inner4() + + return Inner3() + + return Inner2() + + self.assertIn(Inner1(), self._code_objects) + + def testNoLambdaExpression(self): + """Verify that code of lambda expression is not included.""" + + self.assertNotIn(_MethodWithLambdaExpression(), self._code_objects) + + def testNoGeneratorExpression(self): + """Verify that code of generator expression is not included.""" + + self.assertNotIn(_MethodWithGeneratorExpression(), self._code_objects) + + def testMethodOfInnerClass(self): + """Verify that method of inner class is found.""" + + class InnerClass(object): + + def InnerClassMethod(self): + pass + + self.assertIn(InnerClass().InnerClassMethod.__code__, self._code_objects) + + def testMethodOfInnerOldStyleClass(self): + """Verify that method of inner old style class is found.""" + + class InnerClass(): + + def InnerClassMethod(self): + pass + + self.assertIn(InnerClass().InnerClassMethod.__code__, self._code_objects) + + def testGlobalMethodWithClosureDecorator(self): + co = self._GetCodeObjectAtLine(self._module, + 'GLOBAL_METHOD_WITH_CLOSURE_DECORATOR') + self.assertTrue(co) + self.assertEqual('GlobalMethodWithClosureDecorator', co.co_name) + + def testClassMethodWithClosureDecorator(self): + co = self._GetCodeObjectAtLine( + self._module, 'GLOBAL_CLASS_METHOD_WITH_CLOSURE_DECORATOR') + self.assertTrue(co) + self.assertEqual('FnWithClosureDecorator', co.co_name) + + def testGlobalMethodWithClassDecorator(self): + co = self._GetCodeObjectAtLine(self._module, + 'GLOBAL_METHOD_WITH_CLASS_DECORATOR') + self.assertTrue(co) + self.assertEqual('GlobalMethodWithClassDecorator', co.co_name) + + def testClassMethodWithClassDecorator(self): + co = self._GetCodeObjectAtLine(self._module, + 'GLOBAL_CLASS_METHOD_WITH_CLASS_DECORATOR') + self.assertTrue(co) + self.assertEqual('FnWithClassDecorator', co.co_name) + + def testSameFileName(self): + """Verify that all found code objects are defined in the same file.""" + path = next(iter(self._code_objects)).co_filename + self.assertTrue(path) + + for code_object in self._code_objects: + self.assertEqual(path, code_object.co_filename) + + def testCodeObjectAtLine(self): + """Verify that query of code object at a specified source line.""" + test_cases = [ + (self.testCodeObjectAtLine.__code__, 'TEST_CODE_OBJECT_AT_ASSERT'), + (ModuleExplorerTest._StaticMethod(), 'INNER_OF_STATIC_METHOD'), + (_GlobalMethod(), 'INNER_OF_GLOBAL_METHOD') + ] + + for code_object, tag in test_cases: + self.assertEqual( # BPTAG: TEST_CODE_OBJECT_AT_ASSERT + code_object, self._GetCodeObjectAtLine(code_object, tag)) + + def testCodeObjectWithoutModule(self): + """Verify no crash/hang when module has no file name.""" + global global_code_object # pylint: disable=global-variable-undefined + global_code_object = compile('2+3', '', 'exec') + + self.assertFalse( + module_explorer.GetCodeObjectAtLine(self._module, 111111)[0]) + + +# TODO: Re-enable this test, without hardcoding a python version into it. +# def testCodeExtensionMismatch(self): +# """Verify module match when code object points to .py and module to .pyc.""" +# test_dir = tempfile.mkdtemp('', 'module_explorer_') +# sys.path.append(test_dir) +# try: +# # Create and compile module, remove the .py file and leave the .pyc file. +# module_path = os.path.join(test_dir, 'module.py') +# with open(module_path, 'w') as f: +# f.write('def f():\n pass') +# py_compile.compile(module_path) +# module_pyc_path = os.path.join(test_dir, '__pycache__', +# 'module.cpython-37.pyc') +# os.rename(module_pyc_path, module_path + 'c') +# os.remove(module_path) +# +# import module # pylint: disable=g-import-not-at-top +# self.assertEqual('.py', +# os.path.splitext(module.f.__code__.co_filename)[1]) +# self.assertEqual('.pyc', os.path.splitext(module.__file__)[1]) +# +# func_code = module.f.__code__ +# self.assertEqual(func_code, +# module_explorer.GetCodeObjectAtLine( +# module, +# next(dis.findlinestarts(func_code))[1])[1]) +# finally: +# sys.path.remove(test_dir) +# shutil.rmtree(test_dir) + + def testMaxVisitObjects(self): + default_quota = module_explorer._MAX_VISIT_OBJECTS + try: + module_explorer._MAX_VISIT_OBJECTS = 10 + self.assertLess( + len(module_explorer._GetModuleCodeObjects(self._module)), + len(self._code_objects)) + finally: + module_explorer._MAX_VISIT_OBJECTS = default_quota + + def testMaxReferentsBfsDepth(self): + default_quota = module_explorer._MAX_REFERENTS_BFS_DEPTH + try: + module_explorer._MAX_REFERENTS_BFS_DEPTH = 1 + self.assertLess( + len(module_explorer._GetModuleCodeObjects(self._module)), + len(self._code_objects)) + finally: + module_explorer._MAX_REFERENTS_BFS_DEPTH = default_quota + + def testMaxObjectReferents(self): + + class A(object): + pass + + default_quota = module_explorer._MAX_VISIT_OBJECTS + default_referents_quota = module_explorer._MAX_OBJECT_REFERENTS + try: + global large_dict + large_dict = {A(): 0 for i in range(0, 5000)} + + # First test with a referents limit too large, it will visit large_dict + # and exhaust the _MAX_VISIT_OBJECTS quota before finding all the code + # objects + module_explorer._MAX_VISIT_OBJECTS = 5000 + module_explorer._MAX_OBJECT_REFERENTS = sys.maxsize + self.assertLess( + len(module_explorer._GetModuleCodeObjects(self._module)), + len(self._code_objects)) + + # Now test with a referents limit that prevents large_dict from being + # explored, all the code objects should be found now that the large dict + # is skipped and isn't taking up the _MAX_VISIT_OBJECTS quota + module_explorer._MAX_OBJECT_REFERENTS = default_referents_quota + self.assertItemsEqual( + module_explorer._GetModuleCodeObjects(self._module), + self._code_objects) + finally: + module_explorer._MAX_VISIT_OBJECTS = default_quota + module_explorer._MAX_OBJECT_REFERENTS = default_referents_quota + large_dict = None + + @staticmethod + def _StaticMethod(): + + def InnerMethod(): + pass # BPTAG: INNER_OF_STATIC_METHOD + + return InnerMethod.__code__ + + def _GetCodeObjectAtLine(self, fn, tag): + """Wrapper over GetCodeObjectAtLine for tags in this module.""" + unused_path, line = python_test_util.ResolveTag(fn, tag) + return module_explorer.GetCodeObjectAtLine(self._module, line)[1] + + +def _GlobalMethod(): + + def InnerMethod(): + pass # BPTAG: INNER_OF_GLOBAL_METHOD + + return InnerMethod.__code__ + + +def ClosureDecorator(handler): + + def Caller(*args): + return handler(*args) + + return Caller + + +class ClassDecorator(object): + + def __init__(self, fn): + self._fn = fn + + def __call__(self, *args): + return self._fn(*args) + + +@ClosureDecorator +def GlobalMethodWithClosureDecorator(): + return True # BPTAG: GLOBAL_METHOD_WITH_CLOSURE_DECORATOR + + +@ClassDecorator +def GlobalMethodWithClassDecorator(): + return True # BPTAG: GLOBAL_METHOD_WITH_CLASS_DECORATOR + + +class GlobalClass(object): + + @ClosureDecorator + def FnWithClosureDecorator(self): + return True # BPTAG: GLOBAL_CLASS_METHOD_WITH_CLOSURE_DECORATOR + + @ClassDecorator + def FnWithClassDecorator(self): + return True # BPTAG: GLOBAL_CLASS_METHOD_WITH_CLASS_DECORATOR + + +def _MethodWithLambdaExpression(): + return (lambda x: x**3).__code__ + + +def _MethodWithGeneratorExpression(): + return (i for i in range(0, 2)).gi_code + + +# Used for testMaxObjectReferents, need to be in global scope or else the module +# explorer would not explore this +large_dict = None + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/module_search_test.py b/tests/py/module_search_test.py new file mode 100644 index 0000000..3a12c57 --- /dev/null +++ b/tests/py/module_search_test.py @@ -0,0 +1,122 @@ +"""Unit test for module_search module.""" + +import os +import sys +import tempfile + +from absl.testing import absltest + +from googleclouddebugger import module_search + + +# TODO: Add tests for whitespace in location path including in, +# extension, basename, path +class SearchModulesTest(absltest.TestCase): + + def setUp(self): + self._test_package_dir = tempfile.mkdtemp('', 'package_') + sys.path.append(self._test_package_dir) + + def tearDown(self): + sys.path.remove(self._test_package_dir) + + def testSearchValidSourcePath(self): + # These modules are on the sys.path. + self.assertEndsWith( + module_search.Search('googleclouddebugger/module_search.py'), + '/site-packages/googleclouddebugger/module_search.py') + + # inspect and dis are libraries with no real file. So, we + # can no longer match them by file path. + + def testSearchInvalidSourcePath(self): + # This is an invalid module that doesn't exist anywhere. + self.assertEqual(module_search.Search('aaaaa.py'), 'aaaaa.py') + + # This module exists, but the search input is missing the outer package + # name. + self.assertEqual(module_search.Search('absltest.py'), 'absltest.py') + + def testSearchInvalidExtension(self): + # Test that the module rejects invalid extension in the input. + with self.assertRaises(AssertionError): + module_search.Search('module_search.x') + + def testSearchPathStartsWithSep(self): + # Test that module rejects invalid leading os.sep char in the input. + with self.assertRaises(AssertionError): + module_search.Search('/module_search') + + def testSearchRelativeSysPath(self): + # An entry in sys.path is in relative form, and represents the same + # directory as as another absolute entry in sys.path. + for directory in ['', 'a', 'a/b']: + self._CreateFile(os.path.join(directory, '__init__.py')) + self._CreateFile('a/b/first.py') + + try: + # Inject a relative path into sys.path that refers to a directory already + # in sys.path. It should produce the same result as the non-relative form. + testdir_alias = os.path.join(self._test_package_dir, 'a/../a') + + # Add 'a/../a' to sys.path so that 'b/first.py' is reachable. + sys.path.insert(0, testdir_alias) + + # Returned result should have a successful file match and relative + # paths should be kept as-is. + result = module_search.Search('b/first.py') + self.assertEndsWith(result, 'a/../a/b/first.py') + + finally: + sys.path.remove(testdir_alias) + + def testSearchSymLinkInSysPath(self): + # An entry in sys.path is a symlink. + for directory in ['', 'a', 'a/b']: + self._CreateFile(os.path.join(directory, '__init__.py'), '') + self._CreateFile('a/b/first.py') + self._CreateSymLink('a', 'link') + + try: + # Add 'link/' to sys.path so that 'b/first.py' is reachable. + sys.path.append(os.path.join(self._test_package_dir, 'link')) + + # Returned result should have a successful file match and symbolic + # links should be kept. + self.assertEndsWith(module_search.Search('b/first.py'), 'link/b/first.py') + finally: + sys.path.remove(os.path.join(self._test_package_dir, 'link')) + + def _CreateFile(self, path, contents='assert False "Unexpected import"\n'): + full_path = os.path.join(self._test_package_dir, path) + directory, unused_name = os.path.split(full_path) + + if not os.path.isdir(directory): + os.makedirs(directory) + + with open(full_path, 'w') as writer: + writer.write(contents) + + return path + + def _CreateSymLink(self, source, link_name): + full_source_path = os.path.join(self._test_package_dir, source) + full_link_path = os.path.join(self._test_package_dir, link_name) + os.symlink(full_source_path, full_link_path) + + # Since we cannot use os.path.samefile or os.path.realpath to eliminate + # symlinks reliably, we only check suffix equivalence of file paths in these + # unit tests. + def _AssertEndsWith(self, match, path): + """Asserts exactly one match ending with path.""" + self.assertLen(match, 1) + self.assertEndsWith(match[0], path) + + def _AssertEqFile(self, match, path): + """Asserts exactly one match equals to the file created with _CreateFile.""" + self.assertLen(match, 1) + self.assertEqual(match[0], os.path.join(self._test_package_dir, path)) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/module_utils_test.py b/tests/py/module_utils_test.py new file mode 100644 index 0000000..ac847ad --- /dev/null +++ b/tests/py/module_utils_test.py @@ -0,0 +1,168 @@ +"""Tests for googleclouddebugger.module_utils.""" + +import os +import sys +import tempfile + +from absl.testing import absltest + +from googleclouddebugger import module_utils + + +class TestModule(object): + """Dummy class with __name__ and __file__ attributes.""" + + def __init__(self, name, path): + self.__name__ = name + self.__file__ = path + + +def _AddSysModule(name, path): + sys.modules[name] = TestModule(name, path) + + +class ModuleUtilsTest(absltest.TestCase): + + def setUp(self): + self._test_package_dir = tempfile.mkdtemp('', 'package_') + self.modules = sys.modules.copy() + + def tearDown(self): + sys.modules = self.modules + self.modules = None + + def _CreateFile(self, path): + full_path = os.path.join(self._test_package_dir, path) + directory, unused_name = os.path.split(full_path) + + if not os.path.isdir(directory): + os.makedirs(directory) + + with open(full_path, 'w') as writer: + writer.write('') + + return full_path + + def _CreateSymLink(self, source, link_name): + full_source_path = os.path.join(self._test_package_dir, source) + full_link_path = os.path.join(self._test_package_dir, link_name) + os.symlink(full_source_path, full_link_path) + return full_link_path + + def _AssertEndsWith(self, a, b, msg=None): + """Assert that string a ends with string b.""" + if not a.endswith(b): + standard_msg = '%s does not end with %s' % (a, b) + self.fail(self._formatMessage(msg, standard_msg)) + + def testSimpleLoadedModuleFromSuffix(self): + # Lookup simple module. + _AddSysModule('m1', '/a/b/p1/m1.pyc') + for suffix in [ + 'm1.py', 'm1.pyc', 'm1.pyo', 'p1/m1.py', 'b/p1/m1.py', 'a/b/p1/m1.py', + '/a/b/p1/m1.py' + ]: + m1 = module_utils.GetLoadedModuleBySuffix(suffix) + self.assertTrue(m1, 'Module not found') + self.assertEqual('/a/b/p1/m1.pyc', m1.__file__) + + # Lookup simple package, no ext. + _AddSysModule('p1', '/a/b/p1/__init__.pyc') + for suffix in [ + 'p1/__init__.py', 'b/p1/__init__.py', 'a/b/p1/__init__.py', + '/a/b/p1/__init__.py' + ]: + p1 = module_utils.GetLoadedModuleBySuffix(suffix) + self.assertTrue(p1, 'Package not found') + self.assertEqual('/a/b/p1/__init__.pyc', p1.__file__) + + # Lookup via bad suffix. + for suffix in [ + 'm2.py', 'p2/m1.py', 'b2/p1/m1.py', 'a2/b/p1/m1.py', '/a2/b/p1/m1.py' + ]: + m1 = module_utils.GetLoadedModuleBySuffix(suffix) + self.assertFalse(m1, 'Module found unexpectedly') + + def testComplexLoadedModuleFromSuffix(self): + # Lookup complex module. + _AddSysModule('b.p1.m1', '/a/b/p1/m1.pyc') + for suffix in [ + 'm1.py', 'p1/m1.py', 'b/p1/m1.py', 'a/b/p1/m1.py', '/a/b/p1/m1.py' + ]: + m1 = module_utils.GetLoadedModuleBySuffix(suffix) + self.assertTrue(m1, 'Module not found') + self.assertEqual('/a/b/p1/m1.pyc', m1.__file__) + + # Lookup complex package, no ext. + _AddSysModule('a.b.p1', '/a/b/p1/__init__.pyc') + for suffix in [ + 'p1/__init__.py', 'b/p1/__init__.py', 'a/b/p1/__init__.py', + '/a/b/p1/__init__.py' + ]: + p1 = module_utils.GetLoadedModuleBySuffix(suffix) + self.assertTrue(p1, 'Package not found') + self.assertEqual('/a/b/p1/__init__.pyc', p1.__file__) + + def testSimilarLoadedModuleFromSuffix(self): + # Lookup similar module, no ext. + _AddSysModule('m1', '/a/b/p2/m1.pyc') + _AddSysModule('p1.m1', '/a/b1/p1/m1.pyc') + _AddSysModule('b.p1.m1', '/a1/b/p1/m1.pyc') + _AddSysModule('a.b.p1.m1', '/a/b/p1/m1.pyc') + + m1 = module_utils.GetLoadedModuleBySuffix('/a/b/p1/m1.py') + self.assertTrue(m1, 'Module not found') + self.assertEqual('/a/b/p1/m1.pyc', m1.__file__) + + # Lookup similar package, no ext. + _AddSysModule('p1', '/a1/b1/p1/__init__.pyc') + _AddSysModule('b.p1', '/a1/b/p1/__init__.pyc') + _AddSysModule('a.b.p1', '/a/b/p1/__init__.pyc') + p1 = module_utils.GetLoadedModuleBySuffix('/a/b/p1/__init__.py') + self.assertTrue(p1, 'Package not found') + self.assertEqual('/a/b/p1/__init__.pyc', p1.__file__) + + def testDuplicateLoadedModuleFromSuffix(self): + # Lookup name dup module and package. + _AddSysModule('m1', '/m1/__init__.pyc') + _AddSysModule('m1.m1', '/m1/m1.pyc') + _AddSysModule('m1.m1.m1', '/m1/m1/m1/__init__.pyc') + _AddSysModule('m1.m1.m1.m1', '/m1/m1/m1/m1.pyc') + + # Ambiguous request, multiple modules might have matched. + m1 = module_utils.GetLoadedModuleBySuffix('/m1/__init__.py') + self.assertTrue(m1, 'Package not found') + self.assertIn(m1.__file__, ['/m1/__init__.pyc', '/m1/m1/m1/__init__.pyc']) + + # Ambiguous request, multiple modules might have matched. + m1m1 = module_utils.GetLoadedModuleBySuffix('/m1/m1.py') + self.assertTrue(m1m1, 'Module not found') + self.assertIn(m1m1.__file__, ['/m1/m1.pyc', '/m1/m1/m1/m1.pyc']) + + # Not ambiguous. Only 1 match possible. + m1m1m1 = module_utils.GetLoadedModuleBySuffix('/m1/m1/m1/__init__.py') + self.assertTrue(m1m1m1, 'Package not found') + self.assertEqual('/m1/m1/m1/__init__.pyc', m1m1m1.__file__) + + # Not ambiguous. Only 1 match possible. + m1m1m1m1 = module_utils.GetLoadedModuleBySuffix('/m1/m1/m1/m1.py') + self.assertTrue(m1m1m1m1, 'Module not found') + self.assertEqual('/m1/m1/m1/m1.pyc', m1m1m1m1.__file__) + + def testMainLoadedModuleFromSuffix(self): + # Lookup complex module. + _AddSysModule('__main__', '/a/b/p/m.pyc') + m1 = module_utils.GetLoadedModuleBySuffix('/a/b/p/m.py') + self.assertTrue(m1, 'Module not found') + self.assertEqual('/a/b/p/m.pyc', m1.__file__) + + def testMainWithDotSlashLoadedModuleFromSuffix(self): + # Lookup module started via 'python3 ./m.py', notice the './' + _AddSysModule('__main__', '/a/b/p/./m.pyc') + m1 = module_utils.GetLoadedModuleBySuffix('/a/b/p/m.py') + self.assertIsNotNone(m1) + self.assertTrue(m1, 'Module not found') + self.assertEqual('/a/b/p/./m.pyc', m1.__file__) + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/native_module_test.py b/tests/py/native_module_test.py new file mode 100644 index 0000000..b3b486b --- /dev/null +++ b/tests/py/native_module_test.py @@ -0,0 +1,300 @@ +"""Unit tests for native module.""" + +import inspect +import sys +import threading +import time + +from absl.testing import absltest + +from googleclouddebugger import cdbg_native as native +import python_test_util + + +def _DoHardWork(base): + for i in range(base): + if base * i < 0: + return True + return False + + +class NativeModuleTest(absltest.TestCase): + """Unit tests for native module.""" + + def setUp(self): + # Lock for thread safety. + self._lock = threading.Lock() + + # Count hit count for the breakpoints we set. + self._breakpoint_counter = 0 + + # Registers breakpoint events other than breakpoint hit. + self._breakpoint_events = [] + + # Keep track of breakpoints we set to reset them on cleanup. + self._cookies = [] + + def tearDown(self): + # Verify that we didn't get any breakpoint events that the test did + # not expect. + self.assertEqual([], self._PopBreakpointEvents()) + + self._ClearAllBreakpoints() + + def testUnconditionalBreakpoint(self): + + def Trigger(): + unused_lock = threading.Lock() + print('Breakpoint trigger') # BPTAG: UNCONDITIONAL_BREAKPOINT + + self._SetBreakpoint(Trigger, 'UNCONDITIONAL_BREAKPOINT') + Trigger() + self.assertEqual(1, self._breakpoint_counter) + + def testConditionalBreakpoint(self): + + def Trigger(): + d = {} + for i in range(1, 10): + d[i] = i**2 # BPTAG: CONDITIONAL_BREAKPOINT + + self._SetBreakpoint(Trigger, 'CONDITIONAL_BREAKPOINT', 'i % 3 == 1') + Trigger() + self.assertEqual(3, self._breakpoint_counter) + + def testClearBreakpoint(self): + """Set two breakpoint on the same line, then clear one.""" + + def Trigger(): + print('Breakpoint trigger') # BPTAG: CLEAR_BREAKPOINT + + self._SetBreakpoint(Trigger, 'CLEAR_BREAKPOINT') + self._SetBreakpoint(Trigger, 'CLEAR_BREAKPOINT') + native.ClearConditionalBreakpoint(self._cookies.pop()) + Trigger() + self.assertEqual(1, self._breakpoint_counter) + + def testMissingModule(self): + + def Test(): + native.CreateConditionalBreakpoint(None, 123123, None, + self._BreakpointEvent) + + self.assertRaises(TypeError, Test) + + def testBadModule(self): + + def Test(): + native.CreateConditionalBreakpoint('str', 123123, None, + self._BreakpointEvent) + + self.assertRaises(TypeError, Test) + + def testInvalidCondition(self): + + def Test(): + native.CreateConditionalBreakpoint(sys.modules[__name__], 123123, '2+2', + self._BreakpointEvent) + + self.assertRaises(TypeError, Test) + + def testMissingCallback(self): + + def Test(): + native.CreateConditionalBreakpoint('code.py', 123123, None, None) + + self.assertRaises(TypeError, Test) + + def testInvalidCallback(self): + + def Test(): + native.CreateConditionalBreakpoint('code.py', 123123, None, {}) + + self.assertRaises(TypeError, Test) + + def testMissingCookie(self): + self.assertRaises(TypeError, + lambda: native.ClearConditionalBreakpoint(None)) + + def testInvalidCookie(self): + native.ClearConditionalBreakpoint(387873457) + + def testMutableCondition(self): + + def Trigger(): + + def MutableMethod(): + self._evil = True + return True + + print('MutableMethod = %s' % MutableMethod) # BPTAG: MUTABLE_CONDITION + + self._SetBreakpoint(Trigger, 'MUTABLE_CONDITION', 'MutableMethod()') + Trigger() + self.assertEqual([native.BREAKPOINT_EVENT_CONDITION_EXPRESSION_MUTABLE], + self._PopBreakpointEvents()) + + def testGlobalConditionQuotaExceeded(self): + + def Trigger(): + print('Breakpoint trigger') # BPTAG: GLOBAL_CONDITION_QUOTA + + self._SetBreakpoint(Trigger, 'GLOBAL_CONDITION_QUOTA', '_DoHardWork(1000)') + Trigger() + self._ClearAllBreakpoints() + + self.assertListEqual( + [native.BREAKPOINT_EVENT_GLOBAL_CONDITION_QUOTA_EXCEEDED], + self._PopBreakpointEvents()) + + # Sleep for some time to let the quota recover. + time.sleep(0.1) + + def testBreakpointConditionQuotaExceeded(self): + + def Trigger(): + print('Breakpoint trigger') # BPTAG: PER_BREAKPOINT_CONDITION_QUOTA + + time.sleep(1) + + # Per-breakpoint quota is lower than the global one. Exponentially + # increase the complexity of a condition until we hit it. + base = 100 + while True: + self._SetBreakpoint(Trigger, 'PER_BREAKPOINT_CONDITION_QUOTA', + '_DoHardWork(%d)' % base) + Trigger() + self._ClearAllBreakpoints() + + events = self._PopBreakpointEvents() + if events: + self.assertEqual( + [native.BREAKPOINT_EVENT_BREAKPOINT_CONDITION_QUOTA_EXCEEDED], + events) + break + + base *= 1.2 + time.sleep(0.1) + + # Sleep for some time to let the quota recover. + time.sleep(0.1) + + def testImmutableCallSuccess(self): + + def Add(a, b, c): + return a + b + c + + def Magic(): + return 'cake' + + self.assertEqual('643535', + self._CallImmutable(inspect.currentframe(), 'str(643535)')) + self.assertEqual( + 786 + 23 + 891, + self._CallImmutable(inspect.currentframe(), 'Add(786, 23, 891)')) + self.assertEqual('cake', + self._CallImmutable(inspect.currentframe(), 'Magic()')) + return Add or Magic + + def testImmutableCallMutable(self): + + def Change(): + dictionary['bad'] = True + + dictionary = {} + frame = inspect.currentframe() + self.assertRaises(SystemError, + lambda: self._CallImmutable(frame, 'Change()')) + self.assertEqual({}, dictionary) + return Change + + def testImmutableCallExceptionPropagation(self): + + def Divide(a, b): + return a / b + + frame = inspect.currentframe() + self.assertRaises(ZeroDivisionError, + lambda: self._CallImmutable(frame, 'Divide(1, 0)')) + return Divide + + def testImmutableCallInvalidFrame(self): + self.assertRaises(TypeError, lambda: native.CallImmutable(None, lambda: 1)) + self.assertRaises(TypeError, + lambda: native.CallImmutable('not a frame', lambda: 1)) + + def testImmutableCallInvalidCallable(self): + frame = inspect.currentframe() + self.assertRaises(TypeError, lambda: native.CallImmutable(frame, None)) + self.assertRaises(TypeError, + lambda: native.CallImmutable(frame, 'not a callable')) + + def _SetBreakpoint(self, method, tag, condition=None): + """Sets a breakpoint in this source file. + + The line number is identified by tag. This function does not verify that + the source line is in the specified method. + + The breakpoint may have an optional condition. + + Args: + method: method in which the breakpoint will be set. + tag: label for a source line. + condition: optional breakpoint condition. + """ + unused_path, line = python_test_util.ResolveTag(type(self), tag) + + compiled_condition = None + if condition is not None: + compiled_condition = compile(condition, '', 'eval') + + cookie = native.CreateConditionalBreakpoint(method.__code__, line, + compiled_condition, + self._BreakpointEvent) + + self._cookies.append(cookie) + native.ActivateConditionalBreakpoint(cookie) + + def _ClearAllBreakpoints(self): + """Removes all previously set breakpoints.""" + for cookie in self._cookies: + native.ClearConditionalBreakpoint(cookie) + + def _CallImmutable(self, frame, expression): + """Wrapper over native.ImmutableCall for callable.""" + return native.CallImmutable(frame, + compile(expression, '', 'eval')) + + def _BreakpointEvent(self, event, frame): + """Callback on breakpoint event. + + See thread_breakpoints.h for more details of possible events. + + Args: + event: breakpoint event (see kIntegerConstants in native_module.cc). + frame: Python stack frame of breakpoint hit or None for other events. + """ + with self._lock: + if event == native.BREAKPOINT_EVENT_HIT: + self.assertTrue(inspect.isframe(frame)) + self._breakpoint_counter += 1 + else: + self._breakpoint_events.append(event) + + def _PopBreakpointEvents(self): + """Gets and resets the list of breakpoint events received so far.""" + with self._lock: + events = self._breakpoint_events + self._breakpoint_events = [] + return events + + def _HasBreakpointEvents(self): + """Checks whether there are unprocessed breakpoint events.""" + with self._lock: + if self._breakpoint_events: + return True + return False + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/python_breakpoint_test.py b/tests/py/python_breakpoint_test.py new file mode 100644 index 0000000..6aff9c4 --- /dev/null +++ b/tests/py/python_breakpoint_test.py @@ -0,0 +1,652 @@ +"""Unit test for python_breakpoint module.""" + +from datetime import datetime +from datetime import timedelta +import inspect +import os +import sys +import tempfile + +from absl.testing import absltest + +from googleclouddebugger import cdbg_native as native +from googleclouddebugger import imphook +from googleclouddebugger import python_breakpoint +import python_test_util + + +class PythonBreakpointTest(absltest.TestCase): + """Unit test for python_breakpoint module.""" + + def setUp(self): + self._test_package_dir = tempfile.mkdtemp('', 'package_') + sys.path.append(self._test_package_dir) + + path, line = python_test_util.ResolveTag(type(self), 'CODE_LINE') + + self._base_time = datetime(year=2015, month=1, day=1) # BPTAG: CODE_LINE + self._template = { + 'id': 'BP_ID', + 'createTime': python_test_util.DateTimeToTimestamp(self._base_time), + 'location': { + 'path': path, + 'line': line + } + } + self._completed = set() + self._update_queue = [] + + def tearDown(self): + sys.path.remove(self._test_package_dir) + + def CompleteBreakpoint(self, breakpoint_id): + """Mock method of BreakpointsManager.""" + self._completed.add(breakpoint_id) + + def GetCurrentTime(self): + """Mock method of BreakpointsManager.""" + return self._base_time + + def EnqueueBreakpointUpdate(self, breakpoint): + """Mock method of HubClient.""" + self._update_queue.append(breakpoint) + + def testClear(self): + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint.Clear() + self.assertFalse(breakpoint._cookie) + + def testId(self): + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint.Clear() + self.assertEqual('BP_ID', breakpoint.GetBreakpointId()) + + def testNullBytesInCondition(self): + python_breakpoint.PythonBreakpoint( + dict(self._template, condition='\0'), self, self, None) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['status']['isError']) + self.assertTrue(self._update_queue[0]['isFinalState']) + + # Test only applies to the old module search algorithm. When using new module + # search algorithm, this test is same as testDeferredBreakpoint. + def testUnknownModule(self): + pass + + def testDeferredBreakpoint(self): + with open(os.path.join(self._test_package_dir, 'defer_print.py'), 'w') as f: + f.write('def DoPrint():\n') + f.write(' print("Hello from deferred module")\n') + + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': 'defer_print.py', + 'line': 2 + }), self, self, None) + + self.assertFalse(self._completed) + self.assertEmpty(self._update_queue) + + import defer_print # pylint: disable=g-import-not-at-top + defer_print.DoPrint() + + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertGreater(len(self._update_queue[0]['stackFrames']), 3) + self.assertEqual('DoPrint', + self._update_queue[0]['stackFrames'][0]['function']) + self.assertTrue(self._update_queue[0]['isFinalState']) + + self.assertEmpty(imphook._import_callbacks) + + # Old module search algorithm rejects multiple matches. This test verifies + # that the new module search algorithm searches sys.path sequentially, and + # selects the first match (just like the Python importer). + def testSearchUsingSysPathOrder(self): + for i in range(2, 0, -1): + # Create directories and add them to sys.path. + test_dir = os.path.join(self._test_package_dir, ('inner2_%s' % i)) + os.mkdir(test_dir) + sys.path.append(test_dir) + with open(os.path.join(test_dir, 'mod2.py'), 'w') as f: + f.write('def DoPrint():\n') + f.write(' x = %s\n' % i) + f.write(' return x') + + # Loads inner2_2/mod2.py because it comes first in sys.path. + import mod2 # pylint: disable=g-import-not-at-top + + # Search will proceed in sys.path order, and the first match in sys.path + # will uniquely identify the full path of the module as inner2_2/mod2.py. + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': 'mod2.py', + 'line': 3 + }), self, self, None) + + self.assertEqual(2, mod2.DoPrint()) + + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertGreater(len(self._update_queue[0]['stackFrames']), 3) + self.assertEqual('DoPrint', + self._update_queue[0]['stackFrames'][0]['function']) + self.assertTrue(self._update_queue[0]['isFinalState']) + self.assertEqual( + 'x', self._update_queue[0]['stackFrames'][0]['locals'][0]['name']) + self.assertEqual( + '2', self._update_queue[0]['stackFrames'][0]['locals'][0]['value']) + + self.assertEmpty(imphook._import_callbacks) + + # Old module search algorithm rejects multiple matches. This test verifies + # that when the new module search cannot find any match in sys.path, it + # defers the breakpoint, and then selects the first dynamically-loaded + # module that matches the given path. + def testMultipleDeferredMatches(self): + for i in range(2, 0, -1): + # Create packages, but do not add them to sys.path. + test_dir = os.path.join(self._test_package_dir, ('inner3_%s' % i)) + os.mkdir(test_dir) + with open(os.path.join(test_dir, '__init__.py'), 'w') as f: + pass + with open(os.path.join(test_dir, 'defer_print3.py'), 'w') as f: + f.write('def DoPrint():\n') + f.write(' x = %s\n' % i) + f.write(' return x') + + # This breakpoint will be deferred. It can match any one of the modules + # created above. + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': 'defer_print3.py', + 'line': 3 + }), self, self, None) + + # Lazy import module. Activates breakpoint on the loaded module. + import inner3_1.defer_print3 # pylint: disable=g-import-not-at-top + self.assertEqual(1, inner3_1.defer_print3.DoPrint()) + + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertGreater(len(self._update_queue[0]['stackFrames']), 3) + self.assertEqual('DoPrint', + self._update_queue[0]['stackFrames'][0]['function']) + self.assertTrue(self._update_queue[0]['isFinalState']) + self.assertEqual( + 'x', self._update_queue[0]['stackFrames'][0]['locals'][0]['name']) + self.assertEqual( + '1', self._update_queue[0]['stackFrames'][0]['locals'][0]['value']) + + self.assertEmpty(imphook._import_callbacks) + + def testNeverLoadedBreakpoint(self): + open(os.path.join(self._test_package_dir, 'never_print.py'), 'w').close() + + breakpoint = python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': 'never_print.py', + 'line': 99 + }), self, self, None) + breakpoint.Clear() + + self.assertFalse(self._completed) + self.assertEmpty(self._update_queue) + + def testDeferredNoCodeAtLine(self): + open(os.path.join(self._test_package_dir, 'defer_empty.py'), 'w').close() + + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': 'defer_empty.py', + 'line': 10 + }), self, self, None) + + self.assertFalse(self._completed) + self.assertEmpty(self._update_queue) + + import defer_empty # pylint: disable=g-import-not-at-top,unused-variable + + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + status = self._update_queue[0]['status'] + self.assertEqual(status['isError'], True) + self.assertEqual(status['refersTo'], 'BREAKPOINT_SOURCE_LOCATION') + desc = status['description'] + self.assertEqual(desc['format'], 'No code found at line $0 in $1') + params = desc['parameters'] + self.assertIn('defer_empty.py', params[1]) + self.assertEqual(params[0], '10') + self.assertEmpty(imphook._import_callbacks) + + def testDeferredBreakpointCancelled(self): + open(os.path.join(self._test_package_dir, 'defer_cancel.py'), 'w').close() + + breakpoint = python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': 'defer_cancel.py', + 'line': 11 + }), self, self, None) + breakpoint.Clear() + + self.assertFalse(self._completed) + self.assertEmpty(imphook._import_callbacks) + unused_no_code_line_above = 0 # BPTAG: NO_CODE_LINE_ABOVE + + # BPTAG: NO_CODE_LINE + def testNoCodeAtLine(self): + unused_no_code_line_below = 0 # BPTAG: NO_CODE_LINE_BELOW + path, line = python_test_util.ResolveTag(sys.modules[__name__], + 'NO_CODE_LINE') + path, line_above = python_test_util.ResolveTag(sys.modules[__name__], + 'NO_CODE_LINE_ABOVE') + path, line_below = python_test_util.ResolveTag(sys.modules[__name__], + 'NO_CODE_LINE_BELOW') + + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': path, + 'line': line + }), self, self, None) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + status = self._update_queue[0]['status'] + self.assertEqual(status['isError'], True) + self.assertEqual(status['refersTo'], 'BREAKPOINT_SOURCE_LOCATION') + desc = status['description'] + self.assertEqual(desc['format'], + 'No code found at line $0 in $1. Try lines $2 or $3.') + params = desc['parameters'] + self.assertEqual(params[0], str(line)) + self.assertIn(path, params[1]) + self.assertEqual(params[2], str(line_above)) + self.assertEqual(params[3], str(line_below)) + + def testBadExtension(self): + for path in ['unknown.so', 'unknown', 'unknown.java', 'unknown.pyc']: + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': path, + 'line': 83 + }), self, self, None) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + self.assertEqual( + { + 'isError': True, + 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', + 'description': { + 'format': ('Only files with .py extension are supported') + } + }, self._update_queue[0]['status']) + self._update_queue = [] + + def testRootInitFile(self): + for path in [ + '__init__.py', '/__init__.py', '////__init__.py', ' __init__.py ', + ' //__init__.py' + ]: + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': path, + 'line': 83 + }), self, self, None) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + self.assertEqual( + { + 'isError': True, + 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', + 'description': { + 'format': 'Multiple modules matching $0. ' + 'Please specify the module path.', + 'parameters': ['__init__.py'] + } + }, self._update_queue[0]['status']) + self._update_queue = [] + + # Old module search algorithm rejects because there are too many matches. + # The new algorithm selects the very first match in sys.path. + def testNonRootInitFile(self): + # Neither 'a' nor 'a/b' are real packages accessible via sys.path. + # Therefore, module search falls back to search '__init__.py', which matches + # the first entry in sys.path, which we artifically inject below. + test_dir = os.path.join(self._test_package_dir, 'inner4') + os.mkdir(test_dir) + with open(os.path.join(test_dir, '__init__.py'), 'w') as f: + f.write('def DoPrint():\n') + f.write(' print("Hello")') + sys.path.insert(0, test_dir) + + import inner4 # pylint: disable=g-import-not-at-top,unused-variable + + for path in ['/a/__init__.py', 'a/__init__.py', 'a/b/__init__.py']: + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': path, + 'line': 2 + }), self, self, None) + + inner4.DoPrint() + + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + self.assertGreater(len(self._update_queue[0]['stackFrames']), 3) + self.assertEqual('DoPrint', + self._update_queue[0]['stackFrames'][0]['function']) + + self.assertEmpty(imphook._import_callbacks) + self._update_queue = [] + + def testBreakpointInLoadedPackageFile(self): + """Test breakpoint in a loaded package.""" + for name in ['pkg', 'pkg/pkg']: + test_dir = os.path.join(self._test_package_dir, name) + os.mkdir(test_dir) + with open(os.path.join(test_dir, '__init__.py'), 'w') as f: + f.write('def DoPrint():\n') + f.write(' print("Hello from %s")\n' % name) + + import pkg # pylint: disable=g-import-not-at-top,unused-variable + import pkg.pkg # pylint: disable=g-import-not-at-top,unused-variable + + python_breakpoint.PythonBreakpoint( + dict( + self._template, location={ + 'path': 'pkg/pkg/__init__.py', + 'line': 2 + }), self, self, None) + + pkg.pkg.DoPrint() + + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + self.assertEqual(None, self._update_queue[0].get('status')) + self._update_queue = [] + + def testInternalError(self): + """Simulate internal error when setting a new breakpoint. + + Bytecode rewriting breakpoints are not supported for methods with more + than 65K constants. We generate such a method and try to set breakpoint in + it. + """ + + with open(os.path.join(self._test_package_dir, 'intern_err.py'), 'w') as f: + f.write('def DoSums():\n') + f.write(' x = 0\n') + for i in range(70000): + f.write(' x = x + %d\n' % i) + f.write(' print(x)\n') + + import intern_err # pylint: disable=g-import-not-at-top,unused-variable + + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': 'intern_err.py', + 'line': 100 + }), self, self, None) + + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertEqual( + { + 'isError': True, + 'description': { + 'format': 'Internal error occurred' + } + }, self._update_queue[0]['status']) + + def testInvalidCondition(self): + python_breakpoint.PythonBreakpoint( + dict(self._template, condition='2+'), self, self, None) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + if sys.version_info.minor < 10: + self.assertEqual( + { + 'isError': True, + 'refersTo': 'BREAKPOINT_CONDITION', + 'description': { + 'format': 'Expression could not be compiled: $0', + 'parameters': ['unexpected EOF while parsing'] + } + }, self._update_queue[0]['status']) + else: + self.assertEqual( + { + 'isError': True, + 'refersTo': 'BREAKPOINT_CONDITION', + 'description': { + 'format': 'Expression could not be compiled: $0', + 'parameters': ['invalid syntax'] + } + }, self._update_queue[0]['status']) + + def testHit(self): + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint._BreakpointEvent(native.BREAKPOINT_EVENT_HIT, + inspect.currentframe()) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertGreater(len(self._update_queue[0]['stackFrames']), 3) + self.assertTrue(self._update_queue[0]['isFinalState']) + + def testHitNewTimestamp(self): + # Override to use the new format (i.e., without the '.%f' sub-second part) + self._template['createTime'] = python_test_util.DateTimeToTimestampNew( + self._base_time) + + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint._BreakpointEvent(native.BREAKPOINT_EVENT_HIT, + inspect.currentframe()) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertGreater(len(self._update_queue[0]['stackFrames']), 3) + self.assertTrue(self._update_queue[0]['isFinalState']) + + def testHitTimestampUnixMsec(self): + # Using the Snapshot Debugger (Firebase backend) version of creation time + self._template.pop('createTime', None); + self._template[ + 'createTimeUnixMsec'] = python_test_util.DateTimeToUnixMsec( + self._base_time) + + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint._BreakpointEvent(native.BREAKPOINT_EVENT_HIT, + inspect.currentframe()) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertGreater(len(self._update_queue[0]['stackFrames']), 3) + self.assertTrue(self._update_queue[0]['isFinalState']) + + def testDoubleHit(self): + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint._BreakpointEvent(native.BREAKPOINT_EVENT_HIT, + inspect.currentframe()) + breakpoint._BreakpointEvent(native.BREAKPOINT_EVENT_HIT, + inspect.currentframe()) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + + def testEndToEndUnconditional(self): + + def Trigger(): + pass # BPTAG: E2E_UNCONDITIONAL + + path, line = python_test_util.ResolveTag(type(self), 'E2E_UNCONDITIONAL') + breakpoint = python_breakpoint.PythonBreakpoint( + { + 'id': 'BP_ID', + 'location': { + 'path': path, + 'line': line + } + }, self, self, None) + self.assertEmpty(self._update_queue) + Trigger() + self.assertLen(self._update_queue, 1) + breakpoint.Clear() + + def testEndToEndConditional(self): + + def Trigger(): + for i in range(2): + self.assertLen(self._update_queue, i) # BPTAG: E2E_CONDITIONAL + + path, line = python_test_util.ResolveTag(type(self), 'E2E_CONDITIONAL') + breakpoint = python_breakpoint.PythonBreakpoint( + { + 'id': 'BP_ID', + 'location': { + 'path': path, + 'line': line + }, + 'condition': 'i == 1' + }, self, self, None) + Trigger() + breakpoint.Clear() + + def testEndToEndCleared(self): + path, line = python_test_util.ResolveTag(type(self), 'E2E_CLEARED') + breakpoint = python_breakpoint.PythonBreakpoint( + { + 'id': 'BP_ID', + 'location': { + 'path': path, + 'line': line + } + }, self, self, None) + breakpoint.Clear() + self.assertEmpty(self._update_queue) # BPTAG: E2E_CLEARED + + def testBreakpointCancellationEvent(self): + events = [ + native.BREAKPOINT_EVENT_GLOBAL_CONDITION_QUOTA_EXCEEDED, + native.BREAKPOINT_EVENT_BREAKPOINT_CONDITION_QUOTA_EXCEEDED, + native.BREAKPOINT_EVENT_CONDITION_EXPRESSION_MUTABLE + ] + for event in events: + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, + self, None) + breakpoint._BreakpointEvent(event, None) + self.assertLen(self._update_queue, 1) + self.assertEqual(set(['BP_ID']), self._completed) + + self._update_queue = [] + self._completed = set() + + def testExpirationTime(self): + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint.Clear() + self.assertEqual( + datetime(year=2015, month=1, day=2), breakpoint.GetExpirationTime()) + + def testExpirationTimeUnixMsec(self): + # Using the Snapshot Debugger (Firebase backend) version of creation time + self._template.pop('createTime', None); + self._template[ + 'createTimeUnixMsec'] = python_test_util.DateTimeToUnixMsec( + self._base_time) + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint.Clear() + self.assertEqual( + self._base_time + timedelta(hours=24), breakpoint.GetExpirationTime()) + + def testExpirationTimeWithExpiresIn(self): + definition = self._template.copy() + definition['expires_in'] = { + 'seconds': 300 # 5 minutes + } + + breakpoint = python_breakpoint.PythonBreakpoint(definition, self, self, + None) + breakpoint.Clear() + self.assertEqual( + datetime(year=2015, month=1, day=2), breakpoint.GetExpirationTime()) + + def testExpiration(self): + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint.ExpireBreakpoint() + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + self.assertEqual( + { + 'isError': True, + 'refersTo': 'BREAKPOINT_AGE', + 'description': { + 'format': 'The snapshot has expired' + } + }, self._update_queue[0]['status']) + + def testLogpointExpiration(self): + definition = self._template.copy() + definition['action'] = 'LOG' + breakpoint = python_breakpoint.PythonBreakpoint(definition, self, self, + None) + breakpoint.ExpireBreakpoint() + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + self.assertEqual( + { + 'isError': True, + 'refersTo': 'BREAKPOINT_AGE', + 'description': { + 'format': 'The logpoint has expired' + } + }, self._update_queue[0]['status']) + + def testNormalizePath(self): + # Removes leading '/' character. + for path in ['/__init__.py', '//__init__.py', '////__init__.py']: + self.assertEqual('__init__.py', python_breakpoint._NormalizePath(path)) + + # Removes leading and trailing whitespace. + for path in [' __init__.py', '__init__.py ', ' __init__.py ']: + self.assertEqual('__init__.py', python_breakpoint._NormalizePath(path)) + + # Removes combination of leading/trailing whitespace and '/' character. + for path in [' /__init__.py', ' ///__init__.py', '////__init__.py']: + self.assertEqual('__init__.py', python_breakpoint._NormalizePath(path)) + + # Normalizes the relative path. + for path in [ + ' ./__init__.py', '././__init__.py', ' .//abc/../__init__.py', + ' ///abc///..///def/..////__init__.py' + ]: + self.assertEqual('__init__.py', python_breakpoint._NormalizePath(path)) + + # Does not remove non-leading, non-trailing space, or non-leading '/' + # characters. + self.assertEqual( + 'foo bar/baz/__init__.py', + python_breakpoint._NormalizePath('/foo bar/baz/__init__.py')) + self.assertEqual( + 'foo/bar baz/__init__.py', + python_breakpoint._NormalizePath('/foo/bar baz/__init__.py')) + self.assertEqual( + 'foo/bar/baz/__in it__.py', + python_breakpoint._NormalizePath('/foo/bar/baz/__in it__.py')) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/python_test_util.py b/tests/py/python_test_util.py new file mode 100644 index 0000000..ffac4da --- /dev/null +++ b/tests/py/python_test_util.py @@ -0,0 +1,189 @@ +"""Set of helper methods for Python debuglet unit and component tests.""" + +import inspect +import re + + +def GetModuleInfo(obj): + """Gets the source file path and breakpoint tags for a module. + + Breakpoint tag is a named label of a source line. The tag is marked + with "# BPTAG: XXX" comment. + + Args: + obj: any object inside the queried module. + + Returns: + (path, tags) tuple where tags is a dictionary mapping tag name to + line numbers where this tag appears. + """ + return (inspect.getsourcefile(obj), GetSourceFileTags(obj)) + + +def GetSourceFileTags(source): + """Gets breakpoint tags for the specified source file. + + Breakpoint tag is a named label of a source line. The tag is marked + with "# BPTAG: XXX" comment. + + Args: + source: either path to the .py file to analyze or any code related + object (e.g. module, function, code object). + + Returns: + Dictionary mapping tag name to line numbers where this tag appears. + """ + if isinstance(source, str): + lines = open(source, 'r').read().splitlines() + start_line = 1 # line number is 1 based + else: + lines, start_line = inspect.getsourcelines(source) + if not start_line: # "getsourcelines" returns start_line of 0 for modules. + start_line = 1 + + tags = {} + regex = re.compile(r'# BPTAG: ([0-9a-zA-Z_]+)\s*$') + for n, line in enumerate(lines): + m = regex.search(line) + if m: + tag = m.group(1) + if tag in tags: + tags[tag].append(n + start_line) + else: + tags[tag] = [n + start_line] + + return tags + + +def ResolveTag(obj, tag): + """Resolves the breakpoint tag into source file path and a line number. + + Breakpoint tag is a named label of a source line. The tag is marked + with "# BPTAG: XXX" comment. + + Raises + + Args: + obj: any object inside the queried module. + tag: tag name to resolve. + + Raises: + Exception: if no line in the source file define the specified tag or if + more than one line define the tag. + + Returns: + (path, line) tuple, where line is the line number where the tag appears. + """ + path, tags = GetModuleInfo(obj) + if tag not in tags: + raise Exception('tag %s not found' % tag) + lines = tags[tag] + if len(lines) != 1: + raise Exception('tag %s is ambiguous (lines: %s)' % (tag, lines)) + return path, lines[0] + + +def DateTimeToTimestamp(t): + """Converts the specified time to Timestamp format. + + Args: + t: datetime instance + + Returns: + Time in Timestamp format + """ + return t.strftime('%Y-%m-%dT%H:%M:%S.%f') + 'Z' + + +def DateTimeToTimestampNew(t): + """Converts the specified time to Timestamp format in seconds granularity. + + Args: + t: datetime instance + + Returns: + Time in Timestamp format in seconds granularity + """ + return t.strftime('%Y-%m-%dT%H:%M:%S') + 'Z' + +def DateTimeToUnixMsec(t): + """Returns the Unix time as in integer value in milliseconds""" + return int(t.timestamp() * 1000) + + +def PackFrameVariable(breakpoint, name, frame=0, collection='locals'): + """Finds local variable or argument by name. + + Indirections created through varTableIndex are recursively collapsed. Fails + the test case if the named variable is not found. + + Args: + breakpoint: queried breakpoint. + name: name of the local variable or argument. + frame: stack frame index to examine. + collection: 'locals' to get local variable or 'arguments' for an argument. + + Returns: + Single dictionary of variable data. + + Raises: + AssertionError: if the named variable not found. + """ + for variable in breakpoint['stackFrames'][frame][collection]: + if variable['name'] == name: + return _Pack(variable, breakpoint) + + raise AssertionError('Variable %s not found in frame %d collection %s' % + (name, frame, collection)) + + +def PackWatchedExpression(breakpoint, expression): + """Finds watched expression by index. + + Indirections created through varTableIndex are recursively collapsed. Fails + the test case if the named variable is not found. + + Args: + breakpoint: queried breakpoint. + expression: index of the watched expression. + + Returns: + Single dictionary of variable data. + """ + return _Pack(breakpoint['evaluatedExpressions'][expression], breakpoint) + + +def _Pack(variable, breakpoint): + """Recursively collapses indirections created through varTableIndex. + + Circular references by objects are not supported. If variable subtree + has circular references, this function will hang. + + Variable members are sorted by name. This helps asserting the content of + variable since Python has no guarantees over the order of keys of a + dictionary. + + Args: + variable: variable object to pack. Not modified. + breakpoint: queried breakpoint. + + Returns: + A new dictionary with packed variable object. + """ + packed = dict(variable) + + while 'varTableIndex' in packed: + ref = breakpoint['variableTable'][packed['varTableIndex']] + assert 'name' not in ref + assert 'value' not in packed + assert 'members' not in packed + assert 'status' not in ref and 'status' not in packed + del packed['varTableIndex'] + packed.update(ref) + + if 'members' in packed: + packed['members'] = sorted( + [_Pack(m, breakpoint) for m in packed['members']], + key=lambda m: m.get('name', '')) + + return packed diff --git a/tests/py/uniquifier_computer_test.py b/tests/py/uniquifier_computer_test.py new file mode 100644 index 0000000..3d382b7 --- /dev/null +++ b/tests/py/uniquifier_computer_test.py @@ -0,0 +1,121 @@ +"""Unit test for uniquifier_computer module.""" + +import os +import sys +import tempfile + +from absl.testing import absltest + +from googleclouddebugger import uniquifier_computer + + +class UniquifierComputerTest(absltest.TestCase): + + def _Compute(self, files): + """Creates a directory structure and computes uniquifier on it. + + Args: + files: dictionary of relative path to file content. + + Returns: + Uniquifier data lines. + """ + + class Hash(object): + """Fake implementation of hash to collect raw data.""" + + def __init__(self): + self.data = b'' + + def update(self, s): + self.data += s + + root = tempfile.mkdtemp('', 'fake_app_') + for relative_path, content in files.items(): + path = os.path.join(root, relative_path) + directory = os.path.split(path)[0] + if not os.path.exists(directory): + os.makedirs(directory) + with open(path, 'w') as f: + f.write(content) + + sys.path.insert(0, root) + try: + hash_obj = Hash() + uniquifier_computer.ComputeApplicationUniquifier(hash_obj) + return [ + u.decode() for u in ( + hash_obj.data.rstrip(b'\n').split(b'\n') if hash_obj.data else []) + ] + finally: + del sys.path[0] + + def testEmpty(self): + self.assertListEqual([], self._Compute({})) + + def testBundle(self): + self.assertListEqual([ + 'first.py:1', 'in1/__init__.py:6', 'in1/a.py:3', 'in1/b.py:4', + 'in1/in2/__init__.py:7', 'in1/in2/c.py:5', 'second.py:2' + ], + self._Compute({ + 'db.app': 'abc', + 'first.py': 'a', + 'second.py': 'bb', + 'in1/a.py': 'ccc', + 'in1/b.py': 'dddd', + 'in1/in2/c.py': 'eeeee', + 'in1/__init__.py': 'ffffff', + 'in1/in2/__init__.py': 'ggggggg' + })) + + def testEmptyFile(self): + self.assertListEqual(['empty.py:0'], self._Compute({'empty.py': ''})) + + def testNonPythonFilesIgnored(self): + self.assertListEqual(['real.py:1'], + self._Compute({ + 'file.p': '', + 'file.pya': '', + 'real.py': '1' + })) + + def testNonPackageDirectoriesIgnored(self): + self.assertListEqual(['dir2/__init__.py:1'], + self._Compute({ + 'dir1/file.py': '', + 'dir2/__init__.py': 'a', + 'dir2/image.gif': '' + })) + + def testDepthLimit(self): + self.assertListEqual([ + ''.join(str(n) + '/' + for n in range(1, m + 1)) + '__init__.py:%d' % m + for m in range(9, 0, -1) + ], + self._Compute({ + '1/__init__.py': '1', + '1/2/__init__.py': '2' * 2, + '1/2/3/__init__.py': '3' * 3, + '1/2/3/4/__init__.py': '4' * 4, + '1/2/3/4/5/__init__.py': '5' * 5, + '1/2/3/4/5/6/__init__.py': '6' * 6, + '1/2/3/4/5/6/7/__init__.py': '7' * 7, + '1/2/3/4/5/6/7/8/__init__.py': '8' * 8, + '1/2/3/4/5/6/7/8/9/__init__.py': '9' * 9, + '1/2/3/4/5/6/7/8/9/10/__init__.py': 'a' * 10, + '1/2/3/4/5/6/7/8/9/10/11/__init__.py': 'b' * 11 + })) + + def testPrecedence(self): + self.assertListEqual(['my.py:3'], + self._Compute({ + 'my.pyo': 'a', + 'my.pyc': 'aa', + 'my.py': 'aaa' + })) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/yaml_data_visibility_config_reader_test.py b/tests/py/yaml_data_visibility_config_reader_test.py new file mode 100644 index 0000000..65d3cd0 --- /dev/null +++ b/tests/py/yaml_data_visibility_config_reader_test.py @@ -0,0 +1,121 @@ +"""Tests for yaml_data_visibility_config_reader.""" + +import os +import sys +from unittest import mock + +from io import StringIO + +from absl.testing import absltest +from googleclouddebugger import yaml_data_visibility_config_reader + + +class StringIOOpen(object): + """An open for StringIO that supports "with" semantics. + + I tried using mock.mock_open, but the read logic in the yaml.load code is + incompatible with the returned mock object, leading to a test hang/timeout. + """ + + def __init__(self, data): + self.file_obj = StringIO(data) + + def __enter__(self): + return self.file_obj + + def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin + pass + + +class YamlDataVisibilityConfigReaderTest(absltest.TestCase): + + def testOpenAndReadSuccess(self): + data = """ + blacklist: + - bl1 + """ + path_prefix = 'googleclouddebugger.' + with mock.patch( + path_prefix + 'yaml_data_visibility_config_reader.open', + create=True) as m: + m.return_value = StringIOOpen(data) + config = yaml_data_visibility_config_reader.OpenAndRead() + m.assert_called_with( + os.path.join(sys.path[0], 'debugger-blacklist.yaml'), 'r') + self.assertEqual(config.blacklist_patterns, ['bl1']) + + def testOpenAndReadFileNotFound(self): + path_prefix = 'googleclouddebugger.' + with mock.patch( + path_prefix + 'yaml_data_visibility_config_reader.open', + create=True, + side_effect=IOError('IO Error')): + f = yaml_data_visibility_config_reader.OpenAndRead() + self.assertIsNone(f) + + def testReadDataSuccess(self): + data = """ + blacklist: + - bl1 + - bl2 + whitelist: + - wl1 + - wl2.* + """ + + config = yaml_data_visibility_config_reader.Read(StringIO(data)) + self.assertItemsEqual(config.blacklist_patterns, ('bl1', 'bl2')) + self.assertItemsEqual(config.whitelist_patterns, ('wl1', 'wl2.*')) + + def testYAMLLoadError(self): + + class ErrorIO(object): + + def read(self, size): + del size # Unused + raise IOError('IO Error') + + with self.assertRaises(yaml_data_visibility_config_reader.YAMLLoadError): + yaml_data_visibility_config_reader.Read(ErrorIO()) + + def testBadYamlSyntax(self): + data = """ + blacklist: whitelist: + """ + + with self.assertRaises(yaml_data_visibility_config_reader.ParseError): + yaml_data_visibility_config_reader.Read(StringIO(data)) + + def testUnknownConfigKeyError(self): + data = """ + foo: + - bar + """ + + with self.assertRaises( + yaml_data_visibility_config_reader.UnknownConfigKeyError): + yaml_data_visibility_config_reader.Read(StringIO(data)) + + def testNotAListError(self): + data = """ + blacklist: + foo: + - bar + """ + + with self.assertRaises(yaml_data_visibility_config_reader.NotAListError): + yaml_data_visibility_config_reader.Read(StringIO(data)) + + def testElementNotAStringError(self): + data = """ + blacklist: + - 5 + """ + + with self.assertRaises( + yaml_data_visibility_config_reader.ElementNotAStringError): + yaml_data_visibility_config_reader.Read(StringIO(data)) + + +if __name__ == '__main__': + absltest.main()