79 lines
2.1 KiB
Python
79 lines
2.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
from __future__ import unicode_literals
|
|
|
|
import sys
|
|
import gzip
|
|
import marshal
|
|
from math import log, exp
|
|
|
|
from ..utils.frequency import AddOneProb
|
|
|
|
|
|
class Bayes(object):
|
|
|
|
def __init__(self):
|
|
self.d = {}
|
|
self.total = 0
|
|
|
|
def save(self, fname, iszip=True):
|
|
d = {}
|
|
d['total'] = self.total
|
|
d['d'] = {}
|
|
for k, v in self.d.items():
|
|
d['d'][k] = v.__dict__
|
|
if sys.version_info[0] == 3:
|
|
fname = fname + '.3'
|
|
if not iszip:
|
|
marshal.dump(d, open(fname, 'wb'))
|
|
else:
|
|
f = gzip.open(fname, 'wb')
|
|
f.write(marshal.dumps(d))
|
|
f.close()
|
|
|
|
def load(self, fname, iszip=True):
|
|
if sys.version_info[0] == 3:
|
|
fname = fname + '.3'
|
|
if not iszip:
|
|
d = marshal.load(open(fname, 'rb'))
|
|
else:
|
|
try:
|
|
f = gzip.open(fname, 'rb')
|
|
d = marshal.loads(f.read())
|
|
except IOError:
|
|
f = open(fname, 'rb')
|
|
d = marshal.loads(f.read())
|
|
f.close()
|
|
self.total = d['total']
|
|
self.d = {}
|
|
for k, v in d['d'].items():
|
|
self.d[k] = AddOneProb()
|
|
self.d[k].__dict__ = v
|
|
|
|
def train(self, data):
|
|
for d in data:
|
|
c = d[1]
|
|
if c not in self.d:
|
|
self.d[c] = AddOneProb()
|
|
for word in d[0]:
|
|
self.d[c].add(word, 1)
|
|
self.total = sum(map(lambda x: self.d[x].getsum(), self.d.keys()))
|
|
|
|
def classify(self, x):
|
|
tmp = {}
|
|
for k in self.d:
|
|
tmp[k] = log(self.d[k].getsum()) - log(self.total)
|
|
for word in x:
|
|
tmp[k] += log(self.d[k].freq(word))
|
|
ret, prob = 0, 0
|
|
for k in self.d:
|
|
now = 0
|
|
try:
|
|
for otherk in self.d:
|
|
now += exp(tmp[otherk]-tmp[k])
|
|
now = 1/now
|
|
except OverflowError:
|
|
now = 0
|
|
if now > prob:
|
|
ret, prob = k, now
|
|
return (ret, prob)
|