trans = {}
trans["start"] = {"DT": 0.5, "NN": 0.45, "VBD": 0.05}
trans["DT"] = {"DT":0.01, "NN":0.85, "VBD":0.05, "end":0.09}
trans["NN"] = {"DT":0.05, "NN":0.2, "VBD":.7, "end":0.05}
trans["VBD"] = {"DT":0.1, "NN":0.4, "VBD":0.1, "end":0.4}

obs = {}
obs["DT"] = {"Mary":0.05, "had":0.01, "cats":0.01}
obs["NN"] = {"Mary":0.2, "had":0.05, "cats":0.3}
obs["VBD"] = {"Mary":0.02, "had":0.2, "cats":0.03}

seq = ["Mary", "had", "cats"]
tags = ["DT", "NN", "VBD"]

class Cell:
    def __init__(self, prob, backpointer, i, j):
        self.prob = prob
        self.backpointer = backpointer
        self.identifier = (i,j)
    def __repr__(self):
        if self.backpointer:
            return str(self.prob) + str(self.backpointer.identifier)
        else:
            return str(self.prob)

# Matrix
m = [[Cell(0, None, i, j) for i in range(len(seq)+1)] for j in range(len(tags))]
def print_m(m):
    for row in m:
        print(row)


# iteriert ueber Woerter -- Spalten
for j in range(len(seq)):
# iteriert ueber Tags -- Zeilen
    for i in range(len(tags)):
        t = tags[i]
        if j == 0:
            # erste Spalte, wird mit Hilfe von Start ausgefuellt
            m[i][j].prob = trans["start"][t]*obs[t][seq[j]]
        else:
            # normaler Fall: iteriere ueber die vorherige Spalte
            maxPrev = None
            for x in range(len(tags)):
                p = m[x][j-1].prob * trans[tags[x]][t]*obs[t][seq[j]]
                print(seq[j], t, tags[x], m[x][j-1].prob, trans[tags[x]][t], obs[t][seq[j]], p)
                if not maxPrev or maxPrev[0] < p:
                    maxPrev = (p, x, j-1)
            # set backpointer
            m[i][j].prob = maxPrev[0]
            m[i][j].backpointer = m[maxPrev[1]][maxPrev[2]]

print("Viterbi matrix")
print_m(m)

# letzte Spalte: "end":
endProbs = []
for i in range(len(tags)):
    t = tags[i]
    endProbs.append((m[i][len(seq)-1].prob*trans[t]["end"], t, i, len(seq)-1))

endProbs.sort(reverse=True)
print(endProbs)

# output most probable tag sequence
print("Wahrscheinlichkeit der wahrscheinlichsten Tag-Sequenz:", endProbs[0][0])
seq_tags = []
# Backtrace
i = endProbs[0][2]
j = endProbs[0][3]
while True:
    print(i, j)
    print(tags[i])
    seq_tags = [tags[i]] + seq_tags
    print(seq_tags)
    if m[i][j].backpointer is None:
        break
    j, i = m[i][j].backpointer.identifier


