HusainNaqvijobs commited on
Commit
cf2080f
·
verified ·
1 Parent(s): 0b26654
Files changed (1) hide show
  1. app.py +78 -0
app.py CHANGED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+
5
+ # Lingshu-7B imports
6
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
7
+
8
+ # MedGemma imports
9
+ from transformers import pipeline
10
+
11
+ def load_lingshu_model():
12
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
13
+ "lingshu-medical-mllm/Lingshu-7B",
14
+ torch_dtype=torch.bfloat16,
15
+ attn_implementation="flash_attention_2",
16
+ device_map="auto"
17
+ )
18
+ processor = AutoProcessor.from_pretrained("lingshu-medical-mllm/Lingshu-7B")
19
+ return model, processor
20
+
21
+ def load_medgemma_model():
22
+ pipe = pipeline(
23
+ "image-text-to-text",
24
+ model="google/medgemma-27b-it",
25
+ torch_dtype=torch.bfloat16,
26
+ device="cuda"
27
+ )
28
+ return pipe
29
+
30
+ lingshu_model, lingshu_processor = None, None
31
+ medgemma_pipe = None
32
+
33
+ def setup_models(selected_model):
34
+ global lingshu_model, lingshu_processor, medgemma_pipe
35
+ if selected_model == "Lingshu-7B" and lingshu_model is None:
36
+ lingshu_model, lingshu_processor = load_lingshu_model()
37
+ if selected_model == "MedGemma-27B-IT" and medgemma_pipe is None:
38
+ medgemma_pipe = load_medgemma_model()
39
+
40
+ def med_ai_inference(img, prompt, model_type):
41
+ setup_models(model_type)
42
+ if model_type == "Lingshu-7B":
43
+ messages = [
44
+ {
45
+ "role": "user",
46
+ "content": [
47
+ {"type": "image", "image": img},
48
+ {"type": "text", "text": prompt}
49
+ ]
50
+ }
51
+ ]
52
+ text = lingshu_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
53
+ inputs = lingshu_processor(text=[text], images=[img], padding=True, return_tensors="pt").to(lingshu_model.device)
54
+ with torch.no_grad():
55
+ generated_ids = lingshu_model.generate(**inputs, max_new_tokens=128)
56
+ trim_ids = generated_ids[:, inputs.input_ids.shape[1]:]
57
+ out_text = lingshu_processor.batch_decode(trim_ids, skip_special_tokens=True)
58
+ return out_text[0]
59
+ if model_type == "MedGemma-27B-IT":
60
+ # MedGemma expects messages
61
+ messages = [
62
+ {"role": "system", "content": [{"type": "text", "text": "You are a medical expert."}]},
63
+ {"role": "user", "content": [{"type": "text", "text": prompt}, {"type": "image", "image": img}]}
64
+ ]
65
+ res = medgemma_pipe(text=messages, max_new_tokens=200)
66
+ return res[0]["generated_text"][-1]["content"]
67
+
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown("# Medical AI Companion")
70
+ gr.Markdown("Upload a medical image, type your medical question or prompt, and select a model for automated report/answer.")
71
+ model_radio = gr.Radio(label="Model", choices=["Lingshu-7B", "MedGemma-27B-IT"], value="Lingshu-7B")
72
+ img_input = gr.Image(type="pil", label="Medical Image")
73
+ text_input = gr.Textbox(lines=2, label="Prompt", value="Describe this image.")
74
+ outbox = gr.Textbox(lines=10, label="AI Report / Answer", interactive=False)
75
+ run_btn = gr.Button("Analyze")
76
+ run_btn.click(med_ai_inference, [img_input, text_input, model_radio], outbox)
77
+
78
+ demo.launch()