Spaces:
Sleeping
Sleeping
| """ | |
| Self-RAG Demo - Simplified Educational Demonstration | |
| Shows concept of adaptive retrieval and self-correction | |
| β οΈ EDUCATIONAL DEMO - Uses rule-based logic, not trained model | |
| Author: Demetrios Chiuratto Agourakis | |
| License: MIT | |
| """ | |
| import gradio as gr | |
| import re | |
| from typing import List, Dict, Tuple | |
| import time | |
| # ================================================================= | |
| # SYNTHETIC DOCUMENT DATABASE | |
| # ================================================================= | |
| DOCUMENTS = { | |
| "marie_curie_1": "Marie Curie discovered radium and polonium in 1898. She was the first woman to win a Nobel Prize.", | |
| "marie_curie_2": "Marie Curie won two Nobel Prizes: Physics in 1903 (shared with Pierre Curie and Henri Becquerel) and Chemistry in 1911.", | |
| "marie_curie_3": "The Curie Institute in Paris continues Marie Curie's research legacy in cancer treatment.", | |
| "einstein_1": "Albert Einstein developed the theory of relativity. E=mcΒ² is his famous equation.", | |
| "einstein_2": "Einstein won the Nobel Prize in Physics in 1921 for the photoelectric effect, not relativity.", | |
| "quantum_1": "Quantum mechanics describes physics at atomic scales. Key contributors include Niels Bohr and Werner Heisenberg.", | |
| "covid_1": "COVID-19 is caused by the SARS-CoV-2 virus. It was first identified in late 2019 in Wuhan, China.", | |
| "covid_2": "COVID-19 vaccines were developed using mRNA technology by Moderna and BioNTech/Pfizer.", | |
| "brazil_1": "Brazil's GDP in 2023 was approximately $2.1 trillion USD, making it the largest economy in Latin America.", | |
| "brazil_2": "The capital of Brazil is BrasΓlia. The largest city is SΓ£o Paulo.", | |
| } | |
| # ================================================================= | |
| # SELF-RAG SIMULATOR | |
| # ================================================================= | |
| class SelfRAGSimulator: | |
| """Simplified Self-RAG demonstration using rules""" | |
| def __init__(self): | |
| self.documents = DOCUMENTS | |
| self.steps = [] | |
| def should_retrieve(self, query: str) -> Tuple[bool, str]: | |
| """Decide if retrieval is needed (simplified)""" | |
| query_lower = query.lower() | |
| # Simple math or general knowledge | |
| if any(word in query_lower for word in ["what is", "define", "2+2", "hello"]): | |
| if any(word in query_lower for word in ["in", "year", "when", "who", "where"]): | |
| return True, "[Retrieve] - Query requires factual information" | |
| return False, "[No Retrieve] - Simple query, no retrieval needed" | |
| # Factual questions | |
| if any(word in query_lower for word in ["when", "where", "who", "how many", "what year"]): | |
| return True, "[Retrieve] - Factual question detected" | |
| # Opinion or creative | |
| if any(word in query_lower for word in ["should", "opinion", "think", "feel", "story"]): | |
| return False, "[No Retrieve] - Opinion/creative query" | |
| # Default: retrieve | |
| return True, "[Retrieve] - Complex query may need information" | |
| def retrieve_docs(self, query: str, top_k: int = 3) -> List[Tuple[str, str, float]]: | |
| """Retrieve relevant documents (simple keyword matching)""" | |
| query_words = set(query.lower().split()) | |
| scored_docs = [] | |
| for doc_id, doc_text in self.documents.items(): | |
| doc_words = set(doc_text.lower().split()) | |
| # Simple overlap score | |
| overlap = len(query_words.intersection(doc_words)) | |
| if overlap > 0: | |
| scored_docs.append((doc_id, doc_text, overlap)) | |
| # Sort by score | |
| scored_docs.sort(key=lambda x: x[2], reverse=True) | |
| return scored_docs[:top_k] | |
| def evaluate_relevance(self, query: str, doc: str) -> Tuple[bool, str]: | |
| """Check if document is relevant (simplified)""" | |
| query_words = set(query.lower().split()) | |
| doc_words = set(doc.lower().split()) | |
| overlap = len(query_words.intersection(doc_words)) | |
| if overlap >= 2: | |
| return True, f"[Relevant] - {overlap} matching keywords" | |
| else: | |
| return False, f"[Irrelevant] - Only {overlap} matching keywords" | |
| def generate_answer(self, query: str, relevant_docs: List[str]) -> str: | |
| """Generate answer from docs (rule-based simulation)""" | |
| if not relevant_docs: | |
| return "I don't have enough information to answer this question." | |
| # Simple extraction: return first relevant sentence | |
| for doc in relevant_docs: | |
| sentences = doc.split('. ') | |
| for sentence in sentences: | |
| query_words = set(query.lower().split()) | |
| sentence_words = set(sentence.lower().split()) | |
| if len(query_words.intersection(sentence_words)) >= 2: | |
| return sentence.strip() | |
| return relevant_docs[0].split('.')[0] # Fallback | |
| def check_support(self, answer: str, docs: List[str]) -> Tuple[bool, str]: | |
| """Check if answer is supported by documents""" | |
| answer_words = set(answer.lower().split()) | |
| for doc in docs: | |
| doc_words = set(doc.lower().split()) | |
| overlap = len(answer_words.intersection(doc_words)) | |
| # If most answer words are in doc, it's supported | |
| if overlap >= len(answer_words) * 0.6: | |
| return True, "[Supported] - Answer verified against documents" | |
| return False, "[Not Supported] - Answer not well-supported, needs revision" | |
| def revise_answer(self, query: str, original_answer: str, docs: List[str]) -> str: | |
| """Revise answer to be better supported""" | |
| # Find more specific information | |
| for doc in docs: | |
| if len(doc) > len(original_answer): | |
| # Extract relevant portion | |
| sentences = doc.split('. ') | |
| for sent in sentences: | |
| if any(word in sent.lower() for word in query.lower().split()): | |
| return sent.strip() | |
| return f"{original_answer} (Based on provided documents)" | |
| def run_self_rag(self, query: str) -> Dict: | |
| """Run complete Self-RAG pipeline""" | |
| self.steps = [] | |
| result = { | |
| "query": query, | |
| "decision": "", | |
| "retrieved_docs": [], | |
| "relevant_docs": [], | |
| "answer": "", | |
| "verification": "", | |
| "final_answer": "", | |
| "steps": [] | |
| } | |
| # Step 1: Decide if retrieval needed | |
| should_retrieve, decision = self.should_retrieve(query) | |
| result["decision"] = decision | |
| self.steps.append(f"**Step 1:** {decision}") | |
| if not should_retrieve: | |
| result["answer"] = "This query doesn't require document retrieval. (Demo: would use LLM's parametric knowledge)" | |
| result["final_answer"] = result["answer"] | |
| self.steps.append(f"**Step 2:** Generated answer without retrieval") | |
| result["steps"] = self.steps | |
| return result | |
| # Step 2: Retrieve documents | |
| retrieved = self.retrieve_docs(query) | |
| result["retrieved_docs"] = [(doc_id, text) for doc_id, text, score in retrieved] | |
| self.steps.append(f"**Step 2:** Retrieved {len(retrieved)} documents") | |
| # Step 3: Evaluate relevance | |
| relevant_docs = [] | |
| for doc_id, text, score in retrieved: | |
| is_relevant, relevance_msg = self.evaluate_relevance(query, text) | |
| self.steps.append(f"**Step 3.{len(relevant_docs)+1}:** Doc '{doc_id}' - {relevance_msg}") | |
| if is_relevant: | |
| relevant_docs.append(text) | |
| result["relevant_docs"].append((doc_id, text, relevance_msg)) | |
| if not relevant_docs: | |
| result["answer"] = "No relevant documents found." | |
| result["final_answer"] = result["answer"] | |
| self.steps.append(f"**Step 4:** No relevant documents, cannot answer") | |
| result["steps"] = self.steps | |
| return result | |
| # Step 4: Generate answer | |
| answer = self.generate_answer(query, relevant_docs) | |
| result["answer"] = answer | |
| self.steps.append(f"**Step 4:** Generated initial answer") | |
| # Step 5: Check if answer is supported | |
| is_supported, support_msg = self.check_support(answer, relevant_docs) | |
| result["verification"] = support_msg | |
| self.steps.append(f"**Step 5:** {support_msg}") | |
| # Step 6: Self-correct if needed | |
| if not is_supported: | |
| revised_answer = self.revise_answer(query, answer, relevant_docs) | |
| result["final_answer"] = revised_answer | |
| self.steps.append(f"**Step 6:** Self-corrected answer") | |
| # Re-check | |
| is_supported2, support_msg2 = self.check_support(revised_answer, relevant_docs) | |
| self.steps.append(f"**Step 7:** {support_msg2}") | |
| else: | |
| result["final_answer"] = answer | |
| self.steps.append(f"**Step 6:** Answer is well-supported, no correction needed") | |
| result["steps"] = self.steps | |
| return result | |
| # ================================================================= | |
| # VISUALIZATION | |
| # ================================================================= | |
| def format_result(result: Dict) -> Tuple[str, str, str]: | |
| """Format Self-RAG result for display""" | |
| # Steps | |
| steps_md = "## π Self-RAG Process\n\n" | |
| for step in result["steps"]: | |
| steps_md += f"{step}\n\n" | |
| # Documents | |
| docs_md = "## π Retrieved Documents\n\n" | |
| if result["retrieved_docs"]: | |
| for i, (doc_id, text) in enumerate(result["retrieved_docs"], 1): | |
| docs_md += f"### Document {i}: `{doc_id}`\n" | |
| docs_md += f"{text}\n\n" | |
| # Mark if relevant | |
| is_in_relevant = any(doc_id == rd[0] for rd in result["relevant_docs"]) | |
| if is_in_relevant: | |
| docs_md += "β **Marked as Relevant**\n\n" | |
| else: | |
| docs_md += "β **Marked as Irrelevant**\n\n" | |
| else: | |
| docs_md += "*No documents retrieved*\n\n" | |
| # Answer | |
| answer_md = "## π‘ Final Answer\n\n" | |
| answer_md += f"**Query:** {result['query']}\n\n" | |
| if result.get("answer") and result["answer"] != result["final_answer"]: | |
| answer_md += f"**Initial Answer:** {result['answer']}\n\n" | |
| answer_md += f"**After Self-Correction:** {result['final_answer']}\n\n" | |
| else: | |
| answer_md += f"**Answer:** {result['final_answer']}\n\n" | |
| answer_md += f"**Verification:** {result.get('verification', 'N/A')}\n\n" | |
| # Stats | |
| stats_md = "### π Statistics\n\n" | |
| stats_md += f"- **Decision:** {result['decision']}\n" | |
| stats_md += f"- **Documents Retrieved:** {len(result['retrieved_docs'])}\n" | |
| stats_md += f"- **Relevant Documents:** {len(result['relevant_docs'])}\n" | |
| stats_md += f"- **Self-Correction:** {'Yes' if result.get('answer') != result['final_answer'] else 'No'}\n" | |
| answer_md += f"\n{stats_md}" | |
| return steps_md, docs_md, answer_md | |
| # ================================================================= | |
| # GRADIO INTERFACE | |
| # ================================================================= | |
| def process_query(query: str): | |
| """Process query with Self-RAG""" | |
| if not query or len(query) < 3: | |
| return "β οΈ Please enter a valid query", "", "" | |
| simulator = SelfRAGSimulator() | |
| result = simulator.run_self_rag(query) | |
| steps, docs, answer = format_result(result) | |
| return steps, docs, answer | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Self-RAG Demo") as demo: | |
| gr.Markdown(""" | |
| # π Self-RAG Demo: Adaptive Retrieval with Self-Correction | |
| **Educational demonstration of Self-Reflective RAG** | |
| β οΈ **DEMO DISCLAIMER:** This is a simplified educational demo using rule-based logic. | |
| Production Self-RAG uses trained models with reflection tokens. | |
| ### What You'll See: | |
| 1. **Adaptive Retrieval** - System decides if retrieval is needed | |
| 2. **Relevance Evaluation** - Which documents are actually useful? | |
| 3. **Answer Verification** - Is answer supported by sources? | |
| 4. **Self-Correction** - Automatic answer revision if needed | |
| ### Key Innovation: | |
| Self-RAG is **+5-15% more accurate** than traditional RAG and uses **40% fewer retrievals**. | |
| **Reference:** Asai et al. (2023) "Self-RAG" - arXiv:2310.11511 | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Your Query") | |
| query_input = gr.Textbox( | |
| label="Ask a Question", | |
| placeholder="e.g., When did Marie Curie win her Nobel Prizes?", | |
| lines=3 | |
| ) | |
| submit_btn = gr.Button("π Run Self-RAG", variant="primary", size="lg") | |
| gr.Markdown(""" | |
| ### Example Queries: | |
| - **Factual:** "When did Marie Curie win her Nobel Prizes?" | |
| - **Simple:** "What is 2+2?" (won't retrieve) | |
| - **Complex:** "What was Brazil's GDP in 2023?" | |
| - **Ambiguous:** "Tell me about Einstein" | |
| """) | |
| gr.Examples( | |
| examples=[ | |
| ["When did Marie Curie win her Nobel Prizes?"], | |
| ["What is 2+2?"], | |
| ["What was Brazil's GDP in 2023?"], | |
| ["Who developed the COVID-19 vaccines?"], | |
| ["Tell me a story about space"], | |
| ], | |
| inputs=[query_input] | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Results") | |
| with gr.Tab("π Process Steps"): | |
| steps_output = gr.Markdown() | |
| with gr.Tab("π Documents"): | |
| docs_output = gr.Markdown() | |
| with gr.Tab("π‘ Final Answer"): | |
| answer_output = gr.Markdown() | |
| submit_btn.click( | |
| fn=process_query, | |
| inputs=[query_input], | |
| outputs=[steps_output, docs_output, answer_output] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### π Understanding Self-RAG | |
| **Traditional RAG:** | |
| ``` | |
| Query β Retrieve ALL β Generate Answer | |
| ``` | |
| **Self-RAG:** | |
| ``` | |
| Query β [Decide: Retrieve?] | |
| β If yes: Retrieve β [Evaluate: Relevant?] | |
| β Generate β [Check: Supported?] | |
| β If no: [Self-Correct] | |
| β Final Answer | |
| ``` | |
| **Why It's Better:** | |
| - π― **Smarter:** Only retrieves when needed (40% savings) | |
| - β **More Accurate:** Verifies answer quality (+5-15%) | |
| - π **Explainable:** Shows reasoning process | |
| - π **Self-Improving:** Corrects mistakes automatically | |
| **Real-World Applications:** | |
| - Medical Q&A (must be accurate and sourced) | |
| - Legal research (citations critical) | |
| - Fact-checking (verify claims) | |
| - Educational tools (teach verification) | |
| **This Demo:** | |
| - Uses simplified rule-based logic | |
| - Production Self-RAG trains models with reflection tokens | |
| - See paper for full methodology: https://arxiv.org/abs/2310.11511 | |
| **Created by:** Demetrios Chiuratto Agourakis | [GitHub](https://github.com/Agourakis82) | |
| **β οΈ REMINDER:** Educational demonstration only. Production deployment requires trained models. | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |