126 lines
4.3 KiB
Python
126 lines
4.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
from __future__ import unicode_literals
|
|
|
|
import sys
|
|
import gzip
|
|
import marshal
|
|
from math import log
|
|
|
|
from ..utils import frequency
|
|
|
|
|
|
class CharacterBasedGenerativeModel(object):
|
|
|
|
def __init__(self):
|
|
self.l1 = 0.0
|
|
self.l2 = 0.0
|
|
self.l3 = 0.0
|
|
self.status = ('b', 'm', 'e', 's')
|
|
self.uni = frequency.NormalProb()
|
|
self.bi = frequency.NormalProb()
|
|
self.tri = frequency.NormalProb()
|
|
|
|
def save(self, fname, iszip=True):
|
|
d = {}
|
|
for k, v in self.__dict__.items():
|
|
if hasattr(v, '__dict__'):
|
|
d[k] = v.__dict__
|
|
else:
|
|
d[k] = v
|
|
if sys.version_info[0] == 3:
|
|
fname = fname + '.3'
|
|
if not iszip:
|
|
marshal.dump(d, open(fname, 'wb'))
|
|
else:
|
|
f = gzip.open(fname, 'wb')
|
|
f.write(marshal.dumps(d))
|
|
f.close()
|
|
|
|
def load(self, fname, iszip=True):
|
|
if sys.version_info[0] == 3:
|
|
fname = fname + '.3'
|
|
if not iszip:
|
|
d = marshal.load(open(fname, 'rb'))
|
|
else:
|
|
try:
|
|
f = gzip.open(fname, 'rb')
|
|
d = marshal.loads(f.read())
|
|
except IOError:
|
|
f = open(fname, 'rb')
|
|
d = marshal.loads(f.read())
|
|
f.close()
|
|
for k, v in d.items():
|
|
if hasattr(self.__dict__[k], '__dict__'):
|
|
self.__dict__[k].__dict__ = v
|
|
else:
|
|
self.__dict__[k] = v
|
|
|
|
def div(self, v1, v2):
|
|
if v2 == 0:
|
|
return 0
|
|
return float(v1)/v2
|
|
|
|
def train(self, data):
|
|
for sentence in data:
|
|
now = [('', 'BOS'), ('', 'BOS')]
|
|
self.bi.add((('', 'BOS'), ('', 'BOS')), 1)
|
|
self.uni.add(('', 'BOS'), 2)
|
|
for word, tag in sentence:
|
|
now.append((word, tag))
|
|
self.uni.add((word, tag), 1)
|
|
self.bi.add(tuple(now[1:]), 1)
|
|
self.tri.add(tuple(now), 1)
|
|
now.pop(0)
|
|
tl1 = 0.0
|
|
tl2 = 0.0
|
|
tl3 = 0.0
|
|
samples = sorted(self.tri.samples(), key=lambda x: self.tri.get(x)[1])
|
|
for now in samples:
|
|
c3 = self.div(self.tri.get(now)[1]-1, self.bi.get(now[:2])[1]-1)
|
|
c2 = self.div(self.bi.get(now[1:])[1]-1, self.uni.get(now[1])[1]-1)
|
|
c1 = self.div(self.uni.get(now[2])[1]-1, self.uni.getsum()-1)
|
|
if c3 >= c1 and c3 >= c2:
|
|
tl3 += self.tri.get(now)[1]
|
|
elif c2 >= c1 and c2 >= c3:
|
|
tl2 += self.tri.get(now)[1]
|
|
elif c1 >= c2 and c1 >= c3:
|
|
tl1 += self.tri.get(now)[1]
|
|
self.l1 = self.div(tl1, tl1+tl2+tl3)
|
|
self.l2 = self.div(tl2, tl1+tl2+tl3)
|
|
self.l3 = self.div(tl3, tl1+tl2+tl3)
|
|
|
|
def log_prob(self, s1, s2, s3):
|
|
uni = self.l1*self.uni.freq(s3)
|
|
bi = self.div(self.l2*self.bi.get((s2, s3))[1], self.uni.get(s2)[1])
|
|
tri = self.div(self.l3*self.tri.get((s1, s2, s3))[1],
|
|
self.bi.get((s1, s2))[1])
|
|
if uni+bi+tri == 0:
|
|
return float('-inf')
|
|
return log(uni+bi+tri)
|
|
|
|
def tag(self, data):
|
|
now = [((('', 'BOS'), ('', 'BOS')), 0.0, [])]
|
|
for w in data:
|
|
stage = {}
|
|
not_found = True
|
|
for s in self.status:
|
|
if self.uni.freq((w, s)) != 0:
|
|
not_found = False
|
|
break
|
|
if not_found:
|
|
for s in self.status:
|
|
for pre in now:
|
|
stage[(pre[0][1], (w, s))] = (pre[1], pre[2]+[s])
|
|
now = list(map(lambda x: (x[0], x[1][0], x[1][1]),
|
|
stage.items()))
|
|
continue
|
|
for s in self.status:
|
|
for pre in now:
|
|
p = pre[1]+self.log_prob(pre[0][0], pre[0][1], (w, s))
|
|
if (not (pre[0][1],
|
|
(w, s)) in stage) or p > stage[(pre[0][1],
|
|
(w, s))][0]:
|
|
stage[(pre[0][1], (w, s))] = (p, pre[2]+[s])
|
|
now = list(map(lambda x: (x[0], x[1][0], x[1][1]), stage.items()))
|
|
return zip(data, max(now, key=lambda x: x[1])[2])
|