import os import base64 import json import ast import re from io import BytesIO import types import sys # Force CPU-only & disable bitsandbytes CUDA checks in this environment os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1") os.environ.setdefault("BITSANDBYTES_DISABLE_CUDA_CHECK", "1") import torch import torchvision.transforms as T from PIL import Image from torchvision.transforms.functional import InterpolationMode from fastapi import FastAPI, HTTPException from pydantic import BaseModel # Stub bitsandbytes to avoid GPU driver checks in CPU-only environments fake_bnb = types.ModuleType("bitsandbytes") def _bnb_unavailable(*args, **kwargs): raise ImportError("bitsandbytes is not available in this CPU-only deployment") fake_bnb.__all__ = ["_bnb_unavailable"] fake_bnb._bnb_unavailable = _bnb_unavailable sys.modules["bitsandbytes"] = fake_bnb from transformers import AutoModel, AutoTokenizer app = FastAPI(title="CCCD OCR with Vintern-1B-v2") MODEL_NAME = "5CD-AI/Vintern-1B-v2" # Force CPU-only to avoid NVIDIA driver / CUDA issues on Spaces DEVICE = "cpu" DTYPE = torch.float32 print(f"Loading model `{MODEL_NAME}` on {DEVICE} ...") tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, trust_remote_code=True, use_fast=False, ) model = AutoModel.from_pretrained( MODEL_NAME, torch_dtype=DTYPE, low_cpu_mem_usage=True, trust_remote_code=True, ) model.eval().to(DEVICE) generation_config = dict( max_new_tokens=512, do_sample=False, num_beams=3, repetition_penalty=3.5, ) # ========================= # Image preprocessing (from notebook) # ========================= IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) def build_transform(input_size: int): mean, std = IMAGENET_MEAN, IMAGENET_STD transform = T.Compose( [ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=mean, std=std), ] ) return transform def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float("inf") best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio return best_ratio def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num ) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size ) target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size, ) split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images def load_image_from_base64(base64_string: str, input_size=448, max_num=12): if base64_string.startswith("data:image"): base64_string = base64_string.split(",", 1)[1] image_data = base64.b64decode(base64_string) image = Image.open(BytesIO(image_data)).convert("RGB") transform = build_transform(input_size=input_size) images = dynamic_preprocess( image, image_size=input_size, use_thumbnail=True, max_num=max_num ) pixel_values = [transform(img) for img in images] pixel_values = torch.stack(pixel_values) return pixel_values # ========================= # Prompt & helpers # ========================= PROMPT = """ Bạn là hệ thống OCR + trích xuất dữ liệu từ ảnh Căn cước công dân (CCCD) Việt Nam. Nhiệm vụ: đọc đúng chữ trên thẻ và trả về CHỈ 1 đối tượng JSON theo schema quy định. QUY TẮC BẮT BUỘC: 1) Chỉ trả về JSON thuần (không markdown, không giải thích, không thêm ký tự nào ngoài JSON). 2) Chỉ được có đúng 5 khóa sau (đúng chính tả, đúng chữ thường, có dấu gạch dưới): - "so_no" - "ho_va_ten" - "ngay_sinh" - "que_quan" - "noi_thuong_tru" Không được thêm bất kỳ khóa nào khác. 3) Mapping trường (lấy theo NHÃN in trên thẻ, không lấy từ QR): - so_no: lấy giá trị ngay sau nhãn "Số / No." (hoặc "Số/No."). - ho_va_ten: lấy giá trị ngay sau nhãn "Họ và tên / Full name". - ngay_sinh: lấy giá trị ngay sau nhãn "Ngày sinh / Date of birth"; nếu có định dạng dd/mm/yyyy thì giữ đúng dd/mm/yyyy. - que_quan: lấy giá trị ngay sau nhãn "Quê quán / Place of origin". - noi_thuong_tru: lấy giá trị ngay sau nhãn "Nơi thường trú / Place of residence". 4) Nếu trường nào không đọc được rõ/chắc chắn: đặt null. Không được suy đoán. 5) Chuẩn hoá: trim khoảng trắng đầu/cuối; giữ nguyên dấu tiếng Việt và chữ hoa/thường như trong ảnh. CHỈ TRẢ VỀ THEO MẪU JSON NÀY: { "so_no": "... hoặc null", "ho_va_ten": "... hoặc null", "ngay_sinh": "... hoặc null", "que_quan": "... hoặc null", "noi_thuong_tru": "... hoặc null" } """ def parse_response_to_json(response_text: str): if not response_text: return None s = response_text.strip() if s.startswith('"') and s.endswith('"'): s = s[1:-1].replace('\\"', '"') try: obj = json.loads(s) if isinstance(obj, dict): return obj except json.JSONDecodeError: pass try: obj = ast.literal_eval(s) if isinstance(obj, dict): return obj except (ValueError, SyntaxError): pass json_pattern = r"\{[\s\S]*\}" m = re.search(json_pattern, s) if m: chunk = m.group(0).strip() try: obj = ast.literal_eval(chunk) if isinstance(obj, dict): return obj except Exception: pass try: chunk2 = chunk.replace("'", '"') obj = json.loads(chunk2) if isinstance(obj, dict): return obj except Exception: pass return {"text": response_text} def normalize_base64(image_base64: str) -> str: if not image_base64: return image_base64 image_base64 = image_base64.strip() if image_base64.startswith("data:"): parts = image_base64.split(",", 1) if len(parts) == 2: return parts[1] return image_base64 def ocr_by_llm(image_base64: str, prompt: str) -> str: pixel_values = load_image_from_base64(image_base64, max_num=6) pixel_values = pixel_values.to(dtype=torch.float32, device=DEVICE) with torch.no_grad(): response_message = model.chat( tokenizer, pixel_values, prompt, generation_config, ) del pixel_values return response_message class OCRRequest(BaseModel): image_base64: str @app.post("/ocr") def ocr_endpoint(req: OCRRequest): image_base64 = normalize_base64(req.image_base64) if not image_base64: raise HTTPException(status_code=400, detail="image_base64 is required") try: response_message = ocr_by_llm(image_base64, PROMPT) parsed = parse_response_to_json(response_message) return {"response_message": parsed} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.on_event("startup") async def startup_log(): """ Log basic information about available endpoints when the app starts. """ print("============================================") print("CCCD OCR API is running") print("Main endpoint: POST /ocr") print("Docs (Swagger): GET /docs") print("Redoc: GET /redoc") print("============================================")