Priyanka-Kumavat-At-TE's picture
Upload 7 files
dfcdf7b
raw history blame
No virus
5.28 kB
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import random
import jprops
from random import randint
from matumizi.util import *
from matumizi.mlutil import *
"""
Markov chain classifier
"""
class MarkovChainClassifier():
def __init__(self, configFile):
"""
constructor
Parameters
configFile: config file path
"""
defValues = {}
defValues["common.model.directory"] = ("model", None)
defValues["common.model.file"] = (None, None)
defValues["common.verbose"] = (False, None)
defValues["common.states"] = (None, "missing state list")
defValues["train.data.file"] = (None, "missing training data file")
defValues["train.data.class.labels"] = (["F", "T"], None)
defValues["train.data.key.len"] = (1, None)
defValues["train.model.save"] = (False, None)
defValues["train.score.method"] = ("accuracy", None)
defValues["predict.data.file"] = (None, None)
defValues["predict.use.saved.model"] = (True, None)
defValues["predict.log.odds.threshold"] = (0, None)
defValues["validate.data.file"] = (None, "missing validation data file")
defValues["validate.use.saved.model"] = (False, None)
defValues["valid.accuracy.metric"] = ("acc", None)
self.config = Configuration(configFile, defValues)
self.stTranPr = dict()
self.clabels = self.config.getStringListConfig("train.data.class.labels")[0]
self.states = self.config.getStringListConfig("common.states")[0]
self.nstates = len(self.states)
for cl in self.clabels:
stp = np.ones((self.nstates,self.nstates))
self.stTranPr[cl] = stp
def train(self):
"""
train model
"""
#state transition matrix
tdfPath = self.config.getStringConfig("train.data.file")[0]
klen = self.config.getIntConfig("train.data.key.len")[0]
for rec in fileRecGen(tdfPath):
cl = rec[klen]
rlen = len(rec)
for i in range(klen+1, rlen-1, 1):
fst = self.states.index(rec[i])
tst = self.states.index(rec[i+1])
self.stTranPr[cl][fst][tst] += 1
#normalize to probability
for cl in self.clabels:
stp = self.stTranPr[cl]
for i in range(self.nstates):
s = stp[i].sum()
r = stp[i] / s
stp[i] = r
#save
if self.config.getBooleanConfig("train.model.save")[0]:
mdPath = self.config.getStringConfig("common.model.directory")[0]
assert os.path.exists(mdPath), "model save directory does not exist"
mfPath = self.config.getStringConfig("common.model.file")[0]
mfPath = os.path.join(mdPath, mfPath)
with open(mfPath, "w") as fh:
for cl in self.clabels:
fh.write("label:" + cl +"\n")
stp = self.stTranPr[cl]
for r in stp:
rs = ",".join(toStrList(r, 6)) + "\n"
fh.write(rs)
def validate(self):
"""
validate using model
"""
useSavedModel = self.config.getBooleanConfig("predict.use.saved.model")[0]
if useSavedModel:
self.__restoreModel()
else:
self.train()
vdfPath = self.config.getStringConfig("validate.data.file")[0]
accMetric = self.config.getStringConfig("valid.accuracy.metric")[0]
yac, ypr = self.__getPrediction(vdfPath, True)
if type(self.clabels[0]) == str:
yac = self.__toIntClabel(yac)
ypr = self.__toIntClabel(ypr)
score = perfMetric(accMetric, yac, ypr)
print(formatFloat(3, score, "perf score"))
def predict(self):
"""
predict using model
"""
useSavedModel = self.config.getBooleanConfig("predict.use.saved.model")[0]
if useSavedModel:
self.__restoreModel()
else:
self.train()
#predict
pdfPath = self.config.getStringConfig("predict.data.file")[0]
_ , ypr = self.__getPrediction(pdfPath)
return ypr
def __restoreModel(self):
"""
restore model
"""
mdPath = self.config.getStringConfig("common.model.directory")[0]
assert os.path.exists(mdPath), "model save directory does not exist"
mfPath = self.config.getStringConfig("common.model.file")[0]
mfPath = os.path.join(mdPath, mfPath)
stp = None
cl = None
for rec in fileRecGen(mfPath):
if len(rec) == 1:
if stp is not None:
stp = np.array(stp)
self.stTranPr[cl] = stp
cl = rec[0].split(":")[1]
stp = list()
else:
frec = asFloatList(rec)
stp.append(frec)
stp = np.array(stp)
self.stTranPr[cl] = stp
def __getPrediction(self, fpath, validate=False):
"""
get predictions
Parameters
fpath : data file path
validate: True if validation
"""
nc = self.clabels[0]
pc = self.clabels[1]
thold = self.config.getFloatConfig("predict.log.odds.threshold")[0]
klen = self.config.getIntConfig("train.data.key.len")[0]
offset = klen+1 if validate else klen
ypr = list()
yac = list()
for rec in fileRecGen(fpath):
lodds = 0
rlen = len(rec)
for i in range(offset, rlen-1, 1):
fst = self.states.index(rec[i])
tst = self.states.index(rec[i+1])
odds = self.stTranPr[pc][fst][tst] / self.stTranPr[nc][fst][tst]
lodds += math.log(odds)
prc = pc if lodds > thold else nc
ypr.append(prc)
if validate:
yac.append(rec[klen])
else:
recp = prc + "\t" + ",".join(rec)
print(recp)
re = (yac, ypr)
return re
def __toIntClabel(self, labels):
"""
convert string class label to int
Parameters
labels : class label values
"""
return list(map(lambda l : self.clabels.index(l), labels))