dabbledabble-IND-da-air Subh775 commited on
Commit
3168689
·
0 Parent(s):

Duplicate from Subh775/Threat-Detection-RFDETR

Browse files

Co-authored-by: Subhansh Malviya <Subh775@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ metrics_plot.png filter=lfs diff=lfs merge=lfs -text
37
+ class_distribution.png filter=lfs diff=lfs merge=lfs -text
38
+ sample_images_annotated.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ base_model:
6
+ - qualcomm/RF-DETR
7
+ pipeline_tag: object-detection
8
+ tags:
9
+ - surveillance
10
+ - Threat_detection
11
+ ---
12
+
13
+ # RF-DETR based Threat Detection Model
14
+
15
+ <a href="https://opensource.org/licenses/MIT">
16
+ <img src="https://img.shields.io/badge/License-MIT-yellow.svg" alt="License">
17
+ </a>
18
+ <a href="https://github.com/roboflow/rf-detr">
19
+ <img src="https://img.shields.io/badge/RF--DETR-Nano-purple?logo=roboflow&logoColor=white" alt="Model">
20
+ </a>
21
+ <a href="#performance-metrics">
22
+ <img src="https://img.shields.io/badge/mAP%4050-84.8%25-darkgreen?style=flat" alt="mAP">
23
+ </a>
24
+ <a href="https://github.com/subh-775/Threat_Detection_YOLO-vs-RF-DETR">
25
+ <img src="https://img.shields.io/badge/-code-black?logo=github" alt="Code">
26
+ </a>
27
+
28
+
29
+ ## Transformers for Object Detection
30
+
31
+ The paradigm has shifted! While CNNs traditionally dominated object detection with faster inference times, **RF-DETR** (Roboflow's Detection Transformer) has revolutionized the field. This transformer-based architecture not only **outperforms CNNs** in accuracy but also delivers **faster inference** for real-time applications.
32
+
33
+ This repository contains a **fine-tuned RF-DETR Nano model** specifically trained for **threat detection**, capable of identifying four critical threat categories with high precision and speed.
34
+
35
+ ## Predicted Results
36
+ ![predictions](https://cdn-uploads.huggingface.co/production/uploads/66c6048d0bf40704e4159a23/MDRT7LUt1RQE60CGW8to4.jpeg)
37
+
38
+ ### Video Inferencing
39
+ <video muted autoplay loop controls src="https://cdn-uploads.huggingface.co/production/uploads/66c6048d0bf40704e4159a23/5Kt3KghZaanzOVaVB6JS9.mp4" width=800></video>
40
+
41
+
42
+ ## Model Overview
43
+
44
+ **RF-DETR Threat Detection** is a specialized computer vision model designed for security and surveillance applications. Built on Roboflow's cutting-edge RF-DETR architecture, this model can accurately detect and classify potential threats in real-time scenarios.
45
+
46
+ The threat categories are as:
47
+
48
+ | Class ID | Threat Type | Description |
49
+ |----------|-------------|-------------|
50
+ | 1 | **Gun** | Any type of firearm weapon including pistols, rifles, and other firearms |
51
+ | 2 | **Explosive** | Fire, explosion scenarios, and explosive devices |
52
+ | 3 | **Grenade** | Hand grenades and similar explosive devices |
53
+ | 4 | **Knife** | Bladed weapons including knives, daggers, and sharp objects |
54
+
55
+ ## Training Dataset
56
+
57
+ Our custom threat detection dataset was meticulously curated and annotated to ensure robust model performance across diverse scenarios.
58
+
59
+ ### Class Distribution
60
+ ![class_distribution](https://cdn-uploads.huggingface.co/production/uploads/66c6048d0bf40704e4159a23/5t7k-SJfuZWXJTek_RPWh.png)
61
+
62
+ ### Sample Annotations (Actual)
63
+ ![sample_images_annotated](https://cdn-uploads.huggingface.co/production/uploads/66c6048d0bf40704e4159a23/Mf65kxTEwfq9HPMlzwO5y.png)
64
+
65
+ The model is trained to detect threats across various scales, from small concealed weapons to larger explosive devices.
66
+
67
+ ## Performance Metrics
68
+
69
+ ### Training Performance
70
+ ![Training Metrics](metrics_plot.png)
71
+
72
+ The training process demonstrates excellent convergence with:
73
+ - **Consistent loss reduction** over 50 epochs
74
+ - **Stable validation performance** indicating good generalization
75
+ - **Balanced precision and recall** across all threat categories
76
+
77
+ ### Validation Results
78
+
79
+ | Metric | Gun | Explosive | Grenade | Knife | **Overall** |
80
+ |--------|-----|-----------|---------|-------|-------------|
81
+ | **mAP@50:95** | 62.3% | 47.2% | 80.5% | 54.4% | **61.1%** |
82
+ | **mAP@50** | 90.1% | 69.6% | 93.7% | 85.8% | **84.8%** |
83
+ | **Precision** | 92.4% | 54.6% | 97.2% | 91.1% | **83.8%** |
84
+ | **Recall** | 85.0% | 85.0% | 85.0% | 85.0% | **85.0%** |
85
+
86
+ ### Test Results
87
+
88
+ | Metric | Gun | Explosive | Grenade | Knife | **Overall** |
89
+ |--------|-----|-----------|---------|-------|-------------|
90
+ | **mAP@50:95** | 65.3% | 35.7% | 83.2% | 49.8% | **58.5%** |
91
+ | **mAP@50** | 93.1% | 60.5% | 91.1% | 79.7% | **81.1%** |
92
+ | **Precision** | 96.7% | 49.7% | 93.1% | 86.5% | **81.5%** |
93
+ | **Recall** | 83.0% | 83.0% | 83.0% | 83.0% | **83.0%** |
94
+
95
+ ### Key Performance Highlights
96
+
97
+ - **84.8% mAP@50** on validation set
98
+ - **Fast inference** with RF-DETR Nano architecture
99
+ - **Excellent precision** for Gun (96.7%) and Grenade (93.1%) detection
100
+ - **Consistent recall** of 83-85% across all threat categories
101
+ - **Robust generalization** from validation to test performance
102
+
103
+ ## Model Architecture
104
+
105
+ - **Base Architecture**: RF-DETR Nano
106
+ - **Input Resolution**: 640×640 pixels
107
+ - **Backbone**: Optimized transformer encoder
108
+ - **Detection Head**: Custom 4-class threat detection
109
+ - **Inference Speed**: ~50ms per image (GPU)
110
+ - **Model Size**: Lightweight for edge deployment
111
+
112
+ ## Training Details
113
+
114
+ ### Training Configuration
115
+ - **Epochs**: 50
116
+ - **Batch Size**: Optimized for available GPU memory
117
+ - **Optimizer**: AdamW with learning rate scheduling
118
+ - **Data Augmentation**: Advanced augmentation pipeline for robust training
119
+ - **Loss Function**: Multi-scale detection loss with class balancing
120
+
121
+ ### Training Strategy
122
+ 1. **Progressive Training**: Started with lower resolution, gradually increased
123
+ 2. **Class Balancing**: Weighted loss to handle class imbalance
124
+ 3. **Data Augmentation**: Extensive augmentation to improve generalization
125
+ 4. **Early Stopping**: Monitored validation mAP to prevent overfitting
126
+
127
+ ## Model Files
128
+
129
+ - `checkpoint_best_total.pth` - Main model weights
130
+
131
+ ### Inference Instructions
132
+
133
+ ```python
134
+ pip install -q rfdetr==1.2.1 supervision==0.26.1
135
+ ```
136
+ - You can use: [video_processing.py](https://huggingface.co/Subh775/Threat-Detection-RFDETR/blob/main/video_processing.py) to process large videos
137
+
138
+ - Below is the script to process a single image
139
+
140
+ ```python
141
+ import numpy as np
142
+ import supervision as sv
143
+ import torch
144
+ import requests
145
+ from PIL import Image
146
+ import os
147
+
148
+ from rfdetr import RFDETRNano
149
+
150
+ THREAT_CLASSES = {
151
+ 1: "Gun",
152
+ 2: "Explosive",
153
+ 3: "Grenade",
154
+ 4: "Knife"
155
+ }
156
+
157
+ image = Image.open("Path_to_image")
158
+
159
+ # pre-trained weights
160
+ weights_url = "https://huggingface.co/Subh775/Threat-Detection-RFDETR/resolve/main/checkpoint_best_total.pth"
161
+ weights_filename = "checkpoint_best_total.pth"
162
+
163
+ # Download weights if not already present
164
+ if not os.path.exists(weights_filename):
165
+ print(f"Downloading weights from {weights_url}")
166
+ response = requests.get(weights_url, stream=True)
167
+ response.raise_for_status()
168
+ with open(weights_filename, 'wb') as f:
169
+ for chunk in response.iter_content(chunk_size=8192):
170
+ f.write(chunk)
171
+ print("Download complete.")
172
+
173
+ model = RFDETRNano(resolution=640, pretrain_weights=weights_filename)
174
+ model.optimize_for_inference()
175
+
176
+ detections = model.predict(image, threshold=0.5)
177
+
178
+ color = sv.ColorPalette.from_hex([
179
+ "#1E90FF", "#32CD32", "#FF0000", "#FF8C00"
180
+ ])
181
+
182
+ text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size)
183
+ thickness = sv.calculate_optimal_line_thickness(resolution_wh=image.size)
184
+
185
+ bbox_annotator = sv.BoxAnnotator(color=color, thickness=thickness)
186
+ label_annotator = sv.LabelAnnotator(
187
+ color=color,
188
+ text_color=sv.Color.BLACK,
189
+ text_scale=text_scale,
190
+ smart_position=True
191
+ )
192
+
193
+ labels = []
194
+ for class_id, confidence in zip(detections.class_id, detections.confidence):
195
+ class_name = THREAT_CLASSES.get(class_id, f"unknown_class_{class_id}")
196
+ labels.append(f"{class_name} {confidence:.2f}")
197
+
198
+ annotated_image = image.copy()
199
+ annotated_image = bbox_annotator.annotate(annotated_image, detections)
200
+ annotated_image = label_annotator.annotate(annotated_image, detections, labels)
201
+ annotated_image.thumbnail((800, 800))
202
+ annotated_image
203
+ ```
204
+
205
+ ## Acknowledgments
206
+
207
+ - **Roboflow** for the RF-DETR architecture
208
+ - **Hugging Face** for model hosting and distribution
209
+ - **PyTorch** ecosystem for deep learning framework
210
+ - **Supervision** library for computer vision utilities
211
+
212
+ **Disclaimer**: This model is designed for research purpose only. It's predictions cannot be taken into account for deployment right now.
checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15ce8511d2a3b5bb26c1e1cc8e16ecf11c6c3b6cc8e47fb4a7a56bb951d6c1da
3
+ size 483371770
checkpoint0009.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4aab3f79814bdb5706a1307a68a3b3f78f9bf22afbe34b62e5d26b456ff8bf1
3
+ size 483381074
checkpoint0019.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:729c30ee4cfedd0bd8ebc1d5d000abdbf8191f27522560b7777dd0aa9faf79da
3
+ size 483381074
checkpoint0029.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b351c26d6c238fa5c77dfae8bd03379bfdfc0e68d98e420d6b433f4a7a4f03b
3
+ size 483381074
checkpoint0039.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5d4fdbe25e9d7667040d62a52a38fed4797d1915e5e9df986ff80fc5d37802d
3
+ size 483381074
checkpoint0049.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2494d99845cff1db236aa58be866d2748887ffefeaa17534ba48e25f820df6b3
3
+ size 483381074
checkpoint_best_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:955f076abe2011face8d84d34af1f003feec3455c22fa953250270291df082d7
3
+ size 362564037
checkpoint_best_regular.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9bafd5495a9a583d8a70a9ad66231bb1bf1665fc482f293dc56838cb01cad7fd
3
+ size 362571545
checkpoint_best_total.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:481932df97ca0e3cbe73756a11a67dde6a617c548048baf67e29a9d640ed4a81
3
+ size 120825206
eval/000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b734fccb5a7ff64331ceaae8d1bedfebf2ca7cf3c930868988168b883fb9a2d
3
+ size 1196560
eval/latest.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5fb5a3f509742430db35d4ca0c0c3afea92e9f33e7904b0eaf37f4b170f7c75b
3
+ size 1193116
events.out.tfevents.1759683351.7100dacbcb24.36.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4392706209d595ace7553f208c3cecfeb915a9ae816726b8cedc30454680e8c7
3
+ size 21772
log.txt ADDED
The diff for this file is too large to render. See raw diff
 
metrics_plot.png ADDED

Git LFS Details

  • SHA256: d8ce23e63dcc594b1c38f7f8bf4a4e191b102bb59b13f49344ccc26d750fed62
  • Pointer size: 131 Bytes
  • Size of remote file: 238 kB
results.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"class_map": {"valid": [{"class": "Gun", "map@50:95": 0.622645348724507, "map@50": 0.9009196450890545, "precision": 0.9239130434782609, "recall": 0.85}, {"class": "explosion", "map@50:95": 0.4717619370767294, "map@50": 0.6955606171466157, "precision": 0.5463258785942492, "recall": 0.85}, {"class": "grenade", "map@50:95": 0.8045233282168012, "map@50": 0.9368179747971453, "precision": 0.9716312056737588, "recall": 0.85}, {"class": "knife", "map@50:95": 0.5443087767779382, "map@50": 0.8584450143594061, "precision": 0.9111111111111111, "recall": 0.85}, {"class": "all", "map@50:95": 0.6108098476989939, "map@50": 0.8479358128480553, "precision": 0.838245309714345, "recall": 0.85}], "test": [{"class": "Gun", "map@50:95": 0.6530858403662217, "map@50": 0.9310724801164318, "precision": 0.9672131147540983, "recall": 0.8300000000000001}, {"class": "explosion", "map@50:95": 0.3571121332596689, "map@50": 0.605062635831805, "precision": 0.4966887417218543, "recall": 0.8300000000000001}, {"class": "grenade", "map@50:95": 0.8318189072306177, "map@50": 0.9109914811178564, "precision": 0.9305555555555556, "recall": 0.8300000000000001}, {"class": "knife", "map@50:95": 0.49779614891880253, "map@50": 0.796572352660021, "precision": 0.8653846153846154, "recall": 0.8300000000000001}, {"class": "all", "map@50:95": 0.5849532574438278, "map@50": 0.8109247374315286, "precision": 0.814960506854031, "recall": 0.8300000000000001}]}, "map": 0.8479358128480553, "precision": 0.838245309714345, "recall": 0.85}
video_processing.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install -q rfdetr==1.2.1 supervision==0.26.1
2
+
3
+ # RF-DETR video processing for threat detection.
4
+ # Inference time depends on frame resolution (e.g., ~50 ms/frame on GPU for 640×640).
5
+
6
+
7
+ import numpy as np
8
+ import supervision as sv
9
+ import torch
10
+ import requests
11
+ from PIL import Image
12
+ import os
13
+ import cv2
14
+ from tqdm import tqdm
15
+ import time
16
+
17
+ from rfdetr import RFDETRNano
18
+
19
+ THREAT_CLASSES = {
20
+ 1: "Gun",
21
+ 2: "Explosive",
22
+ 3: "Grenade",
23
+ 4: "Knife"
24
+ }
25
+
26
+ # Enable GPU if available
27
+ if torch.cuda.is_available():
28
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
29
+ # print(f"CUDA Version: {torch.version.cuda}")
30
+ # print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
31
+
32
+ # Optimize for batch processing
33
+ torch.backends.cudnn.benchmark = True
34
+ torch.backends.cudnn.deterministic = False
35
+ else:
36
+ print("CUDA not available, using CPU")
37
+
38
+ # Configuration
39
+ INPUT_VIDEO = "test_video.mp4"
40
+
41
+ base, ext = os.path.splitext(INPUT_VIDEO)
42
+ OUTPUT_VIDEO = f"{base}_detr{ext}"
43
+
44
+ THRESHOLD = 0.5
45
+ BATCH_SIZE = 32
46
+
47
+ # Auto-adjust batch size based on GPU memory
48
+ if torch.cuda.is_available():
49
+ gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
50
+
51
+ print(f"Using batch size: {BATCH_SIZE}")
52
+
53
+ # Download weights
54
+ weights_url = "https://huggingface.co/Subh775/Threat-Detection-RFDETR/resolve/main/checkpoint_best_total.pth"
55
+ weights_filename = "checkpoint_best_total.pth"
56
+
57
+ if not os.path.exists(weights_filename):
58
+ print(f"Downloading weights from {weights_url}")
59
+ response = requests.get(weights_url, stream=True)
60
+ response.raise_for_status()
61
+ with open(weights_filename, 'wb') as f:
62
+ for chunk in response.iter_content(chunk_size=8192):
63
+ f.write(chunk)
64
+ print("Download complete.")
65
+
66
+ print("Loading model...")
67
+ model = RFDETRNano(resolution=640, pretrain_weights=weights_filename)
68
+ model.optimize_for_inference()
69
+
70
+ # Setup annotators
71
+ color = sv.ColorPalette.from_hex([
72
+ "#1E90FF", "#32CD32", "#FF0000", "#FF8C00"
73
+ ])
74
+
75
+ bbox_annotator = sv.BoxAnnotator(color=color, thickness=3)
76
+ label_annotator = sv.LabelAnnotator(
77
+ color=color,
78
+ text_color=sv.Color.BLACK,
79
+ text_scale=1.0,
80
+ text_thickness=2,
81
+ smart_position=True
82
+ )
83
+
84
+ def process_frame_batch(frames):
85
+ """Process a batch of frames for better GPU utilization"""
86
+ batch_results = []
87
+
88
+ # Convert all frames to PIL images
89
+ pil_images = []
90
+ for frame in frames:
91
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
92
+ pil_image = Image.fromarray(rgb_frame)
93
+ pil_images.append(pil_image)
94
+
95
+ # Process each image in the batch (RF-DETR processes them efficiently)
96
+ batch_detections = []
97
+ for pil_image in pil_images:
98
+ detections = model.predict(pil_image, threshold=THRESHOLD)
99
+ batch_detections.append(detections)
100
+
101
+ # Annotate all images in the batch
102
+ annotated_frames = []
103
+ for pil_image, detections in zip(pil_images, batch_detections):
104
+ # Create labels
105
+ labels = []
106
+ for class_id, confidence in zip(detections.class_id, detections.confidence):
107
+ class_name = THREAT_CLASSES.get(class_id, f"unknown_class_{class_id}")
108
+ labels.append(f"{class_name} {confidence:.2f}")
109
+
110
+ # Annotate
111
+ annotated_pil = pil_image.copy()
112
+ annotated_pil = bbox_annotator.annotate(annotated_pil, detections)
113
+ annotated_pil = label_annotator.annotate(annotated_pil, detections, labels)
114
+
115
+ # Convert back to BGR
116
+ annotated_frame = cv2.cvtColor(np.array(annotated_pil), cv2.COLOR_RGB2BGR)
117
+ annotated_frames.append(annotated_frame)
118
+
119
+ return annotated_frames, batch_detections
120
+
121
+ # Open video
122
+ cap = cv2.VideoCapture(INPUT_VIDEO)
123
+ if not cap.isOpened():
124
+ print(f"Error: Could not open video file {INPUT_VIDEO}")
125
+ exit()
126
+
127
+ # Get video properties
128
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
129
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
130
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
131
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
132
+
133
+ print(f"Video: {width}x{height}, {fps} FPS, {total_frames} frames")
134
+ print(f"Processing in batches of {BATCH_SIZE} frames")
135
+
136
+ # Setup video writer
137
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
138
+ out = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (width, height))
139
+
140
+ # Batch processing
141
+ print("Processing video with batch inference...")
142
+ frame_buffer = []
143
+ total_detections = 0
144
+ processed_frames = 0
145
+ processing_times = []
146
+
147
+ with tqdm(total=total_frames, desc="Batch processing") as pbar:
148
+ while True:
149
+ ret, frame = cap.read()
150
+ if not ret:
151
+ # Process remaining frames in buffer
152
+ if frame_buffer:
153
+ start_time = time.time()
154
+ annotated_frames, batch_detections = process_frame_batch(frame_buffer)
155
+ processing_time = time.time() - start_time
156
+ processing_times.append(processing_time)
157
+
158
+ # Write remaining frames
159
+ for annotated_frame, detections in zip(annotated_frames, batch_detections):
160
+ out.write(annotated_frame)
161
+ total_detections += len(detections)
162
+
163
+ processed_frames += len(frame_buffer)
164
+ pbar.update(len(frame_buffer))
165
+ break
166
+
167
+ # Add frame to buffer
168
+ frame_buffer.append(frame)
169
+
170
+ # Process when buffer is full
171
+ if len(frame_buffer) >= BATCH_SIZE:
172
+ start_time = time.time()
173
+
174
+ # Process batch
175
+ annotated_frames, batch_detections = process_frame_batch(frame_buffer)
176
+
177
+ processing_time = time.time() - start_time
178
+ processing_times.append(processing_time)
179
+
180
+ # Write frames
181
+ batch_threats = 0
182
+ for annotated_frame, detections in zip(annotated_frames, batch_detections):
183
+ out.write(annotated_frame)
184
+ batch_threats += len(detections)
185
+ total_detections += len(detections)
186
+
187
+ processed_frames += len(frame_buffer)
188
+
189
+ # Update progress
190
+ batch_fps = len(frame_buffer) / processing_time if processing_time > 0 else 0
191
+ pbar.set_postfix({
192
+ 'Batch FPS': f"{batch_fps:.1f}",
193
+ 'Threats': batch_threats,
194
+ 'Total': total_detections
195
+ })
196
+ pbar.update(len(frame_buffer))
197
+
198
+ # Clear buffer
199
+ frame_buffer = []
200
+
201
+ # Clear GPU cache every 10 batches
202
+ if torch.cuda.is_available() and processed_frames % (BATCH_SIZE * 10) == 0:
203
+ torch.cuda.empty_cache()
204
+
205
+ # Cleanup
206
+ cap.release()
207
+ out.release()
208
+
209
+ if torch.cuda.is_available():
210
+ torch.cuda.empty_cache()
211
+
212
+ # Performance summary
213
+ total_time = sum(processing_times)
214
+ avg_fps = processed_frames / total_time if total_time > 0 else 0
215
+ speedup = avg_fps / fps if fps > 0 else 0
216
+
217
+ print(f"Output: {OUTPUT_VIDEO}")
218
+ print(f"Stats:")
219
+ print(f" • Processed: {processed_frames} frames")
220
+ print(f" • Detections: {total_detections}")
221
+ print(f" • Batch size: {BATCH_SIZE}")
222
+ print(f" • Average speed: {avg_fps:.1f} FPS")
223
+ print(f" • Speedup: {speedup:.1f}x real-time")
224
+ print(f" • Processing time: {total_time:.1f}s")