-
Notifications
You must be signed in to change notification settings - Fork 34
Open
Description
Hi @senwu and @lorr1 , I am trying to run NumbersStation/nsql-llama-2-7B model (loaded from HF) on a schema as shown below with same prompt as mentioned in notebooks. Model runs for couple of minutes and then returns "SELECT SUM" posting the schema and prompt please let me know if i am missing on something or should improve?
table_schema = """CREATE TABLE ACCOUNT (
ACCOUNT_ID text,
ACCOUNT_NAME text,
INDUSTRY text,
ACCOUNT_SEGMENT text,
BILLING_COUNTRY text,
)
CREATE TABLE DISCOUNT (
OPPORTUNITY_NUMBER number,
OPPORTUNITY_NAME text,
FISCAL_YEAR_NUMBER text,
OPP_LEVEL_DISCOUNTS number,
OPP_LEVEL_DISCOUNT_VALID text,
LIST_PRICE_MS number,
)
CREATE TABLE PIPELINE (
OPPORTUNITY_ID text,
OPPORTUNITY_LINE_ID text,
OPPORTUNITY_NAME text,
OPPORTUNITY_FORECAST_CATEGORY text,
ACCOUNT_ID text,
ACCOUNT_NAME text,
OPPORTUNITY_CLOSE_DATE DATE,
OPPORTUNITY_CLOSE_FISCAL_YEAR INTEGER,
OPPORTUNITY_CLOSE_FISCAL_QTR_NUMBER INTEGER,
OPPORTUNITY_CLOSE_FISCAL_YEAR_QTR text,
FINAL_GLOBAL_CUSTOMER_NAME text,
SYB_AMOUNT number
)
"""
question = "What is the total SYB amount for year 2023 by account segment?"
prompt = f"""{table_schema}
-- Generate a valid SQL query for the following request:
-- Request: {question}
SELECT"""
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
# model = model.to(mps_device)
generated_ids = model.generate(input_ids, max_length=500)
output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
output = 'SELECT' + output.split('SELECT')[-1]
print(output)
the model returns SELECT(SUM) whereas expected output is a join query
Please support and thanks in advance
Metadata
Metadata
Assignees
Labels
No labels