File size: 3,425 Bytes
cf2080f
 
 
6fd7019
 
586f594
 
cf2080f
586f594
cf2080f
586f594
 
cf2080f
 
6fd7019
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf2080f
6fd7019
 
586f594
6fd7019
 
cf2080f
6fd7019
 
 
 
cf2080f
6fd7019
 
 
 
 
 
 
cf2080f
6fd7019
cf2080f
6fd7019
 
 
 
cf2080f
 
6fd7019
 
 
 
cf2080f
6fd7019
 
 
 
 
 
cf2080f
 
6fd7019
cf2080f
6fd7019
cf2080f
6fd7019
 
 
cf2080f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
from PIL import Image
import torch
import os

# Load Hugging Face token securely from Space Secrets
HF_TOKEN = os.getenv("HF_TOKEN")

from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, pipeline

lingshu_model = None
lingshu_processor = None
medgemma_pipe = None

def load_lingshu():
    global lingshu_model, lingshu_processor
    if lingshu_model is None or lingshu_processor is None:
        lingshu_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            "lingshu-medical-mllm/Lingshu-7B",
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        lingshu_processor = AutoProcessor.from_pretrained("lingshu-medical-mllm/Lingshu-7B")
    return lingshu_model, lingshu_processor

def load_medgemma():
    global medgemma_pipe
    if medgemma_pipe is None:
        medgemma_pipe = pipeline(
            "image-text-to-text",
            model="google/medgemma-27b-it",
            torch_dtype=torch.bfloat16,
            device="cuda",
            use_auth_token=HF_TOKEN
        )
    return medgemma_pipe

def inference(image, question, selected_model):
    if image is None or question is None or question.strip() == "":
        return "Please upload a medical image and enter your question or prompt."
    if selected_model == "Lingshu-7B":
        model, processor = load_lingshu()
        messages = [
            {"role": "user", "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": question}
            ]}
        ]
        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = processor(
            text=[text],
            images=[image],
            padding=True,
            return_tensors="pt"
        ).to(model.device)
        with torch.no_grad():
            generated_ids = model.generate(**inputs, max_new_tokens=128)
            trim_ids = generated_ids[:, inputs.input_ids.shape[1]:]
            out_text = processor.batch_decode(trim_ids, skip_special_tokens=True)
        return out_text[0] if out_text else "No response."
    elif selected_model == "MedGemma-27B-IT":
        pipe = load_medgemma()
        messages = [
            {"role": "system", "content": [{"type": "text", "text": "You are a medical expert."}]},
            {"role": "user", "content": [
                {"type": "text", "text": question},
                {"type": "image", "image": image}
            ]}
        ]
        try:
            res = pipe(text=messages, max_new_tokens=200)
            return res[0]["generated_text"][-1]["content"]
        except Exception as e:
            return f"MedGemma error: {str(e)}"
    return "Please select a valid model."

with gr.Blocks() as demo:
    gr.Markdown("## 🩺 Multi-Modality Medical AI Doctor Companion\nUpload a medical image, type your question, and select a model to generate automated analysis/report.")
    model_radio = gr.Radio(label="Model", choices=["Lingshu-7B", "MedGemma-27B-IT"], value="Lingshu-7B")
    image_input = gr.Image(type="pil", label="Medical Image")
    text_input = gr.Textbox(lines=2, label="Prompt", value="Describe this image.")
    outbox = gr.Textbox(lines=10, label="AI Answer / Report", interactive=False)
    run_btn = gr.Button("Run Analysis")
    run_btn.click(inference, [image_input, text_input, model_radio], outbox)

demo.launch()