diff --git a/.gitignore b/.gitignore index 139597f..ab9464c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,8 @@ - - +/dist/ +/src/build/ +/src/dist/ +/src/setup.cfg +__pycache__/ +*.egg-info/ +.coverage +/bazel-* diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..be7798e --- /dev/null +++ b/.pylintrc @@ -0,0 +1,429 @@ +# This Pylint rcfile contains a best-effort configuration to uphold the +# best-practices and style described in the Google Python style guide: +# https://google.github.io/styleguide/pyguide.html +# +# Its canonical open-source location is: +# https://google.github.io/styleguide/pylintrc + +[MASTER] + +# Files or directories to be skipped. They should be base names, not paths. +ignore=third_party + +# Files or directories matching the regex patterns are skipped. The regex +# matches against base names, not paths. +ignore-patterns= + +# Pickle collected data for later comparisons. +persistent=no + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=4 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +#enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=abstract-method, + apply-builtin, + arguments-differ, + attribute-defined-outside-init, + backtick, + bad-option-value, + basestring-builtin, + buffer-builtin, + c-extension-no-member, + consider-using-enumerate, + cmp-builtin, + cmp-method, + coerce-builtin, + coerce-method, + delslice-method, + div-method, + duplicate-code, + eq-without-hash, + execfile-builtin, + file-builtin, + filter-builtin-not-iterating, + fixme, + getslice-method, + global-statement, + hex-method, + idiv-method, + implicit-str-concat, + import-error, + import-self, + import-star-module-level, + inconsistent-return-statements, + input-builtin, + intern-builtin, + invalid-str-codec, + locally-disabled, + long-builtin, + long-suffix, + map-builtin-not-iterating, + misplaced-comparison-constant, + missing-function-docstring, + metaclass-assignment, + next-method-called, + next-method-defined, + no-absolute-import, + no-else-break, + no-else-continue, + no-else-raise, + no-else-return, + no-init, # added + no-member, + no-name-in-module, + no-self-use, + nonzero-method, + oct-method, + old-division, + old-ne-operator, + old-octal-literal, + old-raise-syntax, + parameter-unpacking, + print-statement, + raising-string, + range-builtin-not-iterating, + raw_input-builtin, + rdiv-method, + reduce-builtin, + relative-import, + reload-builtin, + round-builtin, + setslice-method, + signature-differs, + standarderror-builtin, + suppressed-message, + sys-max-int, + too-few-public-methods, + too-many-ancestors, + too-many-arguments, + too-many-boolean-expressions, + too-many-branches, + too-many-instance-attributes, + too-many-locals, + too-many-nested-blocks, + too-many-public-methods, + too-many-return-statements, + too-many-statements, + trailing-newlines, + unichr-builtin, + unicode-builtin, + unnecessary-pass, + unpacking-in-except, + useless-else-on-loop, + useless-object-inheritance, + useless-suppression, + using-cmp-argument, + wrong-import-order, + xrange-builtin, + zip-builtin-not-iterating, + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[BASIC] + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl + +# Regular expression matching correct function names +function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression matching correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct constant names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression matching correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression matching correct module names +module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ + +# Regular expression matching correct method names +method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=10 + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=80 + +# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt +# lines made too long by directives to pytype. + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=(?x)( + ^\s*(\#\ )??$| + ^\s*(from\s+\S+\s+)?import\s+.+$) + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=yes + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. The internal Google style guide mandates 2 +# spaces. Google's externaly-published style guide says 4, consistent with +# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google +# projects (like TensorFlow). +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=TODO + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=yes + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging,absl.logging,tensorflow.io.logging + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub, + TERMIOS, + Bastion, + rexec, + sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant, absl + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls, + class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=StandardError, + Exception, + BaseException diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 0000000..fdd0723 --- /dev/null +++ b/.style.yapf @@ -0,0 +1,2 @@ +[style] +based_on_style = yapf diff --git a/BUILD b/BUILD new file mode 100644 index 0000000..ae821f1 --- /dev/null +++ b/BUILD @@ -0,0 +1,2 @@ +package(default_visibility = ["//visibility:public"]) + diff --git a/README.md b/README.md index 836a7ee..9f171b4 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,30 @@ -# Python Cloud Debugger Agent +# Python Snapshot Debugger Agent + +[Snapshot debugger](https://github.com/GoogleCloudPlatform/snapshot-debugger/) +agent for Python 3.6, Python 3.7, Python 3.8, Python 3.9, and Python 3.10. + + +## Project Status: Archived + +This project has been archived and is no longer supported. There will be no +further bug fixes or security patches. The repository can be forked by users +if they want to maintain it going forward. -Google [Cloud Debugger](https://cloud.google.com/debugger/) for Python 2.7, -Python 3.6, Python 3.7, Python 3.8 and Python 3.9. ## Overview -Cloud Debugger (also known as Stackdriver Debugger) lets you inspect the state +Snapshot Debugger lets you inspect the state of a running cloud application, at any code location, without stopping or slowing it down. It is not your traditional process debugger but rather an always on, whole app debugger taking snapshots from any instance of the app. -Cloud Debugger is safe for use with production apps or during development. The +Snapshot Debugger is safe for use with production apps or during development. The Python debugger agent only few milliseconds to the request latency when a debug snapshot is captured. In most cases, this is not noticeable to users. Furthermore, the Python debugger agent does not allow modification of application state in any way, and has close to zero impact on the app instances. -Cloud Debugger attaches to all instances of the app providing the ability to +Snapshot Debugger attaches to all instances of the app providing the ability to take debug snapshots and add logpoints. A snapshot captures the call-stack and variables from any one instance that executes the snapshot location. A logpoint writes a formatted message to the application log whenever any instance of the @@ -25,26 +33,16 @@ app executes the logpoint location. The Python debugger agent is only supported on Linux at the moment. It was tested on Debian Linux, but it should work on other distributions as well. -Cloud Debugger consists of 3 primary components: +Snapshot Debugger consists of 3 primary components: -1. The Python debugger agent (this repo implements one for CPython 2.7, 3.6, - 3.7, 3.8 and 3.9). -2. Cloud Debugger service storing and managing snapshots/logpoints. Explore the - APIs using - [APIs Explorer](https://cloud.google.com/debugger/api/reference/rest/). +1. The Python debugger agent (this repo implements one for CPython 3.6, + 3.7, 3.8, 3.9, and 3.10). +2. A Firebase Realtime Database for storing and managing snapshots/logpoints. + Explore the + [schema](https://github.com/GoogleCloudPlatform/snapshot-debugger/blob/main/docs/SCHEMA.md). 3. User interface, including a command line interface - [`gcloud debug`](https://cloud.google.com/sdk/gcloud/reference/debug/) and a - Web interface on - [Google Cloud Console](https://console.cloud.google.com/debug/). See the - [online help](https://cloud.google.com/debugger/docs/using/snapshots) on how - to use Google Cloud Console Debug page. - -## Getting Help - -1. StackOverflow: - http://stackoverflow.com/questions/tagged/google-cloud-debugger -2. Send email to: [Cloud Debugger Feedback](mailto:cdbg-feedback@google.com) -3. Send Feedback from Google Cloud Console + [`snapshot-dbg-cli`](https://pypi.org/project/snapshot-dbg-cli/) and a + [VSCode extension](https://github.com/GoogleCloudPlatform/snapshot-debugger/tree/main/snapshot_dbg_extension) ## Installation @@ -54,21 +52,13 @@ The easiest way to install the Python Cloud Debugger is with PyPI: pip install google-python-cloud-debugger ``` -Alternatively, download the *egg* package from -[Releases](https://github.com/GoogleCloudPlatform/cloud-debug-python/releases) -and install the debugger agent with: - -```shell -easy_install google_python_cloud_debugger-py2.7-linux-x86_64.egg -``` - You can also build the agent from source code: ```shell git clone https://github.com/GoogleCloudPlatform/cloud-debug-python.git cd cloud-debug-python/src/ ./build.sh -easy_install dist/google_python_cloud_debugger-*.egg +pip install dist/google_python_cloud_debugger-*.whl ``` Note that the build script assumes some dependencies. To install these @@ -77,17 +67,11 @@ dependencies on Debian, run this command: ```shell sudo apt-get -y -q --no-install-recommends install \ curl ca-certificates gcc build-essential cmake \ - python python-dev libpython2.7 python-setuptools + python3 python3-dev python3-pip ``` -### Python 3 - -There is support for Python 3.6, Python 3.7, Python 3.8 and Python 3.9. Python -3.0 to 3.5 are not supported, and newer versions have not been tested. - -To build for Python 3.x (x in [6-8]), the `python3.x` and `python3.x-dev` -packages are additionally needed. If Python 3.x is not the default version of -the 'python' command on your system, run the build script as `PYTHON=python3.x +If the desired target version of Python is not the default version of +the 'python3' command on your system, run the build script as `PYTHON=python3.x ./build.sh`. ### Alpine Linux @@ -100,22 +84,13 @@ minimal image with the agent installed. ### Google Cloud Platform -1. First, make sure that you created the VM with this option enabled: - - > Allow API access to all Google Cloud services in the same project. - - This option lets the Python debugger agent authenticate with the machine - account of the Virtual Machine. - - It is possible to use the Python debugger agent without it. Please see the - [next section](#outside-google-cloud-platform) for details. +1. First, make sure that the VM has the + [required scopes](https://github.com/GoogleCloudPlatform/snapshot-debugger/blob/main/docs/configuration.md#access-scopes). 2. Install the Python debugger agent as explained in the [Installation](#installation) section. -3. Enable the debugger in your application using one of the two options: - - _Option A_: add this code to the beginning of your `main()` function: +3. Enable the debugger in your application: ```python # Attach Python Cloud Debugger @@ -126,20 +101,7 @@ minimal image with the agent installed. pass ``` - _Option B_: run the debugger agent as a module: - -
-    python \
-        -m googleclouddebugger --module=[MODULE] --version=[VERSION] -- \
-        myapp.py
-    
- - **Note:** This option does not work well with tools such as - `multiprocessing` or `gunicorn`. These tools spawn workers in separate - processes, but the debugger does not get enabled on these worker processes. - Please use _Option A_ instead. - - Where, in both cases: + Where: * `[MODULE]` is the name of your app. This, along with the version, is used to identify the debug target in the UI.
@@ -174,7 +136,7 @@ account. 1. Use the Google Cloud Console Service Accounts [page](https://console.cloud.google.com/iam-admin/serviceaccounts/project) to create a credentials file for an existing or new service account. The - service account must have at least the `Stackdriver Debugger Agent` role. + service account must have at least the `roles/firebasedatabase.admin` role. 2. Once you have the service account credentials JSON file, deploy it alongside the Python debugger agent. @@ -188,8 +150,6 @@ account. Alternatively, you can provide the path to the credentials file directly to the debugger agent. - _Option A_: - ```python # Attach Python Cloud Debugger try: @@ -201,19 +161,6 @@ account. except ImportError: pass ``` - - _Option B_: - -
-    python \
-        -m googleclouddebugger \
-        --module=[MODULE] \
-        --version=[VERSION] \
-        --service_account_json_file=/path/to/credentials.json \
-        -- \
-        myapp.py
-    
- 4. Follow the rest of the steps in the [GCP](#google-cloud-platform) section. ### Django Web Framework @@ -238,6 +185,16 @@ Alternatively, you can pass the `--noreload` flag when running the Django using the `--noreload` flag disables the autoreload feature in Django, which means local changes to files will not be automatically picked up by Django. +## Historical note + +Version 3.x of this agent supported both the now shutdown Cloud Debugger service +(by default) and the +[Snapshot Debugger](https://github.com/GoogleCloudPlatform/snapshot-debugger/) +(Firebase RTDB backend) by setting the `use_firebase` flag to true. Version 4.0 +removed support for the Cloud Debugger service, making the Snapshot Debugger the +default. To note the `use_firebase` flag is now obsolete, but still present for +backward compatibility. + ## Flag Reference The agent offers various flags to configure its behavior. Flags can be specified @@ -272,8 +229,59 @@ which are automatically available on machines hosted on GCP, or can be set via `gcloud auth application-default login` or the `GOOGLE_APPLICATION_CREDENTIALS` environment variable. -`breakpoint_enable_canary`: Whether to enable the -[breakpoint canary feature](https://cloud.google.com/debugger/docs/using/snapshots#with_canarying). -It expects a boolean value (`True`/`False`) or a string, with `'True'` -interpreted as `True` and any other string interpreted as `False`). If not -provided, the breakpoint canarying will not be enabled. +`firebase_db_url`: Url pointing to a configured Firebase Realtime Database for +the agent to use to store snapshot data. +https://**PROJECT_ID**-cdbg.firebaseio.com will be used if not provided. where +**PROJECT_ID** is your project ID. + +## Development + +The following instructions are intended to help with modifying the codebase. + +### Testing + +#### Unit tests + +Run the `build_and_test.sh` script from the root of the repository to build and +run the unit tests using the locally installed version of Python. + +Run `bazel test tests/cpp:all` from the root of the repository to run unit +tests against the C++ portion of the codebase. + +#### Local development + +You may want to run an agent with local changes in an application in order to +validate functionality in a way that unit tests don't fully cover. To do this, +you will need to build the agent: +``` +cd src +./build.sh +cd .. +``` + +The built agent will be available in the `src/dist` directory. You can now +force the installation of the agent using: +``` +pip3 install src/dist/* --force-reinstall +``` + +You can now run your test application using the development build of the agent +in whatever way you desire. + +It is recommended that you do this within a +[virtual environment](https://docs.python.org/3/library/venv.html). + +### Build & Release (for project owners) + +Before performing a release, be sure to update the version number in +`src/googleclouddebugger/version.py`. Tag the commit that increments the +version number (eg. `v3.1`) and create a Github release. + +Run the `build-dist.sh` script from the root of the repository to build, +test, and generate the distribution whls. You may need to use `sudo` +depending on your system's docker setup. + +Build artifacts will be placed in `/dist` and can be pushed to pypi by running: +``` +twine upload dist/*.whl +``` diff --git a/WORKSPACE b/WORKSPACE new file mode 100644 index 0000000..55013f2 --- /dev/null +++ b/WORKSPACE @@ -0,0 +1,50 @@ +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "bazel_skylib", + sha256 = "74d544d96f4a5bb630d465ca8bbcfe231e3594e5aae57e1edbf17a6eb3ca2506", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz", + "https://github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz", + ], +) +load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") +bazel_skylib_workspace() + +http_archive( + name = "com_github_gflags_gflags", + sha256 = "34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf", + strip_prefix = "gflags-2.2.2", + urls = ["https://github.com/gflags/gflags/archive/v2.2.2.tar.gz"], +) + +http_archive( + name = "com_github_google_glog", + sha256 = "21bc744fb7f2fa701ee8db339ded7dce4f975d0d55837a97be7d46e8382dea5a", + strip_prefix = "glog-0.5.0", + urls = ["https://github.com/google/glog/archive/v0.5.0.zip"], +) + +# Pinning to 1.12.1, the last release that supports C++11 +http_archive( + name = "com_google_googletest", + urls = ["https://github.com/google/googletest/archive/58d77fa8070e8cec2dc1ed015d66b454c8d78850.tar.gz"], + strip_prefix = "googletest-58d77fa8070e8cec2dc1ed015d66b454c8d78850", +) + +# Used to build against Python.h +http_archive( + name = "pybind11_bazel", + strip_prefix = "pybind11_bazel-faf56fb3df11287f26dbc66fdedf60a2fc2c6631", + urls = ["https://github.com/pybind/pybind11_bazel/archive/faf56fb3df11287f26dbc66fdedf60a2fc2c6631.zip"], +) + +http_archive( + name = "pybind11", + build_file = "@pybind11_bazel//:pybind11.BUILD", + strip_prefix = "pybind11-2.9.2", + urls = ["https://github.com/pybind/pybind11/archive/v2.9.2.tar.gz"], +) +load("@pybind11_bazel//:python_configure.bzl", "python_configure") +python_configure(name = "local_config_python")#, python_interpreter_target = interpreter) + diff --git a/build-dist.sh b/build-dist.sh new file mode 100755 index 0000000..fffe140 --- /dev/null +++ b/build-dist.sh @@ -0,0 +1,4 @@ +DOCKER_IMAGE='quay.io/pypa/manylinux2014_x86_64' + +docker pull "$DOCKER_IMAGE" +docker container run -t --rm -v "$(pwd)":/io "$DOCKER_IMAGE" /io/src/build-wheels.sh diff --git a/build_and_test.sh b/build_and_test.sh new file mode 100755 index 0000000..4ce82b9 --- /dev/null +++ b/build_and_test.sh @@ -0,0 +1,18 @@ +#!/bin/bash -e + +# Clean up any previous generated test files. +rm -rf tests/py/__pycache__ + +cd src +./build.sh +cd .. + +python3 -m venv /tmp/cdbg-venv +source /tmp/cdbg-venv/bin/activate +pip3 install -r requirements_dev.txt +pip3 install src/dist/* --force-reinstall +python3 -m pytest tests/py +deactivate + +# Clean up any generated test files. +rm -rf tests/py/__pycache__ diff --git a/firebase-sample/app.py b/firebase-sample/app.py new file mode 100644 index 0000000..0916e7c --- /dev/null +++ b/firebase-sample/app.py @@ -0,0 +1,12 @@ +import googleclouddebugger + +googleclouddebugger.enable(use_firebase=True) + +from flask import Flask + +app = Flask(__name__) + + +@app.route("/") +def hello_world(): + return "

Hello World!

" diff --git a/firebase-sample/build-and-run.sh b/firebase-sample/build-and-run.sh new file mode 100755 index 0000000..a0cc7b1 --- /dev/null +++ b/firebase-sample/build-and-run.sh @@ -0,0 +1,20 @@ +#!/bin/bash -e + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd "${SCRIPT_DIR}/.." + +cd src +./build.sh +cd .. + +python3 -m venv /tmp/cdbg-venv +source /tmp/cdbg-venv/bin/activate +pip3 install -r requirements.txt +pip3 install src/dist/* --force-reinstall + +cd firebase-sample +pip3 install -r requirements.txt +python3 -m flask run +cd .. + +deactivate diff --git a/firebase-sample/requirements.txt b/firebase-sample/requirements.txt new file mode 100644 index 0000000..7e10602 --- /dev/null +++ b/firebase-sample/requirements.txt @@ -0,0 +1 @@ +flask diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..784eb7e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +firebase_admin>=5.3.0 +pyyaml diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 0000000..89aa308 --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,4 @@ +-r requirements.txt +absl-py +pytest +requests-mock diff --git a/src/build-wheels.sh b/src/build-wheels.sh new file mode 100755 index 0000000..8477d84 --- /dev/null +++ b/src/build-wheels.sh @@ -0,0 +1,94 @@ +#!/bin/bash -e + +GFLAGS_URL=https://github.com/gflags/gflags/archive/v2.2.2.tar.gz +GLOG_URL=https://github.com/google/glog/archive/v0.4.0.tar.gz + +SUPPORTED_VERSIONS=(cp36-cp36m cp37-cp37m cp38-cp38 cp39-cp39 cp310-cp310) + +ROOT=$(cd $(dirname "${BASH_SOURCE[0]}") >/dev/null; /bin/pwd -P) + +# Parallelize the build over N threads where N is the number of cores * 1.5. +PARALLEL_BUILD_OPTION="-j $(($(nproc 2> /dev/null || echo 4)*3/2))" + +# Clean up any previous build/test files. +rm -rf \ + ${ROOT}/build \ + ${ROOT}/dist \ + ${ROOT}/setup.cfg \ + ${ROOT}/google_python_cloud_debugger.egg-info \ + /io/dist \ + /io/tests/py/__pycache__ + +# Create directory for third-party libraries. +mkdir -p ${ROOT}/build/third_party + +# Build and install gflags to build/third_party. +pushd ${ROOT}/build/third_party +curl -Lk ${GFLAGS_URL} -o gflags.tar.gz +tar xzvf gflags.tar.gz +cd gflags-* +mkdir build +cd build +cmake -DCMAKE_CXX_FLAGS=-fpic \ + -DGFLAGS_NAMESPACE=google \ + -DCMAKE_INSTALL_PREFIX:PATH=${ROOT}/build/third_party \ + .. +make ${PARALLEL_BUILD_OPTION} +make install +popd + +# Build and install glog to build/third_party. +pushd ${ROOT}/build/third_party +curl -L ${GLOG_URL} -o glog.tar.gz +tar xzvf glog.tar.gz +cd glog-* +mkdir build +cd build +cmake -DCMAKE_CXX_FLAGS=-fpic \ + -DCMAKE_PREFIX_PATH=${ROOT}/build/third_party \ + -DCMAKE_INSTALL_PREFIX:PATH=${ROOT}/build/third_party \ + .. +make ${PARALLEL_BUILD_OPTION} +make install +popd + +# Extract build version from version.py +grep "^ *__version__ *=" "/io/src/googleclouddebugger/version.py" | grep -Eo "[0-9.]+" > "version.txt" +AGENT_VERSION=$(cat "version.txt") +echo "Building distribution packages for python agent version ${AGENT_VERSION}" + +# Create setup.cfg file and point to the third_party libraries we just build. +echo "[global] +verbose=1 + +[build_ext] +include_dirs=${ROOT}/build/third_party/include +library_dirs=${ROOT}/build/third_party/lib:${ROOT}/build/third_party/lib64" > ${ROOT}/setup.cfg + +# Build the Python Cloud Debugger agent. +pushd ${ROOT} + +for PY_VERSION in ${SUPPORTED_VERSIONS[@]}; do + echo "Building the ${PY_VERSION} agent" + "/opt/python/${PY_VERSION}/bin/pip" install -r /io/requirements_dev.txt + "/opt/python/${PY_VERSION}/bin/pip" wheel /io/src --no-deps -w /tmp/dist/ + PACKAGE_NAME="google_python_cloud_debugger-${AGENT_VERSION}" + WHL_FILENAME="${PACKAGE_NAME}-${PY_VERSION}-linux_x86_64.whl" + auditwheel repair "/tmp/dist/${WHL_FILENAME}" -w /io/dist/ + + echo "Running tests" + "/opt/python/${PY_VERSION}/bin/pip" install google-python-cloud-debugger --no-index -f /io/dist + "/opt/python/${PY_VERSION}/bin/pytest" /io/tests/py +done + +popd + +# Clean up temporary directories. +rm -rf \ + ${ROOT}/build \ + ${ROOT}/setup.cfg \ + ${ROOT}/google_python_cloud_debugger.egg-info \ + /io/tests/py/__pycache__ + +echo "Build artifacts are in the dist directory" + diff --git a/src/build.sh b/src/build.sh index 7c86c71..f61ef2f 100755 --- a/src/build.sh +++ b/src/build.sh @@ -33,8 +33,8 @@ # Home page of glog: https://github.com/google/glog # -GFLAGS_URL=https://github.com/gflags/gflags/archive/v2.1.2.tar.gz -GLOG_URL=https://github.com/google/glog/archive/v0.3.4.tar.gz +GFLAGS_URL=https://github.com/gflags/gflags/archive/v2.2.2.tar.gz +GLOG_URL=https://github.com/google/glog/archive/v0.4.0.tar.gz ROOT=$(cd $(dirname "${BASH_SOURCE[0]}") >/dev/null; /bin/pwd -P) @@ -42,7 +42,11 @@ ROOT=$(cd $(dirname "${BASH_SOURCE[0]}") >/dev/null; /bin/pwd -P) PARALLEL_BUILD_OPTION="-j $(($(nproc 2> /dev/null || echo 4)*3/2))" # Clean up any previous build files. -rm -rf ${ROOT}/build ${ROOT}/dist ${ROOT}/setup.cfg +rm -rf \ + ${ROOT}/build \ + ${ROOT}/dist \ + ${ROOT}/setup.cfg \ + ${ROOT}/google_python_cloud_debugger.egg-info # Create directory for third-party libraries. mkdir -p ${ROOT}/build/third_party @@ -67,9 +71,12 @@ pushd ${ROOT}/build/third_party curl -L ${GLOG_URL} -o glog.tar.gz tar xzvf glog.tar.gz cd glog-* -./configure --with-pic \ - --prefix=${ROOT}/build/third_party \ - --with-gflags=${ROOT}/build/third_party +mkdir build +cd build +cmake -DCMAKE_CXX_FLAGS=-fpic \ + -DCMAKE_PREFIX_PATH=${ROOT}/build/third_party \ + -DCMAKE_INSTALL_PREFIX:PATH=${ROOT}/build/third_party \ + .. make ${PARALLEL_BUILD_OPTION} make install popd @@ -80,11 +87,16 @@ verbose=1 [build_ext] include_dirs=${ROOT}/build/third_party/include -library_dirs=${ROOT}/build/third_party/lib" > ${ROOT}/setup.cfg +library_dirs=${ROOT}/build/third_party/lib:${ROOT}/build/third_party/lib64" > ${ROOT}/setup.cfg # Build the Python Cloud Debugger agent. pushd ${ROOT} # Use custom python command if variable is set -"${PYTHON:-python}" setup.py bdist_egg +"${PYTHON:-python3}" -m pip wheel . --no-deps -w dist popd +# Clean up temporary directories. +rm -rf \ + ${ROOT}/build \ + ${ROOT}/setup.cfg \ + ${ROOT}/google_python_cloud_debugger.egg-info diff --git a/src/googleclouddebugger/BUILD b/src/googleclouddebugger/BUILD new file mode 100644 index 0000000..c0d6ae7 --- /dev/null +++ b/src/googleclouddebugger/BUILD @@ -0,0 +1,103 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "common", + hdrs = ["common.h"], + deps = [ + "@com_github_google_glog//:glog", + "@local_config_python//:python_headers", + ], +) + +cc_library( + name = "nullable", + hdrs = ["nullable.h"], + deps = [ + ":common", + ], +) + +cc_library( + name = "python_util", + srcs = ["python_util.cc"], + hdrs = ["python_util.h"], + deps = [ + ":common", + ":nullable", + "//src/third_party:pylinetable", + ], +) + + +cc_library( + name = "python_callback", + srcs = ["python_callback.cc"], + hdrs = ["python_callback.h"], + deps = [ + ":common", + ":python_util", + ], +) + +cc_library( + name = "leaky_bucket", + srcs = ["leaky_bucket.cc"], + hdrs = ["leaky_bucket.h"], + deps = [ + ":common", + ], +) + +cc_library( + name = "rate_limit", + srcs = ["rate_limit.cc"], + hdrs = ["rate_limit.h"], + deps = [ + ":common", + ":leaky_bucket", + ], +) + +cc_library( + name = "bytecode_manipulator", + srcs = ["bytecode_manipulator.cc"], + hdrs = ["bytecode_manipulator.h"], + deps = [ + ":common", + ], +) + +cc_library( + name = "bytecode_breakpoint", + srcs = ["bytecode_breakpoint.cc"], + hdrs = ["bytecode_breakpoint.h"], + deps = [ + ":bytecode_manipulator", + ":common", + ":python_callback", + ":python_util", + ], +) + +cc_library( + name = "immutability_tracer", + srcs = ["immutability_tracer.cc"], + hdrs = ["immutability_tracer.h"], + deps = [ + ":common", + ":python_util", + ], +) + +cc_library( + name = "conditional_breakpoint", + srcs = ["conditional_breakpoint.cc"], + hdrs = ["conditional_breakpoint.h"], + deps = [ + ":common", + ":immutability_tracer", + ":python_util", + ":rate_limit", + ":leaky_bucket", + ], +) diff --git a/src/googleclouddebugger/__init__.py b/src/googleclouddebugger/__init__.py index f9364eb..378f6a7 100644 --- a/src/googleclouddebugger/__init__.py +++ b/src/googleclouddebugger/__init__.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Main module for Python Cloud Debugger. The debugger is enabled in a very similar way to enabling pdb. @@ -28,9 +27,9 @@ from . import appengine_pretty_printers from . import breakpoints_manager -from . import capture_collector +from . import collector from . import error_data_visibility_policy -from . import gcp_hub_client +from . import firebase_client from . import glob_data_visibility_policy from . import yaml_data_visibility_config_reader from . import cdbg_native @@ -39,44 +38,41 @@ __version__ = version.__version__ _flags = None -_hub_client = None +_backend_client = None _breakpoints_manager = None def _StartDebugger(): """Configures and starts the debugger.""" - global _hub_client + global _backend_client global _breakpoints_manager cdbg_native.InitializeModule(_flags) - cdbg_native.LogInfo('Initializing Cloud Debugger Python agent version: %s' % - __version__) + cdbg_native.LogInfo( + f'Initializing Cloud Debugger Python agent version: {__version__}') + + _backend_client = firebase_client.FirebaseClient() + _backend_client.SetupAuth( + _flags.get('project_id'), _flags.get('service_account_json_file'), + _flags.get('firebase_db_url')) - _hub_client = gcp_hub_client.GcpHubClient() visibility_policy = _GetVisibilityPolicy() _breakpoints_manager = breakpoints_manager.BreakpointsManager( - _hub_client, - visibility_policy) + _backend_client, visibility_policy) # Set up loggers for logpoints. - capture_collector.SetLogger(logging.getLogger()) + collector.SetLogger(logging.getLogger()) - capture_collector.CaptureCollector.pretty_printers.append( + collector.CaptureCollector.pretty_printers.append( appengine_pretty_printers.PrettyPrinter) - _hub_client.on_active_breakpoints_changed = ( + _backend_client.on_active_breakpoints_changed = ( _breakpoints_manager.SetActiveBreakpoints) - _hub_client.on_idle = _breakpoints_manager.CheckBreakpointsExpiration - _hub_client.SetupAuth( - _flags.get('project_id'), - _flags.get('project_number'), - _flags.get('service_account_json_file')) - _hub_client.SetupCanaryMode( - _flags.get('breakpoint_enable_canary'), - _flags.get('breakpoint_allow_canary_override')) - _hub_client.InitializeDebuggeeLabels(_flags) - _hub_client.Start() + _backend_client.on_idle = _breakpoints_manager.CheckBreakpointsExpiration + + _backend_client.InitializeDebuggeeLabels(_flags) + _backend_client.Start() def _GetVisibilityPolicy(): @@ -85,7 +81,7 @@ def _GetVisibilityPolicy(): visibility_config = yaml_data_visibility_config_reader.OpenAndRead() except yaml_data_visibility_config_reader.Error as err: return error_data_visibility_policy.ErrorDataVisibilityPolicy( - 'Could not process debugger config: %s' % err) + f'Could not process debugger config: {err}') if visibility_config: return glob_data_visibility_policy.GlobDataVisibilityPolicy( @@ -121,16 +117,18 @@ def _DebuggerMain(): sys.path[0] = os.path.dirname(app_path) - import __main__ # pylint: disable=g-import-not-at-top + import __main__ # pylint: disable=import-outside-toplevel __main__.__dict__.clear() - __main__.__dict__.update({'__name__': '__main__', - '__file__': app_path, - '__builtins__': __builtins__}) + __main__.__dict__.update({ + '__name__': '__main__', + '__file__': app_path, + '__builtins__': __builtins__ + }) locals = globals = __main__.__dict__ # pylint: disable=redefined-builtin sys.modules['__main__'] = __main__ - with open(app_path) as f: + with open(app_path, encoding='utf-8') as f: code = compile(f.read(), app_path, 'exec') exec(code, globals, locals) # pylint: disable=exec-used diff --git a/src/googleclouddebugger/__main__.py b/src/googleclouddebugger/__main__.py index 1f55572..edfe6c0 100644 --- a/src/googleclouddebugger/__main__.py +++ b/src/googleclouddebugger/__main__.py @@ -11,10 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Entry point for Python Cloud Debugger.""" # pylint: disable=invalid-name if __name__ == '__main__': import googleclouddebugger googleclouddebugger._DebuggerMain() - diff --git a/src/googleclouddebugger/appengine_pretty_printers.py b/src/googleclouddebugger/appengine_pretty_printers.py index 036caad..3908990 100644 --- a/src/googleclouddebugger/appengine_pretty_printers.py +++ b/src/googleclouddebugger/appengine_pretty_printers.py @@ -11,11 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Formatters for well known objects that don't show up nicely by default.""" -import six - try: from protorpc import messages # pylint: disable=g-import-not-at-top except ImportError: @@ -31,7 +28,7 @@ def PrettyPrinter(obj): """Pretty printers for AppEngine objects.""" if ndb and isinstance(obj, ndb.Model): - return six.iteritems(obj.to_dict()), 'ndb.Model(%s)' % type(obj).__name__ + return obj.to_dict().items(), 'ndb.Model(%s)' % type(obj).__name__ if messages and isinstance(obj, messages.Enum): return [('name', obj.name), ('number', obj.number)], type(obj).__name__ diff --git a/src/googleclouddebugger/application_info.py b/src/googleclouddebugger/application_info.py index 9909f37..c920cce 100644 --- a/src/googleclouddebugger/application_info.py +++ b/src/googleclouddebugger/application_info.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Module to fetch information regarding the current application. Some examples of the information the methods in this module fetch are platform @@ -65,11 +64,16 @@ def GetRegion(): # Otherwise try fetching it from the metadata server. try: - response = requests.get(_GCP_METADATA_REGION_URL, - headers=_GCP_METADATA_HEADER) + response = requests.get( + _GCP_METADATA_REGION_URL, headers=_GCP_METADATA_HEADER) response.raise_for_status() # Example of response text: projects/id/regions/us-central1. So we strip # everything before the last /. - return response.text.split('/')[-1] + region = response.text.split('/')[-1] + if region == 'html>': + # Sometimes we get an html response! + return None + + return region except requests.exceptions.RequestException: return None diff --git a/src/googleclouddebugger/backoff.py b/src/googleclouddebugger/backoff.py index edc024f..f12237d 100644 --- a/src/googleclouddebugger/backoff.py +++ b/src/googleclouddebugger/backoff.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Implements exponential backoff for retry timeouts.""" diff --git a/src/googleclouddebugger/breakpoints_manager.py b/src/googleclouddebugger/breakpoints_manager.py index 07f4094..e3f0421 100644 --- a/src/googleclouddebugger/breakpoints_manager.py +++ b/src/googleclouddebugger/breakpoints_manager.py @@ -11,14 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Manages lifetime of individual breakpoint objects.""" from datetime import datetime from threading import RLock -import six - from . import python_breakpoint @@ -40,9 +37,7 @@ class BreakpointsManager(object): of a captured variable. May be None if no policy is available. """ - def __init__(self, - hub_client, - data_visibility_policy): + def __init__(self, hub_client, data_visibility_policy): self._hub_client = hub_client self.data_visibility_policy = data_visibility_policy @@ -72,19 +67,17 @@ def SetActiveBreakpoints(self, breakpoints_data): ids = set([x['id'] for x in breakpoints_data]) # Clear breakpoints that no longer show up in active breakpoints list. - for breakpoint_id in six.viewkeys(self._active) - ids: + for breakpoint_id in self._active.keys() - ids: self._active.pop(breakpoint_id).Clear() # Create new breakpoints. self._active.update([ (x['id'], - python_breakpoint.PythonBreakpoint( - x, - self._hub_client, - self, - self.data_visibility_policy)) + python_breakpoint.PythonBreakpoint(x, self._hub_client, self, + self.data_visibility_policy)) for x in breakpoints_data - if x['id'] in ids - six.viewkeys(self._active) - self._completed]) + if x['id'] in ids - self._active.keys() - self._completed + ]) # Remove entries from completed_breakpoints_ that weren't listed in # breakpoints_data vector. These are confirmed to have been removed by the @@ -119,7 +112,7 @@ def CheckBreakpointsExpiration(self): expired_breakpoints = [] self._next_expiration = datetime.max - for breakpoint in six.itervalues(self._active): + for breakpoint in self._active.values(): expiration_time = breakpoint.GetExpirationTime() if expiration_time <= current_time: expired_breakpoints.append(breakpoint) diff --git a/src/googleclouddebugger/bytecode_breakpoint.cc b/src/googleclouddebugger/bytecode_breakpoint.cc index 40939ee..dd1af6e 100644 --- a/src/googleclouddebugger/bytecode_breakpoint.cc +++ b/src/googleclouddebugger/bytecode_breakpoint.cc @@ -66,7 +66,7 @@ void BytecodeBreakpoint::Detach() { } -int BytecodeBreakpoint::SetBreakpoint( +int BytecodeBreakpoint::CreateBreakpoint( PyCodeObject* code_object, int line, std::function hit_callback, @@ -82,7 +82,7 @@ int BytecodeBreakpoint::SetBreakpoint( // table in case "code_object" is already patched with another breakpoint. CodeObjectLinesEnumerator lines_enumerator( code_object->co_firstlineno, - code_object_breakpoints->original_lnotab.get()); + code_object_breakpoints->original_linedata.get()); while (lines_enumerator.line_number() != line) { if (!lines_enumerator.Next()) { LOG(ERROR) << "Line " << line << " not found in " @@ -102,6 +102,7 @@ int BytecodeBreakpoint::SetBreakpoint( breakpoint->hit_callable = PythonCallback::Wrap(hit_callback); breakpoint->error_callback = error_callback; breakpoint->cookie = cookie; + breakpoint->status = BreakpointStatus::kInactive; code_object_breakpoints->breakpoints.insert( std::make_pair(breakpoint->offset, breakpoint.get())); @@ -109,15 +110,44 @@ int BytecodeBreakpoint::SetBreakpoint( DCHECK(cookie_map_[cookie] == nullptr); cookie_map_[cookie] = breakpoint.release(); - PatchCodeObject(code_object_breakpoints); - return cookie; } +void BytecodeBreakpoint::ActivateBreakpoint(int cookie) { + if (cookie == -1) return; // no-op if invalid cookie. + + auto it_breakpoint = cookie_map_.find(cookie); + if (it_breakpoint == cookie_map_.end()) { + LOG(WARNING) << "Trying to activate a breakpoint with an unknown cookie: " + << cookie; + return; // No breakpoint with this cookie. + } + + auto it_code = patches_.find(it_breakpoint->second->code_object); + if (it_code != patches_.end()) { + CodeObjectBreakpoints* code = it_code->second; + // Ensure that there is a new breakpoint that was added. + if (it_breakpoint->second->status == BreakpointStatus::kInactive) { + // Set breakpoint to active. + it_breakpoint->second->status = BreakpointStatus::kActive; + // Patch code. + PatchCodeObject(code); + } else { + LOG(WARNING) << "Breakpoint with cookie: " << cookie + << " has already been activated"; + } + } else { + LOG(DFATAL) << "Missing code object"; + } +} void BytecodeBreakpoint::ClearBreakpoint(int cookie) { + if (cookie == -1) return; // no-op if invalid cookie + auto it_breakpoint = cookie_map_.find(cookie); if (it_breakpoint == cookie_map_.end()) { + LOG(WARNING) << "Trying to clear a breakpoint with an unknown cookie: " + << cookie; return; // No breakpoint with this cookie. } @@ -141,6 +171,9 @@ void BytecodeBreakpoint::ClearBreakpoint(int cookie) { DCHECK_EQ(1, erase_count); + // Set breakpoint as done, as it was removed from code->breakpoints map. + it_breakpoint->second->status = BreakpointStatus::kDone; + PatchCodeObject(code); if (code->breakpoints.empty() && code->zombie_refs.empty()) { @@ -148,13 +181,22 @@ void BytecodeBreakpoint::ClearBreakpoint(int cookie) { patches_.erase(it_code); } } else { - DCHECK(false) << "Missing code object"; + LOG(DFATAL) << "Missing code object"; } delete it_breakpoint->second; cookie_map_.erase(it_breakpoint); } +BreakpointStatus BytecodeBreakpoint::GetBreakpointStatus(int cookie) { + auto it_breakpoint = cookie_map_.find(cookie); + if (it_breakpoint == cookie_map_.end()) { + // No breakpoint with this cookie. + return BreakpointStatus::kUnknown; + } + + return it_breakpoint->second->status; +} BytecodeBreakpoint::CodeObjectBreakpoints* BytecodeBreakpoint::PreparePatchCodeObject( @@ -195,8 +237,14 @@ BytecodeBreakpoint::PreparePatchCodeObject( return nullptr; // Probably a built-in method or uninitialized code object. } - data->original_lnotab = + // Store the original (unmodified) line data. +#if PY_VERSION_HEX < 0x030A0000 + data->original_linedata = ScopedPyObject::NewReference(code_object.get()->co_lnotab); +#else + data->original_linedata = + ScopedPyObject::NewReference(code_object.get()->co_linetable); +#endif patches_[code_object] = data.get(); return data.release(); @@ -220,29 +268,38 @@ void BytecodeBreakpoint::PatchCodeObject(CodeObjectBreakpoints* code) { << " from patched " << code->zombie_refs.back().get(); Py_INCREF(code_object->co_code); + // Restore the original line data to the code object. +#if PY_VERSION_HEX < 0x030A0000 if (code_object->co_lnotab != nullptr) { code->zombie_refs.push_back(ScopedPyObject(code_object->co_lnotab)); } - code_object->co_lnotab = code->original_lnotab.get(); + code_object->co_lnotab = code->original_linedata.get(); Py_INCREF(code_object->co_lnotab); +#else + if (code_object->co_linetable != nullptr) { + code->zombie_refs.push_back(ScopedPyObject(code_object->co_linetable)); + } + code_object->co_linetable = code->original_linedata.get(); + Py_INCREF(code_object->co_linetable); +#endif return; } std::vector bytecode = PyBytesToByteArray(code->original_code.get()); - bool has_lnotab = false; - std::vector lnotab; - if (!code->original_lnotab.is_null() && - PyBytes_CheckExact(code->original_lnotab.get())) { - has_lnotab = true; - lnotab = PyBytesToByteArray(code->original_lnotab.get()); + bool has_linedata = false; + std::vector linedata; + if (!code->original_linedata.is_null() && + PyBytes_CheckExact(code->original_linedata.get())) { + has_linedata = true; + linedata = PyBytesToByteArray(code->original_linedata.get()); } BytecodeManipulator bytecode_manipulator( std::move(bytecode), - has_lnotab, - std::move(lnotab)); + has_linedata, + std::move(linedata)); // Add callbacks to code object constants and patch the bytecode. std::vector callbacks; @@ -254,6 +311,9 @@ void BytecodeBreakpoint::PatchCodeObject(CodeObjectBreakpoints* code) { for (auto it_entry = code->breakpoints.begin(); it_entry != code->breakpoints.end(); ++it_entry, ++const_index) { + // Skip breakpoint if it still hasn't been activated. + if (it_entry->second->status == BreakpointStatus::kInactive) continue; + int offset = it_entry->first; bool offset_found = true; const Breakpoint& breakpoint = *it_entry->second; @@ -261,17 +321,16 @@ void BytecodeBreakpoint::PatchCodeObject(CodeObjectBreakpoints* code) { callbacks.push_back(breakpoint.hit_callable.get()); -#if PY_MAJOR_VERSION >= 3 // In Python 3, since we allow upgrading of instructions to use // EXTENDED_ARG, the offsets for lines originally calculated might not be // accurate, so we need to recalculate them each insertion. offset_found = false; - if (bytecode_manipulator.has_lnotab()) { - ScopedPyObject lnotab(PyBytes_FromStringAndSize( - reinterpret_cast(bytecode_manipulator.lnotab().data()), - bytecode_manipulator.lnotab().size())); + if (bytecode_manipulator.has_linedata()) { + ScopedPyObject linedata(PyBytes_FromStringAndSize( + reinterpret_cast(bytecode_manipulator.linedata().data()), + bytecode_manipulator.linedata().size())); CodeObjectLinesEnumerator lines_enumerator(code_object->co_firstlineno, - lnotab.release()); + linedata.release()); while (lines_enumerator.line_number() != breakpoint.line) { if (!lines_enumerator.Next()) { break; @@ -280,13 +339,15 @@ void BytecodeBreakpoint::PatchCodeObject(CodeObjectBreakpoints* code) { } offset_found = lines_enumerator.line_number() == breakpoint.line; } -#endif if (!offset_found || !bytecode_manipulator.InjectMethodCall(offset, const_index)) { LOG(WARNING) << "Failed to insert bytecode for breakpoint " << breakpoint.cookie << " at line " << breakpoint.line; errors.push_back(breakpoint.error_callback); + it_entry->second->status = BreakpointStatus::kError; + } else { + it_entry->second->status = BreakpointStatus::kActive; } } @@ -307,14 +368,26 @@ void BytecodeBreakpoint::PatchCodeObject(CodeObjectBreakpoints* code) { << " reassigned to " << code_object->co_code << ", original was " << code->original_code.get(); - if (has_lnotab) { + // Update the line data in the code object. +#if PY_VERSION_HEX < 0x030A0000 + if (has_linedata) { code->zombie_refs.push_back(ScopedPyObject(code_object->co_lnotab)); ScopedPyObject lnotab_string(PyBytes_FromStringAndSize( - reinterpret_cast(bytecode_manipulator.lnotab().data()), - bytecode_manipulator.lnotab().size())); + reinterpret_cast(bytecode_manipulator.linedata().data()), + bytecode_manipulator.linedata().size())); DCHECK(!lnotab_string.is_null()); code_object->co_lnotab = lnotab_string.release(); } +#else + if (has_linedata) { + code->zombie_refs.push_back(ScopedPyObject(code_object->co_linetable)); + ScopedPyObject linetable_string(PyBytes_FromStringAndSize( + reinterpret_cast(bytecode_manipulator.linedata().data()), + bytecode_manipulator.linedata().size())); + DCHECK(!linetable_string.is_null()); + code_object->co_linetable = linetable_string.release(); + } +#endif // Invoke error callback after everything else is done. The callback may // decide to remove the breakpoint, which will change "code". diff --git a/src/googleclouddebugger/bytecode_breakpoint.h b/src/googleclouddebugger/bytecode_breakpoint.h index f7ecccf..5eaa893 100644 --- a/src/googleclouddebugger/bytecode_breakpoint.h +++ b/src/googleclouddebugger/bytecode_breakpoint.h @@ -27,6 +27,43 @@ namespace devtools { namespace cdbg { +// Enum representing the status of a breakpoint. State tracking is helpful +// for testing and debugging the bytecode breakpoints. +// ======================================================================= +// State transition map: +// +// (start) kUnknown +// |- [CreateBreakpoint] +// | +// | +// | [ActivateBreakpoint] [PatchCodeObject] +// v | | +// kInactive ----> kActive <---> kError +// | | | +// |-------| | |-------| +// | | | +// |- |- |- [ClearBreakpoint] +// v v v +// kDone +// +// ======================================================================= +enum class BreakpointStatus { + // Unknown status for the breakpoint + kUnknown = 0, + + // Breakpoint is created and is patched in the bytecode. + kActive, + + // Breakpoint is created but is currently not patched in the bytecode. + kInactive, + + // Breakpoint has been cleared. + kDone, + + // Breakpoint is created but failed to be activated (patched in the bytecode). + kError +}; + // Sets breakpoints in Python code with zero runtime overhead. // BytecodeBreakpoint rewrites Python bytecode to insert a breakpoint. The // implementation is specific to CPython 2.7. @@ -41,21 +78,36 @@ class BytecodeBreakpoint { // Clears all the set breakpoints. void Detach(); - // Sets a new breakpoint in the specified code object. More than one - // breakpoint can be set at the same source location. When the breakpoint - // hits, the "callback" parameter is invoked. Every time this class fails to - // install the breakpoint, "error_callback" is invoked. Returns cookie used - // to clear the breakpoint. - int SetBreakpoint( - PyCodeObject* code_object, - int line, - std::function hit_callback, - std::function error_callback); - - // Removes a previously set breakpoint. If the cookie is invalid, this - // function does nothing. + // Creates a new breakpoint in the specified code object. More than one + // breakpoint can be created at the same source location. When the breakpoint + // hits, the "callback" parameter is invoked. Every time this method fails to + // create the breakpoint, "error_callback" is invoked and a cookie value of + // -1 is returned. If it succeeds in creating the breakpoint, returns the + // unique cookie used to activate and clear the breakpoint. Note this method + // only creates the breakpoint, to activate it you must call + // "ActivateBreakpoint". + int CreateBreakpoint(PyCodeObject* code_object, int line, + std::function hit_callback, + std::function error_callback); + + // Activates a previously created breakpoint. If it fails to set any + // breakpoint, the error callback will be invoked. This method is kept + // separate from "CreateBreakpoint" to ensure that the cookie is available + // before the "error_callback" is invoked. Calling this method with a cookie + // value of -1 is a no-op. Note that any breakpoints in the same function that + // previously failed to activate will retry to activate during this call. + // TODO: Provide a method "ActivateAllBreakpoints" to optimize + // the code and patch the code once, instead of multiple times. + void ActivateBreakpoint(int cookie); + + // Removes a previously set breakpoint. Calling this method with a cookie + // value of -1 is a no-op. Note that any breakpoints in the same function that + // previously failed to activate will retry to activate during this call. void ClearBreakpoint(int cookie); + // Get the status of a breakpoint. + BreakpointStatus GetBreakpointStatus(int cookie); + private: // Information about the breakpoint. struct Breakpoint { @@ -77,6 +129,9 @@ class BytecodeBreakpoint { // Breakpoint ID used to clear the breakpoint. int cookie; + + // Status of the breakpoint. + BreakpointStatus status; }; // Set of breakpoints in a particular code object and original data of @@ -107,9 +162,10 @@ class BytecodeBreakpoint { // Original value of PyCodeObject::co_code before patching. ScopedPyObject original_code; - // Original value of PythonCode::co_lnotab before patching. - // "lnotab" stands for "line numbers table" in CPython lingo. - ScopedPyObject original_lnotab; + // Original value of PythonCode::co_lnotab or PythonCode::co_linetable + // before patching. This is the line numbers table in CPython <= 3.9 and + // CPython >= 3.10 respectively + ScopedPyObject original_linedata; }; // Loads code object into "patches_" if not there yet. Returns nullptr if diff --git a/src/googleclouddebugger/bytecode_manipulator.cc b/src/googleclouddebugger/bytecode_manipulator.cc index 9c646e3..44cef74 100644 --- a/src/googleclouddebugger/bytecode_manipulator.cc +++ b/src/googleclouddebugger/bytecode_manipulator.cc @@ -36,18 +36,13 @@ enum PythonOpcodeType { // Single Python instruction. // -// In Python 2.7, there are 3 types of instructions: -// 1. Instruction without arguments (takes 1 byte). -// 2. Instruction with a single 16 bit argument (takes 3 bytes). -// 3. Instruction with a 32 bit argument (very uncommon; takes 6 bytes). -// // In Python 3.6, there are 4 types of instructions: // 1. Instructions without arguments, or a 8 bit argument (takes 2 bytes). // 2. Instructions with a 16 bit argument (takes 4 bytes). // 3. Instructions with a 24 bit argument (takes 6 bytes). // 4. Instructions with a 32 bit argument (takes 8 bytes). // -// To handle 32 bit arguments in Python 2, or 16-32 bit arguments in Python 3, +// To handle 16-32 bit arguments in Python 3, // a special instruction with an opcode of EXTENDED_ARG is prepended to the // actual instruction. The argument of the EXTENDED_ARG instruction is combined // with the argument of the next instruction to form the full argument. @@ -68,11 +63,7 @@ static PythonInstruction PythonInstructionNoArg(uint8_t opcode) { instruction.opcode = opcode; instruction.argument = 0; -#if PY_MAJOR_VERSION >= 3 instruction.size = 2; -#else - instruction.size = 1; -#endif return instruction; } @@ -86,7 +77,6 @@ static PythonInstruction PythonInstructionArg(uint8_t opcode, instruction.opcode = opcode; instruction.argument = argument; -#if PY_MAJOR_VERSION >= 3 if (argument <= 0xFF) { instruction.size = 2; } else if (argument <= 0xFFFF) { @@ -96,9 +86,6 @@ static PythonInstruction PythonInstructionArg(uint8_t opcode, } else { instruction.size = 8; } -#else - instruction.size = instruction.argument > 0xFFFF ? 6 : 3; -#endif return instruction; } @@ -119,9 +106,7 @@ static int GetInstructionsSize( static PythonOpcodeType GetOpcodeType(uint8_t opcode) { switch (opcode) { case YIELD_VALUE: -#if PY_MAJOR_VERSION >= 3 case YIELD_FROM: -#endif return YIELD_OPCODE; case FOR_ITER: @@ -147,6 +132,9 @@ static PythonOpcodeType GetOpcodeType(uint8_t opcode) { #if PY_VERSION_HEX < 0x03080000 // Removed in Python 3.8. case CONTINUE_LOOP: +#endif +#if PY_VERSION_HEX >= 0x03090000 + case JUMP_IF_NOT_EXC_MATCH: #endif return BRANCH_ABSOLUTE_OPCODE; @@ -159,10 +147,18 @@ static PythonOpcodeType GetOpcodeType(uint8_t opcode) { static int GetBranchTarget(int offset, PythonInstruction instruction) { switch (GetOpcodeType(instruction.opcode)) { case BRANCH_DELTA_OPCODE: +#if PY_VERSION_HEX < 0x030A0000 return offset + instruction.size + instruction.argument; +#else + return offset + instruction.size + instruction.argument * 2; +#endif case BRANCH_ABSOLUTE_OPCODE: +#if PY_VERSION_HEX < 0x030A0000 return instruction.argument; +#else + return instruction.argument * 2; +#endif default: DCHECK(false) << "Not a branch instruction"; @@ -171,23 +167,6 @@ static int GetBranchTarget(int offset, PythonInstruction instruction) { } -#if PY_MAJOR_VERSION < 3 -// Reads 16 bit value according to Python bytecode encoding. -static uint16 ReadPythonBytecodeUInt16(std::vector::const_iterator it) { - return it[0] | (static_cast(it[1]) << 8); -} - - -// Writes 16 bit value according to Python bytecode encoding. -static void WritePythonBytecodeUInt16( - std::vector::iterator it, - uint16 data) { - it[0] = static_cast(data); - it[1] = data >> 8; -} -#endif - - // Read instruction at the specified offset. Returns kInvalidInstruction // buffer underflow. static PythonInstruction ReadInstruction( @@ -195,7 +174,6 @@ static PythonInstruction ReadInstruction( std::vector::const_iterator it) { PythonInstruction instruction { 0, 0, 0 }; -#if PY_MAJOR_VERSION >= 3 if (bytecode.end() - it < 2) { LOG(ERROR) << "Buffer underflow"; return kInvalidInstruction; @@ -214,39 +192,6 @@ static PythonInstruction ReadInstruction( instruction.opcode = it[0]; instruction.argument = instruction.argument << 8 | it[1]; instruction.size += 2; -#else - if (it == bytecode.end()) { - LOG(ERROR) << "Buffer underflow"; - return kInvalidInstruction; - } - - instruction.opcode = it[0]; - instruction.size = 1; - - auto it_arg = it + 1; - if (instruction.opcode == EXTENDED_ARG) { - if (bytecode.end() - it < 6) { - LOG(ERROR) << "Buffer underflow"; - return kInvalidInstruction; - } - - instruction.opcode = it[3]; - - auto it_ext = it + 4; - instruction.argument = - (static_cast(ReadPythonBytecodeUInt16(it_arg)) << 16) | - ReadPythonBytecodeUInt16(it_ext); - instruction.size = 6; - } else if (HAS_ARG(instruction.opcode)) { - if (bytecode.end() - it < 3) { - LOG(ERROR) << "Buffer underflow"; - return kInvalidInstruction; - } - - instruction.argument = ReadPythonBytecodeUInt16(it_arg); - instruction.size = 3; - } -#endif return instruction; } @@ -256,7 +201,6 @@ static PythonInstruction ReadInstruction( // instruction. static int WriteInstruction(std::vector::iterator it, const PythonInstruction& instruction) { -#if PY_MAJOR_VERSION >= 3 uint32_t arg = instruction.argument; int size_written = 0; // Start writing backwards from the real instruction, followed by any @@ -268,29 +212,6 @@ static int WriteInstruction(std::vector::iterator it, size_written += 2; } return size_written; -#else - if (instruction.size == 6) { - it[0] = EXTENDED_ARG; - WritePythonBytecodeUInt16(it + 1, instruction.argument >> 16); - it[3] = instruction.opcode; - WritePythonBytecodeUInt16( - it + 4, - static_cast(instruction.argument)); - return 6; - } else { - it[0] = instruction.opcode; - - if (HAS_ARG(instruction.opcode)) { - DCHECK_LE(instruction.argument, 0xFFFFU); - WritePythonBytecodeUInt16( - it + 1, - static_cast(instruction.argument)); - return 3; - } - - return 1; - } -#endif } // Write set of instructions to the specified destination. @@ -318,11 +239,11 @@ static std::vector BuildMethodCall(int const_index) { } BytecodeManipulator::BytecodeManipulator(std::vector bytecode, - const bool has_lnotab, - std::vector lnotab) - : has_lnotab_(has_lnotab) { + const bool has_linedata, + std::vector linedata) + : has_linedata_(has_linedata) { data_.bytecode = std::move(bytecode); - data_.lnotab = std::move(lnotab); + data_.linedata = std::move(linedata); strategy_ = STRATEGY_INSERT; // Default strategy. for (auto it = data_.bytecode.begin(); it < data_.bytecode.end(); ) { @@ -367,16 +288,6 @@ bool BytecodeManipulator::InjectMethodCall( } -// Use different algorithms to insert method calls for Python 2 and 3. -// Technically the algorithm for Python 3 will work with Python 2, but because -// it is more complicated and the issue of needing to upgrade branch -// instructions to use EXTENDED_ARG is less common, we stick with the existing -// algorithm for better safety. - - -#if PY_MAJOR_VERSION >= 3 - - // Represents a branch instruction in the original bytecode that may need to // have its offsets fixed and/or upgraded to use EXTENDED_ARG. struct UpdatedInstruction { @@ -396,21 +307,13 @@ struct Insertion { // InsertAndUpdateBranchInstructions. static const int kMaxInsertionIterations = 10; - +#if PY_VERSION_HEX < 0x030A0000 // Updates the line number table for an insertion in the bytecode. -// This is different than what the Python 2 version of InsertMethodCall() does. -// It should be more accurate, but is confined to Python 3 only for safety. -// This handles the case of adding insertion for EXTENDED_ARG better. // Example for inserting 2 bytes at offset 2: -// lnotab: [{2, 1}, {4, 1}] // {offset_delta, line_delta} -// Old algorithm: [{2, 0}, {2, 1}, {4, 1}] -// New algorithm: [{2, 1}, {6, 1}] -// In the old version, trying to get the offset to insert a breakpoint right -// before line 1 would result in an offset of 2, which is inaccurate as the -// instruction before is an EXTENDED_ARG which will now be applied to the first -// instruction inserted instead of its original target. -static void InsertAndUpdateLnotab(int offset, int size, - std::vector* lnotab) { +// lnotab: [{2, 1}, {4, 1}] // {offset_delta, line_delta} +// updated: [{2, 1}, {6, 1}] +static void InsertAndUpdateLineData(int offset, int size, + std::vector* lnotab) { int current_offset = 0; for (auto it = lnotab->begin(); it != lnotab->end(); it += 2) { current_offset += it[0]; @@ -430,6 +333,36 @@ static void InsertAndUpdateLnotab(int offset, int size, } } } +#else +// Updates the line number table for an insertion in the bytecode. +// Example for inserting 2 bytes at offset 2: +// linetable: [{2, 1}, {4, 1}] // {address_end_delta, line_delta} +// updated: [{2, 1}, {6, 1}] +// +// For more information on the linetable format in Python 3.10, see: +// https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt +static void InsertAndUpdateLineData(int offset, int size, + std::vector* linetable) { + int current_offset = 0; + for (auto it = linetable->begin(); it != linetable->end(); it += 2) { + current_offset += it[0]; + + if (current_offset > offset) { + int remaining_size = it[0] + size; + int remaining_lines = it[1]; + it = linetable->erase(it, it + 2); + while (remaining_size > 0xFE) { // Max address delta is listed as 254. + it = linetable->insert(it, 0xFE) + 1; + it = linetable->insert(it, 0) + 1; + remaining_size -= 0xFE; + } + it = linetable->insert(it, remaining_size) + 1; + it = linetable->insert(it, remaining_lines) + 1; + return; + } + } +} +#endif // Reserves space for instructions to be inserted into the bytecode, and // calculates the new offsets and arguments of branch instructions. @@ -506,13 +439,21 @@ static bool InsertAndUpdateBranchInstructions( // argument of 0 even when it is not required. This needs to be taken // into account when calculating the target of a branch instruction. int inst_size = std::max(instruction.size, it->original_size); +#if PY_VERSION_HEX < 0x030A0000 int32_t target = it->current_offset + inst_size + arg; +#else + int32_t target = it->current_offset + inst_size + arg * 2; +#endif need_to_update = it->current_offset < insertion.current_offset && insertion.current_offset < target; } else if (opcode_type == BRANCH_ABSOLUTE_OPCODE) { // For absolute branches, the argument needs to be updated if the // insertion before the target. +#if PY_VERSION_HEX < 0x030A0000 need_to_update = insertion.current_offset < arg; +#else + need_to_update = insertion.current_offset < arg * 2; +#endif } // If we are inserting the original method call instructions, we want to @@ -526,8 +467,16 @@ static bool InsertAndUpdateBranchInstructions( } if (need_to_update) { +#if PY_VERSION_HEX < 0x030A0000 + int delta = insertion.size; +#else + // Changed in version 3.10: The argument of jump, exception handling + // and loop instructions is now the instruction offset rather than the + // byte offset. + int delta = insertion.size / 2; +#endif PythonInstruction new_instruction = - PythonInstructionArg(instruction.opcode, arg + insertion.size); + PythonInstructionArg(instruction.opcode, arg + delta); int size_diff = new_instruction.size - instruction.size; if (size_diff > 0) { insertions.push_back(Insertion { size_diff, it->current_offset }); @@ -590,8 +539,8 @@ bool BytecodeManipulator::InsertMethodCall( // Insert the method call. data->bytecode.insert(data->bytecode.begin() + offset, method_call_size, NOP); WriteInstructions(data->bytecode.begin() + offset, method_call_instructions); - if (has_lnotab_) { - InsertAndUpdateLnotab(offset, method_call_size, &data->lnotab); + if (has_linedata_) { + InsertAndUpdateLineData(offset, method_call_size, &data->linedata); } // Write new branch instructions. @@ -603,8 +552,8 @@ bool BytecodeManipulator::InsertMethodCall( int offset = it->current_offset; if (size_diff > 0) { data->bytecode.insert(data->bytecode.begin() + offset, size_diff, NOP); - if (has_lnotab_) { - InsertAndUpdateLnotab(it->current_offset, size_diff, &data->lnotab); + if (has_linedata_) { + InsertAndUpdateLineData(it->current_offset, size_diff, &data->linedata); } } else if (size_diff < 0) { // The Python compiler sometimes prematurely adds EXTENDED_ARG with an @@ -619,113 +568,6 @@ bool BytecodeManipulator::InsertMethodCall( } -#else - - -bool BytecodeManipulator::InsertMethodCall( - BytecodeManipulator::Data* data, - int offset, - int const_index) const { - const std::vector method_call_instructions = - BuildMethodCall(const_index); - int size = GetInstructionsSize(method_call_instructions); - - bool offset_valid = false; - for (auto it = data->bytecode.begin(); it < data->bytecode.end(); ) { - const int current_offset = it - data->bytecode.begin(); - if (current_offset == offset) { - DCHECK(!offset_valid) << "Each offset should be visited only once"; - offset_valid = true; - } - - int current_fixed_offset = current_offset; - if (current_fixed_offset >= offset) { - current_fixed_offset += size; - } - - PythonInstruction instruction = ReadInstruction(data->bytecode, it); - if (instruction.opcode == kInvalidInstruction.opcode) { - return false; - } - - // Fix targets in branch instructions. - switch (GetOpcodeType(instruction.opcode)) { - case BRANCH_DELTA_OPCODE: { - int32 delta = static_cast(instruction.argument); - int32 target = current_offset + instruction.size + delta; - - if (target > offset) { - target += size; - } - - int32 fixed_delta = target - current_fixed_offset - instruction.size; - if (delta != fixed_delta) { - PythonInstruction new_instruction = - PythonInstructionArg(instruction.opcode, fixed_delta); - if (new_instruction.size != instruction.size) { - LOG(ERROR) << "Upgrading instruction to extended not supported"; - return false; - } - - WriteInstruction(it, new_instruction); - } - break; - } - - case BRANCH_ABSOLUTE_OPCODE: - if (static_cast(instruction.argument) > offset) { - PythonInstruction new_instruction = PythonInstructionArg( - instruction.opcode, instruction.argument + size); - if (new_instruction.size != instruction.size) { - LOG(ERROR) << "Upgrading instruction to extended not supported"; - return false; - } - - WriteInstruction(it, new_instruction); - } - break; - - default: - break; - } - - it += instruction.size; - } - - if (!offset_valid) { - LOG(ERROR) << "Offset " << offset << " is mid instruction or out of range"; - return false; - } - - // Insert the bytecode to invoke the callable. - data->bytecode.insert(data->bytecode.begin() + offset, size, NOP); - WriteInstructions(data->bytecode.begin() + offset, method_call_instructions); - - // Insert a new entry into line table to account for the new bytecode. - if (has_lnotab_) { - int current_offset = 0; - for (auto it = data->lnotab.begin(); it != data->lnotab.end(); it += 2) { - current_offset += it[0]; - - if (current_offset >= offset) { - int remaining_size = size; - while (remaining_size > 0) { - const int current_size = std::min(remaining_size, 0xFF); - it = data->lnotab.insert(it, static_cast(current_size)) + 1; - it = data->lnotab.insert(it, 0) + 1; - remaining_size -= current_size; - } - - break; - } - } - } - - return true; -} -#endif - - // This method does not change line numbers table. The line numbers table // is monotonically growing, which is not going to work for our case. Besides // the trampoline will virtually always fit a single instruction, so we don't diff --git a/src/googleclouddebugger/bytecode_manipulator.h b/src/googleclouddebugger/bytecode_manipulator.h index d3a7de4..31a5e46 100644 --- a/src/googleclouddebugger/bytecode_manipulator.h +++ b/src/googleclouddebugger/bytecode_manipulator.h @@ -71,17 +71,17 @@ namespace cdbg { // 19 JUMP_ABSOLUTE 3 class BytecodeManipulator { public: - BytecodeManipulator(std::vector bytecode, const bool has_lnotab, - std::vector lnotab); + BytecodeManipulator(std::vector bytecode, const bool has_linedata, + std::vector linedata); // Gets the transformed method bytecode. const std::vector& bytecode() const { return data_.bytecode; } // Returns true if this class was initialized with line numbers table. - bool has_lnotab() const { return has_lnotab_; } + bool has_linedata() const { return has_linedata_; } // Gets the method line numbers table or empty vector if not available. - const std::vector& lnotab() const { return data_.lnotab; } + const std::vector& linedata() const { return data_.linedata; } // Rewrites the method bytecode to invoke callable at the specified offset. // Return false if the method call could not be inserted. The bytecode @@ -109,8 +109,8 @@ class BytecodeManipulator { // Bytecode of a transformed method. std::vector bytecode; - // Method line numbers table or empty vector if "has_lnotab_" is false. - std::vector lnotab; + // Method line numbers table or empty vector if "has_linedata_" is false. + std::vector linedata; }; // Insert space into the bytecode. This space is later used to add new @@ -130,7 +130,7 @@ class BytecodeManipulator { Data data_; // True if the method has line number table. - const bool has_lnotab_; + const bool has_linedata_; // Algorithm to insert breakpoint callback into method bytecode. Strategy strategy_; diff --git a/src/googleclouddebugger/capture_collector.py b/src/googleclouddebugger/collector.py similarity index 87% rename from src/googleclouddebugger/capture_collector.py rename to src/googleclouddebugger/collector.py index fc79366..82916ab 100644 --- a/src/googleclouddebugger/capture_collector.py +++ b/src/googleclouddebugger/collector.py @@ -11,11 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Captures application state on a breakpoint hit.""" -# TODO: rename this file to collector.py. - import copy import datetime import inspect @@ -27,8 +24,6 @@ import time import types -import six - from . import cdbg_native as native from . import labels @@ -47,9 +42,8 @@ # Externally defined function to collect the end user id. breakpoint_labels_collector = lambda: {} -_PRIMITIVE_TYPES = (type(None), float, complex, bool, slice, bytearray, - six.text_type, - six.binary_type) + six.integer_types + six.string_types +_PRIMITIVE_TYPES = (type(None), float, complex, bool, slice, bytearray, str, + bytes, int) _DATE_TYPES = (datetime.date, datetime.time, datetime.timedelta) _VECTOR_TYPES = (tuple, list, set) @@ -280,8 +274,8 @@ def __init__(self, definition, data_visibility_policy): # because in the case where the user has not indicated a preference, we # don't want a single large object on the stack to use the entire max_size # quota and hide the rest of the data. - self.expression_capture_limits = _CaptureLimits(max_value_len=32768, - max_list_items=32768) + self.expression_capture_limits = _CaptureLimits( + max_value_len=32768, max_list_items=32768) def Collect(self, top_frame): """Collects call stack, local variables and objects. @@ -301,8 +295,9 @@ def Collect(self, top_frame): # Evaluate watched expressions. if 'expressions' in self.breakpoint: self.breakpoint['evaluatedExpressions'] = [ - self._CaptureExpression(top_frame, expression) for expression - in self.breakpoint['expressions']] + self._CaptureExpression(top_frame, expression) + for expression in self.breakpoint['expressions'] + ] while frame and (len(breakpoint_frames) < self.max_frames): line = top_line if frame == top_frame else frame.f_lineno @@ -332,7 +327,10 @@ def Collect(self, top_frame): 'description': { 'format': ('INTERNAL ERROR: Failed while capturing locals ' 'of frame $0: $1'), - 'parameters': [str(len(breakpoint_frames)), str(e)]}} + 'parameters': [str(len(breakpoint_frames)), + str(e)] + } + } # Number of entries in _var_table. Starts at 1 (index 0 is the 'buffer full' # status value). @@ -340,10 +338,12 @@ def Collect(self, top_frame): # Explore variables table in BFS fashion. The variables table will grow # inside CaptureVariable as we encounter new references. - while (num_vars < len(self._var_table)) and ( - self._total_size < self.max_size): + while (num_vars < len(self._var_table)) and (self._total_size < + self.max_size): self._var_table[num_vars] = self.CaptureVariable( - self._var_table[num_vars], 0, self.default_capture_limits, + self._var_table[num_vars], + 0, + self.default_capture_limits, can_enqueue=False) # Move on to the next entry in the variable table. @@ -367,20 +367,24 @@ def CaptureFrameLocals(self, frame): (arguments, locals) tuple. """ # Capture all local variables (including method arguments). - variables = {n: self.CaptureNamedVariable(n, v, 1, - self.default_capture_limits) - for n, v in six.viewitems(frame.f_locals)} + variables = { + n: self.CaptureNamedVariable(n, v, 1, self.default_capture_limits) + for n, v in frame.f_locals.items() + } # Split between locals and arguments (keeping arguments in the right order). nargs = frame.f_code.co_argcount - if frame.f_code.co_flags & inspect.CO_VARARGS: nargs += 1 - if frame.f_code.co_flags & inspect.CO_VARKEYWORDS: nargs += 1 + if frame.f_code.co_flags & inspect.CO_VARARGS: + nargs += 1 + if frame.f_code.co_flags & inspect.CO_VARKEYWORDS: + nargs += 1 frame_arguments = [] for argname in frame.f_code.co_varnames[:nargs]: - if argname in variables: frame_arguments.append(variables.pop(argname)) + if argname in variables: + frame_arguments.append(variables.pop(argname)) - return (frame_arguments, list(six.viewvalues(variables))) + return (frame_arguments, list(variables.values())) def CaptureNamedVariable(self, name, value, depth, limits): """Appends name to the product of CaptureVariable. @@ -400,8 +404,9 @@ def CaptureNamedVariable(self, name, value, depth, limits): name = str(id(name)) self._total_size += len(name) - v = (self.CheckDataVisibility(value) or - self.CaptureVariable(value, depth, limits)) + v = ( + self.CheckDataVisibility(value) or + self.CaptureVariable(value, depth, limits)) v['name'] = name return v @@ -449,23 +454,30 @@ def CaptureVariablesList(self, items, depth, empty_message, limits): """ v = [] for name, value in items: - if (self._total_size >= self.max_size) or ( - len(v) >= limits.max_list_items): + if (self._total_size >= self.max_size) or (len(v) >= + limits.max_list_items): v.append({ 'status': { 'refersTo': 'VARIABLE_VALUE', 'description': { - 'format': - ('Only first $0 items were captured. Use in an ' - 'expression to see all items.'), - 'parameters': [str(len(v))]}}}) + 'format': ('Only first $0 items were captured. Use in an ' + 'expression to see all items.'), + 'parameters': [str(len(v))] + } + } + }) break v.append(self.CaptureNamedVariable(name, value, depth, limits)) if not v: - return [{'status': { - 'refersTo': 'VARIABLE_NAME', - 'description': {'format': empty_message}}}] + return [{ + 'status': { + 'refersTo': 'VARIABLE_NAME', + 'description': { + 'format': empty_message + } + } + }] return v @@ -508,31 +520,34 @@ def CaptureVariableInternal(self, value, depth, limits, can_enqueue=True): return {'value': 'None'} if isinstance(value, _PRIMITIVE_TYPES): - r = _TrimString(repr(value), # Primitive type, always immutable. - min(limits.max_value_len, - self.max_size - self._total_size)) + r = _TrimString( + repr(value), # Primitive type, always immutable. + min(limits.max_value_len, self.max_size - self._total_size)) self._total_size += len(r) return {'value': r, 'type': type(value).__name__} if isinstance(value, _DATE_TYPES): r = str(value) # Safe to call str(). self._total_size += len(r) - return {'value': r, 'type': 'datetime.'+ type(value).__name__} + return {'value': r, 'type': 'datetime.' + type(value).__name__} if isinstance(value, dict): # Do not use iteritems() here. If GC happens during iteration (which it # often can for dictionaries containing large variables), you will get a # RunTimeError exception. items = [(repr(k), v) for (k, v) in value.items()] - return {'members': - self.CaptureVariablesList(items, depth + 1, - EMPTY_DICTIONARY, limits), - 'type': 'dict'} + return { + 'members': + self.CaptureVariablesList(items, depth + 1, EMPTY_DICTIONARY, + limits), + 'type': + 'dict' + } if isinstance(value, _VECTOR_TYPES): fields = self.CaptureVariablesList( - (('[%d]' % i, x) for i, x in enumerate(value)), - depth + 1, EMPTY_COLLECTION, limits) + (('[%d]' % i, x) for i, x in enumerate(value)), depth + 1, + EMPTY_COLLECTION, limits) return {'members': fields, 'type': type(value).__name__} if isinstance(value, types.FunctionType): @@ -542,8 +557,8 @@ def CaptureVariableInternal(self, value, depth, limits, can_enqueue=True): if isinstance(value, Exception): fields = self.CaptureVariablesList( - (('[%d]' % i, x) for i, x in enumerate(value.args)), - depth + 1, EMPTY_COLLECTION, limits) + (('[%d]' % i, x) for i, x in enumerate(value.args)), depth + 1, + EMPTY_COLLECTION, limits) return {'members': fields, 'type': type(value).__name__} if can_enqueue: @@ -561,10 +576,13 @@ def CaptureVariableInternal(self, value, depth, limits, can_enqueue=True): continue fields, object_type = pretty_value - return {'members': + return { + 'members': self.CaptureVariablesList(fields, depth + 1, OBJECT_HAS_NO_FIELDS, limits), - 'type': object_type} + 'type': + object_type + } if not hasattr(value, '__dict__'): # TODO: keep "value" empty and populate the "type" field instead. @@ -574,14 +592,13 @@ def CaptureVariableInternal(self, value, depth, limits, can_enqueue=True): # Add an additional depth for the object itself items = value.__dict__.items() - if six.PY3: - # Make a list of the iterator in Python 3, to avoid 'dict changed size - # during iteration' errors from GC happening in the middle. - # Only limits.max_list_items + 1 items are copied, anything past that will - # get ignored by CaptureVariablesList(). - items = list(itertools.islice(items, limits.max_list_items + 1)) - members = self.CaptureVariablesList(items, depth + 2, - OBJECT_HAS_NO_FIELDS, limits) + # Make a list of the iterator in Python 3, to avoid 'dict changed size + # during iteration' errors from GC happening in the middle. + # Only limits.max_list_items + 1 items are copied, anything past that will + # get ignored by CaptureVariablesList(). + items = list(itertools.islice(items, limits.max_list_items + 1)) + members = self.CaptureVariablesList(items, depth + 2, OBJECT_HAS_NO_FIELDS, + limits) v = {'members': members} type_string = DetermineType(value) @@ -641,7 +658,7 @@ def _CaptureEnvironmentLabels(self): self.breakpoint['labels'] = {} if callable(breakpoint_labels_collector): - for (key, value) in six.iteritems(breakpoint_labels_collector()): + for (key, value) in breakpoint_labels_collector().items(): self._StoreLabel(key, value) def _CaptureRequestLogId(self): @@ -736,8 +753,12 @@ def Log(self, frame): """ # Return error if log methods were not configured globally. if not self._log_message: - return {'isError': True, - 'description': {'format': LOG_ACTION_NOT_SUPPORTED}} + return { + 'isError': True, + 'description': { + 'format': LOG_ACTION_NOT_SUPPORTED + } + } if self._quota_recovery_start_time: ms_elapsed = (time.time() - self._quota_recovery_start_time) * 1000 @@ -778,8 +799,10 @@ def _EvaluateExpressions(self, frame): Array of strings where each string corresponds to the breakpoint expression with the same index. """ - return [self._FormatExpression(frame, expression) for expression in - self._definition.get('expressions') or []] + return [ + self._FormatExpression(frame, expression) + for expression in self._definition.get('expressions') or [] + ] def _FormatExpression(self, frame, expression): """Evaluates a single watched expression and formats it into a string form. @@ -819,8 +842,7 @@ def _FormatValue(self, value, level=0): def FormatDictItem(key_value): """Formats single dictionary item.""" key, value = key_value - return (self._FormatValue(key, level + 1) + - ': ' + + return (self._FormatValue(key, level + 1) + ': ' + self._FormatValue(value, level + 1)) def LimitedEnumerate(items, formatter, level=0): @@ -840,8 +862,9 @@ def FormatList(items, formatter, level=0): return ', '.join(LimitedEnumerate(items, formatter, level=level)) if isinstance(value, _PRIMITIVE_TYPES): - return _TrimString(repr(value), # Primitive type, always immutable. - self.max_value_len) + return _TrimString( + repr(value), # Primitive type, always immutable. + self.max_value_len) if isinstance(value, _DATE_TYPES): return str(value) @@ -850,11 +873,14 @@ def FormatList(items, formatter, level=0): return str(type(value)) if isinstance(value, dict): - return '{' + FormatList(six.iteritems(value), FormatDictItem) + '}' + return '{' + FormatList(value.items(), FormatDictItem) + '}' if isinstance(value, _VECTOR_TYPES): - return _ListTypeFormatString(value).format(FormatList( - value, lambda item: self._FormatValue(item, level + 1), level=level)) + return _ListTypeFormatString(value).format( + FormatList( + value, + lambda item: self._FormatValue(item, level + 1), + level=level)) if isinstance(value, types.FunctionType): return 'function ' + value.__name__ @@ -884,14 +910,18 @@ def _EvaluateExpression(frame, expression): 'refersTo': 'VARIABLE_NAME', 'description': { 'format': 'Invalid expression', - 'parameters': [str(e)]}}) + 'parameters': [str(e)] + } + }) except SyntaxError as e: return (False, { 'isError': True, 'refersTo': 'VARIABLE_NAME', 'description': { 'format': 'Expression could not be compiled: $0', - 'parameters': [e.msg]}}) + 'parameters': [e.msg] + } + }) try: return (True, native.CallImmutable(frame, code)) @@ -901,7 +931,9 @@ def _EvaluateExpression(frame, expression): 'refersTo': 'VARIABLE_VALUE', 'description': { 'format': 'Exception occurred: $0', - 'parameters': [str(e)]}}) + 'parameters': [str(e)] + } + }) def _GetFrameCodeObjectName(frame): @@ -917,8 +949,8 @@ def _GetFrameCodeObjectName(frame): # This functions under the assumption that member functions will name their # first parameter argument 'self' but has some edge-cases. if frame.f_code.co_argcount >= 1 and 'self' == frame.f_code.co_varnames[0]: - return (frame.f_locals['self'].__class__.__name__ + - '.' + frame.f_code.co_name) + return (frame.f_locals['self'].__class__.__name__ + '.' + + frame.f_code.co_name) else: return frame.f_code.co_name @@ -933,6 +965,7 @@ def _FormatMessage(template, parameters): Returns: Formatted message with parameters embedded in template placeholders. """ + def GetParameter(m): try: return parameters[int(m.group(0)[1:])] @@ -947,4 +980,4 @@ def _TrimString(s, max_len): """Trims the string if it exceeds max_len.""" if len(s) <= max_len: return s - return s[:max_len+1] + '...' + return s[:max_len + 1] + '...' diff --git a/src/googleclouddebugger/error_data_visibility_policy.py b/src/googleclouddebugger/error_data_visibility_policy.py index a604578..0a04c36 100644 --- a/src/googleclouddebugger/error_data_visibility_policy.py +++ b/src/googleclouddebugger/error_data_visibility_policy.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Always returns the provided error on visibility requests. Example Usage: diff --git a/src/googleclouddebugger/firebase_client.py b/src/googleclouddebugger/firebase_client.py new file mode 100644 index 0000000..8dbe30a --- /dev/null +++ b/src/googleclouddebugger/firebase_client.py @@ -0,0 +1,699 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Communicates with Firebase RTDB backend.""" + +from collections import deque +import copy +import hashlib +import json +import os +import platform +import requests +import sys +import threading +import time +import traceback + +import firebase_admin +import firebase_admin.credentials +import firebase_admin.db +import firebase_admin.exceptions + +from . import backoff +from . import cdbg_native as native +from . import labels +from . import uniquifier_computer +from . import application_info +from . import version +# This module catches all exception. This is safe because it runs in +# a daemon thread (so we are not blocking Ctrl+C). We need to catch all +# the exception because HTTP client is unpredictable as far as every +# exception it can throw. +# pylint: disable=broad-except + +# Set of all known debuggee labels (passed down as flags). The value of +# a map is optional environment variable that can be used to set the flag +# (flags still take precedence). +_DEBUGGEE_LABELS = { + labels.Debuggee.MODULE: [ + 'GAE_SERVICE', 'GAE_MODULE_NAME', 'K_SERVICE', 'FUNCTION_NAME' + ], + labels.Debuggee.VERSION: [ + 'GAE_VERSION', 'GAE_MODULE_VERSION', 'K_REVISION', + 'X_GOOGLE_FUNCTION_VERSION' + ], + labels.Debuggee.MINOR_VERSION: ['GAE_DEPLOYMENT_ID', 'GAE_MINOR_VERSION'] +} + +# Debuggee labels used to format debuggee description (ordered). The minor +# version is excluded for the sake of consistency with AppEngine UX. +_DESCRIPTION_LABELS = [ + labels.Debuggee.PROJECT_ID, labels.Debuggee.MODULE, labels.Debuggee.VERSION +] + +_METADATA_SERVER_URL = 'http://metadata.google.internal/computeMetadata/v1' + +_TRANSIENT_ERROR_CODES = ('UNKNOWN', 'INTERNAL', 'N/A', 'UNAVAILABLE', + 'DEADLINE_EXCEEDED', 'RESOURCE_EXHAUSTED', + 'UNAUTHENTICATED', 'PERMISSION_DENIED') + + +class NoProjectIdError(Exception): + """Used to indicate the project id cannot be determined.""" + + +class FirebaseClient(object): + """Firebase RTDB Backend client. + + Registers the debuggee, subscribes for active breakpoints and sends breakpoint + updates to the backend. + + This class supports two types of authentication: application default + credentials or a manually provided JSON credentials file for a service + account. + + FirebaseClient creates a worker thread that communicates with the backend. The + thread can be stopped with a Stop function, but it is optional since the + worker thread is marked as daemon. + """ + + def __init__(self): + self.on_active_breakpoints_changed = lambda x: None + self.on_idle = lambda: None + self._debuggee_labels = {} + self._credentials = None + self._project_id = None + self._database_url = None + self._debuggee_id = None + self._canary_mode = None + self._breakpoints = {} + self._main_thread = None + self._transmission_thread = None + self._transmission_thread_startup_lock = threading.Lock() + self._transmission_queue = deque(maxlen=100) + self._mark_active_timer = None + self._mark_active_interval_sec = 60 * 60 # 1 hour in seconds + self._new_updates = threading.Event() + self._breakpoint_subscription = None + self._firebase_app = None + + # Events for unit testing. + self.connection_complete = threading.Event() + self.registration_complete = threading.Event() + self.subscription_complete = threading.Event() + + # + # Configuration options (constants only modified by unit test) + # + + # Delay before retrying failed request. + self.connect_backoff = backoff.Backoff() # Connect to the DB. + self.register_backoff = backoff.Backoff() # Register debuggee. + self.subscribe_backoff = backoff.Backoff() # Subscribe to updates. + self.update_backoff = backoff.Backoff() # Update breakpoint. + + # Maximum number of times that the message is re-transmitted before it + # is assumed to be poisonous and discarded + self.max_transmit_attempts = 10 + + def InitializeDebuggeeLabels(self, flags): + """Initialize debuggee labels from environment variables and flags. + + The caller passes all the flags that the debuglet got. This function + will only use the flags used to label the debuggee. Flags take precedence + over environment variables. + + Debuggee description is formatted from available flags. + + Args: + flags: dictionary of debuglet command line flags. + """ + self._debuggee_labels = {} + + for (label, var_names) in _DEBUGGEE_LABELS.items(): + # var_names is a list of possible environment variables that may contain + # the label value. Find the first one that is set. + for name in var_names: + value = os.environ.get(name) + if value: + # Special case for module. We omit the "default" module + # to stay consistent with AppEngine. + if label == labels.Debuggee.MODULE and value == 'default': + break + self._debuggee_labels[label] = value + break + + # Special case when FUNCTION_NAME is set and X_GOOGLE_FUNCTION_VERSION + # isn't set. We set the version to 'unversioned' to be consistent with other + # agents. + # TODO: Stop assigning 'unversioned' to a GCF and find the + # actual version. + if ('FUNCTION_NAME' in os.environ and + labels.Debuggee.VERSION not in self._debuggee_labels): + self._debuggee_labels[labels.Debuggee.VERSION] = 'unversioned' + + if flags: + self._debuggee_labels.update({ + name: value + for (name, value) in flags.items() + if name in _DEBUGGEE_LABELS + }) + + self._debuggee_labels[labels.Debuggee.PROJECT_ID] = self._project_id + + platform_enum = application_info.GetPlatform() + self._debuggee_labels[labels.Debuggee.PLATFORM] = platform_enum.value + + if platform_enum == application_info.PlatformType.CLOUD_FUNCTION: + region = application_info.GetRegion() + if region: + self._debuggee_labels[labels.Debuggee.REGION] = region + + def SetupAuth(self, + project_id=None, + service_account_json_file=None, + database_url=None): + """Sets up authentication with Google APIs. + + This will use the credentials from service_account_json_file if provided, + falling back to application default credentials. + See https://cloud.google.com/docs/authentication/production. + + Args: + project_id: GCP project ID (e.g. myproject). If not provided, will attempt + to retrieve it from the credentials. + service_account_json_file: JSON file to use for credentials. If not + provided, will default to application default credentials. + database_url: Firebase realtime database URL to be used. If not + provided, connect attempts to the following DBs will be made, in + order: + https://{project_id}-cdbg.firebaseio.com + https://{project_id}-default-rtdb.firebaseio.com + Raises: + NoProjectIdError: If the project id cannot be determined. + """ + if service_account_json_file: + self._credentials = firebase_admin.credentials.Certificate( + service_account_json_file) + if not project_id: + with open(service_account_json_file, encoding='utf-8') as f: + project_id = json.load(f).get('project_id') + else: + if not project_id: + try: + r = requests.get( + f'{_METADATA_SERVER_URL}/project/project-id', + headers={'Metadata-Flavor': 'Google'}, + timeout=1) + project_id = r.text + except requests.exceptions.RequestException: + native.LogInfo('Metadata server not available') + + if not project_id: + raise NoProjectIdError( + 'Unable to determine the project id from the API credentials. ' + 'Please specify the project id using the --project_id flag.') + + self._project_id = project_id + self._database_url = database_url + + def Start(self): + """Starts the worker thread.""" + self._shutdown = False + + # Spin up the main thread which will create the other necessary threads. + self._main_thread = threading.Thread(target=self._MainThreadProc) + self._main_thread.name = 'Cloud Debugger main worker thread' + self._main_thread.daemon = True + self._main_thread.start() + + def Stop(self): + """Signals the worker threads to shut down and waits until it exits.""" + self._shutdown = True + self._new_updates.set() # Wake up the transmission thread. + + if self._main_thread is not None: + self._main_thread.join() + self._main_thread = None + + if self._transmission_thread is not None: + self._transmission_thread.join() + self._transmission_thread = None + + if self._mark_active_timer is not None: + self._mark_active_timer.cancel() + self._mark_active_timer = None + + if self._breakpoint_subscription is not None: + self._breakpoint_subscription.close() + self._breakpoint_subscription = None + + def EnqueueBreakpointUpdate(self, breakpoint_data): + """Asynchronously updates the specified breakpoint on the backend. + + This function returns immediately. The worker thread is actually doing + all the work. The worker thread is responsible to retry the transmission + in case of transient errors. + + The assumption is that the breakpoint is moving from Active to Final state. + + Args: + breakpoint: breakpoint in either final or non-final state. + """ + with self._transmission_thread_startup_lock: + if self._transmission_thread is None: + self._transmission_thread = threading.Thread( + target=self._TransmissionThreadProc) + self._transmission_thread.name = 'Cloud Debugger transmission thread' + self._transmission_thread.daemon = True + self._transmission_thread.start() + + self._transmission_queue.append((breakpoint_data, 0)) + self._new_updates.set() # Wake up the worker thread to send immediately. + + def _MainThreadProc(self): + """Entry point for the worker thread. + + This thread only serves to register and kick off the firebase subscription + which will run in its own thread. That thread will be owned by + self._breakpoint_subscription. + """ + connection_required, delay = True, 0 + while connection_required: + time.sleep(delay) + connection_required, delay = self._ConnectToDb() + self.connection_complete.set() + + registration_required, delay = True, 0 + while registration_required: + time.sleep(delay) + registration_required, delay = self._RegisterDebuggee() + self.registration_complete.set() + + subscription_required, delay = True, 0 + while subscription_required: + time.sleep(delay) + subscription_required, delay = self._SubscribeToBreakpoints() + self.subscription_complete.set() + + self._StartMarkActiveTimer() + + while not self._shutdown: + if self.on_idle is not None: + self.on_idle() + + time.sleep(1) + + def _TransmissionThreadProc(self): + """Entry point for the transmission worker thread.""" + + while not self._shutdown: + self._new_updates.clear() + + delay = self._TransmitBreakpointUpdates() + + self._new_updates.wait(delay) + + def _MarkActiveTimerFunc(self): + """Entry point for the mark active timer.""" + + try: + self._MarkDebuggeeActive() + except: + native.LogInfo( + f'Failed to mark debuggee as active: {traceback.format_exc()}') + finally: + self._StartMarkActiveTimer() + + def _StartMarkActiveTimer(self): + self._mark_active_timer = threading.Timer(self._mark_active_interval_sec, + self._MarkActiveTimerFunc) + self._mark_active_timer.start() + + def _ConnectToDb(self): + urls = [self._database_url] if self._database_url is not None else \ + [f'https://{self._project_id}-cdbg.firebaseio.com', + f'https://{self._project_id}-default-rtdb.firebaseio.com'] + + for url in urls: + native.LogInfo(f'Attempting to connect to DB with url: {url}') + + status, firebase_app = self._TryInitializeDbForUrl(url) + if status: + native.LogInfo(f'Successfully connected to DB with url: {url}') + self._database_url = url + self._firebase_app = firebase_app + self.connect_backoff.Succeeded() + return (False, 0) # Proceed immediately to registering the debuggee. + + return (True, self.connect_backoff.Failed()) + + def _TryInitializeDbForUrl(self, database_url): + # Note: if self._credentials is None, default app credentials will be used. + app = None + try: + app = firebase_admin.initialize_app( + self._credentials, {'databaseURL': database_url}, name='cdbg') + + if self._CheckSchemaVersionPresence(app): + return True, app + + except ValueError: + native.LogWarning( + f'Failed to initialize firebase: {traceback.format_exc()}') + + # This is the failure path, if we hit here we must cleanup the app handle + if app is not None: + firebase_admin.delete_app(app) + app = None + + return False, app + + def _RegisterDebuggee(self): + """Single attempt to register the debuggee. + + If the registration succeeds, sets self._debuggee_id to the registered + debuggee ID. + + Args: + service: client to use for API calls + + Returns: + (registration_required, delay) tuple + """ + debuggee = None + try: + debuggee = self._GetDebuggee() + self._debuggee_id = debuggee['id'] + except BaseException: + native.LogWarning( + f'Debuggee information not available: {traceback.format_exc()}') + return (True, self.register_backoff.Failed()) + + try: + present = self._CheckDebuggeePresence() + if present: + self._MarkDebuggeeActive() + else: + debuggee_path = f'cdbg/debuggees/{self._debuggee_id}' + native.LogInfo( + f'Registering at {self._database_url}, path: {debuggee_path}') + debuggee_data = copy.deepcopy(debuggee) + debuggee_data['registrationTimeUnixMsec'] = {'.sv': 'timestamp'} + debuggee_data['lastUpdateTimeUnixMsec'] = {'.sv': 'timestamp'} + firebase_admin.db.reference(debuggee_path, + self._firebase_app).set(debuggee_data) + + native.LogInfo( + f'Debuggee registered successfully, ID: {self._debuggee_id}') + + self.register_backoff.Succeeded() + return (False, 0) # Proceed immediately to subscribing to breakpoints. + except BaseException as e: + # There is no significant benefit to handing different exceptions + # in different ways; we will log and retry regardless. + native.LogInfo(f'Failed to register debuggee: {repr(e)}') + return (True, self.register_backoff.Failed()) + + def _CheckSchemaVersionPresence(self, firebase_app): + path = f'cdbg/schema_version' + try: + snapshot = firebase_admin.db.reference(path, firebase_app).get() + # The value doesn't matter; just return true if there's any value. + return snapshot is not None + except BaseException as e: + native.LogInfo( + f'Failed to check schema version presence at {path}: {repr(e)}') + return False + + def _CheckDebuggeePresence(self): + path = f'cdbg/debuggees/{self._debuggee_id}/registrationTimeUnixMsec' + try: + snapshot = firebase_admin.db.reference(path, self._firebase_app).get() + # The value doesn't matter; just return true if there's any value. + return snapshot is not None + except BaseException as e: + native.LogInfo(f'Failed to check debuggee presence at {path}: {repr(e)}') + return False + + def _MarkDebuggeeActive(self): + active_path = f'cdbg/debuggees/{self._debuggee_id}/lastUpdateTimeUnixMsec' + try: + server_time = {'.sv': 'timestamp'} + firebase_admin.db.reference(active_path, + self._firebase_app).set(server_time) + except BaseException: + native.LogInfo( + f'Failed to mark debuggee active: {traceback.format_exc()}') + + def _SubscribeToBreakpoints(self): + # Kill any previous subscriptions first. + if self._breakpoint_subscription is not None: + self._breakpoint_subscription.close() + self._breakpoint_subscription = None + + path = f'cdbg/breakpoints/{self._debuggee_id}/active' + native.LogInfo(f'Subscribing to breakpoint updates at {path}') + ref = firebase_admin.db.reference(path, self._firebase_app) + try: + self._breakpoint_subscription = ref.listen(self._ActiveBreakpointCallback) + return (False, 0) + except firebase_admin.exceptions.FirebaseError: + native.LogInfo( + f'Failed to subscribe to breakpoints: {traceback.format_exc()}') + return (True, self.subscribe_backoff.Failed()) + + def _ActiveBreakpointCallback(self, event): + if event.event_type == 'put': + if event.data is None: + # Either deleting a breakpoint or initializing with no breakpoints. + # Initializing with no breakpoints is a no-op. + # If deleting, event.path will be /{breakpointid} + if event.path != '/': + breakpoint_id = event.path[1:] + # Breakpoint may have already been deleted, so pop for possible no-op. + self._breakpoints.pop(breakpoint_id, None) + else: + if event.path == '/': + # New set of breakpoints. + self._breakpoints = {} + for (key, value) in event.data.items(): + self._AddBreakpoint(key, value) + else: + # New breakpoint. + breakpoint_id = event.path[1:] + self._AddBreakpoint(breakpoint_id, event.data) + + elif event.event_type == 'patch': + # New breakpoint or breakpoints. + for (key, value) in event.data.items(): + self._AddBreakpoint(key, value) + else: + native.LogWarning('Unexpected event from Firebase: ' + f'{event.event_type} {event.path} {event.data}') + return + + native.LogInfo(f'Breakpoints list changed, {len(self._breakpoints)} active') + self.on_active_breakpoints_changed(list(self._breakpoints.values())) + + def _AddBreakpoint(self, breakpoint_id, breakpoint_data): + breakpoint_data['id'] = breakpoint_id + self._breakpoints[breakpoint_id] = breakpoint_data + + def _TransmitBreakpointUpdates(self): + """Tries to send pending breakpoint updates to the backend. + + Sends all the pending breakpoint updates. In case of transient failures, + the breakpoint is inserted back to the top of the queue. Application + failures are not retried (for example updating breakpoint in a final + state). + + Each pending breakpoint maintains a retry counter. After repeated transient + failures the breakpoint is discarded and dropped from the queue. + + Args: + service: client to use for API calls + + Returns: + (reconnect, timeout) tuple. The first element ("reconnect") is set to + true on unexpected HTTP responses. The caller should discard the HTTP + connection and create a new one. The second element ("timeout") is + set to None if all pending breakpoints were sent successfully. Otherwise + returns time interval in seconds to stall before retrying. + """ + retry_list = [] + + # There is only one consumer, so two step pop is safe. + while self._transmission_queue: + breakpoint_data, retry_count = self._transmission_queue.popleft() + + bp_id = breakpoint_data['id'] + + try: + # Something has changed on the breakpoint. + # It should be going from active to final, but let's make sure. + if not breakpoint_data.get('isFinalState', False): + raise BaseException( + f'Unexpected breakpoint update requested: {breakpoint_data}') + + # If action is missing, it should be set to 'CAPTURE' + is_logpoint = breakpoint_data.get('action') == 'LOG' + is_snapshot = not is_logpoint + if is_snapshot: + breakpoint_data['action'] = 'CAPTURE' + + # Set the completion time on the server side using a magic value. + breakpoint_data['finalTimeUnixMsec'] = {'.sv': 'timestamp'} + + # First, remove from the active breakpoints. + bp_ref = firebase_admin.db.reference( + f'cdbg/breakpoints/{self._debuggee_id}/active/{bp_id}', + self._firebase_app) + bp_ref.delete() + + summary_data = breakpoint_data + # Save snapshot data for snapshots only. + if is_snapshot: + # Note that there may not be snapshot data. + bp_ref = firebase_admin.db.reference( + f'cdbg/breakpoints/{self._debuggee_id}/snapshot/{bp_id}', + self._firebase_app) + bp_ref.set(breakpoint_data) + + # Now strip potential snapshot data. + summary_data = copy.deepcopy(breakpoint_data) + summary_data.pop('evaluatedExpressions', None) + summary_data.pop('stackFrames', None) + summary_data.pop('variableTable', None) + + # Then add it to the list of final breakpoints. + bp_ref = firebase_admin.db.reference( + f'cdbg/breakpoints/{self._debuggee_id}/final/{bp_id}', + self._firebase_app) + bp_ref.set(summary_data) + + native.LogInfo(f'Breakpoint {bp_id} update transmitted successfully') + + except firebase_admin.exceptions.FirebaseError as err: + if err.code in _TRANSIENT_ERROR_CODES: + if retry_count < self.max_transmit_attempts - 1: + native.LogInfo(f'Failed to send breakpoint {bp_id} update: ' + f'{traceback.format_exc()}') + retry_list.append((breakpoint_data, retry_count + 1)) + else: + native.LogWarning( + f'Breakpoint {bp_id} retry count exceeded maximum') + else: + # This is very common if multiple instances are sending final update + # simultaneously. + native.LogInfo(f'{err}, breakpoint: {bp_id}') + + except BaseException: + native.LogWarning(f'Fatal error sending breakpoint {bp_id} update: ' + f'{traceback.format_exc()}') + + self._transmission_queue.extend(retry_list) + + if not self._transmission_queue: + self.update_backoff.Succeeded() + # Nothing to send, wait until next breakpoint update. + return None + else: + return self.update_backoff.Failed() + + def _GetDebuggee(self): + """Builds the debuggee structure.""" + major_version = version.__version__.split('.', maxsplit=1)[0] + python_version = ''.join(platform.python_version().split('.')[:2]) + agent_version = f'google.com/python{python_version}-gcp/v{major_version}' + + debuggee = { + 'description': self._GetDebuggeeDescription(), + 'labels': self._debuggee_labels, + 'agentVersion': agent_version, + } + + source_context = self._ReadAppJsonFile('source-context.json') + if source_context: + debuggee['sourceContexts'] = [source_context] + + debuggee['uniquifier'] = self._ComputeUniquifier(debuggee) + + debuggee['id'] = self._ComputeDebuggeeId(debuggee) + + return debuggee + + def _ComputeDebuggeeId(self, debuggee): + """Computes a debuggee ID. + + The debuggee ID has to be identical on all instances. Therefore the + ID should not include any random elements or elements that may be + different on different instances. + + Args: + debuggee: complete debuggee message (including uniquifier) + + Returns: + Debuggee ID meeting the criteria described above. + """ + fullhash = hashlib.sha1(json.dumps(debuggee, + sort_keys=True).encode()).hexdigest() + return f'd-{fullhash[:8]}' + + def _GetDebuggeeDescription(self): + """Formats debuggee description based on debuggee labels.""" + return '-'.join(self._debuggee_labels[label] + for label in _DESCRIPTION_LABELS + if label in self._debuggee_labels) + + def _ComputeUniquifier(self, debuggee): + """Computes debuggee uniquifier. + + The debuggee uniquifier has to be identical on all instances. Therefore the + uniquifier should not include any random numbers and should only be based + on inputs that are guaranteed to be the same on all instances. + + Args: + debuggee: complete debuggee message without the uniquifier + + Returns: + Hex string of SHA1 hash of project information, debuggee labels and + debuglet version. + """ + uniquifier = hashlib.sha1() + + # Compute hash of application files if we don't have source context. This + # way we can still distinguish between different deployments. + if ('minorversion' not in debuggee.get('labels', []) and + 'sourceContexts' not in debuggee): + uniquifier_computer.ComputeApplicationUniquifier(uniquifier) + + return uniquifier.hexdigest() + + def _ReadAppJsonFile(self, relative_path): + """Reads JSON file from an application directory. + + Args: + relative_path: file name relative to application root directory. + + Returns: + Parsed JSON data or None if the file does not exist, can't be read or + not a valid JSON file. + """ + try: + with open( + os.path.join(sys.path[0], relative_path), 'r', encoding='utf-8') as f: + return json.load(f) + except (IOError, ValueError): + return None diff --git a/src/googleclouddebugger/gcp_hub_client.py b/src/googleclouddebugger/gcp_hub_client.py deleted file mode 100644 index 6214479..0000000 --- a/src/googleclouddebugger/gcp_hub_client.py +++ /dev/null @@ -1,579 +0,0 @@ -# Copyright 2015 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS-IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Communicates with Cloud Debugger backend over HTTP.""" - -from collections import deque -import copy -import hashlib -import inspect -import json -import logging -import os -import platform -import socket -import sys -import threading -import time -import traceback - - - -import google_auth_httplib2 -import googleapiclient -import googleapiclient.discovery -import httplib2 -import six - -import google.auth -from google.oauth2 import service_account - -from . import backoff -from . import cdbg_native as native -from . import labels -from . import uniquifier_computer -from . import application_info -from . import version -# This module catches all exception. This is safe because it runs in -# a daemon thread (so we are not blocking Ctrl+C). We need to catch all -# the exception because HTTP client is unpredictable as far as every -# exception it can throw. -# pylint: disable=broad-except - -# API scope we are requesting when service account authentication is enabled. -_CLOUD_PLATFORM_SCOPE = ['https://www.googleapis.com/auth/cloud-platform'] - -# Set of all known debuggee labels (passed down as flags). The value of -# a map is optional environment variable that can be used to set the flag -# (flags still take precedence). -_DEBUGGEE_LABELS = { - labels.Debuggee.MODULE: [ - 'GAE_SERVICE', 'GAE_MODULE_NAME', 'K_SERVICE', 'FUNCTION_NAME' - ], - labels.Debuggee.VERSION: [ - 'GAE_VERSION', 'GAE_MODULE_VERSION', 'K_REVISION', - 'X_GOOGLE_FUNCTION_VERSION' - ], - labels.Debuggee.MINOR_VERSION: ['GAE_DEPLOYMENT_ID', 'GAE_MINOR_VERSION'] -} - -# Debuggee labels used to format debuggee description (ordered). The minor -# version is excluded for the sake of consistency with AppEngine UX. -_DESCRIPTION_LABELS = [ - labels.Debuggee.PROJECT_ID, labels.Debuggee.MODULE, labels.Debuggee.VERSION -] - -# HTTP timeout when accessing the cloud debugger API. It is selected to be -# longer than the typical controller.breakpoints.list hanging get latency -# of 40 seconds. -_HTTP_TIMEOUT_SECONDS = 100 - -# The map from the values of flags (breakpoint_enable_canary, -# breakpoint_allow_canary_override) to canary mode. -_CANARY_MODE_MAP = { - (True, True): 'CANARY_MODE_DEFAULT_ENABLED', - (True, False): 'CANARY_MODE_ALWAYS_ENABLED', - (False, True): 'CANARY_MODE_DEFAULT_DISABLED', - (False, False): 'CANARY_MODE_ALWAYS_DISABLED', -} - - -class NoProjectIdError(Exception): - """Used to indicate the project id cannot be determined.""" - - -class GcpHubClient(object): - """Controller API client. - - Registers the debuggee, queries the active breakpoints and sends breakpoint - updates to the backend. - - This class supports two types of authentication: application default - credentials or a manually provided JSON credentials file for a service - account. - - GcpHubClient creates a worker thread that communicates with the backend. The - thread can be stopped with a Stop function, but it is optional since the - worker thread is marked as daemon. - """ - - def __init__(self): - self.on_active_breakpoints_changed = lambda x: None - self.on_idle = lambda: None - self._debuggee_labels = {} - self._service_account_auth = False - self._debuggee_id = None - self._agent_id = None - self._canary_mode = None - self._wait_token = 'init' - self._breakpoints = [] - self._main_thread = None - self._transmission_thread = None - self._transmission_thread_startup_lock = threading.Lock() - self._transmission_queue = deque(maxlen=100) - self._new_updates = threading.Event() - - # Disable logging in the discovery API to avoid excessive logging. - class _ChildLogFilter(logging.Filter): - """Filter to eliminate info-level logging when called from this module.""" - - def __init__(self, filter_levels=None): - super(_ChildLogFilter, self).__init__() - self._filter_levels = filter_levels or set(logging.INFO) - # Get name without extension to avoid .py vs .pyc issues - self._my_filename = os.path.splitext( - inspect.getmodule(_ChildLogFilter).__file__)[0] - - def filter(self, record): - if record.levelno not in self._filter_levels: - return True - callerframes = inspect.getouterframes(inspect.currentframe()) - for f in callerframes: - if os.path.splitext(f[1])[0] == self._my_filename: - return False - return True - self._log_filter = _ChildLogFilter({logging.INFO}) - googleapiclient.discovery.logger.addFilter(self._log_filter) - - # - # Configuration options (constants only modified by unit test) - # - - # Delay before retrying failed request. - self.register_backoff = backoff.Backoff() # Register debuggee. - self.list_backoff = backoff.Backoff() # Query active breakpoints. - self.update_backoff = backoff.Backoff() # Update breakpoint. - - # Maximum number of times that the message is re-transmitted before it - # is assumed to be poisonous and discarded - self.max_transmit_attempts = 10 - - def InitializeDebuggeeLabels(self, flags): - """Initialize debuggee labels from environment variables and flags. - - The caller passes all the flags that the debuglet got. This function - will only use the flags used to label the debuggee. Flags take precedence - over environment variables. - - Debuggee description is formatted from available flags. - - Args: - flags: dictionary of debuglet command line flags. - """ - self._debuggee_labels = {} - - for (label, var_names) in six.iteritems(_DEBUGGEE_LABELS): - # var_names is a list of possible environment variables that may contain - # the label value. Find the first one that is set. - for name in var_names: - value = os.environ.get(name) - if value: - # Special case for module. We omit the "default" module - # to stay consistent with AppEngine. - if label == labels.Debuggee.MODULE and value == 'default': - break - self._debuggee_labels[label] = value - break - - # Special case when FUNCTION_NAME is set and X_GOOGLE_FUNCTION_VERSION - # isn't set. We set the version to 'unversioned' to be consistent with other - # agents. - # TODO: Stop assigning 'unversioned' to a GCF and find the - # actual version. - if ('FUNCTION_NAME' in os.environ and - labels.Debuggee.VERSION not in self._debuggee_labels): - self._debuggee_labels[labels.Debuggee.VERSION] = 'unversioned' - - if flags: - self._debuggee_labels.update( - {name: value for (name, value) in six.iteritems(flags) - if name in _DEBUGGEE_LABELS}) - - self._debuggee_labels[labels.Debuggee.PROJECT_ID] = self._project_id - - platform_enum = application_info.GetPlatform() - self._debuggee_labels[labels.Debuggee.PLATFORM] = platform_enum.value - - if platform_enum == application_info.PlatformType.CLOUD_FUNCTION: - region = application_info.GetRegion() - if region: - self._debuggee_labels[labels.Debuggee.REGION] = region - - def SetupAuth(self, - project_id=None, - project_number=None, - service_account_json_file=None): - """Sets up authentication with Google APIs. - - This will use the credentials from service_account_json_file if provided, - falling back to application default credentials. - See https://cloud.google.com/docs/authentication/production. - - Args: - project_id: GCP project ID (e.g. myproject). If not provided, will attempt - to retrieve it from the credentials. - project_number: GCP project number (e.g. 72386324623). If not provided, - project_id will be used in its place. - service_account_json_file: JSON file to use for credentials. If not - provided, will default to application default credentials. - Raises: - NoProjectIdError: If the project id cannot be determined. - """ - if service_account_json_file: - self._credentials = ( - service_account.Credentials.from_service_account_file( - service_account_json_file, scopes=_CLOUD_PLATFORM_SCOPE)) - if not project_id: - with open(service_account_json_file) as f: - project_id = json.load(f).get('project_id') - else: - self._credentials, credentials_project_id = google.auth.default( - scopes=_CLOUD_PLATFORM_SCOPE) - project_id = project_id or credentials_project_id - - if not project_id: - raise NoProjectIdError( - 'Unable to determine the project id from the API credentials. ' - 'Please specify the project id using the --project_id flag.') - - self._project_id = project_id - self._project_number = project_number or project_id - - def SetupCanaryMode(self, breakpoint_enable_canary, - breakpoint_allow_canary_override): - """Sets up canaryMode for the debuggee according to input parameters. - - Args: - breakpoint_enable_canary: str or bool, whether to enable breakpoint - canary. Any string except 'True' is interpreted as False. - breakpoint_allow_canary_override: str or bool, whether to allow the - individually set breakpoint to override the canary behavior. Any - string except 'True' is interpreted as False. - """ - enable_canary = breakpoint_enable_canary in ('True', True) - allow_canary_override = breakpoint_allow_canary_override in ('True', True) - self._canary_mode = _CANARY_MODE_MAP[enable_canary, allow_canary_override] - - def Start(self): - """Starts the worker thread.""" - self._shutdown = False - - self._main_thread = threading.Thread(target=self._MainThreadProc) - self._main_thread.name = 'Cloud Debugger main worker thread' - self._main_thread.daemon = True - self._main_thread.start() - - def Stop(self): - """Signals the worker threads to shut down and waits until it exits.""" - self._shutdown = True - self._new_updates.set() # Wake up the transmission thread. - - if self._main_thread is not None: - self._main_thread.join() - self._main_thread = None - - if self._transmission_thread is not None: - self._transmission_thread.join() - self._transmission_thread = None - - def EnqueueBreakpointUpdate(self, breakpoint): - """Asynchronously updates the specified breakpoint on the backend. - - This function returns immediately. The worker thread is actually doing - all the work. The worker thread is responsible to retry the transmission - in case of transient errors. - - Args: - breakpoint: breakpoint in either final or non-final state. - """ - with self._transmission_thread_startup_lock: - if self._transmission_thread is None: - self._transmission_thread = threading.Thread( - target=self._TransmissionThreadProc) - self._transmission_thread.name = 'Cloud Debugger transmission thread' - self._transmission_thread.daemon = True - self._transmission_thread.start() - - self._transmission_queue.append((breakpoint, 0)) - self._new_updates.set() # Wake up the worker thread to send immediately. - - def _BuildService(self): - http = httplib2.Http(timeout=_HTTP_TIMEOUT_SECONDS) - http = google_auth_httplib2.AuthorizedHttp(self._credentials, http) - - api = googleapiclient.discovery.build( - 'clouddebugger', 'v2', http=http, cache_discovery=False) - return api.controller() - - def _MainThreadProc(self): - """Entry point for the worker thread.""" - registration_required = True - while not self._shutdown: - if registration_required: - service = self._BuildService() - registration_required, delay = self._RegisterDebuggee(service) - - if not registration_required: - registration_required, delay = self._ListActiveBreakpoints(service) - - if self.on_idle is not None: - self.on_idle() - - if not self._shutdown: - time.sleep(delay) - - def _TransmissionThreadProc(self): - """Entry point for the transmission worker thread.""" - reconnect = True - - while not self._shutdown: - self._new_updates.clear() - - if reconnect: - service = self._BuildService() - reconnect = False - - reconnect, delay = self._TransmitBreakpointUpdates(service) - - self._new_updates.wait(delay) - - def _RegisterDebuggee(self, service): - """Single attempt to register the debuggee. - - If the registration succeeds, sets self._debuggee_id to the registered - debuggee ID. - - Args: - service: client to use for API calls - - Returns: - (registration_required, delay) tuple - """ - try: - request = {'debuggee': self._GetDebuggee()} - - try: - response = service.debuggees().register(body=request).execute() - - # self._project_number will refer to the project id on initialization if - # the project number is not available. The project field in the debuggee - # will always refer to the project number. Update so the server will not - # have to do id->number translations in the future. - project_number = response['debuggee'].get('project') - self._project_number = project_number or self._project_number - - self._debuggee_id = response['debuggee']['id'] - self._agent_id = response['agentId'] - native.LogInfo( - 'Debuggee registered successfully, ID: %s, agent ID: %s, ' - 'canary mode: %s' % (self._debuggee_id, self._agent_id, - response['debuggee'].get('canaryMode'))) - self.register_backoff.Succeeded() - return (False, 0) # Proceed immediately to list active breakpoints. - except BaseException: - native.LogInfo('Failed to register debuggee: %s, %s' % - (request, traceback.format_exc())) - except BaseException: - native.LogWarning('Debuggee information not available: ' + - traceback.format_exc()) - - return (True, self.register_backoff.Failed()) - - def _ListActiveBreakpoints(self, service): - """Single attempt query the list of active breakpoints. - - Must not be called before the debuggee has been registered. If the request - fails, this function resets self._debuggee_id, which triggers repeated - debuggee registration. - - Args: - service: client to use for API calls - - Returns: - (registration_required, delay) tuple - """ - try: - response = service.debuggees().breakpoints().list( - debuggeeId=self._debuggee_id, - agentId=self._agent_id, - waitToken=self._wait_token, - successOnTimeout=True).execute() - if not response.get('waitExpired'): - self._wait_token = response.get('nextWaitToken') - breakpoints = response.get('breakpoints') or [] - if self._breakpoints != breakpoints: - self._breakpoints = breakpoints - native.LogInfo( - 'Breakpoints list changed, %d active, wait token: %s' % ( - len(self._breakpoints), self._wait_token)) - self.on_active_breakpoints_changed(copy.deepcopy(self._breakpoints)) - except BaseException: - native.LogInfo('Failed to query active breakpoints: ' + - traceback.format_exc()) - - # Forget debuggee ID to trigger repeated debuggee registration. Once the - # registration succeeds, the worker thread will retry this query - self._debuggee_id = None - - return (True, self.list_backoff.Failed()) - - self.list_backoff.Succeeded() - return (False, 0) - - def _TransmitBreakpointUpdates(self, service): - """Tries to send pending breakpoint updates to the backend. - - Sends all the pending breakpoint updates. In case of transient failures, - the breakpoint is inserted back to the top of the queue. Application - failures are not retried (for example updating breakpoint in a final - state). - - Each pending breakpoint maintains a retry counter. After repeated transient - failures the breakpoint is discarded and dropped from the queue. - - Args: - service: client to use for API calls - - Returns: - (reconnect, timeout) tuple. The first element ("reconnect") is set to - true on unexpected HTTP responses. The caller should discard the HTTP - connection and create a new one. The second element ("timeout") is - set to None if all pending breakpoints were sent successfully. Otherwise - returns time interval in seconds to stall before retrying. - """ - reconnect = False - retry_list = [] - - # There is only one consumer, so two step pop is safe. - while self._transmission_queue: - breakpoint, retry_count = self._transmission_queue.popleft() - - try: - service.debuggees().breakpoints().update( - debuggeeId=self._debuggee_id, id=breakpoint['id'], - body={'breakpoint': breakpoint}).execute() - - native.LogInfo('Breakpoint %s update transmitted successfully' % ( - breakpoint['id'])) - except googleapiclient.errors.HttpError as err: - # Treat 400 error codes (except timeout) as application error that will - # not be retried. All other errors are assumed to be transient. - status = err.resp.status - is_transient = ((status >= 500) or (status == 408)) - if is_transient: - if retry_count < self.max_transmit_attempts - 1: - native.LogInfo('Failed to send breakpoint %s update: %s' % - (breakpoint['id'], traceback.format_exc())) - retry_list.append((breakpoint, retry_count + 1)) - else: - native.LogWarning('Breakpoint %s retry count exceeded maximum' % - breakpoint['id']) - else: - # This is very common if multiple instances are sending final update - # simultaneously. - native.LogInfo('%s, breakpoint: %s' % (err, breakpoint['id'])) - except socket.error as err: - if retry_count < self.max_transmit_attempts - 1: - native.LogInfo( - 'Socket error %d while sending breakpoint %s update: %s' % - (err.errno, breakpoint['id'], traceback.format_exc())) - retry_list.append((breakpoint, retry_count + 1)) - else: - native.LogWarning('Breakpoint %s retry count exceeded maximum' % - breakpoint['id']) - # Socket errors shouldn't persist like this; reconnect. - reconnect = True - except BaseException: - native.LogWarning( - 'Fatal error sending breakpoint %s update: %s' % ( - breakpoint['id'], traceback.format_exc())) - reconnect = True - - self._transmission_queue.extend(retry_list) - - if not self._transmission_queue: - self.update_backoff.Succeeded() - # Nothing to send, wait until next breakpoint update. - return (reconnect, None) - else: - return (reconnect, self.update_backoff.Failed()) - - def _GetDebuggee(self): - """Builds the debuggee structure.""" - major_version = 'v' + version.__version__.split('.')[0] - python_version = ''.join(platform.python_version().split('.')[:2]) - agent_version = ('google.com/python%s-gcp/%s' % (python_version, - major_version)) - - debuggee = { - 'project': self._project_number, - 'description': self._GetDebuggeeDescription(), - 'labels': self._debuggee_labels, - 'agentVersion': agent_version, - 'canaryMode': self._canary_mode, - } - - source_context = self._ReadAppJsonFile('source-context.json') - if source_context: - debuggee['sourceContexts'] = [source_context] - - debuggee['uniquifier'] = self._ComputeUniquifier(debuggee) - - return debuggee - - def _GetDebuggeeDescription(self): - """Formats debuggee description based on debuggee labels.""" - return '-'.join(self._debuggee_labels[label] - for label in _DESCRIPTION_LABELS - if label in self._debuggee_labels) - - def _ComputeUniquifier(self, debuggee): - """Computes debuggee uniquifier. - - The debuggee uniquifier has to be identical on all instances. Therefore the - uniquifier should not include any random numbers and should only be based - on inputs that are guaranteed to be the same on all instances. - - Args: - debuggee: complete debuggee message without the uniquifier - - Returns: - Hex string of SHA1 hash of project information, debuggee labels and - debuglet version. - """ - uniquifier = hashlib.sha1() - - # Compute hash of application files if we don't have source context. This - # way we can still distinguish between different deployments. - if ('minorversion' not in debuggee.get('labels', []) and - 'sourceContexts' not in debuggee): - uniquifier_computer.ComputeApplicationUniquifier(uniquifier) - - return uniquifier.hexdigest() - - def _ReadAppJsonFile(self, relative_path): - """Reads JSON file from an application directory. - - Args: - relative_path: file name relative to application root directory. - - Returns: - Parsed JSON data or None if the file does not exist, can't be read or - not a valid JSON file. - """ - try: - with open(os.path.join(sys.path[0], relative_path), 'r') as f: - return json.load(f) - except (IOError, ValueError): - return None diff --git a/src/googleclouddebugger/glob_data_visibility_policy.py b/src/googleclouddebugger/glob_data_visibility_policy.py index 00255ef..275e69a 100644 --- a/src/googleclouddebugger/glob_data_visibility_policy.py +++ b/src/googleclouddebugger/glob_data_visibility_policy.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Determines the visibility of python data and symbols. Example Usage: @@ -33,7 +32,6 @@ import fnmatch - # Possible visibility responses RESPONSES = { 'UNKNOWN_TYPE': 'could not determine type', @@ -86,4 +84,3 @@ def _Matches(path, pattern_list): """ # Note: This code does not scale to large pattern_list sizes. return any(fnmatch.fnmatchcase(path, pattern) for pattern in pattern_list) - diff --git a/src/googleclouddebugger/immutability_tracer.cc b/src/googleclouddebugger/immutability_tracer.cc index 6cfa66c..c05d407 100644 --- a/src/googleclouddebugger/immutability_tracer.cc +++ b/src/googleclouddebugger/immutability_tracer.cc @@ -365,7 +365,6 @@ static OpcodeMutableStatus IsOpcodeMutable(const uint8_t opcode) { case CONTINUE_LOOP: case SETUP_LOOP: #endif -#if PY_MAJOR_VERSION >= 3 case DUP_TOP_TWO: case BINARY_MATRIX_MULTIPLY: case INPLACE_MATRIX_MULTIPLY: @@ -402,24 +401,15 @@ static OpcodeMutableStatus IsOpcodeMutable(const uint8_t opcode) { // Added back in Python 3.8 (was in 2.7 as well) case ROT_FOUR: #endif -#else - case ROT_FOUR: - case DUP_TOPX: - case UNARY_NOT: - case UNARY_CONVERT: - case BINARY_DIVIDE: - case BINARY_OR: - case INPLACE_DIVIDE: - case SLICE+0: - case SLICE+1: - case SLICE+2: - case SLICE+3: - case LOAD_LOCALS: - case EXEC_STMT: - case JUMP_ABSOLUTE: - case CALL_FUNCTION_VAR: - case CALL_FUNCTION_VAR_KW: - case MAKE_CLOSURE: +#if PY_VERSION_HEX >= 0x030A0000 + // Added in Python 3.10 + case COPY_DICT_WITHOUT_KEYS: + case GET_LEN: + case MATCH_MAPPING: + case MATCH_SEQUENCE: + case MATCH_KEYS: + case MATCH_CLASS: + case ROT_N: #endif return OPCODE_NOT_MUTABLE; @@ -450,7 +440,6 @@ static OpcodeMutableStatus IsOpcodeMutable(const uint8_t opcode) { // Removed in Python 3.8. case SETUP_EXCEPT: #endif -#if PY_MAJOR_VERSION >= 3 case GET_AITER: case GET_ANEXT: case BEFORE_ASYNC_WITH: @@ -490,22 +479,9 @@ static OpcodeMutableStatus IsOpcodeMutable(const uint8_t opcode) { case WITH_EXCEPT_START: case LOAD_ASSERTION_ERROR: #endif -#else - case STORE_SLICE+0: - case STORE_SLICE+1: - case STORE_SLICE+2: - case STORE_SLICE+3: - case DELETE_SLICE+0: - case DELETE_SLICE+1: - case DELETE_SLICE+2: - case DELETE_SLICE+3: - case STORE_MAP: - case PRINT_ITEM_TO: - case PRINT_ITEM: - case PRINT_NEWLINE_TO: - case PRINT_NEWLINE: - case BUILD_CLASS: - case WITH_CLEANUP: +#if PY_VERSION_HEX >= 0x030A0000 + // Added in Python 3.10 + case GEN_START: #endif return OPCODE_MUTABLE; @@ -525,16 +501,11 @@ void ImmutabilityTracer::ProcessCodeRange(const uint8_t* code_start, // We don't worry about the sizes of instructions with EXTENDED_ARG. // The argument does not really matter and so EXTENDED_ARGs can be // treated as just another instruction with an opcode. -#if PY_MAJOR_VERSION >= 3 opcodes += 2; -#else - opcodes += HAS_ARG(opcode) ? 3 : 1; -#endif DCHECK_LE(opcodes, end); break; case OPCODE_MAYBE_MUTABLE: -#if PY_MAJOR_VERSION >= 3 if (opcode == JUMP_ABSOLUTE) { // Check for a jump to itself, which happens in "while True: pass". // The tracer won't call our tracing function unless there is a jump @@ -551,7 +522,6 @@ void ImmutabilityTracer::ProcessCodeRange(const uint8_t* code_start, DCHECK_LE(opcodes, end); break; } -#endif LOG(WARNING) << "Unknown opcode " << static_cast(opcode); mutable_code_detected_ = true; return; diff --git a/src/googleclouddebugger/imphook2.py b/src/googleclouddebugger/imphook.py similarity index 93% rename from src/googleclouddebugger/imphook2.py rename to src/googleclouddebugger/imphook.py index 1aeb89f..2e80648 100644 --- a/src/googleclouddebugger/imphook2.py +++ b/src/googleclouddebugger/imphook.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Support for breakpoints on modules that haven't been loaded yet. -This is the new module import hook which: +This is the module import hook which: 1. Takes a partial path of the module file excluding the file extension as input (can be as short as 'foo' or longer such as 'sys/path/pkg/foo'). 2. At each (top-level-only) import statement: @@ -27,7 +26,6 @@ b. Checks sys.modules if any of these modules have a file that matches the given path, using suffix match. -For the old module import hook, see imphook.py file. """ import importlib @@ -36,10 +34,9 @@ import sys # Must be imported, otherwise import hooks don't work. import threading -import six -from six.moves import builtins # pylint: disable=redefined-builtin +import builtins -from . import module_utils2 +from . import module_utils # Callbacks to invoke when a module is imported. _import_callbacks = {} @@ -112,14 +109,12 @@ def _InstallImportHookBySuffix(): assert _real_import builtins.__import__ = _ImportHookBySuffix - if six.PY3: - # In Python 2, importlib.import_module calls __import__ internally so - # overriding __import__ is enough. In Python 3, they are separate so it also - # needs to be overwritten. - global _real_import_module - _real_import_module = importlib.import_module - assert _real_import_module - importlib.import_module = _ImportModuleHookBySuffix + # importlib.import_module and __import__ are separate in Python 3 so both + # need to be overwritten. + global _real_import_module + _real_import_module = importlib.import_module + assert _real_import_module + importlib.import_module = _ImportModuleHookBySuffix def _IncrementNestLevel(): @@ -168,8 +163,11 @@ def _ProcessImportBySuffix(name, fromlist, globals): # pylint: disable=redefined-builtin, g-doc-args, g-doc-return-or-yield -def _ImportHookBySuffix( - name, globals=None, locals=None, fromlist=None, level=None): +def _ImportHookBySuffix(name, + globals=None, + locals=None, + fromlist=None, + level=None): """Callback when an import statement is executed by the Python interpreter. Argument names have to exactly match those of __import__. Otherwise calls @@ -179,12 +177,9 @@ def _ImportHookBySuffix( if level is None: # A level of 0 means absolute import, positive values means relative - # imports, and -1 means to try both an absolute and relative import. - # Since imports were disambiguated in Python 3, -1 is not a valid value. - # The default values are 0 and -1 for Python 3 and 3 respectively. - # https://docs.python.org/2/library/functions.html#__import__ + # imports. # https://docs.python.org/3/library/functions.html#__import__ - level = 0 if six.PY3 else -1 + level = 0 try: # Really import modules. @@ -272,6 +267,7 @@ def _GenerateNames(name, fromlist, globals): the execution of this import statement. The returned set may contain names that are not real modules. """ + def GetCurrentPackage(globals): """Finds the name of the package for the currently executing module.""" if not globals: @@ -375,6 +371,7 @@ def _InvokeImportCallbackBySuffix(names): to a module. The list is expected to be much smaller than the exact sys.modules so that a linear search is not as costly. """ + def GetModuleFromName(name, path): """Returns the loaded module for this name/path, or None if not found. @@ -433,7 +430,7 @@ def GetModuleFromName(name, path): if not os.path.isabs(mod_root): mod_root = os.path.join(os.curdir, mod_root) - if module_utils2.IsPathSuffix(mod_root, root): + if module_utils.IsPathSuffix(mod_root, root): for callback in callbacks.copy(): callback(module) break diff --git a/src/googleclouddebugger/labels.py b/src/googleclouddebugger/labels.py index d22129a..1bca819 100644 --- a/src/googleclouddebugger/labels.py +++ b/src/googleclouddebugger/labels.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Defines the keys of the well known labels used by the cloud debugger. TODO: Define these strings in a common format for all agents to diff --git a/src/googleclouddebugger/module_explorer.py b/src/googleclouddebugger/module_explorer.py index 75edb05..ac62ce4 100644 --- a/src/googleclouddebugger/module_explorer.py +++ b/src/googleclouddebugger/module_explorer.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Finds all the code objects defined by a module.""" import gc @@ -19,8 +18,6 @@ import sys import types -import six - # Maximum traversal depth when looking for all the code objects referenced by # a module or another code object. _MAX_REFERENTS_BFS_DEPTH = 15 @@ -35,9 +32,8 @@ _MAX_OBJECT_REFERENTS = 1000 # Object types to ignore when looking for the code objects. -_BFS_IGNORE_TYPES = (types.ModuleType, type(None), bool, float, six.binary_type, - six.text_type, types.BuiltinFunctionType, - types.BuiltinMethodType, list) + six.integer_types +_BFS_IGNORE_TYPES = (types.ModuleType, type(None), bool, float, bytes, str, int, + types.BuiltinFunctionType, types.BuiltinMethodType, list) def GetCodeObjectAtLine(module, line): @@ -56,7 +52,7 @@ def GetCodeObjectAtLine(module, line): return (False, (None, None)) prev_line = 0 - next_line = six.MAXSIZE + next_line = sys.maxsize for code_object in _GetModuleCodeObjects(module): for co_line_number in _GetLineNumbers(code_object): @@ -66,10 +62,10 @@ def GetCodeObjectAtLine(module, line): prev_line = max(prev_line, co_line_number) elif co_line_number > line: next_line = min(next_line, co_line_number) - break + # Continue because line numbers may not be sequential. prev_line = None if prev_line == 0 else prev_line - next_line = None if next_line == six.MAXSIZE else next_line + next_line = None if next_line == sys.maxsize else next_line return (False, (prev_line, next_line)) @@ -82,19 +78,27 @@ def _GetLineNumbers(code_object): Yields: The next line number in the code object. """ - # Get the line number deltas, which are the odd number entries, from the - # lnotab. See - # https://svn.python.org/projects/python/branches/pep-0384/Objects/lnotab_notes.txt - # In Python 3, this is just a byte array. In Python 2 it is a string so the - # numerical values have to be extracted from the individual characters. - if six.PY3: + + if sys.version_info.minor < 10: + # Get the line number deltas, which are the odd number entries, from the + # lnotab. See + # https://svn.python.org/projects/python/branches/pep-0384/Objects/lnotab_notes.txt + # In Python 3, prior to 3.10, this is just a byte array. line_incrs = code_object.co_lnotab[1::2] + current_line = code_object.co_firstlineno + for line_incr in line_incrs: + if line_incr >= 0x80: + # line_incrs is an array of 8-bit signed integers + line_incr -= 0x100 + current_line += line_incr + yield current_line else: - line_incrs = (ord(c) for c in code_object.co_lnotab[1::2]) - current_line = code_object.co_firstlineno - for line_incr in line_incrs: - current_line += line_incr - yield current_line + # Get the line numbers directly, which are the third entry in the tuples. + # https://peps.python.org/pep-0626/#the-new-co-lines-method-of-code-objects + line_numbers = [entry[2] for entry in code_object.co_lines()] + for line_number in line_numbers: + if line_number is not None: + yield line_number def _GetModuleCodeObjects(module): @@ -154,6 +158,7 @@ def _FindCodeObjectsReferents(module, start_objects, visit_recorder): Returns: List of code objects. """ + def CheckIgnoreCodeObject(code_object): """Checks if the code object can be ignored. @@ -188,14 +193,13 @@ def CheckIgnoreClass(cls): if not cls_module: return False # We can't tell for sure, so explore this class. - return ( - cls_module is not module and - getattr(cls_module, '__file__', None) != module.__file__) + return (cls_module is not module and + getattr(cls_module, '__file__', None) != module.__file__) code_objects = set() current = start_objects for obj in current: - visit_recorder.Record(current) + visit_recorder.Record(obj) depth = 0 while current and depth < _MAX_REFERENTS_BFS_DEPTH: @@ -213,7 +217,7 @@ def CheckIgnoreClass(cls): if isinstance(obj, types.CodeType) and CheckIgnoreCodeObject(obj): continue - if isinstance(obj, six.class_types) and CheckIgnoreClass(obj): + if isinstance(obj, type) and CheckIgnoreClass(obj): continue if isinstance(obj, types.CodeType): diff --git a/src/googleclouddebugger/module_search2.py b/src/googleclouddebugger/module_search.py similarity index 99% rename from src/googleclouddebugger/module_search2.py rename to src/googleclouddebugger/module_search.py index f7e5de8..e8d29f3 100644 --- a/src/googleclouddebugger/module_search2.py +++ b/src/googleclouddebugger/module_search.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Inclusive search for module files.""" import os @@ -68,6 +67,7 @@ def Search(path): AssertionError: if the provided path is an absolute path, or if it does not have a .py extension. """ + def SearchCandidates(p): """Generates all candidates for the fuzzy search of p.""" while p: @@ -103,4 +103,3 @@ def SearchCandidates(p): # A matching file was not found in sys.path directories. return path - diff --git a/src/googleclouddebugger/module_utils2.py b/src/googleclouddebugger/module_utils.py similarity index 68% rename from src/googleclouddebugger/module_utils2.py rename to src/googleclouddebugger/module_utils.py index 996209f..53f2e37 100644 --- a/src/googleclouddebugger/module_utils2.py +++ b/src/googleclouddebugger/module_utils.py @@ -11,12 +11,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Provides utility functions for module path processing.""" import os import sys +def NormalizePath(path): + """Normalizes a path. + + E.g. One example is it will convert "/a/b/./c" -> "/a/b/c" + """ + # TODO: Calling os.path.normpath "may change the meaning of a + # path that contains symbolic links" (e.g., "A/foo/../B" != "A/B" if foo is a + # symlink). This might cause trouble when matching against loaded module + # paths. We should try to avoid using it. + # Example: + # > import symlink.a + # > symlink.a.__file__ + # symlink/a.py + # > import target.a + # > starget.a.__file__ + # target/a.py + # Python interpreter treats these as two separate modules. So, we also need to + # handle them the same way. + return os.path.normpath(path) + def IsPathSuffix(mod_path, path): """Checks whether path is a full path suffix of mod_path. @@ -29,9 +48,8 @@ def IsPathSuffix(mod_path, path): Returns: True if path is a full path suffix of mod_path. False otherwise. """ - return (mod_path.endswith(path) and - (len(mod_path) == len(path) or - mod_path[:-len(path)].endswith(os.sep))) + return (mod_path.endswith(path) and (len(mod_path) == len(path) or + mod_path[:-len(path)].endswith(os.sep))) def GetLoadedModuleBySuffix(path): @@ -71,6 +89,11 @@ def GetLoadedModuleBySuffix(path): if not os.path.isabs(mod_root): mod_root = os.path.join(os.getcwd(), mod_root) + # In the following invocation 'python3 ./main.py' (using the ./), the + # mod_root variable will '/base/path/./main'. In order to correctly compare + # it with the root variable, it needs to be '/base/path/main'. + mod_root = NormalizePath(mod_root) + if IsPathSuffix(mod_root, root): return module diff --git a/src/googleclouddebugger/native_module.cc b/src/googleclouddebugger/native_module.cc index 4a66c4f..60a9a8a 100644 --- a/src/googleclouddebugger/native_module.cc +++ b/src/googleclouddebugger/native_module.cc @@ -176,7 +176,7 @@ static PyObject* LogError(PyObject* self, PyObject* py_args) { } -// Sets a new breakpoint in Python code. The breakpoint may have an optional +// Creates a new breakpoint in Python code. The breakpoint may have an optional // condition to evaluate. When the breakpoint hits (and the condition matches) // a callable object will be invoked from that thread. // @@ -196,7 +196,8 @@ static PyObject* LogError(PyObject* self, PyObject* py_args) { // Returns: // Integer cookie identifying this breakpoint. It needs to be specified when // clearing the breakpoint. -static PyObject* SetConditionalBreakpoint(PyObject* self, PyObject* py_args) { +static PyObject* CreateConditionalBreakpoint(PyObject* self, + PyObject* py_args) { PyCodeObject* code_object = nullptr; int line = -1; PyCodeObject* condition = nullptr; @@ -238,7 +239,7 @@ static PyObject* SetConditionalBreakpoint(PyObject* self, PyObject* py_args) { int cookie = -1; - cookie = g_bytecode_breakpoint.SetBreakpoint( + cookie = g_bytecode_breakpoint.CreateBreakpoint( code_object, line, std::bind( @@ -255,11 +256,11 @@ static PyObject* SetConditionalBreakpoint(PyObject* self, PyObject* py_args) { } -// Clears the breakpoint previously set by "SetConditionalBreakpoint". Must be -// called exactly once per each call to "SetConditionalBreakpoint". +// Clears a breakpoint previously created by "CreateConditionalBreakpoint". Must +// be called exactly once per each call to "CreateConditionalBreakpoint". // // Args: -// cookie: breakpoint identifier returned by "SetConditionalBreakpoint". +// cookie: breakpoint identifier returned by "CreateConditionalBreakpoint". static PyObject* ClearConditionalBreakpoint(PyObject* self, PyObject* py_args) { int cookie = -1; if (!PyArg_ParseTuple(py_args, "i", &cookie)) { @@ -271,6 +272,24 @@ static PyObject* ClearConditionalBreakpoint(PyObject* self, PyObject* py_args) { Py_RETURN_NONE; } +// Activates a previously created breakpoint by "CreateConditionalBreakpoint" +// and that haven't been cleared yet using "ClearConditionalBreakpoint". +// TODO: Optimize breakpoint activation by having one method +// "ActivateAllConditionalBreakpoints" for all previously created breakpoints. +// +// Args: +// cookie: breakpoint identifier returned by "CreateConditionalBreakpoint". +static PyObject* ActivateConditionalBreakpoint(PyObject* self, + PyObject* py_args) { + int cookie = -1; + if (!PyArg_ParseTuple(py_args, "i", &cookie)) { + return nullptr; + } + + g_bytecode_breakpoint.ActivateBreakpoint(cookie); + + Py_RETURN_NONE; +} // Invokes a Python callable object with immutability tracer. // @@ -369,16 +388,22 @@ static PyMethodDef g_module_functions[] = { "ERROR level logging from Python code." }, { - "SetConditionalBreakpoint", - SetConditionalBreakpoint, + "CreateConditionalBreakpoint", + CreateConditionalBreakpoint, + METH_VARARGS, + "Creates a new breakpoint in Python code." + }, + { + "ActivateConditionalBreakpoint", + ActivateConditionalBreakpoint, METH_VARARGS, - "Sets a new breakpoint in Python code." + "Activates previously created breakpoint in Python code." }, { "ClearConditionalBreakpoint", ClearConditionalBreakpoint, METH_VARARGS, - "Clears previously set breakpoint in Python code." + "Clears previously created breakpoint in Python code." }, { "CallImmutable", diff --git a/src/googleclouddebugger/python_breakpoint.py b/src/googleclouddebugger/python_breakpoint.py index a3ed6bb..62f2512 100644 --- a/src/googleclouddebugger/python_breakpoint.py +++ b/src/googleclouddebugger/python_breakpoint.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Handles a single Python breakpoint.""" from datetime import datetime @@ -19,12 +18,12 @@ import os from threading import Lock -from . import capture_collector +from . import collector from . import cdbg_native as native -from . import imphook2 +from . import imphook from . import module_explorer -from . import module_search2 -from . import module_utils2 +from . import module_search +from . import module_utils # TODO: move to messages.py module. # Use the following schema to define breakpoint error message constant: @@ -36,8 +35,7 @@ 'version of the service you are trying to debug.') ERROR_LOCATION_MULTIPLE_MODULES_1 = ( 'Multiple modules matching $0. Please specify the module path.') -ERROR_LOCATION_MULTIPLE_MODULES_3 = ( - 'Multiple modules matching $0 ($1, $2)') +ERROR_LOCATION_MULTIPLE_MODULES_3 = ('Multiple modules matching $0 ($1, $2)') ERROR_LOCATION_MULTIPLE_MODULES_4 = ( 'Multiple modules matching $0 ($1, $2, and $3 more)') ERROR_LOCATION_NO_CODE_FOUND_AT_LINE_2 = 'No code found at line $0 in $1' @@ -54,30 +52,40 @@ 'the snapshot to a less frequently called statement.') ERROR_CONDITION_MUTABLE_0 = ( 'Only immutable expressions can be used in snapshot conditions') -ERROR_AGE_SNAPSHOT_EXPIRED_0 = ( - 'The snapshot has expired') -ERROR_AGE_LOGPOINT_EXPIRED_0 = ( - 'The logpoint has expired') -ERROR_UNSPECIFIED_INTERNAL_ERROR = ( - 'Internal error occurred') +ERROR_AGE_SNAPSHOT_EXPIRED_0 = ('The snapshot has expired') +ERROR_AGE_LOGPOINT_EXPIRED_0 = ('The logpoint has expired') +ERROR_UNSPECIFIED_INTERNAL_ERROR = ('Internal error occurred') # Status messages for different breakpoint events (except of "hit"). -_BREAKPOINT_EVENT_STATUS = dict( - [(native.BREAKPOINT_EVENT_ERROR, - {'isError': True, - 'description': {'format': ERROR_UNSPECIFIED_INTERNAL_ERROR}}), - (native.BREAKPOINT_EVENT_GLOBAL_CONDITION_QUOTA_EXCEEDED, - {'isError': True, - 'refersTo': 'BREAKPOINT_CONDITION', - 'description': {'format': ERROR_CONDITION_GLOBAL_QUOTA_EXCEEDED_0}}), - (native.BREAKPOINT_EVENT_BREAKPOINT_CONDITION_QUOTA_EXCEEDED, - {'isError': True, - 'refersTo': 'BREAKPOINT_CONDITION', - 'description': {'format': ERROR_CONDITION_BREAKPOINT_QUOTA_EXCEEDED_0}}), - (native.BREAKPOINT_EVENT_CONDITION_EXPRESSION_MUTABLE, - {'isError': True, - 'refersTo': 'BREAKPOINT_CONDITION', - 'description': {'format': ERROR_CONDITION_MUTABLE_0}})]) +_BREAKPOINT_EVENT_STATUS = dict([ + (native.BREAKPOINT_EVENT_ERROR, { + 'isError': True, + 'description': { + 'format': ERROR_UNSPECIFIED_INTERNAL_ERROR + } + }), + (native.BREAKPOINT_EVENT_GLOBAL_CONDITION_QUOTA_EXCEEDED, { + 'isError': True, + 'refersTo': 'BREAKPOINT_CONDITION', + 'description': { + 'format': ERROR_CONDITION_GLOBAL_QUOTA_EXCEEDED_0 + } + }), + (native.BREAKPOINT_EVENT_BREAKPOINT_CONDITION_QUOTA_EXCEEDED, { + 'isError': True, + 'refersTo': 'BREAKPOINT_CONDITION', + 'description': { + 'format': ERROR_CONDITION_BREAKPOINT_QUOTA_EXCEEDED_0 + } + }), + (native.BREAKPOINT_EVENT_CONDITION_EXPRESSION_MUTABLE, { + 'isError': True, + 'refersTo': 'BREAKPOINT_CONDITION', + 'description': { + 'format': ERROR_CONDITION_MUTABLE_0 + } + }) +]) # The implementation of datetime.strptime imports an undocumented module called # _strptime. If it happens at the wrong time, we can get an exception about @@ -126,20 +134,7 @@ def _MultipleModulesFoundError(path, candidates): def _NormalizePath(path): """Removes surrounding whitespace, leading separator and normalize.""" - # TODO: Calling os.path.normpath "may change the meaning of a - # path that contains symbolic links" (e.g., "A/foo/../B" != "A/B" if foo is a - # symlink). This might cause trouble when matching against loaded module - # paths. We should try to avoid using it. - # Example: - # > import symlink.a - # > symlink.a.__file__ - # symlink/a.py - # > import target.a - # > starget.a.__file__ - # target/a.py - # Python interpreter treats these as two separate modules. So, we also need to - # handle them the same way. - return os.path.normpath(path.strip().lstrip(os.sep)) + return module_utils.NormalizePath(path.strip().lstrip(os.sep)) class PythonBreakpoint(object): @@ -186,7 +181,7 @@ def __init__(self, definition, hub_client, breakpoints_manager, self._completed = False if self.definition.get('action') == 'LOG': - self._collector = capture_collector.LogCollector(self.definition) + self._collector = collector.LogCollector(self.definition) path = _NormalizePath(self.definition['location']['path']) @@ -196,7 +191,11 @@ def __init__(self, definition, hub_client, breakpoints_manager, 'status': { 'isError': True, 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', - 'description': {'format': ERROR_LOCATION_FILE_EXTENSION_0}}}) + 'description': { + 'format': ERROR_LOCATION_FILE_EXTENSION_0 + } + } + }) return # A flat init file is too generic; path must include package name. @@ -207,18 +206,20 @@ def __init__(self, definition, hub_client, breakpoints_manager, 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', 'description': { 'format': ERROR_LOCATION_MULTIPLE_MODULES_1, - 'parameters': [path]}}}) + 'parameters': [path] + } + } + }) return - new_path = module_search2.Search(path) - new_module = module_utils2.GetLoadedModuleBySuffix(new_path) + new_path = module_search.Search(path) + new_module = module_utils.GetLoadedModuleBySuffix(new_path) if new_module: self._ActivateBreakpoint(new_module) else: - self._import_hook_cleanup = imphook2.AddImportCallbackBySuffix( - new_path, - self._ActivateBreakpoint) + self._import_hook_cleanup = imphook.AddImportCallbackBySuffix( + new_path, self._ActivateBreakpoint) def Clear(self): """Clears the breakpoint and releases all breakpoint resources. @@ -238,16 +239,40 @@ def GetBreakpointId(self): return self.definition['id'] def GetExpirationTime(self): - """Computes the timestamp at which this breakpoint will expire.""" - # TODO: Move this to a common method. - if '.' not in self.definition['createTime']: + """Computes the timestamp at which this breakpoint will expire. + + If no creation time can be found an expiration time in the past will be + used. + """ + return self.GetCreateTime() + self.expiration_period + + def GetCreateTime(self): + """Retrieves the creation time of this breakpoint. + + If no creation time can be found a creation time in the past will be used. + """ + if 'createTime' in self.definition: + return self.GetTimeFromRfc3339Str(self.definition['createTime']) + else: + return self.GetTimeFromUnixMsec( + self.definition.get('createTimeUnixMsec', 0)) + + def GetTimeFromRfc3339Str(self, rfc3339_str): + if '.' not in rfc3339_str: fmt = '%Y-%m-%dT%H:%M:%S%Z' else: fmt = '%Y-%m-%dT%H:%M:%S.%f%Z' - create_datetime = datetime.strptime( - self.definition['createTime'].replace('Z', 'UTC'), fmt) - return create_datetime + self.expiration_period + return datetime.strptime(rfc3339_str.replace('Z', 'UTC'), fmt) + + def GetTimeFromUnixMsec(self, unix_msec): + try: + return datetime.fromtimestamp(unix_msec / 1000) + except (TypeError, ValueError, OSError, OverflowError) as e: + native.LogWarning( + 'Unexpected error (%s) occured processing unix_msec %s, breakpoint: %s' + % (repr(e), str(unix_msec), self.GetBreakpointId())) + return datetime.fromtimestamp(0) def ExpireBreakpoint(self): """Expires this breakpoint.""" @@ -263,7 +288,11 @@ def ExpireBreakpoint(self): 'status': { 'isError': True, 'refersTo': 'BREAKPOINT_AGE', - 'description': {'format': message}}}) + 'description': { + 'format': message + } + } + }) def _ActivateBreakpoint(self, module): """Sets the breakpoint in the loaded module, or complete with error.""" @@ -300,16 +329,18 @@ def _ActivateBreakpoint(self, module): 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', 'description': { 'format': fmt, - 'parameters': params}}}) + 'parameters': params + } + } + }) return # Compile the breakpoint condition. condition = None if self.definition.get('condition'): try: - condition = compile(self.definition.get('condition'), - '', - 'eval') + condition = compile( + self.definition.get('condition'), '', 'eval') except (TypeError, ValueError) as e: # condition string contains null bytes. self._CompleteBreakpoint({ @@ -318,7 +349,10 @@ def _ActivateBreakpoint(self, module): 'refersTo': 'BREAKPOINT_CONDITION', 'description': { 'format': 'Invalid expression', - 'parameters': [str(e)]}}}) + 'parameters': [str(e)] + } + } + }) return except SyntaxError as e: @@ -328,17 +362,19 @@ def _ActivateBreakpoint(self, module): 'refersTo': 'BREAKPOINT_CONDITION', 'description': { 'format': 'Expression could not be compiled: $0', - 'parameters': [e.msg]}}}) + 'parameters': [e.msg] + } + } + }) return - native.LogInfo('Creating new Python breakpoint %s in %s, line %d' % ( - self.GetBreakpointId(), codeobj, line)) + native.LogInfo('Creating new Python breakpoint %s in %s, line %d' % + (self.GetBreakpointId(), codeobj, line)) + + self._cookie = native.CreateConditionalBreakpoint(codeobj, line, condition, + self._BreakpointEvent) - self._cookie = native.SetConditionalBreakpoint( - codeobj, - line, - condition, - self._BreakpointEvent) + native.ActivateConditionalBreakpoint(self._cookie) def _RemoveImportHook(self): """Removes the import hook if one was installed.""" @@ -395,27 +431,32 @@ def _BreakpointEvent(self, event, frame): self._CompleteBreakpoint({'status': error_status}) return - collector = capture_collector.CaptureCollector( - self.definition, self.data_visibility_policy) + capture_collector = collector.CaptureCollector(self.definition, + self.data_visibility_policy) # TODO: This is a temporary try/except. All exceptions should be # caught inside Collect and converted into breakpoint error messages. try: - collector.Collect(frame) + capture_collector.Collect(frame) except BaseException as e: # pylint: disable=broad-except native.LogInfo('Internal error during data capture: %s' % repr(e)) - error_status = {'isError': True, - 'description': { - 'format': ('Internal error while capturing data: %s' % - repr(e))}} + error_status = { + 'isError': True, + 'description': { + 'format': ('Internal error while capturing data: %s' % repr(e)) + } + } self._CompleteBreakpoint({'status': error_status}) return except: # pylint: disable=bare-except native.LogInfo('Unknown exception raised') - error_status = {'isError': True, - 'description': { - 'format': 'Unknown internal error'}} + error_status = { + 'isError': True, + 'description': { + 'format': 'Unknown internal error' + } + } self._CompleteBreakpoint({'status': error_status}) return - self._CompleteBreakpoint(collector.breakpoint, is_incremental=False) + self._CompleteBreakpoint(capture_collector.breakpoint, is_incremental=False) diff --git a/src/googleclouddebugger/python_util.cc b/src/googleclouddebugger/python_util.cc index 90b67ce..bc03bfc 100644 --- a/src/googleclouddebugger/python_util.cc +++ b/src/googleclouddebugger/python_util.cc @@ -23,6 +23,11 @@ #include +#if PY_VERSION_HEX >= 0x030A0000 +#include "../third_party/pylinetable.h" +#endif // PY_VERSION_HEX >= 0x030A0000 + + namespace devtools { namespace cdbg { @@ -32,17 +37,22 @@ static PyObject* g_debuglet_module = nullptr; CodeObjectLinesEnumerator::CodeObjectLinesEnumerator( PyCodeObject* code_object) { +#if PY_VERSION_HEX < 0x030A0000 Initialize(code_object->co_firstlineno, code_object->co_lnotab); +#else + Initialize(code_object->co_firstlineno, code_object->co_linetable); +#endif // PY_VERSION_HEX < 0x030A0000 } CodeObjectLinesEnumerator::CodeObjectLinesEnumerator( int firstlineno, - PyObject* lnotab) { - Initialize(firstlineno, lnotab); + PyObject* linedata) { + Initialize(firstlineno, linedata); } +#if PY_VERSION_HEX < 0x030A0000 void CodeObjectLinesEnumerator::Initialize( int firstlineno, PyObject* lnotab) { @@ -69,7 +79,7 @@ bool CodeObjectLinesEnumerator::Next() { while (true) { offset_ += next_entry_[0]; - line_number_ += next_entry_[1]; + line_number_ += static_cast(next_entry_[1]); bool stop = ((next_entry_[0] != 0xFF) || (next_entry_[1] != 0)) && ((next_entry_[0] != 0) || (next_entry_[1] != 0xFF)); @@ -86,7 +96,26 @@ bool CodeObjectLinesEnumerator::Next() { } } } +#else + +void CodeObjectLinesEnumerator::Initialize( + int firstlineno, + PyObject* linetable) { + Py_ssize_t length = PyBytes_Size(linetable); + _PyLineTable_InitAddressRange(PyBytes_AsString(linetable), length, firstlineno, &range_); +} +bool CodeObjectLinesEnumerator::Next() { + while (_PyLineTable_NextAddressRange(&range_)) { + if (range_.ar_line >= 0) { + line_number_ = range_.ar_line; + offset_ = range_.ar_start; + return true; + } + } + return false; +} +#endif // PY_VERSION_HEX < 0x030A0000 PyObject* GetDebugletModule() { DCHECK(g_debuglet_module != nullptr); diff --git a/src/googleclouddebugger/python_util.h b/src/googleclouddebugger/python_util.h index 57b5425..10116be 100644 --- a/src/googleclouddebugger/python_util.h +++ b/src/googleclouddebugger/python_util.h @@ -178,7 +178,7 @@ class CodeObjectLinesEnumerator { explicit CodeObjectLinesEnumerator(PyCodeObject* code_object); // Uses explicitly provided line table. - CodeObjectLinesEnumerator(int firstlineno, PyObject* lnotab); + CodeObjectLinesEnumerator(int firstlineno, PyObject* linedata); // Moves over to the next entry in code object line table. bool Next(); @@ -190,24 +190,31 @@ class CodeObjectLinesEnumerator { int32_t line_number() const { return line_number_; } private: - void Initialize(int firstlineno, PyObject* lnotab); + void Initialize(int firstlineno, PyObject* linedata); private: + // Bytecode offset of the current line. + int32_t offset_; + + // Current source code line number + int32_t line_number_; + +#if PY_VERSION_HEX < 0x030A0000 // Number of remaining entries in line table. int remaining_entries_; // Pointer to the next entry of line table. const uint8_t* next_entry_; - // Bytecode offset of the current line. - int32_t offset_; - - // Current source code line number - int32_t line_number_; +#else + // Current address range in the linetable data. + PyCodeAddressRange range_; +#endif DISALLOW_COPY_AND_ASSIGN(CodeObjectLinesEnumerator); }; + template bool operator== (TPointer* ref1, const ScopedPyObjectT& ref2) { return ref2 == ref1; diff --git a/src/googleclouddebugger/uniquifier_computer.py b/src/googleclouddebugger/uniquifier_computer.py index 873b110..8395f33 100644 --- a/src/googleclouddebugger/uniquifier_computer.py +++ b/src/googleclouddebugger/uniquifier_computer.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Computes a unique identifier of the deployed application. When the application runs under AppEngine, the deployment is uniquely @@ -28,7 +27,6 @@ import os import sys - # Maximum recursion depth to follow when traversing the file system. This limit # will prevent stack overflow in case of a loop created by symbolic links. _MAX_DEPTH = 10 @@ -93,8 +91,7 @@ def ProcessDirectory(path, relative_path, depth=1): modules.add(file_name) ProcessApplicationFile(current_path, os.path.join(relative_path, name)) elif IsPackage(current_path): - ProcessDirectory(current_path, - os.path.join(relative_path, name), + ProcessDirectory(current_path, os.path.join(relative_path, name), depth + 1) def IsPackage(path): diff --git a/src/googleclouddebugger/version.py b/src/googleclouddebugger/version.py index cb89582..3b0f00f 100644 --- a/src/googleclouddebugger/version.py +++ b/src/googleclouddebugger/version.py @@ -4,4 +4,4 @@ # The major version should only change on breaking changes. Minor version # changes go between regular updates. Instances running debuggers with # different major versions will show up as two different debuggees. -__version__ = '2.18' +__version__ = '4.1' diff --git a/src/googleclouddebugger/yaml_data_visibility_config_reader.py b/src/googleclouddebugger/yaml_data_visibility_config_reader.py index 198af80..dc75673 100644 --- a/src/googleclouddebugger/yaml_data_visibility_config_reader.py +++ b/src/googleclouddebugger/yaml_data_visibility_config_reader.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Reads a YAML configuration file to determine visibility policy. Example Usage: @@ -27,7 +26,6 @@ import os import sys -import six import yaml @@ -114,8 +112,7 @@ def Read(f): try: return Config( - yaml_data.get('blacklist', ()), - yaml_data.get('whitelist', ('*'))) + yaml_data.get('blacklist', ()), yaml_data.get('whitelist', ('*'))) except UnicodeDecodeError as e: raise YAMLLoadError('%s' % e) @@ -125,10 +122,10 @@ def _CheckData(yaml_data): legal_keys = set(('blacklist', 'whitelist')) unknown_keys = set(yaml_data) - legal_keys if unknown_keys: - raise UnknownConfigKeyError( - 'Unknown keys in configuration: %s' % unknown_keys) + raise UnknownConfigKeyError('Unknown keys in configuration: %s' % + unknown_keys) - for key, data in six.iteritems(yaml_data): + for key, data in yaml_data.items(): _AssertDataIsList(key, data) diff --git a/src/setup.py b/src/setup.py index c6a1c4d..25f6095 100644 --- a/src/setup.py +++ b/src/setup.py @@ -11,15 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Python Cloud Debugger build and packaging script.""" -# pylint: disable=g-statement-before-imports,g-import-not-at-top -try: - from ConfigParser import ConfigParser # Python 2 -except ImportError: - from configparser import ConfigParser # Python 3 -# pylint: enable=g-statement-before-imports,g-import-not-at-top +from configparser import ConfigParser from glob import glob import os import re @@ -54,8 +48,7 @@ def ReadConfig(section, value, default): 'For more details please see ' 'https://github.com/GoogleCloudPlatform/cloud-debug-python\n') -lib_dirs = ReadConfig('build_ext', - 'library_dirs', +lib_dirs = ReadConfig('build_ext', 'library_dirs', sysconfig.get_config_var('LIBDIR')).split(':') extra_compile_args = ReadConfig('cc_options', 'extra_compile_args', '').split() extra_link_args = ReadConfig('cc_options', 'extra_link_args', '').split() @@ -70,9 +63,10 @@ def ReadConfig(section, value, default): assert len(static_libs) == len(deps), (static_libs, deps, lib_dirs) cvars = sysconfig.get_config_vars() -cvars['OPT'] = str.join(' ', RemovePrefixes( - cvars.get('OPT').split(), - ['-g', '-O', '-Wstrict-prototypes'])) +cvars['OPT'] = str.join( + ' ', + RemovePrefixes( + cvars.get('OPT').split(), ['-g', '-O', '-Wstrict-prototypes'])) # Determine the current version of the package. The easiest way would be to # import "googleclouddebugger" and read its __version__ attribute. @@ -107,27 +101,19 @@ def ReadConfig(section, value, default): author='Google Inc.', version=version, install_requires=[ - 'enum34; python_version < "3.4"', - 'google-api-python-client==1.8.4; python_version < "3.0"', - 'google-api-python-client; python_version > "3.0"', - 'google-auth==1.8.2; python_version < "3.0"', - 'google-auth>=1.0.0; python_version > "3.0"', - 'google-auth-httplib2', - 'google-api-core==1.15.0; python_version < "3.0"', - 'google-api-core; python_version > "3.0"', + 'firebase-admin>=5.3.0', 'pyyaml', - 'six>=1.10.0', ], packages=['googleclouddebugger'], ext_modules=[cdbg_native_module], license='Apache License, Version 2.0', keywords='google cloud debugger', classifiers=[ - 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', 'Development Status :: 3 - Alpha', 'Intended Audience :: Developers', ]) diff --git a/src/third_party/BUILD b/src/third_party/BUILD new file mode 100644 index 0000000..bcce1e2 --- /dev/null +++ b/src/third_party/BUILD @@ -0,0 +1,7 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "pylinetable", + hdrs = ["pylinetable.h"], +) + diff --git a/src/third_party/pylinetable.h b/src/third_party/pylinetable.h new file mode 100644 index 0000000..ea44c64 --- /dev/null +++ b/src/third_party/pylinetable.h @@ -0,0 +1,210 @@ +/** + * Copyright (c) 2001-2023 Python Software Foundation; All Rights Reserved + * + * You may obtain a copy of the PSF License at + * + * https://docs.python.org/3/license.html + */ + +#ifndef DEVTOOLS_CDBG_DEBUGLETS_PYTHON_PYLINETABLE_H_ +#define DEVTOOLS_CDBG_DEBUGLETS_PYTHON_PYLINETABLE_H_ + +/* Python Linetable helper methods. + * They are not part of the cpython api. + * This code has been extracted from: + * https://github.com/python/cpython/blob/main/Objects/codeobject.c + * + * See https://peps.python.org/pep-0626/#out-of-process-debuggers-and-profilers + * for more information about this code and its usage. + */ + +#if PY_VERSION_HEX >= 0x030B0000 +// Things are different in 3.11 than 3.10. +// See https://github.com/python/cpython/blob/main/Objects/locations.md + +typedef enum _PyCodeLocationInfoKind { + /* short forms are 0 to 9 */ + PY_CODE_LOCATION_INFO_SHORT0 = 0, + /* one lineforms are 10 to 12 */ + PY_CODE_LOCATION_INFO_ONE_LINE0 = 10, + PY_CODE_LOCATION_INFO_ONE_LINE1 = 11, + PY_CODE_LOCATION_INFO_ONE_LINE2 = 12, + + PY_CODE_LOCATION_INFO_NO_COLUMNS = 13, + PY_CODE_LOCATION_INFO_LONG = 14, + PY_CODE_LOCATION_INFO_NONE = 15 +} _PyCodeLocationInfoKind; + +/** Out of process API for initializing the location table. */ +extern void _PyLineTable_InitAddressRange( + const char *linetable, + Py_ssize_t length, + int firstlineno, + PyCodeAddressRange *range); + +/** API for traversing the line number table. */ +extern int _PyLineTable_NextAddressRange(PyCodeAddressRange *range); + + +void _PyLineTable_InitAddressRange(const char *linetable, Py_ssize_t length, int firstlineno, PyCodeAddressRange *range) { + range->opaque.lo_next = linetable; + range->opaque.limit = range->opaque.lo_next + length; + range->ar_start = -1; + range->ar_end = 0; + range->opaque.computed_line = firstlineno; + range->ar_line = -1; +} + +static int +scan_varint(const uint8_t *ptr) +{ + unsigned int read = *ptr++; + unsigned int val = read & 63; + unsigned int shift = 0; + while (read & 64) { + read = *ptr++; + shift += 6; + val |= (read & 63) << shift; + } + return val; +} + +static int +scan_signed_varint(const uint8_t *ptr) +{ + unsigned int uval = scan_varint(ptr); + if (uval & 1) { + return -(int)(uval >> 1); + } + else { + return uval >> 1; + } +} + +static int +get_line_delta(const uint8_t *ptr) +{ + int code = ((*ptr) >> 3) & 15; + switch (code) { + case PY_CODE_LOCATION_INFO_NONE: + return 0; + case PY_CODE_LOCATION_INFO_NO_COLUMNS: + case PY_CODE_LOCATION_INFO_LONG: + return scan_signed_varint(ptr+1); + case PY_CODE_LOCATION_INFO_ONE_LINE0: + return 0; + case PY_CODE_LOCATION_INFO_ONE_LINE1: + return 1; + case PY_CODE_LOCATION_INFO_ONE_LINE2: + return 2; + default: + /* Same line */ + return 0; + } +} + +static int +is_no_line_marker(uint8_t b) +{ + return (b >> 3) == 0x1f; +} + + +#define ASSERT_VALID_BOUNDS(bounds) \ + assert(bounds->opaque.lo_next <= bounds->opaque.limit && \ + (bounds->ar_line == -1 || bounds->ar_line == bounds->opaque.computed_line) && \ + (bounds->opaque.lo_next == bounds->opaque.limit || \ + (*bounds->opaque.lo_next) & 128)) + +static int +next_code_delta(PyCodeAddressRange *bounds) +{ + assert((*bounds->opaque.lo_next) & 128); + return (((*bounds->opaque.lo_next) & 7) + 1) * sizeof(_Py_CODEUNIT); +} + +static void +advance(PyCodeAddressRange *bounds) +{ + ASSERT_VALID_BOUNDS(bounds); + bounds->opaque.computed_line += get_line_delta(reinterpret_cast(bounds->opaque.lo_next)); + if (is_no_line_marker(*bounds->opaque.lo_next)) { + bounds->ar_line = -1; + } + else { + bounds->ar_line = bounds->opaque.computed_line; + } + bounds->ar_start = bounds->ar_end; + bounds->ar_end += next_code_delta(bounds); + do { + bounds->opaque.lo_next++; + } while (bounds->opaque.lo_next < bounds->opaque.limit && + ((*bounds->opaque.lo_next) & 128) == 0); + ASSERT_VALID_BOUNDS(bounds); +} + +static inline int +at_end(PyCodeAddressRange *bounds) { + return bounds->opaque.lo_next >= bounds->opaque.limit; +} + +int +_PyLineTable_NextAddressRange(PyCodeAddressRange *range) +{ + if (at_end(range)) { + return 0; + } + advance(range); + assert(range->ar_end > range->ar_start); + return 1; +} +#elif PY_VERSION_HEX >= 0x030A0000 +void +_PyLineTable_InitAddressRange(const char *linetable, Py_ssize_t length, int firstlineno, PyCodeAddressRange *range) +{ + range->opaque.lo_next = linetable; + range->opaque.limit = range->opaque.lo_next + length; + range->ar_start = -1; + range->ar_end = 0; + range->opaque.computed_line = firstlineno; + range->ar_line = -1; +} + +static void +advance(PyCodeAddressRange *bounds) +{ + bounds->ar_start = bounds->ar_end; + int delta = ((unsigned char *)bounds->opaque.lo_next)[0]; + bounds->ar_end += delta; + int ldelta = ((signed char *)bounds->opaque.lo_next)[1]; + bounds->opaque.lo_next += 2; + if (ldelta == -128) { + bounds->ar_line = -1; + } + else { + bounds->opaque.computed_line += ldelta; + bounds->ar_line = bounds->opaque.computed_line; + } +} + +static inline int +at_end(PyCodeAddressRange *bounds) { + return bounds->opaque.lo_next >= bounds->opaque.limit; +} + +int +_PyLineTable_NextAddressRange(PyCodeAddressRange *range) +{ + if (at_end(range)) { + return 0; + } + advance(range); + while (range->ar_start == range->ar_end) { + assert(!at_end(range)); + advance(range); + } + return 1; +} +#endif + +#endif // DEVTOOLS_CDBG_DEBUGLETS_PYTHON_PYLINETABLE_H_ diff --git a/tests/cpp/BUILD b/tests/cpp/BUILD new file mode 100644 index 0000000..7536fe9 --- /dev/null +++ b/tests/cpp/BUILD @@ -0,0 +1,10 @@ +package(default_visibility = ["//visibility:public"]) + +cc_test( + name = "bytecode_manipulator_test", + srcs = ["bytecode_manipulator_test.cc"], + deps = [ + "//src/googleclouddebugger:bytecode_manipulator", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tests/cpp/bytecode_manipulator_test.cc b/tests/cpp/bytecode_manipulator_test.cc new file mode 100644 index 0000000..934dfef --- /dev/null +++ b/tests/cpp/bytecode_manipulator_test.cc @@ -0,0 +1,1059 @@ +/** + * Copyright 2023 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/googleclouddebugger/bytecode_manipulator.h" + +#include +#include + +namespace devtools { +namespace cdbg { + +static std::string FormatOpcode(uint8_t opcode) { + switch (opcode) { + case POP_TOP: return "POP_TOP"; + case ROT_TWO: return "ROT_TWO"; + case ROT_THREE: return "ROT_THREE"; + case DUP_TOP: return "DUP_TOP"; + case NOP: return "NOP"; + case UNARY_POSITIVE: return "UNARY_POSITIVE"; + case UNARY_NEGATIVE: return "UNARY_NEGATIVE"; + case UNARY_NOT: return "UNARY_NOT"; + case UNARY_INVERT: return "UNARY_INVERT"; + case BINARY_POWER: return "BINARY_POWER"; + case BINARY_MULTIPLY: return "BINARY_MULTIPLY"; + case BINARY_MODULO: return "BINARY_MODULO"; + case BINARY_ADD: return "BINARY_ADD"; + case BINARY_SUBTRACT: return "BINARY_SUBTRACT"; + case BINARY_SUBSCR: return "BINARY_SUBSCR"; + case BINARY_FLOOR_DIVIDE: return "BINARY_FLOOR_DIVIDE"; + case BINARY_TRUE_DIVIDE: return "BINARY_TRUE_DIVIDE"; + case INPLACE_FLOOR_DIVIDE: return "INPLACE_FLOOR_DIVIDE"; + case INPLACE_TRUE_DIVIDE: return "INPLACE_TRUE_DIVIDE"; + case INPLACE_ADD: return "INPLACE_ADD"; + case INPLACE_SUBTRACT: return "INPLACE_SUBTRACT"; + case INPLACE_MULTIPLY: return "INPLACE_MULTIPLY"; + case INPLACE_MODULO: return "INPLACE_MODULO"; + case STORE_SUBSCR: return "STORE_SUBSCR"; + case DELETE_SUBSCR: return "DELETE_SUBSCR"; + case BINARY_LSHIFT: return "BINARY_LSHIFT"; + case BINARY_RSHIFT: return "BINARY_RSHIFT"; + case BINARY_AND: return "BINARY_AND"; + case BINARY_XOR: return "BINARY_XOR"; + case BINARY_OR: return "BINARY_OR"; + case INPLACE_POWER: return "INPLACE_POWER"; + case GET_ITER: return "GET_ITER"; + case PRINT_EXPR: return "PRINT_EXPR"; + case INPLACE_LSHIFT: return "INPLACE_LSHIFT"; + case INPLACE_RSHIFT: return "INPLACE_RSHIFT"; + case INPLACE_AND: return "INPLACE_AND"; + case INPLACE_XOR: return "INPLACE_XOR"; + case INPLACE_OR: return "INPLACE_OR"; + case RETURN_VALUE: return "RETURN_VALUE"; + case IMPORT_STAR: return "IMPORT_STAR"; + case YIELD_VALUE: return "YIELD_VALUE"; + case POP_BLOCK: return "POP_BLOCK"; +#if PY_VERSION_HEX <= 0x03080000 + case END_FINALLY: return "END_FINALLY"; +#endif + case STORE_NAME: return "STORE_NAME"; + case DELETE_NAME: return "DELETE_NAME"; + case UNPACK_SEQUENCE: return "UNPACK_SEQUENCE"; + case FOR_ITER: return "FOR_ITER"; + case LIST_APPEND: return "LIST_APPEND"; + case STORE_ATTR: return "STORE_ATTR"; + case DELETE_ATTR: return "DELETE_ATTR"; + case STORE_GLOBAL: return "STORE_GLOBAL"; + case DELETE_GLOBAL: return "DELETE_GLOBAL"; + case LOAD_CONST: return "LOAD_CONST"; + case LOAD_NAME: return "LOAD_NAME"; + case BUILD_TUPLE: return "BUILD_TUPLE"; + case BUILD_LIST: return "BUILD_LIST"; + case BUILD_SET: return "BUILD_SET"; + case BUILD_MAP: return "BUILD_MAP"; + case LOAD_ATTR: return "LOAD_ATTR"; + case COMPARE_OP: return "COMPARE_OP"; + case IMPORT_NAME: return "IMPORT_NAME"; + case IMPORT_FROM: return "IMPORT_FROM"; + case JUMP_FORWARD: return "JUMP_FORWARD"; + case JUMP_IF_FALSE_OR_POP: return "JUMP_IF_FALSE_OR_POP"; + case JUMP_IF_TRUE_OR_POP: return "JUMP_IF_TRUE_OR_POP"; + case JUMP_ABSOLUTE: return "JUMP_ABSOLUTE"; + case POP_JUMP_IF_FALSE: return "POP_JUMP_IF_FALSE"; + case POP_JUMP_IF_TRUE: return "POP_JUMP_IF_TRUE"; + case LOAD_GLOBAL: return "LOAD_GLOBAL"; + case SETUP_FINALLY: return "SETUP_FINALLY"; + case LOAD_FAST: return "LOAD_FAST"; + case STORE_FAST: return "STORE_FAST"; + case DELETE_FAST: return "DELETE_FAST"; + case RAISE_VARARGS: return "RAISE_VARARGS"; + case CALL_FUNCTION: return "CALL_FUNCTION"; + case MAKE_FUNCTION: return "MAKE_FUNCTION"; + case BUILD_SLICE: return "BUILD_SLICE"; + case LOAD_CLOSURE: return "LOAD_CLOSURE"; + case LOAD_DEREF: return "LOAD_DEREF"; + case STORE_DEREF: return "STORE_DEREF"; + case CALL_FUNCTION_KW: return "CALL_FUNCTION_KW"; + case SETUP_WITH: return "SETUP_WITH"; + case EXTENDED_ARG: return "EXTENDED_ARG"; + case SET_ADD: return "SET_ADD"; + case MAP_ADD: return "MAP_ADD"; +#if PY_VERSION_HEX < 0x03080000 + case BREAK_LOOP: return "BREAK_LOOP"; + case CONTINUE_LOOP: return "CONTINUE_LOOP"; + case SETUP_LOOP: return "SETUP_LOOP"; + case SETUP_EXCEPT: return "SETUP_EXCEPT"; +#endif + case DUP_TOP_TWO: return "DUP_TOP_TWO"; + case BINARY_MATRIX_MULTIPLY: return "BINARY_MATRIX_MULTIPLY"; + case INPLACE_MATRIX_MULTIPLY: return "INPLACE_MATRIX_MULTIPLY"; + case GET_AITER: return "GET_AITER"; + case GET_ANEXT: return "GET_ANEXT"; + case BEFORE_ASYNC_WITH: return "BEFORE_ASYNC_WITH"; + case GET_YIELD_FROM_ITER: return "GET_YIELD_FROM_ITER"; + case LOAD_BUILD_CLASS: return "LOAD_BUILD_CLASS"; + case YIELD_FROM: return "YIELD_FROM"; + case GET_AWAITABLE: return "GET_AWAITABLE"; +#if PY_VERSION_HEX <= 0x03080000 + case WITH_CLEANUP_START: return "WITH_CLEANUP_START"; + case WITH_CLEANUP_FINISH: return "WITH_CLEANUP_FINISH"; +#endif + case SETUP_ANNOTATIONS: return "SETUP_ANNOTATIONS"; + case POP_EXCEPT: return "POP_EXCEPT"; + case UNPACK_EX: return "UNPACK_EX"; +#if PY_VERSION_HEX < 0x03070000 + case STORE_ANNOTATION: return "STORE_ANNOTATION"; +#endif + case CALL_FUNCTION_EX: return "CALL_FUNCTION_EX"; + case LOAD_CLASSDEREF: return "LOAD_CLASSDEREF"; +#if PY_VERSION_HEX <= 0x03080000 + case BUILD_LIST_UNPACK: return "BUILD_LIST_UNPACK"; + case BUILD_MAP_UNPACK: return "BUILD_MAP_UNPACK"; + case BUILD_MAP_UNPACK_WITH_CALL: return "BUILD_MAP_UNPACK_WITH_CALL"; + case BUILD_TUPLE_UNPACK: return "BUILD_TUPLE_UNPACK"; + case BUILD_SET_UNPACK: return "BUILD_SET_UNPACK"; +#endif + case SETUP_ASYNC_WITH: return "SETUP_ASYNC_WITH"; + case FORMAT_VALUE: return "FORMAT_VALUE"; + case BUILD_CONST_KEY_MAP: return "BUILD_CONST_KEY_MAP"; + case BUILD_STRING: return "BUILD_STRING"; +#if PY_VERSION_HEX <= 0x03080000 + case BUILD_TUPLE_UNPACK_WITH_CALL: return "BUILD_TUPLE_UNPACK_WITH_CALL"; +#endif +#if PY_VERSION_HEX >= 0x03070000 + case LOAD_METHOD: return "LOAD_METHOD"; + case CALL_METHOD: return "CALL_METHOD"; +#endif +#if PY_VERSION_HEX >= 0x03080000 && PY_VERSION_HEX < 0x03090000 + case BEGIN_FINALLY: return "BEGIN_FINALLY": + case POP_FINALLY: return "POP_FINALLY"; +#endif +#if PY_VERSION_HEX >= 0x03080000 + case ROT_FOUR: return "ROT_FOUR"; + case END_ASYNC_FOR: return "END_ASYNC_FOR"; +#endif +#if PY_VERSION_HEX >= 0x03080000 && PY_VERSION_HEX < 0x03090000 + // Added in Python 3.8 and removed in 3.9 + case CALL_FINALLY: return "CALL_FINALLY"; +#endif +#if PY_VERSION_HEX >= 0x03090000 + case RERAISE: return "RERAISE"; + case WITH_EXCEPT_START: return "WITH_EXCEPT_START"; + case LOAD_ASSERTION_ERROR: return "LOAD_ASSERTION_ERROR"; + case LIST_TO_TUPLE: return "LIST_TO_TUPLE"; + case IS_OP: return "IS_OP"; + case CONTAINS_OP: return "CONTAINS_OP"; + case JUMP_IF_NOT_EXC_MATCH: return "JUMP_IF_NOT_EXC_MATCH"; + case LIST_EXTEND: return "LIST_EXTEND"; + case SET_UPDATE: return "SET_UPDATE"; + case DICT_MERGE: return "DICT_MERGE"; + case DICT_UPDATE: return "DICT_UPDATE"; +#endif + + default: return std::to_string(static_cast(opcode)); + } +} + +static std::string FormatBytecode(const std::vector& bytecode, + int indent) { + std::string rc; + int remaining_argument_bytes = 0; + for (auto it = bytecode.begin(); it != bytecode.end(); ++it) { + std::string line; + if (remaining_argument_bytes == 0) { + line = FormatOpcode(*it); + remaining_argument_bytes = 1; + } else { + line = std::to_string(static_cast(*it)); + --remaining_argument_bytes; + } + + if (it < bytecode.end() - 1) { + line += ','; + } + + line.resize(20, ' '); + line += "// offset "; + line += std::to_string(it - bytecode.begin()); + line += '.'; + + rc += std::string(indent, ' '); + rc += line; + + if (it < bytecode.end() - 1) { + rc += '\n'; + } + } + + return rc; +} + +static void VerifyBytecode(const BytecodeManipulator& bytecode_manipulator, + std::vector expected_bytecode) { + EXPECT_EQ(expected_bytecode, bytecode_manipulator.bytecode()) + << "Actual bytecode:\n" + << " {\n" + << FormatBytecode(bytecode_manipulator.bytecode(), 10) << "\n" + << " }"; +} + +static void VerifyLineNumbersTable( + const BytecodeManipulator& bytecode_manipulator, + std::vector expected_linedata) { + // Convert to integers to better logging by EXPECT_EQ. + std::vector expected(expected_linedata.begin(), expected_linedata.end()); + std::vector actual( + bytecode_manipulator.linedata().begin(), + bytecode_manipulator.linedata().end()); + + EXPECT_EQ(expected, actual); +} + +TEST(BytecodeManipulatorTest, EmptyBytecode) { + BytecodeManipulator instance({}, false, {}); + EXPECT_FALSE(instance.InjectMethodCall(0, 0)); +} + + +TEST(BytecodeManipulatorTest, HasLineNumbersTable) { + BytecodeManipulator instance1({}, false, {}); + EXPECT_FALSE(instance1.has_linedata()); + + BytecodeManipulator instance2({}, true, {}); + EXPECT_TRUE(instance2.has_linedata()); +} + + + + +TEST(BytecodeManipulatorTest, InsertionSimple) { + BytecodeManipulator instance({ NOP, 0, RETURN_VALUE, 0 }, false, {}); + ASSERT_TRUE(instance.InjectMethodCall(2, 47)); + + VerifyBytecode( + instance, + { + NOP, // offset 0. + 0, // offset 1. + LOAD_CONST, // offset 4. + 47, // offset 5. + CALL_FUNCTION, // offset 6. + 0, // offset 7. + POP_TOP, // offset 8. + 0, // offset 9. + RETURN_VALUE, // offset 10. + 0 // offset 11. + }); +} + + +TEST(BytecodeManipulatorTest, InsertionExtended) { + BytecodeManipulator instance({ NOP, 0, RETURN_VALUE, 0 }, false, {}); + ASSERT_TRUE(instance.InjectMethodCall(2, 0x12345678)); + + VerifyBytecode( + instance, + { + NOP, // offset 0. + 0, // offset 1. + EXTENDED_ARG, // offset 2. + 0x12, // offset 3. + EXTENDED_ARG, // offset 2. + 0x34, // offset 3. + EXTENDED_ARG, // offset 2. + 0x56, // offset 3. + LOAD_CONST, // offset 4. + 0x78, // offset 5. + CALL_FUNCTION, // offset 6. + 0, // offset 7. + POP_TOP, // offset 8. + 0, // offset 9. + RETURN_VALUE, // offset 10. + 0 // offset 11. + }); +} + + +TEST(BytecodeManipulatorTest, InsertionBeginning) { + BytecodeManipulator instance({ NOP, 0, RETURN_VALUE, 0 }, false, {}); + ASSERT_TRUE(instance.InjectMethodCall(0, 47)); + + VerifyBytecode( + instance, + { + LOAD_CONST, // offset 0. + 47, // offset 1. + CALL_FUNCTION, // offset 2. + 0, // offset 3. + POP_TOP, // offset 4. + 0, // offset 5. + NOP, // offset 6. + 0, // offset 7. + RETURN_VALUE, // offset 8. + 0 // offset 9. + }); +} + + +TEST(BytecodeManipulatorTest, InsertionOffsetUpdates) { + BytecodeManipulator instance( + { + JUMP_FORWARD, + 12, + NOP, + 0, + JUMP_ABSOLUTE, + 34, + }, + false, + {}); + ASSERT_TRUE(instance.InjectMethodCall(2, 47)); + +#if PY_VERSION_HEX >= 0x030A0000 + // Jump offsets are instruction offsets, not byte offsets. + VerifyBytecode( + instance, + { + JUMP_FORWARD, // offset 0. + 12 + 3, // offset 1. + LOAD_CONST, // offset 2. + 47, // offset 3. + CALL_FUNCTION, // offset 4. + 0, // offset 5. + POP_TOP, // offset 6. + 0, // offset 7. + NOP, // offset 8. + 0, // offset 9. + JUMP_ABSOLUTE, // offset 10. + 34 + 3 // offset 11. + }); +#else + VerifyBytecode( + instance, + { + JUMP_FORWARD, // offset 0. + 12 + 6, // offset 1. + LOAD_CONST, // offset 2. + 47, // offset 3. + CALL_FUNCTION, // offset 4. + 0, // offset 5. + POP_TOP, // offset 6. + 0, // offset 7. + NOP, // offset 8. + 0, // offset 9. + JUMP_ABSOLUTE, // offset 10. + 34 + 6 // offset 11. + }); +#endif +} + + +TEST(BytecodeManipulatorTest, InsertionExtendedOffsetUpdates) { + BytecodeManipulator instance( + { + EXTENDED_ARG, + 12, + EXTENDED_ARG, + 34, + EXTENDED_ARG, + 56, + JUMP_FORWARD, + 78, + NOP, + 0, + EXTENDED_ARG, + 98, + EXTENDED_ARG, + 76, + EXTENDED_ARG, + 54, + JUMP_ABSOLUTE, + 32 + }, + false, + {}); + ASSERT_TRUE(instance.InjectMethodCall(8, 11)); + +#if PY_VERSION_HEX >= 0x030A0000 + // Jump offsets are instruction offsets, not byte offsets. + VerifyBytecode( + instance, + { + EXTENDED_ARG, // offset 0. + 12, // offset 1. + EXTENDED_ARG, // offset 2. + 34, // offset 3. + EXTENDED_ARG, // offset 4. + 56, // offset 5. + JUMP_FORWARD, // offset 6. + 78 + 3, // offset 7. + LOAD_CONST, // offset 8. + 11, // offset 9. + CALL_FUNCTION, // offset 10. + 0, // offset 11. + POP_TOP, // offset 12. + 0, // offset 13. + NOP, // offset 14. + 0, // offset 15. + EXTENDED_ARG, // offset 16. + 98, // offset 17. + EXTENDED_ARG, // offset 18. + 76, // offset 19. + EXTENDED_ARG, // offset 20. + 54, // offset 21. + JUMP_ABSOLUTE, // offset 22. + 32 + 3 // offset 23. + }); +#else + VerifyBytecode( + instance, + { + EXTENDED_ARG, // offset 0. + 12, // offset 1. + EXTENDED_ARG, // offset 2. + 34, // offset 3. + EXTENDED_ARG, // offset 4. + 56, // offset 5. + JUMP_FORWARD, // offset 6. + 78 + 6, // offset 7. + LOAD_CONST, // offset 8. + 11, // offset 9. + CALL_FUNCTION, // offset 10. + 0, // offset 11. + POP_TOP, // offset 12. + 0, // offset 13. + NOP, // offset 14. + 0, // offset 15. + EXTENDED_ARG, // offset 16. + 98, // offset 17. + EXTENDED_ARG, // offset 18. + 76, // offset 19. + EXTENDED_ARG, // offset 20. + 54, // offset 21. + JUMP_ABSOLUTE, // offset 22. + 32 + 6 // offset 23. + }); +#endif +} + + +TEST(BytecodeManipulatorTest, InsertionDeltaOffsetNoUpdate) { + BytecodeManipulator instance( + { + JUMP_FORWARD, + 2, + NOP, + 0, + RETURN_VALUE, + 0, + JUMP_FORWARD, + 2, + }, + false, {}); + ASSERT_TRUE(instance.InjectMethodCall(4, 99)); + + VerifyBytecode( + instance, + { + JUMP_FORWARD, // offset 0. + 2, // offset 1. + NOP, // offset 2. + 0, // offset 3. + LOAD_CONST, // offset 4. + 99, // offset 5. + CALL_FUNCTION, // offset 6. + 0, // offset 7. + POP_TOP, // offset 8. + 0, // offset 9. + RETURN_VALUE, // offset 10. + 0, // offset 11. + JUMP_FORWARD, // offset 12. + 2 // offset 13. + }); +} + + +TEST(BytecodeManipulatorTest, InsertionAbsoluteOffsetNoUpdate) { + BytecodeManipulator instance( + { + JUMP_ABSOLUTE, + 2, + RETURN_VALUE, + 0 + }, + false, + {}); + ASSERT_TRUE(instance.InjectMethodCall(2, 99)); + + VerifyBytecode( + instance, + { + JUMP_ABSOLUTE, // offset 0. + 2, // offset 1. + LOAD_CONST, // offset 2. + 99, // offset 3. + CALL_FUNCTION, // offset 4. + 0, // offset 5. + POP_TOP, // offset 6. + 0, // offset 7. + RETURN_VALUE, // offset 8. + 0 // offset 9. + }); +} + + +TEST(BytecodeManipulatorTest, InsertionOffsetUneededExtended) { + BytecodeManipulator instance( + { EXTENDED_ARG, 0, JUMP_FORWARD, 2, NOP, 0 }, + false, + {}); + ASSERT_TRUE(instance.InjectMethodCall(4, 11)); + +#if PY_VERSION_HEX >= 0x030A0000 + // Jump offsets are instruction offsets, not byte offsets. + VerifyBytecode( + instance, + { + EXTENDED_ARG, // offset 0. + 0, // offset 1. + JUMP_FORWARD, // offset 2. + 2 + 3, // offset 3. + LOAD_CONST, // offset 4. + 11, // offset 5. + CALL_FUNCTION, // offset 6. + 0, // offset 7. + POP_TOP, // offset 8. + 0, // offset 9. + NOP, // offset 10. + 0 // offset 11. + }); +#else + VerifyBytecode( + instance, + { + EXTENDED_ARG, // offset 0. + 0, // offset 1. + JUMP_FORWARD, // offset 2. + 2 + 6, // offset 3. + LOAD_CONST, // offset 4. + 11, // offset 5. + CALL_FUNCTION, // offset 6. + 0, // offset 7. + POP_TOP, // offset 8. + 0, // offset 9. + NOP, // offset 10. + 0 // offset 11. + }); +#endif +} + + +TEST(BytecodeManipulatorTest, InsertionOffsetUpgradeExtended) { + BytecodeManipulator instance({ JUMP_ABSOLUTE, 254 , NOP, 0 }, false, {}); + ASSERT_TRUE(instance.InjectMethodCall(2, 11)); + +#if PY_VERSION_HEX >= 0x030A0000 + // Jump offsets are instruction offsets, not byte offsets. + VerifyBytecode( + instance, + { + EXTENDED_ARG, // offset 0. + 1, // offset 1. + JUMP_ABSOLUTE, // offset 2. + 2, // offset 3. + LOAD_CONST, // offset 4. + 11, // offset 5. + CALL_FUNCTION, // offset 6. + 0, // offset 7. + POP_TOP, // offset 8. + 0, // offset 9. + NOP, // offset 10. + 0 // offset 11. + }); +#else + VerifyBytecode( + instance, + { + EXTENDED_ARG, // offset 0. + 1, // offset 1. + JUMP_ABSOLUTE, // offset 2. + 6, // offset 3. + LOAD_CONST, // offset 4. + 11, // offset 5. + CALL_FUNCTION, // offset 6. + 0, // offset 7. + POP_TOP, // offset 8. + 0, // offset 9. + NOP, // offset 10. + 0 // offset 11. + }); +#endif +} + + +TEST(BytecodeManipulatorTest, InsertionOffsetUpgradeExtendedTwice) { + BytecodeManipulator instance( + { JUMP_ABSOLUTE, 252, JUMP_ABSOLUTE, 254, NOP, 0 }, + false, + {}); + ASSERT_TRUE(instance.InjectMethodCall(4, 12)); + +#if PY_VERSION_HEX >= 0x030A0000 + // Jump offsets are instruction offsets, not byte offsets. + VerifyBytecode( + instance, + { + EXTENDED_ARG, // offset 0. + 1, // offset 1. + JUMP_ABSOLUTE, // offset 2. + 1, // offset 3. + EXTENDED_ARG, // offset 4. + 1, // offset 5. + JUMP_ABSOLUTE, // offset 6. + 3, // offset 7. + LOAD_CONST, // offset 8. + 12, // offset 9. + CALL_FUNCTION, // offset 10. + 0, // offset 11. + POP_TOP, // offset 12. + 0, // offset 13. + NOP, // offset 14. + 0 // offset 15. + }); +#else + VerifyBytecode( + instance, + { + EXTENDED_ARG, // offset 0. + 1, // offset 1. + JUMP_ABSOLUTE, // offset 2. + 6, // offset 3. + EXTENDED_ARG, // offset 4. + 1, // offset 5. + JUMP_ABSOLUTE, // offset 6. + 8, // offset 7. + LOAD_CONST, // offset 8. + 12, // offset 9. + CALL_FUNCTION, // offset 10. + 0, // offset 11. + POP_TOP, // offset 12. + 0, // offset 13. + NOP, // offset 14. + 0 // offset 15. + }); +#endif +} + + +TEST(BytecodeManipulatorTest, InsertionBadInstruction) { + BytecodeManipulator instance( + { NOP, 0, NOP, 0, LOAD_CONST }, + false, + {}); + EXPECT_FALSE(instance.InjectMethodCall(2, 0)); +} + + +TEST(BytecodeManipulatorTest, InsertionNegativeOffset) { + BytecodeManipulator instance({ NOP, 0, RETURN_VALUE, 0 }, false, {}); + EXPECT_FALSE(instance.InjectMethodCall(-1, 0)); +} + + +TEST(BytecodeManipulatorTest, InsertionOutOfRangeOffset) { + BytecodeManipulator instance({ NOP, 0, RETURN_VALUE, 0 }, false, {}); + EXPECT_FALSE(instance.InjectMethodCall(4, 0)); +} + + +TEST(BytecodeManipulatorTest, InsertionMidInstruction) { + BytecodeManipulator instance( + { NOP, 0, LOAD_CONST, 0, NOP, 0 }, + false, + {}); + + EXPECT_FALSE(instance.InjectMethodCall(1, 0)); + EXPECT_FALSE(instance.InjectMethodCall(3, 0)); + EXPECT_FALSE(instance.InjectMethodCall(5, 0)); +} + + +TEST(BytecodeManipulatorTest, InsertionTooManyUpgrades) { + BytecodeManipulator instance( + { + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + JUMP_ABSOLUTE, 254, + NOP, 0 + }, + false, + {}); + EXPECT_FALSE(instance.InjectMethodCall(20, 0)); +} + + +TEST(BytecodeManipulatorTest, IncompleteBytecodeInsert) { + BytecodeManipulator instance({ NOP, 0, LOAD_CONST }, false, {}); + EXPECT_FALSE(instance.InjectMethodCall(2, 0)); +} + + +TEST(BytecodeManipulatorTest, IncompleteBytecodeAppend) { + BytecodeManipulator instance( + { YIELD_VALUE, 0, NOP, 0, LOAD_CONST }, + false, {}); + EXPECT_FALSE(instance.InjectMethodCall(4, 0)); +} + + +TEST(BytecodeManipulatorTest, LineNumbersTableUpdateBeginning) { + BytecodeManipulator instance( + { NOP, 0, RETURN_VALUE, 0 }, + true, + { 2, 1, 2, 1 }); + ASSERT_TRUE(instance.InjectMethodCall(0, 99)); + + VerifyLineNumbersTable(instance, { 8, 1, 2, 1 }); +} + + +TEST(BytecodeManipulatorTest, LineNumbersTableUpdateLineBoundary) { + BytecodeManipulator instance( + { NOP, 0, RETURN_VALUE, 0 }, + true, + { 0, 1, 2, 1, 2, 1 }); + ASSERT_TRUE(instance.InjectMethodCall(2, 99)); + + VerifyLineNumbersTable(instance, { 0, 1, 2, 1, 8, 1 }); +} + + +TEST(BytecodeManipulatorTest, LineNumbersTableUpdateMidLine) { + BytecodeManipulator instance( + { NOP, 0, NOP, 0, RETURN_VALUE, 0 }, + true, + { 0, 1, 4, 1 }); + ASSERT_TRUE(instance.InjectMethodCall(2, 99)); + + VerifyLineNumbersTable(instance, { 0, 1, 10, 1 }); +} + + +TEST(BytecodeManipulatorTest, LineNumbersTablePastEnd) { + BytecodeManipulator instance( + { NOP, 0, NOP, 0, NOP, 0, RETURN_VALUE, 0 }, + true, + { 0, 1 }); + ASSERT_TRUE(instance.InjectMethodCall(6, 99)); + + VerifyLineNumbersTable(instance, { 0, 1 }); +} + + +TEST(BytecodeManipulatorTest, LineNumbersTableUpgradeExtended) { + BytecodeManipulator instance( + { JUMP_ABSOLUTE, 254, RETURN_VALUE, 0 }, + true, + { 2, 1, 2, 1 }); + ASSERT_TRUE(instance.InjectMethodCall(2, 99)); + + VerifyLineNumbersTable(instance, { 4, 1, 8, 1 }); +} + + +TEST(BytecodeManipulatorTest, LineNumbersTableOverflow) { + std::vector bytecode(300, 0); + BytecodeManipulator instance( + bytecode, + true, + { 254, 1 }); + ASSERT_TRUE(instance.InjectMethodCall(2, 99)); + +#if PY_VERSION_HEX >= 0x030A0000 + VerifyLineNumbersTable(instance, { 254, 0, 6, 1 }); +#else + VerifyLineNumbersTable(instance, { 255, 0, 5, 1 }); +#endif +} + + +TEST(BytecodeManipulatorTest, SuccessAppend) { + BytecodeManipulator instance( + { YIELD_VALUE, 0, LOAD_CONST, 0, NOP, 0 }, + false, + {}); + ASSERT_TRUE(instance.InjectMethodCall(2, 57)); + + VerifyBytecode( + instance, + { + YIELD_VALUE, // offset 0. + 0, // offset 1. + JUMP_ABSOLUTE, // offset 2. + 6, // offset 3. + NOP, // offset 4. + 0, // offset 5. + LOAD_CONST, // offset 6. + 57, // offset 7. + CALL_FUNCTION, // offset 8. + 0, // offset 9. + POP_TOP, // offset 10. + 0, // offset 11. + LOAD_CONST, // offset 12. + 0, // offset 13. + JUMP_ABSOLUTE, // offset 14. + 4 // offset 15. + }); +} + + +TEST(BytecodeManipulatorTest, SuccessAppendYieldFrom) { + BytecodeManipulator instance( + { YIELD_FROM, 0, LOAD_CONST, 0, NOP, 0 }, + false, + {}); + ASSERT_TRUE(instance.InjectMethodCall(2, 57)); + + VerifyBytecode( + instance, + { + YIELD_FROM, // offset 0. + 0, // offset 1. + JUMP_ABSOLUTE, // offset 2. + 6, // offset 3. + NOP, // offset 4. + 0, // offset 5. + LOAD_CONST, // offset 6. + 57, // offset 7. + CALL_FUNCTION, // offset 8. + 0, // offset 9. + POP_TOP, // offset 10. + 0, // offset 11. + LOAD_CONST, // offset 12. + 0, // offset 13. + JUMP_ABSOLUTE, // offset 14. + 4 // offset 15. + }); +} + + +TEST(BytecodeManipulatorTest, AppendExtraPadding) { + BytecodeManipulator instance( + { + YIELD_VALUE, + 0, + EXTENDED_ARG, + 15, + EXTENDED_ARG, + 16, + EXTENDED_ARG, + 17, + LOAD_CONST, + 18, + RETURN_VALUE, + 0 + }, + false, {}); + ASSERT_TRUE(instance.InjectMethodCall(2, 0x7273)); + + VerifyBytecode( + instance, + { + YIELD_VALUE, // offset 0. + 0, // offset 1. + JUMP_ABSOLUTE, // offset 2. + 12, // offset 3. + NOP, // offset 4. Args for NOP do not matter. + 9, // offset 5. + NOP, // offset 6. + 9, // offset 7. + NOP, // offset 8. + 9, // offset 9. + RETURN_VALUE, // offset 10. + 0, // offset 11. + EXTENDED_ARG, // offset 12. + 0x72, // offset 13. + LOAD_CONST, // offset 14. + 0x73, // offset 15. + CALL_FUNCTION, // offset 16. + 0, // offset 17. + POP_TOP, // offset 18. + 0, // offset 19. + EXTENDED_ARG, // offset 20. + 15, // offset 21. + EXTENDED_ARG, // offset 22. + 16, // offset 23. + EXTENDED_ARG, // offset 24. + 17, // offset 25. + LOAD_CONST, // offset 26. + 18, // offset 27. + JUMP_ABSOLUTE, // offset 28. + 10 // offset 29. + }); +} + + +TEST(BytecodeManipulatorTest, AppendToEnd) { + std::vector bytecode = {YIELD_VALUE, 0}; + // Case where trampoline requires 4 bytes to write. + bytecode.resize(300); + BytecodeManipulator instance(bytecode, false, {}); + + // This scenario could be supported in theory, but it's not. The purpose of + // this test case is to verify there are no crashes or corruption. + ASSERT_FALSE(instance.InjectMethodCall(298, 0x12)); +} + + +TEST(BytecodeManipulatorTest, NoSpaceForTrampoline) { + const std::vector test_cases[] = { + {YIELD_VALUE, 0, YIELD_VALUE, 0, NOP, 0}, + {YIELD_VALUE, 0, FOR_ITER, 0, NOP, 0}, + {YIELD_VALUE, 0, JUMP_FORWARD, 0, NOP, 0}, +#if PY_VERSION_HEX < 0x03080000 + {YIELD_VALUE, 0, SETUP_LOOP, 0, NOP, 0}, +#endif + {YIELD_VALUE, 0, SETUP_FINALLY, 0, NOP, 0}, +#if PY_VERSION_HEX < 0x03080000 + {YIELD_VALUE, 0, SETUP_LOOP, 0, NOP, 0}, + {YIELD_VALUE, 0, SETUP_EXCEPT, 0, NOP, 0}, +#endif +#if PY_VERSION_HEX >= 0x03080000 && PY_VERSION_HEX < 0x03090000 + {YIELD_VALUE, 0, CALL_FINALLY, 0, NOP, 0}, +#endif + }; + + for (const auto& test_case : test_cases) { + BytecodeManipulator instance(test_case, false, {}); + EXPECT_FALSE(instance.InjectMethodCall(2, 0)) + << "Input:\n" + << FormatBytecode(test_case, 4) << "\n" + << "Unexpected output:\n" + << FormatBytecode(instance.bytecode(), 4); + } + + // Case where trampoline requires 4 bytes to write. + std::vector bytecode(300, 0); + bytecode[0] = YIELD_VALUE; + bytecode[2] = NOP; + bytecode[4] = YIELD_VALUE; + BytecodeManipulator instance(bytecode, false, {}); + ASSERT_FALSE(instance.InjectMethodCall(2, 0x12)); +} + +// Tests that we don't allow jumping into the middle of the space reserved for +// the trampoline. See the comments in AppendMethodCall() in +// bytecode_manipulator.cc. +TEST(BytecodeManipulatorTest, JumpMidRelocatedInstructions) { + std::vector test_cases[] = { + {YIELD_VALUE, 0, FOR_ITER, 2, LOAD_CONST, 0}, + {YIELD_VALUE, 0, JUMP_FORWARD, 2, LOAD_CONST, 0}, + {YIELD_VALUE, 0, SETUP_FINALLY, 2, LOAD_CONST, 0}, + {YIELD_VALUE, 0, SETUP_WITH, 2, LOAD_CONST, 0}, + {YIELD_VALUE, 0, SETUP_FINALLY, 2, LOAD_CONST, 0}, + {YIELD_VALUE, 0, JUMP_IF_FALSE_OR_POP, 6, LOAD_CONST, 0}, + {YIELD_VALUE, 0, JUMP_IF_TRUE_OR_POP, 6, LOAD_CONST, 0}, + {YIELD_VALUE, 0, JUMP_ABSOLUTE, 6, LOAD_CONST, 0}, + {YIELD_VALUE, 0, POP_JUMP_IF_FALSE, 6, LOAD_CONST, 0}, + {YIELD_VALUE, 0, POP_JUMP_IF_TRUE, 6, LOAD_CONST, 0}, +#if PY_VERSION_HEX < 0x03080000 + {YIELD_VALUE, 0, SETUP_LOOP, 2, LOAD_CONST, 0}, + {YIELD_VALUE, 0, CONTINUE_LOOP, 6, LOAD_CONST, 0}, +#endif + }; + + for (auto& test_case : test_cases) { + // Case where trampoline requires 4 bytes to write. + test_case.resize(300); + BytecodeManipulator instance(test_case, false, {}); + EXPECT_FALSE(instance.InjectMethodCall(4, 0)) + << "Input:\n" + << FormatBytecode(test_case, 4) << "\n" + << "Unexpected output:\n" + << FormatBytecode(instance.bytecode(), 4); + } +} + + +// Test that we allow jumping to the start of the space reserved for the +// trampoline. +TEST(BytecodeManipulatorTest, JumpStartOfRelocatedInstructions) { + const std::vector test_cases[] = { + {YIELD_VALUE, 0, FOR_ITER, 0, LOAD_CONST, 0}, + {YIELD_VALUE, 0, SETUP_WITH, 0, LOAD_CONST, 0}, + {YIELD_VALUE, 0, JUMP_ABSOLUTE, 4, LOAD_CONST, 0}}; + + for (const auto& test_case : test_cases) { + BytecodeManipulator instance(test_case, false, {}); + EXPECT_TRUE(instance.InjectMethodCall(4, 0)) + << "Input:\n" << FormatBytecode(test_case, 4); + } +} + + +// Test that we allow jumping after the space reserved for the trampoline. +TEST(BytecodeManipulatorTest, JumpAfterRelocatedInstructions) { + const std::vector test_cases[] = { + {YIELD_VALUE, 0, FOR_ITER, 2, LOAD_CONST, 0, NOP, 0}, + {YIELD_VALUE, 0, SETUP_WITH, 2, LOAD_CONST, 0, NOP, 0}, + {YIELD_VALUE, 0, JUMP_ABSOLUTE, 6, LOAD_CONST, 0, NOP, 0}}; + + for (const auto& test_case : test_cases) { + BytecodeManipulator instance(test_case, false, {}); + EXPECT_TRUE(instance.InjectMethodCall(4, 0)) + << "Input:\n" << FormatBytecode(test_case, 4); + } +} + + +TEST(BytecodeManipulatorTest, InsertionRevertOnFailure) { + const std::vector input{JUMP_FORWARD, 0, NOP, 0, JUMP_ABSOLUTE, 2}; + + BytecodeManipulator instance(input, false, {}); + ASSERT_FALSE(instance.InjectMethodCall(1, 47)); + + VerifyBytecode(instance, input); +} + + +} // namespace cdbg +} // namespace devtools diff --git a/tests/py/application_info_test.py b/tests/py/application_info_test.py new file mode 100644 index 0000000..51f7427 --- /dev/null +++ b/tests/py/application_info_test.py @@ -0,0 +1,72 @@ +"""Tests for application_info.""" + +import os +from unittest import mock + +import requests + +from googleclouddebugger import application_info +from absl.testing import absltest + + +class ApplicationInfoTest(absltest.TestCase): + + def test_get_platform_default(self): + """Returns default platform when no platform is detected.""" + self.assertEqual(application_info.PlatformType.DEFAULT, + application_info.GetPlatform()) + + def test_get_platform_gcf_name(self): + """Returns cloud_function when the FUNCTION_NAME env variable is set.""" + try: + os.environ['FUNCTION_NAME'] = 'function-name' + self.assertEqual(application_info.PlatformType.CLOUD_FUNCTION, + application_info.GetPlatform()) + finally: + del os.environ['FUNCTION_NAME'] + + def test_get_platform_gcf_target(self): + """Returns cloud_function when the FUNCTION_TARGET env variable is set.""" + try: + os.environ['FUNCTION_TARGET'] = 'function-target' + self.assertEqual(application_info.PlatformType.CLOUD_FUNCTION, + application_info.GetPlatform()) + finally: + del os.environ['FUNCTION_TARGET'] + + def test_get_region_none(self): + """Returns None when no region is detected.""" + self.assertIsNone(application_info.GetRegion()) + + def test_get_region_gcf(self): + """Returns correct region when the FUNCTION_REGION env variable is set.""" + try: + os.environ['FUNCTION_REGION'] = 'function-region' + self.assertEqual('function-region', application_info.GetRegion()) + finally: + del os.environ['FUNCTION_REGION'] + + @mock.patch('requests.get') + def test_get_region_metadata_server(self, mock_requests_get): + """Returns correct region if found in metadata server.""" + success_response = mock.Mock(requests.Response) + success_response.status_code = 200 + success_response.text = 'a/b/function-region' + mock_requests_get.return_value = success_response + + self.assertEqual('function-region', application_info.GetRegion()) + + @mock.patch('requests.get') + def test_get_region_metadata_server_fail(self, mock_requests_get): + """Returns None if region not found in metadata server.""" + exception = requests.exceptions.HTTPError() + failed_response = mock.Mock(requests.Response) + failed_response.status_code = 400 + failed_response.raise_for_status.side_effect = exception + mock_requests_get.return_value = failed_response + + self.assertIsNone(application_info.GetRegion()) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/backoff_test.py b/tests/py/backoff_test.py new file mode 100644 index 0000000..262976c --- /dev/null +++ b/tests/py/backoff_test.py @@ -0,0 +1,35 @@ +"""Unit test for backoff module.""" + +from absl.testing import absltest + +from googleclouddebugger import backoff + + +class BackoffTest(absltest.TestCase): + """Unit test for backoff module.""" + + def setUp(self): + self._backoff = backoff.Backoff(10, 100, 1.5) + + def testInitial(self): + self.assertEqual(10, self._backoff.Failed()) + + def testIncrease(self): + self._backoff.Failed() + self.assertEqual(15, self._backoff.Failed()) + + def testMaximum(self): + for _ in range(100): + self._backoff.Failed() + + self.assertEqual(100, self._backoff.Failed()) + + def testResetOnSuccess(self): + for _ in range(4): + self._backoff.Failed() + self._backoff.Succeeded() + self.assertEqual(10, self._backoff.Failed()) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/breakpoints_manager_test.py b/tests/py/breakpoints_manager_test.py new file mode 100644 index 0000000..269931f --- /dev/null +++ b/tests/py/breakpoints_manager_test.py @@ -0,0 +1,229 @@ +"""Unit test for breakpoints_manager module.""" + +from datetime import datetime +from datetime import timedelta +from unittest import mock + +from absl.testing import absltest + +from googleclouddebugger import breakpoints_manager + + +class BreakpointsManagerTest(absltest.TestCase): + """Unit test for breakpoints_manager module.""" + + def setUp(self): + self._breakpoints_manager = breakpoints_manager.BreakpointsManager( + self, None) + + path = 'googleclouddebugger.breakpoints_manager.' + breakpoint_class = path + 'python_breakpoint.PythonBreakpoint' + + patcher = mock.patch(breakpoint_class) + self._mock_breakpoint = patcher.start() + self.addCleanup(patcher.stop) + + def testEmpty(self): + self.assertEmpty(self._breakpoints_manager._active) + + def testSetSingle(self): + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self._mock_breakpoint.assert_has_calls( + [mock.call({'id': 'ID1'}, self, self._breakpoints_manager, None)]) + self.assertLen(self._breakpoints_manager._active, 1) + + def testSetDouble(self): + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self._mock_breakpoint.assert_has_calls( + [mock.call({'id': 'ID1'}, self, self._breakpoints_manager, None)]) + self.assertLen(self._breakpoints_manager._active, 1) + + self._breakpoints_manager.SetActiveBreakpoints([{ + 'id': 'ID1' + }, { + 'id': 'ID2' + }]) + self._mock_breakpoint.assert_has_calls([ + mock.call({'id': 'ID1'}, self, self._breakpoints_manager, None), + mock.call({'id': 'ID2'}, self, self._breakpoints_manager, None) + ]) + self.assertLen(self._breakpoints_manager._active, 2) + + def testSetRepeated(self): + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self.assertEqual(1, self._mock_breakpoint.call_count) + + def testClear(self): + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self._breakpoints_manager.SetActiveBreakpoints([]) + self.assertEqual(1, self._mock_breakpoint.return_value.Clear.call_count) + self.assertEmpty(self._breakpoints_manager._active) + + def testCompleteInvalidId(self): + self._breakpoints_manager.CompleteBreakpoint('ID_INVALID') + + def testComplete(self): + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self._breakpoints_manager.CompleteBreakpoint('ID1') + self.assertEqual(1, self._mock_breakpoint.return_value.Clear.call_count) + + def testSetCompleted(self): + self._breakpoints_manager.CompleteBreakpoint('ID1') + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self.assertEqual(0, self._mock_breakpoint.call_count) + + def testCompletedCleanup(self): + self._breakpoints_manager.CompleteBreakpoint('ID1') + self._breakpoints_manager.SetActiveBreakpoints([]) + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self.assertEqual(1, self._mock_breakpoint.call_count) + + def testMultipleSetDelete(self): + self._breakpoints_manager.SetActiveBreakpoints([{ + 'id': 'ID1' + }, { + 'id': 'ID2' + }, { + 'id': 'ID3' + }, { + 'id': 'ID4' + }]) + self.assertLen(self._breakpoints_manager._active, 4) + + self._breakpoints_manager.SetActiveBreakpoints([{ + 'id': 'ID1' + }, { + 'id': 'ID2' + }, { + 'id': 'ID3' + }, { + 'id': 'ID4' + }]) + self.assertLen(self._breakpoints_manager._active, 4) + + self._breakpoints_manager.SetActiveBreakpoints([]) + self.assertEmpty(self._breakpoints_manager._active) + + def testCombination(self): + self._breakpoints_manager.SetActiveBreakpoints([{ + 'id': 'ID1' + }, { + 'id': 'ID2' + }, { + 'id': 'ID3' + }]) + self.assertLen(self._breakpoints_manager._active, 3) + + self._breakpoints_manager.CompleteBreakpoint('ID2') + self.assertEqual(1, self._mock_breakpoint.return_value.Clear.call_count) + self.assertLen(self._breakpoints_manager._active, 2) + + self._breakpoints_manager.SetActiveBreakpoints([{ + 'id': 'ID2' + }, { + 'id': 'ID3' + }, { + 'id': 'ID4' + }]) + self.assertEqual(2, self._mock_breakpoint.return_value.Clear.call_count) + self.assertLen(self._breakpoints_manager._active, 2) + + self._breakpoints_manager.CompleteBreakpoint('ID2') + self.assertEqual(2, self._mock_breakpoint.return_value.Clear.call_count) + self.assertLen(self._breakpoints_manager._active, 2) + + self._breakpoints_manager.SetActiveBreakpoints([]) + self.assertEqual(4, self._mock_breakpoint.return_value.Clear.call_count) + self.assertEmpty(self._breakpoints_manager._active) + + def testCheckExpirationNoBreakpoints(self): + self._breakpoints_manager.CheckBreakpointsExpiration() + + def testCheckNotExpired(self): + self._breakpoints_manager.SetActiveBreakpoints([{ + 'id': 'ID1' + }, { + 'id': 'ID2' + }]) + self._mock_breakpoint.return_value.GetExpirationTime.return_value = ( + datetime.utcnow() + timedelta(minutes=1)) + self._breakpoints_manager.CheckBreakpointsExpiration() + self.assertEqual( + 0, self._mock_breakpoint.return_value.ExpireBreakpoint.call_count) + + def testCheckExpired(self): + self._breakpoints_manager.SetActiveBreakpoints([{ + 'id': 'ID1' + }, { + 'id': 'ID2' + }]) + self._mock_breakpoint.return_value.GetExpirationTime.return_value = ( + datetime.utcnow() - timedelta(minutes=1)) + self._breakpoints_manager.CheckBreakpointsExpiration() + self.assertEqual( + 2, self._mock_breakpoint.return_value.ExpireBreakpoint.call_count) + + def testCheckExpirationReset(self): + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self._mock_breakpoint.return_value.GetExpirationTime.return_value = ( + datetime.utcnow() + timedelta(minutes=1)) + self._breakpoints_manager.CheckBreakpointsExpiration() + self.assertEqual( + 0, self._mock_breakpoint.return_value.ExpireBreakpoint.call_count) + + self._breakpoints_manager.SetActiveBreakpoints([{ + 'id': 'ID1' + }, { + 'id': 'ID2' + }]) + self._mock_breakpoint.return_value.GetExpirationTime.return_value = ( + datetime.utcnow() - timedelta(minutes=1)) + self._breakpoints_manager.CheckBreakpointsExpiration() + self.assertEqual( + 2, self._mock_breakpoint.return_value.ExpireBreakpoint.call_count) + + def testCheckExpirationCacheNegative(self): + base = datetime(2015, 1, 1) + + with mock.patch.object(breakpoints_manager.BreakpointsManager, + 'GetCurrentTime') as mock_time: + mock_time.return_value = base + + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self._mock_breakpoint.return_value.GetExpirationTime.return_value = ( + base + timedelta(minutes=1)) + + self._breakpoints_manager.CheckBreakpointsExpiration() + self.assertEqual( + 0, self._mock_breakpoint.return_value.ExpireBreakpoint.call_count) + + # The nearest expiration time is cached, so this should have no effect. + self._mock_breakpoint.return_value.GetExpirationTime.return_value = ( + base - timedelta(minutes=1)) + self._breakpoints_manager.CheckBreakpointsExpiration() + self.assertEqual( + 0, self._mock_breakpoint.return_value.ExpireBreakpoint.call_count) + + def testCheckExpirationCachePositive(self): + base = datetime(2015, 1, 1) + + with mock.patch.object(breakpoints_manager.BreakpointsManager, + 'GetCurrentTime') as mock_time: + self._breakpoints_manager.SetActiveBreakpoints([{'id': 'ID1'}]) + self._mock_breakpoint.return_value.GetExpirationTime.return_value = ( + base + timedelta(minutes=1)) + + mock_time.return_value = base + self._breakpoints_manager.CheckBreakpointsExpiration() + self.assertEqual( + 0, self._mock_breakpoint.return_value.ExpireBreakpoint.call_count) + + mock_time.return_value = base + timedelta(minutes=2) + self._breakpoints_manager.CheckBreakpointsExpiration() + self.assertEqual( + 1, self._mock_breakpoint.return_value.ExpireBreakpoint.call_count) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/collector_test.py b/tests/py/collector_test.py new file mode 100644 index 0000000..abc39b2 --- /dev/null +++ b/tests/py/collector_test.py @@ -0,0 +1,1778 @@ +"""Unit test for collector module.""" + +import copy +import datetime +import inspect +import logging +import os +import sys +import time +from unittest import mock + +from absl.testing import absltest + +from googleclouddebugger import collector +from googleclouddebugger import labels + +LOGPOINT_PAUSE_MSG = ( + 'LOGPOINT: Logpoint is paused due to high log rate until log ' + 'quota is restored') + + +def CaptureCollectorWithDefaultLocation(definition, + data_visibility_policy=None): + """Makes a LogCollector with a default location. + + Args: + definition: the rest of the breakpoint definition + data_visibility_policy: optional visibility policy + + Returns: + A LogCollector + """ + definition['location'] = {'path': 'collector_test.py', 'line': 10} + return collector.CaptureCollector(definition, data_visibility_policy) + + +def LogCollectorWithDefaultLocation(definition): + """Makes a LogCollector with a default location. + + Args: + definition: the rest of the breakpoint definition + + Returns: + A LogCollector + """ + definition['location'] = {'path': 'collector_test.py', 'line': 10} + return collector.LogCollector(definition) + + +class CaptureCollectorTest(absltest.TestCase): + """Unit test for capture collector.""" + + def tearDown(self): + collector.CaptureCollector.pretty_printers = [] + + def testCallStackUnlimitedFrames(self): + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.max_frames = 1000 + self._collector.Collect(inspect.currentframe()) + + self.assertGreater(len(self._collector.breakpoint['stackFrames']), 1) + self.assertLess(len(self._collector.breakpoint['stackFrames']), 100) + + def testCallStackLimitedFrames(self): + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.max_frames = 2 + self._collector.Collect(inspect.currentframe()) + + self.assertLen(self._collector.breakpoint['stackFrames'], 2) + + top_frame = self._collector.breakpoint['stackFrames'][0] + self.assertEqual('CaptureCollectorTest.testCallStackLimitedFrames', + top_frame['function']) + self.assertIn('collector_test.py', top_frame['location']['path']) + self.assertGreater(top_frame['location']['line'], 1) + + frame_below = self._collector.breakpoint['stackFrames'][1] + frame_below_line = inspect.currentframe().f_back.f_lineno + self.assertEqual(frame_below_line, frame_below['location']['line']) + + def testCallStackLimitedExpandedFrames(self): + + def CountLocals(frame): + return len(frame['arguments']) + len(frame['locals']) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.max_frames = 3 + self._collector.max_expand_frames = 2 + self._collector.Collect(inspect.currentframe()) + + frames = self._collector.breakpoint['stackFrames'] + self.assertLen(frames, 3) + self.assertGreater(CountLocals(frames[0]), 0) + self.assertGreater(CountLocals(frames[1]), 1) + self.assertEqual(0, CountLocals(frames[2])) + + def testSimpleArguments(self): + + def Method(unused_a, unused_b): + self._collector.Collect(inspect.currentframe()) + top_frame = self._collector.breakpoint['stackFrames'][0] + self.assertListEqual([{ + 'name': 'unused_a', + 'value': '158', + 'type': 'int' + }, { + 'name': 'unused_b', + 'value': "'hello'", + 'type': 'str' + }], top_frame['arguments']) + self.assertEqual('Method', top_frame['function']) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + Method(158, 'hello') + + def testMethodWithFirstArgumentNamedSelf(self): + this = self + + def Method(self, unused_a, unused_b): # pylint: disable=unused-argument + this._collector.Collect(inspect.currentframe()) + top_frame = this._collector.breakpoint['stackFrames'][0] + this.assertListEqual([{ + 'name': 'self', + 'value': "'world'", + 'type': 'str' + }, { + 'name': 'unused_a', + 'value': '158', + 'type': 'int' + }, { + 'name': 'unused_b', + 'value': "'hello'", + 'type': 'str' + }], top_frame['arguments']) + # This is the incorrect function name, but we are validating that no + # exceptions are thrown here. + this.assertEqual('str.Method', top_frame['function']) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + Method('world', 158, 'hello') + + def testMethodWithArgumentNamedSelf(self): + this = self + + def Method(unused_a, unused_b, self): # pylint: disable=unused-argument + this._collector.Collect(inspect.currentframe()) + top_frame = this._collector.breakpoint['stackFrames'][0] + this.assertListEqual([{ + 'name': 'unused_a', + 'value': '158', + 'type': 'int' + }, { + 'name': 'unused_b', + 'value': "'hello'", + 'type': 'str' + }, { + 'name': 'self', + 'value': "'world'", + 'type': 'str' + }], top_frame['arguments']) + this.assertEqual('Method', top_frame['function']) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + Method(158, 'hello', 'world') + + def testClassMethod(self): + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + top_frame = self._collector.breakpoint['stackFrames'][0] + self.assertListEqual([{ + 'name': 'self', + 'varTableIndex': 1 + }], top_frame['arguments']) + self.assertEqual('CaptureCollectorTest.testClassMethod', + top_frame['function']) + + def testClassMethodWithOptionalArguments(self): + + def Method(unused_a, unused_optional='notneeded'): + self._collector.Collect(inspect.currentframe()) + top_frame = self._collector.breakpoint['stackFrames'][0] + self.assertListEqual([{ + 'name': 'unused_a', + 'varTableIndex': 1 + }, { + 'name': 'unused_optional', + 'value': "'notneeded'", + 'type': 'str' + }], top_frame['arguments']) + self.assertEqual('Method', top_frame['function']) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + Method(self) + + def testClassMethodWithPositionalArguments(self): + + def Method(*unused_pos): + self._collector.Collect(inspect.currentframe()) + top_frame = self._collector.breakpoint['stackFrames'][0] + self.assertListEqual([{ + 'name': 'unused_pos', + 'type': 'tuple', + 'members': [{ + 'name': '[0]', + 'value': '1', + 'type': 'int' + }] + }], top_frame['arguments']) + self.assertEqual('Method', top_frame['function']) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + Method(1) + + def testClassMethodWithKeywords(self): + + def Method(**unused_kwd): + self._collector.Collect(inspect.currentframe()) + top_frame = self._collector.breakpoint['stackFrames'][0] + self.assertCountEqual([{ + 'name': "'first'", + 'value': '1', + 'type': 'int' + }, { + 'name': "'second'", + 'value': '2', + 'type': 'int' + }], top_frame['arguments'][0]['members']) + self.assertEqual('Method', top_frame['function']) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + Method(first=1, second=2) + + def testNoLocalVariables(self): + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + top_frame = self._collector.breakpoint['stackFrames'][0] + self.assertEmpty(top_frame['locals']) + self.assertEqual('CaptureCollectorTest.testNoLocalVariables', + top_frame['function']) + + def testRuntimeError(self): + + class BadDict(dict): + + def __init__(self, d): + d['foo'] = 'bar' + super(BadDict, self).__init__(d) + + def __getattribute__(self, attr): + raise RuntimeError('Bogus error') + + class BadType(object): + + def __init__(self): + self.__dict__ = BadDict(self.__dict__) + + unused_a = BadType() + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + var_a = self._Pack(self._LocalByName('unused_a')) + self.assertDictEqual( + { + 'name': 'unused_a', + 'status': { + 'isError': True, + 'refersTo': 'VARIABLE_VALUE', + 'description': { + 'format': 'Failed to capture variable: $0', + 'parameters': ['Bogus error'] + }, + } + }, var_a) + + def testBadDictionary(self): + + class BadDict(dict): + + def items(self): + raise AttributeError('attribute error') + + class BadType(object): + + def __init__(self): + self.good = 1 + self.bad = BadDict() + + unused_a = BadType() + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + var_a = self._Pack(self._LocalByName('unused_a')) + members = var_a['members'] + self.assertLen(members, 2) + self.assertIn({'name': 'good', 'value': '1', 'type': 'int'}, members) + self.assertIn( + { + 'name': 'bad', + 'status': { + 'isError': True, + 'refersTo': 'VARIABLE_VALUE', + 'description': { + 'format': 'Failed to capture variable: $0', + 'parameters': ['attribute error'] + }, + } + }, members) + + def testLocalVariables(self): + unused_a = 8 + unused_b = True + unused_nothing = None + unused_s = 'hippo' + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + top_frame = self._collector.breakpoint['stackFrames'][0] + self.assertLen(top_frame['arguments'], 1) # just self. + self.assertCountEqual([{ + 'name': 'unused_a', + 'value': '8', + 'type': 'int' + }, { + 'name': 'unused_b', + 'value': 'True', + 'type': 'bool' + }, { + 'name': 'unused_nothing', + 'value': 'None' + }, { + 'name': 'unused_s', + 'value': "'hippo'", + 'type': 'str' + }], top_frame['locals']) + + def testLocalVariablesWithBlacklist(self): + unused_a = collector.LineNoFilter() + unused_b = 5 + + # Side effect logic for the mock data visibility object + def IsDataVisible(name): + path_prefix = 'googleclouddebugger.collector.' + if name == path_prefix + 'LineNoFilter': + return (False, 'data blocked') + return (True, None) + + mock_policy = mock.MagicMock() + mock_policy.IsDataVisible.side_effect = IsDataVisible + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}, + mock_policy) + self._collector.Collect(inspect.currentframe()) + top_frame = self._collector.breakpoint['stackFrames'][0] + # Should be blocked + self.assertIn( + { + 'name': 'unused_a', + 'status': { + 'description': { + 'format': 'data blocked' + }, + 'refersTo': 'VARIABLE_NAME', + 'isError': True + } + }, top_frame['locals']) + # Should not be blocked + self.assertIn({ + 'name': 'unused_b', + 'value': '5', + 'type': 'int' + }, top_frame['locals']) + + def testWatchedExpressionsBlacklisted(self): + + class TestClass(object): + + def __init__(self): + self.a = 5 + + unused_a = TestClass() + + # Side effect logic for the mock data visibility object + def IsDataVisible(name): + if name == 'collector_test.TestClass': + return (False, 'data blocked') + return (True, None) + + mock_policy = mock.MagicMock() + mock_policy.IsDataVisible.side_effect = IsDataVisible + + self._collector = CaptureCollectorWithDefaultLocation( + { + 'id': 'BP_ID', + 'expressions': ['unused_a', 'unused_a.a'] + }, mock_policy) + self._collector.Collect(inspect.currentframe()) + # Class should be blocked + self.assertIn( + { + 'name': 'unused_a', + 'status': { + 'description': { + 'format': 'data blocked' + }, + 'refersTo': 'VARIABLE_NAME', + 'isError': True + } + }, self._collector.breakpoint['evaluatedExpressions']) + # TODO: Explicit member SHOULD also be blocked but this is + # currently not implemented. After fixing the implementation, change + # the test below to assert that it's blocked too. + self.assertIn({ + 'name': 'unused_a.a', + 'type': 'int', + 'value': '5' + }, self._collector.breakpoint['evaluatedExpressions']) + + def testLocalsNonTopFrame(self): + + def Method(): + self._collector.Collect(inspect.currentframe()) + self.assertListEqual([{ + 'name': 'self', + 'varTableIndex': 1 + }], self._collector.breakpoint['stackFrames'][1]['arguments']) + self.assertCountEqual([{ + 'name': 'unused_a', + 'value': '47', + 'type': 'int' + }, { + 'name': 'Method', + 'value': 'function Method' + }], self._collector.breakpoint['stackFrames'][1]['locals']) + + unused_a = 47 + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + Method() + + def testDictionaryMaxDepth(self): + d = {} + t = d + for _ in range(10): + t['inner'] = {} + t = t['inner'] + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.default_capture_limits.max_depth = 3 + self._collector.Collect(inspect.currentframe()) + self.assertDictEqual( + { + 'name': + 'd', + 'type': + 'dict', + 'members': [{ + 'name': "'inner'", + 'type': 'dict', + 'members': [{ + 'name': "'inner'", + 'varTableIndex': 0 + }] + }] + }, self._LocalByName('d')) + + def testVectorMaxDepth(self): + l = [] + t = l + for _ in range(10): + t.append([]) + t = t[0] + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.default_capture_limits.max_depth = 3 + self._collector.Collect(inspect.currentframe()) + self.assertDictEqual( + { + 'name': + 'l', + 'type': + 'list', + 'members': [{ + 'name': '[0]', + 'type': 'list', + 'members': [{ + 'name': '[0]', + 'varTableIndex': 0 + }] + }] + }, self._LocalByName('l')) + + def testStringTrimming(self): + unused_s = '123456789' + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.default_capture_limits.max_value_len = 8 + self._collector.Collect(inspect.currentframe()) + self.assertListEqual([{ + 'name': 'unused_s', + 'value': "'12345678...", + 'type': 'str' + }], self._collector.breakpoint['stackFrames'][0]['locals']) + + def testBytearrayTrimming(self): + unused_bytes = bytearray(range(20)) + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.default_capture_limits.max_value_len = 20 + self._collector.Collect(inspect.currentframe()) + self.assertListEqual([{ + 'name': 'unused_bytes', + 'value': r"bytearray(b'\x00\x01\...", + 'type': 'bytearray' + }], self._collector.breakpoint['stackFrames'][0]['locals']) + + def testObject(self): + + class MyClass(object): + + def __init__(self): + self.a = 1 + self.b = 2 + + unused_my = MyClass() + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + var_index = self._LocalByName('unused_my')['varTableIndex'] + self.assertEqual( + __name__ + '.MyClass', + self._collector.breakpoint['variableTable'][var_index]['type']) + self.assertCountEqual([{ + 'name': 'a', + 'value': '1', + 'type': 'int' + }, { + 'name': 'b', + 'value': '2', + 'type': 'int' + }], self._collector.breakpoint['variableTable'][var_index]['members']) + + def testBufferFullLocalRef(self): + + class MyClass(object): + + def __init__(self, data): + self.data = data + + def Method(): + unused_m1 = MyClass('1' * 10000) + unused_m2 = MyClass('2' * 10000) + unused_m3 = MyClass('3' * 10000) + unused_m4 = MyClass('4' * 10000) + unused_m5 = MyClass('5' * 10000) + unused_m6 = MyClass('6' * 10000) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.max_frames = 1 + self._collector.max_size = 48000 + self._collector.default_capture_limits.max_value_len = 10009 + self._collector.Collect(inspect.currentframe()) + + # Verify that 5 locals fit and 1 is out of buffer. + count = {True: 0, False: 0} # captured, not captured + for local in self._collector.breakpoint['stackFrames'][0]['locals']: + var_index = local['varTableIndex'] + self.assertLess(var_index, + len(self._collector.breakpoint['variableTable'])) + if local['name'].startswith('unused_m'): + count[var_index != 0] += 1 + self.assertDictEqual({True: 5, False: 1}, count) + + Method() + + def testBufferFullDictionaryRef(self): + + class MyClass(object): + + def __init__(self, data): + self.data = data + + def Method(): + unused_d1 = {'a': MyClass('1' * 10000)} + unused_d2 = {'b': MyClass('2' * 10000)} + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.max_frames = 1 + self._collector.max_size = 9000 + self._collector.default_capture_limits.max_value_len = 10009 + self._collector.Collect(inspect.currentframe()) + + # Verify that one of {d1,d2} could fit and the other didn't. + var_indexes = [ + self._LocalByName(n)['members'][0]['varTableIndex'] == 0 + for n in ['unused_d1', 'unused_d2'] + ] + self.assertEqual(1, sum(var_indexes)) + + Method() + + def testClassCrossReference(self): + + class MyClass(object): + pass + + m1 = MyClass() + m2 = MyClass() + m1.other = m2 + m2.other = m1 + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + m1_var_index = self._LocalByName('m1')['varTableIndex'] + m2_var_index = self._LocalByName('m2')['varTableIndex'] + + var_table = self._collector.breakpoint['variableTable'] + self.assertDictEqual( + { + 'type': __name__ + '.MyClass', + 'members': [{ + 'name': 'other', + 'varTableIndex': m1_var_index + }] + }, var_table[m2_var_index]) + self.assertDictEqual( + { + 'type': __name__ + '.MyClass', + 'members': [{ + 'name': 'other', + 'varTableIndex': m2_var_index + }] + }, var_table[m1_var_index]) + + def testCaptureVector(self): + unused_my_list = [1, 2, 3, 4, 5] + unused_my_slice = unused_my_list[1:4] + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertDictEqual( + { + 'name': + 'unused_my_list', + 'type': + 'list', + 'members': [{ + 'name': '[0]', + 'value': '1', + 'type': 'int' + }, { + 'name': '[1]', + 'value': '2', + 'type': 'int' + }, { + 'name': '[2]', + 'value': '3', + 'type': 'int' + }, { + 'name': '[3]', + 'value': '4', + 'type': 'int' + }, { + 'name': '[4]', + 'value': '5', + 'type': 'int' + }] + }, self._LocalByName('unused_my_list')) + self.assertDictEqual( + { + 'name': + 'unused_my_slice', + 'type': + 'list', + 'members': [{ + 'name': '[0]', + 'value': '2', + 'type': 'int' + }, { + 'name': '[1]', + 'value': '3', + 'type': 'int' + }, { + 'name': '[2]', + 'value': '4', + 'type': 'int' + }] + }, self._LocalByName('unused_my_slice')) + + def testCaptureDictionary(self): + unused_my_dict = { + 'first': 1, + 3.14: 'pi', + (5, 6): 7, + frozenset([5, 6]): 'frozen', + 'vector': ['odin', 'dva', 'tri'], + 'inner': { + 1: 'one' + }, + 'empty': {} + } + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + frozenset_name = 'frozenset({5, 6})' + self.assertCountEqual([{ + 'name': "'first'", + 'value': '1', + 'type': 'int' + }, { + 'name': '3.14', + 'value': "'pi'", + 'type': 'str' + }, { + 'name': '(5, 6)', + 'value': '7', + 'type': 'int' + }, { + 'name': frozenset_name, + 'value': "'frozen'", + 'type': 'str' + }, { + 'name': + "'vector'", + 'type': + 'list', + 'members': [{ + 'name': '[0]', + 'value': "'odin'", + 'type': 'str' + }, { + 'name': '[1]', + 'value': "'dva'", + 'type': 'str' + }, { + 'name': '[2]', + 'value': "'tri'", + 'type': 'str' + }] + }, { + 'name': "'inner'", + 'type': 'dict', + 'members': [{ + 'name': '1', + 'value': "'one'", + 'type': 'str' + }] + }, { + 'name': + "'empty'", + 'type': + 'dict', + 'members': [{ + 'status': { + 'refersTo': 'VARIABLE_NAME', + 'description': { + 'format': 'Empty dictionary' + } + } + }] + }], + self._LocalByName('unused_my_dict')['members']) + + def testEscapeDictionaryKey(self): + unused_dict = {} + unused_dict[u'\xe0'] = u'\xe0' + unused_dict['\x88'] = '\x88' + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + unicode_type = 'str' + unicode_name = "'\xe0'" + unicode_value = "'\xe0'" + + self.assertCountEqual([{ + 'type': 'str', + 'name': "'\\x88'", + 'value': "'\\x88'" + }, { + 'type': unicode_type, + 'name': unicode_name, + 'value': unicode_value + }], + self._LocalByName('unused_dict')['members']) + + def testOversizedList(self): + unused_big_list = ['x'] * 10000 + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + members = self._LocalByName('unused_big_list')['members'] + + self.assertLen(members, 26) + self.assertDictEqual({ + 'name': '[7]', + 'value': "'x'", + 'type': 'str' + }, members[7]) + self.assertDictEqual( + { + 'status': { + 'refersTo': 'VARIABLE_VALUE', + 'description': { + 'format': ( + 'Only first $0 items were captured. Use in an expression' + ' to see all items.'), + 'parameters': ['25'] + } + } + }, members[25]) + + def testOversizedDictionary(self): + unused_big_dict = {'item' + str(i): i**2 for i in range(26)} + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + members = self._LocalByName('unused_big_dict')['members'] + + self.assertLen(members, 26) + self.assertDictEqual( + { + 'status': { + 'refersTo': 'VARIABLE_VALUE', + 'description': { + 'format': ( + 'Only first $0 items were captured. Use in an expression' + ' to see all items.'), + 'parameters': ['25'] + } + } + }, members[25]) + + def testEmptyDictionary(self): + unused_empty_dict = {} + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertEqual( + { + 'name': + 'unused_empty_dict', + 'type': + 'dict', + 'members': [{ + 'status': { + 'refersTo': 'VARIABLE_NAME', + 'description': { + 'format': 'Empty dictionary' + } + } + }] + }, self._LocalByName('unused_empty_dict')) + + def testEmptyCollection(self): + for unused_c, object_type in [([], 'list'), ((), 'tuple'), (set(), 'set')]: + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertEqual( + { + 'name': + 'unused_c', + 'type': + object_type, + 'members': [{ + 'status': { + 'refersTo': 'VARIABLE_NAME', + 'description': { + 'format': 'Empty collection' + } + } + }] + }, self._Pack(self._LocalByName('unused_c'))) + + def testEmptyClass(self): + + class EmptyObject(object): + pass + + unused_empty_object = EmptyObject() + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertEqual( + { + 'name': + 'unused_empty_object', + 'type': + __name__ + '.EmptyObject', + 'members': [{ + 'status': { + 'refersTo': 'VARIABLE_NAME', + 'description': { + 'format': 'Object has no fields' + } + } + }] + }, self._Pack(self._LocalByName('unused_empty_object'))) + + def testWatchedExpressionsSuccess(self): + unused_dummy_a = 'x' + unused_dummy_b = {1: 2, 3: 'a'} + + self._collector = CaptureCollectorWithDefaultLocation({ + 'id': 'BP_ID', + 'expressions': ['1+2', 'unused_dummy_a*8', 'unused_dummy_b'] + }) + self._collector.Collect(inspect.currentframe()) + self.assertListEqual([{ + 'name': '1+2', + 'value': '3', + 'type': 'int' + }, { + 'name': 'unused_dummy_a*8', + 'value': "'xxxxxxxx'", + 'type': 'str' + }, { + 'name': + 'unused_dummy_b', + 'type': + 'dict', + 'members': [{ + 'name': '1', + 'value': '2', + 'type': 'int' + }, { + 'name': '3', + 'value': "'a'", + 'type': 'str' + }] + }], self._collector.breakpoint['evaluatedExpressions']) + + def testOversizedStringExpression(self): + # This test checks that string expressions are collected first, up to the + # max size. The last 18 characters of the string will be missing due to the + # size for the name (14 bytes), type name (3 bytes), and the opening quote + # (1 byte). This test may be sensitive to minor changes in the collector + # code. If it turns out to break easily, consider simply verifying + # that the first 400 characters are collected, since that should suffice to + # ensure that we're not using the normal limit of 256 bytes. + self._collector = CaptureCollectorWithDefaultLocation({ + 'id': 'BP_ID', + 'expressions': ['unused_dummy_a'] + }) + self._collector.max_size = 500 + unused_dummy_a = '|'.join(['%04d' % i for i in range(5, 510, 5)]) + self._collector.Collect(inspect.currentframe()) + self.assertListEqual([{ + 'name': 'unused_dummy_a', + 'type': 'str', + 'value': "'{0}...".format(unused_dummy_a[0:-18]) + }], self._collector.breakpoint['evaluatedExpressions']) + + def testOversizedListExpression(self): + self._collector = CaptureCollectorWithDefaultLocation({ + 'id': 'BP_ID', + 'expressions': ['unused_dummy_a'] + }) + unused_dummy_a = list(range(0, 100)) + self._collector.Collect(inspect.currentframe()) + # Verify that the list did not get truncated. + self.assertListEqual([{ + 'name': + 'unused_dummy_a', + 'type': + 'list', + 'members': [{ + 'type': 'int', + 'value': str(a), + 'name': '[{0}]'.format(a) + } for a in unused_dummy_a] + }], self._collector.breakpoint['evaluatedExpressions']) + + def testExpressionNullBytes(self): + self._collector = CaptureCollectorWithDefaultLocation({ + 'id': 'BP_ID', + 'expressions': ['\0'] + }) + self._collector.Collect(inspect.currentframe()) + + evaluated_expressions = self._collector.breakpoint['evaluatedExpressions'] + self.assertLen(evaluated_expressions, 1) + self.assertTrue(evaluated_expressions[0]['status']['isError']) + + def testSyntaxErrorExpression(self): + self._collector = CaptureCollectorWithDefaultLocation({ + 'id': 'BP_ID', + 'expressions': ['2+'] + }) + self._collector.Collect(inspect.currentframe()) + + evaluated_expressions = self._collector.breakpoint['evaluatedExpressions'] + self.assertLen(evaluated_expressions, 1) + self.assertTrue(evaluated_expressions[0]['status']['isError']) + self.assertEqual('VARIABLE_NAME', + evaluated_expressions[0]['status']['refersTo']) + + def testExpressionException(self): + unused_dummy_a = 1 + unused_dummy_b = 0 + self._collector = CaptureCollectorWithDefaultLocation({ + 'id': 'BP_ID', + 'expressions': ['unused_dummy_a/unused_dummy_b'] + }) + self._collector.Collect(inspect.currentframe()) + + zero_division_msg = 'division by zero' + + self.assertListEqual([{ + 'name': 'unused_dummy_a/unused_dummy_b', + 'status': { + 'isError': True, + 'refersTo': 'VARIABLE_VALUE', + 'description': { + 'format': 'Exception occurred: $0', + 'parameters': [zero_division_msg] + } + } + }], self._collector.breakpoint['evaluatedExpressions']) + + def testMutableExpression(self): + + def ChangeA(): + self._a += 1 + + self._a = 0 + ChangeA() + self._collector = CaptureCollectorWithDefaultLocation({ + 'id': 'BP_ID', + 'expressions': ['ChangeA()'] + }) + self._collector.Collect(inspect.currentframe()) + + self.assertEqual(1, self._a) + self.assertListEqual([{ + 'name': 'ChangeA()', + 'status': { + 'isError': True, + 'refersTo': 'VARIABLE_VALUE', + 'description': { + 'format': + 'Exception occurred: $0', + 'parameters': [('Only immutable methods can be ' + 'called from expressions')] + } + } + }], self._collector.breakpoint['evaluatedExpressions']) + + def testPrettyPrinters(self): + + class MyClass(object): + pass + + def PrettyPrinter1(obj): + if obj != unused_obj1: + return None + return ((('name1_%d' % i, '1_%d' % i) for i in range(2)), 'pp-type1') + + def PrettyPrinter2(obj): + if obj != unused_obj2: + return None + return ((('name2_%d' % i, '2_%d' % i) for i in range(3)), 'pp-type2') + + collector.CaptureCollector.pretty_printers += [ + PrettyPrinter1, PrettyPrinter2 + ] + + unused_obj1 = MyClass() + unused_obj2 = MyClass() + unused_obj3 = MyClass() + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + obj_vars = [ + self._Pack(self._LocalByName('unused_obj%d' % i)) for i in range(1, 4) + ] + + self.assertListEqual([{ + 'name': + 'unused_obj1', + 'type': + 'pp-type1', + 'members': [{ + 'name': 'name1_0', + 'value': "'1_0'", + 'type': 'str' + }, { + 'name': 'name1_1', + 'value': "'1_1'", + 'type': 'str' + }] + }, { + 'name': + 'unused_obj2', + 'type': + 'pp-type2', + 'members': [{ + 'name': 'name2_0', + 'value': "'2_0'", + 'type': 'str' + }, { + 'name': 'name2_1', + 'value': "'2_1'", + 'type': 'str' + }, { + 'name': 'name2_2', + 'value': "'2_2'", + 'type': 'str' + }] + }, { + 'name': + 'unused_obj3', + 'type': + __name__ + '.MyClass', + 'members': [{ + 'status': { + 'refersTo': 'VARIABLE_NAME', + 'description': { + 'format': 'Object has no fields' + } + } + }] + }], obj_vars) + + def testDateTime(self): + unused_datetime = datetime.datetime(2014, 6, 11, 2, 30) + unused_date = datetime.datetime(1980, 3, 1) + unused_time = datetime.time(18, 43, 11) + unused_timedelta = datetime.timedelta(days=3, microseconds=8237) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertDictEqual( + { + 'name': 'unused_datetime', + 'type': 'datetime.datetime', + 'value': '2014-06-11 02:30:00' + }, self._Pack(self._LocalByName('unused_datetime'))) + + self.assertDictEqual( + { + 'name': 'unused_date', + 'type': 'datetime.datetime', + 'value': '1980-03-01 00:00:00' + }, self._Pack(self._LocalByName('unused_date'))) + + self.assertDictEqual( + { + 'name': 'unused_time', + 'type': 'datetime.time', + 'value': '18:43:11' + }, self._Pack(self._LocalByName('unused_time'))) + + self.assertDictEqual( + { + 'name': 'unused_timedelta', + 'type': 'datetime.timedelta', + 'value': '3 days, 0:00:00.008237' + }, self._Pack(self._LocalByName('unused_timedelta'))) + + def testException(self): + unused_exception = ValueError('arg1', 2, [3]) + + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + obj = self._Pack(self._LocalByName('unused_exception')) + + self.assertEqual('unused_exception', obj['name']) + self.assertEqual('ValueError', obj['type']) + self.assertListEqual([{ + 'value': "'arg1'", + 'type': 'str', + 'name': '[0]' + }, { + 'value': '2', + 'type': 'int', + 'name': '[1]' + }, { + 'members': [{ + 'value': '3', + 'type': 'int', + 'name': '[0]' + }], + 'type': 'list', + 'name': '[2]' + }], obj['members']) + + def testRequestLogIdCapturing(self): + collector.request_log_id_collector = lambda: 'test_log_id' + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertIn('labels', self._collector.breakpoint) + self.assertEqual( + 'test_log_id', + self._collector.breakpoint['labels'][labels.Breakpoint.REQUEST_LOG_ID]) + + def testRequestLogIdCapturingNoId(self): + collector.request_log_id_collector = lambda: None + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + def testRequestLogIdCapturingNoCollector(self): + collector.request_log_id_collector = None + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + def testUserIdSuccess(self): + collector.user_id_collector = lambda: ('mdb_user', 'noogler') + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertIn('evaluatedUserId', self._collector.breakpoint) + self.assertEqual({ + 'kind': 'mdb_user', + 'id': 'noogler' + }, self._collector.breakpoint['evaluatedUserId']) + + def testUserIdIsNone(self): + collector.user_id_collector = lambda: (None, None) + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertNotIn('evaluatedUserId', self._collector.breakpoint) + + def testUserIdNoKind(self): + collector.user_id_collector = lambda: (None, 'noogler') + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertNotIn('evaluatedUserId', self._collector.breakpoint) + + def testUserIdNoValue(self): + collector.user_id_collector = lambda: ('mdb_user', None) + self._collector = CaptureCollectorWithDefaultLocation({'id': 'BP_ID'}) + self._collector.Collect(inspect.currentframe()) + + self.assertNotIn('evaluatedUserId', self._collector.breakpoint) + + def _LocalByName(self, name, frame=0): + for local in self._collector.breakpoint['stackFrames'][frame]['locals']: + if local['name'] == name: + return local + self.fail('Local %s not found in frame %d' % (name, frame)) + + def _Pack(self, variable): + """Embeds variables referenced through var_index.""" + packed_variable = copy.copy(variable) + + var_index = variable.get('varTableIndex') + if var_index is not None: + packed_variable.update( + self._collector.breakpoint['variableTable'][var_index]) + del packed_variable['varTableIndex'] + + if 'members' in packed_variable: + packed_variable['members'] = [ + self._Pack(member) for member in packed_variable['members'] + ] + + return packed_variable + + +class LogCollectorTest(absltest.TestCase): + """Unit test for log collector.""" + + def setUp(self): + self._logger = logging.getLogger('test') + + class LogVerifier(logging.Handler): + + def __init__(self): + super(LogVerifier, self).__init__() + self._received_records = [] + + def emit(self, record): + self._received_records.append(record) + + def GotMessage(self, + msg, + level=logging.INFO, + line_number=10, + func_name=None): + """Checks that the given message was logged correctly. + + This method verifies both the contents and the source location of the + message match expectations. + + Args: + msg: The expected message + level: The expected logging level. + line_number: The expected line number. + func_name: If specified, the expected log record must have a funcName + equal to this value. + Returns: + True iff the oldest unverified message matches the given attributes. + """ + record = self._received_records.pop(0) + frame = inspect.currentframe().f_back + if level != record.levelno: + logging.error('Expected log level %d, got %d (%s)', level, + record.levelno, record.levelname) + return False + if msg != record.msg: + logging.error('Expected msg "%s", received "%s"', msg, record.msg) + return False + pathname = collector.NormalizePath(frame.f_code.co_filename) + if pathname != record.pathname: + logging.error('Expected pathname "%s", received "%s"', pathname, + record.pathname) + return False + if os.path.basename(pathname) != record.filename: + logging.error('Expected filename "%s", received "%s"', + os.path.basename(pathname), record.filename) + return False + if func_name and func_name != record.funcName: + logging.error('Expected function "%s", received "%s"', func_name, + record.funcName) + return False + if line_number and record.lineno != line_number: + logging.error('Expected lineno %d, received %d', line_number, + record.lineno) + return False + for attr in ['cdbg_pathname', 'cdbg_lineno']: + if hasattr(record, attr): + logging.error('Attribute %s still present in log record', attr) + return False + return True + + def CheckMessageSafe(self, msg): + """Checks that the given message was logged correctly. + + Unlike GotMessage, this will only check the contents, and will not log + an error or pop the record if the message does not match. + + Args: + msg: The expected message + Returns: + True iff the oldest unverified message matches the given attributes. + """ + record = self._received_records[0] + if msg != record.msg: + print(record.msg) + return False + self._received_records.pop(0) + return True + + self._verifier = LogVerifier() + self._logger.addHandler(self._verifier) + self._logger.setLevel(logging.INFO) + collector.SetLogger(self._logger) + + # Give some time for the global quota to recover + time.sleep(0.1) + + def tearDown(self): + self._logger.removeHandler(self._verifier) + + def ResetGlobalLogQuota(self): + # The global log quota takes up to 5 seconds to fully fill back up to + # capacity (kDynamicLogCapacityFactor is 5). The capacity is 5 times the per + # second fill rate. The best we can do is a sleep, since the global + # leaky_bucket instance is inaccessible to the test. + time.sleep(5.0) + + def ResetGlobalLogBytesQuota(self): + # The global log bytes quota takes up to 2 seconds to fully fill back up to + # capacity (kDynamicLogBytesCapacityFactor is 2). The capacity is twice the + # per second fill rate. The best we can do is a sleep, since the global + # leaky_bucket instance is inaccessible to the test. + time.sleep(2.0) + + def testLogQuota(self): + # Attempt to get to a known starting state by letting the global quota fully + # recover so the ordering of tests ideally doesn't affect this test. + self.ResetGlobalLogQuota() + bucket_max_capacity = 250 + log_collector = LogCollectorWithDefaultLocation({ + 'logMessageFormat': '$0', + 'expressions': ['i'] + }) + for i in range(0, bucket_max_capacity * 2): + self.assertIsNone(log_collector.Log(inspect.currentframe())) + if not self._verifier.CheckMessageSafe('LOGPOINT: %s' % i): + self.assertGreaterEqual(i, bucket_max_capacity, + 'Log quota exhausted earlier than expected') + self.assertTrue( + self._verifier.CheckMessageSafe(LOGPOINT_PAUSE_MSG), + 'Quota hit message not logged') + time.sleep(0.6) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.CheckMessageSafe('LOGPOINT: %s' % i), + 'Logging not resumed after quota recovery time') + return + self.fail('Logging was never paused when quota was exceeded') + + def testLogBytesQuota(self): + # Attempt to get to a known starting state by letting the global quota fully + # recover so the ordering of tests ideally doesn't affect this test. + self.ResetGlobalLogBytesQuota() + + # Default capacity is 40960, though based on how the leaky bucket is + # implemented, it can allow effectively twice that amount to go out in a + # very short time frame. So the third 30k message should pause. + msg = ' ' * 30000 + log_collector = LogCollectorWithDefaultLocation({'logMessageFormat': msg}) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage('LOGPOINT: ' + msg)) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage('LOGPOINT: ' + msg)) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.CheckMessageSafe(LOGPOINT_PAUSE_MSG), + 'Quota hit message not logged') + time.sleep(0.6) + log_collector._definition['logMessageFormat'] = 'hello' + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage('LOGPOINT: hello'), + 'Logging was not resumed after quota recovery time') + + def testMissingLogLevel(self): + # Missing is equivalent to INFO. + log_collector = LogCollectorWithDefaultLocation( + {'logMessageFormat': 'hello'}) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage('LOGPOINT: hello')) + + def testUndefinedLogLevel(self): + collector.log_info_message = None + log_collector = LogCollectorWithDefaultLocation({'logLevel': 'INFO'}) + self.assertDictEqual( + { + 'isError': True, + 'description': { + 'format': 'Log action on a breakpoint not supported' + } + }, log_collector.Log(inspect.currentframe())) + + def testLogInfo(self): + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': 'hello' + }) + log_collector._definition['location']['line'] = 20 + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: hello', + func_name='LogCollectorTest.testLogInfo', + line_number=20)) + + def testLogWarning(self): + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'WARNING', + 'logMessageFormat': 'hello' + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: hello', + level=logging.WARNING, + func_name='LogCollectorTest.testLogWarning')) + + def testLogError(self): + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'ERROR', + 'logMessageFormat': 'hello' + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: hello', + level=logging.ERROR, + func_name='LogCollectorTest.testLogError')) + + def testBadExpression(self): + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': 'a=$0, b=$1', + 'expressions': ['-', '+'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + if sys.version_info.minor < 10: + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: a=, b=')) + else: + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: a=, ' + 'b=')) + + def testDollarEscape(self): + unused_integer = 12345 + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$ $$ $$$ $$$$ $0 $$0 $$$0 $$$$0 $1 hello', + 'expressions': ['unused_integer'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + msg = 'LOGPOINT: $ $ $$ $$ 12345 $0 $12345 $$0 hello' + self.assertTrue(self._verifier.GotMessage(msg)) + + def testInvalidExpressionIndex(self): + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': 'a=$0' + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage('LOGPOINT: a=')) + + def testException(self): + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['[][1]'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: ')) + + def testMutableExpression(self): + + def MutableMethod(): # pylint: disable=unused-variable + self.abc = None + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['MutableMethod()'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: ')) + + def testNone(self): + unused_none = None + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_none'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage('LOGPOINT: None')) + + def testPrimitives(self): + unused_boolean = True + unused_integer = 12345 + unused_string = 'hello' + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0,$1,$2', + 'expressions': ['unused_boolean', 'unused_integer', 'unused_string'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage("LOGPOINT: True,12345,'hello'")) + + def testLongString(self): + unused_string = '1234567890' + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_string'] + }) + log_collector.max_value_len = 9 + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage("LOGPOINT: '123456789...")) + + def testLongBytes(self): + unused_bytes = bytearray([i for i in range(20)]) + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_bytes'] + }) + log_collector.max_value_len = 20 + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage(r"LOGPOINT: bytearray(b'\x00\x01\...")) + + def testDate(self): + unused_datetime = datetime.datetime(2014, 6, 11, 2, 30) + unused_date = datetime.datetime(1980, 3, 1) + unused_time = datetime.time(18, 43, 11) + unused_timedelta = datetime.timedelta(days=3, microseconds=8237) + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': + 'INFO', + 'logMessageFormat': + '$0;$1;$2;$3', + 'expressions': [ + 'unused_datetime', 'unused_date', 'unused_time', 'unused_timedelta' + ] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: 2014-06-11 02:30:00;1980-03-01 00:00:00;' + '18:43:11;3 days, 0:00:00.008237')) + + def testSet(self): + unused_set = set(['a']) + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_set'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage("LOGPOINT: {'a'}")) + + def testTuple(self): + unused_tuple = (1, 2, 3, 4, 5) + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_tuple'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage('LOGPOINT: (1, 2, 3, 4, 5)')) + + def testList(self): + unused_list = ['a', 'b', 'c'] + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_list'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage("LOGPOINT: ['a', 'b', 'c']")) + + def testOversizedList(self): + unused_list = [1, 2, 3, 4] + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_list'] + }) + log_collector.max_list_items = 3 + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage('LOGPOINT: [1, 2, 3, ...]')) + + def testSlice(self): + unused_slice = slice(1, 10) + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_slice'] + }) + collector.max_list_items = 3 + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage('LOGPOINT: slice(1, 10, None)')) + + def testMap(self): + unused_map = {'a': 1} + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_map'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage("LOGPOINT: {'a': 1}")) + + def testObject(self): + + class MyClass(object): + + def __init__(self): + self.some = 'thing' + + unused_my = MyClass() + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_my'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue(self._verifier.GotMessage("LOGPOINT: {'some': 'thing'}")) + + def testNestedBelowLimit(self): + unused_list = [1, [2], [1, 2, 3], [1, [1, 2, 3]], 5] + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_list'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: [1, [2], [1, 2, 3], [1, [1, 2, 3]], 5]')) + + def testNestedAtLimits(self): + unused_list = [ + 1, [1, 2, 3, 4, 5], [[1, 2, 3, 4, 5], 2, 3, 4, 5], 4, 5, 6, 7, 8, 9 + ] + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_list'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: [1, [1, 2, 3, 4, 5], [[1, 2, 3, 4, 5], 2, 3, 4, 5], ' + '4, 5, 6, 7, 8, 9]')) + + def testNestedRecursionLimit(self): + unused_list = [1, [[2, [3]], 4], 5] + + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_list'] + }) + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage('LOGPOINT: [1, [[2, %s], 4], 5]' % type([]))) + + def testNestedRecursionItemLimits(self): + unused_list = [1, [1, [1, [2], 3, 4], 3, 4], 3, 4] + + list_type = "" + log_collector = LogCollectorWithDefaultLocation({ + 'logLevel': 'INFO', + 'logMessageFormat': '$0', + 'expressions': ['unused_list'] + }) + log_collector.max_list_items = 3 + log_collector.max_sublist_items = 3 + self.assertIsNone(log_collector.Log(inspect.currentframe())) + self.assertTrue( + self._verifier.GotMessage( + 'LOGPOINT: [1, [1, [1, %s, 3, ...], 3, ...], 3, ...]' % list_type)) + + def testDetermineType(self): + builtin_prefix = 'builtins.' + path_prefix = 'googleclouddebugger.collector.' + test_data = ( + (builtin_prefix + 'int', 5), + (builtin_prefix + 'str', 'hello'), + (builtin_prefix + 'function', collector.DetermineType), + (path_prefix + 'LineNoFilter', collector.LineNoFilter()), + ) + + for type_string, value in test_data: + self.assertEqual(type_string, collector.DetermineType(value)) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/error_data_visibility_policy_test.py b/tests/py/error_data_visibility_policy_test.py new file mode 100644 index 0000000..6c8b6d9 --- /dev/null +++ b/tests/py/error_data_visibility_policy_test.py @@ -0,0 +1,17 @@ +"""Tests for googleclouddebugger.error_data_visibility_policy.""" + +from absl.testing import absltest +from googleclouddebugger import error_data_visibility_policy + + +class ErrorDataVisibilityPolicyTest(absltest.TestCase): + + def testIsDataVisible(self): + policy = error_data_visibility_policy.ErrorDataVisibilityPolicy( + 'An error message.') + + self.assertEqual((False, 'An error message.'), policy.IsDataVisible('foo')) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/firebase_client_test.py b/tests/py/firebase_client_test.py new file mode 100644 index 0000000..d72c68c --- /dev/null +++ b/tests/py/firebase_client_test.py @@ -0,0 +1,875 @@ +"""Unit tests for firebase_client module.""" + +import copy +import os +import sys +import tempfile +import time +from unittest import mock +from unittest.mock import ANY +from unittest.mock import MagicMock +from unittest.mock import call +from unittest.mock import patch +import requests +import requests_mock + +from googleclouddebugger import version +from googleclouddebugger import firebase_client + +from absl.testing import absltest +from absl.testing import parameterized + +import firebase_admin.credentials +from firebase_admin.exceptions import FirebaseError +from firebase_admin.exceptions import NotFoundError + +TEST_PROJECT_ID = 'test-project-id' +METADATA_PROJECT_URL = ('http://metadata.google.internal/computeMetadata/' + 'v1/project/project-id') + + +class FakeEvent: + + def __init__(self, event_type, path, data): + self.event_type = event_type + self.path = path + self.data = data + + +class FakeReference: + + def __init__(self): + self.subscriber = None + + def listen(self, callback): + self.subscriber = callback + + def update(self, event_type, path, data): + if self.subscriber: + event = FakeEvent(event_type, path, data) + self.subscriber(event) + + +class FirebaseClientTest(parameterized.TestCase): + """Simulates service account authentication.""" + + def setUp(self): + version.__version__ = 'test' + + self._client = firebase_client.FirebaseClient() + + self.breakpoints_changed_count = 0 + self.breakpoints = {} + + # Speed up the delays for retry loops. + for backoff in [ + self._client.connect_backoff, self._client.register_backoff, + self._client.subscribe_backoff, self._client.update_backoff + ]: + backoff.min_interval_sec /= 100000.0 + backoff.max_interval_sec /= 100000.0 + backoff._current_interval_sec /= 100000.0 + + # Set up patchers. + patcher = patch('firebase_admin.initialize_app') + self._mock_initialize_app = patcher.start() + self.addCleanup(patcher.stop) + + patcher = patch('firebase_admin.delete_app') + self._mock_delete_app = patcher.start() + self.addCleanup(patcher.stop) + + patcher = patch('firebase_admin.db.reference') + self._mock_db_ref = patcher.start() + self.addCleanup(patcher.stop) + + # Set up the mocks for the database refs. + self._firebase_app = 'FIREBASE_APP_HANDLE' + self._mock_initialize_app.return_value = self._firebase_app + self._mock_schema_version_ref = MagicMock() + self._mock_schema_version_ref.get.return_value = "2" + self._mock_presence_ref = MagicMock() + self._mock_presence_ref.get.return_value = None + self._mock_active_ref = MagicMock() + self._mock_register_ref = MagicMock() + self._fake_subscribe_ref = FakeReference() + + # Setup common happy path reference sequence: + # cdbg/schema_version + # cdbg/debuggees/{debuggee_id}/registrationTimeUnixMsec + # cdbg/debuggees/{debuggee_id} + # cdbg/breakpoints/{debuggee_id}/active + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, self._mock_presence_ref, + self._mock_register_ref, self._fake_subscribe_ref + ] + + def tearDown(self): + self._client.Stop() + + def testSetupAuthDefault(self): + # By default, we try getting the project id from the metadata server. + # Note that actual credentials are not fetched. + with requests_mock.Mocker() as m: + m.get(METADATA_PROJECT_URL, text=TEST_PROJECT_ID) + + self._client.SetupAuth() + + self.assertEqual(TEST_PROJECT_ID, self._client._project_id) + + def testSetupAuthOverrideProjectIdNumber(self): + # If a project id is provided, we use it. + project_id = 'project2' + self._client.SetupAuth(project_id=project_id) + + self.assertEqual(project_id, self._client._project_id) + + def testSetupAuthServiceAccountJsonAuth(self): + # We'll load credentials from the provided file (mocked for simplicity) + with mock.patch.object(firebase_admin.credentials, + 'Certificate') as firebase_certificate: + json_file = tempfile.NamedTemporaryFile() + # And load the project id from the file as well. + with open(json_file.name, 'w', encoding='utf-8') as f: + f.write(f'{{"project_id": "{TEST_PROJECT_ID}"}}') + self._client.SetupAuth(service_account_json_file=json_file.name) + + firebase_certificate.assert_called_with(json_file.name) + self.assertEqual(TEST_PROJECT_ID, self._client._project_id) + + def testSetupAuthNoProjectId(self): + # There will be an exception raised if we try to contact the metadata + # server on a non-gcp machine. + with requests_mock.Mocker() as m: + m.get(METADATA_PROJECT_URL, exc=requests.exceptions.RequestException) + + with self.assertRaises(firebase_client.NoProjectIdError): + self._client.SetupAuth() + + def testStart(self): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + debuggee_id = self._client._debuggee_id + + self._mock_initialize_app.assert_called_with( + None, {'databaseURL': f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com'}, + name='cdbg') + self.assertEqual([ + call(f'cdbg/schema_version', self._firebase_app), + call(f'cdbg/debuggees/{debuggee_id}/registrationTimeUnixMsec', + self._firebase_app), + call(f'cdbg/debuggees/{debuggee_id}', self._firebase_app), + call(f'cdbg/breakpoints/{debuggee_id}/active', self._firebase_app) + ], self._mock_db_ref.call_args_list) + + # Verify that the register call has been made. + expected_data = copy.deepcopy(self._client._GetDebuggee()) + expected_data['registrationTimeUnixMsec'] = {'.sv': 'timestamp'} + expected_data['lastUpdateTimeUnixMsec'] = {'.sv': 'timestamp'} + self._mock_register_ref.set.assert_called_once_with(expected_data) + + def testStartCustomDbUrlConfigured(self): + self._client.SetupAuth( + project_id=TEST_PROJECT_ID, + database_url='https://custom-db.firebaseio.com') + self._client.Start() + self._client.connection_complete.wait() + + debuggee_id = self._client._debuggee_id + + self._mock_initialize_app.assert_called_once_with( + None, {'databaseURL': 'https://custom-db.firebaseio.com'}, name='cdbg') + + def testStartConnectFallsBackToDefaultRtdb(self): + # A new schema_version ref will be fetched each time + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, self._mock_schema_version_ref, + self._mock_presence_ref, self._mock_register_ref, + self._fake_subscribe_ref + ] + + # Fail on the '-cdbg' instance test, succeed on the '-default-rtdb' one. + self._mock_schema_version_ref.get.side_effect = [ + NotFoundError("Not found", http_response=404), '2' + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.connection_complete.wait() + + self.assertEqual([ + call( + None, + {'databaseURL': f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com'}, + name='cdbg'), + call( + None, { + 'databaseURL': + f'https://{TEST_PROJECT_ID}-default-rtdb.firebaseio.com' + }, + name='cdbg') + ], self._mock_initialize_app.call_args_list) + + self.assertEqual(1, self._mock_delete_app.call_count) + + def testStartConnectFailsThenSucceeds(self): + # A new schema_version ref will be fetched each time + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, self._mock_schema_version_ref, + self._mock_schema_version_ref, self._mock_presence_ref, + self._mock_register_ref, self._fake_subscribe_ref + ] + + # Completely fail on the initial attempt at reaching a DB, then succeed on + # 2nd attempt. One full attempt will try the '-cdbg' db instance followed by + # the '-default-rtdb' one. + self._mock_schema_version_ref.get.side_effect = [ + NotFoundError("Not found", http_response=404), + NotFoundError("Not found", http_response=404), '2' + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.connection_complete.wait() + + self.assertEqual([ + call( + None, + {'databaseURL': f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com'}, + name='cdbg'), + call( + None, { + 'databaseURL': + f'https://{TEST_PROJECT_ID}-default-rtdb.firebaseio.com' + }, + name='cdbg'), + call( + None, + {'databaseURL': f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com'}, + name='cdbg') + ], self._mock_initialize_app.call_args_list) + + self.assertEqual(2, self._mock_delete_app.call_count) + + def testStartAlreadyPresent(self): + # Create a mock for just this test that claims the debuggee is registered. + mock_presence_ref = MagicMock() + mock_presence_ref.get.return_value = 'present!' + + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, mock_presence_ref, self._mock_active_ref, + self._fake_subscribe_ref + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + debuggee_id = self._client._debuggee_id + + self.assertEqual([ + call(f'cdbg/schema_version', self._firebase_app), + call(f'cdbg/debuggees/{debuggee_id}/registrationTimeUnixMsec', + self._firebase_app), + call(f'cdbg/debuggees/{debuggee_id}/lastUpdateTimeUnixMsec', + self._firebase_app), + call(f'cdbg/breakpoints/{debuggee_id}/active', self._firebase_app) + ], self._mock_db_ref.call_args_list) + + # Verify that the register call has been made. + self._mock_active_ref.set.assert_called_once_with({'.sv': 'timestamp'}) + + def testStartRegisterRetry(self): + # A new set of db refs are fetched on each retry. + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, self._mock_presence_ref, + self._mock_register_ref, self._mock_presence_ref, + self._mock_register_ref, self._fake_subscribe_ref + ] + + # Fail once, then succeed on retry. + self._mock_register_ref.set.side_effect = [FirebaseError(1, 'foo'), None] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.registration_complete.wait() + + self.assertEqual(2, self._mock_register_ref.set.call_count) + + def testStartSubscribeRetry(self): + mock_subscribe_ref = MagicMock() + mock_subscribe_ref.listen.side_effect = FirebaseError(1, 'foo') + + # A new db ref is fetched on each retry. + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, + self._mock_presence_ref, + self._mock_register_ref, + mock_subscribe_ref, # Fail the first time + self._fake_subscribe_ref # Succeed the second time + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + self.assertEqual(5, self._mock_db_ref.call_count) + + def testMarkActiveTimer(self): + # Make sure that there are enough refs queued up. + refs = list(self._mock_db_ref.side_effect) + refs.extend([self._mock_active_ref] * 10) + self._mock_db_ref.side_effect = refs + + # Speed things WAY up rather than waiting for hours. + self._client._mark_active_interval_sec = 0.1 + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + # wait long enough for the timer to trigger a few times. + time.sleep(0.5) + + print(f'Timer triggered {self._mock_active_ref.set.call_count} times') + self.assertTrue(self._mock_active_ref.set.call_count > 3) + self._mock_active_ref.set.assert_called_with({'.sv': 'timestamp'}) + + def testBreakpointSubscription(self): + # This class will keep track of the breakpoint updates and will check + # them against expectations. + class ResultChecker: + + def __init__(self, expected_results, test): + self._expected_results = expected_results + self._test = test + self._change_count = 0 + + def callback(self, new_breakpoints): + self._test.assertEqual(self._expected_results[self._change_count], + new_breakpoints) + self._change_count += 1 + + breakpoints = [ + { + 'id': 'breakpoint-0', + 'location': { + 'path': 'foo.py', + 'line': 18 + } + }, + { + 'id': 'breakpoint-1', + 'location': { + 'path': 'bar.py', + 'line': 23 + } + }, + { + 'id': 'breakpoint-2', + 'location': { + 'path': 'baz.py', + 'line': 45 + } + }, + ] + + expected_results = [[breakpoints[0]], [breakpoints[0], breakpoints[1]], + [breakpoints[0], breakpoints[1], breakpoints[2]], + [breakpoints[1], breakpoints[2]], + [breakpoints[1], breakpoints[2]]] + result_checker = ResultChecker(expected_results, self) + + self._client.on_active_breakpoints_changed = result_checker.callback + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + # Send in updates to trigger the subscription callback. + + # Initial state. + self._fake_subscribe_ref.update('put', '/', + {breakpoints[0]['id']: breakpoints[0]}) + # Add a breakpoint via patch. + self._fake_subscribe_ref.update('patch', '/', + {breakpoints[1]['id']: breakpoints[1]}) + # Add a breakpoint via put. + self._fake_subscribe_ref.update('put', f'/{breakpoints[2]["id"]}', + breakpoints[2]) + # Delete a breakpoint. + self._fake_subscribe_ref.update('put', f'/{breakpoints[0]["id"]}', None) + # Delete the breakpoint a second time; should handle this gracefully. + self._fake_subscribe_ref.update('put', f'/{breakpoints[0]["id"]}', None) + + self.assertEqual(len(expected_results), result_checker._change_count) + + def testEnqueueBreakpointUpdate(self): + active_ref_mock = MagicMock() + snapshot_ref_mock = MagicMock() + final_ref_mock = MagicMock() + + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, self._mock_presence_ref, + self._mock_register_ref, self._fake_subscribe_ref, active_ref_mock, + snapshot_ref_mock, final_ref_mock + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + debuggee_id = self._client._debuggee_id + breakpoint_id = 'breakpoint-0' + + input_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'evaluatedExpressions': ['expressions go here'], + 'stackFrames': ['stuff goes here'], + 'variableTable': ['lots', 'of', 'variables'], + } + short_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'action': 'CAPTURE', + 'finalTimeUnixMsec': { + '.sv': 'timestamp' + } + } + full_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'action': 'CAPTURE', + 'evaluatedExpressions': ['expressions go here'], + 'stackFrames': ['stuff goes here'], + 'variableTable': ['lots', 'of', 'variables'], + 'finalTimeUnixMsec': { + '.sv': 'timestamp' + } + } + + self._client.EnqueueBreakpointUpdate(input_breakpoint) + + # Wait for the breakpoint to be sent. + while self._client._transmission_queue: + time.sleep(0.1) + + db_ref_calls = self._mock_db_ref.call_args_list + self.assertEqual( + call(f'cdbg/breakpoints/{debuggee_id}/active/{breakpoint_id}', + self._firebase_app), db_ref_calls[4]) + self.assertEqual( + call(f'cdbg/breakpoints/{debuggee_id}/snapshot/{breakpoint_id}', + self._firebase_app), db_ref_calls[5]) + self.assertEqual( + call(f'cdbg/breakpoints/{debuggee_id}/final/{breakpoint_id}', + self._firebase_app), db_ref_calls[6]) + + active_ref_mock.delete.assert_called_once() + snapshot_ref_mock.set.assert_called_once_with(full_breakpoint) + final_ref_mock.set.assert_called_once_with(short_breakpoint) + + def testEnqueueBreakpointUpdateWithLogpoint(self): + active_ref_mock = MagicMock() + final_ref_mock = MagicMock() + + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, self._mock_presence_ref, + self._mock_register_ref, self._fake_subscribe_ref, active_ref_mock, + final_ref_mock + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + debuggee_id = self._client._debuggee_id + breakpoint_id = 'logpoint-0' + + input_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'action': 'LOG', + 'isFinalState': True, + 'status': { + 'isError': True, + 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', + }, + } + output_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'action': 'LOG', + 'status': { + 'isError': True, + 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', + }, + 'finalTimeUnixMsec': { + '.sv': 'timestamp' + } + } + + self._client.EnqueueBreakpointUpdate(input_breakpoint) + + # Wait for the breakpoint to be sent. + while self._client._transmission_queue: + time.sleep(0.1) + + db_ref_calls = self._mock_db_ref.call_args_list + self.assertEqual( + call(f'cdbg/breakpoints/{debuggee_id}/active/{breakpoint_id}', + self._firebase_app), db_ref_calls[4]) + self.assertEqual( + call(f'cdbg/breakpoints/{debuggee_id}/final/{breakpoint_id}', + self._firebase_app), db_ref_calls[5]) + + active_ref_mock.delete.assert_called_once() + final_ref_mock.set.assert_called_once_with(output_breakpoint) + + # Make sure that the snapshot node was not accessed. + self.assertTrue( + call(f'cdbg/breakpoints/{debuggee_id}/snapshot/{breakpoint_id}', ANY) + not in db_ref_calls) + + def testEnqueueBreakpointUpdateRetry(self): + active_ref_mock = MagicMock() + snapshot_ref_mock = MagicMock() + final_ref_mock = MagicMock() + + # This test will have three failures, one for each of the firebase writes. + # UNAVAILABLE errors are retryable. + active_ref_mock.delete.side_effect = [ + FirebaseError('UNAVAILABLE', 'active error'), None, None, None + ] + snapshot_ref_mock.set.side_effect = [ + FirebaseError('UNAVAILABLE', 'snapshot error'), None, None + ] + final_ref_mock.set.side_effect = [ + FirebaseError('UNAVAILABLE', 'final error'), None + ] + + self._mock_db_ref.side_effect = [ + self._mock_schema_version_ref, + self._mock_presence_ref, + self._mock_register_ref, + self._fake_subscribe_ref, # setup + active_ref_mock, # attempt 1 + active_ref_mock, + snapshot_ref_mock, # attempt 2 + active_ref_mock, + snapshot_ref_mock, + final_ref_mock, # attempt 3 + active_ref_mock, + snapshot_ref_mock, + final_ref_mock # attempt 4 + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + breakpoint_id = 'breakpoint-0' + + input_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'evaluatedExpressions': ['expressions go here'], + 'stackFrames': ['stuff goes here'], + 'variableTable': ['lots', 'of', 'variables'], + } + short_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'action': 'CAPTURE', + 'finalTimeUnixMsec': { + '.sv': 'timestamp' + } + } + full_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'action': 'CAPTURE', + 'evaluatedExpressions': ['expressions go here'], + 'stackFrames': ['stuff goes here'], + 'variableTable': ['lots', 'of', 'variables'], + 'finalTimeUnixMsec': { + '.sv': 'timestamp' + } + } + + self._client.EnqueueBreakpointUpdate(input_breakpoint) + + # Wait for the breakpoint to be sent. Retries will have occured. + while self._client._transmission_queue: + time.sleep(0.1) + + active_ref_mock.delete.assert_has_calls([call()] * 4) + snapshot_ref_mock.set.assert_has_calls([call(full_breakpoint)] * 3) + final_ref_mock.set.assert_has_calls([call(short_breakpoint)] * 2) + + def _TestInitializeLabels(self, module_var, version_var, minor_var): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + self._client.InitializeDebuggeeLabels({ + 'module': 'my_module', + 'version': '1', + 'minorversion': '23', + 'something_else': 'irrelevant' + }) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'my_module', + 'version': '1', + 'minorversion': '23', + 'platform': 'default' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-my_module-1', + self._client._GetDebuggeeDescription()) + + uniquifier1 = self._client._ComputeUniquifier( + {'labels': self._client._debuggee_labels}) + self.assertTrue(uniquifier1) # Not empty string. + + try: + os.environ[module_var] = 'env_module' + os.environ[version_var] = '213' + os.environ[minor_var] = '3476734' + self._client.InitializeDebuggeeLabels(None) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'env_module', + 'version': '213', + 'minorversion': '3476734', + 'platform': 'default' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-env_module-213', + self._client._GetDebuggeeDescription()) + + os.environ[module_var] = 'default' + os.environ[version_var] = '213' + os.environ[minor_var] = '3476734' + self._client.InitializeDebuggeeLabels({'minorversion': 'something else'}) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'version': '213', + 'minorversion': 'something else', + 'platform': 'default' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-213', + self._client._GetDebuggeeDescription()) + + finally: + del os.environ[module_var] + del os.environ[version_var] + del os.environ[minor_var] + + def testInitializeLegacyDebuggeeLabels(self): + self._TestInitializeLabels('GAE_MODULE_NAME', 'GAE_MODULE_VERSION', + 'GAE_MINOR_VERSION') + + def testInitializeDebuggeeLabels(self): + self._TestInitializeLabels('GAE_SERVICE', 'GAE_VERSION', + 'GAE_DEPLOYMENT_ID') + + def testInitializeCloudRunDebuggeeLabels(self): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + try: + os.environ['K_SERVICE'] = 'env_module' + os.environ['K_REVISION'] = '213' + self._client.InitializeDebuggeeLabels(None) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'env_module', + 'version': '213', + 'platform': 'default' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-env_module-213', + self._client._GetDebuggeeDescription()) + + finally: + del os.environ['K_SERVICE'] + del os.environ['K_REVISION'] + + def testInitializeCloudFunctionDebuggeeLabels(self): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + try: + os.environ['FUNCTION_NAME'] = 'fcn-name' + os.environ['X_GOOGLE_FUNCTION_VERSION'] = '213' + self._client.InitializeDebuggeeLabels(None) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'fcn-name', + 'version': '213', + 'platform': 'cloud_function' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-fcn-name-213', + self._client._GetDebuggeeDescription()) + + finally: + del os.environ['FUNCTION_NAME'] + del os.environ['X_GOOGLE_FUNCTION_VERSION'] + + def testInitializeCloudFunctionUnversionedDebuggeeLabels(self): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + try: + os.environ['FUNCTION_NAME'] = 'fcn-name' + self._client.InitializeDebuggeeLabels(None) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'fcn-name', + 'version': 'unversioned', + 'platform': 'cloud_function' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-fcn-name-unversioned', + self._client._GetDebuggeeDescription()) + + finally: + del os.environ['FUNCTION_NAME'] + + def testInitializeCloudFunctionWithRegionDebuggeeLabels(self): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + try: + os.environ['FUNCTION_NAME'] = 'fcn-name' + os.environ['FUNCTION_REGION'] = 'fcn-region' + self._client.InitializeDebuggeeLabels(None) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'fcn-name', + 'version': 'unversioned', + 'platform': 'cloud_function', + 'region': 'fcn-region' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-fcn-name-unversioned', + self._client._GetDebuggeeDescription()) + + finally: + del os.environ['FUNCTION_NAME'] + del os.environ['FUNCTION_REGION'] + + def testAppFilesUniquifierNoMinorVersion(self): + """Verify that uniquifier_computer is used if minor version not defined.""" + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + root = tempfile.mkdtemp('', 'fake_app_') + sys.path.insert(0, root) + try: + uniquifier1 = self._client._ComputeUniquifier({}) + + with open(os.path.join(root, 'app.py'), 'w', encoding='utf-8') as f: + f.write('hello') + uniquifier2 = self._client._ComputeUniquifier({}) + finally: + del sys.path[0] + + self.assertNotEqual(uniquifier1, uniquifier2) + + def testAppFilesUniquifierWithMinorVersion(self): + """Verify that uniquifier_computer not used if minor version is defined.""" + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + root = tempfile.mkdtemp('', 'fake_app_') + + os.environ['GAE_MINOR_VERSION'] = '12345' + sys.path.insert(0, root) + try: + self._client.InitializeDebuggeeLabels(None) + + uniquifier1 = self._client._GetDebuggee()['uniquifier'] + + with open(os.path.join(root, 'app.py'), 'w', encoding='utf-8') as f: + f.write('hello') + uniquifier2 = self._client._GetDebuggee()['uniquifier'] + finally: + del os.environ['GAE_MINOR_VERSION'] + del sys.path[0] + + self.assertEqual(uniquifier1, uniquifier2) + + def testSourceContext(self): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + root = tempfile.mkdtemp('', 'fake_app_') + source_context_path = os.path.join(root, 'source-context.json') + + sys.path.insert(0, root) + try: + debuggee_no_source_context1 = self._client._GetDebuggee() + + with open(source_context_path, 'w', encoding='utf-8') as f: + f.write('not a valid JSON') + debuggee_bad_source_context = self._client._GetDebuggee() + + with open(os.path.join(root, 'fake_app.py'), 'w', encoding='utf-8') as f: + f.write('pretend') + debuggee_no_source_context2 = self._client._GetDebuggee() + + with open(source_context_path, 'w', encoding='utf-8') as f: + f.write('{"what": "source context"}') + debuggee_with_source_context = self._client._GetDebuggee() + + os.remove(source_context_path) + finally: + del sys.path[0] + + self.assertNotIn('sourceContexts', debuggee_no_source_context1) + self.assertNotIn('sourceContexts', debuggee_bad_source_context) + self.assertListEqual([{ + 'what': 'source context' + }], debuggee_with_source_context['sourceContexts']) + + uniquifiers = set() + uniquifiers.add(debuggee_no_source_context1['uniquifier']) + uniquifiers.add(debuggee_with_source_context['uniquifier']) + uniquifiers.add(debuggee_bad_source_context['uniquifier']) + self.assertLen(uniquifiers, 1) + uniquifiers.add(debuggee_no_source_context2['uniquifier']) + self.assertLen(uniquifiers, 2) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/glob_data_visibility_policy_test.py b/tests/py/glob_data_visibility_policy_test.py new file mode 100644 index 0000000..8670198 --- /dev/null +++ b/tests/py/glob_data_visibility_policy_test.py @@ -0,0 +1,36 @@ +"""Tests for glob_data_visibility_policy.""" + +from absl.testing import absltest +from googleclouddebugger import glob_data_visibility_policy + +RESPONSES = glob_data_visibility_policy.RESPONSES +UNKNOWN_TYPE = (False, RESPONSES['UNKNOWN_TYPE']) +BLACKLISTED = (False, RESPONSES['BLACKLISTED']) +NOT_WHITELISTED = (False, RESPONSES['NOT_WHITELISTED']) +VISIBLE = (True, RESPONSES['VISIBLE']) + + +class GlobDataVisibilityPolicyTest(absltest.TestCase): + + def testIsDataVisible(self): + blacklist_patterns = ( + 'wl1.private1', + 'wl2.*', + '*.private2', + '', + ) + whitelist_patterns = ('wl1.*', 'wl2.*') + + policy = glob_data_visibility_policy.GlobDataVisibilityPolicy( + blacklist_patterns, whitelist_patterns) + + self.assertEqual(BLACKLISTED, policy.IsDataVisible('wl1.private1')) + self.assertEqual(BLACKLISTED, policy.IsDataVisible('wl2.foo')) + self.assertEqual(BLACKLISTED, policy.IsDataVisible('foo.private2')) + self.assertEqual(NOT_WHITELISTED, policy.IsDataVisible('wl3.foo')) + self.assertEqual(VISIBLE, policy.IsDataVisible('wl1.foo')) + self.assertEqual(UNKNOWN_TYPE, policy.IsDataVisible(None)) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/imphook_test.py b/tests/py/imphook_test.py new file mode 100644 index 0000000..320e477 --- /dev/null +++ b/tests/py/imphook_test.py @@ -0,0 +1,485 @@ +"""Unit test for imphook module.""" + +import importlib +import os +import sys +import tempfile + +from absl.testing import absltest + +from googleclouddebugger import imphook + + +class ImportHookTest(absltest.TestCase): + """Tests for the new module import hook.""" + + def setUp(self): + self._test_package_dir = tempfile.mkdtemp('', 'imphook_') + sys.path.append(self._test_package_dir) + + self._import_callbacks_log = [] + self._callback_cleanups = [] + + def tearDown(self): + sys.path.remove(self._test_package_dir) + + for cleanup in self._callback_cleanups: + cleanup() + + # Assert no hooks or entries remained in the set. + self.assertEmpty(imphook._import_callbacks) + + def testPackageImport(self): + self._Hook(self._CreateFile('testpkg1/__init__.py')) + import testpkg1 # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg1/__init__.py'], self._import_callbacks_log) + + def testModuleImport(self): + self._CreateFile('testpkg2/__init__.py') + self._Hook(self._CreateFile('testpkg2/my.py')) + import testpkg2.my # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg2/my.py'], self._import_callbacks_log) + + def testUnrelatedImport(self): + self._CreateFile('testpkg3/__init__.py') + self._Hook(self._CreateFile('testpkg3/first.py')) + self._CreateFile('testpkg3/second.py') + import testpkg3.second # pylint: disable=g-import-not-at-top,unused-variable + self.assertEmpty(self._import_callbacks_log) + + def testDoubleImport(self): + self._Hook(self._CreateFile('testpkg4/__init__.py')) + import testpkg4 # pylint: disable=g-import-not-at-top,unused-variable + import testpkg4 # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg4/__init__.py', 'testpkg4/__init__.py'], + sorted(self._import_callbacks_log)) + + def testRemoveCallback(self): + cleanup = self._Hook(self._CreateFile('testpkg4b/__init__.py')) + cleanup() + import testpkg4b # pylint: disable=g-import-not-at-top,unused-variable + self.assertEmpty(self._import_callbacks_log) + + def testRemoveCallbackAfterImport(self): + cleanup = self._Hook(self._CreateFile('testpkg5/__init__.py')) + import testpkg5 # pylint: disable=g-import-not-at-top,unused-variable + cleanup() + import testpkg5 # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg5/__init__.py'], self._import_callbacks_log) + + def testTransitiveImport(self): + self._CreateFile('testpkg6/__init__.py') + self._Hook(self._CreateFile('testpkg6/first.py', 'import second')) + self._Hook(self._CreateFile('testpkg6/second.py', 'import third')) + self._Hook(self._CreateFile('testpkg6/third.py')) + import testpkg6.first # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual( + ['testpkg6/first.py', 'testpkg6/second.py', 'testpkg6/third.py'], + sorted(self._import_callbacks_log)) + + def testPackageDotModuleImport(self): + self._Hook(self._CreateFile('testpkg8/__init__.py')) + self._Hook(self._CreateFile('testpkg8/my.py')) + import testpkg8.my # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg8/__init__.py', 'testpkg8/my.py'], + sorted(self._import_callbacks_log)) + + def testNestedPackageDotModuleImport(self): + self._Hook(self._CreateFile('testpkg9a/__init__.py')) + self._Hook(self._CreateFile('testpkg9a/testpkg9b/__init__.py')) + self._CreateFile('testpkg9a/testpkg9b/my.py') + import testpkg9a.testpkg9b.my # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual( + ['testpkg9a/__init__.py', 'testpkg9a/testpkg9b/__init__.py'], + sorted(self._import_callbacks_log)) + + def testFromImport(self): + self._Hook(self._CreateFile('testpkg10/__init__.py')) + self._CreateFile('testpkg10/my.py') + from testpkg10 import my # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg10/__init__.py'], self._import_callbacks_log) + + def testTransitiveFromImport(self): + self._CreateFile('testpkg7/__init__.py') + self._Hook( + self._CreateFile('testpkg7/first.py', 'from testpkg7 import second')) + self._Hook(self._CreateFile('testpkg7/second.py')) + from testpkg7 import first # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg7/first.py', 'testpkg7/second.py'], + sorted(self._import_callbacks_log)) + + def testFromNestedPackageImportModule(self): + self._Hook(self._CreateFile('testpkg11a/__init__.py')) + self._Hook(self._CreateFile('testpkg11a/testpkg11b/__init__.py')) + self._Hook(self._CreateFile('testpkg11a/testpkg11b/my.py')) + self._Hook(self._CreateFile('testpkg11a/testpkg11b/your.py')) + from testpkg11a.testpkg11b import my, your # pylint: disable=g-import-not-at-top,unused-variable,g-multiple-import + self.assertEqual([ + 'testpkg11a/__init__.py', 'testpkg11a/testpkg11b/__init__.py', + 'testpkg11a/testpkg11b/my.py', 'testpkg11a/testpkg11b/your.py' + ], sorted(self._import_callbacks_log)) + + def testDoubleNestedImport(self): + self._Hook(self._CreateFile('testpkg12a/__init__.py')) + self._Hook(self._CreateFile('testpkg12a/testpkg12b/__init__.py')) + self._Hook(self._CreateFile('testpkg12a/testpkg12b/my.py')) + from testpkg12a.testpkg12b import my # pylint: disable=g-import-not-at-top,unused-variable,g-multiple-import + from testpkg12a.testpkg12b import my # pylint: disable=g-import-not-at-top,unused-variable,g-multiple-import + self.assertEqual([ + 'testpkg12a/__init__.py', 'testpkg12a/__init__.py', + 'testpkg12a/testpkg12b/__init__.py', + 'testpkg12a/testpkg12b/__init__.py', 'testpkg12a/testpkg12b/my.py', + 'testpkg12a/testpkg12b/my.py' + ], sorted(self._import_callbacks_log)) + + def testFromPackageImportStar(self): + self._Hook(self._CreateFile('testpkg13a/__init__.py')) + self._Hook(self._CreateFile('testpkg13a/my1.py')) + self._Hook(self._CreateFile('testpkg13a/your1.py')) + # Star imports are only allowed at the top level, not inside a function in + # Python 3. Doing so would be a SyntaxError. + exec('from testpkg13a import *') # pylint: disable=exec-used + self.assertEqual(['testpkg13a/__init__.py'], self._import_callbacks_log) + + def testFromPackageImportStarWith__all__(self): + self._Hook(self._CreateFile('testpkg14a/__init__.py', '__all__=["my1"]')) + self._Hook(self._CreateFile('testpkg14a/my1.py')) + self._Hook(self._CreateFile('testpkg14a/your1.py')) + exec('from testpkg14a import *') # pylint: disable=exec-used + self.assertEqual(['testpkg14a/__init__.py', 'testpkg14a/my1.py'], + sorted(self._import_callbacks_log)) + + def testImportFunction(self): + self._Hook(self._CreateFile('testpkg27/__init__.py')) + __import__('testpkg27') + self.assertEqual(['testpkg27/__init__.py'], self._import_callbacks_log) + + def testImportLib(self): + self._Hook(self._CreateFile('zero.py')) + self._Hook(self._CreateFile('testpkg15a/__init__.py')) + self._Hook(self._CreateFile('testpkg15a/first.py')) + self._Hook( + self._CreateFile('testpkg15a/testpkg15b/__init__.py', + 'assert False, "unexpected import"')) + self._Hook(self._CreateFile('testpkg15a/testpkg15c/__init__.py')) + self._Hook(self._CreateFile('testpkg15a/testpkg15c/second.py')) + + # Import top level module. + importlib.import_module('zero') + self.assertEqual(['zero.py'], self._import_callbacks_log) + self._import_callbacks_log = [] + + # Import top level package. + importlib.import_module('testpkg15a') + self.assertEqual(['testpkg15a/__init__.py'], self._import_callbacks_log) + self._import_callbacks_log = [] + + # Import package.module. + importlib.import_module('testpkg15a.first') + self.assertEqual(['testpkg15a/__init__.py', 'testpkg15a/first.py'], + sorted(self._import_callbacks_log)) + self._import_callbacks_log = [] + + # Relative module import from package context. + importlib.import_module('.first', 'testpkg15a') + self.assertEqual(['testpkg15a/__init__.py', 'testpkg15a/first.py'], + sorted(self._import_callbacks_log)) + self._import_callbacks_log = [] + + # Relative module import from package context with '..'. + # In Python 3, the parent module has to be loaded before a relative import + importlib.import_module('testpkg15a.testpkg15c') + self._import_callbacks_log = [] + importlib.import_module('..first', 'testpkg15a.testpkg15c') + self.assertEqual( + [ + 'testpkg15a/__init__.py', + # TODO: Importlib may or may not load testpkg15b, + # depending on the implementation. Currently on blaze, it does not + # load testpkg15b, but a similar non-blaze code on my workstation + # loads testpkg15b. We should verify this behavior. + # 'testpkg15a/testpkg15b/__init__.py', + 'testpkg15a/first.py' + ], + sorted(self._import_callbacks_log)) + self._import_callbacks_log = [] + + # Relative module import from nested package context. + importlib.import_module('.second', 'testpkg15a.testpkg15c') + self.assertEqual([ + 'testpkg15a/__init__.py', 'testpkg15a/testpkg15c/__init__.py', + 'testpkg15a/testpkg15c/second.py' + ], sorted(self._import_callbacks_log)) + self._import_callbacks_log = [] + + def testRemoveImportHookFromCallback(self): + + def RunCleanup(unused_mod): + cleanup() + + cleanup = self._Hook(self._CreateFile('testpkg15/__init__.py'), RunCleanup) + import testpkg15 # pylint: disable=g-import-not-at-top,unused-variable + import testpkg15 # pylint: disable=g-import-not-at-top,unused-variable + import testpkg15 # pylint: disable=g-import-not-at-top,unused-variable + + # The first import should have removed the hook, so expect only one entry. + self.assertEqual(['testpkg15/__init__.py'], self._import_callbacks_log) + + def testInitImportNoPrematureCallback(self): + # Verifies that the callback is not invoked before the package is fully + # loaded. Thus, assuring that the all module code is available for lookup. + def CheckFullyLoaded(module): + self.assertEqual(1, getattr(module, 'validate', None), 'premature call') + + self._Hook(self._CreateFile('testpkg16/my1.py')) + self._Hook( + self._CreateFile('testpkg16/__init__.py', 'import my1\nvalidate = 1'), + CheckFullyLoaded) + import testpkg16.my1 # pylint: disable=g-import-not-at-top,unused-variable + + self.assertEqual(['testpkg16/__init__.py', 'testpkg16/my1.py'], + sorted(self._import_callbacks_log)) + + def testCircularImportNoPrematureCallback(self): + # Verifies that the callback is not invoked before the first module is fully + # loaded. Thus, assuring that the all module code is available for lookup. + def CheckFullyLoaded(module): + self.assertEqual(1, getattr(module, 'validate', None), 'premature call') + + self._CreateFile('testpkg17/__init__.py') + self._Hook( + self._CreateFile('testpkg17/c1.py', 'import testpkg17.c2\nvalidate = 1', + False), CheckFullyLoaded) + self._Hook( + self._CreateFile('testpkg17/c2.py', 'import testpkg17.c1\nvalidate = 1', + False), CheckFullyLoaded) + + import testpkg17.c1 # pylint: disable=g-import-not-at-top,unused-variable + + self.assertEqual(['testpkg17/c1.py', 'testpkg17/c2.py'], + sorted(self._import_callbacks_log)) + + def testImportException(self): + # An exception is thrown by the builtin importer during import. + self._CreateFile('testpkg18/__init__.py') + self._Hook(self._CreateFile('testpkg18/bad.py', 'assert False, "bad file"')) + self._Hook(self._CreateFile('testpkg18/good.py')) + + try: + import testpkg18.bad # pylint: disable=g-import-not-at-top,unused-variable + except AssertionError: + pass + + import testpkg18.good # pylint: disable=g-import-not-at-top,unused-variable + + self.assertEqual(['testpkg18/good.py'], self._import_callbacks_log) + + def testImportNestedException(self): + # An import exception is thrown and caught inside a module being imported. + self._CreateFile('testpkg19/__init__.py') + self._Hook( + self._CreateFile('testpkg19/m19.py', + 'try: import m19b\nexcept ImportError: pass')) + + import testpkg19.m19 # pylint: disable=g-import-not-at-top,unused-variable + + self.assertEqual(['testpkg19/m19.py'], self._import_callbacks_log) + + def testModuleImportByPathSuffix(self): + # Import module by providing only a suffix of the module's file path. + self._CreateFile('testpkg20a/__init__.py') + self._CreateFile('testpkg20a/testpkg20b/__init__.py') + self._CreateFile('testpkg20a/testpkg20b/my1.py') + self._CreateFile('testpkg20a/testpkg20b/my2.py') + self._CreateFile('testpkg20a/testpkg20b/my3.py') + + # Import just by the name of the module file. + self._Hook('my1.py') + import testpkg20a.testpkg20b.my1 # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['my1.py'], self._import_callbacks_log) + self._import_callbacks_log = [] + + # Import with only one of the enclosing package names. + self._Hook('testpkg20b/my2.py') + import testpkg20a.testpkg20b.my2 # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg20b/my2.py'], self._import_callbacks_log) + self._import_callbacks_log = [] + + # Import with all enclosing packages (the typical case). + self._Hook('testpkg20b/my3.py') + import testpkg20a.testpkg20b.my3 # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg20b/my3.py'], self._import_callbacks_log) + self._import_callbacks_log = [] + + def testFromImportImportsFunction(self): + self._CreateFile('testpkg21a/__init__.py') + self._CreateFile('testpkg21a/testpkg21b/__init__.py') + self._CreateFile('testpkg21a/testpkg21b/mod.py', ('def func1():\n' + ' return 5\n' + '\n' + 'def func2():\n' + ' return 7\n')) + + self._Hook('mod.py') + from testpkg21a.testpkg21b.mod import func1, func2 # pylint: disable=g-import-not-at-top,unused-variable,g-multiple-import + self.assertEqual(['mod.py'], self._import_callbacks_log) + + def testImportSibling(self): + self._CreateFile('testpkg22/__init__.py') + self._CreateFile('testpkg22/first.py', 'import second') + self._CreateFile('testpkg22/second.py') + + self._Hook('testpkg22/second.py') + import testpkg22.first # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg22/second.py'], self._import_callbacks_log) + + def testImportSiblingSamePackage(self): + self._CreateFile('testpkg23/__init__.py') + self._CreateFile('testpkg23/testpkg23/__init__.py') + self._CreateFile( + 'testpkg23/first.py', + 'import testpkg23.second') # This refers to testpkg23.testpkg23.second + self._CreateFile('testpkg23/testpkg23/second.py') + + self._Hook('testpkg23/testpkg23/second.py') + import testpkg23.first # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual(['testpkg23/testpkg23/second.py'], + self._import_callbacks_log) + + def testImportSiblingFromInit(self): + self._Hook(self._CreateFile('testpkg23a/__init__.py', 'import testpkg23b')) + self._Hook( + self._CreateFile('testpkg23a/testpkg23b/__init__.py', + 'import testpkg23c')) + self._Hook(self._CreateFile('testpkg23a/testpkg23b/testpkg23c/__init__.py')) + import testpkg23a # pylint: disable=g-import-not-at-top,unused-variable + self.assertEqual([ + 'testpkg23a/__init__.py', 'testpkg23a/testpkg23b/__init__.py', + 'testpkg23a/testpkg23b/testpkg23c/__init__.py' + ], sorted(self._import_callbacks_log)) + + def testThreadLocalCleanup(self): + self._CreateFile('testpkg24/__init__.py') + self._CreateFile('testpkg24/foo.py', 'import bar') + self._CreateFile('testpkg24/bar.py') + + # Create a hook for any arbitrary module. Doesn't need to hit. + self._Hook('xxx/yyy.py') + + import testpkg24.foo # pylint: disable=g-import-not-at-top,unused-variable + + self.assertEqual(imphook._import_local.nest_level, 0) + self.assertEmpty(imphook._import_local.names) + + def testThreadLocalCleanupWithCaughtImportError(self): + self._CreateFile('testpkg25/__init__.py') + self._CreateFile( + 'testpkg25/foo.py', + 'import bar\n' # success. + 'import baz') # success. + self._CreateFile('testpkg25/bar.py') + self._CreateFile( + 'testpkg25/baz.py', 'try:\n' + ' import testpkg25b\n' + 'except ImportError:\n' + ' pass') + + # Create a hook for any arbitrary module. Doesn't need to hit. + self._Hook('xxx/yyy.py') + + # Successful import at top level. Failed import at inner level. + import testpkg25.foo # pylint: disable=g-import-not-at-top,unused-variable + + self.assertEqual(imphook._import_local.nest_level, 0) + self.assertEmpty(imphook._import_local.names) + + def testThreadLocalCleanupWithUncaughtImportError(self): + self._CreateFile('testpkg26/__init__.py') + self._CreateFile( + 'testpkg26/foo.py', + 'import bar\n' # success. + 'import baz') # fail. + self._CreateFile('testpkg26/bar.py') + + # Create a hook for any arbitrary module. Doesn't need to hit. + self._Hook('testpkg26/bar.py') + + # Inner import will fail, and exception will be propagated here. + try: + import testpkg26.foo # pylint: disable=g-import-not-at-top,unused-variable + except ImportError: + pass + + # The hook for bar should be invoked, as bar is already loaded. + self.assertEqual(['testpkg26/bar.py'], self._import_callbacks_log) + + self.assertEqual(imphook._import_local.nest_level, 0) + self.assertEmpty(imphook._import_local.names) + + def testCleanup(self): + cleanup1 = self._Hook('a/b/c.py') + cleanup2 = self._Hook('a/b/c.py') + cleanup3 = self._Hook('a/d/f.py') + cleanup4 = self._Hook('a/d/g.py') + cleanup5 = self._Hook('a/d/c.py') + self.assertLen(imphook._import_callbacks, 4) + + cleanup1() + self.assertLen(imphook._import_callbacks, 4) + cleanup2() + self.assertLen(imphook._import_callbacks, 3) + cleanup3() + self.assertLen(imphook._import_callbacks, 2) + cleanup4() + self.assertLen(imphook._import_callbacks, 1) + cleanup5() + self.assertLen(imphook._import_callbacks, 0) + + def _CreateFile(self, path, content='', rewrite_imports=True): + full_path = os.path.join(self._test_package_dir, path) + directory, unused_name = os.path.split(full_path) + + if not os.path.isdir(directory): + os.makedirs(directory) + + def RewriteImport(line): + """Converts import statements to relative form. + + Examples: + import x => from . import x + import x.y.z => from .x.y import z + print('') => print('') + + Args: + line: str, the line to convert. + + Returns: + str, the converted import statement or original line. + """ + original_line_length = len(line) + line = line.lstrip(' ') + indent = ' ' * (original_line_length - len(line)) + if line.startswith('import'): + pkg, _, mod = line.split(' ')[1].rpartition('.') + line = 'from .%s import %s' % (pkg, mod) + return indent + line + + with open(full_path, 'w') as writer: + if rewrite_imports: + content = '\n'.join(RewriteImport(l) for l in content.split('\n')) + writer.write(content) + + return path + + # TODO: add test for the module param in the callback. + def _Hook(self, path, callback=lambda m: None): + cleanup = imphook.AddImportCallbackBySuffix( + path, lambda mod: + (self._import_callbacks_log.append(path), callback(mod))) + self.assertTrue(cleanup, path) + self._callback_cleanups.append(cleanup) + return cleanup + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/integration_test.py b/tests/py/integration_test.py new file mode 100644 index 0000000..1a16f30 --- /dev/null +++ b/tests/py/integration_test.py @@ -0,0 +1,670 @@ +"""Complete tests of the debugger mocking the backend.""" + +from datetime import datetime +from datetime import timedelta +import functools +import inspect +import itertools +import os +import sys +import time +from unittest import mock + +from googleapiclient import discovery +import googleclouddebugger as cdbg + +import queue + +import google.auth +from absl.testing import absltest + +from googleclouddebugger import collector +from googleclouddebugger import labels +import python_test_util + +_TEST_DEBUGGEE_ID = 'gcp:integration-test-debuggee-id' +_TEST_AGENT_ID = 'agent-id-123-abc' +_TEST_PROJECT_ID = 'test-project-id' +_TEST_PROJECT_NUMBER = '123456789' + +# Time to sleep before returning the result of an API call. +# Without a delay, the agent will continuously call ListActiveBreakpoints, +# and the mock object will use a lot of memory to record all the calls. +_REQUEST_DELAY_SECS = 0.01 + +# TODO: Modify to work with a mocked Firebase database instead. +# class IntegrationTest(absltest.TestCase): +# """Complete tests of the debugger mocking the backend. + +# These tests employ all the components of the debugger. The actual +# communication channel with the backend is mocked. This allows the test +# quickly inject breakpoints and read results. It also makes the test +# standalone and independent of the actual backend. + +# Uses the new module search algorithm (b/70226488). +# """ + +# class FakeHub(object): +# """Starts the debugger with a mocked communication channel.""" + +# def __init__(self): +# # Breakpoint updates posted by the debugger that haven't been processed +# # by the test case code. +# self._incoming_breakpoint_updates = queue.Queue() + +# # Running counter used to generate unique breakpoint IDs. +# self._id_counter = itertools.count() + +# self._service = mock.Mock() + +# patcher = mock.patch.object(discovery, 'build') +# self._mock_build = patcher.start() +# self._mock_build.return_value = self._service + +# patcher = mock.patch.object(google.auth, 'default') +# self._default_auth_mock = patcher.start() +# self._default_auth_mock.return_value = None, _TEST_PROJECT_ID + +# controller = self._service.controller.return_value +# debuggees = controller.debuggees.return_value +# breakpoints = debuggees.breakpoints.return_value + +# # Simulate a time delay for calls to the mock API. +# def ReturnWithDelay(val): + +# def GetVal(): +# time.sleep(_REQUEST_DELAY_SECS) +# return val + +# return GetVal + +# self._register_execute = debuggees.register.return_value.execute +# self._register_execute.side_effect = ReturnWithDelay({ +# 'debuggee': { +# 'id': _TEST_DEBUGGEE_ID +# }, +# 'agentId': _TEST_AGENT_ID +# }) + +# self._active_breakpoints = {'breakpoints': []} +# self._list_execute = breakpoints.list.return_value.execute +# self._list_execute.side_effect = ReturnWithDelay(self._active_breakpoints) + +# breakpoints.update = self._UpdateBreakpoint + +# # Start the debugger. +# cdbg.enable() + +# def SetBreakpoint(self, tag, template=None): +# """Sets a new breakpoint in this source file. + +# The line number is identified by tag. The optional template may specify +# other breakpoint parameters such as condition and watched expressions. + +# Args: +# tag: label for a source line. +# template: optional breakpoint parameters. +# """ +# path, line = python_test_util.ResolveTag(sys.modules[__name__], tag) +# self.SetBreakpointAtPathLine(path, line, template) + +# def SetBreakpointAtFile(self, filename, tag, template=None): +# """Sets a breakpoint in a file with the given filename. + +# The line number is identified by tag. The optional template may specify +# other breakpoint parameters such as condition and watched expressions. + +# Args: +# filename: the name of the file inside which the tag will be searched. +# Must be in the same directory as the current file. +# tag: label for a source line. +# template: optional breakpoint parameters. + +# Raises: +# Exception: when the given tag does not uniquely identify a line. +# """ +# # TODO: Move part of this to python_test_utils.py file. +# # Find the full path of filename, using the directory of the current file. +# module_path = inspect.getsourcefile(sys.modules[__name__]) +# directory, unused_name = os.path.split(module_path) +# path = os.path.join(directory, filename) + +# # Similar to ResolveTag(), but for a module that's not loaded yet. +# tags = python_test_util.GetSourceFileTags(path) +# if tag not in tags: +# raise Exception('tag %s not found' % tag) +# lines = tags[tag] +# if len(lines) != 1: +# raise Exception('tag %s is ambiguous (lines: %s)' % (tag, lines)) + +# self.SetBreakpointAtPathLine(path, lines[0], template) + +# def SetBreakpointAtPathLine(self, path, line, template=None): +# """Sets a new breakpoint at path:line.""" +# breakpoint = { +# 'id': 'BP_%d' % next(self._id_counter), +# 'createTime': python_test_util.DateTimeToTimestamp(datetime.utcnow()), +# 'location': { +# 'path': path, +# 'line': line +# } +# } +# breakpoint.update(template or {}) + +# self.SetActiveBreakpoints(self.GetActiveBreakpoints() + [breakpoint]) + +# def GetActiveBreakpoints(self): +# """Returns current list of active breakpoints.""" +# return self._active_breakpoints['breakpoints'] + +# def SetActiveBreakpoints(self, breakpoints): +# """Sets a new list of active breakpoints. + +# Args: +# breakpoints: list of breakpoints to return to the debuglet. +# """ +# self._active_breakpoints['breakpoints'] = breakpoints +# begin_count = self._list_execute.call_count +# while self._list_execute.call_count < begin_count + 2: +# time.sleep(_REQUEST_DELAY_SECS) + +# def GetNextResult(self): +# """Waits for the next breakpoint update from the debuglet. + +# Returns: +# First breakpoint update sent by the debuglet that hasn't been +# processed yet. + +# Raises: +# queue.Empty: if waiting for breakpoint update times out. +# """ +# try: +# return self._incoming_breakpoint_updates.get(True, 15) +# except queue.Empty: +# raise AssertionError('Timed out waiting for breakpoint update') + +# def TryGetNextResult(self): +# """Returns the first unprocessed breakpoint update from the debuglet. + +# Returns: +# First breakpoint update sent by the debuglet that hasn't been +# processed yet. If no updates are pending, returns None. +# """ +# try: +# return self._incoming_breakpoint_updates.get_nowait() +# except queue.Empty: +# return None + +# def _UpdateBreakpoint(self, **keywords): +# """Fake implementation of service.debuggees().breakpoints().update().""" + +# class FakeBreakpointUpdateCommand(object): + +# def __init__(self, q): +# self._breakpoint = keywords['body']['breakpoint'] +# self._queue = q + +# def execute(self): # pylint: disable=invalid-name +# self._queue.put(self._breakpoint) + +# return FakeBreakpointUpdateCommand(self._incoming_breakpoint_updates) + +# # We only need to attach the debugger exactly once. The IntegrationTest class +# # is created for each test case, so we need to keep this state global. + +# _hub = FakeHub() + +# def _FakeLog(self, message, extra=None): +# del extra # unused +# self._info_log.append(message) + +# def setUp(self): +# self._info_log = [] +# collector.log_info_message = self._FakeLog + +# def tearDown(self): +# IntegrationTest._hub.SetActiveBreakpoints([]) + +# while True: +# breakpoint = IntegrationTest._hub.TryGetNextResult() +# if breakpoint is None: +# break +# self.fail('Unexpected incoming breakpoint update: %s' % breakpoint) + +# def testBackCompat(self): +# # Verify that the old AttachDebugger() is the same as enable() +# self.assertEqual(cdbg.enable, cdbg.AttachDebugger) + +# def testBasic(self): + +# def Trigger(): +# print('Breakpoint trigger') # BPTAG: BASIC + +# IntegrationTest._hub.SetBreakpoint('BASIC') +# Trigger() +# result = IntegrationTest._hub.GetNextResult() +# self.assertEqual('Trigger', result['stackFrames'][0]['function']) +# self.assertEqual('IntegrationTest.testBasic', +# result['stackFrames'][1]['function']) + +# # Verify that any pre existing labels present in the breakpoint are preserved +# # by the agent. +# def testExistingLabelsSurvive(self): + +# def Trigger(): +# print('Breakpoint trigger with labels') # BPTAG: EXISTING_LABELS_SURVIVE + +# IntegrationTest._hub.SetBreakpoint( +# 'EXISTING_LABELS_SURVIVE', +# {'labels': { +# 'label_1': 'value_1', +# 'label_2': 'value_2' +# }}) +# Trigger() +# result = IntegrationTest._hub.GetNextResult() +# self.assertIn('labels', result.keys()) +# self.assertIn('label_1', result['labels']) +# self.assertIn('label_2', result['labels']) +# self.assertEqual('value_1', result['labels']['label_1']) +# self.assertEqual('value_2', result['labels']['label_2']) + +# # Verify that any pre existing labels present in the breakpoint have priority +# # if they 'collide' with labels in the agent. +# def testExistingLabelsPriority(self): + +# def Trigger(): +# print('Breakpoint trigger with labels') # BPTAG: EXISTING_LABELS_PRIORITY + +# current_labels_collector = collector.breakpoint_labels_collector +# collector.breakpoint_labels_collector = \ +# lambda: {'label_1': 'value_1', 'label_2': 'value_2'} + +# IntegrationTest._hub.SetBreakpoint( +# 'EXISTING_LABELS_PRIORITY', +# {'labels': { +# 'label_1': 'value_foobar', +# 'label_3': 'value_3' +# }}) + +# Trigger() + +# collector.breakpoint_labels_collector = current_labels_collector + +# # In this case, label_1 was in both the agent and the pre existing labels, +# # the pre existing value of value_foobar should be preserved. +# result = IntegrationTest._hub.GetNextResult() +# self.assertIn('labels', result.keys()) +# self.assertIn('label_1', result['labels']) +# self.assertIn('label_2', result['labels']) +# self.assertIn('label_3', result['labels']) +# self.assertEqual('value_foobar', result['labels']['label_1']) +# self.assertEqual('value_2', result['labels']['label_2']) +# self.assertEqual('value_3', result['labels']['label_3']) + +# def testRequestLogIdLabel(self): + +# def Trigger(): +# print('Breakpoint trigger req id label') # BPTAG: REQUEST_LOG_ID_LABEL + +# current_request_log_id_collector = \ +# collector.request_log_id_collector +# collector.request_log_id_collector = lambda: 'foo_bar_id' + +# IntegrationTest._hub.SetBreakpoint('REQUEST_LOG_ID_LABEL') + +# Trigger() + +# collector.request_log_id_collector = \ +# current_request_log_id_collector + +# result = IntegrationTest._hub.GetNextResult() +# self.assertIn('labels', result.keys()) +# self.assertIn(labels.Breakpoint.REQUEST_LOG_ID, result['labels']) +# self.assertEqual('foo_bar_id', +# result['labels'][labels.Breakpoint.REQUEST_LOG_ID]) + +# # Tests the issue in b/30876465 +# def testSameLine(self): + +# def Trigger(): +# print('Breakpoint trigger same line') # BPTAG: SAME_LINE + +# num_breakpoints = 5 +# _, line = python_test_util.ResolveTag(sys.modules[__name__], 'SAME_LINE') +# for _ in range(0, num_breakpoints): +# IntegrationTest._hub.SetBreakpoint('SAME_LINE') +# Trigger() +# results = [] +# for _ in range(0, num_breakpoints): +# results.append(IntegrationTest._hub.GetNextResult()) +# lines = [result['stackFrames'][0]['location']['line'] for result in results] +# self.assertListEqual(lines, [line] * num_breakpoints) + +# def testCallStack(self): + +# def Method1(): +# Method2() + +# def Method2(): +# Method3() + +# def Method3(): +# Method4() + +# def Method4(): +# Method5() + +# def Method5(): +# return 0 # BPTAG: CALL_STACK + +# IntegrationTest._hub.SetBreakpoint('CALL_STACK') +# Method1() +# result = IntegrationTest._hub.GetNextResult() +# self.assertEqual([ +# 'Method5', 'Method4', 'Method3', 'Method2', 'Method1', +# 'IntegrationTest.testCallStack' +# ], [frame['function'] for frame in result['stackFrames']][:6]) + +# def testInnerMethod(self): + +# def Inner1(): + +# def Inner2(): + +# def Inner3(): +# print('Inner3') # BPTAG: INNER3 + +# Inner3() + +# Inner2() + +# IntegrationTest._hub.SetBreakpoint('INNER3') +# Inner1() +# result = IntegrationTest._hub.GetNextResult() +# self.assertEqual('Inner3', result['stackFrames'][0]['function']) + +# def testClassMethodWithDecorator(self): + +# def MyDecorator(handler): + +# def Caller(self): +# return handler(self) + +# return Caller + +# class BaseClass(object): +# pass + +# class MyClass(BaseClass): + +# @MyDecorator +# def Get(self): +# param = {} # BPTAG: METHOD_WITH_DECORATOR +# return str(param) + +# IntegrationTest._hub.SetBreakpoint('METHOD_WITH_DECORATOR') +# self.assertEqual('{}', MyClass().Get()) +# result = IntegrationTest._hub.GetNextResult() +# self.assertEqual('MyClass.Get', result['stackFrames'][0]['function']) +# self.assertEqual('MyClass.Caller', result['stackFrames'][1]['function']) +# self.assertEqual( +# { +# 'name': +# 'self', +# 'type': +# __name__ + '.MyClass', +# 'members': [{ +# 'status': { +# 'refersTo': 'VARIABLE_NAME', +# 'description': { +# 'format': 'Object has no fields' +# } +# } +# }] +# }, +# python_test_util.PackFrameVariable( +# result, 'self', collection='arguments')) + +# def testGlobalDecorator(self): +# IntegrationTest._hub.SetBreakpoint('WRAPPED_GLOBAL_METHOD') +# self.assertEqual('hello', WrappedGlobalMethod()) +# result = IntegrationTest._hub.GetNextResult() + +# self.assertNotIn('status', result) + +# def testNoLambdaExpression(self): + +# def Trigger(): +# cube = lambda x: x**3 # BPTAG: LAMBDA +# cube(18) + +# num_breakpoints = 5 +# for _ in range(0, num_breakpoints): +# IntegrationTest._hub.SetBreakpoint('LAMBDA') +# Trigger() +# results = [] +# for _ in range(0, num_breakpoints): +# results.append(IntegrationTest._hub.GetNextResult()) +# functions = [result['stackFrames'][0]['function'] for result in results] +# self.assertListEqual(functions, ['Trigger'] * num_breakpoints) + +# def testNoGeneratorExpression(self): + +# def Trigger(): +# gen = (i for i in range(0, 5)) # BPTAG: GENEXPR +# next(gen) +# next(gen) +# next(gen) +# next(gen) +# next(gen) + +# num_breakpoints = 1 +# for _ in range(0, num_breakpoints): +# IntegrationTest._hub.SetBreakpoint('GENEXPR') +# Trigger() +# results = [] +# for _ in range(0, num_breakpoints): +# results.append(IntegrationTest._hub.GetNextResult()) +# functions = [result['stackFrames'][0]['function'] for result in results] +# self.assertListEqual(functions, ['Trigger'] * num_breakpoints) + +# def testTryBlock(self): + +# def Method(a): +# try: +# return a * a # BPTAG: TRY_BLOCK +# except Exception as unused_e: # pylint: disable=broad-except +# return a + +# IntegrationTest._hub.SetBreakpoint('TRY_BLOCK') +# Method(11) +# result = IntegrationTest._hub.GetNextResult() +# self.assertEqual('Method', result['stackFrames'][0]['function']) +# self.assertEqual([{ +# 'name': 'a', +# 'value': '11', +# 'type': 'int' +# }], result['stackFrames'][0]['arguments']) + +# def testFrameArguments(self): + +# def Method(a, b): +# return a + str(b) # BPTAG: FRAME_ARGUMENTS + +# IntegrationTest._hub.SetBreakpoint('FRAME_ARGUMENTS') +# Method('hello', 87) +# result = IntegrationTest._hub.GetNextResult() +# self.assertEqual([{ +# 'name': 'a', +# 'value': "'hello'", +# 'type': 'str' +# }, { +# 'name': 'b', +# 'value': '87', +# 'type': 'int' +# }], result['stackFrames'][0]['arguments']) +# self.assertEqual('self', result['stackFrames'][1]['arguments'][0]['name']) + +# def testFrameLocals(self): + +# class Number(object): + +# def __init__(self): +# self.n = 57 + +# def Method(a): +# b = a**2 +# c = str(a) * 3 +# return c + str(b) # BPTAG: FRAME_LOCALS + +# IntegrationTest._hub.SetBreakpoint('FRAME_LOCALS') +# x = {'a': 1, 'b': Number()} +# Method(8) +# result = IntegrationTest._hub.GetNextResult() +# self.assertEqual({ +# 'name': 'b', +# 'value': '64', +# 'type': 'int' +# }, python_test_util.PackFrameVariable(result, 'b')) +# self.assertEqual({ +# 'name': 'c', +# 'value': "'888'", +# 'type': 'str' +# }, python_test_util.PackFrameVariable(result, 'c')) +# self.assertEqual( +# { +# 'name': +# 'x', +# 'type': +# 'dict', +# 'members': [{ +# 'name': "'a'", +# 'value': '1', +# 'type': 'int' +# }, { +# 'name': "'b'", +# 'type': __name__ + '.Number', +# 'members': [{ +# 'name': 'n', +# 'value': '57', +# 'type': 'int' +# }] +# }] +# }, python_test_util.PackFrameVariable(result, 'x', frame=1)) +# return x + + +# # FIXME: Broken in Python 3.10 +# # def testRecursion(self): +# # +# # def RecursiveMethod(i): +# # if i == 0: +# # return 0 # BPTAG: RECURSION +# # return RecursiveMethod(i - 1) +# # +# # IntegrationTest._hub.SetBreakpoint('RECURSION') +# # RecursiveMethod(5) +# # result = IntegrationTest._hub.GetNextResult() +# # +# # for frame in range(5): +# # self.assertEqual({ +# # 'name': 'i', +# # 'value': str(frame), +# # 'type': 'int' +# # }, python_test_util.PackFrameVariable(result, 'i', frame, 'arguments')) + +# def testWatchedExpressions(self): + +# def Trigger(): + +# class MyClass(object): + +# def __init__(self): +# self.a = 1 +# self.b = 'bbb' + +# unused_my = MyClass() +# print('Breakpoint trigger') # BPTAG: WATCHED_EXPRESSION + +# IntegrationTest._hub.SetBreakpoint('WATCHED_EXPRESSION', +# {'expressions': ['unused_my']}) +# Trigger() +# result = IntegrationTest._hub.GetNextResult() + +# self.assertEqual( +# { +# 'name': +# 'unused_my', +# 'type': +# __name__ + '.MyClass', +# 'members': [{ +# 'name': 'a', +# 'value': '1', +# 'type': 'int' +# }, { +# 'name': 'b', +# 'value': "'bbb'", +# 'type': 'str' +# }] +# }, python_test_util.PackWatchedExpression(result, 0)) + +# def testBreakpointExpiration(self): # BPTAG: BREAKPOINT_EXPIRATION +# created_time = datetime.utcnow() - timedelta(hours=25) +# IntegrationTest._hub.SetBreakpoint( +# 'BREAKPOINT_EXPIRATION', +# {'createTime': python_test_util.DateTimeToTimestamp(created_time)}) +# result = IntegrationTest._hub.GetNextResult() + +# self.assertTrue(result['status']['isError']) + +# def testLogAction(self): + +# def Trigger(): +# for i in range(3): +# print('Log me %d' % i) # BPTAG: LOG + +# IntegrationTest._hub.SetBreakpoint( +# 'LOG', { +# 'action': 'LOG', +# 'logLevel': 'INFO', +# 'logMessageFormat': 'hello $0', +# 'expressions': ['i'] +# }) +# Trigger() +# self.assertListEqual( +# ['LOGPOINT: hello 0', 'LOGPOINT: hello 1', 'LOGPOINT: hello 2'], +# self._info_log) + +# def testDeferred(self): + +# def Trigger(): +# import integration_test_helper # pylint: disable=g-import-not-at-top +# integration_test_helper.Trigger() + +# IntegrationTest._hub.SetBreakpointAtFile('integration_test_helper.py', +# 'DEFERRED') + +# Trigger() +# result = IntegrationTest._hub.GetNextResult() +# self.assertEqual('Trigger', result['stackFrames'][0]['function']) +# self.assertEqual('Trigger', result['stackFrames'][1]['function']) +# self.assertEqual('IntegrationTest.testDeferred', +# result['stackFrames'][2]['function']) + + +def MyGlobalDecorator(fn): + + @functools.wraps(fn) + def Wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + return Wrapper + + +@MyGlobalDecorator +def WrappedGlobalMethod(): + return 'hello' # BPTAG: WRAPPED_GLOBAL_METHOD + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/integration_test_helper.py b/tests/py/integration_test_helper.py new file mode 100644 index 0000000..5a7e04b --- /dev/null +++ b/tests/py/integration_test_helper.py @@ -0,0 +1,5 @@ +"""Helper module for integration test to validate deferred breakpoints.""" + + +def Trigger(): + print('bp trigger') # BPTAG: DEFERRED diff --git a/tests/py/labels_test.py b/tests/py/labels_test.py new file mode 100644 index 0000000..b7b01dd --- /dev/null +++ b/tests/py/labels_test.py @@ -0,0 +1,29 @@ +"""Tests for googleclouddebugger.labels""" + +from absl.testing import absltest +from googleclouddebugger import labels + + +class LabelsTest(absltest.TestCase): + + def testDefinesLabelsCorrectly(self): + self.assertEqual(labels.Breakpoint.REQUEST_LOG_ID, 'requestlogid') + + self.assertEqual(labels.Debuggee.DOMAIN, 'domain') + self.assertEqual(labels.Debuggee.PROJECT_ID, 'projectid') + self.assertEqual(labels.Debuggee.MODULE, 'module') + self.assertEqual(labels.Debuggee.VERSION, 'version') + self.assertEqual(labels.Debuggee.MINOR_VERSION, 'minorversion') + self.assertEqual(labels.Debuggee.PLATFORM, 'platform') + self.assertEqual(labels.Debuggee.REGION, 'region') + + def testProvidesAllLabelsSet(self): + self.assertIsNotNone(labels.Breakpoint.SET_ALL) + self.assertLen(labels.Breakpoint.SET_ALL, 1) + + self.assertIsNotNone(labels.Debuggee.SET_ALL) + self.assertLen(labels.Debuggee.SET_ALL, 7) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/module_explorer_test.py b/tests/py/module_explorer_test.py new file mode 100644 index 0000000..4e1a42c --- /dev/null +++ b/tests/py/module_explorer_test.py @@ -0,0 +1,321 @@ +"""Unit test for module_explorer module.""" + +import dis +import inspect +import os +import py_compile +import shutil +import sys +import tempfile + +from absl.testing import absltest + +from googleclouddebugger import module_explorer +import python_test_util + + +class ModuleExplorerTest(absltest.TestCase): + """Unit test for module_explorer module.""" + + def setUp(self): + self._module = sys.modules[__name__] + self._code_objects = module_explorer._GetModuleCodeObjects(self._module) + + # Populate line cache for this module (neeed .par test). + inspect.getsourcelines(self.testCodeObjectAtLine) + + def testGlobalMethod(self): + """Verify that global method is found.""" + self.assertIn(_GlobalMethod.__code__, self._code_objects) + + def testInnerMethodOfGlobalMethod(self): + """Verify that inner method defined in a global method is found.""" + self.assertIn(_GlobalMethod(), self._code_objects) + + def testInstanceClassMethod(self): + """Verify that instance class method is found.""" + self.assertIn(self.testInstanceClassMethod.__code__, self._code_objects) + + def testInnerMethodOfInstanceClassMethod(self): + """Verify that inner method defined in a class instance method is found.""" + + def InnerMethod(): + pass + + self.assertIn(InnerMethod.__code__, self._code_objects) + + def testStaticMethod(self): + """Verify that static class method is found.""" + self.assertIn(ModuleExplorerTest._StaticMethod.__code__, self._code_objects) + + def testInnerMethodOfStaticMethod(self): + """Verify that static class method is found.""" + self.assertIn(ModuleExplorerTest._StaticMethod(), self._code_objects) + + def testNonModuleClassMethod(self): + """Verify that instance method defined in a base class is not added.""" + self.assertNotIn(self.assertTrue.__code__, self._code_objects) + + def testDeepInnerMethod(self): + """Verify that inner of inner of inner, etc. method is found.""" + + def Inner1(): + + def Inner2(): + + def Inner3(): + + def Inner4(): + + def Inner5(): + pass + + return Inner5.__code__ + + return Inner4() + + return Inner3() + + return Inner2() + + self.assertIn(Inner1(), self._code_objects) + + def testNoLambdaExpression(self): + """Verify that code of lambda expression is not included.""" + + self.assertNotIn(_MethodWithLambdaExpression(), self._code_objects) + + def testNoGeneratorExpression(self): + """Verify that code of generator expression is not included.""" + + self.assertNotIn(_MethodWithGeneratorExpression(), self._code_objects) + + def testMethodOfInnerClass(self): + """Verify that method of inner class is found.""" + + class InnerClass(object): + + def InnerClassMethod(self): + pass + + self.assertIn(InnerClass().InnerClassMethod.__code__, self._code_objects) + + def testMethodOfInnerOldStyleClass(self): + """Verify that method of inner old style class is found.""" + + class InnerClass(): + + def InnerClassMethod(self): + pass + + self.assertIn(InnerClass().InnerClassMethod.__code__, self._code_objects) + + def testGlobalMethodWithClosureDecorator(self): + co = self._GetCodeObjectAtLine(self._module, + 'GLOBAL_METHOD_WITH_CLOSURE_DECORATOR') + self.assertTrue(co) + self.assertEqual('GlobalMethodWithClosureDecorator', co.co_name) + + def testClassMethodWithClosureDecorator(self): + co = self._GetCodeObjectAtLine( + self._module, 'GLOBAL_CLASS_METHOD_WITH_CLOSURE_DECORATOR') + self.assertTrue(co) + self.assertEqual('FnWithClosureDecorator', co.co_name) + + def testGlobalMethodWithClassDecorator(self): + co = self._GetCodeObjectAtLine(self._module, + 'GLOBAL_METHOD_WITH_CLASS_DECORATOR') + self.assertTrue(co) + self.assertEqual('GlobalMethodWithClassDecorator', co.co_name) + + def testClassMethodWithClassDecorator(self): + co = self._GetCodeObjectAtLine(self._module, + 'GLOBAL_CLASS_METHOD_WITH_CLASS_DECORATOR') + self.assertTrue(co) + self.assertEqual('FnWithClassDecorator', co.co_name) + + def testSameFileName(self): + """Verify that all found code objects are defined in the same file.""" + path = next(iter(self._code_objects)).co_filename + self.assertTrue(path) + + for code_object in self._code_objects: + self.assertEqual(path, code_object.co_filename) + + def testCodeObjectAtLine(self): + """Verify that query of code object at a specified source line.""" + test_cases = [ + (self.testCodeObjectAtLine.__code__, 'TEST_CODE_OBJECT_AT_ASSERT'), + (ModuleExplorerTest._StaticMethod(), 'INNER_OF_STATIC_METHOD'), + (_GlobalMethod(), 'INNER_OF_GLOBAL_METHOD') + ] + + for code_object, tag in test_cases: + self.assertEqual( # BPTAG: TEST_CODE_OBJECT_AT_ASSERT + code_object, self._GetCodeObjectAtLine(code_object, tag)) + + def testCodeObjectWithoutModule(self): + """Verify no crash/hang when module has no file name.""" + global global_code_object # pylint: disable=global-variable-undefined + global_code_object = compile('2+3', '', 'exec') + + self.assertFalse( + module_explorer.GetCodeObjectAtLine(self._module, 111111)[0]) + + +# TODO: Re-enable this test, without hardcoding a python version into it. +# def testCodeExtensionMismatch(self): +# """Verify module match when code object points to .py and module to .pyc.""" +# test_dir = tempfile.mkdtemp('', 'module_explorer_') +# sys.path.append(test_dir) +# try: +# # Create and compile module, remove the .py file and leave the .pyc file. +# module_path = os.path.join(test_dir, 'module.py') +# with open(module_path, 'w') as f: +# f.write('def f():\n pass') +# py_compile.compile(module_path) +# module_pyc_path = os.path.join(test_dir, '__pycache__', +# 'module.cpython-37.pyc') +# os.rename(module_pyc_path, module_path + 'c') +# os.remove(module_path) +# +# import module # pylint: disable=g-import-not-at-top +# self.assertEqual('.py', +# os.path.splitext(module.f.__code__.co_filename)[1]) +# self.assertEqual('.pyc', os.path.splitext(module.__file__)[1]) +# +# func_code = module.f.__code__ +# self.assertEqual(func_code, +# module_explorer.GetCodeObjectAtLine( +# module, +# next(dis.findlinestarts(func_code))[1])[1]) +# finally: +# sys.path.remove(test_dir) +# shutil.rmtree(test_dir) + + def testMaxVisitObjects(self): + default_quota = module_explorer._MAX_VISIT_OBJECTS + try: + module_explorer._MAX_VISIT_OBJECTS = 10 + self.assertLess( + len(module_explorer._GetModuleCodeObjects(self._module)), + len(self._code_objects)) + finally: + module_explorer._MAX_VISIT_OBJECTS = default_quota + + def testMaxReferentsBfsDepth(self): + default_quota = module_explorer._MAX_REFERENTS_BFS_DEPTH + try: + module_explorer._MAX_REFERENTS_BFS_DEPTH = 1 + self.assertLess( + len(module_explorer._GetModuleCodeObjects(self._module)), + len(self._code_objects)) + finally: + module_explorer._MAX_REFERENTS_BFS_DEPTH = default_quota + + def testMaxObjectReferents(self): + + class A(object): + pass + + default_quota = module_explorer._MAX_VISIT_OBJECTS + default_referents_quota = module_explorer._MAX_OBJECT_REFERENTS + try: + global large_dict + large_dict = {A(): 0 for i in range(0, 5000)} + + # First test with a referents limit too large, it will visit large_dict + # and exhaust the _MAX_VISIT_OBJECTS quota before finding all the code + # objects + module_explorer._MAX_VISIT_OBJECTS = 5000 + module_explorer._MAX_OBJECT_REFERENTS = sys.maxsize + self.assertLess( + len(module_explorer._GetModuleCodeObjects(self._module)), + len(self._code_objects)) + + # Now test with a referents limit that prevents large_dict from being + # explored, all the code objects should be found now that the large dict + # is skipped and isn't taking up the _MAX_VISIT_OBJECTS quota + module_explorer._MAX_OBJECT_REFERENTS = default_referents_quota + self.assertItemsEqual( + module_explorer._GetModuleCodeObjects(self._module), + self._code_objects) + finally: + module_explorer._MAX_VISIT_OBJECTS = default_quota + module_explorer._MAX_OBJECT_REFERENTS = default_referents_quota + large_dict = None + + @staticmethod + def _StaticMethod(): + + def InnerMethod(): + pass # BPTAG: INNER_OF_STATIC_METHOD + + return InnerMethod.__code__ + + def _GetCodeObjectAtLine(self, fn, tag): + """Wrapper over GetCodeObjectAtLine for tags in this module.""" + unused_path, line = python_test_util.ResolveTag(fn, tag) + return module_explorer.GetCodeObjectAtLine(self._module, line)[1] + + +def _GlobalMethod(): + + def InnerMethod(): + pass # BPTAG: INNER_OF_GLOBAL_METHOD + + return InnerMethod.__code__ + + +def ClosureDecorator(handler): + + def Caller(*args): + return handler(*args) + + return Caller + + +class ClassDecorator(object): + + def __init__(self, fn): + self._fn = fn + + def __call__(self, *args): + return self._fn(*args) + + +@ClosureDecorator +def GlobalMethodWithClosureDecorator(): + return True # BPTAG: GLOBAL_METHOD_WITH_CLOSURE_DECORATOR + + +@ClassDecorator +def GlobalMethodWithClassDecorator(): + return True # BPTAG: GLOBAL_METHOD_WITH_CLASS_DECORATOR + + +class GlobalClass(object): + + @ClosureDecorator + def FnWithClosureDecorator(self): + return True # BPTAG: GLOBAL_CLASS_METHOD_WITH_CLOSURE_DECORATOR + + @ClassDecorator + def FnWithClassDecorator(self): + return True # BPTAG: GLOBAL_CLASS_METHOD_WITH_CLASS_DECORATOR + + +def _MethodWithLambdaExpression(): + return (lambda x: x**3).__code__ + + +def _MethodWithGeneratorExpression(): + return (i for i in range(0, 2)).gi_code + + +# Used for testMaxObjectReferents, need to be in global scope or else the module +# explorer would not explore this +large_dict = None + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/module_search_test.py b/tests/py/module_search_test.py new file mode 100644 index 0000000..3a12c57 --- /dev/null +++ b/tests/py/module_search_test.py @@ -0,0 +1,122 @@ +"""Unit test for module_search module.""" + +import os +import sys +import tempfile + +from absl.testing import absltest + +from googleclouddebugger import module_search + + +# TODO: Add tests for whitespace in location path including in, +# extension, basename, path +class SearchModulesTest(absltest.TestCase): + + def setUp(self): + self._test_package_dir = tempfile.mkdtemp('', 'package_') + sys.path.append(self._test_package_dir) + + def tearDown(self): + sys.path.remove(self._test_package_dir) + + def testSearchValidSourcePath(self): + # These modules are on the sys.path. + self.assertEndsWith( + module_search.Search('googleclouddebugger/module_search.py'), + '/site-packages/googleclouddebugger/module_search.py') + + # inspect and dis are libraries with no real file. So, we + # can no longer match them by file path. + + def testSearchInvalidSourcePath(self): + # This is an invalid module that doesn't exist anywhere. + self.assertEqual(module_search.Search('aaaaa.py'), 'aaaaa.py') + + # This module exists, but the search input is missing the outer package + # name. + self.assertEqual(module_search.Search('absltest.py'), 'absltest.py') + + def testSearchInvalidExtension(self): + # Test that the module rejects invalid extension in the input. + with self.assertRaises(AssertionError): + module_search.Search('module_search.x') + + def testSearchPathStartsWithSep(self): + # Test that module rejects invalid leading os.sep char in the input. + with self.assertRaises(AssertionError): + module_search.Search('/module_search') + + def testSearchRelativeSysPath(self): + # An entry in sys.path is in relative form, and represents the same + # directory as as another absolute entry in sys.path. + for directory in ['', 'a', 'a/b']: + self._CreateFile(os.path.join(directory, '__init__.py')) + self._CreateFile('a/b/first.py') + + try: + # Inject a relative path into sys.path that refers to a directory already + # in sys.path. It should produce the same result as the non-relative form. + testdir_alias = os.path.join(self._test_package_dir, 'a/../a') + + # Add 'a/../a' to sys.path so that 'b/first.py' is reachable. + sys.path.insert(0, testdir_alias) + + # Returned result should have a successful file match and relative + # paths should be kept as-is. + result = module_search.Search('b/first.py') + self.assertEndsWith(result, 'a/../a/b/first.py') + + finally: + sys.path.remove(testdir_alias) + + def testSearchSymLinkInSysPath(self): + # An entry in sys.path is a symlink. + for directory in ['', 'a', 'a/b']: + self._CreateFile(os.path.join(directory, '__init__.py'), '') + self._CreateFile('a/b/first.py') + self._CreateSymLink('a', 'link') + + try: + # Add 'link/' to sys.path so that 'b/first.py' is reachable. + sys.path.append(os.path.join(self._test_package_dir, 'link')) + + # Returned result should have a successful file match and symbolic + # links should be kept. + self.assertEndsWith(module_search.Search('b/first.py'), 'link/b/first.py') + finally: + sys.path.remove(os.path.join(self._test_package_dir, 'link')) + + def _CreateFile(self, path, contents='assert False "Unexpected import"\n'): + full_path = os.path.join(self._test_package_dir, path) + directory, unused_name = os.path.split(full_path) + + if not os.path.isdir(directory): + os.makedirs(directory) + + with open(full_path, 'w') as writer: + writer.write(contents) + + return path + + def _CreateSymLink(self, source, link_name): + full_source_path = os.path.join(self._test_package_dir, source) + full_link_path = os.path.join(self._test_package_dir, link_name) + os.symlink(full_source_path, full_link_path) + + # Since we cannot use os.path.samefile or os.path.realpath to eliminate + # symlinks reliably, we only check suffix equivalence of file paths in these + # unit tests. + def _AssertEndsWith(self, match, path): + """Asserts exactly one match ending with path.""" + self.assertLen(match, 1) + self.assertEndsWith(match[0], path) + + def _AssertEqFile(self, match, path): + """Asserts exactly one match equals to the file created with _CreateFile.""" + self.assertLen(match, 1) + self.assertEqual(match[0], os.path.join(self._test_package_dir, path)) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/module_utils_test.py b/tests/py/module_utils_test.py new file mode 100644 index 0000000..ac847ad --- /dev/null +++ b/tests/py/module_utils_test.py @@ -0,0 +1,168 @@ +"""Tests for googleclouddebugger.module_utils.""" + +import os +import sys +import tempfile + +from absl.testing import absltest + +from googleclouddebugger import module_utils + + +class TestModule(object): + """Dummy class with __name__ and __file__ attributes.""" + + def __init__(self, name, path): + self.__name__ = name + self.__file__ = path + + +def _AddSysModule(name, path): + sys.modules[name] = TestModule(name, path) + + +class ModuleUtilsTest(absltest.TestCase): + + def setUp(self): + self._test_package_dir = tempfile.mkdtemp('', 'package_') + self.modules = sys.modules.copy() + + def tearDown(self): + sys.modules = self.modules + self.modules = None + + def _CreateFile(self, path): + full_path = os.path.join(self._test_package_dir, path) + directory, unused_name = os.path.split(full_path) + + if not os.path.isdir(directory): + os.makedirs(directory) + + with open(full_path, 'w') as writer: + writer.write('') + + return full_path + + def _CreateSymLink(self, source, link_name): + full_source_path = os.path.join(self._test_package_dir, source) + full_link_path = os.path.join(self._test_package_dir, link_name) + os.symlink(full_source_path, full_link_path) + return full_link_path + + def _AssertEndsWith(self, a, b, msg=None): + """Assert that string a ends with string b.""" + if not a.endswith(b): + standard_msg = '%s does not end with %s' % (a, b) + self.fail(self._formatMessage(msg, standard_msg)) + + def testSimpleLoadedModuleFromSuffix(self): + # Lookup simple module. + _AddSysModule('m1', '/a/b/p1/m1.pyc') + for suffix in [ + 'm1.py', 'm1.pyc', 'm1.pyo', 'p1/m1.py', 'b/p1/m1.py', 'a/b/p1/m1.py', + '/a/b/p1/m1.py' + ]: + m1 = module_utils.GetLoadedModuleBySuffix(suffix) + self.assertTrue(m1, 'Module not found') + self.assertEqual('/a/b/p1/m1.pyc', m1.__file__) + + # Lookup simple package, no ext. + _AddSysModule('p1', '/a/b/p1/__init__.pyc') + for suffix in [ + 'p1/__init__.py', 'b/p1/__init__.py', 'a/b/p1/__init__.py', + '/a/b/p1/__init__.py' + ]: + p1 = module_utils.GetLoadedModuleBySuffix(suffix) + self.assertTrue(p1, 'Package not found') + self.assertEqual('/a/b/p1/__init__.pyc', p1.__file__) + + # Lookup via bad suffix. + for suffix in [ + 'm2.py', 'p2/m1.py', 'b2/p1/m1.py', 'a2/b/p1/m1.py', '/a2/b/p1/m1.py' + ]: + m1 = module_utils.GetLoadedModuleBySuffix(suffix) + self.assertFalse(m1, 'Module found unexpectedly') + + def testComplexLoadedModuleFromSuffix(self): + # Lookup complex module. + _AddSysModule('b.p1.m1', '/a/b/p1/m1.pyc') + for suffix in [ + 'm1.py', 'p1/m1.py', 'b/p1/m1.py', 'a/b/p1/m1.py', '/a/b/p1/m1.py' + ]: + m1 = module_utils.GetLoadedModuleBySuffix(suffix) + self.assertTrue(m1, 'Module not found') + self.assertEqual('/a/b/p1/m1.pyc', m1.__file__) + + # Lookup complex package, no ext. + _AddSysModule('a.b.p1', '/a/b/p1/__init__.pyc') + for suffix in [ + 'p1/__init__.py', 'b/p1/__init__.py', 'a/b/p1/__init__.py', + '/a/b/p1/__init__.py' + ]: + p1 = module_utils.GetLoadedModuleBySuffix(suffix) + self.assertTrue(p1, 'Package not found') + self.assertEqual('/a/b/p1/__init__.pyc', p1.__file__) + + def testSimilarLoadedModuleFromSuffix(self): + # Lookup similar module, no ext. + _AddSysModule('m1', '/a/b/p2/m1.pyc') + _AddSysModule('p1.m1', '/a/b1/p1/m1.pyc') + _AddSysModule('b.p1.m1', '/a1/b/p1/m1.pyc') + _AddSysModule('a.b.p1.m1', '/a/b/p1/m1.pyc') + + m1 = module_utils.GetLoadedModuleBySuffix('/a/b/p1/m1.py') + self.assertTrue(m1, 'Module not found') + self.assertEqual('/a/b/p1/m1.pyc', m1.__file__) + + # Lookup similar package, no ext. + _AddSysModule('p1', '/a1/b1/p1/__init__.pyc') + _AddSysModule('b.p1', '/a1/b/p1/__init__.pyc') + _AddSysModule('a.b.p1', '/a/b/p1/__init__.pyc') + p1 = module_utils.GetLoadedModuleBySuffix('/a/b/p1/__init__.py') + self.assertTrue(p1, 'Package not found') + self.assertEqual('/a/b/p1/__init__.pyc', p1.__file__) + + def testDuplicateLoadedModuleFromSuffix(self): + # Lookup name dup module and package. + _AddSysModule('m1', '/m1/__init__.pyc') + _AddSysModule('m1.m1', '/m1/m1.pyc') + _AddSysModule('m1.m1.m1', '/m1/m1/m1/__init__.pyc') + _AddSysModule('m1.m1.m1.m1', '/m1/m1/m1/m1.pyc') + + # Ambiguous request, multiple modules might have matched. + m1 = module_utils.GetLoadedModuleBySuffix('/m1/__init__.py') + self.assertTrue(m1, 'Package not found') + self.assertIn(m1.__file__, ['/m1/__init__.pyc', '/m1/m1/m1/__init__.pyc']) + + # Ambiguous request, multiple modules might have matched. + m1m1 = module_utils.GetLoadedModuleBySuffix('/m1/m1.py') + self.assertTrue(m1m1, 'Module not found') + self.assertIn(m1m1.__file__, ['/m1/m1.pyc', '/m1/m1/m1/m1.pyc']) + + # Not ambiguous. Only 1 match possible. + m1m1m1 = module_utils.GetLoadedModuleBySuffix('/m1/m1/m1/__init__.py') + self.assertTrue(m1m1m1, 'Package not found') + self.assertEqual('/m1/m1/m1/__init__.pyc', m1m1m1.__file__) + + # Not ambiguous. Only 1 match possible. + m1m1m1m1 = module_utils.GetLoadedModuleBySuffix('/m1/m1/m1/m1.py') + self.assertTrue(m1m1m1m1, 'Module not found') + self.assertEqual('/m1/m1/m1/m1.pyc', m1m1m1m1.__file__) + + def testMainLoadedModuleFromSuffix(self): + # Lookup complex module. + _AddSysModule('__main__', '/a/b/p/m.pyc') + m1 = module_utils.GetLoadedModuleBySuffix('/a/b/p/m.py') + self.assertTrue(m1, 'Module not found') + self.assertEqual('/a/b/p/m.pyc', m1.__file__) + + def testMainWithDotSlashLoadedModuleFromSuffix(self): + # Lookup module started via 'python3 ./m.py', notice the './' + _AddSysModule('__main__', '/a/b/p/./m.pyc') + m1 = module_utils.GetLoadedModuleBySuffix('/a/b/p/m.py') + self.assertIsNotNone(m1) + self.assertTrue(m1, 'Module not found') + self.assertEqual('/a/b/p/./m.pyc', m1.__file__) + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/native_module_test.py b/tests/py/native_module_test.py new file mode 100644 index 0000000..b3b486b --- /dev/null +++ b/tests/py/native_module_test.py @@ -0,0 +1,300 @@ +"""Unit tests for native module.""" + +import inspect +import sys +import threading +import time + +from absl.testing import absltest + +from googleclouddebugger import cdbg_native as native +import python_test_util + + +def _DoHardWork(base): + for i in range(base): + if base * i < 0: + return True + return False + + +class NativeModuleTest(absltest.TestCase): + """Unit tests for native module.""" + + def setUp(self): + # Lock for thread safety. + self._lock = threading.Lock() + + # Count hit count for the breakpoints we set. + self._breakpoint_counter = 0 + + # Registers breakpoint events other than breakpoint hit. + self._breakpoint_events = [] + + # Keep track of breakpoints we set to reset them on cleanup. + self._cookies = [] + + def tearDown(self): + # Verify that we didn't get any breakpoint events that the test did + # not expect. + self.assertEqual([], self._PopBreakpointEvents()) + + self._ClearAllBreakpoints() + + def testUnconditionalBreakpoint(self): + + def Trigger(): + unused_lock = threading.Lock() + print('Breakpoint trigger') # BPTAG: UNCONDITIONAL_BREAKPOINT + + self._SetBreakpoint(Trigger, 'UNCONDITIONAL_BREAKPOINT') + Trigger() + self.assertEqual(1, self._breakpoint_counter) + + def testConditionalBreakpoint(self): + + def Trigger(): + d = {} + for i in range(1, 10): + d[i] = i**2 # BPTAG: CONDITIONAL_BREAKPOINT + + self._SetBreakpoint(Trigger, 'CONDITIONAL_BREAKPOINT', 'i % 3 == 1') + Trigger() + self.assertEqual(3, self._breakpoint_counter) + + def testClearBreakpoint(self): + """Set two breakpoint on the same line, then clear one.""" + + def Trigger(): + print('Breakpoint trigger') # BPTAG: CLEAR_BREAKPOINT + + self._SetBreakpoint(Trigger, 'CLEAR_BREAKPOINT') + self._SetBreakpoint(Trigger, 'CLEAR_BREAKPOINT') + native.ClearConditionalBreakpoint(self._cookies.pop()) + Trigger() + self.assertEqual(1, self._breakpoint_counter) + + def testMissingModule(self): + + def Test(): + native.CreateConditionalBreakpoint(None, 123123, None, + self._BreakpointEvent) + + self.assertRaises(TypeError, Test) + + def testBadModule(self): + + def Test(): + native.CreateConditionalBreakpoint('str', 123123, None, + self._BreakpointEvent) + + self.assertRaises(TypeError, Test) + + def testInvalidCondition(self): + + def Test(): + native.CreateConditionalBreakpoint(sys.modules[__name__], 123123, '2+2', + self._BreakpointEvent) + + self.assertRaises(TypeError, Test) + + def testMissingCallback(self): + + def Test(): + native.CreateConditionalBreakpoint('code.py', 123123, None, None) + + self.assertRaises(TypeError, Test) + + def testInvalidCallback(self): + + def Test(): + native.CreateConditionalBreakpoint('code.py', 123123, None, {}) + + self.assertRaises(TypeError, Test) + + def testMissingCookie(self): + self.assertRaises(TypeError, + lambda: native.ClearConditionalBreakpoint(None)) + + def testInvalidCookie(self): + native.ClearConditionalBreakpoint(387873457) + + def testMutableCondition(self): + + def Trigger(): + + def MutableMethod(): + self._evil = True + return True + + print('MutableMethod = %s' % MutableMethod) # BPTAG: MUTABLE_CONDITION + + self._SetBreakpoint(Trigger, 'MUTABLE_CONDITION', 'MutableMethod()') + Trigger() + self.assertEqual([native.BREAKPOINT_EVENT_CONDITION_EXPRESSION_MUTABLE], + self._PopBreakpointEvents()) + + def testGlobalConditionQuotaExceeded(self): + + def Trigger(): + print('Breakpoint trigger') # BPTAG: GLOBAL_CONDITION_QUOTA + + self._SetBreakpoint(Trigger, 'GLOBAL_CONDITION_QUOTA', '_DoHardWork(1000)') + Trigger() + self._ClearAllBreakpoints() + + self.assertListEqual( + [native.BREAKPOINT_EVENT_GLOBAL_CONDITION_QUOTA_EXCEEDED], + self._PopBreakpointEvents()) + + # Sleep for some time to let the quota recover. + time.sleep(0.1) + + def testBreakpointConditionQuotaExceeded(self): + + def Trigger(): + print('Breakpoint trigger') # BPTAG: PER_BREAKPOINT_CONDITION_QUOTA + + time.sleep(1) + + # Per-breakpoint quota is lower than the global one. Exponentially + # increase the complexity of a condition until we hit it. + base = 100 + while True: + self._SetBreakpoint(Trigger, 'PER_BREAKPOINT_CONDITION_QUOTA', + '_DoHardWork(%d)' % base) + Trigger() + self._ClearAllBreakpoints() + + events = self._PopBreakpointEvents() + if events: + self.assertEqual( + [native.BREAKPOINT_EVENT_BREAKPOINT_CONDITION_QUOTA_EXCEEDED], + events) + break + + base *= 1.2 + time.sleep(0.1) + + # Sleep for some time to let the quota recover. + time.sleep(0.1) + + def testImmutableCallSuccess(self): + + def Add(a, b, c): + return a + b + c + + def Magic(): + return 'cake' + + self.assertEqual('643535', + self._CallImmutable(inspect.currentframe(), 'str(643535)')) + self.assertEqual( + 786 + 23 + 891, + self._CallImmutable(inspect.currentframe(), 'Add(786, 23, 891)')) + self.assertEqual('cake', + self._CallImmutable(inspect.currentframe(), 'Magic()')) + return Add or Magic + + def testImmutableCallMutable(self): + + def Change(): + dictionary['bad'] = True + + dictionary = {} + frame = inspect.currentframe() + self.assertRaises(SystemError, + lambda: self._CallImmutable(frame, 'Change()')) + self.assertEqual({}, dictionary) + return Change + + def testImmutableCallExceptionPropagation(self): + + def Divide(a, b): + return a / b + + frame = inspect.currentframe() + self.assertRaises(ZeroDivisionError, + lambda: self._CallImmutable(frame, 'Divide(1, 0)')) + return Divide + + def testImmutableCallInvalidFrame(self): + self.assertRaises(TypeError, lambda: native.CallImmutable(None, lambda: 1)) + self.assertRaises(TypeError, + lambda: native.CallImmutable('not a frame', lambda: 1)) + + def testImmutableCallInvalidCallable(self): + frame = inspect.currentframe() + self.assertRaises(TypeError, lambda: native.CallImmutable(frame, None)) + self.assertRaises(TypeError, + lambda: native.CallImmutable(frame, 'not a callable')) + + def _SetBreakpoint(self, method, tag, condition=None): + """Sets a breakpoint in this source file. + + The line number is identified by tag. This function does not verify that + the source line is in the specified method. + + The breakpoint may have an optional condition. + + Args: + method: method in which the breakpoint will be set. + tag: label for a source line. + condition: optional breakpoint condition. + """ + unused_path, line = python_test_util.ResolveTag(type(self), tag) + + compiled_condition = None + if condition is not None: + compiled_condition = compile(condition, '', 'eval') + + cookie = native.CreateConditionalBreakpoint(method.__code__, line, + compiled_condition, + self._BreakpointEvent) + + self._cookies.append(cookie) + native.ActivateConditionalBreakpoint(cookie) + + def _ClearAllBreakpoints(self): + """Removes all previously set breakpoints.""" + for cookie in self._cookies: + native.ClearConditionalBreakpoint(cookie) + + def _CallImmutable(self, frame, expression): + """Wrapper over native.ImmutableCall for callable.""" + return native.CallImmutable(frame, + compile(expression, '', 'eval')) + + def _BreakpointEvent(self, event, frame): + """Callback on breakpoint event. + + See thread_breakpoints.h for more details of possible events. + + Args: + event: breakpoint event (see kIntegerConstants in native_module.cc). + frame: Python stack frame of breakpoint hit or None for other events. + """ + with self._lock: + if event == native.BREAKPOINT_EVENT_HIT: + self.assertTrue(inspect.isframe(frame)) + self._breakpoint_counter += 1 + else: + self._breakpoint_events.append(event) + + def _PopBreakpointEvents(self): + """Gets and resets the list of breakpoint events received so far.""" + with self._lock: + events = self._breakpoint_events + self._breakpoint_events = [] + return events + + def _HasBreakpointEvents(self): + """Checks whether there are unprocessed breakpoint events.""" + with self._lock: + if self._breakpoint_events: + return True + return False + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/python_breakpoint_test.py b/tests/py/python_breakpoint_test.py new file mode 100644 index 0000000..6aff9c4 --- /dev/null +++ b/tests/py/python_breakpoint_test.py @@ -0,0 +1,652 @@ +"""Unit test for python_breakpoint module.""" + +from datetime import datetime +from datetime import timedelta +import inspect +import os +import sys +import tempfile + +from absl.testing import absltest + +from googleclouddebugger import cdbg_native as native +from googleclouddebugger import imphook +from googleclouddebugger import python_breakpoint +import python_test_util + + +class PythonBreakpointTest(absltest.TestCase): + """Unit test for python_breakpoint module.""" + + def setUp(self): + self._test_package_dir = tempfile.mkdtemp('', 'package_') + sys.path.append(self._test_package_dir) + + path, line = python_test_util.ResolveTag(type(self), 'CODE_LINE') + + self._base_time = datetime(year=2015, month=1, day=1) # BPTAG: CODE_LINE + self._template = { + 'id': 'BP_ID', + 'createTime': python_test_util.DateTimeToTimestamp(self._base_time), + 'location': { + 'path': path, + 'line': line + } + } + self._completed = set() + self._update_queue = [] + + def tearDown(self): + sys.path.remove(self._test_package_dir) + + def CompleteBreakpoint(self, breakpoint_id): + """Mock method of BreakpointsManager.""" + self._completed.add(breakpoint_id) + + def GetCurrentTime(self): + """Mock method of BreakpointsManager.""" + return self._base_time + + def EnqueueBreakpointUpdate(self, breakpoint): + """Mock method of HubClient.""" + self._update_queue.append(breakpoint) + + def testClear(self): + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint.Clear() + self.assertFalse(breakpoint._cookie) + + def testId(self): + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint.Clear() + self.assertEqual('BP_ID', breakpoint.GetBreakpointId()) + + def testNullBytesInCondition(self): + python_breakpoint.PythonBreakpoint( + dict(self._template, condition='\0'), self, self, None) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['status']['isError']) + self.assertTrue(self._update_queue[0]['isFinalState']) + + # Test only applies to the old module search algorithm. When using new module + # search algorithm, this test is same as testDeferredBreakpoint. + def testUnknownModule(self): + pass + + def testDeferredBreakpoint(self): + with open(os.path.join(self._test_package_dir, 'defer_print.py'), 'w') as f: + f.write('def DoPrint():\n') + f.write(' print("Hello from deferred module")\n') + + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': 'defer_print.py', + 'line': 2 + }), self, self, None) + + self.assertFalse(self._completed) + self.assertEmpty(self._update_queue) + + import defer_print # pylint: disable=g-import-not-at-top + defer_print.DoPrint() + + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertGreater(len(self._update_queue[0]['stackFrames']), 3) + self.assertEqual('DoPrint', + self._update_queue[0]['stackFrames'][0]['function']) + self.assertTrue(self._update_queue[0]['isFinalState']) + + self.assertEmpty(imphook._import_callbacks) + + # Old module search algorithm rejects multiple matches. This test verifies + # that the new module search algorithm searches sys.path sequentially, and + # selects the first match (just like the Python importer). + def testSearchUsingSysPathOrder(self): + for i in range(2, 0, -1): + # Create directories and add them to sys.path. + test_dir = os.path.join(self._test_package_dir, ('inner2_%s' % i)) + os.mkdir(test_dir) + sys.path.append(test_dir) + with open(os.path.join(test_dir, 'mod2.py'), 'w') as f: + f.write('def DoPrint():\n') + f.write(' x = %s\n' % i) + f.write(' return x') + + # Loads inner2_2/mod2.py because it comes first in sys.path. + import mod2 # pylint: disable=g-import-not-at-top + + # Search will proceed in sys.path order, and the first match in sys.path + # will uniquely identify the full path of the module as inner2_2/mod2.py. + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': 'mod2.py', + 'line': 3 + }), self, self, None) + + self.assertEqual(2, mod2.DoPrint()) + + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertGreater(len(self._update_queue[0]['stackFrames']), 3) + self.assertEqual('DoPrint', + self._update_queue[0]['stackFrames'][0]['function']) + self.assertTrue(self._update_queue[0]['isFinalState']) + self.assertEqual( + 'x', self._update_queue[0]['stackFrames'][0]['locals'][0]['name']) + self.assertEqual( + '2', self._update_queue[0]['stackFrames'][0]['locals'][0]['value']) + + self.assertEmpty(imphook._import_callbacks) + + # Old module search algorithm rejects multiple matches. This test verifies + # that when the new module search cannot find any match in sys.path, it + # defers the breakpoint, and then selects the first dynamically-loaded + # module that matches the given path. + def testMultipleDeferredMatches(self): + for i in range(2, 0, -1): + # Create packages, but do not add them to sys.path. + test_dir = os.path.join(self._test_package_dir, ('inner3_%s' % i)) + os.mkdir(test_dir) + with open(os.path.join(test_dir, '__init__.py'), 'w') as f: + pass + with open(os.path.join(test_dir, 'defer_print3.py'), 'w') as f: + f.write('def DoPrint():\n') + f.write(' x = %s\n' % i) + f.write(' return x') + + # This breakpoint will be deferred. It can match any one of the modules + # created above. + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': 'defer_print3.py', + 'line': 3 + }), self, self, None) + + # Lazy import module. Activates breakpoint on the loaded module. + import inner3_1.defer_print3 # pylint: disable=g-import-not-at-top + self.assertEqual(1, inner3_1.defer_print3.DoPrint()) + + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertGreater(len(self._update_queue[0]['stackFrames']), 3) + self.assertEqual('DoPrint', + self._update_queue[0]['stackFrames'][0]['function']) + self.assertTrue(self._update_queue[0]['isFinalState']) + self.assertEqual( + 'x', self._update_queue[0]['stackFrames'][0]['locals'][0]['name']) + self.assertEqual( + '1', self._update_queue[0]['stackFrames'][0]['locals'][0]['value']) + + self.assertEmpty(imphook._import_callbacks) + + def testNeverLoadedBreakpoint(self): + open(os.path.join(self._test_package_dir, 'never_print.py'), 'w').close() + + breakpoint = python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': 'never_print.py', + 'line': 99 + }), self, self, None) + breakpoint.Clear() + + self.assertFalse(self._completed) + self.assertEmpty(self._update_queue) + + def testDeferredNoCodeAtLine(self): + open(os.path.join(self._test_package_dir, 'defer_empty.py'), 'w').close() + + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': 'defer_empty.py', + 'line': 10 + }), self, self, None) + + self.assertFalse(self._completed) + self.assertEmpty(self._update_queue) + + import defer_empty # pylint: disable=g-import-not-at-top,unused-variable + + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + status = self._update_queue[0]['status'] + self.assertEqual(status['isError'], True) + self.assertEqual(status['refersTo'], 'BREAKPOINT_SOURCE_LOCATION') + desc = status['description'] + self.assertEqual(desc['format'], 'No code found at line $0 in $1') + params = desc['parameters'] + self.assertIn('defer_empty.py', params[1]) + self.assertEqual(params[0], '10') + self.assertEmpty(imphook._import_callbacks) + + def testDeferredBreakpointCancelled(self): + open(os.path.join(self._test_package_dir, 'defer_cancel.py'), 'w').close() + + breakpoint = python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': 'defer_cancel.py', + 'line': 11 + }), self, self, None) + breakpoint.Clear() + + self.assertFalse(self._completed) + self.assertEmpty(imphook._import_callbacks) + unused_no_code_line_above = 0 # BPTAG: NO_CODE_LINE_ABOVE + + # BPTAG: NO_CODE_LINE + def testNoCodeAtLine(self): + unused_no_code_line_below = 0 # BPTAG: NO_CODE_LINE_BELOW + path, line = python_test_util.ResolveTag(sys.modules[__name__], + 'NO_CODE_LINE') + path, line_above = python_test_util.ResolveTag(sys.modules[__name__], + 'NO_CODE_LINE_ABOVE') + path, line_below = python_test_util.ResolveTag(sys.modules[__name__], + 'NO_CODE_LINE_BELOW') + + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': path, + 'line': line + }), self, self, None) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + status = self._update_queue[0]['status'] + self.assertEqual(status['isError'], True) + self.assertEqual(status['refersTo'], 'BREAKPOINT_SOURCE_LOCATION') + desc = status['description'] + self.assertEqual(desc['format'], + 'No code found at line $0 in $1. Try lines $2 or $3.') + params = desc['parameters'] + self.assertEqual(params[0], str(line)) + self.assertIn(path, params[1]) + self.assertEqual(params[2], str(line_above)) + self.assertEqual(params[3], str(line_below)) + + def testBadExtension(self): + for path in ['unknown.so', 'unknown', 'unknown.java', 'unknown.pyc']: + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': path, + 'line': 83 + }), self, self, None) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + self.assertEqual( + { + 'isError': True, + 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', + 'description': { + 'format': ('Only files with .py extension are supported') + } + }, self._update_queue[0]['status']) + self._update_queue = [] + + def testRootInitFile(self): + for path in [ + '__init__.py', '/__init__.py', '////__init__.py', ' __init__.py ', + ' //__init__.py' + ]: + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': path, + 'line': 83 + }), self, self, None) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + self.assertEqual( + { + 'isError': True, + 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', + 'description': { + 'format': 'Multiple modules matching $0. ' + 'Please specify the module path.', + 'parameters': ['__init__.py'] + } + }, self._update_queue[0]['status']) + self._update_queue = [] + + # Old module search algorithm rejects because there are too many matches. + # The new algorithm selects the very first match in sys.path. + def testNonRootInitFile(self): + # Neither 'a' nor 'a/b' are real packages accessible via sys.path. + # Therefore, module search falls back to search '__init__.py', which matches + # the first entry in sys.path, which we artifically inject below. + test_dir = os.path.join(self._test_package_dir, 'inner4') + os.mkdir(test_dir) + with open(os.path.join(test_dir, '__init__.py'), 'w') as f: + f.write('def DoPrint():\n') + f.write(' print("Hello")') + sys.path.insert(0, test_dir) + + import inner4 # pylint: disable=g-import-not-at-top,unused-variable + + for path in ['/a/__init__.py', 'a/__init__.py', 'a/b/__init__.py']: + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': path, + 'line': 2 + }), self, self, None) + + inner4.DoPrint() + + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + self.assertGreater(len(self._update_queue[0]['stackFrames']), 3) + self.assertEqual('DoPrint', + self._update_queue[0]['stackFrames'][0]['function']) + + self.assertEmpty(imphook._import_callbacks) + self._update_queue = [] + + def testBreakpointInLoadedPackageFile(self): + """Test breakpoint in a loaded package.""" + for name in ['pkg', 'pkg/pkg']: + test_dir = os.path.join(self._test_package_dir, name) + os.mkdir(test_dir) + with open(os.path.join(test_dir, '__init__.py'), 'w') as f: + f.write('def DoPrint():\n') + f.write(' print("Hello from %s")\n' % name) + + import pkg # pylint: disable=g-import-not-at-top,unused-variable + import pkg.pkg # pylint: disable=g-import-not-at-top,unused-variable + + python_breakpoint.PythonBreakpoint( + dict( + self._template, location={ + 'path': 'pkg/pkg/__init__.py', + 'line': 2 + }), self, self, None) + + pkg.pkg.DoPrint() + + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + self.assertEqual(None, self._update_queue[0].get('status')) + self._update_queue = [] + + def testInternalError(self): + """Simulate internal error when setting a new breakpoint. + + Bytecode rewriting breakpoints are not supported for methods with more + than 65K constants. We generate such a method and try to set breakpoint in + it. + """ + + with open(os.path.join(self._test_package_dir, 'intern_err.py'), 'w') as f: + f.write('def DoSums():\n') + f.write(' x = 0\n') + for i in range(70000): + f.write(' x = x + %d\n' % i) + f.write(' print(x)\n') + + import intern_err # pylint: disable=g-import-not-at-top,unused-variable + + python_breakpoint.PythonBreakpoint( + dict(self._template, location={ + 'path': 'intern_err.py', + 'line': 100 + }), self, self, None) + + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertEqual( + { + 'isError': True, + 'description': { + 'format': 'Internal error occurred' + } + }, self._update_queue[0]['status']) + + def testInvalidCondition(self): + python_breakpoint.PythonBreakpoint( + dict(self._template, condition='2+'), self, self, None) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + if sys.version_info.minor < 10: + self.assertEqual( + { + 'isError': True, + 'refersTo': 'BREAKPOINT_CONDITION', + 'description': { + 'format': 'Expression could not be compiled: $0', + 'parameters': ['unexpected EOF while parsing'] + } + }, self._update_queue[0]['status']) + else: + self.assertEqual( + { + 'isError': True, + 'refersTo': 'BREAKPOINT_CONDITION', + 'description': { + 'format': 'Expression could not be compiled: $0', + 'parameters': ['invalid syntax'] + } + }, self._update_queue[0]['status']) + + def testHit(self): + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint._BreakpointEvent(native.BREAKPOINT_EVENT_HIT, + inspect.currentframe()) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertGreater(len(self._update_queue[0]['stackFrames']), 3) + self.assertTrue(self._update_queue[0]['isFinalState']) + + def testHitNewTimestamp(self): + # Override to use the new format (i.e., without the '.%f' sub-second part) + self._template['createTime'] = python_test_util.DateTimeToTimestampNew( + self._base_time) + + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint._BreakpointEvent(native.BREAKPOINT_EVENT_HIT, + inspect.currentframe()) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertGreater(len(self._update_queue[0]['stackFrames']), 3) + self.assertTrue(self._update_queue[0]['isFinalState']) + + def testHitTimestampUnixMsec(self): + # Using the Snapshot Debugger (Firebase backend) version of creation time + self._template.pop('createTime', None); + self._template[ + 'createTimeUnixMsec'] = python_test_util.DateTimeToUnixMsec( + self._base_time) + + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint._BreakpointEvent(native.BREAKPOINT_EVENT_HIT, + inspect.currentframe()) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertGreater(len(self._update_queue[0]['stackFrames']), 3) + self.assertTrue(self._update_queue[0]['isFinalState']) + + def testDoubleHit(self): + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint._BreakpointEvent(native.BREAKPOINT_EVENT_HIT, + inspect.currentframe()) + breakpoint._BreakpointEvent(native.BREAKPOINT_EVENT_HIT, + inspect.currentframe()) + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + + def testEndToEndUnconditional(self): + + def Trigger(): + pass # BPTAG: E2E_UNCONDITIONAL + + path, line = python_test_util.ResolveTag(type(self), 'E2E_UNCONDITIONAL') + breakpoint = python_breakpoint.PythonBreakpoint( + { + 'id': 'BP_ID', + 'location': { + 'path': path, + 'line': line + } + }, self, self, None) + self.assertEmpty(self._update_queue) + Trigger() + self.assertLen(self._update_queue, 1) + breakpoint.Clear() + + def testEndToEndConditional(self): + + def Trigger(): + for i in range(2): + self.assertLen(self._update_queue, i) # BPTAG: E2E_CONDITIONAL + + path, line = python_test_util.ResolveTag(type(self), 'E2E_CONDITIONAL') + breakpoint = python_breakpoint.PythonBreakpoint( + { + 'id': 'BP_ID', + 'location': { + 'path': path, + 'line': line + }, + 'condition': 'i == 1' + }, self, self, None) + Trigger() + breakpoint.Clear() + + def testEndToEndCleared(self): + path, line = python_test_util.ResolveTag(type(self), 'E2E_CLEARED') + breakpoint = python_breakpoint.PythonBreakpoint( + { + 'id': 'BP_ID', + 'location': { + 'path': path, + 'line': line + } + }, self, self, None) + breakpoint.Clear() + self.assertEmpty(self._update_queue) # BPTAG: E2E_CLEARED + + def testBreakpointCancellationEvent(self): + events = [ + native.BREAKPOINT_EVENT_GLOBAL_CONDITION_QUOTA_EXCEEDED, + native.BREAKPOINT_EVENT_BREAKPOINT_CONDITION_QUOTA_EXCEEDED, + native.BREAKPOINT_EVENT_CONDITION_EXPRESSION_MUTABLE + ] + for event in events: + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, + self, None) + breakpoint._BreakpointEvent(event, None) + self.assertLen(self._update_queue, 1) + self.assertEqual(set(['BP_ID']), self._completed) + + self._update_queue = [] + self._completed = set() + + def testExpirationTime(self): + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint.Clear() + self.assertEqual( + datetime(year=2015, month=1, day=2), breakpoint.GetExpirationTime()) + + def testExpirationTimeUnixMsec(self): + # Using the Snapshot Debugger (Firebase backend) version of creation time + self._template.pop('createTime', None); + self._template[ + 'createTimeUnixMsec'] = python_test_util.DateTimeToUnixMsec( + self._base_time) + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint.Clear() + self.assertEqual( + self._base_time + timedelta(hours=24), breakpoint.GetExpirationTime()) + + def testExpirationTimeWithExpiresIn(self): + definition = self._template.copy() + definition['expires_in'] = { + 'seconds': 300 # 5 minutes + } + + breakpoint = python_breakpoint.PythonBreakpoint(definition, self, self, + None) + breakpoint.Clear() + self.assertEqual( + datetime(year=2015, month=1, day=2), breakpoint.GetExpirationTime()) + + def testExpiration(self): + breakpoint = python_breakpoint.PythonBreakpoint(self._template, self, self, + None) + breakpoint.ExpireBreakpoint() + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + self.assertEqual( + { + 'isError': True, + 'refersTo': 'BREAKPOINT_AGE', + 'description': { + 'format': 'The snapshot has expired' + } + }, self._update_queue[0]['status']) + + def testLogpointExpiration(self): + definition = self._template.copy() + definition['action'] = 'LOG' + breakpoint = python_breakpoint.PythonBreakpoint(definition, self, self, + None) + breakpoint.ExpireBreakpoint() + self.assertEqual(set(['BP_ID']), self._completed) + self.assertLen(self._update_queue, 1) + self.assertTrue(self._update_queue[0]['isFinalState']) + self.assertEqual( + { + 'isError': True, + 'refersTo': 'BREAKPOINT_AGE', + 'description': { + 'format': 'The logpoint has expired' + } + }, self._update_queue[0]['status']) + + def testNormalizePath(self): + # Removes leading '/' character. + for path in ['/__init__.py', '//__init__.py', '////__init__.py']: + self.assertEqual('__init__.py', python_breakpoint._NormalizePath(path)) + + # Removes leading and trailing whitespace. + for path in [' __init__.py', '__init__.py ', ' __init__.py ']: + self.assertEqual('__init__.py', python_breakpoint._NormalizePath(path)) + + # Removes combination of leading/trailing whitespace and '/' character. + for path in [' /__init__.py', ' ///__init__.py', '////__init__.py']: + self.assertEqual('__init__.py', python_breakpoint._NormalizePath(path)) + + # Normalizes the relative path. + for path in [ + ' ./__init__.py', '././__init__.py', ' .//abc/../__init__.py', + ' ///abc///..///def/..////__init__.py' + ]: + self.assertEqual('__init__.py', python_breakpoint._NormalizePath(path)) + + # Does not remove non-leading, non-trailing space, or non-leading '/' + # characters. + self.assertEqual( + 'foo bar/baz/__init__.py', + python_breakpoint._NormalizePath('/foo bar/baz/__init__.py')) + self.assertEqual( + 'foo/bar baz/__init__.py', + python_breakpoint._NormalizePath('/foo/bar baz/__init__.py')) + self.assertEqual( + 'foo/bar/baz/__in it__.py', + python_breakpoint._NormalizePath('/foo/bar/baz/__in it__.py')) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/python_test_util.py b/tests/py/python_test_util.py new file mode 100644 index 0000000..ffac4da --- /dev/null +++ b/tests/py/python_test_util.py @@ -0,0 +1,189 @@ +"""Set of helper methods for Python debuglet unit and component tests.""" + +import inspect +import re + + +def GetModuleInfo(obj): + """Gets the source file path and breakpoint tags for a module. + + Breakpoint tag is a named label of a source line. The tag is marked + with "# BPTAG: XXX" comment. + + Args: + obj: any object inside the queried module. + + Returns: + (path, tags) tuple where tags is a dictionary mapping tag name to + line numbers where this tag appears. + """ + return (inspect.getsourcefile(obj), GetSourceFileTags(obj)) + + +def GetSourceFileTags(source): + """Gets breakpoint tags for the specified source file. + + Breakpoint tag is a named label of a source line. The tag is marked + with "# BPTAG: XXX" comment. + + Args: + source: either path to the .py file to analyze or any code related + object (e.g. module, function, code object). + + Returns: + Dictionary mapping tag name to line numbers where this tag appears. + """ + if isinstance(source, str): + lines = open(source, 'r').read().splitlines() + start_line = 1 # line number is 1 based + else: + lines, start_line = inspect.getsourcelines(source) + if not start_line: # "getsourcelines" returns start_line of 0 for modules. + start_line = 1 + + tags = {} + regex = re.compile(r'# BPTAG: ([0-9a-zA-Z_]+)\s*$') + for n, line in enumerate(lines): + m = regex.search(line) + if m: + tag = m.group(1) + if tag in tags: + tags[tag].append(n + start_line) + else: + tags[tag] = [n + start_line] + + return tags + + +def ResolveTag(obj, tag): + """Resolves the breakpoint tag into source file path and a line number. + + Breakpoint tag is a named label of a source line. The tag is marked + with "# BPTAG: XXX" comment. + + Raises + + Args: + obj: any object inside the queried module. + tag: tag name to resolve. + + Raises: + Exception: if no line in the source file define the specified tag or if + more than one line define the tag. + + Returns: + (path, line) tuple, where line is the line number where the tag appears. + """ + path, tags = GetModuleInfo(obj) + if tag not in tags: + raise Exception('tag %s not found' % tag) + lines = tags[tag] + if len(lines) != 1: + raise Exception('tag %s is ambiguous (lines: %s)' % (tag, lines)) + return path, lines[0] + + +def DateTimeToTimestamp(t): + """Converts the specified time to Timestamp format. + + Args: + t: datetime instance + + Returns: + Time in Timestamp format + """ + return t.strftime('%Y-%m-%dT%H:%M:%S.%f') + 'Z' + + +def DateTimeToTimestampNew(t): + """Converts the specified time to Timestamp format in seconds granularity. + + Args: + t: datetime instance + + Returns: + Time in Timestamp format in seconds granularity + """ + return t.strftime('%Y-%m-%dT%H:%M:%S') + 'Z' + +def DateTimeToUnixMsec(t): + """Returns the Unix time as in integer value in milliseconds""" + return int(t.timestamp() * 1000) + + +def PackFrameVariable(breakpoint, name, frame=0, collection='locals'): + """Finds local variable or argument by name. + + Indirections created through varTableIndex are recursively collapsed. Fails + the test case if the named variable is not found. + + Args: + breakpoint: queried breakpoint. + name: name of the local variable or argument. + frame: stack frame index to examine. + collection: 'locals' to get local variable or 'arguments' for an argument. + + Returns: + Single dictionary of variable data. + + Raises: + AssertionError: if the named variable not found. + """ + for variable in breakpoint['stackFrames'][frame][collection]: + if variable['name'] == name: + return _Pack(variable, breakpoint) + + raise AssertionError('Variable %s not found in frame %d collection %s' % + (name, frame, collection)) + + +def PackWatchedExpression(breakpoint, expression): + """Finds watched expression by index. + + Indirections created through varTableIndex are recursively collapsed. Fails + the test case if the named variable is not found. + + Args: + breakpoint: queried breakpoint. + expression: index of the watched expression. + + Returns: + Single dictionary of variable data. + """ + return _Pack(breakpoint['evaluatedExpressions'][expression], breakpoint) + + +def _Pack(variable, breakpoint): + """Recursively collapses indirections created through varTableIndex. + + Circular references by objects are not supported. If variable subtree + has circular references, this function will hang. + + Variable members are sorted by name. This helps asserting the content of + variable since Python has no guarantees over the order of keys of a + dictionary. + + Args: + variable: variable object to pack. Not modified. + breakpoint: queried breakpoint. + + Returns: + A new dictionary with packed variable object. + """ + packed = dict(variable) + + while 'varTableIndex' in packed: + ref = breakpoint['variableTable'][packed['varTableIndex']] + assert 'name' not in ref + assert 'value' not in packed + assert 'members' not in packed + assert 'status' not in ref and 'status' not in packed + del packed['varTableIndex'] + packed.update(ref) + + if 'members' in packed: + packed['members'] = sorted( + [_Pack(m, breakpoint) for m in packed['members']], + key=lambda m: m.get('name', '')) + + return packed diff --git a/tests/py/uniquifier_computer_test.py b/tests/py/uniquifier_computer_test.py new file mode 100644 index 0000000..3d382b7 --- /dev/null +++ b/tests/py/uniquifier_computer_test.py @@ -0,0 +1,121 @@ +"""Unit test for uniquifier_computer module.""" + +import os +import sys +import tempfile + +from absl.testing import absltest + +from googleclouddebugger import uniquifier_computer + + +class UniquifierComputerTest(absltest.TestCase): + + def _Compute(self, files): + """Creates a directory structure and computes uniquifier on it. + + Args: + files: dictionary of relative path to file content. + + Returns: + Uniquifier data lines. + """ + + class Hash(object): + """Fake implementation of hash to collect raw data.""" + + def __init__(self): + self.data = b'' + + def update(self, s): + self.data += s + + root = tempfile.mkdtemp('', 'fake_app_') + for relative_path, content in files.items(): + path = os.path.join(root, relative_path) + directory = os.path.split(path)[0] + if not os.path.exists(directory): + os.makedirs(directory) + with open(path, 'w') as f: + f.write(content) + + sys.path.insert(0, root) + try: + hash_obj = Hash() + uniquifier_computer.ComputeApplicationUniquifier(hash_obj) + return [ + u.decode() for u in ( + hash_obj.data.rstrip(b'\n').split(b'\n') if hash_obj.data else []) + ] + finally: + del sys.path[0] + + def testEmpty(self): + self.assertListEqual([], self._Compute({})) + + def testBundle(self): + self.assertListEqual([ + 'first.py:1', 'in1/__init__.py:6', 'in1/a.py:3', 'in1/b.py:4', + 'in1/in2/__init__.py:7', 'in1/in2/c.py:5', 'second.py:2' + ], + self._Compute({ + 'db.app': 'abc', + 'first.py': 'a', + 'second.py': 'bb', + 'in1/a.py': 'ccc', + 'in1/b.py': 'dddd', + 'in1/in2/c.py': 'eeeee', + 'in1/__init__.py': 'ffffff', + 'in1/in2/__init__.py': 'ggggggg' + })) + + def testEmptyFile(self): + self.assertListEqual(['empty.py:0'], self._Compute({'empty.py': ''})) + + def testNonPythonFilesIgnored(self): + self.assertListEqual(['real.py:1'], + self._Compute({ + 'file.p': '', + 'file.pya': '', + 'real.py': '1' + })) + + def testNonPackageDirectoriesIgnored(self): + self.assertListEqual(['dir2/__init__.py:1'], + self._Compute({ + 'dir1/file.py': '', + 'dir2/__init__.py': 'a', + 'dir2/image.gif': '' + })) + + def testDepthLimit(self): + self.assertListEqual([ + ''.join(str(n) + '/' + for n in range(1, m + 1)) + '__init__.py:%d' % m + for m in range(9, 0, -1) + ], + self._Compute({ + '1/__init__.py': '1', + '1/2/__init__.py': '2' * 2, + '1/2/3/__init__.py': '3' * 3, + '1/2/3/4/__init__.py': '4' * 4, + '1/2/3/4/5/__init__.py': '5' * 5, + '1/2/3/4/5/6/__init__.py': '6' * 6, + '1/2/3/4/5/6/7/__init__.py': '7' * 7, + '1/2/3/4/5/6/7/8/__init__.py': '8' * 8, + '1/2/3/4/5/6/7/8/9/__init__.py': '9' * 9, + '1/2/3/4/5/6/7/8/9/10/__init__.py': 'a' * 10, + '1/2/3/4/5/6/7/8/9/10/11/__init__.py': 'b' * 11 + })) + + def testPrecedence(self): + self.assertListEqual(['my.py:3'], + self._Compute({ + 'my.pyo': 'a', + 'my.pyc': 'aa', + 'my.py': 'aaa' + })) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/py/yaml_data_visibility_config_reader_test.py b/tests/py/yaml_data_visibility_config_reader_test.py new file mode 100644 index 0000000..65d3cd0 --- /dev/null +++ b/tests/py/yaml_data_visibility_config_reader_test.py @@ -0,0 +1,121 @@ +"""Tests for yaml_data_visibility_config_reader.""" + +import os +import sys +from unittest import mock + +from io import StringIO + +from absl.testing import absltest +from googleclouddebugger import yaml_data_visibility_config_reader + + +class StringIOOpen(object): + """An open for StringIO that supports "with" semantics. + + I tried using mock.mock_open, but the read logic in the yaml.load code is + incompatible with the returned mock object, leading to a test hang/timeout. + """ + + def __init__(self, data): + self.file_obj = StringIO(data) + + def __enter__(self): + return self.file_obj + + def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin + pass + + +class YamlDataVisibilityConfigReaderTest(absltest.TestCase): + + def testOpenAndReadSuccess(self): + data = """ + blacklist: + - bl1 + """ + path_prefix = 'googleclouddebugger.' + with mock.patch( + path_prefix + 'yaml_data_visibility_config_reader.open', + create=True) as m: + m.return_value = StringIOOpen(data) + config = yaml_data_visibility_config_reader.OpenAndRead() + m.assert_called_with( + os.path.join(sys.path[0], 'debugger-blacklist.yaml'), 'r') + self.assertEqual(config.blacklist_patterns, ['bl1']) + + def testOpenAndReadFileNotFound(self): + path_prefix = 'googleclouddebugger.' + with mock.patch( + path_prefix + 'yaml_data_visibility_config_reader.open', + create=True, + side_effect=IOError('IO Error')): + f = yaml_data_visibility_config_reader.OpenAndRead() + self.assertIsNone(f) + + def testReadDataSuccess(self): + data = """ + blacklist: + - bl1 + - bl2 + whitelist: + - wl1 + - wl2.* + """ + + config = yaml_data_visibility_config_reader.Read(StringIO(data)) + self.assertItemsEqual(config.blacklist_patterns, ('bl1', 'bl2')) + self.assertItemsEqual(config.whitelist_patterns, ('wl1', 'wl2.*')) + + def testYAMLLoadError(self): + + class ErrorIO(object): + + def read(self, size): + del size # Unused + raise IOError('IO Error') + + with self.assertRaises(yaml_data_visibility_config_reader.YAMLLoadError): + yaml_data_visibility_config_reader.Read(ErrorIO()) + + def testBadYamlSyntax(self): + data = """ + blacklist: whitelist: + """ + + with self.assertRaises(yaml_data_visibility_config_reader.ParseError): + yaml_data_visibility_config_reader.Read(StringIO(data)) + + def testUnknownConfigKeyError(self): + data = """ + foo: + - bar + """ + + with self.assertRaises( + yaml_data_visibility_config_reader.UnknownConfigKeyError): + yaml_data_visibility_config_reader.Read(StringIO(data)) + + def testNotAListError(self): + data = """ + blacklist: + foo: + - bar + """ + + with self.assertRaises(yaml_data_visibility_config_reader.NotAListError): + yaml_data_visibility_config_reader.Read(StringIO(data)) + + def testElementNotAStringError(self): + data = """ + blacklist: + - 5 + """ + + with self.assertRaises( + yaml_data_visibility_config_reader.ElementNotAStringError): + yaml_data_visibility_config_reader.Read(StringIO(data)) + + +if __name__ == '__main__': + absltest.main()