Kody źródłowe/Algorytm Earleya

Algorytm Earleya • Kod źródłowy
Algorytm Earleya
Kod źródłowy
Implementacja algorytmu Earleya w Pythonie.
Wikipedia
Zobacz w Wikipedii hasło Algorytm Earleya
#!/usr/bin/python
 
def parse(tokens, terminals, rules, head, debug=False):
    class OrderedDict(object):
        def __init__(self):
            self.d = {}
            self.l = []
 
    def format(item):
        try:
            lhs, h, rhs, dot, i = item
            return '%s ->%d %s' % (
                lhs, h, ' '.join(rhs[:dot] + ['*%d' % i] + rhs[dot:]))
        except (TypeError, ValueError):
            return str(item)
 
    def add_item(item, precedents):
        lhs, h, rhs, dot, i = item
        current_parse_set = parse_sets[i].d
        if item in current_parse_set:
            current_parse_set[item].add(precedents)
            return
        current_parse_set[item] = set([precedents])
        parse_sets[i].l.append(item)
        if dot < len(rhs):
            symb = rhs[dot]
            active_items.setdefault((symb, i), []).append(item)
            for nullable_item in nullable.get(symb, ()):
                add_item((lhs, h, rhs, dot + 1, i), (item, nullable_item))
 
    parse_sets = [OrderedDict()]
    active_items = {}
    nullable = {}
    add_item((None, 0, (head,), 0, 0), (None, None))
    for i in xrange(len(tokens) + 1):
        parse_sets.append(OrderedDict())
        for item in parse_sets[i].l:
            lhs, h, rhs, dot, _ = item
            if dot == len(rhs):
                # Completer.
                for symb in rhs:
                    if symb not in nullable:
                        break
                else:
                    nullable.setdefault(lhs, set()).add(item)
                for ci in active_items.get((lhs, h), ()):
                    add_item((ci[0], ci[1], ci[2], ci[3] + 1, i), (ci, item))
            else:
                symb = rhs[dot]
                if symb in rules:
                    # Predictor.
                    for new_rhs in rules[symb]:
                        add_item((symb, i, new_rhs, 0, i), (None, None))
                elif i < len(tokens) and tokens[i] in terminals[symb]:
                    # Scanner.
                    add_item((lhs, h, rhs, dot + 1, i + 1), (item, tokens[i]))
        if debug:
            for item in parse_sets[i].l:
                print '%s:' % format(item),
                for predecessor, reduction in parse_sets[i].d[item]:
                    print '[%s, %s]' % (format(predecessor), format(reduction)),
                print
    root = (None, 0, (head,), 1, i)
    return parse_sets[i].d.get(root, ()), parse_sets
 
def bracketed(symb, node):
    return '[%s %s]' % (symb, node)
 
def forest(node, parse_sets, textify, current_path=None):
    if current_path is None:
        current_path = set([(None, None)])
    if node in current_path:
        return ['']
    current_path.add(node)
    predecessor, reduction = node
    if type(reduction) is str:
        lhs, h, rhs, dot, i = predecessor
        result = base = [textify(rhs[dot], reduction)]
    else:
        lhs, h, rhs, dot, i = reduction
        result = base = [textify(lhs, kids)
            for rnode in parse_sets[i].d[reduction]
            for kids in forest(rnode, parse_sets, textify, current_path)]
    extended = []
    i = predecessor[4]
    for pnode in parse_sets[i].d[predecessor]:
        variants = forest(pnode, parse_sets, textify, current_path)
        if variants:
            extended.extend([v + b for v in variants for b in base])
            result = extended
    current_path.remove(node)
    return result
 
if __name__ == '__main__':
    terminals = {
        'Det': set(['an']),
        'N': set(['arrow', 'flies', 'time']),
        'V': set(['flies', 'like', 'time']),
        'Prep': set(['like']),
    }
    rules = {
        'S': [('NP', 'VP'), ('VP',)],
        'VP': [('V',), ('V', 'NP'), ('VP', 'PP')],
        'NP': [('N',), ('Det', 'N'), ('NP', 'N'), ('NP', 'PP')],
        'PP': [('Prep', 'NP')],
    }
    tokens = 'time flies like an arrow'.split()
    nodes, parse_sets = parse(tokens, terminals, rules, 'S', debug=True)
    for node in nodes:
        for tree in forest(node, parse_sets, bracketed):
            print tree