Commit ·
3168689
0
Parent(s):
Duplicate from Subh775/Threat-Detection-RFDETR
Browse filesCo-authored-by: Subhansh Malviya <Subh775@users.noreply.huggingface.co>
- .gitattributes +38 -0
- README.md +212 -0
- checkpoint.pth +3 -0
- checkpoint0009.pth +3 -0
- checkpoint0019.pth +3 -0
- checkpoint0029.pth +3 -0
- checkpoint0039.pth +3 -0
- checkpoint0049.pth +3 -0
- checkpoint_best_ema.pth +3 -0
- checkpoint_best_regular.pth +3 -0
- checkpoint_best_total.pth +3 -0
- eval/000.pth +3 -0
- eval/latest.pth +3 -0
- events.out.tfevents.1759683351.7100dacbcb24.36.0 +3 -0
- log.txt +0 -0
- metrics_plot.png +3 -0
- results.json +1 -0
- video_processing.py +224 -0
.gitattributes
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
metrics_plot.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
class_distribution.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
sample_images_annotated.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
base_model:
|
| 6 |
+
- qualcomm/RF-DETR
|
| 7 |
+
pipeline_tag: object-detection
|
| 8 |
+
tags:
|
| 9 |
+
- surveillance
|
| 10 |
+
- Threat_detection
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# RF-DETR based Threat Detection Model
|
| 14 |
+
|
| 15 |
+
<a href="https://opensource.org/licenses/MIT">
|
| 16 |
+
<img src="https://img.shields.io/badge/License-MIT-yellow.svg" alt="License">
|
| 17 |
+
</a>
|
| 18 |
+
<a href="https://github.com/roboflow/rf-detr">
|
| 19 |
+
<img src="https://img.shields.io/badge/RF--DETR-Nano-purple?logo=roboflow&logoColor=white" alt="Model">
|
| 20 |
+
</a>
|
| 21 |
+
<a href="#performance-metrics">
|
| 22 |
+
<img src="https://img.shields.io/badge/mAP%4050-84.8%25-darkgreen?style=flat" alt="mAP">
|
| 23 |
+
</a>
|
| 24 |
+
<a href="https://github.com/subh-775/Threat_Detection_YOLO-vs-RF-DETR">
|
| 25 |
+
<img src="https://img.shields.io/badge/-code-black?logo=github" alt="Code">
|
| 26 |
+
</a>
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
## Transformers for Object Detection
|
| 30 |
+
|
| 31 |
+
The paradigm has shifted! While CNNs traditionally dominated object detection with faster inference times, **RF-DETR** (Roboflow's Detection Transformer) has revolutionized the field. This transformer-based architecture not only **outperforms CNNs** in accuracy but also delivers **faster inference** for real-time applications.
|
| 32 |
+
|
| 33 |
+
This repository contains a **fine-tuned RF-DETR Nano model** specifically trained for **threat detection**, capable of identifying four critical threat categories with high precision and speed.
|
| 34 |
+
|
| 35 |
+
## Predicted Results
|
| 36 |
+

|
| 37 |
+
|
| 38 |
+
### Video Inferencing
|
| 39 |
+
<video muted autoplay loop controls src="https://cdn-uploads.huggingface.co/production/uploads/66c6048d0bf40704e4159a23/5Kt3KghZaanzOVaVB6JS9.mp4" width=800></video>
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
## Model Overview
|
| 43 |
+
|
| 44 |
+
**RF-DETR Threat Detection** is a specialized computer vision model designed for security and surveillance applications. Built on Roboflow's cutting-edge RF-DETR architecture, this model can accurately detect and classify potential threats in real-time scenarios.
|
| 45 |
+
|
| 46 |
+
The threat categories are as:
|
| 47 |
+
|
| 48 |
+
| Class ID | Threat Type | Description |
|
| 49 |
+
|----------|-------------|-------------|
|
| 50 |
+
| 1 | **Gun** | Any type of firearm weapon including pistols, rifles, and other firearms |
|
| 51 |
+
| 2 | **Explosive** | Fire, explosion scenarios, and explosive devices |
|
| 52 |
+
| 3 | **Grenade** | Hand grenades and similar explosive devices |
|
| 53 |
+
| 4 | **Knife** | Bladed weapons including knives, daggers, and sharp objects |
|
| 54 |
+
|
| 55 |
+
## Training Dataset
|
| 56 |
+
|
| 57 |
+
Our custom threat detection dataset was meticulously curated and annotated to ensure robust model performance across diverse scenarios.
|
| 58 |
+
|
| 59 |
+
### Class Distribution
|
| 60 |
+

|
| 61 |
+
|
| 62 |
+
### Sample Annotations (Actual)
|
| 63 |
+

|
| 64 |
+
|
| 65 |
+
The model is trained to detect threats across various scales, from small concealed weapons to larger explosive devices.
|
| 66 |
+
|
| 67 |
+
## Performance Metrics
|
| 68 |
+
|
| 69 |
+
### Training Performance
|
| 70 |
+

|
| 71 |
+
|
| 72 |
+
The training process demonstrates excellent convergence with:
|
| 73 |
+
- **Consistent loss reduction** over 50 epochs
|
| 74 |
+
- **Stable validation performance** indicating good generalization
|
| 75 |
+
- **Balanced precision and recall** across all threat categories
|
| 76 |
+
|
| 77 |
+
### Validation Results
|
| 78 |
+
|
| 79 |
+
| Metric | Gun | Explosive | Grenade | Knife | **Overall** |
|
| 80 |
+
|--------|-----|-----------|---------|-------|-------------|
|
| 81 |
+
| **mAP@50:95** | 62.3% | 47.2% | 80.5% | 54.4% | **61.1%** |
|
| 82 |
+
| **mAP@50** | 90.1% | 69.6% | 93.7% | 85.8% | **84.8%** |
|
| 83 |
+
| **Precision** | 92.4% | 54.6% | 97.2% | 91.1% | **83.8%** |
|
| 84 |
+
| **Recall** | 85.0% | 85.0% | 85.0% | 85.0% | **85.0%** |
|
| 85 |
+
|
| 86 |
+
### Test Results
|
| 87 |
+
|
| 88 |
+
| Metric | Gun | Explosive | Grenade | Knife | **Overall** |
|
| 89 |
+
|--------|-----|-----------|---------|-------|-------------|
|
| 90 |
+
| **mAP@50:95** | 65.3% | 35.7% | 83.2% | 49.8% | **58.5%** |
|
| 91 |
+
| **mAP@50** | 93.1% | 60.5% | 91.1% | 79.7% | **81.1%** |
|
| 92 |
+
| **Precision** | 96.7% | 49.7% | 93.1% | 86.5% | **81.5%** |
|
| 93 |
+
| **Recall** | 83.0% | 83.0% | 83.0% | 83.0% | **83.0%** |
|
| 94 |
+
|
| 95 |
+
### Key Performance Highlights
|
| 96 |
+
|
| 97 |
+
- **84.8% mAP@50** on validation set
|
| 98 |
+
- **Fast inference** with RF-DETR Nano architecture
|
| 99 |
+
- **Excellent precision** for Gun (96.7%) and Grenade (93.1%) detection
|
| 100 |
+
- **Consistent recall** of 83-85% across all threat categories
|
| 101 |
+
- **Robust generalization** from validation to test performance
|
| 102 |
+
|
| 103 |
+
## Model Architecture
|
| 104 |
+
|
| 105 |
+
- **Base Architecture**: RF-DETR Nano
|
| 106 |
+
- **Input Resolution**: 640×640 pixels
|
| 107 |
+
- **Backbone**: Optimized transformer encoder
|
| 108 |
+
- **Detection Head**: Custom 4-class threat detection
|
| 109 |
+
- **Inference Speed**: ~50ms per image (GPU)
|
| 110 |
+
- **Model Size**: Lightweight for edge deployment
|
| 111 |
+
|
| 112 |
+
## Training Details
|
| 113 |
+
|
| 114 |
+
### Training Configuration
|
| 115 |
+
- **Epochs**: 50
|
| 116 |
+
- **Batch Size**: Optimized for available GPU memory
|
| 117 |
+
- **Optimizer**: AdamW with learning rate scheduling
|
| 118 |
+
- **Data Augmentation**: Advanced augmentation pipeline for robust training
|
| 119 |
+
- **Loss Function**: Multi-scale detection loss with class balancing
|
| 120 |
+
|
| 121 |
+
### Training Strategy
|
| 122 |
+
1. **Progressive Training**: Started with lower resolution, gradually increased
|
| 123 |
+
2. **Class Balancing**: Weighted loss to handle class imbalance
|
| 124 |
+
3. **Data Augmentation**: Extensive augmentation to improve generalization
|
| 125 |
+
4. **Early Stopping**: Monitored validation mAP to prevent overfitting
|
| 126 |
+
|
| 127 |
+
## Model Files
|
| 128 |
+
|
| 129 |
+
- `checkpoint_best_total.pth` - Main model weights
|
| 130 |
+
|
| 131 |
+
### Inference Instructions
|
| 132 |
+
|
| 133 |
+
```python
|
| 134 |
+
pip install -q rfdetr==1.2.1 supervision==0.26.1
|
| 135 |
+
```
|
| 136 |
+
- You can use: [video_processing.py](https://huggingface.co/Subh775/Threat-Detection-RFDETR/blob/main/video_processing.py) to process large videos
|
| 137 |
+
|
| 138 |
+
- Below is the script to process a single image
|
| 139 |
+
|
| 140 |
+
```python
|
| 141 |
+
import numpy as np
|
| 142 |
+
import supervision as sv
|
| 143 |
+
import torch
|
| 144 |
+
import requests
|
| 145 |
+
from PIL import Image
|
| 146 |
+
import os
|
| 147 |
+
|
| 148 |
+
from rfdetr import RFDETRNano
|
| 149 |
+
|
| 150 |
+
THREAT_CLASSES = {
|
| 151 |
+
1: "Gun",
|
| 152 |
+
2: "Explosive",
|
| 153 |
+
3: "Grenade",
|
| 154 |
+
4: "Knife"
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
image = Image.open("Path_to_image")
|
| 158 |
+
|
| 159 |
+
# pre-trained weights
|
| 160 |
+
weights_url = "https://huggingface.co/Subh775/Threat-Detection-RFDETR/resolve/main/checkpoint_best_total.pth"
|
| 161 |
+
weights_filename = "checkpoint_best_total.pth"
|
| 162 |
+
|
| 163 |
+
# Download weights if not already present
|
| 164 |
+
if not os.path.exists(weights_filename):
|
| 165 |
+
print(f"Downloading weights from {weights_url}")
|
| 166 |
+
response = requests.get(weights_url, stream=True)
|
| 167 |
+
response.raise_for_status()
|
| 168 |
+
with open(weights_filename, 'wb') as f:
|
| 169 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 170 |
+
f.write(chunk)
|
| 171 |
+
print("Download complete.")
|
| 172 |
+
|
| 173 |
+
model = RFDETRNano(resolution=640, pretrain_weights=weights_filename)
|
| 174 |
+
model.optimize_for_inference()
|
| 175 |
+
|
| 176 |
+
detections = model.predict(image, threshold=0.5)
|
| 177 |
+
|
| 178 |
+
color = sv.ColorPalette.from_hex([
|
| 179 |
+
"#1E90FF", "#32CD32", "#FF0000", "#FF8C00"
|
| 180 |
+
])
|
| 181 |
+
|
| 182 |
+
text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size)
|
| 183 |
+
thickness = sv.calculate_optimal_line_thickness(resolution_wh=image.size)
|
| 184 |
+
|
| 185 |
+
bbox_annotator = sv.BoxAnnotator(color=color, thickness=thickness)
|
| 186 |
+
label_annotator = sv.LabelAnnotator(
|
| 187 |
+
color=color,
|
| 188 |
+
text_color=sv.Color.BLACK,
|
| 189 |
+
text_scale=text_scale,
|
| 190 |
+
smart_position=True
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
labels = []
|
| 194 |
+
for class_id, confidence in zip(detections.class_id, detections.confidence):
|
| 195 |
+
class_name = THREAT_CLASSES.get(class_id, f"unknown_class_{class_id}")
|
| 196 |
+
labels.append(f"{class_name} {confidence:.2f}")
|
| 197 |
+
|
| 198 |
+
annotated_image = image.copy()
|
| 199 |
+
annotated_image = bbox_annotator.annotate(annotated_image, detections)
|
| 200 |
+
annotated_image = label_annotator.annotate(annotated_image, detections, labels)
|
| 201 |
+
annotated_image.thumbnail((800, 800))
|
| 202 |
+
annotated_image
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
## Acknowledgments
|
| 206 |
+
|
| 207 |
+
- **Roboflow** for the RF-DETR architecture
|
| 208 |
+
- **Hugging Face** for model hosting and distribution
|
| 209 |
+
- **PyTorch** ecosystem for deep learning framework
|
| 210 |
+
- **Supervision** library for computer vision utilities
|
| 211 |
+
|
| 212 |
+
**Disclaimer**: This model is designed for research purpose only. It's predictions cannot be taken into account for deployment right now.
|
checkpoint.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:15ce8511d2a3b5bb26c1e1cc8e16ecf11c6c3b6cc8e47fb4a7a56bb951d6c1da
|
| 3 |
+
size 483371770
|
checkpoint0009.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f4aab3f79814bdb5706a1307a68a3b3f78f9bf22afbe34b62e5d26b456ff8bf1
|
| 3 |
+
size 483381074
|
checkpoint0019.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:729c30ee4cfedd0bd8ebc1d5d000abdbf8191f27522560b7777dd0aa9faf79da
|
| 3 |
+
size 483381074
|
checkpoint0029.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8b351c26d6c238fa5c77dfae8bd03379bfdfc0e68d98e420d6b433f4a7a4f03b
|
| 3 |
+
size 483381074
|
checkpoint0039.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a5d4fdbe25e9d7667040d62a52a38fed4797d1915e5e9df986ff80fc5d37802d
|
| 3 |
+
size 483381074
|
checkpoint0049.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2494d99845cff1db236aa58be866d2748887ffefeaa17534ba48e25f820df6b3
|
| 3 |
+
size 483381074
|
checkpoint_best_ema.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:955f076abe2011face8d84d34af1f003feec3455c22fa953250270291df082d7
|
| 3 |
+
size 362564037
|
checkpoint_best_regular.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9bafd5495a9a583d8a70a9ad66231bb1bf1665fc482f293dc56838cb01cad7fd
|
| 3 |
+
size 362571545
|
checkpoint_best_total.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:481932df97ca0e3cbe73756a11a67dde6a617c548048baf67e29a9d640ed4a81
|
| 3 |
+
size 120825206
|
eval/000.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9b734fccb5a7ff64331ceaae8d1bedfebf2ca7cf3c930868988168b883fb9a2d
|
| 3 |
+
size 1196560
|
eval/latest.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5fb5a3f509742430db35d4ca0c0c3afea92e9f33e7904b0eaf37f4b170f7c75b
|
| 3 |
+
size 1193116
|
events.out.tfevents.1759683351.7100dacbcb24.36.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4392706209d595ace7553f208c3cecfeb915a9ae816726b8cedc30454680e8c7
|
| 3 |
+
size 21772
|
log.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
metrics_plot.png
ADDED
|
Git LFS Details
|
results.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"class_map": {"valid": [{"class": "Gun", "map@50:95": 0.622645348724507, "map@50": 0.9009196450890545, "precision": 0.9239130434782609, "recall": 0.85}, {"class": "explosion", "map@50:95": 0.4717619370767294, "map@50": 0.6955606171466157, "precision": 0.5463258785942492, "recall": 0.85}, {"class": "grenade", "map@50:95": 0.8045233282168012, "map@50": 0.9368179747971453, "precision": 0.9716312056737588, "recall": 0.85}, {"class": "knife", "map@50:95": 0.5443087767779382, "map@50": 0.8584450143594061, "precision": 0.9111111111111111, "recall": 0.85}, {"class": "all", "map@50:95": 0.6108098476989939, "map@50": 0.8479358128480553, "precision": 0.838245309714345, "recall": 0.85}], "test": [{"class": "Gun", "map@50:95": 0.6530858403662217, "map@50": 0.9310724801164318, "precision": 0.9672131147540983, "recall": 0.8300000000000001}, {"class": "explosion", "map@50:95": 0.3571121332596689, "map@50": 0.605062635831805, "precision": 0.4966887417218543, "recall": 0.8300000000000001}, {"class": "grenade", "map@50:95": 0.8318189072306177, "map@50": 0.9109914811178564, "precision": 0.9305555555555556, "recall": 0.8300000000000001}, {"class": "knife", "map@50:95": 0.49779614891880253, "map@50": 0.796572352660021, "precision": 0.8653846153846154, "recall": 0.8300000000000001}, {"class": "all", "map@50:95": 0.5849532574438278, "map@50": 0.8109247374315286, "precision": 0.814960506854031, "recall": 0.8300000000000001}]}, "map": 0.8479358128480553, "precision": 0.838245309714345, "recall": 0.85}
|
video_processing.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pip install -q rfdetr==1.2.1 supervision==0.26.1
|
| 2 |
+
|
| 3 |
+
# RF-DETR video processing for threat detection.
|
| 4 |
+
# Inference time depends on frame resolution (e.g., ~50 ms/frame on GPU for 640×640).
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import supervision as sv
|
| 9 |
+
import torch
|
| 10 |
+
import requests
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import os
|
| 13 |
+
import cv2
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
from rfdetr import RFDETRNano
|
| 18 |
+
|
| 19 |
+
THREAT_CLASSES = {
|
| 20 |
+
1: "Gun",
|
| 21 |
+
2: "Explosive",
|
| 22 |
+
3: "Grenade",
|
| 23 |
+
4: "Knife"
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
# Enable GPU if available
|
| 27 |
+
if torch.cuda.is_available():
|
| 28 |
+
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 29 |
+
# print(f"CUDA Version: {torch.version.cuda}")
|
| 30 |
+
# print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
| 31 |
+
|
| 32 |
+
# Optimize for batch processing
|
| 33 |
+
torch.backends.cudnn.benchmark = True
|
| 34 |
+
torch.backends.cudnn.deterministic = False
|
| 35 |
+
else:
|
| 36 |
+
print("CUDA not available, using CPU")
|
| 37 |
+
|
| 38 |
+
# Configuration
|
| 39 |
+
INPUT_VIDEO = "test_video.mp4"
|
| 40 |
+
|
| 41 |
+
base, ext = os.path.splitext(INPUT_VIDEO)
|
| 42 |
+
OUTPUT_VIDEO = f"{base}_detr{ext}"
|
| 43 |
+
|
| 44 |
+
THRESHOLD = 0.5
|
| 45 |
+
BATCH_SIZE = 32
|
| 46 |
+
|
| 47 |
+
# Auto-adjust batch size based on GPU memory
|
| 48 |
+
if torch.cuda.is_available():
|
| 49 |
+
gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 50 |
+
|
| 51 |
+
print(f"Using batch size: {BATCH_SIZE}")
|
| 52 |
+
|
| 53 |
+
# Download weights
|
| 54 |
+
weights_url = "https://huggingface.co/Subh775/Threat-Detection-RFDETR/resolve/main/checkpoint_best_total.pth"
|
| 55 |
+
weights_filename = "checkpoint_best_total.pth"
|
| 56 |
+
|
| 57 |
+
if not os.path.exists(weights_filename):
|
| 58 |
+
print(f"Downloading weights from {weights_url}")
|
| 59 |
+
response = requests.get(weights_url, stream=True)
|
| 60 |
+
response.raise_for_status()
|
| 61 |
+
with open(weights_filename, 'wb') as f:
|
| 62 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 63 |
+
f.write(chunk)
|
| 64 |
+
print("Download complete.")
|
| 65 |
+
|
| 66 |
+
print("Loading model...")
|
| 67 |
+
model = RFDETRNano(resolution=640, pretrain_weights=weights_filename)
|
| 68 |
+
model.optimize_for_inference()
|
| 69 |
+
|
| 70 |
+
# Setup annotators
|
| 71 |
+
color = sv.ColorPalette.from_hex([
|
| 72 |
+
"#1E90FF", "#32CD32", "#FF0000", "#FF8C00"
|
| 73 |
+
])
|
| 74 |
+
|
| 75 |
+
bbox_annotator = sv.BoxAnnotator(color=color, thickness=3)
|
| 76 |
+
label_annotator = sv.LabelAnnotator(
|
| 77 |
+
color=color,
|
| 78 |
+
text_color=sv.Color.BLACK,
|
| 79 |
+
text_scale=1.0,
|
| 80 |
+
text_thickness=2,
|
| 81 |
+
smart_position=True
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def process_frame_batch(frames):
|
| 85 |
+
"""Process a batch of frames for better GPU utilization"""
|
| 86 |
+
batch_results = []
|
| 87 |
+
|
| 88 |
+
# Convert all frames to PIL images
|
| 89 |
+
pil_images = []
|
| 90 |
+
for frame in frames:
|
| 91 |
+
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 92 |
+
pil_image = Image.fromarray(rgb_frame)
|
| 93 |
+
pil_images.append(pil_image)
|
| 94 |
+
|
| 95 |
+
# Process each image in the batch (RF-DETR processes them efficiently)
|
| 96 |
+
batch_detections = []
|
| 97 |
+
for pil_image in pil_images:
|
| 98 |
+
detections = model.predict(pil_image, threshold=THRESHOLD)
|
| 99 |
+
batch_detections.append(detections)
|
| 100 |
+
|
| 101 |
+
# Annotate all images in the batch
|
| 102 |
+
annotated_frames = []
|
| 103 |
+
for pil_image, detections in zip(pil_images, batch_detections):
|
| 104 |
+
# Create labels
|
| 105 |
+
labels = []
|
| 106 |
+
for class_id, confidence in zip(detections.class_id, detections.confidence):
|
| 107 |
+
class_name = THREAT_CLASSES.get(class_id, f"unknown_class_{class_id}")
|
| 108 |
+
labels.append(f"{class_name} {confidence:.2f}")
|
| 109 |
+
|
| 110 |
+
# Annotate
|
| 111 |
+
annotated_pil = pil_image.copy()
|
| 112 |
+
annotated_pil = bbox_annotator.annotate(annotated_pil, detections)
|
| 113 |
+
annotated_pil = label_annotator.annotate(annotated_pil, detections, labels)
|
| 114 |
+
|
| 115 |
+
# Convert back to BGR
|
| 116 |
+
annotated_frame = cv2.cvtColor(np.array(annotated_pil), cv2.COLOR_RGB2BGR)
|
| 117 |
+
annotated_frames.append(annotated_frame)
|
| 118 |
+
|
| 119 |
+
return annotated_frames, batch_detections
|
| 120 |
+
|
| 121 |
+
# Open video
|
| 122 |
+
cap = cv2.VideoCapture(INPUT_VIDEO)
|
| 123 |
+
if not cap.isOpened():
|
| 124 |
+
print(f"Error: Could not open video file {INPUT_VIDEO}")
|
| 125 |
+
exit()
|
| 126 |
+
|
| 127 |
+
# Get video properties
|
| 128 |
+
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 129 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 130 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 131 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 132 |
+
|
| 133 |
+
print(f"Video: {width}x{height}, {fps} FPS, {total_frames} frames")
|
| 134 |
+
print(f"Processing in batches of {BATCH_SIZE} frames")
|
| 135 |
+
|
| 136 |
+
# Setup video writer
|
| 137 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 138 |
+
out = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (width, height))
|
| 139 |
+
|
| 140 |
+
# Batch processing
|
| 141 |
+
print("Processing video with batch inference...")
|
| 142 |
+
frame_buffer = []
|
| 143 |
+
total_detections = 0
|
| 144 |
+
processed_frames = 0
|
| 145 |
+
processing_times = []
|
| 146 |
+
|
| 147 |
+
with tqdm(total=total_frames, desc="Batch processing") as pbar:
|
| 148 |
+
while True:
|
| 149 |
+
ret, frame = cap.read()
|
| 150 |
+
if not ret:
|
| 151 |
+
# Process remaining frames in buffer
|
| 152 |
+
if frame_buffer:
|
| 153 |
+
start_time = time.time()
|
| 154 |
+
annotated_frames, batch_detections = process_frame_batch(frame_buffer)
|
| 155 |
+
processing_time = time.time() - start_time
|
| 156 |
+
processing_times.append(processing_time)
|
| 157 |
+
|
| 158 |
+
# Write remaining frames
|
| 159 |
+
for annotated_frame, detections in zip(annotated_frames, batch_detections):
|
| 160 |
+
out.write(annotated_frame)
|
| 161 |
+
total_detections += len(detections)
|
| 162 |
+
|
| 163 |
+
processed_frames += len(frame_buffer)
|
| 164 |
+
pbar.update(len(frame_buffer))
|
| 165 |
+
break
|
| 166 |
+
|
| 167 |
+
# Add frame to buffer
|
| 168 |
+
frame_buffer.append(frame)
|
| 169 |
+
|
| 170 |
+
# Process when buffer is full
|
| 171 |
+
if len(frame_buffer) >= BATCH_SIZE:
|
| 172 |
+
start_time = time.time()
|
| 173 |
+
|
| 174 |
+
# Process batch
|
| 175 |
+
annotated_frames, batch_detections = process_frame_batch(frame_buffer)
|
| 176 |
+
|
| 177 |
+
processing_time = time.time() - start_time
|
| 178 |
+
processing_times.append(processing_time)
|
| 179 |
+
|
| 180 |
+
# Write frames
|
| 181 |
+
batch_threats = 0
|
| 182 |
+
for annotated_frame, detections in zip(annotated_frames, batch_detections):
|
| 183 |
+
out.write(annotated_frame)
|
| 184 |
+
batch_threats += len(detections)
|
| 185 |
+
total_detections += len(detections)
|
| 186 |
+
|
| 187 |
+
processed_frames += len(frame_buffer)
|
| 188 |
+
|
| 189 |
+
# Update progress
|
| 190 |
+
batch_fps = len(frame_buffer) / processing_time if processing_time > 0 else 0
|
| 191 |
+
pbar.set_postfix({
|
| 192 |
+
'Batch FPS': f"{batch_fps:.1f}",
|
| 193 |
+
'Threats': batch_threats,
|
| 194 |
+
'Total': total_detections
|
| 195 |
+
})
|
| 196 |
+
pbar.update(len(frame_buffer))
|
| 197 |
+
|
| 198 |
+
# Clear buffer
|
| 199 |
+
frame_buffer = []
|
| 200 |
+
|
| 201 |
+
# Clear GPU cache every 10 batches
|
| 202 |
+
if torch.cuda.is_available() and processed_frames % (BATCH_SIZE * 10) == 0:
|
| 203 |
+
torch.cuda.empty_cache()
|
| 204 |
+
|
| 205 |
+
# Cleanup
|
| 206 |
+
cap.release()
|
| 207 |
+
out.release()
|
| 208 |
+
|
| 209 |
+
if torch.cuda.is_available():
|
| 210 |
+
torch.cuda.empty_cache()
|
| 211 |
+
|
| 212 |
+
# Performance summary
|
| 213 |
+
total_time = sum(processing_times)
|
| 214 |
+
avg_fps = processed_frames / total_time if total_time > 0 else 0
|
| 215 |
+
speedup = avg_fps / fps if fps > 0 else 0
|
| 216 |
+
|
| 217 |
+
print(f"Output: {OUTPUT_VIDEO}")
|
| 218 |
+
print(f"Stats:")
|
| 219 |
+
print(f" • Processed: {processed_frames} frames")
|
| 220 |
+
print(f" • Detections: {total_detections}")
|
| 221 |
+
print(f" • Batch size: {BATCH_SIZE}")
|
| 222 |
+
print(f" • Average speed: {avg_fps:.1f} FPS")
|
| 223 |
+
print(f" • Speedup: {speedup:.1f}x real-time")
|
| 224 |
+
print(f" • Processing time: {total_time:.1f}s")
|