tuandunghcmut commited on
Commit
e0ce601
·
verified ·
1 Parent(s): 6fc37d0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +1020 -405
README.md CHANGED
@@ -8,18 +8,24 @@ base_model:
8
  - unsloth/Qwen2.5-Coder-1.5B-Instruct
9
  pipeline_tag: text-generation
10
  ---
 
11
 
12
- # Using tuandunghcmut/Qwen25_Coder_MultipleChoice
13
- The project "Knowledge Distillation About YAML-Based Structured Multi-Step Reasoning from a Teacher Model GPT-4o to a Small LLM: Qwen2.5 Coder 1.5B-Instruct" focuses on distilling structured multi-step reasoning from GPT-4o into a smaller model.
14
 
15
- This document provides everything you need to get started with tuandunghcmut/Qwen25_Coder_MultipleChoice, a model designed for multiple-choice coding questions.
16
 
17
- I plan to refactor the project into a well-structured GitHub repository, expand the dataset, and re-train it later with distributed training for better scalability.
18
- ## Installation and Setup
19
 
20
- ### Prerequisites
 
21
 
22
- Make sure you have Python 3.8+ installed. Then install the required packages:
 
 
 
 
 
 
23
 
24
  ```bash
25
  # Install core dependencies
@@ -35,175 +41,381 @@ pip install flash-attn --no-build-isolation
35
  pip install datasets pyyaml
36
  ```
37
 
38
- ### Flash Attention Setup
39
 
40
- Flash Attention provides a significant speedup for transformer models. To use it with the Qwen model:
41
-
42
- 1. Install Flash Attention as shown above
43
- 2. Enable it when loading the model:
44
 
 
45
  ```python
46
- from transformers import AutoModelForCausalLM, AutoTokenizer
47
-
48
- # Enable Flash Attention during model loading
49
- model = AutoModelForCausalLM.from_pretrained(
50
- "tuandunghcmut/Qwen25_Coder_MultipleChoice",
51
- torch_dtype=torch.bfloat16,
52
- device_map="auto",
53
- trust_remote_code=True,
54
- use_flash_attention_2=True # Enable Flash Attention
55
- )
 
 
 
 
 
56
  ```
57
 
58
- Flash Attention will provide:
59
- - 2-3x faster inference speed
60
- - Lower memory usage
61
- - Compatible with 4-bit quantization for even more efficiency
62
-
63
- ### Environment Variables
64
-
65
- If you're using Hugging Face Hub models, you may want to set up your access token:
66
-
67
- ```bash
68
- # Set environment variable for Hugging Face token
69
- export HF_TOKEN="your_huggingface_token_here"
70
 
71
- # Or in Python
72
- import os
73
- os.environ["HF_TOKEN"] = "your_huggingface_token_here"
 
 
 
 
 
 
74
  ```
75
 
76
- ### GPU Setup
 
 
 
77
 
78
- For optimal performance, you'll need a CUDA-compatible GPU. Check your installation:
 
 
 
 
 
 
 
 
79
 
80
- ```bash
81
- # Verify CUDA is available
82
- python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
 
83
 
84
- # Print CUDA device info
85
- python -c "import torch; print('CUDA device count:', torch.cuda.device_count()); print('CUDA device name:', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU')"
 
 
 
 
 
 
 
 
 
 
 
86
  ```
87
 
88
- ## Required Classes
 
 
 
 
 
89
 
90
- Below are the essential classes needed to work with the model. Copy these into your Python files to use them in your project.
91
 
92
- ### PromptCreator
 
93
 
94
- This class formats prompts for multiple-choice questions:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
 
96
  ```python
97
  class PromptCreator:
98
- """
99
- Creates and formats prompts for multiple choice questions
100
- Supports different prompt styles for training and inference
101
- """
102
-
103
  # Prompt types
104
  BASIC = "basic" # Simple answer-only format
105
  YAML_REASONING = "yaml" # YAML formatted reasoning
106
- TEACHER_REASONED = "teacher" # Same YAML format as YAML_REASONING but using teacher completions for training
107
 
108
  def __init__(self, prompt_type=BASIC):
 
 
109
  self.prompt_type = prompt_type
110
- # Initialize parser mode based on prompt type
111
- if prompt_type == self.YAML_REASONING or prompt_type == self.TEACHER_REASONED:
112
- self.parser_mode = "yaml"
113
- else:
114
- self.parser_mode = "basic"
115
 
116
  def format_choices(self, choices):
117
- """Format choices with letter prefixes"""
118
- return "\n".join([f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)])
 
 
119
 
120
  def get_max_letter(self, choices):
121
- """Get the last valid letter based on choice count"""
122
  return chr(65 + len(choices) - 1)
123
 
124
  def create_inference_prompt(self, question, choices):
125
- """Create a prompt for inference based on the configured prompt type"""
126
  formatted_choices = self.format_choices(choices)
127
  max_letter = self.get_max_letter(choices)
128
-
129
- if self.prompt_type == self.BASIC:
130
- return self._create_basic_prompt(question, formatted_choices, max_letter)
131
- elif self.prompt_type == self.YAML_REASONING or self.prompt_type == self.TEACHER_REASONED:
132
  return self._create_yaml_prompt(question, formatted_choices, max_letter)
133
  else:
134
  return self._create_basic_prompt(question, formatted_choices, max_letter)
135
 
136
  def _create_basic_prompt(self, question, formatted_choices, max_letter):
137
- """Create a basic prompt that just asks for an answer letter"""
138
  return f"""
 
139
  {question}
140
 
 
141
  {formatted_choices}
142
 
143
- Select the correct answer from A through {max_letter}:
144
  """
145
 
146
  def _create_yaml_prompt(self, question, formatted_choices, max_letter):
147
- """Create a prompt with YAML formatted reasoning structure"""
148
  return f"""
 
149
  {question}
150
 
 
151
  {formatted_choices}
152
 
153
- Think through this step-by-step:
154
- - Understand what the question is asking
155
- - Analyze each option carefully
156
- - Reason about why each option might be correct or incorrect
157
- - Select the most appropriate answer
158
 
159
- Your response should be in YAML format:
160
  understanding: |
161
- <your understanding of the question>
162
  analysis: |
163
  <your analysis of each option>
164
  reasoning: |
165
- <your reasoning about the correct answer>
166
  conclusion: |
167
  <your final conclusion>
168
- answer: <single letter A through {max_letter} representing your final answer>
 
 
169
  """
170
 
171
  def create_training_prompt(self, question, choices):
172
- """Create a prompt for training based on the configured prompt type"""
173
  formatted_choices = self.format_choices(choices)
174
  max_letter = self.get_max_letter(choices)
175
-
176
- if self.prompt_type == self.BASIC:
177
- return self._create_basic_training_prompt(question, formatted_choices, max_letter)
178
- elif self.prompt_type == self.YAML_REASONING or self.prompt_type == self.TEACHER_REASONED:
179
- return self._create_yaml_training_prompt(question, formatted_choices, max_letter)
180
  else:
181
- return self._create_basic_training_prompt(question, formatted_choices, max_letter)
 
 
182
 
183
  def _create_basic_training_prompt(self, question, formatted_choices, max_letter):
184
  """Create a basic training prompt"""
185
  return f"""
 
186
  {question}
187
 
 
188
  {formatted_choices}
189
 
190
- Select the correct answer from A through {max_letter}:
191
  """
192
 
193
  def _create_yaml_training_prompt(self, question, formatted_choices, max_letter):
194
- """Create a training prompt with YAML formatted reasoning structure"""
195
  return f"""
 
196
  {question}
197
 
 
198
  {formatted_choices}
199
 
200
- Think through this step-by-step:
201
- - Understand what the question is asking
202
- - Analyze each option carefully
203
- - Reason about why each option might be correct or incorrect
204
- - Select the most appropriate answer
205
 
206
- Your response should be in YAML format:
207
  understanding: |
208
  <your understanding of the question>
209
  analysis: |
@@ -212,420 +424,823 @@ reasoning: |
212
  <your reasoning about the correct answer>
213
  conclusion: |
214
  <your final conclusion>
215
- answer: <single letter A through {max_letter} representing your final answer>
216
  """
217
 
218
  def set_prompt_type(self, prompt_type):
219
- """Set the prompt type and update parser mode accordingly"""
 
 
 
220
  self.prompt_type = prompt_type
221
- if prompt_type == self.YAML_REASONING or prompt_type == self.TEACHER_REASONED:
222
- self.parser_mode = "yaml"
223
- else:
224
- self.parser_mode = "basic"
225
-
226
  def is_teacher_mode(self):
227
- """Check if prompt type is teacher mode"""
228
- return self.prompt_type == self.TEACHER_REASONED
229
  ```
230
 
231
- ### ResponseParser
232
-
233
- This class extracts answers from model responses:
234
-
235
  ```python
236
  class ResponseParser:
237
- """
238
- Parser for model responses with support for different formats
239
- Extracts answers and reasoning from model outputs
240
- """
241
 
242
  # Parser modes
243
- BASIC = "basic" # Extract single letter answer
244
- YAML = "yaml" # Parse YAML formatted response with reasoning
245
 
246
  def __init__(self, parser_mode=BASIC):
247
- """Initialize with parser mode (basic or yaml)"""
248
  self.parser_mode = parser_mode
249
-
250
  def parse(self, response_text):
251
- """Parse the response text and extract answer and reasoning"""
252
  if self.parser_mode == self.YAML:
253
  return self._parse_yaml_response(response_text)
254
  else:
255
  return self._parse_basic_response(response_text)
256
 
257
  def _parse_basic_response(self, response_text):
258
- """
259
- Parse a basic response to extract the answer letter
260
-
261
- Returns:
262
- tuple: (answer_letter, None)
263
- """
264
- # Look for just the letter at the end of text
265
  import re
266
 
267
- # Try to find the last occurrence of letters A-Z by themselves
268
- matches = re.findall(r'\b([A-Z])\b', response_text)
269
- if matches:
270
- return matches[-1], None # Return the last matching letter
271
-
272
- # Try to find "The answer is X" pattern
273
- answer_match = re.search(r'[Tt]he answer is[:\s]+([A-Z])', response_text)
274
  if answer_match:
275
- return answer_match.group(1), None
276
-
277
- # If nothing else works, just get the last uppercase letter
278
- uppercase_letters = re.findall(r'[A-Z]', response_text)
279
- if uppercase_letters:
280
- return uppercase_letters[-1], None
281
-
282
- return None, None # No answer found
 
 
 
 
283
 
284
  def _parse_yaml_response(self, response_text):
285
- """
286
- Parse a YAML formatted response to extract the answer and reasoning
287
-
288
- Returns:
289
- tuple: (answer_letter, reasoning_dict)
290
- """
291
  import re
292
  import yaml
293
 
294
- # First try to extract just the answer field
295
- answer_match = re.search(r'answer:\s*([A-Z])', response_text)
296
- answer = answer_match.group(1) if answer_match else None
 
 
 
 
 
 
 
 
 
 
297
 
298
- # Try to extract the entire YAML
299
- try:
300
- # Remove potential code block markers
301
- yaml_text = response_text
302
- if "```yaml" in yaml_text:
303
- yaml_text = yaml_text.split("```yaml")[1]
304
- if "```" in yaml_text:
305
- yaml_text = yaml_text.split("```")[0]
306
- elif "```" in yaml_text:
307
- # Assume the whole thing is a code block
308
- parts = yaml_text.split("```")
309
- if len(parts) >= 3:
310
- yaml_text = parts[1]
311
-
312
- # Parse the YAML
313
- parsed_yaml = yaml.safe_load(yaml_text)
314
-
315
- # If successful, use the answer from the YAML, and return the parsed structure
316
- if isinstance(parsed_yaml, dict) and "answer" in parsed_yaml:
317
- return parsed_yaml.get("answer"), parsed_yaml
318
- except Exception:
319
- # If YAML parsing fails, we already have the answer from regex
320
- pass
321
 
322
- return answer, None
323
 
324
  def set_parser_mode(self, parser_mode):
325
  """Set the parser mode"""
326
  self.parser_mode = parser_mode
 
327
 
328
  @classmethod
329
  def from_prompt_type(cls, prompt_type):
330
- """
331
- Create a ResponseParser with the appropriate mode based on prompt type
332
-
333
- Args:
334
- prompt_type: The prompt type (e.g., PromptCreator.YAML_REASONING)
335
-
336
- Returns:
337
- ResponseParser: A parser configured for the prompt type
338
- """
339
- if prompt_type in ["yaml", "teacher"]:
340
- return cls("yaml")
341
  else:
342
- return cls("basic")
343
  ```
344
 
345
- ### QwenModelHandler
346
-
347
- This class handles model loading and inference:
348
-
349
  ```python
350
- class QwenModelHandler:
351
- def __init__(self, model_name="unsloth/Qwen2.5-7B", max_seq_length=768,
352
- quantization=None, device_map="auto", cache_dir=None,
353
- use_flash_attention=True):
354
- """
355
- Initialize a handler for Qwen models
 
 
 
 
 
 
 
 
 
 
356
 
357
- Args:
358
- model_name: Model identifier (local path or Hugging Face model ID)
359
- max_seq_length: Maximum sequence length
360
- quantization: Quantization method ("4bit", "8bit", or None)
361
- device_map: Device mapping strategy
362
- cache_dir: Directory to cache downloaded models
363
- use_flash_attention: Whether to use Flash Attention 2 for faster inference
364
- """
365
- self.model_name = model_name
366
- self.max_seq_length = max_seq_length
367
- self.quantization = quantization
368
- self.device_map = device_map
369
- self.cache_dir = cache_dir
370
- self.use_flash_attention = use_flash_attention
371
 
372
- self.model = None
373
- self.tokenizer = None
374
 
375
- # Load the model and tokenizer
376
- self._load_model()
377
 
378
- def _load_model(self):
379
- """Load the model and tokenizer with appropriate settings"""
380
- from transformers import AutoModelForCausalLM, AutoTokenizer
381
- import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
- # Load tokenizer
384
- self.tokenizer = AutoTokenizer.from_pretrained(
385
- self.model_name,
386
- trust_remote_code=True,
387
- cache_dir=self.cache_dir
388
- )
389
 
390
- # Prepare model loading kwargs
391
- model_kwargs = {
392
- "trust_remote_code": True,
393
- "cache_dir": self.cache_dir,
394
- "device_map": self.device_map,
 
 
 
 
395
  }
396
 
397
- # Add Flash Attention if requested and available
398
- if self.use_flash_attention:
399
- try:
400
- import flash_attn
401
- model_kwargs["use_flash_attention_2"] = True
402
- print("Flash Attention 2 enabled!")
403
- except ImportError:
404
- print("Flash Attention not available. For faster inference, install with: pip install flash-attn")
405
-
406
- # Add quantization if specified
407
- if self.quantization == "4bit":
408
- try:
409
- from transformers import BitsAndBytesConfig
410
- model_kwargs["quantization_config"] = BitsAndBytesConfig(
411
- load_in_4bit=True,
412
- bnb_4bit_compute_dtype=torch.bfloat16
413
- )
414
- except ImportError:
415
- print("bitsandbytes not available, loading without 4-bit quantization")
416
- elif self.quantization == "8bit":
417
- model_kwargs["load_in_8bit"] = True
418
- else:
419
- model_kwargs["torch_dtype"] = torch.bfloat16
420
 
421
- # Load the model
422
- self.model = AutoModelForCausalLM.from_pretrained(
423
- self.model_name,
424
- **model_kwargs
425
- )
 
 
 
 
 
 
 
 
 
 
426
 
427
- def generate_with_streaming(self, prompt, temperature=0.7, max_tokens=1024, stream=True):
428
- """
429
- Generate text from the model with optional streaming
430
 
431
- Args:
432
- prompt: Input text prompt
433
- temperature: Temperature for sampling (0 for deterministic)
434
- max_tokens: Maximum number of tokens to generate
435
- stream: Whether to stream the output
436
 
437
- Returns:
438
- String containing the generated text, or generator if streaming
439
- """
440
- import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
- # Tokenize prompt
443
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
444
- input_ids = inputs.input_ids
445
- attention_mask = inputs.attention_mask
446
-
447
- # Set generation parameters
448
- generation_config = {
449
- "max_new_tokens": max_tokens,
450
- "temperature": temperature,
451
- "do_sample": temperature > 0,
452
- "top_p": 0.95 if temperature > 0 else 1.0,
453
- "repetition_penalty": 1.1,
454
- "pad_token_id": self.tokenizer.eos_token_id,
455
- }
456
 
457
- # If not streaming, do normal generation
458
- if not stream:
459
- with torch.no_grad():
460
- outputs = self.model.generate(
461
- input_ids=input_ids,
462
- attention_mask=attention_mask,
463
- **generation_config
 
 
 
 
 
 
 
464
  )
 
465
 
466
- # Decode the generated text (skip the prompt)
467
- generated_text = self.tokenizer.decode(
468
- outputs[0][input_ids.shape[1]:],
469
- skip_special_tokens=True
470
- )
471
 
472
- return generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
 
474
- # If streaming, yield generated tokens one by one
475
- else:
476
- generated = []
477
-
478
- # Initialize generator
479
- with torch.no_grad():
480
- generated_ids = self.model.generate(
481
- input_ids=input_ids,
482
- attention_mask=attention_mask,
483
- **generation_config,
484
- streamer=None # Would need a custom streamer here if available
485
- )
486
 
487
- # Decode the entire sequence at once (not truly streaming, but simpler)
488
- full_text = self.tokenizer.decode(
489
- generated_ids[0][input_ids.shape[1]:],
490
- skip_special_tokens=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  )
492
 
493
- return full_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
  ```
495
 
496
- ## Hardware Requirements and Optimization
497
 
498
- ### Flash Attention Benefits
499
 
500
- Flash Attention is a highly optimized implementation of the attention mechanism that:
501
 
502
- 1. **Speeds up inference by 2-3x** compared to standard attention
503
- 2. **Reduces memory usage** by avoiding materializing large attention matrices
504
- 3. **Works perfectly with 4-bit quantization** for even further optimization
505
- 4. **Scales better with sequence length**, which is important for complex coding questions
506
 
507
- For the best performance, make sure to:
508
- - Install Flash Attention (`pip install flash-attn`)
509
- - Enable it when loading the model (see QwenModelHandler class)
510
- - Use with CUDA-compatible NVIDIA GPUs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
 
512
- ### Hardware Recommendations
 
 
 
513
 
514
- For optimal performance, we recommend:
 
515
 
516
- - **GPU**: NVIDIA GPU with at least 8GB VRAM (16GB+ recommended for larger models)
517
- - **RAM**: 16GB+ system RAM
518
- - **Storage**: At least 10GB free disk space for model files
519
- - **CPU**: Modern multi-core processor (for preprocessing)
520
 
521
- ### Reducing Memory Usage
522
 
523
- If you're facing memory constraints:
524
 
525
  ```python
526
- # Use 4-bit quantization with Flash Attention for optimal memory-efficiency
 
 
527
  model_handler = QwenModelHandler(
528
  model_name="tuandunghcmut/Qwen25_Coder_MultipleChoice",
 
529
  quantization="4bit",
530
- use_flash_attention=True
531
  )
532
 
533
- # Further optimize with unsloth
534
- try:
535
- from unsloth.models import FastLanguageModel
536
- FastLanguageModel.for_inference(model_handler.model)
537
- print("Using unsloth for additional optimization")
538
- except ImportError:
539
- print("unsloth not available")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
  ```
541
 
542
- ## Usage Example
543
 
544
- Here's how to use these classes with Flash Attention enabled:
545
 
546
  ```python
547
- # 1. Load the model with Flash Attention and 4-bit quantization
548
- from transformers import AutoModelForCausalLM, AutoTokenizer
549
- import torch
 
 
 
 
 
 
550
 
551
- hub_model_id = "tuandunghcmut/Qwen25_Coder_MultipleChoice"
 
 
 
552
 
553
- # Create model handler with Flash Attention and 4-bit quantization
 
 
 
 
 
 
 
 
 
554
  model_handler = QwenModelHandler(
555
- model_name=hub_model_id,
556
- max_seq_length=2048,
557
- quantization="4bit",
558
- use_flash_attention=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
  )
560
 
561
- # Optional: Use unsloth for even faster inference
562
- try:
563
- from unsloth.models import FastLanguageModel
564
- FastLanguageModel.for_inference(model_handler.model)
565
- print("Using unsloth for faster inference")
566
- except ImportError:
567
- print("unsloth not available, using standard inference")
568
 
569
- # 2. Create prompt creator with YAML reasoning format
570
- prompt_creator = PromptCreator(PromptCreator.YAML_REASONING)
571
 
572
- # 3. Example question
573
- question = "Which of the following correctly defines a list comprehension in Python?"
574
- choices = [
575
- "[x**2 for x in range(10)]",
576
- "for(x in range(10)) { return x**2; }",
577
- "map(lambda x: x**2, range(10))",
578
- "[for x in range(10): x**2]"
579
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
 
581
- # 4. Create prompt and generate answer
582
- prompt = prompt_creator.create_inference_prompt(question, choices)
583
- response = model_handler.generate_with_streaming(prompt, temperature=0.0, stream=False)
 
 
 
 
 
 
 
 
 
 
584
 
585
- # 5. Parse the response
586
- parser = ResponseParser(prompt_creator.parser_mode)
587
- answer, reasoning = parser.parse(response)
588
 
589
- print(f"Question: {question}")
590
- print(f"Answer: {answer}")
591
- if reasoning:
592
- print(f"Reasoning: {reasoning}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
  ```
594
 
595
- ## Troubleshooting
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
 
597
- ### Common Issues
598
 
599
- 1. **Flash Attention Installation Issues**: If you encounter problems installing `flash-attn`:
600
- ```bash
601
- # Try with specific CUDA version (e.g., for CUDA 11.8)
602
- pip install flash-attn==2.3.4+cu118 --no-build-isolation
603
-
604
- # For older GPUs
605
- pip install flash-attn==2.3.4 --no-build-isolation
606
- ```
607
 
608
- 2. **CUDA Out of Memory**: Try combining 4-bit quantization with Flash Attention.
609
- ```python
610
- model_handler = QwenModelHandler(
611
- model_name=hub_model_id,
612
- quantization="4bit",
613
- use_flash_attention=True
614
- )
615
- ```
616
 
617
- 3. **Module Not Found Errors**: Make sure you've installed all required packages.
618
- ```bash
619
- pip install transformers torch unsloth datasets pyyaml bitsandbytes flash-attn
620
- ```
621
 
622
- 4. **Parsing Errors**: If the model isn't producing valid YAML responses, try adjusting the temperature:
623
- ```python
624
- response = model_handler.generate_with_streaming(prompt, temperature=0.0, stream=False)
625
- ```
626
 
627
- ### Getting Help
628
 
629
- If you encounter issues, check the [model repository on Hugging Face](https://huggingface.co/tuandunghcmut/Qwen25_Coder_MultipleChoice) for updates and community discussions.
630
 
631
- This guide provides you with all the necessary code and optimization techniques to use the model effectively for multiple-choice coding questions.
 
8
  - unsloth/Qwen2.5-Coder-1.5B-Instruct
9
  pipeline_tag: text-generation
10
  ---
11
+ # Qwen25_Coder_MultipleChoice
12
 
13
+ * This project focuses on distilling YAML-based structured multi-step reasoning capabilities from the GPT-4o teacher model into the smaller Qwen2.5 Coder 1.5B-Instruct LLM.
 
14
 
15
+ * This document provides guidance on getting started with `tuandunghcmut/Qwen25_Coder_MultipleChoice`, a model fine-tuned for multiple-choice coding questions.
16
 
17
+ * Future plans include refactoring the project into a well-structured GitHub repository, expanding the dataset, and retraining the model using distributed training for improved scalability.
 
18
 
19
+ * A demonstration notebook is available on Google Colab (click the badge below). Please note that the training code has been omitted from this notebook. It is intended solely for testing and inference using the latest checkpoint.
20
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/1Q4jtRjIkFWIAM82pAg4OBPCLjpQ8ndpI/view?usp=sharing)
21
 
22
+ * Note: The Qwen2.5 Coder 1.5B-Instruct model might be too small for this task, and the current training dataset may be insufficient. Future iterations will explore using a larger model and more extensive data. However, the current model successfully adheres to the desired YAML format and demonstrates structured reasoning.
23
+
24
+ * The guide below provides an explanation of the code presented in the notebook.
25
+
26
+ ## Installation
27
+
28
+ First, install the required dependencies:
29
 
30
  ```bash
31
  # Install core dependencies
 
41
  pip install datasets pyyaml
42
  ```
43
 
44
+ ## Key Classes
45
 
46
+ The project provides several key classes for working with the model:
 
 
 
47
 
48
+ ### 1. QwenModelHandler
49
  ```python
50
+ class QwenModelHandler:
51
+ """Handler for Qwen models with inference and saving capabilities using Unsloth"""
52
+
53
+ def __init__(self, model_name="unsloth/Qwen2.5-7B", max_seq_length=768,
54
+ quantization=None, device_map="auto", cache_dir=None):
55
+ """
56
+ Initialize model and tokenizer using Unsloth
57
+
58
+ Args:
59
+ model_name: Name or path of the model (preferably an unsloth model)
60
+ max_seq_length: Maximum sequence length for the model
61
+ quantization: Quantization type (None, '4bit', '8bit') - for compatibility
62
+ device_map: Device mapping strategy
63
+ cache_dir: Cache directory for models
64
+ """
65
  ```
66
 
67
+ This class handles the core model operations:
68
+ - Model loading and initialization
69
+ - Text generation with streaming support
70
+ - Perplexity calculation
71
+ - Model saving and pushing to HuggingFace Hub
 
 
 
 
 
 
 
72
 
73
+ ### 2. PromptCreator
74
+ ```python
75
+ class PromptCreator:
76
+ """Creates and formats prompts for multiple choice questions"""
77
+
78
+ # Prompt types
79
+ BASIC = "basic" # Simple answer-only format
80
+ YAML_REASONING = "yaml" # YAML formatted reasoning
81
+ TEACHER_REASONED = "teacher" # Same YAML format but using teacher completions
82
  ```
83
 
84
+ This class manages prompt creation with three modes:
85
+ - Basic: Simple answer-only format
86
+ - YAML Reasoning: Structured reasoning in YAML format
87
+ - Teacher Reasoned: YAML format with teacher completions for training
88
 
89
+ ### 3. ResponseParser
90
+ ```python
91
+ class ResponseParser:
92
+ """Parser for model responses with support for different formats"""
93
+
94
+ # Parser modes
95
+ BASIC = "basic" # Extract single letter answer
96
+ YAML = "yaml" # Parse YAML formatted response with reasoning
97
+ ```
98
 
99
+ This class handles response parsing:
100
+ - Extracts answers from model responses
101
+ - Parses YAML-formatted reasoning
102
+ - Supports both basic and YAML formats
103
 
104
+ ### 4. MultipleChoiceTester
105
+ ```python
106
+ class MultipleChoiceTester:
107
+ """Framework for testing Qwen models on multiple choice questions"""
108
+
109
+ def __init__(self, model_handler, prompt_creator=None):
110
+ """
111
+ Initialize with model handler and prompt configuration
112
+
113
+ Args:
114
+ model_handler: The QwenModelHandler instance
115
+ prompt_creator: Optional PromptCreator instance
116
+ """
117
  ```
118
 
119
+ This class provides a complete testing framework:
120
+ - Single example inference
121
+ - Batch processing
122
+ - Dataset evaluation
123
+ - Performance metrics tracking
124
+ - Results saving and visualization
125
 
126
+ ## Full Class Implementations
127
 
128
+ <details>
129
+ <summary>Click to expand/collapse full class implementations</summary>
130
 
131
+ ### 1. QwenModelHandler
132
+ ```python
133
+ class QwenModelHandler:
134
+ """Handler for Qwen models with inference and saving capabilities using Unsloth"""
135
+
136
+ def __init__(self, model_name="unsloth/Qwen2.5-7B", max_seq_length=768,
137
+ quantization=None, device_map="auto", cache_dir=None):
138
+ self.model_name = model_name
139
+ self.max_seq_length = max_seq_length
140
+ self.device_map = device_map
141
+ self.quantization = quantization
142
+ self.cache_dir = cache_dir
143
+
144
+ # Convert quantization parameter to load_in_4bit parameter for Unsloth
145
+ self.load_in_4bit = quantization == "4bit"
146
+
147
+ # Load tokenizer and model
148
+ self.tokenizer, self.model = self._load_model()
149
+ self.response_parser = ResponseParser()
150
+
151
+ def _load_model(self):
152
+ """Load model and tokenizer with Unsloth for optimization"""
153
+ from unsloth import FastLanguageModel
154
+ import torch
155
+
156
+ print(f"Loading {self.model_name} with Unsloth, max_seq_length={self.max_seq_length}")
157
+
158
+ # Set dtype based on hardware
159
+ dtype = None # None for auto detection
160
+
161
+ # Load model and tokenizer with Unsloth
162
+ model, tokenizer = FastLanguageModel.from_pretrained(
163
+ model_name=self.model_name,
164
+ max_seq_length=self.max_seq_length,
165
+ dtype=dtype,
166
+ load_in_4bit=self.load_in_4bit,
167
+ cache_dir=self.cache_dir,
168
+ )
169
+
170
+ return tokenizer, model
171
+
172
+ def generate_with_streaming(self, prompt, temperature=0.7, max_tokens=1024, stream=True):
173
+ """Generate completion with optional streaming using Unsloth's optimized inference"""
174
+ # Enable faster inference
175
+ from unsloth import FastLanguageModel
176
+ FastLanguageModel.for_inference(self.model)
177
+
178
+ # Format as chat
179
+ messages = [{"role": "user", "content": prompt}]
180
+ chat_text = self.tokenizer.apply_chat_template(
181
+ messages,
182
+ tokenize=False,
183
+ add_generation_prompt=True
184
+ )
185
+
186
+ # Tokenize input
187
+ model_inputs = self.tokenizer([chat_text], return_tensors="pt").to(self.model.device)
188
+
189
+ # Generate with streaming if requested
190
+ if stream:
191
+ from transformers import TextIteratorStreamer
192
+ import threading
193
+
194
+ # Set up streamer
195
+ streamer = TextIteratorStreamer(
196
+ self.tokenizer,
197
+ skip_prompt=True,
198
+ skip_special_tokens=True
199
+ )
200
+
201
+ # Start generation in a thread
202
+ generation_kwargs = {
203
+ "input_ids": model_inputs.input_ids,
204
+ "attention_mask": model_inputs.attention_mask,
205
+ "temperature": temperature,
206
+ "max_new_tokens": max_tokens,
207
+ "streamer": streamer,
208
+ "do_sample": temperature > 0.0,
209
+ "use_cache": True,
210
+ "min_p": 0.1 if temperature > 0.0 else None,
211
+ }
212
+
213
+ thread = threading.Thread(target=self.model.generate, kwargs=generation_kwargs)
214
+ thread.start()
215
+
216
+ return streamer
217
+ else:
218
+ # Generate without streaming
219
+ generated_ids = self.model.generate(
220
+ input_ids=model_inputs.input_ids,
221
+ attention_mask=model_inputs.attention_mask,
222
+ temperature=temperature,
223
+ max_new_tokens=max_tokens,
224
+ do_sample=temperature > 0.0,
225
+ use_cache=True,
226
+ min_p=0.1 if temperature > 0.0 else None,
227
+ )
228
+
229
+ # Decode the generated text
230
+ generated_text = self.tokenizer.decode(
231
+ generated_ids[0][model_inputs.input_ids.shape[1]:],
232
+ skip_special_tokens=True
233
+ )
234
+
235
+ return generated_text
236
+
237
+ def calculate_perplexity(self, prompt, answer, temperature=0.0):
238
+ """Calculate perplexity for a prompt and answer pair"""
239
+ import torch
240
+
241
+ # Format chat for perplexity calculation
242
+ messages = [
243
+ {"role": "user", "content": prompt},
244
+ {"role": "assistant", "content": answer}
245
+ ]
246
+ chat_text = self.tokenizer.apply_chat_template(
247
+ messages,
248
+ tokenize=False
249
+ )
250
+
251
+ # Tokenize the text
252
+ encodings = self.tokenizer(chat_text, return_tensors="pt").to(self.model.device)
253
+
254
+ # Calculate loss
255
+ with torch.no_grad():
256
+ outputs = self.model(**encodings, labels=encodings.input_ids)
257
+
258
+ # Get loss and calculate perplexity
259
+ neg_log_likelihood = outputs.loss.item()
260
+ perplexity = torch.exp(torch.tensor(neg_log_likelihood)).item()
261
+
262
+ return perplexity
263
+
264
+ def save_model(self, output_dir, save_method="lora"):
265
+ """Save model to disk using Unsloth's optimized methods"""
266
+ import os
267
+
268
+ os.makedirs(output_dir, exist_ok=True)
269
+
270
+ # Use Unsloth's saving methods
271
+ if save_method == "lora":
272
+ self.model.save_pretrained(output_dir)
273
+ self.tokenizer.save_pretrained(output_dir)
274
+ elif save_method == "merged_16bit":
275
+ self.model.save_pretrained_merged(output_dir, self.tokenizer, save_method="merged_16bit")
276
+ elif save_method == "merged_4bit":
277
+ self.model.save_pretrained_merged(output_dir, self.tokenizer, save_method="merged_4bit")
278
+ elif save_method == "gguf":
279
+ self.model.save_pretrained_gguf(output_dir, self.tokenizer, quantization_method="q4_k_m")
280
+ else:
281
+ raise ValueError(f"Unknown save method: {save_method}")
282
+
283
+ print(f"Model saved to {output_dir} using method {save_method}")
284
+ return output_dir
285
+
286
+ def push_to_hub(self, repo_id, token=None, save_method="lora", private=False):
287
+ """Push model to Hugging Face Hub using Unsloth's optimized methods"""
288
+ if save_method == "lora":
289
+ self.model.push_to_hub_merged(repo_id, self.tokenizer, save_method="lora", token=token)
290
+ elif save_method == "merged_16bit":
291
+ self.model.push_to_hub_merged(repo_id, self.tokenizer, save_method="merged_16bit", token=token)
292
+ elif save_method == "merged_4bit":
293
+ self.model.push_to_hub_merged(repo_id, self.tokenizer, save_method="merged_4bit", token=token)
294
+ elif save_method == "gguf":
295
+ self.model.push_to_hub_gguf(
296
+ repo_id,
297
+ self.tokenizer,
298
+ quantization_method=["q4_k_m", "q5_k_m"],
299
+ token=token
300
+ )
301
+ else:
302
+ raise ValueError(f"Unknown save method: {save_method}")
303
+
304
+ print(f"Model successfully pushed to: https://huggingface.co/{repo_id}")
305
+ return f"https://huggingface.co/{repo_id}"
306
+ ```
307
 
308
+ ### 2. PromptCreator
309
  ```python
310
  class PromptCreator:
311
+ """Creates and formats prompts for multiple choice questions"""
312
+
 
 
 
313
  # Prompt types
314
  BASIC = "basic" # Simple answer-only format
315
  YAML_REASONING = "yaml" # YAML formatted reasoning
316
+ TEACHER_REASONED = "teacher" # Same YAML format but using teacher completions
317
 
318
  def __init__(self, prompt_type=BASIC):
319
+ if prompt_type == self.TEACHER_REASONED:
320
+ prompt_type = self.YAML_REASONING
321
  self.prompt_type = prompt_type
322
+ self.original_type = prompt_type
 
 
 
 
323
 
324
  def format_choices(self, choices):
325
+ """Format choices as a lettered list"""
326
+ return "\n".join(
327
+ [f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)]
328
+ )
329
 
330
  def get_max_letter(self, choices):
331
+ """Get the maximum letter based on number of choices"""
332
  return chr(65 + len(choices) - 1)
333
 
334
  def create_inference_prompt(self, question, choices):
335
+ """Create a prompt for inference based on current prompt type"""
336
  formatted_choices = self.format_choices(choices)
337
  max_letter = self.get_max_letter(choices)
338
+
339
+ if self.prompt_type == self.YAML_REASONING:
 
 
340
  return self._create_yaml_prompt(question, formatted_choices, max_letter)
341
  else:
342
  return self._create_basic_prompt(question, formatted_choices, max_letter)
343
 
344
  def _create_basic_prompt(self, question, formatted_choices, max_letter):
345
+ """Create a basic prompt asking for just the answer letter"""
346
  return f"""
347
+ QUESTION:
348
  {question}
349
 
350
+ CHOICES:
351
  {formatted_choices}
352
 
353
+ Answer with a single letter from A through {max_letter} without any additional explanation or commentary.
354
  """
355
 
356
  def _create_yaml_prompt(self, question, formatted_choices, max_letter):
357
+ """Create a prompt requesting YAML-formatted reasoning"""
358
  return f"""
359
+ QUESTION:
360
  {question}
361
 
362
+ CHOICES:
363
  {formatted_choices}
364
 
365
+ Analyze this question step-by-step and provide a detailed explanation.
366
+ Your response MUST be in YAML format as follows:
 
 
 
367
 
 
368
  understanding: |
369
+ <your understanding of what the question is asking>
370
  analysis: |
371
  <your analysis of each option>
372
  reasoning: |
373
+ <your step-by-step reasoning process>
374
  conclusion: |
375
  <your final conclusion>
376
+ answer: <single letter A through {max_letter}>
377
+
378
+ The answer field MUST contain ONLY a single character letter.
379
  """
380
 
381
  def create_training_prompt(self, question, choices):
382
+ """Create a prompt for training with the current prompt type"""
383
  formatted_choices = self.format_choices(choices)
384
  max_letter = self.get_max_letter(choices)
385
+
386
+ if self.prompt_type == self.YAML_REASONING:
387
+ return self._create_yaml_training_prompt(
388
+ question, formatted_choices, max_letter
389
+ )
390
  else:
391
+ return self._create_basic_training_prompt(
392
+ question, formatted_choices, max_letter
393
+ )
394
 
395
  def _create_basic_training_prompt(self, question, formatted_choices, max_letter):
396
  """Create a basic training prompt"""
397
  return f"""
398
+ QUESTION:
399
  {question}
400
 
401
+ CHOICES:
402
  {formatted_choices}
403
 
404
+ The answer is a single letter (A, B, C, etc.). Only provide ONE character as your answer:
405
  """
406
 
407
  def _create_yaml_training_prompt(self, question, formatted_choices, max_letter):
408
+ """Create a YAML-formatted training prompt"""
409
  return f"""
410
+ QUESTION:
411
  {question}
412
 
413
+ CHOICES:
414
  {formatted_choices}
415
 
416
+ Analyze this question step-by-step and provide a detailed explanation.
417
+ Follow the YAML format in your response:
 
 
 
418
 
 
419
  understanding: |
420
  <your understanding of the question>
421
  analysis: |
 
424
  <your reasoning about the correct answer>
425
  conclusion: |
426
  <your final conclusion>
427
+ answer: <single letter A through {max_letter}>
428
  """
429
 
430
  def set_prompt_type(self, prompt_type):
431
+ """Set the prompt type"""
432
+ self.original_type = prompt_type
433
+ if prompt_type == self.TEACHER_REASONED:
434
+ pass
435
  self.prompt_type = prompt_type
436
+ return self
437
+
 
 
 
438
  def is_teacher_mode(self):
439
+ """Check if we're using teacher mode"""
440
+ return self.original_type == self.TEACHER_REASONED
441
  ```
442
 
443
+ ### 3. ResponseParser
 
 
 
444
  ```python
445
  class ResponseParser:
446
+ """Parser for model responses with support for different formats"""
 
 
 
447
 
448
  # Parser modes
449
+ BASIC = "basic" # Extract single letter answer
450
+ YAML = "yaml" # Parse YAML formatted response with reasoning
451
 
452
  def __init__(self, parser_mode=BASIC):
 
453
  self.parser_mode = parser_mode
454
+
455
  def parse(self, response_text):
456
+ """Parse the model's response according to the current mode"""
457
  if self.parser_mode == self.YAML:
458
  return self._parse_yaml_response(response_text)
459
  else:
460
  return self._parse_basic_response(response_text)
461
 
462
  def _parse_basic_response(self, response_text):
463
+ """Parse basic response looking for a letter answer"""
 
 
 
 
 
 
464
  import re
465
 
466
+ # Try to extract a single letter answer (A-Z)
467
+ answer_match = re.search(r"(?:^|\s)([A-Z])(?:\s|$|\.)", response_text)
 
 
 
 
 
468
  if answer_match:
469
+ answer = answer_match.group(1)
470
+ else:
471
+ # Take first character if it's a letter
472
+ if response_text and response_text[0].isalpha():
473
+ answer = response_text[0].upper()
474
+ else:
475
+ answer = None
476
+
477
+ # For basic mode, we don't extract detailed reasoning
478
+ reasoning = ""
479
+
480
+ return answer, reasoning
481
 
482
  def _parse_yaml_response(self, response_text):
483
+ """Parse YAML formatted response extracting answer and reasoning"""
 
 
 
 
 
484
  import re
485
  import yaml
486
 
487
+ # First try to find answer in YAML format
488
+ yaml_match = re.search(r"answer:\s*([A-Z])", response_text)
489
+ if yaml_match:
490
+ answer = yaml_match.group(1)
491
+ else:
492
+ # Fall back to basic extraction if YAML parsing fails
493
+ answer_match = re.search(r"(?:^|\s)([A-Z])(?:\s|$|\.)", response_text)
494
+ if answer_match:
495
+ answer = answer_match.group(1)
496
+ elif response_text and response_text[0].isalpha():
497
+ answer = response_text[0].upper()
498
+ else:
499
+ answer = None
500
 
501
+ # Try to parse reasoning from YAML format
502
+ reasoning = ""
503
+ if "reasoning:" in response_text:
504
+ yaml_content = yaml.safe_load("---\n" + response_text)
505
+ if isinstance(yaml_content, dict) and "reasoning" in yaml_content:
506
+ reasoning = yaml_content["reasoning"]
507
+
508
+ # Add other YAML fields if available
509
+ if "understanding" in yaml_content:
510
+ reasoning = f"Understanding: {yaml_content['understanding']}\n\n{reasoning}"
511
+ if "conclusion" in yaml_content:
512
+ reasoning = f"{reasoning}\n\nConclusion: {yaml_content['conclusion']}"
513
+ else:
514
+ # Use the full response as reasoning if not in YAML format
515
+ reasoning = response_text
 
 
 
 
 
 
 
 
516
 
517
+ return answer, reasoning
518
 
519
  def set_parser_mode(self, parser_mode):
520
  """Set the parser mode"""
521
  self.parser_mode = parser_mode
522
+ return self
523
 
524
  @classmethod
525
  def from_prompt_type(cls, prompt_type):
526
+ """Create a parser instance with mode matching the prompt type"""
527
+ if prompt_type == PromptCreator.YAML_REASONING or prompt_type == PromptCreator.TEACHER_REASONED:
528
+ return cls(parser_mode=cls.YAML)
 
 
 
 
 
 
 
 
529
  else:
530
+ return cls(parser_mode=cls.BASIC)
531
  ```
532
 
533
+ ### 4. MultipleChoiceTester
 
 
 
534
  ```python
535
+ class MultipleChoiceTester:
536
+ """Framework for testing Qwen models on multiple choice questions"""
537
+
538
+ def __init__(self, model_handler, prompt_creator=None):
539
+ self.model_handler = model_handler
540
+ self.prompt_creator = prompt_creator or PromptCreator(PromptCreator.BASIC)
541
+ self.response_parser = ResponseParser.from_prompt_type(self.prompt_creator.prompt_type)
542
+
543
+ def infer_example(self, example, temperature=0.7, max_tokens=1024, prompt_type=None, stream=False):
544
+ """Inference on a single example for visualization/demonstration"""
545
+ # Allow temporary override of prompt type
546
+ original_prompt_type = None
547
+ if prompt_type is not None:
548
+ original_prompt_type = self.prompt_creator.prompt_type
549
+ self.prompt_creator.set_prompt_type(prompt_type)
550
+ self.response_parser = ResponseParser.from_prompt_type(prompt_type)
551
 
552
+ # Prepare data
553
+ question = example["question"]
554
+
555
+ # Handle different formats of choices
556
+ if isinstance(example["choices"], list):
557
+ choices = example["choices"]
558
+ elif isinstance(example["choices"], str) and example["choices"].startswith("["):
559
+ import ast
560
+ choices = ast.literal_eval(example["choices"]) if "[" in example["choices"] else example["choices"].split(",")
561
+ else:
562
+ choices = str(example["choices"]).split(",")
 
 
 
563
 
564
+ # Generate the prompt using prompt creator
565
+ prompt = self.prompt_creator.create_inference_prompt(question, choices)
566
 
567
+ # Start timing
568
+ start_time = time.time()
569
 
570
+ if stream:
571
+ # Use streaming generation
572
+ streamer = self.model_handler.generate_with_streaming(
573
+ prompt=prompt,
574
+ temperature=temperature,
575
+ max_tokens=max_tokens,
576
+ stream=True
577
+ )
578
+
579
+ # Collect output from streamer
580
+ raw_response = ""
581
+ print("Model response:")
582
+ for text_chunk in streamer:
583
+ print(text_chunk, end="", flush=True)
584
+ raw_response += text_chunk
585
+ print("\n")
586
+ else:
587
+ # Generate without streaming
588
+ raw_response = self.model_handler.generate_with_streaming(
589
+ prompt=prompt,
590
+ temperature=temperature,
591
+ max_tokens=max_tokens,
592
+ stream=False
593
+ )
594
 
595
+ response_time = time.time() - start_time
596
+
597
+ # Parse the response using the response parser
598
+ predicted_answer, reasoning = self.response_parser.parse(raw_response)
 
 
599
 
600
+ # Prepare results
601
+ result = {
602
+ "question": question,
603
+ "choices": choices,
604
+ "predicted_answer": predicted_answer,
605
+ "reasoning": reasoning,
606
+ "response_time": response_time,
607
+ "raw_response": raw_response,
608
+ "prompt_type": self.prompt_creator.prompt_type,
609
  }
610
 
611
+ # Add task_id if available
612
+ if "task_id" in example:
613
+ result["task_id"] = example["task_id"]
614
+
615
+ # Calculate metrics if label is provided
616
+ if "answer" in example:
617
+ label = example["answer"]
618
+ result["correct_answer"] = label
619
+ result["is_correct"] = predicted_answer == label
620
+
621
+ # Calculate perplexity if requested
622
+ if hasattr(self.model_handler, "calculate_perplexity"):
623
+ perplexity = self.model_handler.calculate_perplexity(prompt, raw_response)
624
+ result["perplexity"] = perplexity
 
 
 
 
 
 
 
 
 
625
 
626
+ # Restore original prompt type if it was overridden
627
+ if original_prompt_type is not None:
628
+ self.prompt_creator.set_prompt_type(original_prompt_type)
629
+ self.response_parser = ResponseParser.from_prompt_type(original_prompt_type)
630
+
631
+ return result
632
+
633
+ def infer_batch(self, examples, temperature=0.7, max_tokens=1024, prompt_type=None, batch_size=4):
634
+ """Inference on a batch of examples"""
635
+ # Allow temporary override of prompt type
636
+ original_prompt_type = None
637
+ if prompt_type is not None:
638
+ original_prompt_type = self.prompt_creator.prompt_type
639
+ self.prompt_creator.set_prompt_type(prompt_type)
640
+ self.response_parser = ResponseParser.from_prompt_type(prompt_type)
641
 
642
+ # Prepare all prompts
643
+ prompts = []
644
+ metadata = []
645
 
646
+ for i, example in enumerate(examples):
647
+ # Extract data
648
+ question = example["question"]
 
 
649
 
650
+ # Handle different formats of choices
651
+ if isinstance(example["choices"], list):
652
+ choices = example["choices"]
653
+ elif isinstance(example["choices"], str) and example["choices"].startswith("["):
654
+ import ast
655
+ choices = ast.literal_eval(example["choices"]) if "[" in example["choices"] else example["choices"].split(",")
656
+ else:
657
+ choices = str(example["choices"]).split(",")
658
+
659
+ # Generate the prompt using prompt creator
660
+ prompt = self.prompt_creator.create_inference_prompt(question, choices)
661
+ prompts.append(prompt)
662
+
663
+ # Store metadata for later
664
+ meta = {
665
+ "question": question,
666
+ "choices": choices,
667
+ "index": i,
668
+ }
669
+
670
+ # Add label if available
671
+ if "answer" in example:
672
+ meta["label"] = example["answer"]
673
+
674
+ if "task_id" in example:
675
+ meta["task_id"] = example["task_id"]
676
+
677
+ metadata.append(meta)
678
 
679
+ # Process in batches
680
+ results = []
681
+ correct_count = 0
682
+ total_count = 0
683
+ perplexities = []
 
 
 
 
 
 
 
 
 
684
 
685
+ for i in range(0, len(prompts), batch_size):
686
+ batch_prompts = prompts[i:i+batch_size]
687
+ batch_meta = metadata[i:i+batch_size]
688
+
689
+ # Process batch
690
+ start_time = time.time()
691
+ batch_responses = []
692
+
693
+ for prompt in batch_prompts:
694
+ response = self.model_handler.generate_with_streaming(
695
+ prompt=prompt,
696
+ temperature=temperature,
697
+ max_tokens=max_tokens,
698
+ stream=False
699
  )
700
+ batch_responses.append(response)
701
 
702
+ batch_time = time.time() - start_time
 
 
 
 
703
 
704
+ # Process each response in the batch
705
+ for j, (response, meta) in enumerate(zip(batch_responses, batch_meta)):
706
+ # Parse response
707
+ predicted_answer, reasoning = self.response_parser.parse(response)
708
+
709
+ # Create result
710
+ result = {
711
+ "question": meta["question"],
712
+ "choices": meta["choices"],
713
+ "predicted_answer": predicted_answer,
714
+ "reasoning": reasoning,
715
+ "raw_response": response,
716
+ "prompt_type": self.prompt_creator.prompt_type,
717
+ "response_time": batch_time / len(batch_prompts),
718
+ }
719
+
720
+ # Add task_id if available
721
+ if "task_id" in meta:
722
+ result["task_id"] = meta["task_id"]
723
+
724
+ # Add metrics if label available
725
+ if "label" in meta:
726
+ label = meta["label"]
727
+ result["correct_answer"] = label
728
+ result["is_correct"] = predicted_answer == label
729
+
730
+ # Update counts for accuracy
731
+ total_count += 1
732
+ if result["is_correct"]:
733
+ correct_count += 1
734
+
735
+ # Calculate perplexity if possible
736
+ if hasattr(self.model_handler, "calculate_perplexity"):
737
+ prompt = batch_prompts[j]
738
+ perplexity = self.model_handler.calculate_perplexity(prompt, response)
739
+ result["perplexity"] = perplexity
740
+ perplexities.append(perplexity)
741
+
742
+ results.append(result)
743
 
744
+ # Calculate aggregate metrics
745
+ summary_metrics = {}
746
+ if total_count > 0:
747
+ summary_metrics["accuracy"] = correct_count / total_count
748
+ summary_metrics["correct_count"] = correct_count
749
+ summary_metrics["total_count"] = total_count
 
 
 
 
 
 
750
 
751
+ if perplexities:
752
+ summary_metrics["avg_perplexity"] = sum(perplexities) / len(perplexities)
753
+ summary_metrics["min_perplexity"] = min(perplexities)
754
+ summary_metrics["max_perplexity"] = max(perplexities)
755
+
756
+ # Restore original prompt type if it was overridden
757
+ if original_prompt_type is not None:
758
+ self.prompt_creator.set_prompt_type(original_prompt_type)
759
+ self.response_parser = ResponseParser.from_prompt_type(original_prompt_type)
760
+
761
+ return results, summary_metrics
762
+
763
+ def evaluate_dataset(self, dataset, temperature=0.7, max_tokens=1024, num_examples=None,
764
+ verbose=True, prompt_type=None, batch_size=4, log_to_wandb=False):
765
+ """Inference on a whole dataset with metrics calculation"""
766
+ # Allow overriding the prompt type for this evaluation
767
+ original_prompt_type = self.prompt_creator.prompt_type
768
+ if prompt_type is not None:
769
+ self.prompt_creator.set_prompt_type(prompt_type)
770
+ self.response_parser = ResponseParser.from_prompt_type(prompt_type)
771
+
772
+ # Select subset if specified
773
+ if num_examples is not None:
774
+ dataset = dataset.select(range(min(num_examples, len(dataset))))
775
+
776
+ results = []
777
+ correct_count = 0
778
+ total_count = 0
779
+ perplexities = []
780
+
781
+ # Process examples in batches
782
+ for i in range(0, len(dataset), batch_size):
783
+ batch_examples = dataset[i:i+batch_size]
784
+
785
+ if verbose:
786
+ batch_desc = f"Batch {i//batch_size + 1}/{(len(dataset) + batch_size - 1) // batch_size}"
787
+ print(f"\nProcessing {batch_desc} with {len(batch_examples)} examples...")
788
+
789
+ # Infer batch
790
+ batch_results, batch_metrics = self.infer_batch(
791
+ examples=batch_examples,
792
+ temperature=temperature,
793
+ max_tokens=max_tokens,
794
+ batch_size=batch_size
795
  )
796
 
797
+ # Update metrics
798
+ results.extend(batch_results)
799
+ if "correct_count" in batch_metrics:
800
+ correct_count += batch_metrics["correct_count"]
801
+ total_count += batch_metrics["total_count"]
802
+
803
+ if verbose:
804
+ batch_accuracy = batch_metrics["accuracy"]
805
+ overall_accuracy = correct_count / total_count
806
+ print(f"Batch accuracy: {batch_accuracy:.2%}, Overall: {overall_accuracy:.2%} ({correct_count}/{total_count})")
807
+
808
+ # Collect perplexities
809
+ if "avg_perplexity" in batch_metrics:
810
+ for result in batch_results:
811
+ if "perplexity" in result:
812
+ perplexities.append(result["perplexity"])
813
+
814
+ # Calculate final accuracy
815
+ accuracy = correct_count / total_count if total_count > 0 else 0.0
816
+
817
+ if verbose:
818
+ prompt_type_str = self.prompt_creator.prompt_type
819
+ print(f"\nFinal accuracy with {prompt_type_str} prompts: {accuracy:.2%} ({correct_count}/{total_count})")
820
+ if perplexities:
821
+ avg_perplexity = sum(perplexities) / len(perplexities)
822
+ print(f"Average perplexity: {avg_perplexity:.4f}")
823
+
824
+ # Prepare comprehensive summary
825
+ summary = {
826
+ "accuracy": accuracy,
827
+ "correct_count": correct_count,
828
+ "total_count": total_count,
829
+ "prompt_type": self.prompt_creator.prompt_type,
830
+ "results": results,
831
+ }
832
+
833
+ # Add perplexity metrics if available
834
+ if perplexities:
835
+ summary["avg_perplexity"] = sum(perplexities) / len(perplexities)
836
+ summary["min_perplexity"] = min(perplexities)
837
+ summary["max_perplexity"] = max(perplexities)
838
+
839
+ # Log results to wandb if requested
840
+ if log_to_wandb and wandb.run is not None:
841
+ metrics = {
842
+ "test/accuracy": accuracy,
843
+ "test/correct_count": correct_count,
844
+ "test/total_count": total_count,
845
+ }
846
+ if perplexities:
847
+ metrics["test/avg_perplexity"] = summary["avg_perplexity"]
848
+ metrics["test/min_perplexity"] = summary["min_perplexity"]
849
+ metrics["test/max_perplexity"] = summary["max_perplexity"]
850
+
851
+ wandb.log(metrics)
852
+
853
+ # Create a table of results for visualization if task_id exists
854
+ if "task_id" in dataset.features:
855
+ columns = ["task_id", "question", "correct_answer", "predicted_answer", "is_correct"]
856
+ table = wandb.Table(columns=columns)
857
+
858
+ for res in results[:min(100, len(results))]:
859
+ table.add_data(
860
+ res.get("task_id", "unknown"),
861
+ res["question"][:100] + "...",
862
+ res.get("correct_answer", ""),
863
+ res.get("predicted_answer", ""),
864
+ res.get("is_correct", False)
865
+ )
866
+
867
+ wandb.log({"test_samples": table})
868
+
869
+ # Restore original prompt type
870
+ self.prompt_creator.set_prompt_type(original_prompt_type)
871
+ self.response_parser = ResponseParser.from_prompt_type(original_prompt_type)
872
+
873
+ return summary
874
+
875
+ def save_results(self, results, output_dir="./results"):
876
+ """Save evaluation results to file"""
877
+ os.makedirs(output_dir, exist_ok=True)
878
+
879
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
880
+ results_file = os.path.join(output_dir, f"results_{timestamp}.json")
881
+
882
+ # Create serializable results
883
+ serializable_results = {
884
+ "accuracy": results.get("accuracy", 0.0),
885
+ "correct_count": results.get("correct_count", 0),
886
+ "total_count": results.get("total_count", 0),
887
+ "timestamp": timestamp,
888
+ "prompt_type": results.get("prompt_type", "unknown"),
889
+ }
890
+
891
+ # Add perplexity metrics if available
892
+ if "avg_perplexity" in results:
893
+ serializable_results["avg_perplexity"] = results["avg_perplexity"]
894
+ serializable_results["min_perplexity"] = results["min_perplexity"]
895
+ serializable_results["max_perplexity"] = results["max_perplexity"]
896
+
897
+ # Process individual results
898
+ serializable_results["individual_results"] = []
899
+ for result in results["results"]:
900
+ # Skip perplexity in individual results to save space
901
+ result_copy = result.copy()
902
+ if "perplexity" in result_copy:
903
+ del result_copy["perplexity"]
904
+
905
+ # Convert choices if needed
906
+ choices = result_copy["choices"]
907
+ if not isinstance(choices, list):
908
+ try:
909
+ import ast
910
+ result_copy["choices"] = ast.literal_eval(choices)
911
+ except (SyntaxError, ValueError):
912
+ pass
913
+
914
+ serializable_results["individual_results"].append(result_copy)
915
+
916
+ # Save to file
917
+ with open(results_file, "w") as f:
918
+ import json
919
+ json.dump(serializable_results, f, indent=2)
920
+
921
+ print(f"Results saved to {results_file}")
922
+ return results_file
923
  ```
924
 
925
+ </details>
926
 
927
+ ## Quick Start
928
 
929
+ Here's a simple example of how to use the model:
930
 
931
+ ```python
932
+ from transformers import AutoModelForCausalLM, AutoTokenizer
933
+ import torch
 
934
 
935
+ # Load the model and tokenizer
936
+ model_id = "tuandunghcmut/Qwen25_Coder_MultipleChoice"
937
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
938
+ model = AutoModelForCausalLM.from_pretrained(
939
+ model_id,
940
+ torch_dtype=torch.bfloat16,
941
+ device_map="auto",
942
+ trust_remote_code=True
943
+ )
944
+
945
+ # Example question
946
+ question = "What is the correct way to open a file in Python for reading?"
947
+ choices = [
948
+ "open('file.txt', 'r')",
949
+ "file.open('file.txt', 'read')",
950
+ "read('file.txt')",
951
+ "File.open('file.txt')"
952
+ ]
953
+
954
+ # Format the prompt
955
+ prompt = f"""
956
+ QUESTION:
957
+ {question}
958
+
959
+ CHOICES:
960
+ {chr(65 + i)}. {choice}
961
+ for i, choice in enumerate(choices)}
962
+
963
+ Answer with a single letter from A through {chr(65 + len(choices) - 1)} without any additional explanation or commentary.
964
+ """
965
 
966
+ # Generate response
967
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
968
+ outputs = model.generate(**inputs, max_new_tokens=10)
969
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
970
 
971
+ print(f"Model's answer: {response}")
972
+ ```
973
 
974
+ ## Advanced Usage
 
 
 
975
 
976
+ ### Using the MultipleChoiceTester Framework
977
 
978
+ For more advanced usage, you can use the provided `MultipleChoiceTester` framework:
979
 
980
  ```python
981
+ from save import QwenModelHandler, MultipleChoiceTester, PromptCreator
982
+
983
+ # Initialize the model handler
984
  model_handler = QwenModelHandler(
985
  model_name="tuandunghcmut/Qwen25_Coder_MultipleChoice",
986
+ max_seq_length=2048,
987
  quantization="4bit",
988
+ device_map="auto"
989
  )
990
 
991
+ # Create a prompt creator with YAML reasoning format
992
+ prompt_creator = PromptCreator(PromptCreator.YAML_REASONING)
993
+
994
+ # Initialize the tester
995
+ tester = MultipleChoiceTester(model_handler, prompt_creator=prompt_creator)
996
+
997
+ # Example question
998
+ example = {
999
+ "question": "What is the correct way to open a file in Python for reading?",
1000
+ "choices": [
1001
+ "open('file.txt', 'r')",
1002
+ "file.open('file.txt', 'read')",
1003
+ "read('file.txt')",
1004
+ "File.open('file.txt')"
1005
+ ],
1006
+ "answer": "A" # Optional ground truth
1007
+ }
1008
+
1009
+ # Get prediction with reasoning
1010
+ result = tester.infer_example(example, temperature=0.0001, stream=True)
1011
+ print(f"Predicted answer: {result['predicted_answer']}")
1012
+ print("Reasoning:")
1013
+ print(result['reasoning'])
1014
  ```
1015
 
1016
+ ### Batch Processing
1017
 
1018
+ You can also process multiple questions in batches:
1019
 
1020
  ```python
1021
+ # List of examples
1022
+ examples = [
1023
+ {
1024
+ "question": "What is the correct way to open a file in Python for reading?",
1025
+ "choices": ["open('file.txt', 'r')", "file.open('file.txt', 'read')", "read('file.txt')", "File.open('file.txt')"],
1026
+ "answer": "A"
1027
+ },
1028
+ # Add more examples...
1029
+ ]
1030
 
1031
+ # Process batch
1032
+ results, metrics = tester.infer_batch(examples, batch_size=4)
1033
+ print(f"Batch accuracy: {metrics['accuracy']:.2%}")
1034
+ ```
1035
 
1036
+ ### Streaming Inference
1037
+
1038
+ The model supports streaming inference, which provides real-time output as the model generates its response. This is particularly useful for interactive applications and when you want to see the reasoning process in real-time.
1039
+
1040
+ #### Basic Streaming Usage
1041
+
1042
+ Here's how to use streaming inference:
1043
+
1044
+ ```python
1045
+ # Initialize model handler and tester as before
1046
  model_handler = QwenModelHandler(
1047
+ model_name="tuandunghcmut/Qwen25_Coder_MultipleChoice",
1048
+ max_seq_length=2048
1049
+ )
1050
+ tester = MultipleChoiceTester(model_handler)
1051
+
1052
+ # Example with streaming
1053
+ example = {
1054
+ "question": "Which Python method is used to remove whitespace from both ends of a string?",
1055
+ "choices": [
1056
+ "strip()",
1057
+ "trim()",
1058
+ "clean()",
1059
+ "remove_whitespace()"
1060
+ ],
1061
+ "answer": "A"
1062
+ }
1063
+
1064
+ # Enable streaming with stream=True
1065
+ result = tester.infer_example(
1066
+ example,
1067
+ temperature=0.0001,
1068
+ max_tokens=1024,
1069
+ stream=True # Enable streaming
1070
  )
1071
 
1072
+ # The output will be printed in real-time as the model generates it
1073
+ # You can also access the complete response after generation
1074
+ print("\nFinal result:")
1075
+ print(f"Predicted answer: {result['predicted_answer']}")
1076
+ print("Complete reasoning:")
1077
+ print(result['reasoning'])
1078
+ ```
1079
 
1080
+ #### Advanced Streaming Patterns
 
1081
 
1082
+ ##### 1. Custom Stream Processing
1083
+
1084
+ You can process the streamed output in custom ways:
1085
+
1086
+ ```python
1087
+ def process_stream(streamer):
1088
+ """Custom stream processing function"""
1089
+ collected_text = ""
1090
+ for chunk in streamer:
1091
+ # Process each chunk as it arrives
1092
+ collected_text += chunk
1093
+ # You can do custom processing here
1094
+ # For example, parse partial YAML, update UI, etc.
1095
+ yield chunk, collected_text
1096
+
1097
+ # Use custom stream processing
1098
+ result = tester.infer_example(
1099
+ example,
1100
+ temperature=0.0001,
1101
+ stream=True
1102
+ )
1103
+
1104
+ # Process the stream with custom logic
1105
+ for chunk, full_text in process_stream(result['stream']):
1106
+ # Do something with each chunk
1107
+ print(f"Chunk: {chunk}")
1108
+ print(f"Full text so far: {full_text}")
1109
+ ```
1110
+
1111
+ ##### 2. YAML Streaming with Real-time Parsing
1112
+
1113
+ When using YAML reasoning format, you can parse the output as it streams:
1114
+
1115
+ ```python
1116
+ import yaml
1117
+ from io import StringIO
1118
+
1119
+ def parse_yaml_stream(streamer):
1120
+ """Parse YAML content as it streams"""
1121
+ buffer = StringIO()
1122
+ for chunk in streamer:
1123
+ buffer.write(chunk)
1124
+ try:
1125
+ # Try to parse the current buffer as YAML
1126
+ yaml_content = yaml.safe_load(buffer.getvalue())
1127
+ if yaml_content:
1128
+ yield chunk, yaml_content
1129
+ except yaml.YAMLError:
1130
+ # Not enough content for valid YAML yet
1131
+ continue
1132
+
1133
+ # Use YAML streaming with parsing
1134
+ result = tester.infer_example(
1135
+ example,
1136
+ temperature=0.0001,
1137
+ prompt_type=PromptCreator.YAML_REASONING,
1138
+ stream=True
1139
+ )
1140
 
1141
+ # Process YAML content as it streams
1142
+ for chunk, yaml_content in parse_yaml_stream(result['stream']):
1143
+ if isinstance(yaml_content, dict):
1144
+ # Access YAML fields as they become available
1145
+ if 'understanding' in yaml_content:
1146
+ print(f"Understanding: {yaml_content['understanding']}")
1147
+ if 'reasoning' in yaml_content:
1148
+ print(f"Reasoning: {yaml_content['reasoning']}")
1149
+ if 'answer' in yaml_content:
1150
+ print(f"Answer: {yaml_content['answer']}")
1151
+ ```
1152
+
1153
+ ##### 3. Streaming with Progress Tracking
1154
 
1155
+ You can track generation progress and timing:
 
 
1156
 
1157
+ ```python
1158
+ import time
1159
+
1160
+ def stream_with_progress(streamer):
1161
+ """Stream with progress tracking"""
1162
+ start_time = time.time()
1163
+ tokens_generated = 0
1164
+
1165
+ for chunk in streamer:
1166
+ tokens_generated += len(chunk.split())
1167
+ elapsed = time.time() - start_time
1168
+ tokens_per_second = tokens_generated / elapsed if elapsed > 0 else 0
1169
+
1170
+ yield {
1171
+ 'chunk': chunk,
1172
+ 'tokens': tokens_generated,
1173
+ 'tokens_per_second': tokens_per_second,
1174
+ 'elapsed': elapsed
1175
+ }
1176
+
1177
+ # Use streaming with progress tracking
1178
+ result = tester.infer_example(
1179
+ example,
1180
+ temperature=0.0001,
1181
+ stream=True
1182
+ )
1183
+
1184
+ for progress in stream_with_progress(result['stream']):
1185
+ print(f"Generated {progress['tokens']} tokens "
1186
+ f"({progress['tokens_per_second']:.2f} tokens/sec)")
1187
+ print(f"Chunk: {progress['chunk']}")
1188
  ```
1189
 
1190
+ #### Implementation Details
1191
+
1192
+ The streaming implementation uses Unsloth's optimized inference with the following key features:
1193
+
1194
+ 1. **Efficient Token Generation**
1195
+ - Uses Unsloth's `FastLanguageModel` for optimized inference
1196
+ - Implements streaming using `TextIteratorStreamer`
1197
+ - Supports both greedy and temperature-based sampling
1198
+
1199
+ 2. **Memory Management**
1200
+ - Streams tokens without storing the entire response in memory
1201
+ - Efficiently handles long responses
1202
+ - Supports batch processing with streaming
1203
+
1204
+ 3. **Performance Optimizations**
1205
+ - Uses `use_cache=True` for faster generation
1206
+ - Implements `min_p` sampling for better quality
1207
+ - Supports 4-bit quantization for reduced memory usage
1208
+
1209
+ 4. **Error Handling**
1210
+ - Gracefully handles streaming interruptions
1211
+ - Provides partial results if generation is interrupted
1212
+ - Maintains context for resumed generation
1213
+
1214
+ The streaming output will show the model's reasoning process in real-time, including:
1215
+ - Understanding of the question
1216
+ - Analysis of each option
1217
+ - Step-by-step reasoning
1218
+ - Final conclusion
1219
+ - Answer selection
1220
+
1221
+ This is particularly useful for:
1222
+ - Debugging model behavior
1223
+ - Creating interactive demos
1224
+ - Understanding the model's reasoning process
1225
+ - Providing immediate feedback to users
1226
+ - Building real-time applications
1227
 
1228
+ ## Model Features
1229
 
1230
+ - **YAML-Based Reasoning**: The model provides structured reasoning in YAML format
1231
+ - **Multiple Prompt Types**: Supports both basic and YAML-formatted reasoning prompts
1232
+ - **Batch Processing**: Efficiently process multiple questions at once
1233
+ - **Performance Metrics**: Tracks accuracy, perplexity, and response times
1234
+ - **Streaming Support**: Real-time output streaming for interactive use
 
 
 
1235
 
1236
+ ## Examples
 
 
 
 
 
 
 
1237
 
1238
+ Check out the [example notebook](https://colab.research.google.com/drive/1YOUR_NOTEBOOK_ID) for more detailed usage examples and demonstrations.
 
 
 
1239
 
1240
+ ## Contributing
 
 
 
1241
 
1242
+ Contributions are welcome! Please feel free to submit a Pull Request.
1243
 
1244
+ ## License
1245
 
1246
+ This project is licensed under the MIT License - see the LICENSE file for details.