Spaces:
Runtime error
Runtime error
| from langchain.chains import ConversationChain, LLMChain | |
| from langchain.prompts import PromptTemplate | |
| from langchain.llms import HuggingFacePipeline | |
| from langchain.memory import ConversationBufferMemory | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| import torch | |
| import gradio as gr | |
| # Model configuration | |
| LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf" | |
| MEDITRON_MODEL = "epfl-llm/meditron-7b" | |
| # System prompts | |
| SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data. | |
| Ask 1-2 follow-up questions at a time to gather more details about: | |
| - Detailed description of symptoms | |
| - Duration (when did it start?) | |
| - Severity (scale of 1-10) | |
| - Aggravating or alleviating factors | |
| - Related symptoms | |
| - Medical history | |
| - Current medications and allergies | |
| After collecting sufficient information (4-5 exchanges), summarize findings and suggest when they should seek professional care. Do NOT make specific diagnoses or recommend specific treatments. | |
| Respond empathetically and clearly. Always be professional and thorough.""" | |
| MEDITRON_PROMPT = """<|im_start|>system | |
| You are a specialized medical assistant focusing ONLY on suggesting over-the-counter medicines and home remedies based on patient information. | |
| Based on the following patient information, provide ONLY: | |
| 1. One specific over-the-counter medicine with proper adult dosing instructions | |
| 2. One practical home remedy that might help | |
| 3. Clear guidance on when to seek professional medical care | |
| Be concise, practical, and focus only on general symptom relief. Do not diagnose. Include a disclaimer that you are not a licensed medical professional. | |
| <|im_end|> | |
| <|im_start|>user | |
| Patient information: {patient_info} | |
| <|im_end|> | |
| <|im_start|>assistant | |
| """ | |
| print("Loading Llama-2 model...") | |
| # Create LangChain wrapper for Llama-2 | |
| llama_tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL) | |
| llama_model = AutoModelForCausalLM.from_pretrained( | |
| LLAMA_MODEL, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| # Create a pipeline for LangChain | |
| llama_pipeline = pipeline( | |
| "text-generation", | |
| model=llama_model, | |
| tokenizer=llama_tokenizer, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True | |
| ) | |
| llama_llm = HuggingFacePipeline(pipeline=llama_pipeline) | |
| print("Llama-2 model loaded successfully!") | |
| print("Loading Meditron model...") | |
| meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL) | |
| meditron_model = AutoModelForCausalLM.from_pretrained( | |
| MEDITRON_MODEL, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| # Create a pipeline for Meditron | |
| meditron_pipeline = pipeline( | |
| "text-generation", | |
| model=meditron_model, | |
| tokenizer=meditron_tokenizer, | |
| max_new_tokens=256, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True | |
| ) | |
| meditron_llm = HuggingFacePipeline(pipeline=meditron_pipeline) | |
| print("Meditron model loaded successfully!") | |
| # Create LangChain conversation with memory | |
| memory = ConversationBufferMemory(return_messages=True) | |
| conversation = ConversationChain( | |
| llm=llama_llm, | |
| memory=memory, | |
| verbose=True | |
| ) | |
| # Create a template for the Meditron model | |
| meditron_template = PromptTemplate( | |
| input_variables=["patient_info"], | |
| template=MEDITRON_PROMPT | |
| ) | |
| meditron_chain = LLMChain( | |
| llm=meditron_llm, | |
| prompt=meditron_template, | |
| verbose=True | |
| ) | |
| # Track conversation turns | |
| conversation_turns = 0 | |
| patient_data = [] | |
| def generate_response(message, history): | |
| global conversation_turns, patient_data | |
| conversation_turns += 1 | |
| # Store patient message | |
| patient_data.append(message) | |
| # Format the prompt with system instructions | |
| if conversation_turns >= 4: | |
| # Add summarization instruction after 4 turns | |
| prompt = f"{SYSTEM_PROMPT}\n\nNow summarize what you've learned and suggest when professional care may be needed.\n\n{message}" | |
| else: | |
| prompt = f"{SYSTEM_PROMPT}\n\n{message}" | |
| # Generate response using LangChain conversation | |
| llama_response = conversation.predict(input=prompt) | |
| # After 4 turns, add medicine suggestions from Meditron | |
| if conversation_turns >= 4: | |
| # Collect full patient conversation | |
| full_patient_info = "\n".join(patient_data) + "\n\nSummary: " + llama_response | |
| # Get medicine suggestions using LangChain | |
| medicine_suggestions = meditron_chain.run(patient_info=full_patient_info) | |
| # Format final response | |
| final_response = ( | |
| f"{llama_response}\n\n" | |
| f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n" | |
| f"{medicine_suggestions}" | |
| ) | |
| return final_response | |
| return llama_response | |
| # Create the Gradio interface | |
| demo = gr.ChatInterface( | |
| fn=generate_response, | |
| title="Medical Assistant with Medicine Suggestions", | |
| description="Tell me about your symptoms, and after gathering enough information, I'll suggest potential remedies.", | |
| examples=[ | |
| "I have a cough and my throat hurts", | |
| "I've been having headaches for a week", | |
| "My stomach has been hurting since yesterday" | |
| ], | |
| theme="soft" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |