#
# EfProb code associated with the article:
#
# A Channel-based Exact Inference Algorithm for Bayesian Networks
#
# See: https://arxiv.org/abs/1804.08032
#
# Copyright: Bart Jacobs, http://www.cs.ru.nl/~bart/, April 2018
#
from pgm_efprob import *
# from pgmpy.models import BayesianModel
# from pgmpy.factors.discrete import TabularCPD
# from pgmpy.inference import VariableElimination
from pgmpy.readwrite.BIF import *
import timeit

# http://www.bnlearn.com/bnrepository/
#
# Download first and unzip:
#    http://www.bnlearn.com/bnrepository/insurance/insurance.bif.gz
#
reader=BIFReader('insurance.bif')

#
# Turn it into pgmpy Bayesian model
#
model = reader.get_model()

#
# Extract the graph from the model and display it
#
#graph = pydot_graph_of_pgm(model)
#graph_image(graph, "child")

#
# Define an instance for doing variable elimination inference
#
inference = VariableElimination(model)


#
# Number of iterations
#
N = 5

#
# Times for variable elimination (ve) and transformations (tr)
#
tve = 0
ttr = 0

#
# Run the different inference algorithms multiple times, 
#   with randomly generated evidence nodes and observation
#
for i in range(N):
    # select new evidence and observation
    evidence_dictionary = {}
    evidence = {}
    evidence_num = random.randint(1,5)
    # select random nodes; the first one is for observation
    picks = pick_from_list(model.nodes, evidence_num+1)
    print("\n* Observation and evidence nodes: ", picks[0], picks[1:] )
    for e in picks[1:]:
        # form the right predicate, to be used as evidence
        ls = model.get_cardinality(e) * [0]
        ls[0] = 1
        evidence_dictionary[e] = ls
        evidence[e] = 0
    # variable elimination
    # stretch-and-infer
    t1 = timeit.timeit(lambda: print(stretch_and_infer(model, picks[0], 
                                                       evidence_dictionary,
                                                       silent=True)), 
                       number = 1)
    ttr += t1
    t2 = timeit.timeit(lambda: print(inference.query([picks[0]], 
                                                     evidence=evidence)[picks[0]]), 
                       number = 1)
    tve += t2
    print("Iteration:", i, " ", t1, t2)
    
print("\nTotal inference time is: ", tve+ttr, " for ", N, " runs")
print("of which for variable elimination: ", tve)
print("and for stretch-and-infer: ", ttr)
print("How much faster is stretch-and-infer:", tve/ttr)

