AbhishekParanjape commited on
Commit
ef981ba
Β·
1 Parent(s): 4de9017

semantic chunker

Browse files
Files changed (2) hide show
  1. rag_system.py +11 -4
  2. semantic_chunking.py +420 -0
rag_system.py CHANGED
@@ -18,6 +18,7 @@ import json
18
  import base64
19
  from openai import OpenAI
20
  import re
 
21
 
22
  # Load environment variables
23
  load_dotenv()
@@ -198,11 +199,17 @@ class DocumentIngestion:
198
  st.error("No embedding model available. Please install sentence-transformers or provide OpenAI API key.")
199
  raise Exception("No embedding model available")
200
 
201
- self.text_splitter = RecursiveCharacterTextSplitter(
202
- chunk_size=1000,
203
- chunk_overlap=200,
204
- separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
 
 
 
 
205
  )
 
 
206
  self.persist_directory = os.getenv("CHROMA_PERSIST_DIRECTORY", "./chroma_db")
207
  os.makedirs(self.persist_directory, exist_ok=True)
208
 
 
18
  import base64
19
  from openai import OpenAI
20
  import re
21
+ from semantic_chunking import SemanticChunker
22
 
23
  # Load environment variables
24
  load_dotenv()
 
199
  st.error("No embedding model available. Please install sentence-transformers or provide OpenAI API key.")
200
  raise Exception("No embedding model available")
201
 
202
+ self.text_splitter = SemanticChunker(
203
+ embeddings_model=self.embeddings,
204
+ chunk_size=4, # 4 sentences per base chunk
205
+ overlap=1, # 1 sentence overlap
206
+ similarity_threshold=0.75, # Semantic similarity threshold
207
+ min_chunk_size=150, # Minimum 150 characters
208
+ max_chunk_size=1500, # Maximum 1500 characters
209
+ debug=True # Show statistics in Streamlit
210
  )
211
+
212
+ st.info(f"🧠 Using semantic chunking with {self.embedding_type} embeddings")
213
  self.persist_directory = os.getenv("CHROMA_PERSIST_DIRECTORY", "./chroma_db")
214
  os.makedirs(self.persist_directory, exist_ok=True)
215
 
semantic_chunking.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Semantic Chunker Module for RAG Systems
3
+ ======================================
4
+
5
+ A drop-in replacement for RecursiveCharacterTextSplitter that uses semantic similarity
6
+ to create more coherent chunks. Designed to work seamlessly with existing LangChain
7
+ and Streamlit RAG systems.
8
+
9
+ Author: AI Assistant
10
+ Compatible with: LangChain, BGE embeddings, OpenAI embeddings, Streamlit
11
+ """
12
+
13
+ import numpy as np
14
+ import re
15
+ from typing import List, Dict, Any, Optional, Union
16
+ from langchain.schema import Document
17
+ import streamlit as st
18
+ from sklearn.metrics.pairwise import cosine_similarity
19
+ import logging
20
+
21
+ # Set up logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ class SemanticChunker:
26
+ """
27
+ Advanced semantic document chunker that creates coherent chunks based on
28
+ semantic similarity rather than fixed character counts.
29
+
30
+ Perfect for university documents, research papers, and policy documents
31
+ where maintaining semantic coherence is crucial.
32
+ """
33
+
34
+ def __init__(self,
35
+ embeddings_model,
36
+ chunk_size: int = 4,
37
+ overlap: int = 1,
38
+ similarity_threshold: float = 0.75,
39
+ min_chunk_size: int = 150,
40
+ max_chunk_size: int = 1500,
41
+ sentence_split_pattern: Optional[str] = None,
42
+ debug: bool = False):
43
+ """
44
+ Initialize the semantic chunker.
45
+
46
+ Args:
47
+ embeddings_model: Your existing embeddings model (BGE, OpenAI, etc.)
48
+ chunk_size: Base number of sentences per chunk (default: 4)
49
+ overlap: Number of sentences to overlap between chunks (default: 1)
50
+ similarity_threshold: Cosine similarity threshold for extending chunks (0.0-1.0)
51
+ min_chunk_size: Minimum characters per chunk (skip smaller chunks)
52
+ max_chunk_size: Maximum characters per chunk (prevent overly large chunks)
53
+ sentence_split_pattern: Custom regex pattern for sentence splitting
54
+ debug: Enable debug logging and statistics
55
+ """
56
+ self.embeddings_model = embeddings_model
57
+ self.chunk_size = chunk_size
58
+ self.overlap = overlap
59
+ self.similarity_threshold = similarity_threshold
60
+ self.min_chunk_size = min_chunk_size
61
+ self.max_chunk_size = max_chunk_size
62
+ self.debug = debug
63
+
64
+ # Default sentence splitting pattern optimized for academic/university documents
65
+ self.sentence_pattern = sentence_split_pattern or r'[.!?]+\s+'
66
+
67
+ # Statistics tracking
68
+ self.stats = {
69
+ "total_documents": 0,
70
+ "total_chunks": 0,
71
+ "avg_chunk_size": 0,
72
+ "chunking_methods": {},
73
+ "embedding_errors": 0
74
+ }
75
+
76
+ if self.debug:
77
+ logger.info(f"Initialized SemanticChunker with threshold={similarity_threshold}")
78
+
79
+ def _detect_embedding_model_type(self) -> str:
80
+ """Detect the type of embedding model being used."""
81
+ if hasattr(self.embeddings_model, 'model'):
82
+ # Likely sentence-transformers model (BGE, etc.)
83
+ model_name = getattr(self.embeddings_model.model, 'model_name', 'sentence-transformers')
84
+ return f"sentence-transformers ({model_name})"
85
+ elif hasattr(self.embeddings_model, 'client'):
86
+ # Likely OpenAI
87
+ return "OpenAI"
88
+ else:
89
+ return "Unknown"
90
+
91
+ def _preprocess_text_for_splitting(self, text: str) -> str:
92
+ """
93
+ Preprocess text to handle common formatting issues in university documents.
94
+ """
95
+ # Fix common formatting issues
96
+ fixes = [
97
+ # Add space after periods before capital letters
98
+ (r'([a-z])\.([A-Z])', r'\1. \2'),
99
+ # Add space after numbers with periods
100
+ (r'([0-9]+)\.([A-Z])', r'\1. \2'),
101
+ # Fix missing spaces after question/exclamation marks
102
+ (r'([a-z])\?([A-Z])', r'\1? \2'),
103
+ (r'([a-z])\!([A-Z])', r'\1! \2'),
104
+ # Clean up multiple spaces
105
+ (r'\s+', ' '),
106
+ # Fix bullet points
107
+ (r'β€’\s*([A-Z])', r'β€’ \1'),
108
+ (r'-\s*([A-Z])', r'- \1'),
109
+ ]
110
+
111
+ processed_text = text
112
+ for pattern, replacement in fixes:
113
+ processed_text = re.sub(pattern, replacement, processed_text)
114
+
115
+ return processed_text.strip()
116
+
117
+ def _split_into_sentences(self, text: str) -> List[str]:
118
+ """
119
+ Advanced sentence splitting optimized for academic documents.
120
+ """
121
+ # Preprocess text
122
+ text = self._preprocess_text_for_splitting(text)
123
+
124
+ # Split on sentence boundaries
125
+ raw_sentences = re.split(self.sentence_pattern, text)
126
+
127
+ # Clean and filter sentences
128
+ sentences = []
129
+ for sentence in raw_sentences:
130
+ sentence = sentence.strip()
131
+
132
+ # Filter out very short sentences, pure numbers, or empty strings
133
+ if len(sentence) >= 10 and not sentence.isdigit() and not re.match(r'^[^\w]*$', sentence):
134
+ sentences.append(sentence)
135
+
136
+ if self.debug:
137
+ logger.info(f"Split text into {len(sentences)} sentences")
138
+
139
+ return sentences
140
+
141
+ def _get_embeddings(self, texts: List[str]) -> Optional[np.ndarray]:
142
+ """
143
+ Get embeddings from the provided model with error handling.
144
+ """
145
+ try:
146
+ if hasattr(self.embeddings_model, 'model'):
147
+ # sentence-transformers model (BGE, etc.)
148
+ embeddings = self.embeddings_model.model.encode(texts)
149
+ return np.array(embeddings)
150
+ elif hasattr(self.embeddings_model, 'embed_documents'):
151
+ # OpenAI or similar API-based embeddings
152
+ embeddings = self.embeddings_model.embed_documents(texts)
153
+ return np.array(embeddings)
154
+ else:
155
+ # Try direct call
156
+ embeddings = self.embeddings_model(texts)
157
+ return np.array(embeddings)
158
+
159
+ except Exception as e:
160
+ self.stats["embedding_errors"] += 1
161
+ if self.debug:
162
+ logger.error(f"Error generating embeddings: {e}")
163
+
164
+ # Show warning in Streamlit if available
165
+ try:
166
+ st.warning(f"⚠️ Embedding error, falling back to simple chunking: {str(e)[:100]}...")
167
+ except:
168
+ pass # Streamlit not available
169
+
170
+ return None
171
+
172
+ def _calculate_semantic_boundaries(self, embeddings: np.ndarray, sentences: List[str]) -> List[int]:
173
+ """
174
+ Find natural semantic boundaries in the text based on embedding similarities.
175
+ """
176
+ boundaries = [0] # Always start with first sentence
177
+
178
+ # Calculate similarities between consecutive sentences
179
+ similarities = []
180
+ for i in range(len(embeddings) - 1):
181
+ sim = cosine_similarity(
182
+ embeddings[i:i+1],
183
+ embeddings[i+1:i+2]
184
+ )[0][0]
185
+ similarities.append(sim)
186
+
187
+ # Find significant drops in similarity (topic boundaries)
188
+ if len(similarities) > 1:
189
+ mean_sim = np.mean(similarities)
190
+ std_sim = np.std(similarities)
191
+ threshold = mean_sim - (0.5 * std_sim) # Adaptive threshold
192
+
193
+ for i, sim in enumerate(similarities):
194
+ if sim < threshold:
195
+ boundaries.append(i + 1)
196
+
197
+ boundaries.append(len(sentences)) # Always end with last sentence
198
+
199
+ return sorted(list(set(boundaries))) # Remove duplicates and sort
200
+
201
+ def _create_chunks_from_boundaries(self, sentences: List[str], boundaries: List[int],
202
+ embeddings: Optional[np.ndarray], metadata: Dict[str, Any]) -> List[Document]:
203
+ """
204
+ Create document chunks based on semantic boundaries.
205
+ """
206
+ chunks = []
207
+
208
+ for i in range(len(boundaries) - 1):
209
+ start_idx = boundaries[i]
210
+ end_idx = boundaries[i + 1]
211
+
212
+ # Create base chunk
213
+ chunk_sentences = sentences[start_idx:end_idx]
214
+
215
+ # Try to extend chunk if semantically similar
216
+ if embeddings is not None and end_idx < len(sentences):
217
+ current_embedding = np.mean(embeddings[start_idx:end_idx], axis=0, keepdims=True)
218
+
219
+ # Check if we can extend the chunk
220
+ extended_end = end_idx
221
+ while extended_end < len(sentences):
222
+ next_sentence_embedding = embeddings[extended_end:extended_end+1]
223
+ similarity = cosine_similarity(current_embedding, next_sentence_embedding)[0][0]
224
+
225
+ if similarity > self.similarity_threshold:
226
+ # Check size limit
227
+ test_chunk = ' '.join(sentences[start_idx:extended_end+1])
228
+ if len(test_chunk) <= self.max_chunk_size:
229
+ extended_end += 1
230
+ # Update current embedding
231
+ current_embedding = np.mean(embeddings[start_idx:extended_end], axis=0, keepdims=True)
232
+ else:
233
+ break
234
+ else:
235
+ break
236
+
237
+ # Use extended chunk if we found extensions
238
+ if extended_end > end_idx:
239
+ chunk_sentences = sentences[start_idx:extended_end]
240
+
241
+ # Create chunk text
242
+ chunk_text = ' '.join(chunk_sentences)
243
+
244
+ # Only add chunks that meet minimum size requirement
245
+ if len(chunk_text) >= self.min_chunk_size:
246
+ chunk_metadata = metadata.copy()
247
+ chunk_metadata.update({
248
+ "chunk_index": len(chunks),
249
+ "sentence_count": len(chunk_sentences),
250
+ "start_sentence": start_idx,
251
+ "end_sentence": start_idx + len(chunk_sentences) - 1,
252
+ "chunking_method": "semantic_boundary",
253
+ "similarity_threshold": self.similarity_threshold,
254
+ "chunk_size_chars": len(chunk_text)
255
+ })
256
+
257
+ chunks.append(Document(page_content=chunk_text, metadata=chunk_metadata))
258
+
259
+ return chunks
260
+
261
+ def _create_simple_chunks(self, sentences: List[str], metadata: Dict[str, Any]) -> List[Document]:
262
+ """
263
+ Fallback to simple sentence-based chunking when embeddings are unavailable.
264
+ """
265
+ chunks = []
266
+
267
+ for i in range(0, len(sentences), max(1, self.chunk_size - self.overlap)):
268
+ chunk_sentences = sentences[i:i + self.chunk_size]
269
+ chunk_text = ' '.join(chunk_sentences)
270
+
271
+ if len(chunk_text) >= self.min_chunk_size:
272
+ chunk_metadata = metadata.copy()
273
+ chunk_metadata.update({
274
+ "chunk_index": len(chunks),
275
+ "sentence_count": len(chunk_sentences),
276
+ "start_sentence": i,
277
+ "end_sentence": i + len(chunk_sentences) - 1,
278
+ "chunking_method": "simple_fallback",
279
+ "chunk_size_chars": len(chunk_text)
280
+ })
281
+
282
+ chunks.append(Document(page_content=chunk_text, metadata=chunk_metadata))
283
+
284
+ return chunks
285
+
286
+ def split_documents(self, documents: List[Document]) -> List[Document]:
287
+ """
288
+ Main method: Split documents into semantically coherent chunks.
289
+
290
+ Args:
291
+ documents: List of LangChain Document objects
292
+
293
+ Returns:
294
+ List of Document objects with semantic chunks
295
+ """
296
+ all_chunks = []
297
+ self.stats["total_documents"] = len(documents)
298
+
299
+ for doc_idx, doc in enumerate(documents):
300
+ try:
301
+ # Split document into sentences
302
+ sentences = self._split_into_sentences(doc.page_content)
303
+
304
+ if not sentences:
305
+ if self.debug:
306
+ logger.warning(f"No sentences found in document {doc_idx}")
307
+ continue
308
+
309
+ # Handle very short documents
310
+ if len(sentences) < self.chunk_size:
311
+ chunk_text = ' '.join(sentences)
312
+ if len(chunk_text) >= self.min_chunk_size:
313
+ chunk_metadata = doc.metadata.copy()
314
+ chunk_metadata.update({
315
+ "chunk_index": 0,
316
+ "total_chunks": 1,
317
+ "sentence_count": len(sentences),
318
+ "chunking_method": "single_chunk",
319
+ "chunk_size_chars": len(chunk_text)
320
+ })
321
+ all_chunks.append(Document(page_content=chunk_text, metadata=chunk_metadata))
322
+ continue
323
+
324
+ # Generate embeddings
325
+ embeddings = self._get_embeddings(sentences)
326
+
327
+ if embeddings is not None:
328
+ # Create semantic chunks
329
+ chunks = self._create_chunks_from_boundaries(sentences, [0, len(sentences)], embeddings, doc.metadata)
330
+ method = "semantic"
331
+ else:
332
+ # Fallback to simple chunking
333
+ chunks = self._create_simple_chunks(sentences, doc.metadata)
334
+ method = "simple_fallback"
335
+
336
+ # Update statistics
337
+ self.stats["chunking_methods"][method] = self.stats["chunking_methods"].get(method, 0) + 1
338
+
339
+ # Update total chunks count in each chunk's metadata
340
+ for chunk in chunks:
341
+ chunk.metadata["total_chunks"] = len(chunks)
342
+ chunk.metadata["source_document_index"] = doc_idx
343
+
344
+ all_chunks.extend(chunks)
345
+
346
+ if self.debug:
347
+ logger.info(f"Document {doc_idx}: {len(sentences)} sentences β†’ {len(chunks)} chunks ({method})")
348
+
349
+ except Exception as e:
350
+ logger.error(f"Error processing document {doc_idx}: {e}")
351
+ if self.debug:
352
+ st.error(f"Error processing document {doc_idx}: {e}")
353
+
354
+ # Update final statistics
355
+ self.stats["total_chunks"] = len(all_chunks)
356
+ if all_chunks:
357
+ chunk_sizes = [len(chunk.page_content) for chunk in all_chunks]
358
+ self.stats["avg_chunk_size"] = sum(chunk_sizes) / len(chunk_sizes)
359
+
360
+ if self.debug:
361
+ logger.info(f"Created {len(all_chunks)} total chunks from {len(documents)} documents")
362
+
363
+ return all_chunks
364
+
365
+ def get_statistics(self) -> Dict[str, Any]:
366
+ """Get chunking statistics for analysis."""
367
+ return self.stats.copy()
368
+
369
+ def display_statistics(self):
370
+ """Display chunking statistics in Streamlit (if available)."""
371
+ try:
372
+ with st.expander("πŸ“Š Semantic Chunking Statistics"):
373
+ col1, col2 = st.columns(2)
374
+
375
+ with col1:
376
+ st.metric("Total Documents", self.stats["total_documents"])
377
+ st.metric("Total Chunks", self.stats["total_chunks"])
378
+
379
+ with col2:
380
+ st.metric("Avg Chunk Size", f"{self.stats['avg_chunk_size']:.0f} chars")
381
+ st.metric("Embedding Errors", self.stats["embedding_errors"])
382
+
383
+ if self.stats["chunking_methods"]:
384
+ st.write("**Chunking Methods Used:**")
385
+ for method, count in self.stats["chunking_methods"].items():
386
+ percentage = (count / self.stats["total_documents"]) * 100 if self.stats["total_documents"] > 0 else 0
387
+ st.write(f" - {method}: {count} documents ({percentage:.1f}%)")
388
+
389
+ st.write("**Configuration:**")
390
+ st.json({
391
+ "chunk_size": self.chunk_size,
392
+ "overlap": self.overlap,
393
+ "similarity_threshold": self.similarity_threshold,
394
+ "min_chunk_size": self.min_chunk_size,
395
+ "max_chunk_size": self.max_chunk_size,
396
+ "embedding_model": self._detect_embedding_model_type()
397
+ })
398
+
399
+ except ImportError:
400
+ # Streamlit not available, print to console
401
+ print("\n=== Semantic Chunking Statistics ===")
402
+ print(f"Documents processed: {self.stats['total_documents']}")
403
+ print(f"Chunks created: {self.stats['total_chunks']}")
404
+ print(f"Average chunk size: {self.stats['avg_chunk_size']:.0f} characters")
405
+ print(f"Embedding errors: {self.stats['embedding_errors']}")
406
+ print(f"Chunking methods: {self.stats['chunking_methods']}")
407
+
408
+
409
+ def create_semantic_chunker(embeddings_model, **kwargs) -> SemanticChunker:
410
+ """
411
+ Convenience function to create a semantic chunker with sensible defaults.
412
+
413
+ Args:
414
+ embeddings_model: Your existing embeddings model
415
+ **kwargs: Additional parameters to pass to SemanticChunker
416
+
417
+ Returns:
418
+ SemanticChunker instance ready to use
419
+ """
420
+ return SemanticChunker(embeddings_model=embeddings_model, **kwargs)