Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -21,6 +21,71 @@ EMOTION_LABELS = {
|
|
| 21 |
'LABEL_2': 'Neutral'
|
| 22 |
}
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
def format_topics(topic_model, topic_counts):
|
| 25 |
"""Convert topic numbers to readable labels."""
|
| 26 |
formatted_topics = []
|
|
@@ -50,10 +115,26 @@ def format_emotions(emotion_counts):
|
|
| 50 |
})
|
| 51 |
return formatted_emotions
|
| 52 |
|
| 53 |
-
# [Previous functions remain the same until process_and_summarize]
|
| 54 |
-
|
| 55 |
def process_and_summarize(uploaded_file, top_n=50):
|
| 56 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
# Initialize BERTopic with specific parameters
|
| 59 |
topic_model = BERTopic(
|
|
|
|
| 21 |
'LABEL_2': 'Neutral'
|
| 22 |
}
|
| 23 |
|
| 24 |
+
def chunk_text(text, max_length=512):
|
| 25 |
+
"""Split text into chunks of maximum token length."""
|
| 26 |
+
tokens = bert_tokenizer.encode(text, add_special_tokens=False)
|
| 27 |
+
chunks = []
|
| 28 |
+
|
| 29 |
+
for i in range(0, len(tokens), max_length - 2): # -2 to account for [CLS] and [SEP] tokens
|
| 30 |
+
chunk = tokens[i:i + max_length - 2]
|
| 31 |
+
# Add special tokens
|
| 32 |
+
chunk = [bert_tokenizer.cls_token_id] + chunk + [bert_tokenizer.sep_token_id]
|
| 33 |
+
chunks.append(chunk)
|
| 34 |
+
|
| 35 |
+
return chunks
|
| 36 |
+
|
| 37 |
+
def get_embedding_for_text(text):
|
| 38 |
+
"""Get embedding for a single text."""
|
| 39 |
+
chunks = chunk_text(text)
|
| 40 |
+
chunk_embeddings = []
|
| 41 |
+
|
| 42 |
+
for chunk in chunks:
|
| 43 |
+
# Convert to tensor and add batch dimension
|
| 44 |
+
input_ids = torch.tensor([chunk]).to(bert_model.device)
|
| 45 |
+
attention_mask = torch.ones_like(input_ids)
|
| 46 |
+
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
outputs = bert_model(input_ids, attention_mask=attention_mask)
|
| 49 |
+
|
| 50 |
+
# Get [CLS] token embedding for this chunk
|
| 51 |
+
chunk_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
| 52 |
+
chunk_embeddings.append(chunk_embedding[0])
|
| 53 |
+
|
| 54 |
+
# Average embeddings from all chunks
|
| 55 |
+
if chunk_embeddings:
|
| 56 |
+
return np.mean(chunk_embeddings, axis=0)
|
| 57 |
+
return np.zeros(bert_model.config.hidden_size) # fallback
|
| 58 |
+
|
| 59 |
+
def generate_embeddings(texts):
|
| 60 |
+
"""Generate embeddings for a list of texts."""
|
| 61 |
+
embeddings = []
|
| 62 |
+
|
| 63 |
+
for text in texts:
|
| 64 |
+
try:
|
| 65 |
+
embedding = get_embedding_for_text(text)
|
| 66 |
+
embeddings.append(embedding)
|
| 67 |
+
except Exception as e:
|
| 68 |
+
st.warning(f"Error processing text: {str(e)}")
|
| 69 |
+
# Add zero embedding as fallback
|
| 70 |
+
embeddings.append(np.zeros(bert_model.config.hidden_size))
|
| 71 |
+
|
| 72 |
+
return np.array(embeddings)
|
| 73 |
+
|
| 74 |
+
def classify_emotion(text):
|
| 75 |
+
"""Classify emotion for a single text."""
|
| 76 |
+
try:
|
| 77 |
+
chunks = chunk_text(text)
|
| 78 |
+
if not chunks:
|
| 79 |
+
return "unknown"
|
| 80 |
+
|
| 81 |
+
# Use first chunk for classification
|
| 82 |
+
chunk_text = bert_tokenizer.decode(chunks[0])
|
| 83 |
+
result = emotion_classifier(chunk_text)[0]
|
| 84 |
+
return result['label']
|
| 85 |
+
except Exception as e:
|
| 86 |
+
st.warning(f"Error in emotion classification: {str(e)}")
|
| 87 |
+
return "unknown"
|
| 88 |
+
|
| 89 |
def format_topics(topic_model, topic_counts):
|
| 90 |
"""Convert topic numbers to readable labels."""
|
| 91 |
formatted_topics = []
|
|
|
|
| 115 |
})
|
| 116 |
return formatted_emotions
|
| 117 |
|
|
|
|
|
|
|
| 118 |
def process_and_summarize(uploaded_file, top_n=50):
|
| 119 |
+
# Determine the file type
|
| 120 |
+
if uploaded_file.name.endswith(".csv"):
|
| 121 |
+
df = pd.read_csv(uploaded_file)
|
| 122 |
+
elif uploaded_file.name.endswith(".xlsx"):
|
| 123 |
+
df = pd.read_excel(uploaded_file)
|
| 124 |
+
else:
|
| 125 |
+
st.error("Unsupported file format.")
|
| 126 |
+
return None, None
|
| 127 |
+
|
| 128 |
+
# Validate required columns
|
| 129 |
+
required_columns = ['country', 'poem']
|
| 130 |
+
missing_columns = [col for col in required_columns if col not in df.columns]
|
| 131 |
+
if missing_columns:
|
| 132 |
+
st.error(f"Missing columns: {', '.join(missing_columns)}")
|
| 133 |
+
return None, None
|
| 134 |
+
|
| 135 |
+
# Parse and preprocess the file
|
| 136 |
+
df['country'] = df['country'].str.strip()
|
| 137 |
+
df = df.dropna(subset=['country', 'poem'])
|
| 138 |
|
| 139 |
# Initialize BERTopic with specific parameters
|
| 140 |
topic_model = BERTopic(
|