diff --git a/llm_exl2_dynamic_gen.py b/llm_exl2_dynamic_gen.py index 8f6daad..339ac4d 100644 --- a/llm_exl2_dynamic_gen.py +++ b/llm_exl2_dynamic_gen.py @@ -417,38 +417,102 @@ async def stream_response(prompt_id, timeout=180): break +def handle_exception(exception_prompt_id, e, stream=None): + global prompt_ids2jobs, prompt_length, partial_responses + + exception_message = f"There was a problem processing your request. Exception occurred: {str(e)}" + + if stream: + exception_response = { + "id": f"chatcmpl-{exception_prompt_id}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": repo_str, + "choices": [ + { + "index": 0, + "delta": { + "content": exception_message + }, + "finish_reason": None + } + ] + } + partial_responses[exception_prompt_id] = [exception_response] + + partial_response_data = { + "finish_reason": "stop" + } + responses[exception_prompt_id] = partial_response_data + + else: + exception_response = { + "id": f"chatcmpl-{exception_prompt_id}", + "object": "chat.completion", + "created": int(time.time()), + "model": repo_str, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": exception_message, + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + } + } + + responses[exception_prompt_id] = exception_response + + if exception_prompt_id in prompt_ids2jobs: + generator.cancel(prompt_ids2jobs[exception_prompt_id]) + + prompt_ids2jobs.pop(exception_prompt_id, None) + prompt_length.pop(exception_prompt_id, None) + status_area.update(f"Problem processing the request for prompt: {exception_prompt_id}", line=STATUS_LINES-1) + + def process_prompts(): global partial_responses global prompt_ids2jobs, prompt_length, cancelled_request_ids - try: - while True: + while True: + try: while not prompts.empty() or len(prompt_length): while len(prompt_length) < max_batch_size and not prompts.empty(): prompt_id, prompt, max_tokens, stream, temperature, outlines_dict = prompts.get() stop_at = outlines_dict.get("stop_at", None) - if outlines_dict["type"] == "choices": - filters = [ChoiceFilter(outlines_dict["choices"], hf_tokenizer)] - elif outlines_dict["type"] == "json": - filters = [JSONFilter(outlines_dict["json"], hf_tokenizer)] - elif outlines_dict["type"] == "regex": - # Validation of regex - filters = [RegexFilter(outlines_dict["regex"], hf_tokenizer)] - else: - filters = [] + + try: + if outlines_dict["type"] == "choices": + filters = [ChoiceFilter(outlines_dict["choices"], hf_tokenizer)] + elif outlines_dict["type"] == "json": + filters = [JSONFilter(outlines_dict["json"], hf_tokenizer)] + elif outlines_dict["type"] == "regex": + # Validation of regex + filters = [RegexFilter(outlines_dict["regex"], hf_tokenizer)] + else: + filters = [] + except Exception as e: + handle_exception(prompt_id, e, stream) + ids = tokenizer.encode(prompt, encode_special_tokens = True) prompt_tokens = ids.shape[-1] new_tokens = prompt_tokens + max_tokens #print("Processing prompt: " + str(prompt_id) + " Req tokens: " + str(new_tokens)) status_area.update(f"Processing prompt: {prompt_id} Req tokens: {new_tokens}", line=STATUS_LINES-1) - # Truncate if new_tokens exceed max_context + + # If prompt exceeds allowed length, abort generation and respond with message if new_tokens > max_context: - # Calculate how many tokens to truncate - ids = tokenizer.encode("Say, 'Prompt exceeds allowed length. Please try again.'") - # Update new_tokens after truncation - prompt_tokens = ids.shape[-1] - new_tokens = prompt_tokens + max_tokens - print("Truncating prompt: " + str(prompt_id) + " Req tokens: " + str(new_tokens)) + exception_message = 'Prompt exceeds allowed length. Please try again.' + handle_exception(prompt_id, exception_message, stream) + continue + prompt_length[prompt_id] = prompt_tokens #streamer.append(stream) #prompt_ids.append(prompt_id) @@ -596,28 +660,27 @@ def process_prompts(): for item in temp_storage: prompts.put(item) - - else: # Sleep for a short duration when there's no work time.sleep(0.1) # Sleep for 100 milliseconds - except Exception as e: - print("Reset server due to ", e) - print(traceback.format_exc()) - for prompt_id in prompt_ids2jobs: - job = prompt_ids2jobs[prompt_id] - if(job.streamer): - ## Generator, yield here.. - partial_response_data = { - "finish_reason": "stop" - } - responses[prompt_id] = partial_response_data - else: - print("Error handling for full generation current not implemented") - generator.cancel(job) - prompt_ids2jobs = {} - prompt_length = {} + except Exception as e: + print("Reset server due to ", e) + print(traceback.format_exc()) + for prompt_id in prompt_ids2jobs: + job = prompt_ids2jobs[prompt_id] + if(job.streamer): + ## Generator, yield here.. + partial_response_data = { + "finish_reason": "stop" + } + + responses[prompt_id] = partial_response_data + else: + print("Error handling for full generation current not implemented") + generator.cancel(job) + prompt_ids2jobs = {} + prompt_length = {} # Start worker thread worker = Thread(target=process_prompts) diff --git a/llm_exl2_dynamic_gen_lora.py b/llm_exl2_dynamic_gen_lora.py index ec973eb..fda3e28 100644 --- a/llm_exl2_dynamic_gen_lora.py +++ b/llm_exl2_dynamic_gen_lora.py @@ -448,37 +448,102 @@ async def stream_response(prompt_id, timeout=180): break +def handle_exception(exception_prompt_id, e, stream=None): + global prompt_ids2jobs, prompt_length, partial_responses + + exception_message = f"There was a problem processing your request. Exception occurred: {str(e)}" + + if stream: + exception_response = { + "id": f"chatcmpl-{exception_prompt_id}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": repo_str, + "choices": [ + { + "index": 0, + "delta": { + "content": exception_message + }, + "finish_reason": None + } + ] + } + partial_responses[exception_prompt_id] = [exception_response] + + partial_response_data = { + "finish_reason": "stop" + } + responses[exception_prompt_id] = partial_response_data + + else: + exception_response = { + "id": f"chatcmpl-{exception_prompt_id}", + "object": "chat.completion", + "created": int(time.time()), + "model": repo_str, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": exception_message, + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + } + } + + responses[exception_prompt_id] = exception_response + + if exception_prompt_id in prompt_ids2jobs: + generator.cancel(prompt_ids2jobs[exception_prompt_id]) + + prompt_ids2jobs.pop(exception_prompt_id, None) + prompt_length.pop(exception_prompt_id, None) + status_area.update(f"Problem processing the request for prompt: {exception_prompt_id}", line=STATUS_LINES-1) + + def process_prompts(): global partial_responses global prompt_ids2jobs, prompt_length, prompt_model, cancelled_request_ids - try: - - while True: + + while True: + try: while not prompts.empty() or len(prompt_length): while len(prompt_length) < max_batch_size and not prompts.empty(): prompt_id, prompt, max_tokens, stream, temperature, rmodel, outlines_dict = prompts.get() stop_at = outlines_dict.get("stop_at", None) - if outlines_dict["type"] == "choices": - filters = [ChoiceFilter(outlines_dict["choices"], hf_tokenizer)] - elif outlines_dict["type"] == "json": - filters = [JSONFilter(outlines_dict["json"], hf_tokenizer)] - elif outlines_dict["type"] == "regex": - filters = [RegexFilter(outlines_dict["regex"], hf_tokenizer)] - else: - filters = [] + + try: + if outlines_dict["type"] == "choices": + filters = [ChoiceFilter(outlines_dict["choices"], hf_tokenizer)] + elif outlines_dict["type"] == "json": + filters = [JSONFilter(outlines_dict["json"], hf_tokenizer)] + elif outlines_dict["type"] == "regex": + # Validation of regex + filters = [RegexFilter(outlines_dict["regex"], hf_tokenizer)] + else: + filters = [] + except Exception as e: + handle_exception(prompt_id, e, stream) + ids = tokenizer.encode(prompt, encode_special_tokens = True) prompt_tokens = ids.shape[-1] new_tokens = prompt_tokens + max_tokens #print("Processing prompt: " + str(prompt_id) + " Req tokens: " + str(new_tokens)) status_area.update(f"Processing prompt: {prompt_id} Req tokens: {new_tokens}", line=STATUS_LINES-1) - # Truncate if new_tokens exceed max_context + + # If prompt exceeds allowed length, abort generation and respond with message if new_tokens > max_context: - # Calculate how many tokens to truncate - ids = tokenizer.encode("Say, 'Prompt exceeds allowed length. Please try again.'") - # Update new_tokens after truncation - prompt_tokens = ids.shape[-1] - new_tokens = prompt_tokens + max_tokens - print("Truncating prompt: " + str(prompt_id) + " Req tokens: " + str(new_tokens)) + exception_message = 'Prompt exceeds allowed length. Please try again.' + handle_exception(prompt_id, exception_message, stream) + continue + prompt_length[prompt_id] = prompt_tokens #streamer.append(stream) #prompt_ids.append(prompt_id) @@ -652,25 +717,25 @@ def process_prompts(): else: # Sleep for a short duration when there's no work time.sleep(0.1) # Sleep for 100 milliseconds - except Exception as e: - print("Reset server due to ", e) - print(traceback.format_exc()) - for prompt_id in prompt_ids2jobs: - job = prompt_ids2jobs[prompt_id] - if(job.streamer): - ## Generator, yield here.. - partial_response_data = { - "finish_reason": "stop" - } - - responses[prompt_id] = partial_response_data - else: - print("Error handling for full generation current not implemented") - generators[job.model].cancel(job) - #generator.cancel(job) - prompt_ids2jobs = {} - prompt_length = {} - prompt_model = {} + except Exception as e: + print("Reset server due to ", e) + print(traceback.format_exc()) + for prompt_id in prompt_ids2jobs: + job = prompt_ids2jobs[prompt_id] + if(job.streamer): + ## Generator, yield here.. + partial_response_data = { + "finish_reason": "stop" + } + + responses[prompt_id] = partial_response_data + else: + print("Error handling for full generation current not implemented") + generators[job.model].cancel(job) + #generator.cancel(job) + prompt_ids2jobs = {} + prompt_length = {} + prompt_model = {} # Start worker thread worker = Thread(target=process_prompts)