diff --git a/src/PartitionedParallelSampling.jl b/src/PartitionedParallelSampling.jl index ab2a821..7bfea64 100644 --- a/src/PartitionedParallelSampling.jl +++ b/src/PartitionedParallelSampling.jl @@ -25,16 +25,17 @@ using Parameters: @with_kw using DocStringExtensions import BAT -using BAT: AbstractMeasureOrDensity, AbstractPosteriorMeasure, PosteriorMeasure +using BAT: BATMeasure, AbstractPosteriorMeasure, LBQIntegral using BAT: bat_sample, AbstractSamplingAlgorithm, MCMCSampling, OrderedResampling using BAT: bat_integrate, IntegrationAlgorithm, DensitySampleVector using BAT: bat_transform, AbstractTransformTarget, PriorToUniform +using BAT: BATContext # BAT-experimental: -using BAT: truncate_density +using BAT: truncate_batmeasure # BAT-internal: -using BAT: var_bounds, spatialvolume +using BAT: measure_support, spatialvolume using AHMI: AHMIntegration diff --git a/src/partitioned_sampling.jl b/src/partitioned_sampling.jl index 63bc5f0..bd7c5b7 100644 --- a/src/partitioned_sampling.jl +++ b/src/partitioned_sampling.jl @@ -32,24 +32,21 @@ end export PartitionedSampling -function BAT.bat_sample_impl(rng::AbstractRNG, target::PosteriorMeasure, algorithm::PartitionedSampling) - - density_notrafo = convert(AbstractMeasureOrDensity, target) - shaped_density, trafo = bat_transform(algorithm.trafo, density_notrafo) - density = unshaped(shaped_density) +function BAT.bat_sample_impl(m::BATMeasure, algorithm::PartitionedSampling, context::BATContext) + transformed_m, trafo = BAT.transform_and_unshape(algorithm.trafo, m, context) @info "Generating Exploration Samples" - exploration_samples = bat_sample(density, algorithm.exploration_sampler).result + exploration_samples = bat_sample(transformed_m, algorithm.exploration_sampler).result @info "Constructing Partition Tree" partition_tree, cost_values = partition_space(exploration_samples, algorithm.npartitions, algorithm.partitioner) # Convert 'partition_tree' structure into a set of truncated posteriors: - posteriors_array = convert_to_posterior(density, partition_tree, extend_bounds = algorithm.partitioner.extend_bounds) + posteriors_array = convert_to_posterior(transformed_m, partition_tree, extend_bounds = algorithm.partitioner.extend_bounds) @info "Sampling Subspaces" iterator_subspaces = [ [subspace_ind, posteriors_array[subspace_ind], algorithm.sampler] for subspace_ind in Base.OneTo(algorithm.npartitions)] - samples_subspaces_run = pmap(inp -> sample_subspace(inp...), iterator_subspaces) + samples_subspaces_run = pmap(inp -> sample_subspace(inp..., context), iterator_subspaces) unconv_mask = [samples_subspace.isvalid for samples_subspace in samples_subspaces_run] # returns "false" if subspace was not converged during tuning cycle unconv_ind = findall(x->x==false, unconv_mask) @@ -74,7 +71,8 @@ function BAT.bat_sample_impl(rng::AbstractRNG, target::PosteriorMeasure, algorit exploration_samples_rep = bat_sample( samples_subspaces_run[rep_ind].samples, - OrderedResampling(nsamples=algorithm.exploration_sampler.nsteps) + OrderedResampling(nsamples=algorithm.exploration_sampler.nsteps), + context ).result partition_tree_rep, _ = partition_space(exploration_samples_rep, 2, algorithm.partitioner) @@ -109,22 +107,21 @@ function BAT.bat_sample_impl(rng::AbstractRNG, target::PosteriorMeasure, algorit samples_subspaces = pmap(inp -> integrate_subspace(inp, algorithm.integrator), samples_subspaces) @info "Combining Samples" - samples = deepcopy(samples_subspaces[1].samples) + transformed_smpls = deepcopy(samples_subspaces[1].samples) info = deepcopy(samples_subspaces[1].info) # Save indices from different subspaces: - info.samples_ind[1] = 1:length(samples) + info.samples_ind[1] = 1:length(transformed_smpls) for subspace in samples_subspaces[2:end] - start_ind, stop_ind = length(samples)+1, length(samples)+length(subspace.samples) + start_ind, stop_ind = length(transformed_smpls)+1, length(transformed_smpls)+length(subspace.samples) subspace.info.samples_ind[1] = start_ind:stop_ind - append!(samples, subspace.samples) + append!(transformed_smpls, subspace.samples) append!(info, subspace.info) end - samples_trafo = varshape(shaped_density).(samples) - samples_notrafo = inverse(trafo).(samples_trafo) + smpls = inverse(trafo).(transformed_smpls) return ( - result = samples_notrafo, result_trafo = samples_trafo, trafo = trafo, + result = smpls, result_trafo = samples_trafo, trafo = trafo, info = info, exp_samples = exploration_samples, part_tree = partition_tree, cost_values = cost_values, @@ -134,14 +131,15 @@ end function sample_subspace( space_id::Integer, - posterior::PosteriorMeasure, + posterior::LBQIntegral, sampling_algorithm::A, + context::BATContext ) where {N<:NamedTuple, A<:AbstractSamplingAlgorithm, I<:IntegrationAlgorithm} @info "Sampling subspace #$space_id" sampling_wc_start = Dates.Time(Dates.now()) sampling_cpu_time = @CPUelapsed begin - sampling_output = bat_sample(posterior, sampling_algorithm) + sampling_output = bat_sample(posterior, sampling_algorithm, context) samples_subspace = sampling_output.result isvalid = sampling_output.isvalid end @@ -162,6 +160,7 @@ end function integrate_subspace( sampling_reuslt::N, integration_algorithm::I, + context::BATContext ) where {N<:NamedTuple, A<:AbstractSamplingAlgorithm, I<:IntegrationAlgorithm} samples_subspace = sampling_reuslt.samples @@ -169,7 +168,7 @@ function integrate_subspace( integration_wc_start = Dates.Time(Dates.now()) integration_cpu_time = @CPUelapsed begin # ToDo: Use samples_subspace_trafo for integration instead of samples_subspace? - integras_subspace = bat_integrate(samples_subspace, integration_algorithm).result + integras_subspace = bat_integrate(samples_subspace, integration_algorithm, context).result end integration_wc_stop = Dates.Time(Dates.now()) @@ -193,9 +192,9 @@ function integrate_subspace( end function convert_to_posterior_resampled( - posterior::PosteriorMeasure, + posterior::LBQIntegral, partition_tree::SpacePartTree, - posterior_subspace::PosteriorMeasure; + posterior_subspace::LBQIntegral; extend_bounds::Bool=true ) @@ -218,8 +217,8 @@ function convert_to_posterior_resampled( posterior_array = map( x -> begin bounds = HyperRectBounds(x[:,1], x[:,2]) - prior_dist = truncate_density(getprior(posterior), bounds) - PosteriorMeasure(getlikelihood(posterior), prior_dist) + prior_dist = truncate_batmeasure(getprior(posterior), bounds) + LBQIntegral(getlikelihood(posterior), prior_dist) end, subspaces_rect_bounds) @@ -228,12 +227,13 @@ end -function convert_to_posterior(posterior::PosteriorMeasure, partition_tree::SpacePartTree; extend_bounds::Bool=true) +function convert_to_posterior(posterior::LBQIntegral, partition_tree::SpacePartTree; extend_bounds::Bool=true) if extend_bounds # Exploration samples might not always cover properly tails of the distribution. # We will extend boudnaries of the partition tree with original bounds which are: - vol = spatialvolume(var_bounds(posterior)) + #!!!!!!!!!!!!!!!!! + vol = spatialvolume(measure_support(posterior)) lo_bounds = vol.lo hi_bounds = vol.hi extend_tree_bounds!(partition_tree, lo_bounds, hi_bounds) @@ -244,7 +244,7 @@ function convert_to_posterior(posterior::PosteriorMeasure, partition_tree::Space posterior_array = map(subspaces_rect_bounds) do x bounds = StructArray{Interval}((x[:,1], x[:,2])) - truncate_density(posterior, bounds) + truncate_batmeasure(posterior, bounds) end return posterior_array diff --git a/test/test_partitioned_sampling.jl b/test/test_partitioned_sampling.jl index 8243b86..c627f18 100644 --- a/test/test_partitioned_sampling.jl +++ b/test/test_partitioned_sampling.jl @@ -4,7 +4,7 @@ using PartitionedParallelSampling using Test using Distributions, LinearAlgebra, DensityInterface, ValueShapes -using BAT: AbstractPosteriorMeasure, PosteriorMeasure +using BAT: AbstractPosteriorMeasure, LBQIntegral using BAT: bat_sample, MCMCSampling, MetropolisHastings, IIDSampling using BAT: bat_transform, PriorToUniform, NoWhitening @@ -26,7 +26,7 @@ using PartitionedParallelSampling: convert_to_posterior params -> logpdf(model, params.a) end) - posterior = PosteriorMeasure(likelihood, prior) + posterior = LBQIntegral(likelihood, prior) transformed_posterior, trafo = bat_transform(PriorToUniform(), posterior) #Sampling and integration algorithms