HusainNaqvijobs commited on
Commit
6fd7019
·
verified ·
1 Parent(s): 577ece6
Files changed (1) hide show
  1. app.py +66 -51
app.py CHANGED
@@ -1,6 +1,9 @@
1
  import gradio as gr
2
  from PIL import Image
3
  import torch
 
 
 
4
 
5
  # Lingshu-7B imports
6
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
@@ -8,71 +11,83 @@ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
8
  # MedGemma imports
9
  from transformers import pipeline
10
 
11
- def load_lingshu_model():
12
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
13
- "lingshu-medical-mllm/Lingshu-7B",
14
- torch_dtype=torch.bfloat16,
15
- attn_implementation="flash_attention_2",
16
- device_map="auto"
17
- )
18
- processor = AutoProcessor.from_pretrained("lingshu-medical-mllm/Lingshu-7B")
19
- return model, processor
20
-
21
- def load_medgemma_model():
22
- pipe = pipeline(
23
- "image-text-to-text",
24
- model="google/medgemma-27b-it",
25
- torch_dtype=torch.bfloat16,
26
- device="cuda"
27
- )
28
- return pipe
29
-
30
  lingshu_model, lingshu_processor = None, None
31
  medgemma_pipe = None
32
 
33
- def setup_models(selected_model):
34
- global lingshu_model, lingshu_processor, medgemma_pipe
35
- if selected_model == "Lingshu-7B" and lingshu_model is None:
36
- lingshu_model, lingshu_processor = load_lingshu_model()
37
- if selected_model == "MedGemma-27B-IT" and medgemma_pipe is None:
38
- medgemma_pipe = load_medgemma_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- def med_ai_inference(img, prompt, model_type):
41
- setup_models(model_type)
42
- if model_type == "Lingshu-7B":
 
 
 
43
  messages = [
44
- {
45
- "role": "user",
46
- "content": [
47
- {"type": "image", "image": img},
48
- {"type": "text", "text": prompt}
49
- ]
50
- }
51
  ]
52
- text = lingshu_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
53
- inputs = lingshu_processor(text=[text], images=[img], padding=True, return_tensors="pt").to(lingshu_model.device)
 
 
 
 
 
54
  with torch.no_grad():
55
- generated_ids = lingshu_model.generate(**inputs, max_new_tokens=128)
56
  trim_ids = generated_ids[:, inputs.input_ids.shape[1]:]
57
- out_text = lingshu_processor.batch_decode(trim_ids, skip_special_tokens=True)
58
- return out_text[0]
59
- if model_type == "MedGemma-27B-IT":
60
- # MedGemma expects messages
61
  messages = [
62
  {"role": "system", "content": [{"type": "text", "text": "You are a medical expert."}]},
63
- {"role": "user", "content": [{"type": "text", "text": prompt}, {"type": "image", "image": img}]}
 
 
 
64
  ]
65
- res = medgemma_pipe(text=messages, max_new_tokens=200)
66
- return res[0]["generated_text"][-1]["content"]
 
 
 
 
67
 
68
  with gr.Blocks() as demo:
69
- gr.Markdown("# Medical AI Companion")
70
- gr.Markdown("Upload a medical image, type your medical question or prompt, and select a model for automated report/answer.")
71
  model_radio = gr.Radio(label="Model", choices=["Lingshu-7B", "MedGemma-27B-IT"], value="Lingshu-7B")
72
- img_input = gr.Image(type="pil", label="Medical Image")
73
  text_input = gr.Textbox(lines=2, label="Prompt", value="Describe this image.")
74
- outbox = gr.Textbox(lines=10, label="AI Report / Answer", interactive=False)
75
- run_btn = gr.Button("Analyze")
76
- run_btn.click(med_ai_inference, [img_input, text_input, model_radio], outbox)
77
 
78
  demo.launch()
 
1
  import gradio as gr
2
  from PIL import Image
3
  import torch
4
+ import os
5
+
6
+ # Your Hugging Face token for gated model access
7
 
8
  # Lingshu-7B imports
9
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
 
11
  # MedGemma imports
12
  from transformers import pipeline
13
 
14
+ # Caching models and processors to avoid repeat loading
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  lingshu_model, lingshu_processor = None, None
16
  medgemma_pipe = None
17
 
18
+ # Load Lingshu-7B
19
+ def load_lingshu():
20
+ global lingshu_model, lingshu_processor
21
+ if lingshu_model is None or lingshu_processor is None:
22
+ lingshu_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
23
+ "lingshu-medical-mllm/Lingshu-7B",
24
+ torch_dtype=torch.bfloat16,
25
+ attn_implementation="flash_attention_2",
26
+ device_map="auto"
27
+ )
28
+ lingshu_processor = AutoProcessor.from_pretrained("lingshu-medical-mllm/Lingshu-7B")
29
+ return lingshu_model, lingshu_processor
30
+
31
+ # Load MedGemma-27B-IT with token for gated access
32
+ def load_medgemma():
33
+ global medgemma_pipe
34
+ if medgemma_pipe is None:
35
+ medgemma_pipe = pipeline(
36
+ "image-text-to-text",
37
+ model="google/medgemma-27b-it",
38
+ torch_dtype=torch.bfloat16,
39
+ device="cuda",
40
+ use_auth_token=HF_TOKEN
41
+ )
42
+ return medgemma_pipe
43
 
44
+ def inference(image, question, selected_model):
45
+ # Check image and question validity
46
+ if image is None or question is None or question.strip() == "":
47
+ return "Please upload a medical image and enter your question/prompt."
48
+ if selected_model == "Lingshu-7B":
49
+ model, processor = load_lingshu()
50
  messages = [
51
+ {"role": "user", "content": [
52
+ {"type": "image", "image": image},
53
+ {"type": "text", "text": question}
54
+ ]}
 
 
 
55
  ]
56
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
57
+ inputs = processor(
58
+ text=[text],
59
+ images=[image],
60
+ padding=True,
61
+ return_tensors="pt"
62
+ ).to(model.device)
63
  with torch.no_grad():
64
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
65
  trim_ids = generated_ids[:, inputs.input_ids.shape[1]:]
66
+ out_text = processor.batch_decode(trim_ids, skip_special_tokens=True)
67
+ return out_text[0] if out_text else "No response."
68
+ elif selected_model == "MedGemma-27B-IT":
69
+ pipe = load_medgemma()
70
  messages = [
71
  {"role": "system", "content": [{"type": "text", "text": "You are a medical expert."}]},
72
+ {"role": "user", "content": [
73
+ {"type": "text", "text": question},
74
+ {"type": "image", "image": image}
75
+ ]}
76
  ]
77
+ try:
78
+ res = pipe(text=messages, max_new_tokens=200)
79
+ return res[0]["generated_text"][-1]["content"]
80
+ except Exception as e:
81
+ return f"MedGemma error: {str(e)}"
82
+ return "Please select a valid model."
83
 
84
  with gr.Blocks() as demo:
85
+ gr.Markdown("## 🩺 Multi-Modality Medical AI Doctor Companion\nUpload a medical image, type your question, and select a model to generate automated analysis/report.")
 
86
  model_radio = gr.Radio(label="Model", choices=["Lingshu-7B", "MedGemma-27B-IT"], value="Lingshu-7B")
87
+ image_input = gr.Image(type="pil", label="Medical Image")
88
  text_input = gr.Textbox(lines=2, label="Prompt", value="Describe this image.")
89
+ outbox = gr.Textbox(lines=10, label="AI Answer / Report", interactive=False)
90
+ run_btn = gr.Button("Run Analysis")
91
+ run_btn.click(inference, [image_input, text_input, model_radio], outbox)
92
 
93
  demo.launch()