itlevy commited on
Commit
c8f9dc1
·
verified ·
1 Parent(s): 4c23945

Update README and toolcall_parser

Browse files
README.md CHANGED
@@ -279,7 +279,7 @@ We evaluate the model using temperature=`0.6`, top_p=`0.95`, and 64k sequence le
279
 
280
  | Reasoning Mode | pass@1 (avg. over 16 runs) |
281
  |--------------|------------|
282
- | Reasoning On | |
283
 
284
  ### GPQA
285
 
@@ -297,19 +297,19 @@ We evaluate the model using temperature=`0.6`, top_p=`0.95`, and 64k sequence le
297
 
298
  | Reasoning Mode | pass@1 (avg. over 2 runs) |
299
  |--------------|------------|
300
- | Reasoning On | |
301
 
302
  ### IFEval
303
 
304
  | Reasoning Mode | Strict:Instruction |
305
  |--------------|------------|
306
- | Reasoning On | |
307
 
308
  ### ArenaHard
309
 
310
  | Reasoning Mode | pass@1 (avg. over 1 runs) |
311
  |--------------|------------|
312
- | Reasoning On | |
313
 
314
  ### Humanity's Last Exam (Text-Only Subset)
315
 
 
279
 
280
  | Reasoning Mode | pass@1 (avg. over 16 runs) |
281
  |--------------|------------|
282
+ | Reasoning On | 79.0 |
283
 
284
  ### GPQA
285
 
 
297
 
298
  | Reasoning Mode | pass@1 (avg. over 2 runs) |
299
  |--------------|------------|
300
+ | Reasoning On | 77.39 |
301
 
302
  ### IFEval
303
 
304
  | Reasoning Mode | Strict:Instruction |
305
  |--------------|------------|
306
+ | Reasoning On | 85.86 |
307
 
308
  ### ArenaHard
309
 
310
  | Reasoning Mode | pass@1 (avg. over 1 runs) |
311
  |--------------|------------|
312
+ | Reasoning On | 94.6 |
313
 
314
  ### Humanity's Last Exam (Text-Only Subset)
315
 
llama_nemotron_toolcall_parser_no_streaming.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import json
3
+ import re
4
+ from collections.abc import Sequence
5
+ from typing import Union
6
+
7
+ import partial_json_parser
8
+ from partial_json_parser.core.options import Allow
9
+
10
+ from vllm.entrypoints.openai.protocol import (
11
+ ChatCompletionRequest,
12
+ DeltaFunctionCall, DeltaMessage,
13
+ DeltaToolCall,
14
+ ExtractedToolCallInformation,
15
+ FunctionCall,
16
+ ToolCall,
17
+ )
18
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
19
+ ToolParser,
20
+ ToolParserManager,
21
+ )
22
+ from vllm.logger import init_logger
23
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
24
+ from vllm.utils import random_uuid
25
+
26
+ logger = init_logger(__name__)
27
+
28
+
29
+ @ToolParserManager.register_module("llama_nemotron_xml")
30
+ class LlamaNemotronXMLToolParser(ToolParser):
31
+
32
+ def __init__(self, tokenizer: AnyTokenizer):
33
+ super().__init__(tokenizer)
34
+
35
+ self.current_tool_name_sent: bool = False
36
+ self.prev_tool_call_arr: list[dict] = []
37
+ self.current_tool_id: int = -1 # Potentially for streaming
38
+ self.streamed_args_for_tool: list[str] = [] # Potentially for streaming
39
+
40
+ self.tool_call_start_token: str = "<tool_call>"
41
+ self.tool_call_end_token: str = "</tool_call>"
42
+
43
+ # Regex to find full <tool_call>...</tool_call> blocks and capture their content
44
+ self.tool_call_block_regex = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
45
+ # Regex to find <tool>...</tool> within a tool_call block content
46
+ self.name_regex = re.compile(r"<tool>(.*?)</tool>", re.DOTALL)
47
+ # Regex to find <key>value</key> pairs within the tool_call block content (excluding <tool> tags)
48
+ self.param_regex = re.compile(r"<([^/>\s]+)>(.*?)</\1>", re.DOTALL)
49
+
50
+ def extract_tool_calls(
51
+ self,
52
+ model_output: str,
53
+ request: ChatCompletionRequest,
54
+ ) -> ExtractedToolCallInformation:
55
+
56
+ tool_call_start_index = model_output.find(self.tool_call_start_token)
57
+
58
+ if tool_call_start_index == -1:
59
+ return ExtractedToolCallInformation(
60
+ tools_called=False,
61
+ tool_calls=[],
62
+ content=model_output,
63
+ )
64
+
65
+ content = model_output[:tool_call_start_index].strip()
66
+ tool_calls_str_content = model_output[tool_call_start_index:]
67
+
68
+ parsed_tool_calls = []
69
+
70
+ try:
71
+ # Find all occurrences of <tool_call>...</tool_call>
72
+ xml_tool_call_contents = self.tool_call_block_regex.findall(tool_calls_str_content)
73
+
74
+ for tool_content_str in xml_tool_call_contents:
75
+ name_match = self.name_regex.search(tool_content_str)
76
+ if not name_match:
77
+ logger.warning(f"Could not find tool name in XML block: {tool_content_str}")
78
+ continue
79
+ tool_name = name_match.group(1).strip()
80
+
81
+ parsed_arguments = {}
82
+
83
+ # Find all parameter tags in the tool_call content, excluding the <tool> tag
84
+ param_matches = self.param_regex.finditer(tool_content_str)
85
+
86
+ for match in param_matches:
87
+ param_name = match.group(1).strip()
88
+ param_value_str = match.group(2).strip()
89
+
90
+ # Skip the <tool> tag since it's not a parameter
91
+ if param_name == "tool":
92
+ continue
93
+
94
+ target_type = None
95
+ # Try to get type from request.tools schema
96
+ if request.tools:
97
+ for tool_def in request.tools:
98
+ if tool_def.function.name == tool_name:
99
+ if tool_def.function.parameters and \
100
+ isinstance(tool_def.function.parameters, dict) and \
101
+ "properties" in tool_def.function.parameters and \
102
+ isinstance(tool_def.function.parameters["properties"], dict) and \
103
+ param_name in tool_def.function.parameters["properties"] and \
104
+ isinstance(tool_def.function.parameters["properties"][param_name], dict):
105
+ target_type = tool_def.function.parameters["properties"][param_name].get("type")
106
+ break
107
+
108
+ typed_param_value = param_value_str # Default to string
109
+ if target_type:
110
+ try:
111
+ if target_type == "string":
112
+ typed_param_value = param_value_str
113
+ elif target_type == "integer":
114
+ typed_param_value = int(param_value_str)
115
+ elif target_type == "number":
116
+ typed_param_value = float(param_value_str)
117
+ elif target_type == "boolean":
118
+ typed_param_value = param_value_str.lower() == 'true'
119
+ elif target_type in ["object", "array"]:
120
+ try:
121
+ typed_param_value = json.loads(param_value_str)
122
+ except json.JSONDecodeError:
123
+ # Fallback for non-strict JSON like Python dict/list string
124
+ typed_param_value = ast.literal_eval(param_value_str)
125
+ else: # Unknown type, keep as string
126
+ typed_param_value = param_value_str
127
+ except (ValueError, SyntaxError, json.JSONDecodeError) as e:
128
+ logger.warning(
129
+ f"Could not convert param '{param_name}' with value '{param_value_str}' "
130
+ f"to type '{target_type}'. Error: {e}. Using string value."
131
+ )
132
+ typed_param_value = param_value_str
133
+ else: # No schema type, try ast.literal_eval
134
+ try:
135
+ # For values like "true", "123", "['a', 'b']"
136
+ # ast.literal_eval('some_string_without_quotes') will raise SyntaxError
137
+ if (param_value_str.startswith("'") and param_value_str.endswith("'")) or \
138
+ (param_value_str.startswith('"') and param_value_str.endswith('"')) or \
139
+ (param_value_str.startswith('[') and param_value_str.endswith(']')) or \
140
+ (param_value_str.startswith('{') and param_value_str.endswith('}')) or \
141
+ param_value_str.lower() in ['true', 'false', 'none'] or \
142
+ param_value_str.replace('.', '', 1).isdigit() or \
143
+ (param_value_str.startswith('-') and param_value_str[1:].replace('.', '', 1).isdigit()):
144
+ typed_param_value = ast.literal_eval(param_value_str)
145
+ else: # It's likely a plain string not meant for ast.literal_eval
146
+ typed_param_value = param_value_str
147
+ except (ValueError, SyntaxError):
148
+ typed_param_value = param_value_str # Keep as string if ast.literal_eval fails
149
+
150
+ parsed_arguments[param_name] = typed_param_value
151
+
152
+ parsed_tool_calls.append(ToolCall(
153
+ id=f"call_{random_uuid()}",
154
+ type="function",
155
+ function=FunctionCall(
156
+ name=tool_name,
157
+ arguments=json.dumps(parsed_arguments, ensure_ascii=False),
158
+ ),
159
+ ))
160
+
161
+ return ExtractedToolCallInformation(
162
+ tools_called=len(parsed_tool_calls) > 0,
163
+ tool_calls=parsed_tool_calls,
164
+ content=content if content else None,
165
+ )
166
+
167
+ except Exception:
168
+ logger.exception(f"Error in extracting XML tool call from response. Response: {model_output}")
169
+ # Fallback to original model output if parsing fails catastrophically
170
+ return ExtractedToolCallInformation(
171
+ tools_called=False,
172
+ tool_calls=[],
173
+ content=model_output,
174
+ )
175
+
176
+ def extract_tool_calls_streaming(
177
+ self,
178
+ previous_text: str,
179
+ current_text: str,
180
+ delta_text: str,
181
+ previous_token_ids: Sequence[int],
182
+ current_token_ids: Sequence[int],
183
+ delta_token_ids: Sequence[int],
184
+ request: ChatCompletionRequest,
185
+ ) -> Union[DeltaMessage, None]:
186
+
187
+ raise NotImplementedError("Tool calling is not supported in streaming mode!")
188
+
189
+
190
+ @ToolParserManager.register_module("llama_nemotron_json")
191
+ class LlamaNemotronJSONToolParser(ToolParser):
192
+
193
+ def __init__(self, tokenizer: AnyTokenizer):
194
+ super().__init__(tokenizer)
195
+
196
+ self.current_tool_name_sent: bool = False
197
+ self.prev_tool_call_arr: list[dict] = []
198
+ self.current_tool_id: int = -1
199
+ self.streamed_args_for_tool: list[str] = []
200
+
201
+ self.tool_call_start_token: str = "<TOOLCALL>"
202
+ self.tool_call_end_token: str = "</TOOLCALL>"
203
+
204
+ self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL)
205
+
206
+ def extract_tool_calls(
207
+ self,
208
+ model_output: str,
209
+ request: ChatCompletionRequest,
210
+ ) -> ExtractedToolCallInformation:
211
+
212
+ if self.tool_call_start_token not in model_output:
213
+ return ExtractedToolCallInformation(
214
+ tools_called=False,
215
+ tool_calls=[],
216
+ content=model_output,
217
+ )
218
+
219
+ else:
220
+
221
+ try:
222
+ str_tool_calls = self.tool_call_regex.findall(model_output)[0].strip()
223
+ if not str_tool_calls.startswith("["):
224
+ str_tool_calls = "[" + str_tool_calls
225
+ if not str_tool_calls.endswith("]"):
226
+ str_tool_calls = "]" + str_tool_calls
227
+ json_tool_calls = json.loads(str_tool_calls)
228
+ tool_calls = []
229
+ for tool_call in json_tool_calls:
230
+ try:
231
+ tool_calls.append(ToolCall(
232
+ type="function",
233
+ function=FunctionCall(
234
+ name=tool_call["name"],
235
+ arguments=json.dumps(tool_call["arguments"], ensure_ascii=False) \
236
+ if isinstance(tool_call["arguments"], dict) else tool_call["arguments"],
237
+ ),
238
+ ))
239
+ except:
240
+ continue
241
+
242
+ content = model_output[:model_output.rfind(self.tool_call_start_token)]
243
+
244
+ return ExtractedToolCallInformation(
245
+ tools_called=True,
246
+ tool_calls=tool_calls,
247
+ content=content if content else None,
248
+ )
249
+
250
+ except Exception:
251
+ logger.exception(f"Error in extracting tool call from response. Response: {model_output}")
252
+ return ExtractedToolCallInformation(
253
+ tools_called=False,
254
+ tool_calls=[],
255
+ content=model_output,
256
+ )
257
+
258
+ def extract_tool_calls_streaming(
259
+ self,
260
+ previous_text: str,
261
+ current_text: str,
262
+ delta_text: str,
263
+ previous_token_ids: Sequence[int],
264
+ current_token_ids: Sequence[int],
265
+ delta_token_ids: Sequence[int],
266
+ request: ChatCompletionRequest,
267
+ ) -> Union[DeltaMessage, None]:
268
+
269
+ raise NotImplementedError("Tool calling is not supported in streaming mode!")
270
+
271
+
272
+ @ToolParserManager.register_module("llama_nemotron_pythonic")
273
+ class LlamaNemotronPythonicToolParser(ToolParser):
274
+
275
+ def __init__(self, tokenizer: AnyTokenizer):
276
+ super().__init__(tokenizer)
277
+
278
+ self.current_tool_name_sent: bool = False
279
+ self.prev_tool_call_arr: list[dict] = []
280
+ self.current_tool_id: int = -1
281
+ self.streamed_args_for_tool: list[str] = []
282
+
283
+ self.tool_call_start_token: str = "<TOOLCALL>"
284
+ self.tool_call_end_token: str = "</TOOLCALL>"
285
+
286
+ self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL)
287
+ # Regex to parse pythonic function calls: function_name(arg1="value1", arg2=123, arg3=True)
288
+ self.function_call_regex = re.compile(r"(\w+)\((.*?)\)$", re.DOTALL)
289
+
290
+ def parse_function_arguments(self, args_str: str) -> dict:
291
+ """Parse pythonic function arguments string into a dictionary"""
292
+ if not args_str.strip():
293
+ return {}
294
+
295
+ # Use ast.parse to safely parse the function call arguments
296
+ # We'll construct a temporary function call and parse it
297
+ try:
298
+ # Create a dummy function call to parse arguments
299
+ dummy_code = f"dummy_func({args_str})"
300
+ parsed = ast.parse(dummy_code, mode='eval')
301
+
302
+ # Extract arguments from the AST
303
+ call_node = parsed.body
304
+ if not isinstance(call_node, ast.Call):
305
+ return {}
306
+
307
+ arguments = {}
308
+
309
+ # Handle keyword arguments
310
+ for keyword in call_node.keywords:
311
+ if keyword.arg is None: # **kwargs
312
+ continue
313
+
314
+ # Convert AST value to Python value
315
+ try:
316
+ value = ast.literal_eval(keyword.value)
317
+ arguments[keyword.arg] = value
318
+ except (ValueError, TypeError):
319
+ # If literal_eval fails, try to get the raw value
320
+ if isinstance(keyword.value, ast.Name):
321
+ arguments[keyword.arg] = keyword.value.id
322
+ elif isinstance(keyword.value, ast.Constant):
323
+ arguments[keyword.arg] = keyword.value.value
324
+ else:
325
+ # Fallback: convert to string
326
+ arguments[keyword.arg] = ast.unparse(keyword.value)
327
+
328
+ # Handle positional arguments (less common in tool calls but supported)
329
+ for i, arg in enumerate(call_node.args):
330
+ try:
331
+ value = ast.literal_eval(arg)
332
+ arguments[f"arg_{i}"] = value
333
+ except (ValueError, TypeError):
334
+ if isinstance(arg, ast.Name):
335
+ arguments[f"arg_{i}"] = arg.id
336
+ elif isinstance(arg, ast.Constant):
337
+ arguments[f"arg_{i}"] = arg.value
338
+ else:
339
+ arguments[f"arg_{i}"] = ast.unparse(arg)
340
+
341
+ return arguments
342
+
343
+ except (SyntaxError, ValueError) as e:
344
+ logger.warning(f"Failed to parse function arguments '{args_str}': {e}")
345
+ return {}
346
+
347
+ def extract_tool_calls(
348
+ self,
349
+ model_output: str,
350
+ request: ChatCompletionRequest,
351
+ ) -> ExtractedToolCallInformation:
352
+
353
+ if self.tool_call_start_token not in model_output:
354
+ return ExtractedToolCallInformation(
355
+ tools_called=False,
356
+ tool_calls=[],
357
+ content=model_output,
358
+ )
359
+
360
+ tool_call_start_index = model_output.find(self.tool_call_start_token)
361
+ content = model_output[:tool_call_start_index].strip()
362
+
363
+ try:
364
+ # Extract content between <TOOLCALL> tags
365
+ tool_call_matches = self.tool_call_regex.findall(model_output)
366
+ if not tool_call_matches:
367
+ return ExtractedToolCallInformation(
368
+ tools_called=False,
369
+ tool_calls=[],
370
+ content=model_output,
371
+ )
372
+
373
+ tool_calls_content = tool_call_matches[0].strip()
374
+
375
+ # Split by lines to get individual function calls
376
+ function_lines = [line.strip() for line in tool_calls_content.split('\n') if line.strip()]
377
+
378
+ parsed_tool_calls = []
379
+
380
+ for func_line in function_lines:
381
+ # Parse each function call
382
+ match = self.function_call_regex.match(func_line)
383
+ if not match:
384
+ logger.warning(f"Could not parse function call: {func_line}")
385
+ continue
386
+
387
+ function_name = match.group(1)
388
+ args_str = match.group(2)
389
+
390
+ # Parse arguments
391
+ parsed_arguments = self.parse_function_arguments(args_str)
392
+
393
+ # Apply type conversion based on schema if available
394
+ if request.tools:
395
+ for tool_def in request.tools:
396
+ if tool_def.function.name == function_name:
397
+ schema_properties = {}
398
+ if (tool_def.function.parameters and
399
+ isinstance(tool_def.function.parameters, dict) and
400
+ "properties" in tool_def.function.parameters and
401
+ isinstance(tool_def.function.parameters["properties"], dict)):
402
+ schema_properties = tool_def.function.parameters["properties"]
403
+
404
+ # Convert arguments based on schema types
405
+ for arg_name, arg_value in parsed_arguments.items():
406
+ if arg_name in schema_properties:
407
+ param_info = schema_properties[arg_name]
408
+ target_type = param_info.get("type")
409
+
410
+ try:
411
+ if target_type == "string" and not isinstance(arg_value, str):
412
+ parsed_arguments[arg_name] = str(arg_value)
413
+ elif target_type == "integer" and not isinstance(arg_value, int):
414
+ parsed_arguments[arg_name] = int(arg_value)
415
+ elif target_type == "number" and not isinstance(arg_value, (int, float)):
416
+ parsed_arguments[arg_name] = float(arg_value)
417
+ elif target_type == "boolean" and not isinstance(arg_value, bool):
418
+ if isinstance(arg_value, str):
419
+ parsed_arguments[arg_name] = arg_value.lower() in ['true', '1', 'yes']
420
+ else:
421
+ parsed_arguments[arg_name] = bool(arg_value)
422
+ elif target_type in ["object", "array"]:
423
+ if isinstance(arg_value, str):
424
+ try:
425
+ parsed_arguments[arg_name] = json.loads(arg_value)
426
+ except json.JSONDecodeError:
427
+ # Keep as string if JSON parsing fails
428
+ pass
429
+ except (ValueError, TypeError) as e:
430
+ logger.warning(f"Type conversion failed for {arg_name}: {e}")
431
+ # Keep original value if conversion fails
432
+ break
433
+
434
+ parsed_tool_calls.append(ToolCall(
435
+ id=f"call_{random_uuid()}",
436
+ type="function",
437
+ function=FunctionCall(
438
+ name=function_name,
439
+ arguments=json.dumps(parsed_arguments, ensure_ascii=False),
440
+ ),
441
+ ))
442
+
443
+ return ExtractedToolCallInformation(
444
+ tools_called=len(parsed_tool_calls) > 0,
445
+ tool_calls=parsed_tool_calls,
446
+ content=content if content else None,
447
+ )
448
+
449
+ except Exception:
450
+ logger.exception(f"Error in extracting pythonic tool call from response. Response: {model_output}")
451
+ return ExtractedToolCallInformation(
452
+ tools_called=False,
453
+ tool_calls=[],
454
+ content=model_output,
455
+ )
456
+
457
+ def extract_tool_calls_streaming(
458
+ self,
459
+ previous_text: str,
460
+ current_text: str,
461
+ delta_text: str,
462
+ previous_token_ids: Sequence[int],
463
+ current_token_ids: Sequence[int],
464
+ delta_token_ids: Sequence[int],
465
+ request: ChatCompletionRequest,
466
+ ) -> Union[DeltaMessage, None]:
467
+
468
+ raise NotImplementedError("Tool calling is not supported in streaming mode!")