isha0110 commited on
Commit
aa08558
·
verified ·
1 Parent(s): 51ff5eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +336 -1
app.py CHANGED
@@ -1,4 +1,339 @@
1
- <div style='display: flex; gap: 8px; margin-top: 10px;'>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  <span style='background: {"#10b98120" if detected else "#f3f4f6"}; color: {"#10b981" if detected else "#9ca3af"}; padding: 4px 10px; border-radius: 6px; font-size: 12px; font-weight: 600;'>
3
  {("✓ DETECTED" if detected else "○ Not Detected")}
4
  </span>
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import numpy as np
6
+ import os
7
+ from typing import Dict, List, Tuple, Optional
8
+ import json
9
+ from datetime import datetime
10
+ from collections import defaultdict
11
+
12
+ print("🎭 Emotion Classifier Starting...")
13
+
14
+ # ========== CONFIG ==========
15
+ MODEL_NAME = "roberta-base"
16
+ EMOTIONS = ["anger", "fear", "joy", "sadness", "surprise"]
17
+ BEST_THRESHOLDS = [0.24722222, 0.61666667, 0.59722222, 0.44166667, 0.46111111]
18
+ MAX_LEN = 200
19
+ MODEL_PATH = "roberta.pth"
20
+
21
+ # Emotion metadata with richer information
22
+ EMOTION_META = {
23
+ "anger": {
24
+ "emoji": "😠", "color": "#ef4444", "light_color": "#fee2e2",
25
+ "gradient": "linear-gradient(135deg, #ef4444 0%, #dc2626 100%)",
26
+ "description": "Frustration, irritation, or rage",
27
+ "keywords": ["angry", "furious", "mad", "annoyed", "irritated"],
28
+ "intensity_labels": ["Mild irritation", "Moderate anger", "Strong anger", "Intense rage"]
29
+ },
30
+ "fear": {
31
+ "emoji": "😨", "color": "#8b5cf6", "light_color": "#ede9fe",
32
+ "gradient": "linear-gradient(135deg, #8b5cf6 0%, #7c3aed 100%)",
33
+ "description": "Anxiety, worry, or terror",
34
+ "keywords": ["scared", "afraid", "terrified", "anxious", "worried"],
35
+ "intensity_labels": ["Slight concern", "Moderate fear", "Strong fear", "Extreme terror"]
36
+ },
37
+ "joy": {
38
+ "emoji": "😊", "color": "#fbbf24", "light_color": "#fef3c7",
39
+ "gradient": "linear-gradient(135deg, #fbbf24 0%, #f59e0b 100%)",
40
+ "description": "Happiness, excitement, or delight",
41
+ "keywords": ["happy", "excited", "delighted", "joyful", "thrilled"],
42
+ "intensity_labels": ["Mild pleasure", "Moderate happiness", "Strong joy", "Intense euphoria"]
43
+ },
44
+ "sadness": {
45
+ "emoji": "😢", "color": "#3b82f6", "light_color": "#dbeafe",
46
+ "gradient": "linear-gradient(135deg, #3b82f6 0%, #2563eb 100%)",
47
+ "description": "Sorrow, grief, or disappointment",
48
+ "keywords": ["sad", "depressed", "unhappy", "miserable", "heartbroken"],
49
+ "intensity_labels": ["Mild sadness", "Moderate sorrow", "Strong grief", "Deep despair"]
50
+ },
51
+ "surprise": {
52
+ "emoji": "😲", "color": "#ec4899", "light_color": "#fce7f3",
53
+ "gradient": "linear-gradient(135deg, #ec4899 0%, #db2777 100%)",
54
+ "description": "Astonishment, shock, or amazement",
55
+ "keywords": ["surprised", "shocked", "amazed", "astonished", "startled"],
56
+ "intensity_labels": ["Mild surprise", "Moderate shock", "Strong astonishment", "Complete disbelief"]
57
+ }
58
+ }
59
+
60
+ # ========== MODEL CLASS ==========
61
+ class RobertaEmotion(nn.Module):
62
+ def __init__(self):
63
+ super().__init__()
64
+ self.backbone = AutoModel.from_pretrained(MODEL_NAME)
65
+ self.dropout = nn.Dropout(0.35)
66
+ self.head = nn.Linear(768, 5)
67
+
68
+ def forward(self, input_ids, attention_mask):
69
+ outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
70
+ pooled = outputs.pooler_output if hasattr(outputs, "pooler_output") else outputs.last_hidden_state[:, 0]
71
+ x = self.dropout(pooled)
72
+ return self.head(x)
73
+
74
+ # ========== GLOBAL STATE ==========
75
+ class ModelState:
76
+ def __init__(self):
77
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
78
+ self.model = None
79
+ self.tokenizer = None
80
+ self.ready = False
81
+ self.predictions_count = 0
82
+ self.history = []
83
+ self.emotion_stats = defaultdict(int)
84
+
85
+ state = ModelState()
86
+
87
+ # ========== UTILITY FUNCTIONS ==========
88
+ def sigmoid(x: np.ndarray) -> np.ndarray:
89
+ """Apply sigmoid with numerical stability"""
90
+ return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
91
+
92
+ def get_intensity_level(prob: float) -> Tuple[str, str]:
93
+ """Get intensity level and color"""
94
+ if prob >= 0.85: return "Very High", "#10b981"
95
+ elif prob >= 0.70: return "High", "#3b82f6"
96
+ elif prob >= 0.50: return "Moderate", "#f59e0b"
97
+ elif prob >= 0.30: return "Low", "#9ca3af"
98
+ else: return "Very Low", "#d1d5db"
99
+
100
+ def get_emotion_intensity_label(emotion: str, prob: float) -> str:
101
+ """Get human-readable intensity for emotion"""
102
+ labels = EMOTION_META[emotion]["intensity_labels"]
103
+ if prob >= 0.85: return labels[3]
104
+ elif prob >= 0.65: return labels[2]
105
+ elif prob >= 0.45: return labels[1]
106
+ else: return labels[0]
107
+
108
+ # ========== MODEL LOADING ==========
109
+ def load_model() -> Tuple[str, str]:
110
+ """Load model and return status HTML"""
111
+ try:
112
+ print("📦 Loading model...")
113
+ state.model = RobertaEmotion()
114
+
115
+ if os.path.exists(MODEL_PATH):
116
+ state.model.load_state_dict(torch.load(MODEL_PATH, map_location=state.device))
117
+ print("✅ Trained model loaded")
118
+ status = "success"
119
+ status_text = "Trained Model Loaded"
120
+ else:
121
+ print("⚠️ No trained weights found")
122
+ status = "warning"
123
+ status_text = "Model Initialized (No Weights)"
124
+
125
+ state.model.to(state.device)
126
+ state.model.eval()
127
+ state.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
128
+ state.ready = True
129
+
130
+ device_emoji = "🚀" if state.device == "cuda" else "💻"
131
+ bg_color = "#d1fae5" if status == "success" else "#fef3c7"
132
+ border_color = "#10b981" if status == "success" else "#f59e0b"
133
+ icon = "✅" if status == "success" else "⚠️"
134
+
135
+ html = f"""
136
+ <div style='background: {bg_color}; padding: 20px; border-radius: 12px; border-left: 5px solid {border_color}; margin: 10px 0;'>
137
+ <div style='display: flex; align-items: center; gap: 15px;'>
138
+ <div style='font-size: 48px;'>{icon}</div>
139
+ <div style='flex: 1;'>
140
+ <h3 style='margin: 0 0 8px 0; color: #1f2937;'>{status_text}</h3>
141
+ <div style='color: #4b5563; font-size: 14px;'>
142
+ <strong>Device:</strong> {device_emoji} {state.device.upper()} |
143
+ <strong>Architecture:</strong> RoBERTa-base |
144
+ <strong>F1 Score:</strong> 0.872
145
+ </div>
146
+ <div style='margin-top: 8px; padding: 8px; background: rgba(255,255,255,0.5); border-radius: 6px; font-size: 13px; color: #374151;'>
147
+ 🎯 Ready to analyze emotions! You can now enter text or try the examples below.
148
+ </div>
149
+ </div>
150
+ </div>
151
+ </div>
152
+ """
153
+
154
+ controls_html = _create_controls_panel()
155
+
156
+ print(f"✅ Model ready on {state.device}")
157
+ return html, controls_html
158
+
159
+ except Exception as e:
160
+ state.ready = False
161
+ error_html = f"""
162
+ <div style='background: #fee2e2; padding: 20px; border-radius: 12px; border-left: 5px solid #ef4444;'>
163
+ <div style='display: flex; align-items: center; gap: 15px;'>
164
+ <div style='font-size: 48px;'>❌</div>
165
+ <div>
166
+ <h3 style='margin: 0 0 8px 0; color: #991b1b;'>Error Loading Model</h3>
167
+ <div style='color: #7f1d1d; font-size: 14px; font-family: monospace;'>{str(e)}</div>
168
+ </div>
169
+ </div>
170
+ </div>
171
+ """
172
+ print(f"❌ Error: {e}")
173
+ return error_html, ""
174
+
175
+ # ========== VISUALIZATION FUNCTIONS ==========
176
+ def _create_controls_panel() -> str:
177
+ """Create interactive controls panel"""
178
+ return f"""
179
+ <div style='background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 15px; border-radius: 10px; color: white; margin: 10px 0;'>
180
+ <div style='display: flex; justify-content: space-around; align-items: center; flex-wrap: wrap; gap: 15px;'>
181
+ <div style='text-align: center;'>
182
+ <div style='font-size: 28px; font-weight: bold;'>{state.predictions_count}</div>
183
+ <div style='font-size: 12px; opacity: 0.9;'>Total Predictions</div>
184
+ </div>
185
+ <div style='text-align: center;'>
186
+ <div style='font-size: 28px; font-weight: bold;'>5</div>
187
+ <div style='font-size: 12px; opacity: 0.9;'>Emotion Classes</div>
188
+ </div>
189
+ <div style='text-align: center;'>
190
+ <div style='font-size: 28px; font-weight: bold;'>0.872</div>
191
+ <div style='font-size: 12px; opacity: 0.9;'>F1 Score</div>
192
+ </div>
193
+ <div style='text-align: center;'>
194
+ <div style='font-size: 28px; font-weight: bold;'>{state.device.upper()}</div>
195
+ <div style='font-size: 12px; opacity: 0.9;'>Device</div>
196
+ </div>
197
+ </div>
198
+ </div>
199
+ """
200
+
201
+ def _create_radar_chart_svg(probs: np.ndarray) -> str:
202
+ """Create SVG radar chart for emotions"""
203
+ size = 400
204
+ padding = 80
205
+ center = size / 2
206
+ max_radius = (size / 2) - padding
207
+
208
+ # Calculate points
209
+ angles = [i * 2 * np.pi / 5 - np.pi/2 for i in range(5)]
210
+ points = []
211
+ for i, prob in enumerate(probs):
212
+ r = prob * max_radius
213
+ x = center + r * np.cos(angles[i])
214
+ y = center + r * np.sin(angles[i])
215
+ points.append(f"{x},{y}")
216
+
217
+ points_str = " ".join(points)
218
+
219
+ # Create axis lines
220
+ axis_lines = ""
221
+ for i in range(5):
222
+ x = center + max_radius * np.cos(angles[i])
223
+ y = center + max_radius * np.sin(angles[i])
224
+ axis_lines += f'<line x1="{center}" y1="{center}" x2="{x}" y2="{y}" stroke="#d1d5db" stroke-width="2"/>'
225
+
226
+ # Create circles (reference lines)
227
+ circles = ""
228
+ for level in [0.25, 0.5, 0.75, 1.0]:
229
+ r = level * max_radius
230
+ circles += f'<circle cx="{center}" cy="{center}" r="{r}" fill="none" stroke="#e5e7eb" stroke-width="1.5" stroke-dasharray="4,4"/>'
231
+
232
+ # Create value points on the data polygon
233
+ data_points = ""
234
+ for i, prob in enumerate(probs):
235
+ r = prob * max_radius
236
+ x = center + r * np.cos(angles[i])
237
+ y = center + r * np.sin(angles[i])
238
+ data_points += f'<circle cx="{x}" cy="{y}" r="6" fill="#667eea" stroke="white" stroke-width="2"/>'
239
+
240
+ # Labels with better positioning
241
+ labels = ""
242
+ for i, emotion in enumerate(EMOTIONS):
243
+ label_radius = max_radius + 30
244
+ x = center + label_radius * np.cos(angles[i])
245
+ y = center + label_radius * np.sin(angles[i])
246
+
247
+ emoji = EMOTION_META[emotion]["emoji"]
248
+ prob_pct = probs[i] * 100
249
+
250
+ # Adjust text anchor based on position
251
+ if x < center - 10:
252
+ anchor = "end"
253
+ elif x > center + 10:
254
+ anchor = "start"
255
+ else:
256
+ anchor = "middle"
257
+
258
+ labels += f'''
259
+ <g>
260
+ <text x="{x}" y="{y - 15}" text-anchor="{anchor}" font-size="28" dominant-baseline="middle">{emoji}</text>
261
+ <text x="{x}" y="{y + 8}" text-anchor="{anchor}" font-size="14" font-weight="600" fill="#1f2937" dominant-baseline="middle">{emotion.capitalize()}</text>
262
+ <text x="{x}" y="{y + 24}" text-anchor="{anchor}" font-size="12" fill="#6b7280" dominant-baseline="middle">{prob_pct:.0f}%</text>
263
+ </g>
264
+ '''
265
+
266
+ return f"""
267
+ <div style="width: 100%; max-width: 450px; margin: 20px auto; padding: 10px; overflow: visible;">
268
+ <svg width="100%" height="100%" viewBox="0 0 {size} {size}" preserveAspectRatio="xMidYMid meet" style="overflow: visible;">
269
+ <defs>
270
+ <filter id="shadow">
271
+ <feDropShadow dx="0" dy="2" stdDeviation="3" flood-opacity="0.3"/>
272
+ </filter>
273
+ </defs>
274
+ {circles}
275
+ {axis_lines}
276
+ <polygon points="{points_str}" fill="rgba(102, 126, 234, 0.25)" stroke="#667eea" stroke-width="3" filter="url(#shadow)"/>
277
+ {data_points}
278
+ {labels}
279
+ </svg>
280
+ </div>
281
+ """
282
+
283
+ def _create_emotion_card(emotion: str, prob: float, detected: bool, threshold: float, rank: int) -> str:
284
+ """Create enhanced emotion card with ranking"""
285
+ meta = EMOTION_META[emotion]
286
+ emoji = meta["emoji"]
287
+ color = meta["color"]
288
+ light_color = meta["light_color"]
289
+ gradient = meta["gradient"]
290
+ desc = meta["description"]
291
+
292
+ prob_pct = prob * 100
293
+ intensity, intensity_color = get_intensity_level(prob)
294
+ intensity_label = get_emotion_intensity_label(emotion, prob)
295
+
296
+ # Medal for top 3
297
+ medals = {1: "🥇", 2: "🥈", 3: "🥉"}
298
+ rank_display = medals.get(rank, f"#{rank}")
299
+
300
+ if detected:
301
+ border = f"3px solid {color}"
302
+ shadow = "0 6px 12px rgba(0,0,0,0.15)"
303
+ bg = "white"
304
+ else:
305
+ border = "2px solid #e5e7eb"
306
+ shadow = "0 2px 4px rgba(0,0,0,0.05)"
307
+ bg = "#fafafa"
308
+
309
+ return f"""
310
+ <div style='background: {bg}; padding: 18px; margin: 12px 0; border-radius: 12px; border: {border}; box-shadow: {shadow}; transition: all 0.3s;'>
311
+ <div style='display: flex; justify-content: space-between; align-items: center; margin-bottom: 10px;'>
312
+ <div style='display: flex; align-items: center; gap: 12px;'>
313
+ <span style='font-size: 36px;'>{emoji}</span>
314
+ <div>
315
+ <div style='display: flex; align-items: center; gap: 8px;'>
316
+ <span style='font-weight: bold; font-size: 20px; color: #1f2937; text-transform: capitalize;'>{emotion}</span>
317
+ <span style='background: {light_color}; color: {color}; padding: 2px 8px; border-radius: 12px; font-size: 12px; font-weight: bold;'>{rank_display}</span>
318
+ </div>
319
+ <div style='font-size: 13px; color: #6b7280; margin-top: 2px;'>{intensity_label}</div>
320
+ </div>
321
+ </div>
322
+ <div style='text-align: right;'>
323
+ <div style='font-size: 24px; font-weight: bold; color: {color};'>{prob_pct:.1f}%</div>
324
+ <div style='font-size: 11px; color: {intensity_color}; font-weight: 600;'>{intensity}</div>
325
+ </div>
326
+ </div>
327
+
328
+ <div style='position: relative; background: #f3f4f6; height: 28px; border-radius: 14px; overflow: hidden; margin: 10px 0;'>
329
+ <div style='position: absolute; height: 100%; background: {gradient}; width: {min(prob_pct, 100)}%; transition: width 0.8s cubic-bezier(0.4, 0, 0.2, 1); border-radius: 14px;'></div>
330
+ <div style='position: absolute; width: 100%; height: 100%; display: flex; align-items: center; padding: 0 12px; justify-content: space-between;'>
331
+ <span style='font-size: 12px; font-weight: 600; color: {"white" if prob_pct > 50 else "#1f2937"};'>{desc}</span>
332
+ <span style='font-size: 11px; color: {"rgba(255,255,255,0.8)" if prob_pct > 50 else "#6b7280"};'>Threshold: {threshold:.1%}</span>
333
+ </div>
334
+ </div>
335
+
336
+ <div style='display: flex; gap: 8px; margin-top: 10px;'>
337
  <span style='background: {"#10b98120" if detected else "#f3f4f6"}; color: {"#10b981" if detected else "#9ca3af"}; padding: 4px 10px; border-radius: 6px; font-size: 12px; font-weight: 600;'>
338
  {("✓ DETECTED" if detected else "○ Not Detected")}
339
  </span>