Oleg Lavrovsky commited on
Commit
b9acf2f
·
unverified ·
1 Parent(s): c46c72f

OpenAI type completions

Browse files
Files changed (1) hide show
  1. app.py +82 -34
app.py CHANGED
@@ -1,7 +1,7 @@
1
  from contextlib import asynccontextmanager
2
  from fastapi import FastAPI, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
- from pydantic import BaseModel
5
 
6
  from torch import cuda
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -40,6 +40,10 @@ class ModelResponse(BaseModel):
40
  confidence: float
41
  processing_time: float
42
 
 
 
 
 
43
 
44
  @asynccontextmanager
45
  async def lifespan(app: FastAPI):
@@ -88,6 +92,81 @@ app.add_middleware(
88
  allow_headers=["*"],
89
  )
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  @app.get("/predict", response_model=ModelResponse)
92
  async def predict(q: str):
93
  """Generate a model response for input text"""
@@ -100,40 +179,9 @@ async def predict(q: str):
100
 
101
  input_data = TextInput(text=q)
102
 
103
- # Truncate text if too long
104
- text = input_data.text[:input_data.max_length]
105
- if len(text) == input_data.max_length:
106
- logger.warning("Warning: text truncated")
107
- if len(text) < input_data.min_length:
108
- logger.warning("Warning: empty text, aborting")
109
- return None
110
-
111
- # Prepare the model input
112
- messages_think = [
113
- {"role": "user", "content": text}
114
- ]
115
- text = tokenizer.apply_chat_template(
116
- messages_think,
117
- tokenize=False,
118
- add_generation_prompt=True,
119
- top_p=0.9,
120
- temperature=0.8,
121
- )
122
- model_inputs = tokenizer(
123
- [text],
124
- return_tensors="pt",
125
- add_special_tokens=False
126
- ).to(model.device)
127
-
128
- # Generate the output
129
- generated_ids = model.generate(
130
- **model_inputs,
131
- max_new_tokens=512
132
- )
133
 
134
- # Get and decode the output
135
- output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :]
136
- result = tokenizer.decode(output_ids, skip_special_tokens=True)
137
 
138
  # Checkpoint
139
  processing_time = time.time() - start_time
 
1
  from contextlib import asynccontextmanager
2
  from fastapi import FastAPI, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel, ValidationError
5
 
6
  from torch import cuda
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
40
  confidence: float
41
  processing_time: float
42
 
43
+ class Completion(BaseModel):
44
+ model: str
45
+ prompt: str
46
+ max_tokens: int = 65536
47
 
48
  @asynccontextmanager
49
  async def lifespan(app: FastAPI):
 
92
  allow_headers=["*"],
93
  )
94
 
95
+
96
+ def fit_to_length(text, min_length=3, max_length=100):
97
+ """Truncate text if too long."""
98
+ text = text[:max_length]
99
+ if len(text) == max_length:
100
+ logger.warning("Warning: text truncated")
101
+ if len(text) < min_length:
102
+ logger.warning("Warning: empty text, aborting")
103
+ return None
104
+ return text
105
+
106
+
107
+ def get_model_reponse(text: str):
108
+ """Process the text content."""
109
+
110
+ # Prepare the model input
111
+ messages_think = [
112
+ {"role": "user", "content": text}
113
+ ]
114
+ text = tokenizer.apply_chat_template(
115
+ messages_think,
116
+ tokenize=False,
117
+ add_generation_prompt=True,
118
+ top_p=0.9,
119
+ temperature=0.8,
120
+ )
121
+ model_inputs = tokenizer(
122
+ [text],
123
+ return_tensors="pt",
124
+ add_special_tokens=False
125
+ ).to(model.device)
126
+
127
+ # Generate the output
128
+ generated_ids = model.generate(
129
+ **model_inputs,
130
+ max_new_tokens=512
131
+ )
132
+
133
+ # Get and decode the output
134
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :]
135
+
136
+ # Return just the text
137
+ return tokenizer.decode(output_ids, skip_special_tokens=True)
138
+
139
+
140
+ @app.post("/v1/models/apertus")
141
+ async def completion(data: Completion):
142
+ """Generate an OpenAPI-style completion"""
143
+ if model is None or tokenizer is None:
144
+ raise HTTPException(status_code=503, detail="Model not loaded")
145
+
146
+ try:
147
+ text = fit_to_length(input_data.text, input_data.max_length)
148
+
149
+ result = get_model_reponse(text, model)
150
+
151
+ return {
152
+ "choices": [
153
+ {
154
+ "text": result,
155
+ "_index": 0,
156
+ "logprobs": None,
157
+ "finish_reason": "length"
158
+ }
159
+ ],
160
+ "usage": {
161
+ "prompt_tokens": len(text),
162
+ "completion_tokens": len(result),
163
+ "total_tokens": len(text) + len(result)
164
+ }
165
+ }
166
+ except ValidationError as e:
167
+ raise HTTPException(status_code=400, detail="Invalid input data") from e
168
+
169
+
170
  @app.get("/predict", response_model=ModelResponse)
171
  async def predict(q: str):
172
  """Generate a model response for input text"""
 
179
 
180
  input_data = TextInput(text=q)
181
 
182
+ text = fit_to_length(input_data.text, input_data.max_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ result = get_model_reponse(text, model)
 
 
185
 
186
  # Checkpoint
187
  processing_time = time.time() - start_time