ojus1 commited on
Commit
5720a91
·
verified ·
1 Parent(s): 79f33e1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +38 -19
README.md CHANGED
@@ -29,13 +29,14 @@ MiniGuard-v0.1 uses the **same prompt template** as [nvidia/Llama-3.1-Nemotron-S
29
 
30
  ## Training
31
 
32
- MiniGuard-v0.1 was trained using three key techniques:
33
 
34
- 1. **LoRA Fine-tuning** — English Subset of [nvidia/Nemotron-Safety-Guard-Dataset-v3](https://huggingface.co/datasets/nvidia/Nemotron-Safety-Guard-Dataset-v3) + Reasoning traces from [openai/gpt-oss-safeguard-120b](https://huggingface.co/openai/gpt-oss-safeguard-120b).
35
 
36
- 2. **Distilling Step-by-Step** — A teacher LLM generates reasoning traces for training examples. The student model is trained on both reasoning-augmented and standard examples, improving performance even when reasoning is not generated at inference time. ([Reference](https://research.google/blog/distilling-step-by-step-outperforming-larger-language-models-with-less-training-data-and-smaller-model-sizes/))
 
 
37
 
38
- 3. **Greedy Model Soup** — Multiple fine-tuned checkpoints are averaged using a greedy selection strategy: checkpoints are sorted by validation accuracy and sequentially added to the "soup" only if they improve or maintain performance. This provides a free accuracy boost without additional compute. ([Reference](https://arxiv.org/abs/2203.05482))
39
 
40
  ## Evaluation
41
 
@@ -46,38 +47,56 @@ Dataset - English subset test split of [nvidia/Nemotron-Safety-Guard-Dataset-v3]
46
  | Metric | MiniGuard-v0.1 | Nemotron-Guard-8B-v3 |
47
  |--------|----------------|----------------------|
48
  | Parameters | **0.6B** | 8B |
49
- | Overall F1 | 0.881 | 0.893 |
50
- | Accuracy Retained | **98.7%** | 100% |
51
  | Size Reduction | **13x** | 1x |
52
 
53
- ### Production Dataset Evaluation
54
 
55
- Evaluated on an internal dataset of real-user prompts. Cost estimated based on H200 GPU pricing ($3.50/hour) at concurrency 16 with P95 latency SLA of <500ms.
 
 
56
 
57
  | Metric | MiniGuard-v0.1 | Nemotron-Guard-8B-v3 |
58
  |--------|----------------|----------------------|
59
- | Relative Safety Score | 97.4% | 100% |
60
- | Relative Cost | $5.4 | $8 |
 
61
 
62
- MiniGuard-v0.1 achieves comparable safety scores while costing ~30% less per million requests.
 
63
 
64
  ### Ablation Study
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  Dataset - English subset test split of [nvidia/Nemotron-Safety-Guard-Dataset-v3](https://huggingface.co/datasets/nvidia/Nemotron-Safety-Guard-Dataset-v3).
67
 
68
- | Model | Overall F1 | User Safety F1 | Response Safety F1 |
69
- |-------|------------|----------------|-------------------|
70
- | Qwen3-0.6B (baseline) | 0.637 | 0.594 | 0.681 |
71
- | + Vanilla SFT | 0.844 | 0.843 | 0.844 |
72
- | + Distilling Step-by-Step | 0.882 | 0.873 | 0.892 |
73
- | + Greedy Model Soup | 0.886 | 0.873 | 0.899 |
74
- | + FP8 Quantization | 0.881 | 0.872 | 0.890 |
 
75
 
76
  ## Input
77
  **Input Type(s)**: Text <br>
78
  **Input Format(s)**: String <br>
79
  **Input Parameters**: One-Dimensional (1D): Sequences <br>
80
- **Other Properties Related to Input**: Context length up to 8K. Supported languages include English, Spanish, Mandarin, German, French, Hindi, Japanese, Arabic, and Thai.
81
 
82
  ## Output
83
  **Output Type(s)**: Text Json <br>
 
29
 
30
  ## Training
31
 
32
+ MiniGuard-v0.1 was trained using four key techniques to break the trade-off between safety and latency:
33
 
34
+ 1. **Targeted Synthetic Data** — To address specific failure modes (e.g., sports terms, ambiguous edge cases), we generated ~1,200 targeted examples using **Hermes-4.3-36B**. This data complements the English subset of [nvidia/Nemotron-Safety-Guard-Dataset-v3](https://huggingface.co/datasets/nvidia/Nemotron-Safety-Guard-Dataset-v3).
35
 
36
+ 2. **Think SFT (Distilling Step-by-Step)** — A teacher LLM (**gpt-oss-safeguard-120b**) generates reasoning traces for training examples. The student model is trained on these traces but discards them at inference, retaining reasoning capabilities without the token cost.
37
+
38
+ 3. **Top-K Model Soup** — We employ a Top-K (K=3) weight averaging strategy. Weights from the top 3 validation checkpoints are averaged to improve out-of-distribution generalization without increasing inference overhead.
39
 
 
40
 
41
  ## Evaluation
42
 
 
47
  | Metric | MiniGuard-v0.1 | Nemotron-Guard-8B-v3 |
48
  |--------|----------------|----------------------|
49
  | Parameters | **0.6B** | 8B |
50
+ | Weighted F1 | 88.9 | 89.3 |
51
+ | Accuracy Retained | **99.5%** | 100% |
52
  | Size Reduction | **13x** | 1x |
53
 
 
54
 
55
+ #### Production Dataset Evaluation
56
+
57
+ Evaluated on out-of-distribution production data containing real user queries. Cost estimated based on H200 GPU pricing ($7.91/hour) at maximum concurrency with P95 latency SLA of <500ms.
58
 
59
  | Metric | MiniGuard-v0.1 | Nemotron-Guard-8B-v3 |
60
  |--------|----------------|----------------------|
61
+ | Relative Macro F1 | 91.1% | 100% |
62
+ | Cost per 1M requests | **$15.54** | $46.93 |
63
+ | Cost Savings | **67%** | - |
64
 
65
+
66
+ MiniGuard-v0.1 achieves 91.1% relative performance on out-of-distribution data while costing **67% less** to serve.
67
 
68
  ### Ablation Study
69
 
70
+ #### Out-of-Distribution: Production Dataset
71
+
72
+ Impact of techniques on out-of-distribution production data (Relative Macro F1 compared to Nemotron-Guard-8B).
73
+
74
+ | Configuration | Parameters | Rel. Macro F1 | Improvement |
75
+ | :--- | :--- | :--- | :--- |
76
+ | Qwen3-0.6B + Think SFT | 0.6B | 85.6% | baseline |
77
+ | + Targeted Synthetic Data | 0.6B | 87.2% | +1.6% |
78
+ | + Soup (top-3) [MiniGuard-v0.1] | 0.6B | 92.3% | +5.1% |
79
+ | + FP8 | 0.6B | 91.1% | -1.2% |
80
+ | Nemotron-Guard-8B-v3 | 8B | 100% | reference |
81
+
82
+ #### In-Distribution
83
+
84
  Dataset - English subset test split of [nvidia/Nemotron-Safety-Guard-Dataset-v3](https://huggingface.co/datasets/nvidia/Nemotron-Safety-Guard-Dataset-v3).
85
 
86
+ | Training Configuration | Weighted F1 | Macro F1 |
87
+ |-------|------------|----------------|
88
+ | Qwen3-0.6B (base) | 63.7 | 52.5 |
89
+ | + Vanilla SFT | 84.4 | 85.0 |
90
+ | + Think SFT (distillation) | 88.2 | 88.6 |
91
+ | + Targeted Synthetic Data | 88.9 | 89.3 |
92
+ | + Top-3 Model Soup | 88.8 | 89.2 |
93
+ | + FP8 Quantization | 88.9 | 89.3 |
94
 
95
  ## Input
96
  **Input Type(s)**: Text <br>
97
  **Input Format(s)**: String <br>
98
  **Input Parameters**: One-Dimensional (1D): Sequences <br>
99
+ **Other Properties Related to Input**: Context length up to 32K. Supported language: English
100
 
101
  ## Output
102
  **Output Type(s)**: Text Json <br>