Skip to content

Parsing

PartialIndenter

Bases: Indenter

An Indenter that doesn't reset its state every time process is called.

Source code in outlines/fsm/parsing.py
class PartialIndenter(Indenter):
    """An `Indenter` that doesn't reset its state every time `process` is called."""

    def process(self, stream):
        return self._process(stream)

    def _process(self, stream):
        for token in stream:
            # These were previously *after* the `yield`, but that makes the
            # state tracking unnecessarily convoluted.
            if token.type in self.OPEN_PAREN_types:
                self.paren_level += 1
            elif token.type in self.CLOSE_PAREN_types:
                self.paren_level -= 1
                if self.paren_level < 0:
                    raise UnexpectedToken(token, [])

            if token.type == self.NL_type:
                yield from self.handle_NL(token)
            else:
                yield token

        # TODO: What do we want to do here?
        # while len(self.indent_level) > 1:
        #     self.indent_level.pop()
        #     yield Token(self.DEDENT_type, "")

    def accepts_token_type(self, token_type):
        if token_type in self.CLOSE_PAREN_types and self.paren_level - 1 < 0:
            return False

        # TODO:
        # if token_type == self.NL_type and self.paren_level == 0:
        #     ...
        #     return False

        return True

    def __copy__(self):
        res = type(self)()
        res.paren_level = self.paren_level
        res.indent_level = copy(self.indent_level)
        return res

    def __repr__(self):
        return f"{type(self).__name__}(paren_level={self.paren_level!r}, indent_level={self.indent_level!r})"

PartialParserState

Bases: ParserState

Source code in outlines/fsm/parsing.py
class PartialParserState(ParserState):
    __slots__ = "use_value_stack"

    def __init__(
        self,
        parse_conf,
        lexer,
        state_stack=None,
        value_stack=None,
        use_value_stack=False,
    ):
        super().__init__(
            parse_conf, lexer, state_stack=state_stack, value_stack=value_stack
        )
        self.use_value_stack = use_value_stack

    def feed_token(self, token, is_end=False):
        if token.type == "partial":
            # If none of the potential terminals can transition, we need to know now
            current_state = self.state_stack[-1]
            current_lexer = get_contextual_lexer(self.lexer).lexers[current_state]

            # We have to feed the token and determine whether or not at least
            # one terminal is consistent with the stack; otherwise, we'll miss
            # invalid REDUCE cases.
            # TODO: We should track separate parses conditional on possible
            # token/symbol types, then we can coherently reuse the following
            # results instead of recomputing it later.
            can_transition = False
            for terminal_info in token.value.terminals_and_info:
                if terminal_info.terminal_name not in current_lexer.ignore_types:
                    test_token = Token.new_borrow_pos(
                        terminal_info.terminal_name, "", token
                    )

                    stack = copy(self.state_stack)
                    try:
                        self.feed_token_no_stack(test_token, is_end=is_end)
                        can_transition = True
                        break
                    except UnexpectedToken:
                        continue
                    finally:
                        self.state_stack = stack
                else:
                    can_transition = True

            if not can_transition:
                expected = {
                    s
                    for s in self.parse_conf.states[current_state].keys()
                    if s.isupper()
                }
                raise UnexpectedToken(
                    token, expected, state=self, interactive_parser=None
                )

        elif self.use_value_stack:
            super().feed_token(token, is_end=is_end)
        else:
            self.feed_token_no_stack(token, is_end=is_end)

    def feed_token_no_stack(self, token, is_end=False):
        """
        This is a copy of `ParserState.feed_token` with all the value stack
        steps removed.  Since we're not exactly parsing in order to obtain a
        CST or anything similar, we can avoid the growing expense of tracking
        the parse tree.
        """
        state_stack = self.state_stack
        states = self.parse_conf.states
        end_state = self.parse_conf.end_state

        while True:
            state = state_stack[-1]
            try:
                action, arg = states[state][token.type]
            except KeyError:
                expected = {s for s in states[state].keys() if s.isupper()}
                raise UnexpectedToken(
                    token, expected, state=self, interactive_parser=None
                )

            assert arg != end_state

            if action is Shift:
                # shift once and return
                assert not is_end
                state_stack.append(arg)
                return
            else:
                # reduce+shift as many times as necessary
                rule = arg
                size = len(rule.expansion)
                if size:
                    del state_stack[-size:]

                _action, new_state = states[state_stack[-1]][rule.origin.name]
                assert _action is Shift
                state_stack.append(new_state)

                if is_end and state_stack[-1] == end_state:
                    return

    def feed_eof(self):
        last_token = self.lexer.state.last_token

        if last_token is None:
            eof_token = self.lexer._Token("$END", "", 0, 1, 1)
        else:
            eof_token = Token.new_borrow_pos("$END", "", last_token)

        new_token_is_legal = (
            last_token is None
            or last_token.type != "partial"
            or any(ti.is_final for ti in last_token.value.terminals_and_info)
        )
        if new_token_is_legal:
            self.feed_token(eof_token, is_end=True)
        else:
            raise UnexpectedToken(eof_token, [], state=self, interactive_parser=None)

    def choices(self):
        return self.parse_conf.parse_table.states[self.position]

    def accepts(self):
        """
        Adapted from https://github.com/lark-parser/lark/blob/be542c2ff6d968817df019b8bf03f37b3111c08c/lark/parsers/lalr_interactive_parser.py#L95
        Returns the set of possible tokens that will advance the parser into a new valid state.
        """
        accepts = set()
        conf_no_callbacks = copy(self.parse_conf)
        # We don't want to call callbacks here since those might have arbitrary side effects
        # and are unnecessarily slow.
        conf_no_callbacks.callbacks = {}
        for t in self.choices():
            if t.isupper():  # is terminal?
                new_state = copy(self)
                new_state.parse_conf = conf_no_callbacks
                try:
                    new_state.feed_token(new_state.lexer._Token(t, ""))
                except UnexpectedToken:
                    pass
                else:
                    accepts.add(t)
        return accepts

    def __copy__(self):
        return type(self)(
            self.parse_conf,
            copy(self.lexer),
            copy(self.state_stack),
            deepcopy(self.value_stack),
            use_value_stack=self.use_value_stack,
        )

    def __repr__(self):
        return f"{type(self).__name__}(lexer={self.lexer!r}, state_stack={self.state_stack!r})"

accepts()

Adapted from https://github.com/lark-parser/lark/blob/be542c2ff6d968817df019b8bf03f37b3111c08c/lark/parsers/lalr_interactive_parser.py#L95 Returns the set of possible tokens that will advance the parser into a new valid state.

Source code in outlines/fsm/parsing.py
def accepts(self):
    """
    Adapted from https://github.com/lark-parser/lark/blob/be542c2ff6d968817df019b8bf03f37b3111c08c/lark/parsers/lalr_interactive_parser.py#L95
    Returns the set of possible tokens that will advance the parser into a new valid state.
    """
    accepts = set()
    conf_no_callbacks = copy(self.parse_conf)
    # We don't want to call callbacks here since those might have arbitrary side effects
    # and are unnecessarily slow.
    conf_no_callbacks.callbacks = {}
    for t in self.choices():
        if t.isupper():  # is terminal?
            new_state = copy(self)
            new_state.parse_conf = conf_no_callbacks
            try:
                new_state.feed_token(new_state.lexer._Token(t, ""))
            except UnexpectedToken:
                pass
            else:
                accepts.add(t)
    return accepts

feed_token_no_stack(token, is_end=False)

This is a copy of ParserState.feed_token with all the value stack steps removed. Since we're not exactly parsing in order to obtain a CST or anything similar, we can avoid the growing expense of tracking the parse tree.

Source code in outlines/fsm/parsing.py
def feed_token_no_stack(self, token, is_end=False):
    """
    This is a copy of `ParserState.feed_token` with all the value stack
    steps removed.  Since we're not exactly parsing in order to obtain a
    CST or anything similar, we can avoid the growing expense of tracking
    the parse tree.
    """
    state_stack = self.state_stack
    states = self.parse_conf.states
    end_state = self.parse_conf.end_state

    while True:
        state = state_stack[-1]
        try:
            action, arg = states[state][token.type]
        except KeyError:
            expected = {s for s in states[state].keys() if s.isupper()}
            raise UnexpectedToken(
                token, expected, state=self, interactive_parser=None
            )

        assert arg != end_state

        if action is Shift:
            # shift once and return
            assert not is_end
            state_stack.append(arg)
            return
        else:
            # reduce+shift as many times as necessary
            rule = arg
            size = len(rule.expansion)
            if size:
                del state_stack[-size:]

            _action, new_state = states[state_stack[-1]][rule.origin.name]
            assert _action is Shift
            state_stack.append(new_state)

            if is_end and state_stack[-1] == end_state:
                return

PartialParsingFrontend

Bases: ParsingFrontend

Source code in outlines/fsm/parsing.py
class PartialParsingFrontend(ParsingFrontend):
    def __init__(self, lexer_conf, parser_conf, options, parser=None):
        assert parser_conf.parser_type == "lalr"

        options._plugins["LALR_Parser"] = PartialLALRParser
        options._plugins["BasicLexer"] = PartialBasicLexer
        options._plugins["ContextualLexer"] = PartialContextualLexer
        options._plugins["LexerThread"] = PartialLexerThread

        super().__init__(lexer_conf, parser_conf, options, parser=parser)

        if lexer_conf.postlex:
            self.lexer = PartialPostLexConnector(self.lexer.lexer, lexer_conf.postlex)

        self._termset_fsm_info = None
        self._symbols_to_states: Optional[
            Dict[str, Set[Tuple[ParseStateType, Action]]]
        ] = None
        self._reverse_shifts: Optional[
            Dict[ParseStateType, Dict[str, Set[ParseStateType]]]
        ] = None
        # self._state_transition_map: Optional[
        #     Dict[Tuple[ParseStateType, str], Set[ParseStateType]]
        # ] = None

    def _compute_maps(
        self,
    ):
        """Compute state transition and symbols-to-states maps."""
        self._reverse_shifts = {}
        self._symbols_to_states = {}

        parse_table = self.parser.parser.parse_table

        for from_state, symbols_to_ops in parse_table.states.items():
            for symbol, op in symbols_to_ops.items():
                if op[0] == Shift:
                    symbols_to_from_states = self._reverse_shifts.setdefault(op[1], {})
                    symbols_to_from_states.setdefault(symbol, set()).add(from_state)
                self._symbols_to_states.setdefault(symbol, set()).add((from_state, op))

        # # TODO: This approach is very wasteful.
        # context_lexer = get_contextual_lexer(self)
        # self._state_transition_map = {}
        #
        # for from_state, transitions in parse_table.states.items():
        #     for symbol, action in transitions.items():
        #         # TODO: Filter non-terminals
        #         if symbol not in context_lexer.root_lexer.terminals_by_name:
        #             continue
        #
        #         if action[0] is Shift:
        #             self._state_transition_map.setdefault(
        #                 (from_state, symbol), set()
        #             ).add(action[1])
        #             continue
        #
        #         antecedent_state_seqs = parse_to_terminal(self, [(from_state,)], symbol)
        #
        #         for antecedent_state_seq in antecedent_state_seqs:
        #             antecedent_state = antecedent_state_seq[-1]
        #             self._state_transition_map.setdefault(
        #                 (from_state, symbol), set()
        #             ).add(antecedent_state)

    def _compute_termset_fsm_info(self):
        """Collect and return information about terminal symbol sets and their FSMs.

        Terminal symbol sets (or "termsets") are ordered sequences of terminal
        symbols that are used by each parser state.  Associated with each is a
        collection of FSMs for each terminal and a single parse state FSM that is
        the union of each terminal's FSM.

        This constructs a list of tuples containing the termset, the set of
        parse states that use the termsets, parse state FSMs, and information
        mapping the components of the parse state FSMs to their terminal symbol
        FSMs.

        """
        context_lexer = get_contextual_lexer(self)
        termsets_to_fsms = {}
        termsets_to_parse_states: Dict[Tuple[str, ...], Set[ParseStateType]] = {}
        for parse_state, lexer in context_lexer.lexers.items():
            scanner = lexer.scanner
            key = tuple(term.name for term in scanner.terminals)
            termsets_to_fsms[key] = (scanner.fsm, scanner.fsms_to_trans_finals)
            termsets_to_parse_states.setdefault(key, set()).add(parse_state)

        self._termset_fsm_info = [
            (
                termset,
                frozenset(termsets_to_parse_states[termset]),
                fsm,
                fsms_to_trans_finals,
            )
            for termset, (fsm, fsms_to_trans_finals) in termsets_to_fsms.items()
        ]

    @property
    def termset_fsm_info(self):
        if self._termset_fsm_info is None:
            self._compute_termset_fsm_info()
        return self._termset_fsm_info

    @property
    def symbols_to_states(self):
        if self._symbols_to_states is None:
            self._compute_maps()
        return self._symbols_to_states

    @property
    def reverse_shifts(self):
        if self._reverse_shifts is None:
            self._compute_maps()
        return self._reverse_shifts

PartialScanner

Bases: Scanner

Source code in outlines/fsm/parsing.py
class PartialScanner(Scanner):
    @classmethod
    @lru_cache
    def construct_terminal_fsm(cls, terminal):
        # TODO: This should really be done at the lexer/parser level so that
        # the lifetime of these objects is tied to the parser itself.
        regex_str = terminal.pattern.to_regexp()
        pattern = interegular.parse_pattern(regex_str)
        fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce())
        return fsm, pattern.prefix_postfix

    def __init__(self, terminals, g_regex_flags, re_, use_bytes, match_whole=False):
        self.terminals = terminals
        self.g_regex_flags = g_regex_flags
        self.use_bytes = use_bytes
        self.match_whole = match_whole
        self.allowed_types = {t.name for t in self.terminals}
        self._mres = None

        fsms = []
        for t in self.terminals:
            fsm, prefix_postfix = self.construct_terminal_fsm(t)

            # TODO FIXME: We don't support this right now.
            assert prefix_postfix == (0, 0)

            fsms.append(fsm)

        self.fsm, self.fsms_to_trans_finals = fsm_union(fsms)

    def get_terminals_info(
        self, fsm_state_seq
    ) -> Tuple[Tuple[PartialTerminalInfo, ...], Tuple[PartialTerminalInfo, ...]]:
        """Get the possible terminal symbols for an FSM state sequence."""
        terminals_and_info: Tuple[PartialTerminalInfo, ...] = ()
        final_terminals_and_info: Tuple[PartialTerminalInfo, ...] = ()
        for i, (fsm_id, fsm_reads_more, in_final) in enumerate(
            get_sub_fsms_from_seq(fsm_state_seq, self.fsms_to_trans_finals)
        ):
            terminal_name = self.terminals[fsm_id].name
            info = PartialTerminalInfo(i, terminal_name, fsm_reads_more, in_final)
            terminals_and_info += (info,)
            if in_final:
                final_terminals_and_info += (info,)

        return terminals_and_info, final_terminals_and_info

    def match(self, text, pos, last_fsm_state_seq: Optional[Tuple[int, ...]] = None):
        """Determine an FSM match over `text` starting at `pos` and continuing `last_fsm_state_seq`."""

        start_pos = pos

        if last_fsm_state_seq:
            assert len(last_fsm_state_seq) > 1
            start_pos += len(last_fsm_state_seq) - 1
            start_state = last_fsm_state_seq[-1]
        else:
            start_state = self.fsm.initial

        text_part = text[start_pos:]

        text_transitions = get_token_transition_keys(
            self.fsm.fsm_info.alphabet_symbol_mapping,
            self.fsm.fsm_info.alphabet_anything_value,
            text_part,
        )

        state_seq = walk_fsm(
            self.fsm,
            text_transitions,
            start_state,
            full_match=self.match_whole,
        )

        if not state_seq:
            return None

        if last_fsm_state_seq:
            res = last_fsm_state_seq + tuple(state_seq)
        else:
            res = (start_state,) + tuple(state_seq)

        return res

get_terminals_info(fsm_state_seq)

Get the possible terminal symbols for an FSM state sequence.

Source code in outlines/fsm/parsing.py
def get_terminals_info(
    self, fsm_state_seq
) -> Tuple[Tuple[PartialTerminalInfo, ...], Tuple[PartialTerminalInfo, ...]]:
    """Get the possible terminal symbols for an FSM state sequence."""
    terminals_and_info: Tuple[PartialTerminalInfo, ...] = ()
    final_terminals_and_info: Tuple[PartialTerminalInfo, ...] = ()
    for i, (fsm_id, fsm_reads_more, in_final) in enumerate(
        get_sub_fsms_from_seq(fsm_state_seq, self.fsms_to_trans_finals)
    ):
        terminal_name = self.terminals[fsm_id].name
        info = PartialTerminalInfo(i, terminal_name, fsm_reads_more, in_final)
        terminals_and_info += (info,)
        if in_final:
            final_terminals_and_info += (info,)

    return terminals_and_info, final_terminals_and_info

match(text, pos, last_fsm_state_seq=None)

Determine an FSM match over text starting at pos and continuing last_fsm_state_seq.

Source code in outlines/fsm/parsing.py
def match(self, text, pos, last_fsm_state_seq: Optional[Tuple[int, ...]] = None):
    """Determine an FSM match over `text` starting at `pos` and continuing `last_fsm_state_seq`."""

    start_pos = pos

    if last_fsm_state_seq:
        assert len(last_fsm_state_seq) > 1
        start_pos += len(last_fsm_state_seq) - 1
        start_state = last_fsm_state_seq[-1]
    else:
        start_state = self.fsm.initial

    text_part = text[start_pos:]

    text_transitions = get_token_transition_keys(
        self.fsm.fsm_info.alphabet_symbol_mapping,
        self.fsm.fsm_info.alphabet_anything_value,
        text_part,
    )

    state_seq = walk_fsm(
        self.fsm,
        text_transitions,
        start_state,
        full_match=self.match_whole,
    )

    if not state_seq:
        return None

    if last_fsm_state_seq:
        res = last_fsm_state_seq + tuple(state_seq)
    else:
        res = (start_state,) + tuple(state_seq)

    return res

fsm_union(fsms)

Construct an FSM representing the union of the FSMs in fsms.

This is an updated version of interegular.fsm.FSM.union made to return an extra map of component FSMs to the sets of state transitions that correspond to them in the new FSM.

Source code in outlines/fsm/parsing.py
def fsm_union(
    fsms: Sequence[FSM],
) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]:
    """Construct an FSM representing the union of the FSMs in `fsms`.

    This is an updated version of `interegular.fsm.FSM.union` made to return an
    extra map of component FSMs to the sets of state transitions that
    correspond to them in the new FSM.

    """

    alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms])

    indexed_fsms = tuple(enumerate(fsms))

    initial = {i: fsm.initial for (i, fsm) in indexed_fsms}

    # Dedicated function accepting a "superset" and returning the next
    # "superset" obtained by following this transition in the new FSM
    def follow(current_state, new_transition: int):
        next = {}
        for i, f in indexed_fsms:
            old_transition = new_to_old[i][new_transition]
            if (
                i in current_state
                and current_state[i] in f.map
                and old_transition in f.map[current_state[i]]
            ):
                next[i] = f.map[current_state[i]][old_transition]
        if not next:
            raise OblivionError
        return next

    states = [initial]
    finals: Set[int] = set()
    map: Dict[int, Dict[int, int]] = {}

    # Map component FSMs to their new state-to-state transitions, finals, and a
    # map translating component FSM states to aggregate FSM states
    fsms_to_trans_finals: Dict[
        int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
    ] = {}

    i = 0
    while i < len(states):
        state = states[i]

        # Add to the finals of the aggregate FSM whenever we hit a final in a
        # component FSM
        if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms):
            finals.add(i)

        # Compute the map for this state
        map[i] = {}
        for transition in alphabet.by_transition:
            try:
                next = follow(state, transition)
            except OblivionError:
                # Reached an oblivion state; don't list it
                continue
            else:
                try:
                    # TODO: Seems like this could--and should--be avoided
                    j = states.index(next)
                except ValueError:
                    j = len(states)
                    states.append(next)

                map[i][transition] = j

                for fsm_id, fsm_state in next.items():
                    (
                        fsm_transitions,
                        fsm_finals,
                        fsm_old_to_new,
                    ) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {}))
                    old_from = state[fsm_id]
                    old_to = fsm_state
                    fsm_old_to_new.setdefault(old_from, set()).add(i)
                    fsm_old_to_new.setdefault(old_to, set()).add(j)
                    fsm_transitions.add((i, j))
                    if fsm_state in fsms[fsm_id].finals:
                        fsm_finals.add(j)

        i += 1

    fsm = FSM(
        alphabet=alphabet,
        states=range(len(states)),
        initial=0,
        finals=finals,
        map=map,
        __no_validation__=True,
    )

    fsm, old_to_new_states = make_deterministic_fsm(fsm)
    _fsms_to_trans_finals = {
        fsm_id: (
            {(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions},
            {old_to_new_states[s] for s in finals},
            {
                old_state: {old_to_new_states[new_state] for new_state in new_states}
                for old_state, new_states in old_to_new.items()
            },
        )
        for fsm_id, (transitions, finals, old_to_new) in sorted(
            fsms_to_trans_finals.items(), key=lambda x: x[0]
        )
    }

    return (
        fsm,
        _fsms_to_trans_finals,
    )

get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)

Get the indices of the sub-FSMs in fsm that could have matched the state sequence state_seq.

Parameters:

Name Type Description Default
state_seq Sequence[int]

A state sequence.

required
fsms_to_trans_finals Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]

A map from FSM indices to tuples containing sets of their state transitions and sets of the final/accept states.

required

Returns:

Type Description
A generator returning tuples containing each sub-FSM index (in the order
they were union-ed to construct `fsm`) and booleans indicating whether or
not there is another valid transition from the last state in the sequence
for the associated sub-FSM (i.e. if the FSM can continue
accepting/matching) and whether or not the sequence ends in a final state
of the sub-FSM.
Source code in outlines/fsm/parsing.py
def get_sub_fsms_from_seq(
    state_seq: Sequence[int],
    fsms_to_trans_finals: Dict[
        int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
    ],
) -> Generator[Tuple[int, bool, bool], None, None]:
    """Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`.

    Parameters
    ----------
    state_seq
        A state sequence.
    fsms_to_trans_finals
        A map from FSM indices to tuples containing sets of their state transitions
        and sets of the final/accept states.

    Returns
    -------
    A generator returning tuples containing each sub-FSM index (in the order
    they were union-ed to construct `fsm`) and booleans indicating whether or
    not there is another valid transition from the last state in the sequence
    for the associated sub-FSM (i.e. if the FSM can continue
    accepting/matching) and whether or not the sequence ends in a final state
    of the sub-FSM.
    """
    state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:]))
    last_fsm_state = state_seq[-1]
    yield from (
        (
            # The sub-FMS index
            fsm_idx,
            # Is there another possible transition in this sub-FSM?
            any(last_fsm_state == from_s for (from_s, to_s) in transitions),
            # Is this sub-FSM in a final state?
            state_seq[-1] in finals,
        )
        for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items()
        if state_seq_transitions.issubset(transitions)
    )

terminals_to_fsms(lp)

Construct a dict mapping terminal symbol names to their finite state machines.

Source code in outlines/fsm/parsing.py
def terminals_to_fsms(lp: PartialLark) -> Dict[str, FSM]:
    """Construct a ``dict`` mapping terminal symbol names to their finite state machines."""

    symbol_names_and_fsms = {}
    for terminal in lp.terminals:
        pattern = interegular.parse_pattern(terminal.pattern.to_regexp())
        # TODO: Use `pyparser.terminals[0].pattern.flags`?
        try:
            fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce())
        except Unsupported:
            fsm = None

        symbol_names_and_fsms[terminal.name] = fsm

    return symbol_names_and_fsms