Sky-T1-32B-Preview, 450$로 학습한 o1급 오픈소스 모델 알아보기

JeongYun Lee·2025년 1월 23일
0

AI

목록 보기
13/13

Sky-T1-32B-Preview는 NovaSky(UC Berkeley)에서 2025년 1월에 공개한 오픈소스 모델이다. 데이터셋 생성부터 파인튜닝까지의 전체 코드를 다 공유하며, 학습에 사용된 비용이 450$밖에 되지 않는다고 밝혀서 주목을 받고 있다.

실제로 깃헙을 확인해보면 재현가능한 수준의 코드를 공유하고 있다. 코드와 함께 학습 데이터를 구성하는 방법과 SFT 파트를 구분해서 살펴보았다.


💫 학습 데이터 구성

기본적으로 Sky-T1 모델은 reasoning 기반의 학습이 이루어졌다. reasoning을 할 수 있는 학습데이더는 QwQ-32B-Preview 모델(o1-preview와 비슷한 추론 기능을 갖춘 Qwen의 오픈소스 모델)을 기반으로 data synthesis를 통해 생성했다고 한다. 기반데이터는 다음과 같다.

  • 코딩 데이터: APPs와 TACO에서 수집한 5000개
  • 수학 데이터: AIME, MATH, Olympiads 등 NuminalMATH(math dataset) 데이터셋에서 약 10000개
  • 과학 및 퍼즐 데이터: STILL-2에서 약 1000개

기반데이터로 data synthesis하고 reject sampling하고 다시 reformatting하고 reject sampling한 데이터가 최종적으로 17000개라는 듯

Method

(1) data synthesis

  • 실행 파일: skythought/tools/inference_and_check.py

QA셋 형태의 추론이 필요한 다양한 도메인의 기반 데이터(data mixture)를 가져와서 QwQ모델에 넣고 reasoning을 수행하는 정답을 도출하게 하도록 한다. 이후 실제 정답과 QwQ가 생성한 정답이 맞는지 확인하고 선택, 제거를 하는 reject sampling을 진행한다.

data mixture와 관련해서 추론의 방식이 다른 학습 데이터셋들을 사용할 경우 데이터의 양과 난이도를 '적절히' 조절해서 섞어야 두 영역 모두 성능 향상이 가능하다고 하는데, 여기서 적절한 비율을 찾는게 하나의 태스크가 될 수 있을 것 같다.

(2) data formatting

  • 실행 파일: skythought/tools/convert_format.py

reject sampling을 거쳐서 reasoning이 포함된 데이터들을 다시 GPT-4o-mini를 활용해서 rewrite(reformatting과정)해주는 과정이다. 이때 맞춰주는 format은 다음과 같다.

# util/prompts.py
convert_prompt = "Another solution is written in an unstructured way. Your job is to convert them into two sections: \
    <|begin_of_thought|> \
    (Thought process, you should copy exactly the thinking process of the original solution.) \
    <|end_of_thought|> \
    <|begin_of_solution|> \
    (Final formatted, precise, and clear solution; make sure there is only one solution in this section; If it is a coding problem, make sure there is only one code block) \
    <|end_of_solution|> \
    Here is an example demonstration of a different question, you can refer to its format: \
    {example} \
    Important: You should almost copy all the contents word-by-word of the original solution. Just convert them into two sections. \
    Make sure you include: <|begin_of_slow_thought|>, <|end_of_slow_thought|>,  <|begin_of_solution|>,<|end_of_solution|>  These four headers explicitly.  \
    Content to be converted: {content}"

<|begin_of_thought|> ~ <|end_of_thought|>로 논리적인 사고 과정을 설명하도록 하고 <|begin_of_solution|> ~ <|end_of_solution|>에서는 정확한 최종 솔루션을 넣어주는 구조이다.

convert_prompt_example도 함께 넣어준다.

# util/prompts.py
convert_prompt_example = ("<|begin_of_thought|>\n\n"
    "Okay, so I've got this problem here. Mr. Wang leaves home at 6 AM, riding his bike at 12 km/h, "
    "and he stops to rest for 6 minutes after every 30 minutes of riding. Then, when he arrives at a park "
    "that's 16.8 km away, I need to find out the angle between the hour and minute hands on his watch.\n\n"
    "Alright, first things first, I need to figure out how long it takes Mr. Wang to ride 16.8 km, including "
    "his rest periods.\n\n"
    "So, his speed is 12 km/h. To find out how long it takes to go 16.8 km without any stops, I can use the formula "
    "time = distance/speed. That would be 16.8 divided by 12, which is 1.4 hours. To make it easier, that's 1 hour and 24 minutes.\n\n"
    "But wait, he doesn't ride straight through. He stops for 6 minutes after every 30 minutes of riding. So, I need to see how many "
    "of those 30-minute riding periods are there in his total riding time.\n\n"
    "In 1 hour and 24 minutes of riding, how many 30-minute segments are there? Well, 1 hour is 60 minutes, plus 24 minutes makes 84 minutes "
    "total riding time. So, 84 divided by 30 is 2.8. That means he has two full 30-minute riding periods and a partial one.\n\n"
    "After each full 30-minute riding period, he rests for 6 minutes. So, for two full periods, he rests twice, which is 12 minutes of rest.\n\n"
    "Now, for the partial riding period. Since 2 times 30 minutes is 60 minutes, and he has 84 minutes of riding, the remaining riding time is 84 minus 60, "
    "which is 24 minutes. So, he rides for 24 minutes without another rest because he doesn't complete another 30-minute segment.\n\n"
    "So, total time taken is riding time plus rest time. That's 84 minutes riding plus 12 minutes resting, totaling 96 minutes.\n\n"
    "Wait a minute, but he stops after every 30 minutes of riding, but in the last partial period of 24 minutes, does he rest again? I think he only rests after "
    "completing 30 minutes of riding, so in this case, since the last riding period is only 24 minutes, he doesn't take an additional rest after that.\n\n"
    "So, total time should be 84 minutes riding plus 12 minutes resting, which is indeed 96 minutes, or 1 hour and 36 minutes.\n\n"
    "So, he leaves at 6 AM and takes 1 hour and 36 minutes to reach the park, arriving at 7:36 AM.\n\n"
    "Now, I need to find the angle between the hour and minute hands at 7:36.\n\n"
    "To find the angle between the hour and minute hands, I can use the formula:\n\n"
    "|30H - 5.5M|\n\n"
    "where H is the hour and M is the minutes.\n\n"
    "At 7:36, H is 7 and M is 36.\n\n"
    "So, plugging in:\n\n"
    "30*7 = 210\n\n"
    "5.5*36 = 198\n\n"
    "210 - 198 = 12\n\n"
    "So, the angle is 12 degrees.\n\n"
    "Wait, but I should make sure that's the smaller angle. Sometimes, the larger angle is considered, but usually, the smaller one is what is asked for.\n\n"
    "So, the angle between the hour and minute hands at 7:36 AM is 12 degrees.\n\n"
    "I think that's the answer.<|end_of_thought|>\n\n"
    "<|begin_of_solution|>\n\n"
    "Mr. Wang leaves home at 6 AM and rides at a speed of 12 km/h, stopping to rest for 6 minutes after every 30 minutes of riding. "
    "He arrives at a park 16.8 km away. To determine the angle between the hour and minute hands on his watch when he arrives, we first calculate the total time taken.\n\n"
    "1. **Riding time without stops**:\n\n"
    "$$\\text{Time} = \\frac{\\text{Distance}}{\\text{Speed}} = \\frac{16.8 \\text{ km}}{12 \\text{ km/h}} = 1.4 \\text{ hours} = 84 \\text{ minutes}$$\n\n"
    "2. **Rest periods**:\n\n"
    "  - He rests for 6 minutes after every 30 minutes of riding.\n\n"
    "  - In 84 minutes of riding, he completes 2 full 30-minute segments and a partial 24-minute segment.\n\n"
    "  - He rests twice, totaling 12 minutes of rest.\n\n"
    "3. **Total time**:\n\n"
    "$$\\text{Total time} = 84 \\text{ minutes (riding)} + 12 \\text{ minutes (rest)} = 96 \\text{ minutes} = 1 \\text{ hour and } 36 \\text{ minutes}$$\n\n"
    "  - He arrives at 7:36 AM.\n\n"
    "4. **Angle between hour and minute hands at 7:36**:\n\n"
    "  - Use the formula:\n\n"
    "$$\\text{Angle} = |30H - 5.5M|$$\n\n"
    "  - At 7:36, $H = 7$ and $M = 36$:\n\n"
    "$$\\text{Angle} = |30 \\times 7 - 5.5 \\times 36| = |210 - 198| = 12 \\text{ degrees}$$\n\n"
    "Thus, the angle between the hour and minute hands on his watch is $\\boxed{12}$.<|end_of_solution|>\n")

(3) 최종 Reject sampling

  • 실행 파일: skythought/tools/inference_and_check.py

formatting한 데이터를 가져와서 다시 reject sampling 진행한다. 최종적으로 사용한 데이터는 17000건 정도라고 한다.

(4) Convert to ShareGPT format for training

  • 실행 파일: skythought/tools/convert_to_data.py

최종 샘플링한 데이터를 학습 가능한 구조인 ShareGPT format으로 맞춰주는 작업이다.

[
    {
        "conversation": [
            {"role": "user", "content": "What is AI?"},
            {"role": "assistant", "content": "AI stands for artificial intelligence, which refers to machines that can mimic human intelligence."},
            {"role": "user", "content": "Can you give an example?"},
            {"role": "assistant", "content": "Sure, an example of AI is a chatbot like me that can answer your questions."}
        ]
    },
    {
        "conversation": [
            {"role": "user", "content": "What is the capital of France?"},
            {"role": "assistant", "content": "The capital of France is Paris."}
        ]
    }
]

다만, 내가 아는 ShareGPT format은 위와 같이 같이 user, assistant 등 대화의 주체를 명확히 구분해서 작성하는 포맷인데, 여기서 사용한 utils.prompts.py의 system_prompt는 다음과 같다.


# From https://arxiv.org/pdf/2412.09413
system_prompt = "Your role as an assistant involves thoroughly exploring questions through a systematic long \
thinking process before providing the final precise and accurate solutions. This requires \
engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, \
backtracing, and iteration to develop well-considered thinking process. \
Please structure your response into two main sections: Thought and Solution. \
In the Thought section, detail your reasoning process using the specified format: \
<|begin_of_thought|> {thought with steps separated with '\n\n'} \
<|end_of_thought|> \
Each step should include detailed considerations such as analisying questions, summarizing \
relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining \
any errors, and revisiting previous steps. \
In the Solution section, based on various attempts, explorations, and reflections from the Thought \
section, systematically present the final solution that you deem correct. The solution should \
remain a logical, accurate, concise expression style and detail necessary step needed to reach the \
conclusion, formatted as follows: \
<|begin_of_solution|> \
{final formatted, precise, and clear solution} \
<|end_of_solution|> \
Now, try to solve the following question through the above guidelines:" 

하지만, 최종 데이터셋을 보면 conversations에 fromvalue로 잘 지정되어 있는 걸을 알 수 있다. 뭔가 내가 놓친 부분이 있는거 같은데 일단 넘어가겠다.

💫 SFT

base model은 Qwen2.5-32B-Instruct로, 추론 기능이 없는 모델이다. 연구진이 작성한 블로그에 보면, 7B정도의 사이즈가 작은 모델로 테스트 했을 때는 큰 성능 개선을 확인할 수 없었고, 32B정도에서는 뚜렷한 성능 개선이 있었다고 한다.

학습 파라미터는 다음과 같다.

  • 3 epochs
  • learning rate of 1e-5
  • batch size of 96

GPU는 AWS의 Lambda Cloud를 사용했고, H100 8대로 19시간 학습을 했다고 한다.

백엔드에서는 Deepspeed Zero-3 offload로 가속화를 했는데, 이는 MS의 대규모 분산 딥러닝을 위한 데이터 병렬 처리+메모리 최적화 라이브러리라고 한다. 모델 파라미터, 그래디언트, 옵티마이저 상태를 GPU들에 분산 저장한다. 이때 offload는 일부 데이터를 CPU 메모리로 이동시켜 GPU 메모리 사용량을 줄이는 기술이다.

최종 코드는 Llama-Factory를 사용했다고 한다. Llama-Factory는 No-Code training을 지원하며, GPU와 파라미터 설정 등 파인 튜닝에 필요한 다양한 밑작업들을 간단하게 진행할 수 있도록 하는 프레임워크이다.

UnSloth와 대응해서 사용할 수 있는 프레임워크라고 이해했는데, Llama-Factory에서도 UnSloth를 지원한다고 해서 살짝 헷갈렸다. 지금 이해하기로는 UnSloth는 양자화 모델의 학습을 지원하는 라이브러리인데 단독으로 사용할 수 있지만, 마치 langchain을 사용하듯 Llama-Factory에서 사용할 수 있다...뭐 이런 개념이 아닐까 생각한다. 이건 테스트를 해봐야 알 것 같다.


기존의 대규모 모델들은 오픈소스라고 해도 이렇게 데이터셋 구축까지의 모든 코드를 공개하진 않았던 것 같다. 다만, 위에서도 언급했듯 다양한 형식의 데이터셋을 어떻게 '적절히'구성하는 지에 대한 의문은 여전히 있으며, fine-tuning을 할 때 full finetuning인지, PEFT 방식을 사용했는지에 대한 명확한 언급은 없었다. 아마도 이 부분은 코드에서 확인해야 할 것 같다.

또한 학습 과정에서 사용된 비용이 450$라고 했는데, 여기서 궁금했던 점은 분명 이 모델을 만들기까지 버전 0.1, 0.2, 0.3 등 수많은 테스트를 거쳤을 텐데, 그 테스트 비용을 총 합친게 450$일까? 라는 의문이다. 아마도 최종 버전을 만들 때 사용한 클라우드 비용만 450달러가 아닐까...하는 추측은 있다.

그럼에도 연구실이나 작은 규모의 회사에서도 32B정도의 모델을 reasoning기반의 fine-tuning을 할 수 있는 가능성을 보여준다는 점에서 의의가 있다고 생각한다.

profile
궁금한 건 많지만, 천천히 알아가는 중입니다

0개의 댓글

관련 채용 정보