self-rag / app.py
chiuratto-AIgourakis's picture
Upload folder using huggingface_hub
09058b6 verified
"""
Self-RAG Demo - Simplified Educational Demonstration
Shows concept of adaptive retrieval and self-correction
⚠️ EDUCATIONAL DEMO - Uses rule-based logic, not trained model
Author: Demetrios Chiuratto Agourakis
License: MIT
"""
import gradio as gr
import re
from typing import List, Dict, Tuple
import time
# =================================================================
# SYNTHETIC DOCUMENT DATABASE
# =================================================================
DOCUMENTS = {
"marie_curie_1": "Marie Curie discovered radium and polonium in 1898. She was the first woman to win a Nobel Prize.",
"marie_curie_2": "Marie Curie won two Nobel Prizes: Physics in 1903 (shared with Pierre Curie and Henri Becquerel) and Chemistry in 1911.",
"marie_curie_3": "The Curie Institute in Paris continues Marie Curie's research legacy in cancer treatment.",
"einstein_1": "Albert Einstein developed the theory of relativity. E=mcΒ² is his famous equation.",
"einstein_2": "Einstein won the Nobel Prize in Physics in 1921 for the photoelectric effect, not relativity.",
"quantum_1": "Quantum mechanics describes physics at atomic scales. Key contributors include Niels Bohr and Werner Heisenberg.",
"covid_1": "COVID-19 is caused by the SARS-CoV-2 virus. It was first identified in late 2019 in Wuhan, China.",
"covid_2": "COVID-19 vaccines were developed using mRNA technology by Moderna and BioNTech/Pfizer.",
"brazil_1": "Brazil's GDP in 2023 was approximately $2.1 trillion USD, making it the largest economy in Latin America.",
"brazil_2": "The capital of Brazil is BrasΓ­lia. The largest city is SΓ£o Paulo.",
}
# =================================================================
# SELF-RAG SIMULATOR
# =================================================================
class SelfRAGSimulator:
"""Simplified Self-RAG demonstration using rules"""
def __init__(self):
self.documents = DOCUMENTS
self.steps = []
def should_retrieve(self, query: str) -> Tuple[bool, str]:
"""Decide if retrieval is needed (simplified)"""
query_lower = query.lower()
# Simple math or general knowledge
if any(word in query_lower for word in ["what is", "define", "2+2", "hello"]):
if any(word in query_lower for word in ["in", "year", "when", "who", "where"]):
return True, "[Retrieve] - Query requires factual information"
return False, "[No Retrieve] - Simple query, no retrieval needed"
# Factual questions
if any(word in query_lower for word in ["when", "where", "who", "how many", "what year"]):
return True, "[Retrieve] - Factual question detected"
# Opinion or creative
if any(word in query_lower for word in ["should", "opinion", "think", "feel", "story"]):
return False, "[No Retrieve] - Opinion/creative query"
# Default: retrieve
return True, "[Retrieve] - Complex query may need information"
def retrieve_docs(self, query: str, top_k: int = 3) -> List[Tuple[str, str, float]]:
"""Retrieve relevant documents (simple keyword matching)"""
query_words = set(query.lower().split())
scored_docs = []
for doc_id, doc_text in self.documents.items():
doc_words = set(doc_text.lower().split())
# Simple overlap score
overlap = len(query_words.intersection(doc_words))
if overlap > 0:
scored_docs.append((doc_id, doc_text, overlap))
# Sort by score
scored_docs.sort(key=lambda x: x[2], reverse=True)
return scored_docs[:top_k]
def evaluate_relevance(self, query: str, doc: str) -> Tuple[bool, str]:
"""Check if document is relevant (simplified)"""
query_words = set(query.lower().split())
doc_words = set(doc.lower().split())
overlap = len(query_words.intersection(doc_words))
if overlap >= 2:
return True, f"[Relevant] - {overlap} matching keywords"
else:
return False, f"[Irrelevant] - Only {overlap} matching keywords"
def generate_answer(self, query: str, relevant_docs: List[str]) -> str:
"""Generate answer from docs (rule-based simulation)"""
if not relevant_docs:
return "I don't have enough information to answer this question."
# Simple extraction: return first relevant sentence
for doc in relevant_docs:
sentences = doc.split('. ')
for sentence in sentences:
query_words = set(query.lower().split())
sentence_words = set(sentence.lower().split())
if len(query_words.intersection(sentence_words)) >= 2:
return sentence.strip()
return relevant_docs[0].split('.')[0] # Fallback
def check_support(self, answer: str, docs: List[str]) -> Tuple[bool, str]:
"""Check if answer is supported by documents"""
answer_words = set(answer.lower().split())
for doc in docs:
doc_words = set(doc.lower().split())
overlap = len(answer_words.intersection(doc_words))
# If most answer words are in doc, it's supported
if overlap >= len(answer_words) * 0.6:
return True, "[Supported] - Answer verified against documents"
return False, "[Not Supported] - Answer not well-supported, needs revision"
def revise_answer(self, query: str, original_answer: str, docs: List[str]) -> str:
"""Revise answer to be better supported"""
# Find more specific information
for doc in docs:
if len(doc) > len(original_answer):
# Extract relevant portion
sentences = doc.split('. ')
for sent in sentences:
if any(word in sent.lower() for word in query.lower().split()):
return sent.strip()
return f"{original_answer} (Based on provided documents)"
def run_self_rag(self, query: str) -> Dict:
"""Run complete Self-RAG pipeline"""
self.steps = []
result = {
"query": query,
"decision": "",
"retrieved_docs": [],
"relevant_docs": [],
"answer": "",
"verification": "",
"final_answer": "",
"steps": []
}
# Step 1: Decide if retrieval needed
should_retrieve, decision = self.should_retrieve(query)
result["decision"] = decision
self.steps.append(f"**Step 1:** {decision}")
if not should_retrieve:
result["answer"] = "This query doesn't require document retrieval. (Demo: would use LLM's parametric knowledge)"
result["final_answer"] = result["answer"]
self.steps.append(f"**Step 2:** Generated answer without retrieval")
result["steps"] = self.steps
return result
# Step 2: Retrieve documents
retrieved = self.retrieve_docs(query)
result["retrieved_docs"] = [(doc_id, text) for doc_id, text, score in retrieved]
self.steps.append(f"**Step 2:** Retrieved {len(retrieved)} documents")
# Step 3: Evaluate relevance
relevant_docs = []
for doc_id, text, score in retrieved:
is_relevant, relevance_msg = self.evaluate_relevance(query, text)
self.steps.append(f"**Step 3.{len(relevant_docs)+1}:** Doc '{doc_id}' - {relevance_msg}")
if is_relevant:
relevant_docs.append(text)
result["relevant_docs"].append((doc_id, text, relevance_msg))
if not relevant_docs:
result["answer"] = "No relevant documents found."
result["final_answer"] = result["answer"]
self.steps.append(f"**Step 4:** No relevant documents, cannot answer")
result["steps"] = self.steps
return result
# Step 4: Generate answer
answer = self.generate_answer(query, relevant_docs)
result["answer"] = answer
self.steps.append(f"**Step 4:** Generated initial answer")
# Step 5: Check if answer is supported
is_supported, support_msg = self.check_support(answer, relevant_docs)
result["verification"] = support_msg
self.steps.append(f"**Step 5:** {support_msg}")
# Step 6: Self-correct if needed
if not is_supported:
revised_answer = self.revise_answer(query, answer, relevant_docs)
result["final_answer"] = revised_answer
self.steps.append(f"**Step 6:** Self-corrected answer")
# Re-check
is_supported2, support_msg2 = self.check_support(revised_answer, relevant_docs)
self.steps.append(f"**Step 7:** {support_msg2}")
else:
result["final_answer"] = answer
self.steps.append(f"**Step 6:** Answer is well-supported, no correction needed")
result["steps"] = self.steps
return result
# =================================================================
# VISUALIZATION
# =================================================================
def format_result(result: Dict) -> Tuple[str, str, str]:
"""Format Self-RAG result for display"""
# Steps
steps_md = "## πŸ”„ Self-RAG Process\n\n"
for step in result["steps"]:
steps_md += f"{step}\n\n"
# Documents
docs_md = "## πŸ“„ Retrieved Documents\n\n"
if result["retrieved_docs"]:
for i, (doc_id, text) in enumerate(result["retrieved_docs"], 1):
docs_md += f"### Document {i}: `{doc_id}`\n"
docs_md += f"{text}\n\n"
# Mark if relevant
is_in_relevant = any(doc_id == rd[0] for rd in result["relevant_docs"])
if is_in_relevant:
docs_md += "βœ… **Marked as Relevant**\n\n"
else:
docs_md += "❌ **Marked as Irrelevant**\n\n"
else:
docs_md += "*No documents retrieved*\n\n"
# Answer
answer_md = "## πŸ’‘ Final Answer\n\n"
answer_md += f"**Query:** {result['query']}\n\n"
if result.get("answer") and result["answer"] != result["final_answer"]:
answer_md += f"**Initial Answer:** {result['answer']}\n\n"
answer_md += f"**After Self-Correction:** {result['final_answer']}\n\n"
else:
answer_md += f"**Answer:** {result['final_answer']}\n\n"
answer_md += f"**Verification:** {result.get('verification', 'N/A')}\n\n"
# Stats
stats_md = "### πŸ“Š Statistics\n\n"
stats_md += f"- **Decision:** {result['decision']}\n"
stats_md += f"- **Documents Retrieved:** {len(result['retrieved_docs'])}\n"
stats_md += f"- **Relevant Documents:** {len(result['relevant_docs'])}\n"
stats_md += f"- **Self-Correction:** {'Yes' if result.get('answer') != result['final_answer'] else 'No'}\n"
answer_md += f"\n{stats_md}"
return steps_md, docs_md, answer_md
# =================================================================
# GRADIO INTERFACE
# =================================================================
def process_query(query: str):
"""Process query with Self-RAG"""
if not query or len(query) < 3:
return "⚠️ Please enter a valid query", "", ""
simulator = SelfRAGSimulator()
result = simulator.run_self_rag(query)
steps, docs, answer = format_result(result)
return steps, docs, answer
with gr.Blocks(theme=gr.themes.Soft(), title="Self-RAG Demo") as demo:
gr.Markdown("""
# πŸ”„ Self-RAG Demo: Adaptive Retrieval with Self-Correction
**Educational demonstration of Self-Reflective RAG**
⚠️ **DEMO DISCLAIMER:** This is a simplified educational demo using rule-based logic.
Production Self-RAG uses trained models with reflection tokens.
### What You'll See:
1. **Adaptive Retrieval** - System decides if retrieval is needed
2. **Relevance Evaluation** - Which documents are actually useful?
3. **Answer Verification** - Is answer supported by sources?
4. **Self-Correction** - Automatic answer revision if needed
### Key Innovation:
Self-RAG is **+5-15% more accurate** than traditional RAG and uses **40% fewer retrievals**.
**Reference:** Asai et al. (2023) "Self-RAG" - arXiv:2310.11511
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Your Query")
query_input = gr.Textbox(
label="Ask a Question",
placeholder="e.g., When did Marie Curie win her Nobel Prizes?",
lines=3
)
submit_btn = gr.Button("πŸš€ Run Self-RAG", variant="primary", size="lg")
gr.Markdown("""
### Example Queries:
- **Factual:** "When did Marie Curie win her Nobel Prizes?"
- **Simple:** "What is 2+2?" (won't retrieve)
- **Complex:** "What was Brazil's GDP in 2023?"
- **Ambiguous:** "Tell me about Einstein"
""")
gr.Examples(
examples=[
["When did Marie Curie win her Nobel Prizes?"],
["What is 2+2?"],
["What was Brazil's GDP in 2023?"],
["Who developed the COVID-19 vaccines?"],
["Tell me a story about space"],
],
inputs=[query_input]
)
with gr.Column(scale=2):
gr.Markdown("### Results")
with gr.Tab("πŸ”„ Process Steps"):
steps_output = gr.Markdown()
with gr.Tab("πŸ“„ Documents"):
docs_output = gr.Markdown()
with gr.Tab("πŸ’‘ Final Answer"):
answer_output = gr.Markdown()
submit_btn.click(
fn=process_query,
inputs=[query_input],
outputs=[steps_output, docs_output, answer_output]
)
gr.Markdown("""
---
### πŸŽ“ Understanding Self-RAG
**Traditional RAG:**
```
Query β†’ Retrieve ALL β†’ Generate Answer
```
**Self-RAG:**
```
Query β†’ [Decide: Retrieve?]
β†’ If yes: Retrieve β†’ [Evaluate: Relevant?]
β†’ Generate β†’ [Check: Supported?]
β†’ If no: [Self-Correct]
β†’ Final Answer
```
**Why It's Better:**
- 🎯 **Smarter:** Only retrieves when needed (40% savings)
- βœ… **More Accurate:** Verifies answer quality (+5-15%)
- πŸ” **Explainable:** Shows reasoning process
- πŸ”„ **Self-Improving:** Corrects mistakes automatically
**Real-World Applications:**
- Medical Q&A (must be accurate and sourced)
- Legal research (citations critical)
- Fact-checking (verify claims)
- Educational tools (teach verification)
**This Demo:**
- Uses simplified rule-based logic
- Production Self-RAG trains models with reflection tokens
- See paper for full methodology: https://arxiv.org/abs/2310.11511
**Created by:** Demetrios Chiuratto Agourakis | [GitHub](https://github.com/Agourakis82)
**⚠️ REMINDER:** Educational demonstration only. Production deployment requires trained models.
""")
if __name__ == "__main__":
demo.launch()