Spaces:
No application file
No application file
| import os | |
| import base64 | |
| import json | |
| import ast | |
| import re | |
| from io import BytesIO | |
| import types | |
| import sys | |
| # Force CPU-only & disable bitsandbytes CUDA checks in this environment | |
| os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") | |
| os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1") | |
| os.environ.setdefault("BITSANDBYTES_DISABLE_CUDA_CHECK", "1") | |
| import torch | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| from torchvision.transforms.functional import InterpolationMode | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| # Stub bitsandbytes to avoid GPU driver checks in CPU-only environments | |
| fake_bnb = types.ModuleType("bitsandbytes") | |
| def _bnb_unavailable(*args, **kwargs): | |
| raise ImportError("bitsandbytes is not available in this CPU-only deployment") | |
| fake_bnb.__all__ = ["_bnb_unavailable"] | |
| fake_bnb._bnb_unavailable = _bnb_unavailable | |
| sys.modules["bitsandbytes"] = fake_bnb | |
| from transformers import AutoModel, AutoTokenizer | |
| app = FastAPI(title="CCCD OCR with Vintern-1B-v2") | |
| MODEL_NAME = "5CD-AI/Vintern-1B-v2" | |
| # Force CPU-only to avoid NVIDIA driver / CUDA issues on Spaces | |
| DEVICE = "cpu" | |
| DTYPE = torch.float32 | |
| print(f"Loading model `{MODEL_NAME}` on {DEVICE} ...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| use_fast=False, | |
| ) | |
| model = AutoModel.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=DTYPE, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| ) | |
| model.eval().to(DEVICE) | |
| generation_config = dict( | |
| max_new_tokens=512, | |
| do_sample=False, | |
| num_beams=3, | |
| repetition_penalty=3.5, | |
| ) | |
| # ========================= | |
| # Image preprocessing (from notebook) | |
| # ========================= | |
| IMAGENET_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_STD = (0.229, 0.224, 0.225) | |
| def build_transform(input_size: int): | |
| mean, std = IMAGENET_MEAN, IMAGENET_STD | |
| transform = T.Compose( | |
| [ | |
| T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), | |
| T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), | |
| T.ToTensor(), | |
| T.Normalize(mean=mean, std=std), | |
| ] | |
| ) | |
| return transform | |
| def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): | |
| best_ratio_diff = float("inf") | |
| best_ratio = (1, 1) | |
| area = width * height | |
| for ratio in target_ratios: | |
| target_aspect_ratio = ratio[0] / ratio[1] | |
| ratio_diff = abs(aspect_ratio - target_aspect_ratio) | |
| if ratio_diff < best_ratio_diff: | |
| best_ratio_diff = ratio_diff | |
| best_ratio = ratio | |
| elif ratio_diff == best_ratio_diff: | |
| if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: | |
| best_ratio = ratio | |
| return best_ratio | |
| def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): | |
| orig_width, orig_height = image.size | |
| aspect_ratio = orig_width / orig_height | |
| target_ratios = set( | |
| (i, j) | |
| for n in range(min_num, max_num + 1) | |
| for i in range(1, n + 1) | |
| for j in range(1, n + 1) | |
| if i * j <= max_num and i * j >= min_num | |
| ) | |
| target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) | |
| target_aspect_ratio = find_closest_aspect_ratio( | |
| aspect_ratio, target_ratios, orig_width, orig_height, image_size | |
| ) | |
| target_width = image_size * target_aspect_ratio[0] | |
| target_height = image_size * target_aspect_ratio[1] | |
| blocks = target_aspect_ratio[0] * target_aspect_ratio[1] | |
| resized_img = image.resize((target_width, target_height)) | |
| processed_images = [] | |
| for i in range(blocks): | |
| box = ( | |
| (i % (target_width // image_size)) * image_size, | |
| (i // (target_width // image_size)) * image_size, | |
| ((i % (target_width // image_size)) + 1) * image_size, | |
| ((i // (target_width // image_size)) + 1) * image_size, | |
| ) | |
| split_img = resized_img.crop(box) | |
| processed_images.append(split_img) | |
| assert len(processed_images) == blocks | |
| if use_thumbnail and len(processed_images) != 1: | |
| thumbnail_img = image.resize((image_size, image_size)) | |
| processed_images.append(thumbnail_img) | |
| return processed_images | |
| def load_image_from_base64(base64_string: str, input_size=448, max_num=12): | |
| if base64_string.startswith("data:image"): | |
| base64_string = base64_string.split(",", 1)[1] | |
| image_data = base64.b64decode(base64_string) | |
| image = Image.open(BytesIO(image_data)).convert("RGB") | |
| transform = build_transform(input_size=input_size) | |
| images = dynamic_preprocess( | |
| image, image_size=input_size, use_thumbnail=True, max_num=max_num | |
| ) | |
| pixel_values = [transform(img) for img in images] | |
| pixel_values = torch.stack(pixel_values) | |
| return pixel_values | |
| # ========================= | |
| # Prompt & helpers | |
| # ========================= | |
| PROMPT = """<image> | |
| Bạn là hệ thống OCR + trích xuất dữ liệu từ ảnh Căn cước công dân (CCCD) Việt Nam. | |
| Nhiệm vụ: đọc đúng chữ trên thẻ và trả về CHỈ 1 đối tượng JSON theo schema quy định. | |
| QUY TẮC BẮT BUỘC: | |
| 1) Chỉ trả về JSON thuần (không markdown, không giải thích, không thêm ký tự nào ngoài JSON). | |
| 2) Chỉ được có đúng 5 khóa sau (đúng chính tả, đúng chữ thường, có dấu gạch dưới): | |
| - "so_no" | |
| - "ho_va_ten" | |
| - "ngay_sinh" | |
| - "que_quan" | |
| - "noi_thuong_tru" | |
| Không được thêm bất kỳ khóa nào khác. | |
| 3) Mapping trường (lấy theo NHÃN in trên thẻ, không lấy từ QR): | |
| - so_no: lấy giá trị ngay sau nhãn "Số / No." (hoặc "Số/No."). | |
| - ho_va_ten: lấy giá trị ngay sau nhãn "Họ và tên / Full name". | |
| - ngay_sinh: lấy giá trị ngay sau nhãn "Ngày sinh / Date of birth"; nếu có định dạng dd/mm/yyyy thì giữ đúng dd/mm/yyyy. | |
| - que_quan: lấy giá trị ngay sau nhãn "Quê quán / Place of origin". | |
| - noi_thuong_tru: lấy giá trị ngay sau nhãn "Nơi thường trú / Place of residence". | |
| 4) Nếu trường nào không đọc được rõ/chắc chắn: đặt null. Không được suy đoán. | |
| 5) Chuẩn hoá: trim khoảng trắng đầu/cuối; giữ nguyên dấu tiếng Việt và chữ hoa/thường như trong ảnh. | |
| CHỈ TRẢ VỀ THEO MẪU JSON NÀY: | |
| { | |
| "so_no": "... hoặc null", | |
| "ho_va_ten": "... hoặc null", | |
| "ngay_sinh": "... hoặc null", | |
| "que_quan": "... hoặc null", | |
| "noi_thuong_tru": "... hoặc null" | |
| } | |
| """ | |
| def parse_response_to_json(response_text: str): | |
| if not response_text: | |
| return None | |
| s = response_text.strip() | |
| if s.startswith('"') and s.endswith('"'): | |
| s = s[1:-1].replace('\\"', '"') | |
| try: | |
| obj = json.loads(s) | |
| if isinstance(obj, dict): | |
| return obj | |
| except json.JSONDecodeError: | |
| pass | |
| try: | |
| obj = ast.literal_eval(s) | |
| if isinstance(obj, dict): | |
| return obj | |
| except (ValueError, SyntaxError): | |
| pass | |
| json_pattern = r"\{[\s\S]*\}" | |
| m = re.search(json_pattern, s) | |
| if m: | |
| chunk = m.group(0).strip() | |
| try: | |
| obj = ast.literal_eval(chunk) | |
| if isinstance(obj, dict): | |
| return obj | |
| except Exception: | |
| pass | |
| try: | |
| chunk2 = chunk.replace("'", '"') | |
| obj = json.loads(chunk2) | |
| if isinstance(obj, dict): | |
| return obj | |
| except Exception: | |
| pass | |
| return {"text": response_text} | |
| def normalize_base64(image_base64: str) -> str: | |
| if not image_base64: | |
| return image_base64 | |
| image_base64 = image_base64.strip() | |
| if image_base64.startswith("data:"): | |
| parts = image_base64.split(",", 1) | |
| if len(parts) == 2: | |
| return parts[1] | |
| return image_base64 | |
| def ocr_by_llm(image_base64: str, prompt: str) -> str: | |
| pixel_values = load_image_from_base64(image_base64, max_num=6) | |
| pixel_values = pixel_values.to(dtype=torch.float32, device=DEVICE) | |
| with torch.no_grad(): | |
| response_message = model.chat( | |
| tokenizer, | |
| pixel_values, | |
| prompt, | |
| generation_config, | |
| ) | |
| del pixel_values | |
| return response_message | |
| class OCRRequest(BaseModel): | |
| image_base64: str | |
| def ocr_endpoint(req: OCRRequest): | |
| image_base64 = normalize_base64(req.image_base64) | |
| if not image_base64: | |
| raise HTTPException(status_code=400, detail="image_base64 is required") | |
| try: | |
| response_message = ocr_by_llm(image_base64, PROMPT) | |
| parsed = parse_response_to_json(response_message) | |
| return {"response_message": parsed} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def startup_log(): | |
| """ | |
| Log basic information about available endpoints when the app starts. | |
| """ | |
| print("============================================") | |
| print("CCCD OCR API is running") | |
| print("Main endpoint: POST /ocr") | |
| print("Docs (Swagger): GET /docs") | |
| print("Redoc: GET /redoc") | |
| print("============================================") | |