| |
| |
| |
| |
|
|
| import time |
| import io |
| import requests |
| from PIL import Image |
| import gradio as gr |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| MODEL_ID = "apple/FastVLM-0.5B" |
| IMAGE_TOKEN_INDEX = -200 |
| DEVICE = "cpu" |
|
|
| |
| SAMPLES = { |
| |
| "Dog-in-street (COCO)": "http://images.cocodataset.org/val2017/000000039769.jpg", |
| |
| "Chart β Blind wine tasting": "https://huggingface.co/datasets/lytang/ChartMuseum/resolve/main/images/wine_blind_taste.png", |
| "Chart β Life expectancy (Africa vs Asia)": "https://huggingface.co/datasets/lytang/ChartMuseum/resolve/main/images/life-expectancy-africa-vs-asia.png", |
| |
| "Document page β example": "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/1.jpg", |
| } |
|
|
| TASK_PROMPTS = { |
| "Explain": "Describe this image in detail.", |
| "Extract numbers": ( |
| "Extract every number you can see with its label/context. " |
| "Return a concise YAML list with fields: value, what_it_refers_to." |
| ), |
| "Write alt-text": ( |
| "Write high-quality alt-text (<=200 chars) that would help a blind user understand " |
| "the key content and purpose of this image." |
| ), |
| "Ask a questionβ¦": None, |
| } |
|
|
| |
| tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.float32, |
| device_map={"": DEVICE}, |
| trust_remote_code=True, |
| ) |
|
|
| |
| def _fetch_image(url: str) -> Image.Image: |
| r = requests.get(url, timeout=20) |
| r.raise_for_status() |
| return Image.open(io.BytesIO(r.content)).convert("RGB") |
|
|
| def _build_inputs(prompt: str): |
| |
| messages = [{"role": "user", "content": f"<image>\n{prompt}"}] |
| rendered = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
| pre, post = rendered.split("<image>", 1) |
|
|
| pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids |
| post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids |
| img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype) |
|
|
| input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device) |
| attention_mask = torch.ones_like(input_ids, device=model.device) |
| return input_ids, attention_mask |
|
|
| def _prepare_pixels(pil_image: Image.Image): |
| |
| px = model.get_vision_tower().image_processor(images=pil_image, return_tensors="pt")["pixel_values"] |
| return px.to(model.device, dtype=model.dtype) |
|
|
| @torch.inference_mode() |
| def run_inference(choice: str, task: str, user_q: str, max_new_tokens: int, temperature: float): |
| try: |
| img = _fetch_image(SAMPLES[choice]) |
| except Exception as e: |
| return None, f"Could not load image: {e}", "" |
|
|
| |
| if task == "Ask a questionβ¦": |
| prompt = user_q.strip() or "Answer questions about this image." |
| else: |
| prompt = TASK_PROMPTS[task] |
|
|
| |
| input_ids, attention_mask = _build_inputs(prompt) |
| px = _prepare_pixels(img) |
|
|
| |
| t0 = time.perf_counter() |
| out = model.generate( |
| inputs=input_ids, |
| attention_mask=attention_mask, |
| images=px, |
| max_new_tokens=int(max_new_tokens), |
| temperature=float(temperature), |
| ) |
| t1 = time.perf_counter() |
|
|
| text = tok.decode(out[0], skip_special_tokens=True) |
|
|
| |
| gen_len = (out.shape[-1] - input_ids.shape[-1]) |
| elapsed = t1 - t0 |
| meta = f"β±οΈ {elapsed:.2f}s β’ new tokens: {gen_len} β’ ~{(gen_len/elapsed if elapsed>0 else 0):.1f} tok/s β’ device: {DEVICE.upper()}" |
|
|
| return img, text.strip(), meta |
|
|
| |
| with gr.Blocks(title="FastVLM Screenshot Explainer (CPU)") as demo: |
| gr.Markdown( |
| """ |
| # β‘ FastVLM Screenshot Explainer β CPU-only (no uploads) |
| Click an example image, pick a task, and go. |
| Model: **apple/FastVLM-0.5B** (research license). |
| """ |
| ) |
|
|
| with gr.Row(): |
| choice = gr.Dropdown( |
| label="Choose example image", |
| choices=list(SAMPLES.keys()), |
| value=list(SAMPLES.keys())[0], |
| ) |
| task = gr.Radio( |
| label="Task", |
| choices=list(TASK_PROMPTS.keys()), |
| value="Explain", |
| info="βAsk a questionβ¦β enables free-form VQA.", |
| ) |
| user_q = gr.Textbox(label="If asking a question, type it here", placeholder="e.g., What is the trend from 1950 to 2000?") |
| with gr.Accordion("Generation settings", open=False): |
| max_new = gr.Slider(32, 256, 128, step=8, label="max_new_tokens") |
| temp = gr.Slider(0.0, 1.0, 0.2, step=0.05, label="temperature") |
|
|
| go = gr.Button("Explain / Answer", variant="primary") |
| with gr.Row(): |
| img_out = gr.Image(label="Image", interactive=False) |
| txt_out = gr.Textbox(label="Model output", lines=14) |
| meta = gr.Markdown() |
|
|
| go.click(run_inference, [choice, task, user_q, max_new, temp], [img_out, txt_out, meta]) |
|
|
| gr.Markdown( |
| """ |
| **Notes** |
| - Runs on CPU by default (float32). For GPUs, restart Space with CUDA and it will auto-use float16. |
| - Model + usage based on the official model cardβs `trust_remote_code` API and <image> token handling. |
| - **License:** Apple AML Research License β *research & non-commercial use only*. |
| """ |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|