| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import random |
| | import re |
| | import string |
| | import ast |
| | import json |
| | from collections.abc import Sequence |
| | from typing import Union, Tuple, List, Optional |
| |
|
| | from vllm.entrypoints.openai.protocol import ( |
| | ChatCompletionRequest, |
| | DeltaMessage, |
| | DeltaFunctionCall, |
| | DeltaToolCall, |
| | ExtractedToolCallInformation, |
| | ToolCall, |
| | FunctionCall, |
| | ) |
| | from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( |
| | ToolParser |
| | ) |
| | from vllm.logger import init_logger |
| |
|
| | import pyjson5 |
| |
|
| | class ToolCallID: |
| | _LENGTH = 10 |
| |
|
| | def __init__(self, id_val: str, validation: bool = False): |
| | self._id = id_val |
| | if validation: |
| | self._validate() |
| |
|
| | @classmethod |
| | def random(cls, validation=False) -> 'ToolCallID': |
| | chars = string.ascii_lowercase + string.digits |
| | return cls(''.join(random.choice(chars) for _ in range(ToolCallID._LENGTH)), validation=validation) |
| |
|
| | def _validate(self): |
| | assert len(self._id) == ToolCallID._LENGTH |
| | pattern = r'^[a-z0-9]{10}$' |
| | assert re.match(pattern, self._id) is not None |
| |
|
| | def to_string(self) -> str: |
| | return self._id |
| |
|
| | def __str__(self) -> str: |
| | return self.to_string() |
| |
|
| |
|
| | logger = init_logger(__name__) |
| |
|
| |
|
| | class SolarOpenToolParser(ToolParser): |
| |
|
| | def extract_tool_calls( |
| | self, |
| | model_output: str, |
| | request: ChatCompletionRequest, |
| | ) -> ExtractedToolCallInformation: |
| | content, tool_calls = self._parse_text(model_output) |
| | return ExtractedToolCallInformation( |
| | tools_called=len(tool_calls) > 0, |
| | tool_calls=tool_calls, |
| | content=content if content else None, |
| | ) |
| |
|
| | def extract_tool_calls_streaming( |
| | self, |
| | previous_text: str, |
| | current_text: str, |
| | delta_text: str, |
| | previous_token_ids: Sequence[int], |
| | current_token_ids: Sequence[int], |
| | delta_token_ids: Sequence[int], |
| | request: ChatCompletionRequest, |
| | ) -> Union[DeltaMessage, None]: |
| | |
| | |
| | |
| | |
| | |
| | |
| | if delta_text: |
| | |
| | |
| | |
| | special_markers = ( |
| | "<|flush|>", |
| | "<|end|>", |
| | "<|begin|>", |
| | "<|tool_calls|>", |
| | "<|tool_call:begin|>", |
| | "<|tool_call:name|>", |
| | "<|tool_call:args|>", |
| | "<|tool_call:end|>", |
| | "<|calls|>", |
| | ) |
| | if not any(tag in previous_text for tag in special_markers): |
| | if not any(tag in delta_text for tag in special_markers): |
| | return DeltaMessage(content=delta_text, tool_calls=[]) |
| |
|
| | tool_call_deltas: list[DeltaToolCall] = [] |
| |
|
| | |
| | def _completed_calls_count(txt: str) -> int: |
| | return len(self._parse_tool_calls(txt)) |
| |
|
| | |
| | if delta_text and "<|tool_call:args|>" in delta_text: |
| | |
| | begin_tag = "<|tool_call:begin|>" |
| | name_tag = "<|tool_call:name|>" |
| | args_tag = "<|tool_call:args|>" |
| |
|
| | latest_args = current_text.rfind(args_tag) |
| | latest_name = current_text.rfind(name_tag, 0, latest_args if latest_args != -1 else None) |
| | latest_begin = current_text.rfind(begin_tag, 0, latest_name if latest_name != -1 else None) |
| | if latest_begin != -1 and latest_name != -1 and latest_args != -1 and latest_begin < latest_name < latest_args: |
| | tool_id = current_text[latest_begin + len(begin_tag):latest_name] |
| | func_name = current_text[latest_name + len(name_tag):latest_args] |
| | |
| | index = previous_text.count(args_tag) |
| | tool_call_deltas.append( |
| | DeltaToolCall( |
| | id=tool_id, |
| | type="function", |
| | index=index, |
| | function=DeltaFunctionCall(name=func_name, arguments=""), |
| | ) |
| | ) |
| |
|
| | |
| | begin_tag = "<|tool_call:begin|>" |
| | args_tag = "<|tool_call:args|>" |
| | end_tag = "<|tool_call:end|>" |
| | last_args_pos = current_text.rfind(args_tag) |
| | last_end_pos = current_text.rfind(end_tag) |
| | if last_args_pos != -1 and (last_end_pos == -1 or last_args_pos > last_end_pos): |
| | |
| | |
| | prev_last_args = previous_text.rfind(args_tag) |
| | prev_last_end = previous_text.rfind(end_tag) |
| | if prev_last_args != -1 and (prev_last_end == -1 or prev_last_args > prev_last_end): |
| | |
| | if delta_text and delta_text not in (begin_tag, args_tag, end_tag): |
| | |
| | index = max(previous_text.count(args_tag) - 1, 0) |
| | tool_call_deltas.append( |
| | DeltaToolCall( |
| | id=None, |
| | type=None, |
| | index=index, |
| | function=DeltaFunctionCall(name=None, arguments=delta_text), |
| | ) |
| | ) |
| |
|
| | if not tool_call_deltas: |
| | return None |
| |
|
| | return DeltaMessage(content=None, tool_calls=tool_call_deltas) |
| |
|
| | |
| | |
| | |
| | def _parse_text(self, text: str) -> Tuple[Optional[str], List[ToolCall]]: |
| | """Parse the completed segments from the given text. |
| | |
| | Returns (content, tool_calls) where content is extracted as the leading |
| | text up to the first '<|flush|>' or '<|end|>' marker, and tool_calls is |
| | a list of fully parsed tool calls inside '<|tool_calls|> ... <|calls|>'. |
| | """ |
| | content = self._parse_content(text) |
| | tool_calls = self._parse_tool_calls(text) |
| | return content, tool_calls |
| |
|
| | def _parse_content(self, text: str) -> Optional[str]: |
| | """Extract assistant content from the text. |
| | |
| | Rule: take the leading content before the first '<|flush|>' or |
| | '<|end|>' marker. If neither marker exists, return None. |
| | """ |
| | end_tags = ["<|flush|>", "<|end|>"] |
| |
|
| | |
| | end_positions = [pos for tag in end_tags if (pos := text.find(tag)) != -1] |
| | if not end_positions: |
| | return None |
| | end = min(end_positions) |
| | |
| | return text[:end] |
| |
|
| | def _parse_tool_call_args(self, text: str) -> str: |
| | try: |
| | |
| | args = json.loads(text) |
| | except json.JSONDecodeError: |
| | try: |
| | |
| | args = pyjson5.decode(text) |
| | except pyjson5.Json5DecoderException: |
| | try: |
| | |
| | args = ast.literal_eval(text) |
| | except Exception: |
| | |
| | args = text |
| | if not isinstance(args, str): |
| | |
| | args = json.dumps(args) |
| | return args |
| |
|
| | def _parse_tool_calls(self, text: str) -> List[ToolCall]: |
| | tool_calls: list[ToolCall] = [] |
| | |
| | section_start = 0 |
| | |
| | section_end = text.find("<|calls|>") |
| | if section_end == -1: |
| | section_end = len(text) |
| | i = section_start |
| | while True: |
| | begin_tag = "<|tool_call:begin|>" |
| | name_tag = "<|tool_call:name|>" |
| | args_tag = "<|tool_call:args|>" |
| | end_tag = "<|tool_call:end|>" |
| |
|
| | b = text.find(begin_tag, i, section_end) |
| | if b == -1: |
| | break |
| | b += len(begin_tag) |
| | n = text.find(name_tag, b, section_end) |
| | if n == -1: |
| | break |
| | tool_id = text[b:n] |
| | n += len(name_tag) |
| | a = text.find(args_tag, n, section_end) |
| | if a == -1: |
| | break |
| | name = text[n:a] |
| | a += len(args_tag) |
| | e = text.find(end_tag, a, section_end) |
| | if e == -1: |
| | break |
| | args = text[a:e] |
| | tool_calls.append( |
| | ToolCall( |
| | id=tool_id, |
| | function=FunctionCall(name=name, arguments=self._parse_tool_call_args(args)), |
| | )) |
| | i = e + len(end_tag) |
| |
|
| | return tool_calls |
| |
|